i was training a mlp through variational inference for a regression task on a small dataset with 1 feature. The nn works and the training loss goes down but the validation loss has random spikes and i do not understand how to avoid them
import tensorflow_probability as tfp
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
def create_flipout_bnn_model(train_size):
def normal_sp(params):
return tfd.Normal(loc=params[:,0:1], scale=1e-3 + tf.math.softplus(0.05 * params[:,1:2]))
kernel_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (train_size)
bias_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (train_size)
inputs = Input(shape=(1,),name="input layer")
hidden = tfp.layers.DenseFlipout(30,
kernel_divergence_fn=kernel_divergence_fn,
activation="relu",name="DenseFlipout_layer_1")(inputs)
hidden = tfp.layers.DenseFlipout(30,
kernel_divergence_fn=kernel_divergence_fn,
activation="relu",name="DenseFlipout_layer_2")(hidden)
hidden = tfp.layers.DenseFlipout(30,
kernel_divergence_fn=kernel_divergence_fn,
activation="relu",name="DenseFlipout_layer_3")(hidden)
params = tfp.layers.DenseFlipout(2,
kernel_divergence_fn=kernel_divergence_fn,
name="DenseFlipout_layer_5")(hidden)
dist = tfp.layers.DistributionLambda(normal_sp,name = 'normal_sp')(params)
model = Model(inputs=inputs, outputs=dist)
return model
batch_size = train_size
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=1830,restore_best_weights=True)
flipout_BNN = create_flipout_bnn_model(train_size=train_size)
flipout_BNN.compile(optimizer=Adam(learning_rate=0.002 ),jit_compile=True,
loss=NLL,metrics= [tf.keras.metrics.RootMeanSquaredError()]
)
flipout_BNN.summary()
history_flipout_BNN = flipout_BNN.fit(X_train, y_train, epochs=30000, verbose=0, batch_size=batch_size,validation_data=(X_val,y_val),callbacks=[callback] )
the result of the training is almost always something like this
The fluctuating
val_loss
could be due to the dataset containing outliers. To tackle this situation we canNormalize
the dataset and can useloss=tf.keras.losses.MeanAbsoluteError()
for better results. I have tried replicating the above code using one of the datasets from this page.Please have a look at the output: (Attaching the gist for your reference)
Output: