I have a CNN model that classifies waveforms (of the shape (601,3), where 601 is the number of timesteps while 3 is the number of channels) into noise or signal. It is as follows:
# imports
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
random_seed = 42
tf.random.set_seed(random_seed)
model = keras.Sequential([
layers.Input(shape=(601, 3)), # Input shape for 1D data
layers.Conv1D(32, kernel_size=16, activation='relu'),
layers.Conv1D(64, kernel_size=16, activation='relu'),
layers.Conv1D(128, kernel_size=16, activation='relu'),
layers.Flatten(),
layers.Dense(80, activation='relu'),
layers.Dense(80, activation='relu'),
layers.Dense(2, activation='softmax')
])
optimizer = keras.optimizers.Adam(learning_rate=0.001)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
num_epochs = 40
batch_size = 48
history = model.fit(X_train, y_train_encoded, epochs=num_epochs, batch_size=batch_size,
validation_data=(X_test, y_test_encoded), verbose=2)
# X_train shape: (num_train_samples,601,3)
# X_test shape: (num_test_samples,601,3)
# y_train_encoded shape: (num_train_samples,2)
# y_test_encoded shape: (num_test_samples,2)
The above model runs finely converges at the 12th epoch and gives an accuracy of over 99% after training over all the epochs.
The problem arises when I try to convert the above CNN to a Bayesian CNN by replacing the Conv1D and Dense layers with Convolution1DFlipout and DenseFlipout layers respectively.
# imports
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers
random_seed = 42
tf.random.set_seed(random_seed)
num_training_samples = X_train.shape[0]
kl_divergence_fn = lambda q, p, _: tfd.kl_divergence(q, p) / num_training_samples
model = keras.Sequential([
layers.Input(shape=(601, 3)),
tfpl.Convolution1DFlipout(
32, kernel_size=16, activation=tf.nn.relu, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn),
tfpl.Convolution1DFlipout(
64, kernel_size=16, activation=tf.nn.relu, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn),
tfpl.Convolution1DFlipout(
128, kernel_size=16, activation=tf.nn.relu, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn),
layers.MaxPooling1D(pool_size=2),
layers.Flatten(),
tfpl.DenseFlipout(80, activation=tf.nn.relu, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn),
tfpl.DenseFlipout(80, activation=tf.nn.relu, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn),
tfpl.DenseFlipout(2, activation=tf.nn.softmax, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn)
])
optimizer = keras.optimizers.Adam(learning_rate=0.001)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
# Train the model (same as before)
num_epochs = 40
batch_size = 48
history = model.fit(X_train, y_train_encoded, epochs=num_epochs, batch_size=batch_size,
validation_data=(X_test, y_test_encoded), verbose=2)
This model doesn't seem to converge. Could someone please help me figure this out?
This is how the model trains:
Layer (type) Output Shape Param #
=================================================================
conv1d_flipout (Conv1DFlipo (None, 586, 32) 3104
ut)
conv1d_flipout_1 (Conv1DFli (None, 571, 64) 65600
pout)
conv1d_flipout_2 (Conv1DFli (None, 556, 128) 262272
pout)
flatten (Flatten) (None, 71168) 0
dense_flipout (DenseFlipout (None, 80) 11386960
)
dense_flipout_1 (DenseFlipo (None, 80) 12880
ut)
dense_flipout_2 (DenseFlipo (None, 2) 322
ut)
=================================================================
Total params: 11,731,138
Trainable params: 11,731,138
Non-trainable params: 0
_________________________________________________________________
Epoch 1/40
3572/3572 - 1937s - loss: 44.5086 - accuracy: 0.4996 - val_loss: 10.1754 - val_accuracy: 0.5016 - 1937s/epoch - 542ms/step
Epoch 2/40
3572/3572 - 1933s - loss: 3.7701 - accuracy: 0.4993 - val_loss: 1.8013 - val_accuracy: 0.4946 - 1933s/epoch - 541ms/step
Epoch 3/40
3572/3572 - 1923s - loss: 1.4620 - accuracy: 0.5002 - val_loss: 1.2377 - val_accuracy: 0.5034 - 1923s/epoch - 538ms/step
Epoch 4/40
3572/3572 - 1927s - loss: 1.1489 - accuracy: 0.4992 - val_loss: 1.0744 - val_accuracy: 0.4994 - 1927s/epoch - 540ms/step
Epoch 5/40
3572/3572 - 1972s - loss: 1.0328 - accuracy: 0.4990 - val_loss: 0.9811 - val_accuracy: 0.5031 - 1972s/epoch - 552ms/step
Epoch 6/40
3572/3572 - 1919s - loss: 0.9570 - accuracy: 0.4994 - val_loss: 0.9243 - val_accuracy: 0.4970 - 1919s/epoch - 537ms/step
Epoch 7/40
3572/3572 - 1963s - loss: 0.9232 - accuracy: 0.4995 - val_loss: 0.9010 - val_accuracy: 0.4969 - 1963s/epoch - 550ms/step
Epoch 8/40
3572/3572 - 1928s - loss: 0.8889 - accuracy: 0.5007 - val_loss: 0.8608 - val_accuracy: 0.4970 - 1928s/epoch - 540ms/step
Epoch 9/40
3572/3572 - 1928s - loss: 0.8496 - accuracy: 0.5029 - val_loss: 0.8410 - val_accuracy: 0.5030 - 1928s/epoch - 540ms/step
Epoch 10/40
3572/3572 - 1928s - loss: 0.8387 - accuracy: 0.4997 - val_loss: 0.8268 - val_accuracy: 0.5031 - 1928s/epoch - 540ms/step
Epoch 11/40
3572/3572 - 1976s - loss: 0.8187 - accuracy: 0.4998 - val_loss: 0.8114 - val_accuracy: 0.5033 - 1976s/epoch - 553ms/step
Epoch 12/40
3572/3572 - 1925s - loss: 0.8262 - accuracy: 0.4999 - val_loss: 0.8038 - val_accuracy: 0.5031 - 1925s/epoch - 539ms/step
Epoch 13/40
3572/3572 - 1926s - loss: 0.8069 - accuracy: 0.4993 - val_loss: 0.8509 - val_accuracy: 0.4968 - 1926s/epoch - 539ms/step
Epoch 14/40
3572/3572 - 1932s - loss: 0.8109 - accuracy: 0.5017 - val_loss: 0.8335 - val_accuracy: 0.4969 - 1932s/epoch - 541ms/step
Epoch 15/40
3572/3572 - 1973s - loss: 0.7914 - accuracy: 0.4991 - val_loss: 0.7814 - val_accuracy: 0.5031 - 1973s/epoch - 552ms/step
Epoch 16/40
3572/3572 - 1927s - loss: 0.8024 - accuracy: 0.4986 - val_loss: 0.7761 - val_accuracy: 0.4969 - 1927s/epoch - 540ms/step
Epoch 17/40
3572/3572 - 1927s - loss: 0.7867 - accuracy: 0.5000 - val_loss: 0.7676 - val_accuracy: 0.5031 - 1927s/epoch - 540ms/step
Epoch 18/40
3572/3572 - 1926s - loss: 0.7835 - accuracy: 0.5003 - val_loss: 0.7686 - val_accuracy: 0.4970 - 1926s/epoch - 539ms/step
Epoch 19/40
3572/3572 - 1933s - loss: 0.7838 - accuracy: 0.4994 - val_loss: 0.7704 - val_accuracy: 0.5030 - 1933s/epoch - 541ms/step
Epoch 20/40
3572/3572 - 1929s - loss: 0.8191 - accuracy: 0.5012 - val_loss: 0.7761 - val_accuracy: 0.4969 - 1929s/epoch - 540ms/step
Epoch 21/40
3572/3572 - 1931s - loss: 0.7637 - accuracy: 0.5027 - val_loss: 0.7741 - val_accuracy: 0.5031 - 1931s/epoch - 541ms/step
Epoch 22/40
3572/3572 - 1931s - loss: 0.7637 - accuracy: 0.4984 - val_loss: 0.7571 - val_accuracy: 0.4970 - 1931s/epoch - 541ms/step
Epoch 23/40
3572/3572 - 1926s - loss: 0.7640 - accuracy: 0.4983 - val_loss: 0.8398 - val_accuracy: 0.4969 - 1926s/epoch - 539ms/step
Epoch 24/40
3572/3572 - 1969s - loss: 0.7849 - accuracy: 0.4994 - val_loss: 0.7513 - val_accuracy: 0.5031 - 1969s/epoch - 551ms/step
Epoch 25/40
3572/3572 - 1932s - loss: 0.7741 - accuracy: 0.4988 - val_loss: 0.7600 - val_accuracy: 0.5031 - 1932s/epoch - 541ms/step
Epoch 26/40
3572/3572 - 1923s - loss: 0.8127 - accuracy: 0.5012 - val_loss: 0.7449 - val_accuracy: 0.5083 - 1923s/epoch - 538ms/step
Epoch 27/40
3572/3572 - 1925s - loss: 0.7586 - accuracy: 0.5002 - val_loss: 0.7445 - val_accuracy: 0.4969 - 1925s/epoch - 539ms/step
Epoch 28/40
3572/3572 - 1975s - loss: 0.7814 - accuracy: 0.4991 - val_loss: 0.7599 - val_accuracy: 0.4969 - 1975s/epoch - 553ms/step
Epoch 29/40
3572/3572 - 1926s - loss: 0.7589 - accuracy: 0.4994 - val_loss: 0.7458 - val_accuracy: 0.4968 - 1926s/epoch - 539ms/step
Epoch 30/40
3572/3572 - 1927s - loss: 0.7518 - accuracy: 0.5008 - val_loss: 0.7425 - val_accuracy: 0.4970 - 1927s/epoch - 540ms/step
Epoch 31/40
3572/3572 - 1976s - loss: 0.9064 - accuracy: 0.4973 - val_loss: 0.7523 - val_accuracy: 0.4970 - 1976s/epoch - 553ms/step
Epoch 32/40
3572/3572 - 1926s - loss: 0.7519 - accuracy: 0.5005 - val_loss: 0.7480 - val_accuracy: 0.4969 - 1926s/epoch - 539ms/step
Epoch 33/40
3572/3572 - 1962s - loss: 0.7544 - accuracy: 0.5018 - val_loss: 0.7468 - val_accuracy: 0.5031 - 1962s/epoch - 549ms/step
Epoch 34/40
3572/3572 - 1933s - loss: 0.7630 - accuracy: 0.4998 - val_loss: 0.7402 - val_accuracy: 0.4917 - 1933s/epoch - 541ms/step
Epoch 35/40
3572/3572 - 1974s - loss: 0.7534 - accuracy: 0.5004 - val_loss: 0.7421 - val_accuracy: 0.4970 - 1974s/epoch - 553ms/step
Epoch 36/40
3572/3572 - 1928s - loss: 0.7463 - accuracy: 0.4994 - val_loss: 0.7394 - val_accuracy: 0.5031 - 1928s/epoch - 540ms/step
Epoch 37/40
3572/3572 - 1920s - loss: 0.7945 - accuracy: 0.4983 - val_loss: 0.7901 - val_accuracy: 0.4969 - 1920s/epoch - 537ms/step
Epoch 38/40
3572/3572 - 1930s - loss: 0.7470 - accuracy: 0.4995 - val_loss: 0.7348 - val_accuracy: 0.4970 - 1930s/epoch - 540ms/step
Epoch 39/40
3572/3572 - 1923s - loss: 0.7387 - accuracy: 0.4999 - val_loss: 0.7436 - val_accuracy: 0.5031 - 1923s/epoch - 538ms/step
Epoch 40/40
3572/3572 - 1945s - loss: 0.7480 - accuracy: 0.5011 - val_loss: 0.7336 - val_accuracy: 0.4969 - 1945s/epoch - 544ms/step