I modified the WGAN-GP from the keras website (https://keras.io/examples/generative/wgan_gp/) for what I needed: my network takes 3 consecutive images as input. The generator takes the one in the middle and tries to generate a similar image, the discriminator takes all 3 images and evaluates if the second is really the central one among the 3. However it gives me this error when I call fit():
NotImplementedError: When subclassing the `Model' class, you should implement a call method.
I am using tensorflow 2.3
Model:
import tensorflow as tf
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU, MaxPooling2D
from tensorflow.keras.layers import Activation, Flatten, Dense
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import concatenate
from tensorflow.keras.layers import LayerNormalization
from tensorflow.keras.utils import plot_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers import SGD
import tensorflow.keras.backend as K
from tensorflow.keras.layers import GaussianNoise
from tensorflow import keras
import numpy as np
init = RandomNormal(mean=0.0, stddev=0.02)
GNOISE = 0.25
BATCH_SIZE = 32
def encodeImage(image):
layers = []
layer1 = Conv2D(32, 3, kernel_initializer=init, padding='same')(image)
layer1 = BatchNormalization()(layer1)
layer1 = LeakyReLU(alpha=0.01)(layer1)
layer1 = Conv2D(32, 3, kernel_initializer=init, padding='same')(layer1)
layer1 = BatchNormalization()(layer1)
layer1 = LeakyReLU(alpha=0.01)(layer1)
layer1 = GaussianNoise(GNOISE)(layer1)
pool1 = MaxPooling2D(pool_size=(2, 2))(layer1)
layers.append(layer1)
print(pool1.shape)
layer2 = Conv2D(64, 3, kernel_initializer=init, padding='same')(pool1)
layer2 = BatchNormalization()(layer2)
layer2 = LeakyReLU(alpha=0.01)(layer2)
layer2 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer2)
layer2 = BatchNormalization()(layer2)
layer2 = LeakyReLU(alpha=0.01)(layer2)
layer2 = GaussianNoise(GNOISE)(layer2)
pool2 = MaxPooling2D(pool_size=(2, 2))(layer2)
layers.append(layer2)
print(pool2.shape)
layer3 = Conv2D(64, 3, kernel_initializer=init, padding='same')(pool2)
layer3 = BatchNormalization(momentum=0.8)(layer3)
layer3 = LeakyReLU(alpha=0.01)(layer3)
layer3 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer3)
layer3 = BatchNormalization(momentum=0.8)(layer3)
layer3 = LeakyReLU(alpha=0.01)(layer3)
layer3 = GaussianNoise(GNOISE)(layer3)
pool3 = MaxPooling2D(pool_size=(2, 2))(layer3)
layers.append(layer3)
print(pool3.shape)
layer4 = Conv2D(128, 3, kernel_initializer=init, padding='same')(pool3)
layer4 = BatchNormalization(momentum=0.8)(layer4)
layer4 = LeakyReLU(alpha=0.01)(layer4)
layer4 = Conv2D(128, 3, kernel_initializer=init, padding='same')(layer4)
layer4 = BatchNormalization(momentum=0.8)(layer4)
layer4 = LeakyReLU(alpha=0.01)(layer4)
layer4 = GaussianNoise(GNOISE)(layer4)
pool4 = MaxPooling2D(pool_size=(2, 2))(layer4)
layers.append(layer4)
print(pool4.shape)
return pool4, layers
def decodeImage(image, layersA):
layer1 = Conv2DTranspose(128, 3, 2, kernel_initializer=init, padding='same')(image)
layer1 = Concatenate()([layer1, layersA[3]])
layer1 = Conv2D(128, 3, kernel_initializer=init, padding='same')(layer1)
layer1 = BatchNormalization(momentum=0.8)(layer1)
layer1 = LeakyReLU(alpha=0.01)(layer1)
layer1 = Dropout(0.2)(layer1, training=True)
layer1 = Conv2D(128, 3, kernel_initializer=init, padding='same')(layer1)
layer1 = BatchNormalization(momentum=0.8)(layer1)
layer1 = LeakyReLU(alpha=0.01)(layer1)
layer1 = GaussianNoise(GNOISE)(layer1)
layer2 = Conv2DTranspose(64, 3, 2, kernel_initializer=init, padding='same')(layer1)
layer2 = Concatenate()([layer2, layersA[2]])
layer2 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer2)
layer2 = BatchNormalization(momentum=0.8)(layer2)
layer2 = LeakyReLU(alpha=0.01)(layer2)
layer2 = Dropout(0.2)(layer2, training=True)
layer2 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer2)
layer2 = BatchNormalization(momentum=0.8)(layer2)
layer2 = LeakyReLU(alpha=0.01)(layer2)
layer2 = GaussianNoise(GNOISE)(layer2)
layer3 = Conv2DTranspose(64, 3, 2, kernel_initializer=init, padding='same')(layer2)
layer3 = Concatenate()([layer3, layersA[1]])
layer3 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer3)
layer3 = BatchNormalization(momentum=0.8)(layer3)
layer3 = LeakyReLU(alpha=0.01)(layer3)
layer3 = Dropout(0.2)(layer3, training=True)
layer3 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer3)
layer3 = BatchNormalization(momentum=0.8)(layer3)
layer3 = LeakyReLU(alpha=0.01)(layer3)
layer3 = GaussianNoise(GNOISE)(layer3)
layer4 = Conv2DTranspose(32, 3, 2, kernel_initializer=init, padding='same')(layer3)
layer4 = Concatenate()([layer4, layersA[0]])
layer4 = Conv2D(32, 3, kernel_initializer=init, padding='same')(layer4)
layer4 = BatchNormalization(momentum=0.8)(layer4)
layer4 = LeakyReLU(alpha=0.01)(layer4)
layer4 = Dropout(0.2)(layer4, training=True)
layer4 = Conv2D(32, 3, kernel_initializer=init, padding='same')(layer4)
layer4 = BatchNormalization(momentum=0.8)(layer4)
layer4 = LeakyReLU(alpha=0.01)(layer4)
layer4 = GaussianNoise(GNOISE)(layer4)
outImage = Conv2D(1, 1, kernel_initializer=init, padding='same')(layer4)
outImage = Activation('tanh')(outImage)
return outImage
def defineGenerator(inputShape):
realImage = Input(shape=inputShape)
encodedA, layersA = encodeImage(realImage)
combined = Conv2D(8, 3, kernel_initializer=init, input_shape=inputShape, padding='same')(encodedA)
combined = BatchNormalization(momentum=0.8)(combined)
combined = LeakyReLU(alpha=0.01)(combined)
combined = Conv2D(8, 3, kernel_initializer=init, input_shape=inputShape, padding='same')(combined)
combined = BatchNormalization(momentum=0.8)(combined)
combined = LeakyReLU(alpha=0.01)(combined)
outImage = decodeImage(combined, layersA)
model = Model(realImage, outImage)
return model
def defineDiscriminator(imageShape):
imageA = Input(shape=imageShape)
imageB = Input(shape=imageShape)
imageC = Input(shape=imageShape)
merged = Concatenate()([imageA, imageB, imageC])
layer1 = Conv2D(32, 5, 2, padding='same', kernel_initializer=init)(merged)
#layer1 = BatchNormalization(momentum=0.8)(layer1)
layer1 = LayerNormalization()(layer1)
layer1 = LeakyReLU(alpha=0.01)(layer1)
layer1 = Dropout(0.2)(layer1, training=True)
layer1 = Conv2D(32, 5, padding='same', kernel_initializer=init)(layer1)
#layer1 = BatchNormalization(momentum=0.8)(layer1)
layer1 = LayerNormalization()(layer1)
layer1 = LeakyReLU(alpha=0.01)(layer1)
layer1 = GaussianNoise(GNOISE)(layer1)
layer2 = Conv2D(64, 5, 2, padding='same', kernel_initializer=init)(layer1)
layer2 = LeakyReLU(alpha=0.01)(layer2)
#layer2 = BatchNormalization(momentum=0.8)(layer2)
layer2 = LayerNormalization()(layer2)
layer2 = Dropout(0.2)(layer2, training=True)
layer2 = Conv2D(64, 5, padding='same', kernel_initializer=init)(layer2)
layer2 = LeakyReLU(alpha=0.01)(layer2)
#layer2 = BatchNormalization(momentum=0.8)(layer2)
layer2 = LayerNormalization()(layer2)
layer2 = GaussianNoise(GNOISE)(layer2)
layer3 = Conv2D(64, 5, 2, padding='same', kernel_initializer=init)(layer2)
layer3 = LeakyReLU(alpha=0.01)(layer3)
#layer3 = BatchNormalization(momentum=0.8)(layer3)
layer3 = LayerNormalization()(layer3)
layer3 = Dropout(0.3)(layer3, training=True)
layer3 = Conv2D(64, 5, padding='same', kernel_initializer=init)(layer3)
layer3 = LeakyReLU(alpha=0.01)(layer3)
#layer3 = BatchNormalization(momentum=0.8)(layer3)
layer3 = LayerNormalization()(layer3)
layer3 = GaussianNoise(GNOISE)(layer3)
layer4 = Conv2D(128, 5, 2, padding='same', kernel_initializer=init)(layer3)
layer4 = LeakyReLU(alpha=0.01)(layer4)
#layer4 = BatchNormalization(momentum=0.8)(layer4)
layer4 = LayerNormalization()(layer4)
layer4 = Dropout(0.3)(layer4, training=True)
layer4 = Conv2D(128, 5, 2, padding='same', kernel_initializer=init)(layer4)
layer4 = LeakyReLU(alpha=0.01)(layer4)
#layer4 = BatchNormalization(momentum=0.8)(layer4)
layer4 = LayerNormalization()(layer4)
layer4 = GaussianNoise(GNOISE)(layer4)
output = Flatten()(layer4)
output = Dense(1, activation=None)(output)
model = Model([imageA, imageB, imageC], output)
return model
class WGAN(Model):
def __init__(
self,
discriminator,
generator,
discriminator_extra_steps=3,
gp_weight=10.0,
):
super(WGAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.d_steps = discriminator_extra_steps
self.gp_weight = gp_weight
def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
super(WGAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.d_loss_fn = d_loss_fn
self.g_loss_fn = g_loss_fn
def gradient_penalty(self, batch_size, imgA, imgB, imgC, fakeImg):
""" Calculates the gradient penalty.
This loss is calculated on an interpolated image
and added to the discriminator loss.
"""
# get the interplated image
alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
diff = fakeImg - imgB
interpolated = imgB + alpha * diff
with tf.GradientTape() as gp_tape:
gp_tape.watch(interpolated)
# 1. Get the discriminator output for this interpolated image.
pred = self.discriminator([imgA, interpolated, imgC], training=True)
# 2. Calculate the gradients w.r.t to this interpolated image.
grads = gp_tape.gradient(pred, [interpolated])[0]
# 3. Calcuate the norm of the gradients
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
gp = tf.reduce_mean((norm - 1.0) ** 2)
return gp
@tf.function
def train_step(self, real_images):
#if isinstance(real_images, tuple):
# real_images = real_images[0]
# Get the batch size
batch_size = tf.shape(real_images)[0]
firstImgs = real_images[0:batch_size-2:3]
secondImgs = real_images[1:batch_size-1:3]
thirdImgs = real_images[2:batch_size:3]
# For each batch, we are going to perform the
# following steps as laid out in the original paper.
# 1. Train the generator and get the generator loss
# 2. Train the discriminator and get the discriminator loss
# 3. Calculate the gradient penalty
# 4. Multiply this gradient penalty with a constant weight factor
# 5. Add gradient penalty to the discriminator loss
# 6. Return generator and discriminator losses as a loss dictionary.
# Train discriminator first. The original paper recommends training
# the discriminator for `x` more steps (typically 5) as compared to
# one step of the generator. Here we will train it for 3 extra steps
# as compared to 5 to reduce the training time.
for i in range(self.d_steps):
with tf.GradientTape() as tape:
# Generate fake images from the latent vector
fake_images = self.generator(secondImgs, training=True)
# Get the logits for the fake images
fake_logits = self.discriminator([firstImgs, fake_images, thirdImgs], training=True)
# Get the logits for real images
real_logits = self.discriminator([firstImgs, real_images, thirdImgs], training=True)
# Calculate discriminator loss using fake and real logits
d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
# Calculate the gradient penalty
gp = self.gradient_penalty(batch_size, firstImgs, secondImgs, thirdImgs, fake_images)
# Add the gradient penalty to the original discriminator loss
d_loss = d_cost + gp * self.gp_weight
# Get the gradients w.r.t the discriminator loss
d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
# Update the weights of the discriminator using the discriminator optimizer
self.d_optimizer.apply_gradients(
zip(d_gradient, self.discriminator.trainable_variables)
)
# Train the generator now.
with tf.GradientTape() as tape:
# Generate fake images using the generator
generated_images = self.generator(secondImgs, training=True)
# Get the discriminator logits for fake images
gen_img_logits = self.discriminator([firstImgs, generated_images, thirdImgs], training=True)
# Calculate the generator loss
g_loss = self.g_loss_fn(gen_img_logits)
# Get the gradients w.r.t the generator loss
gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
# Update the weights of the generator using the generator optimizer
self.g_optimizer.apply_gradients(
zip(gen_gradient, self.generator.trainable_variables)
)
return {"d_loss": d_loss, "g_loss": g_loss}
Training:
def imgGenerator(dataset, batchSize, count):
n = 0
while n < count:
randomFolder = np.random.randint(0, len(dataset))
randomIndex = np.random.randint(0, len(dataset[randomFolder]) - batchSize)
n += 1
yield dataset[randomFolder][randomIndex:randomIndex+batchSize]
dataset = loadDataset()
gen = imgGenerator(dataset, 48, 1000)
print("Done")
# learning_rate=0.0002, beta_1=0.5 are recommened
generator_optimizer = keras.optimizers.Adam(
learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
# Define the loss functions to be used for discrimiator
# This should be (fake_loss - real_loss)
# We will add the gradient penalty later to this loss function
def discriminator_loss(real_img, fake_img):
real_loss = tf.reduce_mean(real_img)
fake_loss = tf.reduce_mean(fake_img)
return fake_loss - real_loss
# Define the loss functions to be used for generator
def generator_loss(fake_img):
return -tf.reduce_mean(fake_img)
# Epochs to train
epochs = 15
imageShape = (256, 256, 1)
g_model = defineGenerator(imageShape)
d_model = defineDiscriminator(imageShape)
# Get the wgan model
wgan = WGAN(
discriminator=d_model,
generator=g_model,
discriminator_extra_steps=3,
)
# Compile the wgan model
wgan.compile(
d_optimizer=discriminator_optimizer,
g_optimizer=generator_optimizer,
g_loss_fn=generator_loss,
d_loss_fn=discriminator_loss,
)
# Start training
wgan.fit(gen, epochs=epochs)
g_model.save_weights("wgenerator_256x256_keras.h5")
d_model.save_weights("wdiscriminator_256x256_keras_2.h5")