Train loss increases after resuming lightning training with different dataset

71 views Asked by At

I shuffled and divided my dataset into several parts. After training first subset of data I noticed loss spike at the very beginning of next subset. This is not expected since subsets do not vary much from each other. Loss spikes occur every time after loading next subset of data. Image below.

loss spikes

When I downscaled my training to dummy size, problem still occured. Merging all subsets into one big train set caused loss to go down smoothly without any problem. Also when I loaded the same subset for second time loss did not rise. This suggests there is something wrong with lightning checkpoint. I load my dataset as follows:

def load_dataset(file):
  src_file = file.replace("_", "src")
  tgt_file = file.replace("_", "tgt")
  src = torch.load(src_file)
  tgt = torch.load(tgt_file)
  ds = DummyDataset(src, tgt)
  return ds

class TranslatorDataLoader(LightningDataModule):
    def __init__(self, train_dataset, dev_dataset, batch_size, num_workers=2):
        super().__init__()
        self.train_dataset = train_dataset
        self.dev_dataset = dev_dataset
        self.batch_size = batch_size
        self.num_workers = num_workers


    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size = self.batch_size,
            num_workers = self.num_workers,
            pin_memory = True,
            shuffle = True,
        )


    def val_dataloader(self):
        return DataLoader(
            self.dev_dataset,
            batch_size = self.batch_size,
            num_workers = self.num_workers,
            pin_memory = True,
            shuffle = False,
        )

Later during training I instantiate approprioate subset of data and resume training via:

trainer.fit(model, data_loader, ckpt_path=config["checkpoint"])

My model checkpoint callback looks as follows:

checkpoint_callback = ModelCheckpoint(
        dirpath = checkpoint_dir,
        filename = "{epoch:02d}-{train_loss:.4f}-{dev_loss:.4f}",
        save_top_k = config["save_top_k"],
)

I investigated optimizer state and learning rate schedulers but when plotting them to Tensorboard they revealed nothing suspicious. I am training using Google Colab with both torch and lightning 2.0.1. Does any one know where the problem may be?

0

There are 0 answers