I'm trying to enable activation checkpointing for a T5-3b model to significantly free up GPU memory. However, it's not quite clear how to do the implementation for an LLM. Based on the PTL docs, it's something like this:
from lightning.pytorch import Trainer
import deepspeed
class MyModel(LightningModule):
...
def __init__(self):
super().__init__()
self.block_1 = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
self.block_2 = torch.nn.Linear(32, 2)
def forward(self, x):
# Use the DeepSpeed checkpointing function instead of calling the module directly
# checkpointing self.block_1 means the activations are deleted after use,
# and re-calculated during the backward passes
x = deepspeed.checkpointing.checkpoint(self.block_1, x)
return self.block_2(x)
Here is my PTL LLM Model Code. I want to add the deepspeed checkpointing, so I've attempted to do it like this:
class T5FineTuner(pl.LightningModule):
"""PyTorch Lightning T5 Model class"""
def __init__(self, hparams, tokenizer, model):
"""initiates a PyTorch Lightning T5 Model"""
super().__init__()
self.hparams.update(vars(hparams))
self.save_hyperparameters(self.hparams)
self.model = model
self.tokenizer = tokenizer
self.outputdir = self.hparams.output_dir
self.average_training_loss = None
self.average_validation_loss = None
self.save_only_last_epoch = self.hparams.save_only_last_epoch
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
return deepspeed.checkpointing.checkpoint(self._forward, input_ids, attention_mask, decoder_attention_mask, labels)
def _forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
output = self.model(
input_ids,
attention_mask=attention_mask,
labels=labels,
decoder_attention_mask=decoder_attention_mask,
)
return output.loss, output.logits
def training_step(self, batch, batch_size):
"""training step"""
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
labels_attention_mask = batch["labels_attention_mask"]
loss, outputs = self(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels,
)
self.log(
"train_loss",
loss,
prog_bar=True,
logger=True,
on_epoch=True,
on_step=True,
sync_dist=True,
)
return loss
Unfortunately, I get the following error:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
The full reproducible example can be found here