How to load GAN checkpoint properly in PyTorch?

632 views Asked by At

I trained a GAN on 256x256 images, basically extending the code in PyTorch' own DCGAN tutorial to accommodate larger resolution images. The model and optimizer initialization look like this:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gen = Generator(...).to(device)
disc = Discriminator(...).to(device)

opt_gen = optim.Adam(gen.parameters(), ...)
opt_disc = optim.Adam(disc.parameters(), ...)

gen.train()
disc.train()

The GAN produced good quality samples. A few times during each epoch, I generated a few images (and viewed them on Tensorboard using SummaryWriter) using the same input vector fixed_noise to the generator:

with torch.no_grad():
    fake = gen(fixed_noise)

    img_grid_real = torchvision.utils.make_grid(
        real[:NUM_VISUALIZATION_SAMPLES], normalize=True
    )
    img_grid_fake = torchvision.utils.make_grid(
        fake[:NUM_VISUALIZATION_SAMPLES], normalize=True
    )

    writer_real.add_image("Real", img_grid_real, global_step=step)
    writer_fake.add_image("Fake", img_grid_fake, global_step=step)

I saved the GAN after each training epoch as such:

checkpoint = {
    "gen_state": gen.state_dict(),
    "gen_optimizer": opt_gen.state_dict(),
    "disc_state": disc.state_dict(),
    "disc_optimizer": opt_disc.state_dict()
}
torch.save(checkpoint, f"checkpoints/checkpoint_{epoch_number}.pth.tar")

Thus far, I had trained the GAN on a CentOS7.9 machine with an NVIDIA T4 GPU, with PyTorch 1.11.0. I then rsync'd a few checkpoints (that had been saved as described above) to my personal machine (Windows 10, NVIDIA GTX1050Ti, PyTorch 1.10.1). Using the exact same class definition for the GAN, and initializing it the same way (cf. first code snippet, except for setting them in training mode), I loaded a checkpoint as such:

checkpoint = torch.load(f"checkpoints/checkpoint_10.pth.tar")
gen.load_state_dict(checkpoint["gen_state"])
opt_gen.load_state_dict(checkpoint["gen_optimizer"])
disc.load_state_dict(checkpoint["disc_state"])
opt_disc.load_state_dict(checkpoint["disc_optimizer"])

I then used the same code as in the second code snippet to generate some images with the trained GAN, now in my machine with the loaded checkpoint. This yielded garbage output:

enter image description here

I tried using all the checkpoints I had, and all output nonsense. I looked in the PyTorch forums for questions (1, 2, 3), but none seemed to help.

Am I saving/loading the model wrong?

0

There are 0 answers