How to update GAN Generator and Discriminator asynchronously in Tensorflow?

1.7k views Asked by At

I want to develop a GAN with Tensorflow, with the Generator being an autoencoder and the Discriminator a Convolutional Neural Net with binary output. There is no problem to develop an autoencoder and the CNN, but my idea is to train 1 epoch for each one of the components (Discriminator and Generator) and repeat this cycle for 1000 epochs, keeping the results (weights) of the previous training epoch for the next one. How can I operationalize this ?

2

There are 2 answers

0
razimbres On BEST ANSWER

I solved the problem. In fact, I want the output of the autoencoder to be the input of the CNN, connecting the GAN and updating weights in the proportion 1:1. I noticed I had to have a special care differentiating the losses of the generator and the discriminator, otherwise in the start of the second loop the tensor loss of the Generator will be replaced by a float, the last loss generated by Discriminator.

HereĀ“s the code:

with tf.Session() as sess:
sess.run(init)
for i in range(1, num_steps+1):

here the Generator training

    batch_x, batch_y=next_batch(batch_size, x_train_noisy, x_train)        
    _, l = sess.run([optimizer, loss], feed_dict={X: batch_x.reshape(n,784),
                    Y:batch_y})
    if i % display_step == 0 or i == 1:
        print('Epoch %i: Denoising Loss: %f' % (i, l))

here the output of the Generator will be used as an input for the Discriminator

    output=sess.run([decoder_op],feed_dict={X: x_train})
    x_train2=np.array(output).reshape(n,784).astype(np.float64)

here the Discriminator training

    batch_x2, batch_y2 = next_batch(batch_size, x_train2, y_train)
    sess.run(train_op, feed_dict={X2: batch_x2.reshape(n,784), Y2: batch_y2, keep_prob: 0.8})
    if i % display_step == 0 or i == 1:
        loss3, acc = sess.run([loss_op2, accuracy], feed_dict={X2: batch_x2,
                                                             Y2: batch_y2,
                                                             keep_prob: 1.0})
        print("Epoch " + str(i) + ", CNN Loss= " + \
              "{:.4f}".format(loss3) + ", Training Accuracy= " + "{:.3f}".format(acc))

This way the asynchronous update can be operationalized in the proportion 1:1, 1:5, 5:1 (Discriminator : Generator) or any other way

0
Lior On

If you have two ops called train_step_generator and train_step_discriminator (each of which are, for example, of the form tf.train.AdamOptimizer().minimize(loss) with an appropriate loss for each), then your training loop should be something similar to the following structure:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(1000):
        if epoch%2 == 0: # train discriminator on even epochs
            for i in range(training_set_size/batch_size):
                z_ = np.random.normal(0,1,batch_size) # this is the input to the generator
                batch = get_next_batch(batch_size)
                sess.run(train_step_discriminator,feed_dict={z:z_, x:batch})
        else: # train generator on odd epochs
            for i in range(training_set_size/batch_size):
                z_ = np.random.normal(0,1,batch_size)  # this is the input to the generator
                sess.run(train_step_generator,feed_dict={z:z_})

The weights will persist between iterations.