I have been running my code for several days but got stuck in a problem. The problem is that when I am training my adversarial autoencoder, the loss of both the generator and discriminator decreases. After some epochs, the discriminator's loss is very close to 0. What is the main problem? I would highly appreciate it if someone could help me in this regard. Here is the code of the network:
def build_encoder_layer(input_shape, encoder_reshape_shape):
input_layer = layers.Input(shape=input_shape)
x = layers.Bidirectional(CUDNNLSTM(units=window_size, return_sequences=True))(input_layer)
x = layers.Activation(activation='ReLU')(x) # new part
x = BatchNormalization()(x)
x = layers.Dropout(rate=0.2)(x)
# x = layers.Dropout(rate=0.2)(x)
# new LSTM layer
x = layers.Bidirectional(CUDNNLSTM(units=100, return_sequences=True))(x)
x = layers.Activation(activation='ReLU')(x)
x = BatchNormalization()(x)
x = layers.Dropout(rate=0.2)(x)
x = layers.Bidirectional(CUDNNLSTM(units=100, return_sequences=True))(x)
x = layers.Activation(activation='ReLU')(x)
x = BatchNormalization()(x)
x = layers.Dropout(rate=0.2)(x)
x = layers.Flatten()(x)
x = layers.Dense(20)(x)
x = layers.Reshape(target_shape=encoder_reshape_shape)(x)
# x = layers.Activation(activation='tanh')(x)
model = keras.models.Model(input_layer, x, name='encoder')
return model
def build_generator_layer(input_shape, generator_reshape_shape):
input_layer = layers.Input(shape=input_shape)
x = layers.Flatten()(input_layer)
x = layers.Dense(generator_reshape_shape[0])(x)
x = layers.Reshape(target_shape=generator_reshape_shape)(x)
x = layers.Bidirectional(CUDNNLSTM(units=64, return_sequences=True), merge_mode='concat')(x)
x = layers.Activation(activation='ReLU')(x) # new part. we need to add dropout after activation for gen
x = BatchNormalization()(x)
x = layers.Dropout(rate=0.2)(x)
x = layers.UpSampling1D(size=2)(x)
# New LSTM layer
x = layers.Bidirectional(CUDNNLSTM(units=64, return_sequences=True), merge_mode='concat')(x)
x = layers.Activation(activation='ReLU')(x)
x = BatchNormalization()(x)
x = layers.Dropout(rate=0.2)(x)
# New LSTM layer
x = layers.Bidirectional(CUDNNLSTM(units=64, return_sequences=True), merge_mode='concat')(x)
x = layers.Activation(activation='ReLU')(x)
x = BatchNormalization()(x)
x = layers.Dropout(rate=0.2)(x)
x = layers.Bidirectional(CUDNNLSTM(units=64, return_sequences=True), merge_mode='concat')(x)
x = layers.Activation(activation='ReLU')(x)
x = BatchNormalization()(x)
x = layers.Dropout(rate=0.2)(x)
x = layers.TimeDistributed(layers.Dense(1))(x)
x = layers.Activation(activation='tanh')(x) # originally was relu
model = keras.models.Model(input_layer, x, name='generator')
return model
def build_critic_x_layer(input_shape):
input_layer = layers.Input(shape=input_shape)
x = layers.Conv1D(filters=64, kernel_size=5)(input_layer)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(rate=0.25)(x)
x = layers.Flatten()(x)
x = layers.Dense(units=100)(x)
model = keras.models.Model(input_layer, x, name='critic_x')
return model
def build_critic_z_layer(input_shape):
input_layer = layers.Input(shape=input_shape)
x = layers.Conv1D(filters=64, kernel_size=5)(input_layer)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(rate=0.2)(x)
x = layers.Flatten()(x)
model = keras.models.Model(input_layer, x, name='critic_z')
return model
I also tested more convolutional layers for critics, different learning rates, different batch sizes. However, in all the cases the loss of critics will decrease.
the input data is processed to be a set of data with window_size=100, the latent space size is 20.
The learning on batches are written in this way:
def critic_x_train_on_batch(x, z):
# Loss
with tf.GradientTape() as tape:
valid_x = critic_x(x)
x_ = generator(z)
fake_x = critic_x(x_)
# Interpolated
alpha = tf.random.uniform([batch_size, 1, 1], 0.0, 1.0, dtype=tf.dtypes.float32)
x_ = tf.cast(x_, dtype='float32')
x = tf.cast(x, dtype='float32')
interpolated = alpha * x + (1 - alpha) * x_
with tf.GradientTape() as gp_tape:
gp_tape.watch(interpolated)
pred = critic_x(interpolated)
grads = gp_tape.gradient(pred, interpolated)
grad_norm = tf.norm(tf.reshape(grads, (batch_size, -1)), axis=1)
gp_loss = 10.0*tf.reduce_mean(tf.square(grad_norm - 1.))
# grads = tf.square(grads)
# ddx = tf.sqrt(tf.reduce_sum(grads, axis=np.arange(1, len(grads.shape))))
# gp_loss = tf.reduce_mean((1.0 - ddx) ** 2)
loss1 = wasserstein_loss(-tf.ones_like(valid_x), valid_x)
loss2 = wasserstein_loss(tf.ones_like(fake_x), fake_x)
#loss = tf.add_n([loss1, loss2, gp_loss*10.0])
loss = loss1 + loss2 + gp_loss
# loss = tf.reduce_mean(loss)
gradients = tape.gradient(loss, critic_x.trainable_weights)
critic_x_optimizer.apply_gradients(zip(gradients, critic_x.trainable_weights))
return loss
Critic Z
def critic_z_train_on_batch(x, z):
with tf.GradientTape() as tape:
z_ = encoder(x)
valid_z = critic_z(z)
fake_z = critic_z(z_) # <- critic_z
# Interpolated
alpha = tf.random.uniform([batch_size, 1, 1], 0.0, 1.0)
interpolated = alpha * z + (1 - alpha) * z_
with tf.GradientTape() as gp_tape:
gp_tape.watch(interpolated)
pred = critic_z(interpolated, training=True)
grads = gp_tape.gradient(pred, interpolated)
grad_norm = tf.norm(tf.reshape(grads, (batch_size, -1)), axis=1)
gp_loss = 10.0*tf.reduce_mean(tf.square(grad_norm - 1.))
# grads = tf.square(grads)
# ddx = tf.sqrt(tf.reduce_sum(grads, axis=np.arange(1, len(grads.shape))))
# gp_loss = tf.reduce_mean((1.0 - ddx) ** 2)
loss1 = wasserstein_loss(-tf.ones_like(valid_z), valid_z)
loss2 = wasserstein_loss(tf.ones_like(fake_z), fake_z)
loss = loss1 + loss2 + gp_loss
# loss = tf.reduce_mean(loss)
gradients = tape.gradient(loss, critic_z.trainable_weights)
critic_z_optimizer.apply_gradients(zip(gradients, critic_z.trainable_weights))
return loss
Generator Train
@tf.function
def enc_gen_train_on_batch(x, z):
with tf.GradientTape() as enc_tape:
z_gen_ = encoder(x, training=True)
x_gen_ = generator(z, training=False)
x_gen_rec = generator(z_gen_, training=False)
fake_gen_x = critic_x(x_gen_, training=False)
fake_gen_z = critic_z(z_gen_, training=False)
loss1 = wasserstein_loss(fake_gen_x, -tf.ones_like(fake_gen_x))
loss2 = wasserstein_loss(fake_gen_z, -tf.ones_like(fake_gen_z))
loss3 = 10.0*tf.reduce_mean(tf.keras.losses.MSE(x, x_gen_rec))
enc_loss = loss1 + loss2 + loss3
# enc_loss = loss3
gradients_encoder = enc_tape.gradient(enc_loss, encoder.trainable_weights)
encoder_optimizer.apply_gradients(zip(gradients_encoder, encoder.trainable_weights))
with tf.GradientTape() as gen_tape:
z_gen_ = encoder(x, training=False)
x_gen_ = generator(z, training=True)
x_gen_rec = generator(z_gen_, training=True)
fake_gen_x = critic_x(x_gen_, training=False)
fake_gen_z = critic_z(z_gen_, training=False)
loss1 = wasserstein_loss(fake_gen_x, -tf.ones_like(fake_gen_x))
loss2 = wasserstein_loss(fake_gen_z, -tf.ones_like(fake_gen_z))
loss3 = 10.0*tf.reduce_mean(tf.keras.losses.MSE(x, x_gen_rec))
gen_loss = loss1 + loss2 + loss3
# gen_loss = loss3
gradients_generator = gen_tape.gradient(gen_loss, generator.trainable_weights)
generator_optimizer.apply_gradients(zip(gradients_generator, generator.trainable_weights))
return enc_loss, gen_loss
I expect to reconstruct time series data. As I expect, the loss of discriminator should increase after some epochs while the loss of generator should be decreasing, but the loss change is as follows.
Epoch: 1/30, [Dx loss: 1.098870038986206] [Dz loss: -0.10945891588926315] [E loss: 5.278852462768555] [G loss: 3.7521088123321533]
Epoch: 2/30, [Dx loss: 0.5344677567481995] [Dz loss: -0.13788911700248718] [E loss: 3.6803643703460693] [G loss: 3.213960886001587]
Epoch: 3/30, [Dx loss: 0.34554678201675415] [Dz loss: -0.1143367737531662] [E loss: 3.154308319091797] [G loss: 2.820709228515625]
Epoch: 4/30, [Dx loss: 0.2561565041542053] [Dz loss: -0.1354585587978363] [E loss: 3.0694801807403564] [G loss: 3.1366894245147705]
Epoch: 5/30, [Dx loss: 0.20118674635887146] [Dz loss: -0.1930069476366043] [E loss: 2.9434409141540527] [G loss: 3.0156397819519043]
Epoch: 6/30, [Dx loss: 0.16586175560951233] [Dz loss: -0.2611238360404968] [E loss: 3.078233480453491] [G loss: 2.937089443206787]
Epoch: 7/30, [Dx loss: 0.13812018930912018] [Dz loss: -0.2792683243751526] [E loss: 3.0081939697265625] [G loss: 2.8207926750183105]
Epoch: 8/30, [Dx loss: 0.11749842762947083] [Dz loss: -0.3390710949897766] [E loss: 3.2441465854644775] [G loss: 2.751972198486328]
Epoch: 9/30, [Dx loss: 0.09816166758537292] [Dz loss: -0.4120389521121979] [E loss: 3.4268479347229004] [G loss: 2.742722988128662]
This would indicate that the Discriminator is overpowering the Generator and no more helpful information is being learned. This can be verified by visualizing the backpropagating gradients. Usually, you would have regularization techniques to combat this behavior (e.g., R1 regularization).
I would first try changing the learning rate of the Discriminator, try by half to the Generator, or even more. Some popular projects, like StyleGAN3, also use blurring for several kilo images (kimgs) displayed to the Discriminator.
Another popular approach is FreezeD (if coming from a pre-trained setting) or ADA (adaptive augmentation) to the Discriminator; you can check how they do it in StyleGAN2-ADA paper or github.