Training ESRGAN : Same tape for both gradient and discriminator?

24 views Asked by At

I am trying to reproduce the training of ESRGAN (or real ESRGAN : https://github.com/xinntao/ESRGAN) with a simplified code from the original, which is quite complex. I have to train it on my dataset, a biological dataset. However :

  • I have seen different codes where tf.GradientTape(persistent=True) (here tape) is the same for both the discriminator and the generator ;
  • In general, it seems that they must have differents tape.

So I do not know if ESRGAN is an isolated case.

I wanted to know whether the following code is adequate, or completely wrong.

In case of, here is the link to the global repository in github : https://github.com/SalomePx/ESRGAN2

def train_step(lr, hr):
    with tf.GradientTape(persistent=True) as tape:
        sr = generator(lr, training=True)
        hr_output = discriminator(hr, training=True)
        sr_output = discriminator(sr, training=True)

        losses_D = {}
        losses_D['reg'] = tf.reduce_sum(discriminator.losses)
        losses_D['gan'] = dis_loss_fn(hr_output, sr_output)

        losses_G = {}
        losses_G['reg'] = tf.reduce_sum(generator.losses)
        losses_G['pixel'] = 1e-2 * pixel_loss_fn(hr, sr)
        losses_G['feature'] = 1.0 * fea_loss_fn(hr, sr)

        losses_G['gan'] = 5e-3 * gen_loss_fn(hr_output, sr_output)

        total_loss_G = tf.add_n([l for l in losses_G.values()])
        total_loss_D = tf.add_n([l for l in losses_D.values()])

    grads_G = tape.gradient(total_loss_G, generator.trainable_variables)
    grads_D = tape.gradient(total_loss_D, discriminator.trainable_variables)
    optimizer_G.apply_gradients(zip(grads_G, generator.trainable_variables))
    optimizer_D.apply_gradients(zip(grads_D, discriminator.trainable_variables))

    return total_loss_G, total_loss_D, losses_G, losses_D
0

There are 0 answers