How to save a GAN model with custom training loop?

30 views Asked by At

I have implemented a custom WGAN model and recently I refactored it to make it be able to be compiled and trained using 'fit' command.

However, I just cannot save the model, I can save the 'generator' model and 'discriminator' model separately successfully, but I cannot save 'gan' model.

I tried to use 'model.save()' command and it seems it's probably not possible since I have my custom training logic in 'train_step'.

So, I used 'tf.saved_model.save(gan, "./GAN")', it works, however, I can only load a vanilla 'gan' after it's saved this way.

I know it's general description but can someone point me to some resources so I can investigate and resolve this issue. I am thinking one solution may be using checkpoint approaches as some old posts suggested. But I may need to put check point into my training loop and hopefully during training (model.fit) these ckpt will be recorded. I am wondering if this is possible.

Any advice will be appreciated.

-Xiaokuan

0

There are 0 answers