validation loss goes up and down [variational inference]

139 views Asked by At

i was training a mlp through variational inference for a regression task on a small dataset with 1 feature. The nn works and the training loss goes down but the validation loss has random spikes and i do not understand how to avoid them

import tensorflow_probability as tfp
import tensorflow as tf
 
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

def create_flipout_bnn_model(train_size):
  def normal_sp(params): 
      return tfd.Normal(loc=params[:,0:1], scale=1e-3 + tf.math.softplus(0.05 * params[:,1:2]))

  kernel_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (train_size)
  bias_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (train_size)


  inputs = Input(shape=(1,),name="input layer")


  hidden = tfp.layers.DenseFlipout(30,
                           kernel_divergence_fn=kernel_divergence_fn,
                           activation="relu",name="DenseFlipout_layer_1")(inputs)
  hidden = tfp.layers.DenseFlipout(30,
                           kernel_divergence_fn=kernel_divergence_fn,
                           activation="relu",name="DenseFlipout_layer_2")(hidden)
  hidden = tfp.layers.DenseFlipout(30,
                           kernel_divergence_fn=kernel_divergence_fn,
                           activation="relu",name="DenseFlipout_layer_3")(hidden)
  params = tfp.layers.DenseFlipout(2,
                           kernel_divergence_fn=kernel_divergence_fn,
                           name="DenseFlipout_layer_5")(hidden)
  dist = tfp.layers.DistributionLambda(normal_sp,name = 'normal_sp')(params) 

  model = Model(inputs=inputs, outputs=dist)

 
  return model

batch_size  = train_size
 
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=1830,restore_best_weights=True)
flipout_BNN = create_flipout_bnn_model(train_size=train_size)
flipout_BNN.compile(optimizer=Adam(learning_rate=0.002 ),jit_compile=True,
                  loss=NLL,metrics= [tf.keras.metrics.RootMeanSquaredError()]
                 ) 
flipout_BNN.summary()
history_flipout_BNN = flipout_BNN.fit(X_train, y_train, epochs=30000, verbose=0, batch_size=batch_size,validation_data=(X_val,y_val),callbacks=[callback] )

the result of the training is almost always something like this

how can i avoid this issue? enter image description here

1

There are 1 answers

0
AudioBubble On

The fluctuating val_loss could be due to the dataset containing outliers. To tackle this situation we can Normalize the dataset and can use loss=tf.keras.losses.MeanAbsoluteError() for better results. I have tried replicating the above code using one of the datasets from this page.

Please have a look at the output: (Attaching the gist for your reference)

history_flipout_BNN = flipout_BNN.fit(train_features,train_labels, epochs=500, verbose=2, batch_size=batch_size,validation_data=(test_features,test_labels),callbacks=[callback] )

Output:

Epoch 401/500
10/10 - 0s - loss: 26.9614 - mean_absolute_error: 5.7971 - val_loss: 25.3911 - val_mean_absolute_error: 4.2861 - 99ms/epoch - 10ms/step
Epoch 402/500
10/10 - 0s - loss: 26.0816 - mean_absolute_error: 5.0223 - val_loss: 25.5811 - val_mean_absolute_error: 4.5798 - 85ms/epoch - 8ms/step
Epoch 403/500
10/10 - 0s - loss: 25.8087 - mean_absolute_error: 4.8529 - val_loss: 26.4341 - val_mean_absolute_error: 5.5357 - 90ms/epoch - 9ms/step
Epoch 404/500
10/10 - 0s - loss: 25.3384 - mean_absolute_error: 4.4862 - val_loss: 25.0166 - val_mean_absolute_error: 4.2234 - 92ms/epoch - 9ms/step
Epoch 405/500
10/10 - 0s - loss: 25.5521 - mean_absolute_error: 4.8063 - val_loss: 25.1751 - val_mean_absolute_error: 4.4889 - 103ms/epoch - 10ms/step
Epoch 406/500
10/10 - 0s - loss: 25.2258 - mean_absolute_error: 4.5854 - val_loss: 25.0371 - val_mean_absolute_error: 4.4541 - 89ms/epoch - 9ms/step
Epoch 407/500
10/10 - 0s - loss: 25.1053 - mean_absolute_error: 4.5674 - val_loss: 25.2423 - val_mean_absolute_error: 4.7611 - 74ms/epoch - 7ms/step
Epoch 408/500
10/10 - 0s - loss: 25.4199 - mean_absolute_error: 4.9825 - val_loss: 25.0862 - val_mean_absolute_error: 4.7047 - 73ms/epoch - 7ms/step
Epoch 409/500
10/10 - 0s - loss: 25.2182 - mean_absolute_error: 4.8814 - val_loss: 25.0651 - val_mean_absolute_error: 4.7841 - 75ms/epoch - 8ms/step
Epoch 410/500
10/10 - 0s - loss: 24.6748 - mean_absolute_error: 4.4359 - val_loss: 24.5389 - val_mean_absolute_error: 4.3531 - 93ms/epoch - 9ms/step
Epoch 411/500
10/10 - 0s - loss: 25.5475 - mean_absolute_error: 5.4029 - val_loss: 24.3203 - val_mean_absolute_error: 4.2277 - 98ms/epoch - 10ms/step
Epoch 412/500
10/10 - 0s - loss: 24.5997 - mean_absolute_error: 4.5493 - val_loss: 24.5846 - val_mean_absolute_error: 4.5878 - 86ms/epoch - 9ms/step
Epoch 413/500
10/10 - 0s - loss: 24.7902 - mean_absolute_error: 4.8364 - val_loss: 24.7510 - val_mean_absolute_error: 4.8514 - 73ms/epoch - 7ms/step
Epoch 414/500
10/10 - 0s - loss: 24.5510 - mean_absolute_error: 4.6942 - val_loss: 24.8243 - val_mean_absolute_error: 5.0216 - 74ms/epoch - 7ms/step
Epoch 415/500
10/10 - 0s - loss: 25.2334 - mean_absolute_error: 5.4727 - val_loss: 24.1006 - val_mean_absolute_error: 4.3917 - 93ms/epoch - 9ms/step
Epoch 416/500
10/10 - 0s - loss: 25.1920 - mean_absolute_error: 5.5221 - val_loss: 24.9322 - val_mean_absolute_error: 5.3105 - 71ms/epoch - 7ms/step
Epoch 417/500
10/10 - 0s - loss: 24.0554 - mean_absolute_error: 4.4716 - val_loss: 23.5270 - val_mean_absolute_error: 3.9934 - 91ms/epoch - 9ms/step
Epoch 418/500
10/10 - 0s - loss: 24.3465 - mean_absolute_error: 4.8528 - val_loss: 24.0905 - val_mean_absolute_error: 4.6466 - 68ms/epoch - 7ms/step
Epoch 419/500
10/10 - 0s - loss: 23.8700 - mean_absolute_error: 4.4655 - val_loss: 25.7271 - val_mean_absolute_error: 6.3734 - 91ms/epoch - 9ms/step
Epoch 420/500
10/10 - 0s - loss: 23.9039 - mean_absolute_error: 4.5902 - val_loss: 23.9026 - val_mean_absolute_error: 4.6395 - 73ms/epoch - 7ms/step
Epoch 421/500
10/10 - 0s - loss: 23.8635 - mean_absolute_error: 4.6409 - val_loss: 23.6810 - val_mean_absolute_error: 4.5092 - 80ms/epoch - 8ms/step
Epoch 422/500
10/10 - 0s - loss: 23.6667 - mean_absolute_error: 4.5341 - val_loss: 23.3594 - val_mean_absolute_error: 4.2762 - 95ms/epoch - 9ms/step
Epoch 423/500
10/10 - 0s - loss: 23.7698 - mean_absolute_error: 4.7262 - val_loss: 23.7548 - val_mean_absolute_error: 4.7614 - 91ms/epoch - 9ms/step
Epoch 424/500
10/10 - 0s - loss: 23.4459 - mean_absolute_error: 4.4925 - val_loss: 23.1056 - val_mean_absolute_error: 4.2023 - 106ms/epoch - 11ms/step
Epoch 425/500
10/10 - 0s - loss: 23.0812 - mean_absolute_error: 4.2165 - val_loss: 25.4889 - val_mean_absolute_error: 6.6729 - 86ms/epoch - 9ms/step
Epoch 426/500
10/10 - 0s - loss: 23.5379 - mean_absolute_error: 4.7600 - val_loss: 23.3109 - val_mean_absolute_error: 4.5815 - 81ms/epoch - 8ms/step
Epoch 427/500
10/10 - 0s - loss: 23.0973 - mean_absolute_error: 4.4067 - val_loss: 23.1382 - val_mean_absolute_error: 4.4958 - 86ms/epoch - 9ms/step
Epoch 428/500
10/10 - 0s - loss: 23.1789 - mean_absolute_error: 4.5741 - val_loss: 22.8976 - val_mean_absolute_error: 4.3397 - 95ms/epoch - 10ms/step
Epoch 429/500
10/10 - 0s - loss: 23.5457 - mean_absolute_error: 5.0253 - val_loss: 23.5527 - val_mean_absolute_error: 5.0789 - 72ms/epoch - 7ms/step
Epoch 430/500
10/10 - 0s - loss: 23.3083 - mean_absolute_error: 4.8688 - val_loss: 23.1191 - val_mean_absolute_error: 4.7235 - 66ms/epoch - 7ms/step
Epoch 431/500
10/10 - 0s - loss: 22.9547 - mean_absolute_error: 4.5943 - val_loss: 23.0954 - val_mean_absolute_error: 4.7804 - 79ms/epoch - 8ms/step
Epoch 432/500
10/10 - 0s - loss: 23.2755 - mean_absolute_error: 4.9968 - val_loss: 22.9616 - val_mean_absolute_error: 4.7283 - 88ms/epoch - 9ms/step
Epoch 433/500
10/10 - 0s - loss: 23.5949 - mean_absolute_error: 5.3955 - val_loss: 22.3164 - val_mean_absolute_error: 4.1575 - 90ms/epoch - 9ms/step
Epoch 434/500
10/10 - 0s - loss: 22.9300 - mean_absolute_error: 4.8030 - val_loss: 22.6514 - val_mean_absolute_error: 4.5655 - 76ms/epoch - 8ms/step
Epoch 435/500
10/10 - 0s - loss: 22.8438 - mean_absolute_error: 4.7912 - val_loss: 22.4905 - val_mean_absolute_error: 4.4806 - 89ms/epoch - 9ms/step
Epoch 436/500
10/10 - 0s - loss: 22.4169 - mean_absolute_error: 4.4417 - val_loss: 22.8025 - val_mean_absolute_error: 4.8712 - 83ms/epoch - 8ms/step
Epoch 437/500
10/10 - 0s - loss: 22.5422 - mean_absolute_error: 4.6457 - val_loss: 24.9054 - val_mean_absolute_error: 7.0521 - 69ms/epoch - 7ms/step
Epoch 438/500
10/10 - 0s - loss: 22.5137 - mean_absolute_error: 4.6931 - val_loss: 21.7106 - val_mean_absolute_error: 3.9314 - 85ms/epoch - 9ms/step
Epoch 439/500
10/10 - 0s - loss: 22.3283 - mean_absolute_error: 4.5823 - val_loss: 22.6495 - val_mean_absolute_error: 4.9456 - 99ms/epoch - 10ms/step
Epoch 440/500
10/10 - 0s - loss: 22.0736 - mean_absolute_error: 4.4034 - val_loss: 22.3924 - val_mean_absolute_error: 4.7656 - 70ms/epoch - 7ms/step
Epoch 441/500
10/10 - 0s - loss: 22.3100 - mean_absolute_error: 4.7172 - val_loss: 22.6311 - val_mean_absolute_error: 5.0808 - 85ms/epoch - 8ms/step
Epoch 442/500
10/10 - 0s - loss: 23.0256 - mean_absolute_error: 5.5076 - val_loss: 22.9733 - val_mean_absolute_error: 5.4954 - 75ms/epoch - 7ms/step
Epoch 443/500
10/10 - 0s - loss: 21.6798 - mean_absolute_error: 4.2348 - val_loss: 21.6242 - val_mean_absolute_error: 4.2211 - 91ms/epoch - 9ms/step
Epoch 444/500
10/10 - 0s - loss: 21.9015 - mean_absolute_error: 4.5319 - val_loss: 22.2187 - val_mean_absolute_error: 4.8911 - 83ms/epoch - 8ms/step
Epoch 445/500
10/10 - 0s - loss: 21.8007 - mean_absolute_error: 4.5058 - val_loss: 21.4597 - val_mean_absolute_error: 4.2052 - 109ms/epoch - 11ms/step
Epoch 446/500
10/10 - 0s - loss: 21.7802 - mean_absolute_error: 4.5573 - val_loss: 21.0135 - val_mean_absolute_error: 3.8313 - 91ms/epoch - 9ms/step
Epoch 447/500
10/10 - 0s - loss: 21.8448 - mean_absolute_error: 4.6948 - val_loss: 21.4467 - val_mean_absolute_error: 4.3360 - 71ms/epoch - 7ms/step
Epoch 448/500
10/10 - 0s - loss: 21.4740 - mean_absolute_error: 4.3944 - val_loss: 21.8361 - val_mean_absolute_error: 4.7958 - 87ms/epoch - 9ms/step
Epoch 449/500
10/10 - 0s - loss: 21.3762 - mean_absolute_error: 4.3665 - val_loss: 21.1680 - val_mean_absolute_error: 4.1976 - 97ms/epoch - 10ms/step
Epoch 450/500
10/10 - 0s - loss: 21.1940 - mean_absolute_error: 4.2551 - val_loss: 21.6621 - val_mean_absolute_error: 4.7631 - 83ms/epoch - 8ms/step
Epoch 451/500
10/10 - 0s - loss: 21.2462 - mean_absolute_error: 4.3787 - val_loss: 21.3224 - val_mean_absolute_error: 4.4936 - 74ms/epoch - 7ms/step
Epoch 452/500
10/10 - 0s - loss: 21.4716 - mean_absolute_error: 4.6725 - val_loss: 21.3193 - val_mean_absolute_error: 4.5576 - 84ms/epoch - 8ms/step
Epoch 453/500
10/10 - 0s - loss: 21.1026 - mean_absolute_error: 4.3706 - val_loss: 21.8985 - val_mean_absolute_error: 5.2052 - 72ms/epoch - 7ms/step
Epoch 454/500
10/10 - 0s - loss: 20.9130 - mean_absolute_error: 4.2509 - val_loss: 20.8512 - val_mean_absolute_error: 4.2281 - 95ms/epoch - 9ms/step
Epoch 455/500
10/10 - 0s - loss: 20.7118 - mean_absolute_error: 4.1189 - val_loss: 24.0372 - val_mean_absolute_error: 7.4822 - 105ms/epoch - 10ms/step
Epoch 456/500
10/10 - 0s - loss: 20.9174 - mean_absolute_error: 4.3919 - val_loss: 20.5719 - val_mean_absolute_error: 4.0837 - 99ms/epoch - 10ms/step
Epoch 457/500
10/10 - 0s - loss: 21.0356 - mean_absolute_error: 4.5764 - val_loss: 21.6790 - val_mean_absolute_error: 5.2558 - 73ms/epoch - 7ms/step
Epoch 458/500
10/10 - 0s - loss: 21.0807 - mean_absolute_error: 4.6846 - val_loss: 20.1755 - val_mean_absolute_error: 3.8142 - 93ms/epoch - 9ms/step
Epoch 459/500
10/10 - 0s - loss: 20.3427 - mean_absolute_error: 4.0099 - val_loss: 45.3344 - val_mean_absolute_error: 29.0383 - 70ms/epoch - 7ms/step
Epoch 460/500
10/10 - 0s - loss: 20.8348 - mean_absolute_error: 4.5675 - val_loss: 20.6027 - val_mean_absolute_error: 4.3719 - 69ms/epoch - 7ms/step
Epoch 461/500
10/10 - 0s - loss: 20.4903 - mean_absolute_error: 4.2877 - val_loss: 20.9089 - val_mean_absolute_error: 4.7427 - 69ms/epoch - 7ms/step
Epoch 462/500
10/10 - 0s - loss: 20.3484 - mean_absolute_error: 4.2115 - val_loss: 20.4108 - val_mean_absolute_error: 4.3104 - 86ms/epoch - 9ms/step
Epoch 463/500
10/10 - 0s - loss: 22.3694 - mean_absolute_error: 6.2942 - val_loss: 20.5144 - val_mean_absolute_error: 4.4713 - 71ms/epoch - 7ms/step
Epoch 464/500
10/10 - 0s - loss: 20.8265 - mean_absolute_error: 4.8048 - val_loss: 20.5560 - val_mean_absolute_error: 4.5620 - 69ms/epoch - 7ms/step
Epoch 465/500
10/10 - 0s - loss: 20.3374 - mean_absolute_error: 4.3669 - val_loss: 20.2208 - val_mean_absolute_error: 4.2813 - 77ms/epoch - 8ms/step
Epoch 466/500
10/10 - 0s - loss: 20.5049 - mean_absolute_error: 4.5907 - val_loss: 19.7859 - val_mean_absolute_error: 3.9031 - 96ms/epoch - 10ms/step
Epoch 467/500
10/10 - 0s - loss: 20.2043 - mean_absolute_error: 4.3468 - val_loss: 19.7356 - val_mean_absolute_error: 3.9101 - 123ms/epoch - 12ms/step
Epoch 468/500
10/10 - 0s - loss: 19.8643 - mean_absolute_error: 4.0647 - val_loss: 22.0424 - val_mean_absolute_error: 6.2759 - 73ms/epoch - 7ms/step
Epoch 469/500
10/10 - 0s - loss: 20.0671 - mean_absolute_error: 4.3277 - val_loss: 20.1854 - val_mean_absolute_error: 4.4798 - 77ms/epoch - 8ms/step
Epoch 470/500
10/10 - 0s - loss: 20.0209 - mean_absolute_error: 4.3407 - val_loss: 20.3919 - val_mean_absolute_error: 4.7435 - 89ms/epoch - 9ms/step
Epoch 471/500
10/10 - 0s - loss: 20.0756 - mean_absolute_error: 4.4520 - val_loss: 19.9796 - val_mean_absolute_error: 4.3877 - 76ms/epoch - 8ms/step
Epoch 472/500
10/10 - 0s - loss: 20.0424 - mean_absolute_error: 4.4758 - val_loss: 19.5151 - val_mean_absolute_error: 3.9802 - 118ms/epoch - 12ms/step
Epoch 473/500
10/10 - 0s - loss: 19.6751 - mean_absolute_error: 4.1645 - val_loss: 19.7584 - val_mean_absolute_error: 4.2778 - 73ms/epoch - 7ms/step
Epoch 474/500
10/10 - 0s - loss: 20.1219 - mean_absolute_error: 4.6639 - val_loss: 19.6155 - val_mean_absolute_error: 4.1857 - 70ms/epoch - 7ms/step
Epoch 475/500
10/10 - 0s - loss: 19.7404 - mean_absolute_error: 4.3327 - val_loss: 20.8408 - val_mean_absolute_error: 5.4616 - 76ms/epoch - 8ms/step
Epoch 476/500
10/10 - 0s - loss: 19.8373 - mean_absolute_error: 4.4821 - val_loss: 19.5820 - val_mean_absolute_error: 4.2573 - 70ms/epoch - 7ms/step
Epoch 477/500
10/10 - 0s - loss: 19.4083 - mean_absolute_error: 4.1082 - val_loss: 19.3071 - val_mean_absolute_error: 4.0379 - 88ms/epoch - 9ms/step
Epoch 478/500
10/10 - 0s - loss: 19.4480 - mean_absolute_error: 4.2032 - val_loss: 19.2325 - val_mean_absolute_error: 4.0184 - 120ms/epoch - 12ms/step
Epoch 479/500
10/10 - 0s - loss: 19.6966 - mean_absolute_error: 4.5051 - val_loss: 18.7808 - val_mean_absolute_error: 3.6175 - 91ms/epoch - 9ms/step
Epoch 480/500
10/10 - 0s - loss: 19.0331 - mean_absolute_error: 3.8921 - val_loss: 19.1469 - val_mean_absolute_error: 4.0343 - 72ms/epoch - 7ms/step
Epoch 481/500
10/10 - 0s - loss: 19.5920 - mean_absolute_error: 4.5013 - val_loss: 19.6969 - val_mean_absolute_error: 4.6326 - 69ms/epoch - 7ms/step
Epoch 482/500
10/10 - 0s - loss: 19.1715 - mean_absolute_error: 4.1272 - val_loss: 19.1224 - val_mean_absolute_error: 4.1040 - 74ms/epoch - 7ms/step
Epoch 483/500
10/10 - 0s - loss: 19.8812 - mean_absolute_error: 4.8832 - val_loss: 18.6707 - val_mean_absolute_error: 3.6980 - 95ms/epoch - 10ms/step
Epoch 484/500
10/10 - 0s - loss: 19.1581 - mean_absolute_error: 4.2063 - val_loss: 19.3597 - val_mean_absolute_error: 4.4352 - 78ms/epoch - 8ms/step
Epoch 485/500
10/10 - 0s - loss: 19.0230 - mean_absolute_error: 4.1206 - val_loss: 18.4774 - val_mean_absolute_error: 3.6028 - 92ms/epoch - 9ms/step
Epoch 486/500
10/10 - 0s - loss: 18.8798 - mean_absolute_error: 4.0279 - val_loss: 20.5281 - val_mean_absolute_error: 5.7055 - 87ms/epoch - 9ms/step
Epoch 487/500
10/10 - 0s - loss: 19.0226 - mean_absolute_error: 4.2233 - val_loss: 19.2094 - val_mean_absolute_error: 4.4396 - 72ms/epoch - 7ms/step
Epoch 488/500
10/10 - 0s - loss: 18.9576 - mean_absolute_error: 4.2107 - val_loss: 19.1106 - val_mean_absolute_error: 4.3923 - 77ms/epoch - 8ms/step
Epoch 489/500
10/10 - 0s - loss: 19.3164 - mean_absolute_error: 4.6182 - val_loss: 20.5644 - val_mean_absolute_error: 5.8904 - 85ms/epoch - 8ms/step
Epoch 490/500
10/10 - 0s - loss: 18.5034 - mean_absolute_error: 3.8501 - val_loss: 19.6553 - val_mean_absolute_error: 5.0289 - 84ms/epoch - 8ms/step
Epoch 491/500
10/10 - 0s - loss: 18.9104 - mean_absolute_error: 4.3049 - val_loss: 18.4100 - val_mean_absolute_error: 3.8302 - 89ms/epoch - 9ms/step
Epoch 492/500
10/10 - 0s - loss: 19.3977 - mean_absolute_error: 4.8379 - val_loss: 19.0681 - val_mean_absolute_error: 4.5329 - 89ms/epoch - 9ms/step
Epoch 493/500
10/10 - 0s - loss: 18.5063 - mean_absolute_error: 3.9908 - val_loss: 18.5971 - val_mean_absolute_error: 4.1069 - 72ms/epoch - 7ms/step
Epoch 494/500
10/10 - 0s - loss: 23.5121 - mean_absolute_error: 9.0426 - val_loss: 18.0131 - val_mean_absolute_error: 3.5696 - 92ms/epoch - 9ms/step
Epoch 495/500
10/10 - 0s - loss: 18.6455 - mean_absolute_error: 4.2199 - val_loss: 18.5258 - val_mean_absolute_error: 4.1229 - 78ms/epoch - 8ms/step
Epoch 496/500
10/10 - 0s - loss: 18.8162 - mean_absolute_error: 4.4317 - val_loss: 18.3150 - val_mean_absolute_error: 3.9539 - 75ms/epoch - 8ms/step
Epoch 497/500
10/10 - 0s - loss: 18.5742 - mean_absolute_error: 4.2320 - val_loss: 18.9121 - val_mean_absolute_error: 4.5930 - 69ms/epoch - 7ms/step
Epoch 498/500
10/10 - 0s - loss: 18.7194 - mean_absolute_error: 4.4183 - val_loss: 18.6576 - val_mean_absolute_error: 4.3785 - 74ms/epoch - 7ms/step
Epoch 499/500
10/10 - 0s - loss: 18.4750 - mean_absolute_error: 4.2142 - val_loss: 17.9634 - val_mean_absolute_error: 3.7261 - 108ms/epoch - 11ms/step
Epoch 500/500
10/10 - 0s - loss: 17.9538 - mean_absolute_error: 3.7362 - val_loss: 17.6494 - val_mean_absolute_error: 3.4570 - 111ms/epoch - 11ms/step