pytorch torch.load load_checkpoint and learning_rate

1.6k views Asked by At

Following this medium post, I understand how to save and load my model (or at least I think I do). They say the learning_rate is saved. However, looking at this person's code (it's a github repo with lots of people watching, forking, etc. so I'm assuming it shouldn't be filled with mistakes), the person writes:

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

Why doesn't optimizer.load_state_dict(checkpoint["optimizer"]) give the learning rate of old checkpoint. If so (I believe it does), why do they say it's a problem If we don't do this then it will just have learning rate of old checkpoint and it will lead to many hours of debugging.

There is no learning rate decay anyway in the code. So should it even matter?

1

There are 1 answers

11
CuCaRot On

Why doesn't optimizer.load_state_dict(checkpoint["optimizer"]) give the learning rate of old checkpoint.

With Pytorch, the learning rate is a constant variable in the optimizer object, and it can be adjusted via torch.optim.lr_scheduler.

In case you want to keep training at the point where it stopped last time, the scheduler would keep all information about the optimizer that you need to continue: the strategy to adjust the learning rate, the last epoch, the step-index that model was on, the last learning rate (this should be the same with the optimizer learning rate), then your model can keep training just like it never stop before.

>>> import torch
>>> model = torch.nn.Linear(5, 1, bias=False)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0)
>>> scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, >>> gamma=0.1)
>>> print(scheduler.state_dict())
{'step_size': 1, 'gamma': 0.1, 'base_lrs': [0.1], 'last_epoch': 0, '_step_count': 1, 'verbose': False, '_get_lr_called_within_step': False, '_last_lr': [0.1]}

why do they say it's a problem If we don't do this then it will just have learning rate of old checkpoint and it will lead to many hours of debugging.

Normally, if you didn't touch the learning rate, it should be the same as the initial one. I guess they did something with it from other projects and just want to ensure the value of the learning rate this time.

The code that you provided was about CycleGAN, but I also found it in ESRGAN, Pix2Pix, ProGAN, SRGAN, etc. so I think they used the same utils for multiple projects.

There is no learning rate decay anyway in the code. So should it even matter?

I found no learning rate scheduler in the CycleGAN code so I believe it doesn't matter if you remove those lines, but in this case only.