I have a CNN model that classifies waveforms (of the shape (601,3), where 601 is the number of timesteps while 3 is the number of channels) into noise or signal. It is as follows:

# imports
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

random_seed = 42

tf.random.set_seed(random_seed)

model = keras.Sequential([
    layers.Input(shape=(601, 3)),  # Input shape for 1D data
    layers.Conv1D(32, kernel_size=16, activation='relu'),
    layers.Conv1D(64, kernel_size=16, activation='relu'),
    layers.Conv1D(128, kernel_size=16, activation='relu'),
    layers.Flatten(),
    layers.Dense(80, activation='relu'),
    layers.Dense(80, activation='relu'),
    layers.Dense(2, activation='softmax')
])

optimizer = keras.optimizers.Adam(learning_rate=0.001)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

num_epochs = 40
batch_size = 48

history = model.fit(X_train, y_train_encoded, epochs=num_epochs, batch_size=batch_size,
                    validation_data=(X_test, y_test_encoded), verbose=2)

# X_train shape: (num_train_samples,601,3)
# X_test shape: (num_test_samples,601,3)
# y_train_encoded shape: (num_train_samples,2)
# y_test_encoded shape: (num_test_samples,2)

The above model runs finely converges at the 12th epoch and gives an accuracy of over 99% after training over all the epochs.

The problem arises when I try to convert the above CNN to a Bayesian CNN by replacing the Conv1D and Dense layers with Convolution1DFlipout and DenseFlipout layers respectively.

# imports
import tensorflow_probability as tfp

tfd = tfp.distributions
tfpl = tfp.layers

random_seed = 42

tf.random.set_seed(random_seed)

num_training_samples = X_train.shape[0]
kl_divergence_fn = lambda q, p, _: tfd.kl_divergence(q, p) / num_training_samples

model = keras.Sequential([
    layers.Input(shape=(601, 3)),
    tfpl.Convolution1DFlipout(
        32, kernel_size=16, activation=tf.nn.relu, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn),
    tfpl.Convolution1DFlipout(
        64, kernel_size=16, activation=tf.nn.relu, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn),
    tfpl.Convolution1DFlipout(
        128, kernel_size=16, activation=tf.nn.relu, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn),
    layers.MaxPooling1D(pool_size=2),
    layers.Flatten(),
    tfpl.DenseFlipout(80, activation=tf.nn.relu, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn),
    tfpl.DenseFlipout(80, activation=tf.nn.relu, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn),
    tfpl.DenseFlipout(2, activation=tf.nn.softmax, kernel_divergence_fn=kl_divergence_fn, bias_divergence_fn=kl_divergence_fn)
])

optimizer = keras.optimizers.Adam(learning_rate=0.001)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

# Train the model (same as before)
num_epochs = 40
batch_size = 48

history = model.fit(X_train, y_train_encoded, epochs=num_epochs, batch_size=batch_size,
                    validation_data=(X_test, y_test_encoded), verbose=2)

This model doesn't seem to converge. Could someone please help me figure this out?

This is how the model trains:

 Layer (type)                Output Shape              Param #   
=================================================================
 conv1d_flipout (Conv1DFlipo  (None, 586, 32)          3104      
 ut)                                                             
                                                                 
 conv1d_flipout_1 (Conv1DFli  (None, 571, 64)          65600     
 pout)                                                           
                                                                 
 conv1d_flipout_2 (Conv1DFli  (None, 556, 128)         262272    
 pout)                                                           
                                                                 
 flatten (Flatten)           (None, 71168)             0         
                                                                 
 dense_flipout (DenseFlipout  (None, 80)               11386960  
 )                                                               
                                                                 
 dense_flipout_1 (DenseFlipo  (None, 80)               12880     
 ut)                                                             
                                                                 
 dense_flipout_2 (DenseFlipo  (None, 2)                322       
 ut)                                                             
                                                                 
=================================================================
Total params: 11,731,138
Trainable params: 11,731,138
Non-trainable params: 0
_________________________________________________________________
Epoch 1/40
3572/3572 - 1937s - loss: 44.5086 - accuracy: 0.4996 - val_loss: 10.1754 - val_accuracy: 0.5016 - 1937s/epoch - 542ms/step
Epoch 2/40
3572/3572 - 1933s - loss: 3.7701 - accuracy: 0.4993 - val_loss: 1.8013 - val_accuracy: 0.4946 - 1933s/epoch - 541ms/step
Epoch 3/40
3572/3572 - 1923s - loss: 1.4620 - accuracy: 0.5002 - val_loss: 1.2377 - val_accuracy: 0.5034 - 1923s/epoch - 538ms/step
Epoch 4/40
3572/3572 - 1927s - loss: 1.1489 - accuracy: 0.4992 - val_loss: 1.0744 - val_accuracy: 0.4994 - 1927s/epoch - 540ms/step
Epoch 5/40
3572/3572 - 1972s - loss: 1.0328 - accuracy: 0.4990 - val_loss: 0.9811 - val_accuracy: 0.5031 - 1972s/epoch - 552ms/step
Epoch 6/40
3572/3572 - 1919s - loss: 0.9570 - accuracy: 0.4994 - val_loss: 0.9243 - val_accuracy: 0.4970 - 1919s/epoch - 537ms/step
Epoch 7/40
3572/3572 - 1963s - loss: 0.9232 - accuracy: 0.4995 - val_loss: 0.9010 - val_accuracy: 0.4969 - 1963s/epoch - 550ms/step
Epoch 8/40
3572/3572 - 1928s - loss: 0.8889 - accuracy: 0.5007 - val_loss: 0.8608 - val_accuracy: 0.4970 - 1928s/epoch - 540ms/step
Epoch 9/40
3572/3572 - 1928s - loss: 0.8496 - accuracy: 0.5029 - val_loss: 0.8410 - val_accuracy: 0.5030 - 1928s/epoch - 540ms/step
Epoch 10/40
3572/3572 - 1928s - loss: 0.8387 - accuracy: 0.4997 - val_loss: 0.8268 - val_accuracy: 0.5031 - 1928s/epoch - 540ms/step
Epoch 11/40
3572/3572 - 1976s - loss: 0.8187 - accuracy: 0.4998 - val_loss: 0.8114 - val_accuracy: 0.5033 - 1976s/epoch - 553ms/step
Epoch 12/40
3572/3572 - 1925s - loss: 0.8262 - accuracy: 0.4999 - val_loss: 0.8038 - val_accuracy: 0.5031 - 1925s/epoch - 539ms/step
Epoch 13/40
3572/3572 - 1926s - loss: 0.8069 - accuracy: 0.4993 - val_loss: 0.8509 - val_accuracy: 0.4968 - 1926s/epoch - 539ms/step
Epoch 14/40
3572/3572 - 1932s - loss: 0.8109 - accuracy: 0.5017 - val_loss: 0.8335 - val_accuracy: 0.4969 - 1932s/epoch - 541ms/step
Epoch 15/40
3572/3572 - 1973s - loss: 0.7914 - accuracy: 0.4991 - val_loss: 0.7814 - val_accuracy: 0.5031 - 1973s/epoch - 552ms/step
Epoch 16/40
3572/3572 - 1927s - loss: 0.8024 - accuracy: 0.4986 - val_loss: 0.7761 - val_accuracy: 0.4969 - 1927s/epoch - 540ms/step
Epoch 17/40
3572/3572 - 1927s - loss: 0.7867 - accuracy: 0.5000 - val_loss: 0.7676 - val_accuracy: 0.5031 - 1927s/epoch - 540ms/step
Epoch 18/40
3572/3572 - 1926s - loss: 0.7835 - accuracy: 0.5003 - val_loss: 0.7686 - val_accuracy: 0.4970 - 1926s/epoch - 539ms/step
Epoch 19/40
3572/3572 - 1933s - loss: 0.7838 - accuracy: 0.4994 - val_loss: 0.7704 - val_accuracy: 0.5030 - 1933s/epoch - 541ms/step
Epoch 20/40
3572/3572 - 1929s - loss: 0.8191 - accuracy: 0.5012 - val_loss: 0.7761 - val_accuracy: 0.4969 - 1929s/epoch - 540ms/step
Epoch 21/40
3572/3572 - 1931s - loss: 0.7637 - accuracy: 0.5027 - val_loss: 0.7741 - val_accuracy: 0.5031 - 1931s/epoch - 541ms/step
Epoch 22/40
3572/3572 - 1931s - loss: 0.7637 - accuracy: 0.4984 - val_loss: 0.7571 - val_accuracy: 0.4970 - 1931s/epoch - 541ms/step
Epoch 23/40
3572/3572 - 1926s - loss: 0.7640 - accuracy: 0.4983 - val_loss: 0.8398 - val_accuracy: 0.4969 - 1926s/epoch - 539ms/step
Epoch 24/40
3572/3572 - 1969s - loss: 0.7849 - accuracy: 0.4994 - val_loss: 0.7513 - val_accuracy: 0.5031 - 1969s/epoch - 551ms/step
Epoch 25/40
3572/3572 - 1932s - loss: 0.7741 - accuracy: 0.4988 - val_loss: 0.7600 - val_accuracy: 0.5031 - 1932s/epoch - 541ms/step
Epoch 26/40
3572/3572 - 1923s - loss: 0.8127 - accuracy: 0.5012 - val_loss: 0.7449 - val_accuracy: 0.5083 - 1923s/epoch - 538ms/step
Epoch 27/40
3572/3572 - 1925s - loss: 0.7586 - accuracy: 0.5002 - val_loss: 0.7445 - val_accuracy: 0.4969 - 1925s/epoch - 539ms/step
Epoch 28/40
3572/3572 - 1975s - loss: 0.7814 - accuracy: 0.4991 - val_loss: 0.7599 - val_accuracy: 0.4969 - 1975s/epoch - 553ms/step
Epoch 29/40
3572/3572 - 1926s - loss: 0.7589 - accuracy: 0.4994 - val_loss: 0.7458 - val_accuracy: 0.4968 - 1926s/epoch - 539ms/step
Epoch 30/40
3572/3572 - 1927s - loss: 0.7518 - accuracy: 0.5008 - val_loss: 0.7425 - val_accuracy: 0.4970 - 1927s/epoch - 540ms/step
Epoch 31/40
3572/3572 - 1976s - loss: 0.9064 - accuracy: 0.4973 - val_loss: 0.7523 - val_accuracy: 0.4970 - 1976s/epoch - 553ms/step
Epoch 32/40
3572/3572 - 1926s - loss: 0.7519 - accuracy: 0.5005 - val_loss: 0.7480 - val_accuracy: 0.4969 - 1926s/epoch - 539ms/step
Epoch 33/40
3572/3572 - 1962s - loss: 0.7544 - accuracy: 0.5018 - val_loss: 0.7468 - val_accuracy: 0.5031 - 1962s/epoch - 549ms/step
Epoch 34/40
3572/3572 - 1933s - loss: 0.7630 - accuracy: 0.4998 - val_loss: 0.7402 - val_accuracy: 0.4917 - 1933s/epoch - 541ms/step
Epoch 35/40
3572/3572 - 1974s - loss: 0.7534 - accuracy: 0.5004 - val_loss: 0.7421 - val_accuracy: 0.4970 - 1974s/epoch - 553ms/step
Epoch 36/40
3572/3572 - 1928s - loss: 0.7463 - accuracy: 0.4994 - val_loss: 0.7394 - val_accuracy: 0.5031 - 1928s/epoch - 540ms/step
Epoch 37/40
3572/3572 - 1920s - loss: 0.7945 - accuracy: 0.4983 - val_loss: 0.7901 - val_accuracy: 0.4969 - 1920s/epoch - 537ms/step
Epoch 38/40
3572/3572 - 1930s - loss: 0.7470 - accuracy: 0.4995 - val_loss: 0.7348 - val_accuracy: 0.4970 - 1930s/epoch - 540ms/step
Epoch 39/40
3572/3572 - 1923s - loss: 0.7387 - accuracy: 0.4999 - val_loss: 0.7436 - val_accuracy: 0.5031 - 1923s/epoch - 538ms/step
Epoch 40/40
3572/3572 - 1945s - loss: 0.7480 - accuracy: 0.5011 - val_loss: 0.7336 - val_accuracy: 0.4969 - 1945s/epoch - 544ms/step

0

There are 0 answers