Keras - NotImplementedError: When subclassing the `Model` class, you should implement a `call` method

818 views Asked by At

I modified the WGAN-GP from the keras website (https://keras.io/examples/generative/wgan_gp/) for what I needed: my network takes 3 consecutive images as input. The generator takes the one in the middle and tries to generate a similar image, the discriminator takes all 3 images and evaluates if the second is really the central one among the 3. However it gives me this error when I call fit():

NotImplementedError: When subclassing the `Model' class, you should implement a call method.

I am using tensorflow 2.3

Model:

import tensorflow as tf
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU, MaxPooling2D
from tensorflow.keras.layers import Activation, Flatten, Dense
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import concatenate
from tensorflow.keras.layers import LayerNormalization
from tensorflow.keras.utils import plot_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers import SGD
import tensorflow.keras.backend as K
from tensorflow.keras.layers import GaussianNoise
from tensorflow import keras
import numpy as np

init = RandomNormal(mean=0.0, stddev=0.02)

GNOISE = 0.25
BATCH_SIZE = 32

def encodeImage(image):
    layers = []

    layer1 = Conv2D(32, 3, kernel_initializer=init, padding='same')(image)
    layer1 = BatchNormalization()(layer1)
    layer1 = LeakyReLU(alpha=0.01)(layer1)
    layer1 = Conv2D(32, 3, kernel_initializer=init, padding='same')(layer1)
    layer1 = BatchNormalization()(layer1)
    layer1 = LeakyReLU(alpha=0.01)(layer1)
    layer1 = GaussianNoise(GNOISE)(layer1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(layer1)
    layers.append(layer1)
    print(pool1.shape)

    layer2 = Conv2D(64, 3, kernel_initializer=init, padding='same')(pool1)
    layer2 = BatchNormalization()(layer2)
    layer2 = LeakyReLU(alpha=0.01)(layer2)
    layer2 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer2)
    layer2 = BatchNormalization()(layer2)
    layer2 = LeakyReLU(alpha=0.01)(layer2)
    layer2 = GaussianNoise(GNOISE)(layer2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(layer2)
    layers.append(layer2)
    print(pool2.shape)

    layer3 = Conv2D(64, 3, kernel_initializer=init, padding='same')(pool2)
    layer3 = BatchNormalization(momentum=0.8)(layer3)
    layer3 = LeakyReLU(alpha=0.01)(layer3)
    layer3 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer3)
    layer3 = BatchNormalization(momentum=0.8)(layer3)
    layer3 = LeakyReLU(alpha=0.01)(layer3)
    layer3 = GaussianNoise(GNOISE)(layer3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(layer3)
    layers.append(layer3)
    print(pool3.shape)

    layer4 = Conv2D(128, 3, kernel_initializer=init, padding='same')(pool3)
    layer4 = BatchNormalization(momentum=0.8)(layer4)
    layer4 = LeakyReLU(alpha=0.01)(layer4)
    layer4 = Conv2D(128, 3, kernel_initializer=init, padding='same')(layer4)
    layer4 = BatchNormalization(momentum=0.8)(layer4)
    layer4 = LeakyReLU(alpha=0.01)(layer4)
    layer4 = GaussianNoise(GNOISE)(layer4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(layer4)
    layers.append(layer4)
    print(pool4.shape)

    return pool4, layers


def decodeImage(image, layersA):
    layer1 = Conv2DTranspose(128, 3, 2, kernel_initializer=init, padding='same')(image)
    layer1 = Concatenate()([layer1, layersA[3]])
    layer1 = Conv2D(128, 3, kernel_initializer=init, padding='same')(layer1)
    layer1 = BatchNormalization(momentum=0.8)(layer1)
    layer1 = LeakyReLU(alpha=0.01)(layer1)
    layer1 = Dropout(0.2)(layer1, training=True)
    layer1 = Conv2D(128, 3, kernel_initializer=init, padding='same')(layer1)
    layer1 = BatchNormalization(momentum=0.8)(layer1)
    layer1 = LeakyReLU(alpha=0.01)(layer1)
    layer1 = GaussianNoise(GNOISE)(layer1)

    layer2 = Conv2DTranspose(64, 3, 2, kernel_initializer=init, padding='same')(layer1)
    layer2 = Concatenate()([layer2, layersA[2]])
    layer2 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer2)
    layer2 = BatchNormalization(momentum=0.8)(layer2)
    layer2 = LeakyReLU(alpha=0.01)(layer2)
    layer2 = Dropout(0.2)(layer2, training=True)
    layer2 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer2)
    layer2 = BatchNormalization(momentum=0.8)(layer2)
    layer2 = LeakyReLU(alpha=0.01)(layer2)
    layer2 = GaussianNoise(GNOISE)(layer2)


    layer3 = Conv2DTranspose(64, 3, 2, kernel_initializer=init, padding='same')(layer2)
    layer3 = Concatenate()([layer3, layersA[1]])
    layer3 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer3)
    layer3 = BatchNormalization(momentum=0.8)(layer3)
    layer3 = LeakyReLU(alpha=0.01)(layer3)
    layer3 = Dropout(0.2)(layer3, training=True)
    layer3 = Conv2D(64, 3, kernel_initializer=init, padding='same')(layer3)
    layer3 = BatchNormalization(momentum=0.8)(layer3)
    layer3 = LeakyReLU(alpha=0.01)(layer3)
    layer3 = GaussianNoise(GNOISE)(layer3)


    layer4 = Conv2DTranspose(32, 3, 2, kernel_initializer=init, padding='same')(layer3)
    layer4 = Concatenate()([layer4, layersA[0]])
    layer4 = Conv2D(32, 3, kernel_initializer=init, padding='same')(layer4)
    layer4 = BatchNormalization(momentum=0.8)(layer4)
    layer4 = LeakyReLU(alpha=0.01)(layer4)
    layer4 = Dropout(0.2)(layer4, training=True)
    layer4 = Conv2D(32, 3, kernel_initializer=init, padding='same')(layer4)
    layer4 = BatchNormalization(momentum=0.8)(layer4)
    layer4 = LeakyReLU(alpha=0.01)(layer4)
    layer4 = GaussianNoise(GNOISE)(layer4)

    outImage = Conv2D(1, 1, kernel_initializer=init, padding='same')(layer4)
    outImage = Activation('tanh')(outImage)

    return outImage


def defineGenerator(inputShape):
    realImage = Input(shape=inputShape)

    encodedA, layersA = encodeImage(realImage)

    combined = Conv2D(8, 3, kernel_initializer=init, input_shape=inputShape, padding='same')(encodedA)
    combined = BatchNormalization(momentum=0.8)(combined)
    combined = LeakyReLU(alpha=0.01)(combined)
    combined = Conv2D(8, 3, kernel_initializer=init, input_shape=inputShape, padding='same')(combined)
    combined = BatchNormalization(momentum=0.8)(combined)
    combined = LeakyReLU(alpha=0.01)(combined)

    outImage = decodeImage(combined, layersA)

    model = Model(realImage, outImage)

    return model



def defineDiscriminator(imageShape):

    imageA = Input(shape=imageShape)
    imageB = Input(shape=imageShape)
    imageC = Input(shape=imageShape)

    merged = Concatenate()([imageA, imageB, imageC])

    layer1 = Conv2D(32, 5, 2, padding='same', kernel_initializer=init)(merged)
    #layer1 = BatchNormalization(momentum=0.8)(layer1)
    layer1 = LayerNormalization()(layer1)
    layer1 = LeakyReLU(alpha=0.01)(layer1)
    layer1 = Dropout(0.2)(layer1, training=True)
    layer1 = Conv2D(32, 5, padding='same', kernel_initializer=init)(layer1)
    #layer1 = BatchNormalization(momentum=0.8)(layer1)
    layer1 = LayerNormalization()(layer1)
    layer1 = LeakyReLU(alpha=0.01)(layer1)
    layer1 = GaussianNoise(GNOISE)(layer1)

    layer2 = Conv2D(64, 5, 2, padding='same', kernel_initializer=init)(layer1)
    layer2 = LeakyReLU(alpha=0.01)(layer2)
    #layer2 = BatchNormalization(momentum=0.8)(layer2)
    layer2 = LayerNormalization()(layer2)
    layer2 = Dropout(0.2)(layer2, training=True)
    layer2 = Conv2D(64, 5, padding='same', kernel_initializer=init)(layer2)
    layer2 = LeakyReLU(alpha=0.01)(layer2)
    #layer2 = BatchNormalization(momentum=0.8)(layer2)
    layer2 = LayerNormalization()(layer2)
    layer2 = GaussianNoise(GNOISE)(layer2)

    layer3 = Conv2D(64, 5, 2, padding='same', kernel_initializer=init)(layer2)
    layer3 = LeakyReLU(alpha=0.01)(layer3)
    #layer3 = BatchNormalization(momentum=0.8)(layer3)
    layer3 = LayerNormalization()(layer3)
    layer3 = Dropout(0.3)(layer3, training=True)
    layer3 = Conv2D(64, 5, padding='same', kernel_initializer=init)(layer3)
    layer3 = LeakyReLU(alpha=0.01)(layer3)
    #layer3 = BatchNormalization(momentum=0.8)(layer3)
    layer3 = LayerNormalization()(layer3)
    layer3 = GaussianNoise(GNOISE)(layer3)

    layer4 = Conv2D(128, 5, 2, padding='same', kernel_initializer=init)(layer3)
    layer4 = LeakyReLU(alpha=0.01)(layer4)
    #layer4 = BatchNormalization(momentum=0.8)(layer4)
    layer4 = LayerNormalization()(layer4)
    layer4 = Dropout(0.3)(layer4, training=True)
    layer4 = Conv2D(128, 5, 2, padding='same', kernel_initializer=init)(layer4)
    layer4 = LeakyReLU(alpha=0.01)(layer4)
    #layer4 = BatchNormalization(momentum=0.8)(layer4)
    layer4 = LayerNormalization()(layer4)
    layer4 = GaussianNoise(GNOISE)(layer4)

    output = Flatten()(layer4)

    output = Dense(1, activation=None)(output)

    model = Model([imageA, imageB, imageC], output)
    
    return model

class WGAN(Model):
    def __init__(
        self,
        discriminator,
        generator,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def gradient_penalty(self, batch_size, imgA, imgB, imgC, fakeImg):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # get the interplated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fakeImg - imgB
        interpolated = imgB + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator([imgA, interpolated, imgC], training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calcuate the norm of the gradients
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    @tf.function
    def train_step(self, real_images):
        #if isinstance(real_images, tuple):
        #    real_images = real_images[0]

        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        firstImgs = real_images[0:batch_size-2:3]
        secondImgs = real_images[1:batch_size-1:3]
        thirdImgs = real_images[2:batch_size:3]

        # For each batch, we are going to perform the
        # following steps as laid out in the original paper.
        # 1. Train the generator and get the generator loss
        # 2. Train the discriminator and get the discriminator loss
        # 3. Calculate the gradient penalty
        # 4. Multiply this gradient penalty with a constant weight factor
        # 5. Add gradient penalty to the discriminator loss
        # 6. Return generator and discriminator losses as a loss dictionary.

        # Train discriminator first. The original paper recommends training
        # the discriminator for `x` more steps (typically 5) as compared to
        # one step of the generator. Here we will train it for 3 extra steps
        # as compared to 5 to reduce the training time.
        for i in range(self.d_steps):
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(secondImgs, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator([firstImgs, fake_images, thirdImgs], training=True)
                # Get the logits for real images
                real_logits = self.discriminator([firstImgs, real_images, thirdImgs], training=True)

                # Calculate discriminator loss using fake and real logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, firstImgs, secondImgs, thirdImgs, fake_images)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        # Train the generator now.
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(secondImgs, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator([firstImgs, generated_images, thirdImgs], training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}

Training:

def imgGenerator(dataset, batchSize, count):
    n = 0
    while n < count:
        randomFolder = np.random.randint(0, len(dataset))
        randomIndex = np.random.randint(0, len(dataset[randomFolder]) - batchSize)
        n += 1
        yield dataset[randomFolder][randomIndex:randomIndex+batchSize]

dataset = loadDataset()
gen = imgGenerator(dataset, 48, 1000)
print("Done")

# learning_rate=0.0002, beta_1=0.5 are recommened
generator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)

# Define the loss functions to be used for discrimiator
# This should be (fake_loss - real_loss)
# We will add the gradient penalty later to this loss function
def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss


# Define the loss functions to be used for generator
def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)

# Epochs to train
epochs = 15

imageShape = (256, 256, 1)
g_model = defineGenerator(imageShape)
d_model = defineDiscriminator(imageShape)

# Get the wgan model
wgan = WGAN(
    discriminator=d_model,
    generator=g_model,
    discriminator_extra_steps=3,
)

# Compile the wgan model
wgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss,
)

# Start training
wgan.fit(gen, epochs=epochs)
g_model.save_weights("wgenerator_256x256_keras.h5")
d_model.save_weights("wdiscriminator_256x256_keras_2.h5")
0

There are 0 answers