I am working with a U-Net in Pytorch Lightning. I am able to train the model successfully but after training when I try to load the model from checkpoint I get this error:
Complete Traceback:
Traceback (most recent call last):
File "src/train.py", line 269, in <module>
main(sys.argv[1:])
File "src/train.py", line 263, in main
model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)
File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 153, in load_from_checkpoint
model = cls._load_model_state(checkpoint, *args, strict=strict, **kwargs)
File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 190, in _load_model_state
model = cls(*cls_args, **cls_kwargs)
File "src/train.py", line 162, in __init__
self.inc = double_conv(self.n_channels, 64)
File "src/train.py", line 122, in double_conv
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 406, in __init__
super(Conv2d, self).__init__(
File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 50, in __init__
if in_channels % groups != 0:
TypeError: unsupported operand type(s) for %: 'dict' and 'int'
I tried surfing the github issues and forums, am not able to figure out what the issue is. Please help.
Here's the code of my model and the checkpoint loading step:
Model:
class Unet(pl.LightningModule):
def __init__(self, n_channels, n_classes=5):
super(Unet, self).__init__()
# self.hparams = hparams
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = True
self.logger = WandbLogger(name="Adam", project="pytorchlightning")
def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def down(in_channels, out_channels):
return nn.Sequential(
nn.MaxPool2d(2), double_conv(in_channels, out_channels)
)
class up(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=False):
super().__init__()
if bilinear:
self.up = nn.Upsample(
scale_factor=2, mode="bilinear", align_corners=True
)
else:
self.up = nn.ConvTranspose2d(
in_channels // 2, in_channels // 2, kernel_size=2, stride=2
)
self.conv = double_conv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# [?, C, H, W]
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(
x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
self.inc = double_conv(self.n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.out = nn.Conv2d(64, self.n_classes, kernel_size=1)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
return self.out(x)
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
loss = self.MSE(y_hat, y)
# wandb_logger.log_metrics({"loss":loss})
return {"loss": loss}
def training_epoch_end(self, outputs):
avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
self.logger.log_metrics({"train_loss": avg_train_loss})
return {"average_loss": avg_train_loss}
def test_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
loss = self.MSE(y_hat, y)
return {"test_loss": loss, "pred": y_hat}
def test_end(self, outputs):
avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
return {"avg_test_loss": avg_loss}
def MSE(self, logits, labels):
return torch.mean((logits - labels) ** 2)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.1, weight_decay=1e-8)
Main Function:
def main(expconfig):
# Define checkpoint callback
checkpoint_callback = ModelCheckpoint(
filepath="/home/africa_wikilimo/data/model_checkpoint/",
save_top_k=1,
verbose=True,
monitor="loss",
mode="min",
prefix="",
)
# Initialise datasets
print("Initializing Climate Dataset....")
clima_train = Clima_Dataset(expconfig[0])
# Initialise dataloaders
print("Initializing train_loader....")
train_dataloader = DataLoader(clima_train, batch_size=2, num_workers=4)
# Initialise model and trainer
print("Initializing model...")
model = Unet(n_channels=9, n_classes=5)
print("Initializing Trainer....")
if torch.cuda.is_available():
model.cuda()
trainer = pl.Trainer(
max_epochs=1,
gpus=1,
checkpoint_callback=checkpoint_callback,
early_stop_callback=None,
)
else:
trainer = pl.Trainer(max_epochs=1, checkpoint_callback=checkpoint_callback)
trainer.fit(model, train_dataloader=train_dataloader)
print(checkpoint_callback.best_model_path)
model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)
Cause
This happens because your model is unable to load hyperparameters(n_channels, n_classes=5) from the checkpoint as you do not save them explicitly.
Fix
You can resolve it by using the
self.save_hyperparameters('n_channels', 'n_classes')
method in your Unet class's init method. Refer PyTorch Lightning hyperparams-docs for more details on the use of this method. Use of save_hyperparameters lets the selected params to be saved in the hparams.yaml along with the checkpoint.Thanks @Adrian Wälchli (awaelchli) from the PyTorch Lightning core contributors team who suggested this fix, when I faced the same issue.