I have a dataset of images on which I am trying to train a GAN to create something similar. I have looked at various tutorials for how to train a GAN but am still getting random outputs from my generator despite the loss converging to zero during training. I have included my generator and discriminator code below:
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Input((128, 128, 3)))
model.add(layers.Conv2D(32, (4,4), strides=(1, 1), padding='same', activation=layers.LeakyReLU(alpha=0.2)))
model.add(layers.BatchNormalization())
model.add(layers.Dropout(0.4))
model.add(layers.Conv2D(64, (4,4), strides=(2, 2), padding='same', activation=layers.LeakyReLU(alpha=0.2)))
model.add(layers.Dropout(0.4))
model.add(layers.Conv2D(32, (4,4), strides=(1, 1), padding='same', activation=layers.LeakyReLU(alpha=0.2)))
model.add(layers.BatchNormalization())
model.add(layers.Dropout(0.4))
model.add(layers.Conv2D(128, (4,4), strides=(2, 2), padding='same', activation=layers.LeakyReLU(alpha=0.2)))
model.add(layers.Flatten())
model.add(layers.Dropout(0.4))
model.add(layers.Dense(1,activation='sigmoid'))
return model
discriminator = make_discriminator_model()
def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(16*16*128, input_dim=512, activation=layers.LeakyReLU(alpha=0.2)))
model.add(layers.BatchNormalization())
model.add(layers.Reshape((16, 16, 128)))
model.add(layers.Conv2DTranspose(128, (4,4), strides=(2, 2), padding='same', activation=layers.LeakyReLU(alpha=0.2)))
model.add(layers.BatchNormalization())
model.add(layers.Conv2DTranspose(128, (4,4), strides=(2, 2), padding='same', activation=layers.LeakyReLU(alpha=0.2)))
model.add(layers.BatchNormalization())
model.add(layers.Conv2DTranspose(128, (4,4), strides=(2, 2), padding='same', activation=layers.LeakyReLU(alpha=0.2)))
model.add(layers.Conv2D(3, (4,4), strides=(1, 1), padding='same', activation='tanh'))
return model
generator = make_generator_model()
My images are 128x128, with pixel values normalized to the range -1 to 1 so the output of the generator (tanh activation) matches the images being fed into the discriminator.
I have a custom train_step
to train the model which I have included below inside my GAN class:
class GAN(tf.keras.models.Model):
def __init__(self, generator, discriminator, *args, **kwargs):
# Pass through args and kwargs to base class
super().__init__(*args, **kwargs)
# Create attributes for gen and disc
self.generator = generator
self.discriminator = discriminator
def compile(self, g_opt, d_opt, g_loss, d_loss, *args, **kwargs):
# Compile with base class
super().compile(*args, **kwargs)
# Create attributes for losses and optimizers
self.g_opt = g_opt
self.d_opt = d_opt
self.g_loss = g_loss
self.d_loss = d_loss
def train_step(self, batch):
# Get the data
real_images = batch
fake_images = self.generator(tf.random.normal((128, 512, 1)), training=False)
# Train the discriminator
with tf.GradientTape() as d_tape:
# Pass the real and fake images to the discriminator model
yhat_real = self.discriminator(real_images, training=True)
yhat_fake = self.discriminator(fake_images, training=True)
yhat_realfake = tf.concat([yhat_real, yhat_fake], axis=0)
# Create labels for real and fakes images
y_realfake = tf.concat([tf.ones_like(yhat_real), tf.zeros_like(yhat_fake)], axis=0)
noise_real = -0.15*tf.random.uniform(tf.shape(yhat_real))
noise_fake = 0.15*tf.random.uniform(tf.shape(yhat_fake))
y_realfake += tf.concat([noise_real, noise_fake], axis=0)
# Calculate loss - BINARYCROSS
total_d_loss = self.d_loss(y_realfake, yhat_realfake)
# Apply backpropagation - nn learn
dgrad = d_tape.gradient(total_d_loss, self.discriminator.trainable_variables)
self.d_opt.apply_gradients(zip(dgrad, self.discriminator.trainable_variables))
# Train the generator
with tf.GradientTape() as g_tape:
# Generate some new images
gen_images = self.generator(tf.random.normal((128,512,1)), training=True)
# Create the predicted labels
predicted_labels = self.discriminator(gen_images, training=False)
# Calculate loss - trick to training to fake out the discriminator
total_g_loss = self.g_loss(tf.ones_like(predicted_labels), predicted_labels)
# Apply backprop
ggrad = g_tape.gradient(total_g_loss, self.generator.trainable_variables)
self.g_opt.apply_gradients(zip(ggrad, self.generator.trainable_variables))
return {"d_loss":total_d_loss, "g_loss":total_g_loss}
g_opt = tf.keras.optimizers.Adam(learning_rate=0.0001)
d_opt = tf.keras.optimizers.Adam(learning_rate=0.00001)
g_loss = BinaryCrossentropy()
d_loss = BinaryCrossentropy()
gan = GAN(generator, discriminator)
gan.compile(g_opt, d_opt, g_loss, d_loss)
I followed various tip and tricks taken from the following Github repository:
https://github.com/soumith/ganhacks
I have taken and edited a sublcass from the following repository after watching a tutorial:
https://github.com/nicknochnack/GANBasics/blob/main/FashionGAN-Tutorial.ipynb
I have trained the GAN for up to 300 epochs with random images still being output from the generator. I am not sure whether there is an issue with my loss function in the train_step
? The values for the output of the generator appear to be quite small and do not appear to be values anywhere close to the real images so I am not sure whether my generator layers may be doing something I am unaware of?
I did separately try to train the discriminator on real and random images and can see that alone (without any generator loss to consider), this model is training and making sensible predictions. During GAN training, the discriminator loss drops before flattening out at a low value.
My question is why would my generator not be training?
- Can it be the loss function
total_g_loss
? - Are the hidden layers of the generator too complex / simple for image generation?
Any other tips to debug this would be greatly appreciated!