How to add Deepspeed Activation Checkpointing to LLM for Fine-Tuning in PyTorch Lightning?

371 views Asked by At

I'm trying to enable activation checkpointing for a T5-3b model to significantly free up GPU memory. However, it's not quite clear how to do the implementation for an LLM. Based on the PTL docs, it's something like this:

from lightning.pytorch import Trainer
import deepspeed


class MyModel(LightningModule):
    ...

    def __init__(self):
        super().__init__()
        self.block_1 = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
        self.block_2 = torch.nn.Linear(32, 2)

    def forward(self, x):
        # Use the DeepSpeed checkpointing function instead of calling the module directly
        # checkpointing self.block_1 means the activations are deleted after use,
        # and re-calculated during the backward passes
        x = deepspeed.checkpointing.checkpoint(self.block_1, x)
        return self.block_2(x)

Here is my PTL LLM Model Code. I want to add the deepspeed checkpointing, so I've attempted to do it like this:

class T5FineTuner(pl.LightningModule):
    """PyTorch Lightning T5 Model class"""

    def __init__(self, hparams, tokenizer, model):
        """initiates a PyTorch Lightning T5 Model"""
        super().__init__()
        self.hparams.update(vars(hparams))
        self.save_hyperparameters(self.hparams)

        self.model = model
        self.tokenizer = tokenizer
        self.outputdir = self.hparams.output_dir
        self.average_training_loss = None
        self.average_validation_loss = None
        self.save_only_last_epoch = self.hparams.save_only_last_epoch

    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
                
        return deepspeed.checkpointing.checkpoint(self._forward, input_ids, attention_mask, decoder_attention_mask, labels)
    
    def _forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
                
        output = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask,
        )

        return output.loss, output.logits

    def training_step(self, batch, batch_size):
        """training step"""
        input_ids = batch["source_text_input_ids"]
        attention_mask = batch["source_text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels,
        )

        self.log(
            "train_loss",
            loss,
            prog_bar=True,
            logger=True,
            on_epoch=True,
            on_step=True,
            sync_dist=True,
        )
        return loss

Unfortunately, I get the following error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

The full reproducible example can be found here

0

There are 0 answers