Training Loss When Resuming From a Checkpoint Explodes

759 views Asked by At

I am trying to implement a function in my algorithm which allows me to resume training from a checkpoint. The problem is that when I resume training, my loss explodes by many orders of magnitude, from the order to 0.001 to 1000. I suspect that the problem may be that when training is resumed, the learning rate is not being set properly.

Here is my training function:

def train_gray(epoch, data_loader, device, model, criterion, optimizer, i, path):
    train_loss = 0.0
    for data in data_loader:
        img, _ = data
        img = img.to(device)
        stand_dev = 0.0392

        noisy_img = add_noise(img, stand_dev, device)

        output = model(noisy_img, stand_dev)
        output = output[:,0:1,:,:]  
        loss = criterion(output, img)

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()
        train_loss += loss.item()*img.size(0)

    train_loss = train_loss/len(data_loader)

    print('Epoch: {} Complete \tTraining Loss: {:.6f}'.format(
        epoch,
        train_loss
        ))
    return train_loss

And here is my main function that initialises my variables, loads a checkpoint, calls my training function, and saves a checkpoint after an epoch of training:


def main():
    now = datetime.now()
    current_time = now.strftime("%H_%M_%S")
    path = "/home/bledc/my_remote_folder/denoiser/models/{}_sigma_10_session2".format(current_time)
    os.mkdir(path)

    width = 256
    # height = 256
    num_epochs = 25
    batch_size = 4
    learning_rate = 0.0001

    data_loader = load_dataset(batch_size, width)
    
    model = UNetWithResnet50Encoder().to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=1e-5)

    ############################################################################################
    # UNCOMMENT CODE BELOW TO RESUME TRAINING FROM A MODEL
    model_path = "/home/bledc/my_remote_folder/denoiser/models/resnet_sigma_10/model_epoch_10.pt"
    save_point = torch.load(model_path)
    model.load_state_dict(save_point['model_state_dict'])
    optimizer.load_state_dict(save_point['optimizer_state_dict'])
    epoch = save_point['epoch']
    train_loss = save_point['train_loss']
    model.train()
    ############################################################################################

    for i in range(epoch, num_epochs+1):
        train_loss = train_gray(i, data_loader, device, model, criterion, optimizer, i, path)
        checkpoint(i, train_loss, model, optimizer, path)

    print("end")

Lastly, here is my function to save checkpoints:

def checkpoint(epoch, train_loss, model, optimizer, path):
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss
            }, path+"/model_epoch_{}.pt".format(epoch))
    print("Epoch saved")

If my problem is that I am not saving my learning rate, how would I do this?

Any help would be greatly appreciated, Clement

Update: I'm fairly certain that the problem lies in my pretrained model. I am saving the optimiser every epoch but the optimiser only holds information for the trainable layers. I hope to solve this soon and post a more thorough answer when I figure out who to save and load the entire model.

0

There are 0 answers