State dict not persisting to checkpoint file for custom callback

65 views Asked by At

I'm attempting to save additional information and metrics from model training in a custom pytorch lightning callback. I've implemented the load_state_dict and state_dict functions as outlined in the documentation here. However, the information is not being saved to the model's checkpoint file.

I have the following simple custom callback implemented:

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback

class MyCallback(Callback):

    def __init__(self) -> None:
        self.state = {"metric": None}

    def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule]) -> None:

        # do stuff with trainer and pl_module to get the metric
        updated_metric = ...

        self.state["metric"] = updated_metric

    def load_state_dict(self, state_dict):
        self.state.update(state_dict)

    def state_dict(self):
        return self.state.copy()

I initialize and train my trainer with the following callbacks:

trainer = pl.trainer(..., callbacks=[EarlyStopping(...), ModelCheckpoint(...), MyCallback(...)])
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

After training, direct inspection of the trainer's MyCallback state (trainer.callbacks[<callback_idx>]["state"]) indicates that the state has been properly updated with updated_metric. However, if I try and load the checkpoint created from training, the callback state does not persist and shows up as {"metric": None}. The docs linked above seem to indicate that simply implementing load_state_dict and state_dict is enough to "persist the state effectively", but I'm not sure if I'm missing something here?

After training, I'm loading the checkpoint with:

loaded_state = torch.load(path/to/checkpoint.ckpt)

All my callbacks appear in the loaded state's "callbacks" field, but only the ModelCheckpoint and EarlyStopping callbacks have persistent states. After stepping through MyCallback in my debugger, it seems like state_dict is getting called before on_train_end. Based on state_dict's doc string, I know that this method gets called when a checkpoint is getting saved, which means the state dict is getting saved before being updated by on_train_end. In light of this, I also tried implementing my metric calculation code in the on_save_checkpoint hook, but experienced the same result.

Any help would be greatly appreciated.

0

There are 0 answers