Using Captum with Pytorch Lightning?

879 views Asked by At

So I tried to use Captum with PyTorch Lightning. I am having issues when passing the Module to Captum, since it seems to do weird reshaping of the tensors. For example in the below minimal example, the lightning code works easy and well. But when I use IntegratedGradient with "n_step>=1" I get an issue. The code of the LighningModule is not that important I would say, I wonder more at the code line at the very bottom.

Does anyone know how to work around this?

from captum.attr import IntegratedGradients
from torch import nn, optim, rand, sum as tsum, reshape, device
import torch.nn.functional as F
from pytorch_lightning import seed_everything, LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset

SAMPLE_DIM = 3


class CustomDataset(Dataset):
    def __init__(self, samples=42):
        self.dataset = rand(samples, SAMPLE_DIM).cuda().float() * 2 - 1

    def __getitem__(self, index):
        return (self.dataset[index], (tsum(self.dataset[index]) > 0).cuda().float())

    def __len__(self):
        return self.dataset.size()[0]


class OurModel(LightningModule):
    def __init__(self):
        super(OurModel, self).__init__()
        # Network layers
        self.linear = nn.Linear(SAMPLE_DIM, 2048)
        self.linear2 = nn.Linear(2048, 1)
        self.output = nn.Sigmoid()
        # Hyper-parameters, that we will auto-tune using lightning!
        self.lr = 0.001
        self.batch_size = 512

    def forward(self, x):
        x = self.linear(x)
        x = self.linear2(x)
        output = self.output(x)
        return reshape(output, (-1,))

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        loader = DataLoader(CustomDataset(samples=1000), batch_size=self.batch_size, shuffle=True)
        return loader

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.binary_cross_entropy(self(x), y)
        return {'loss': loss, 'log': {'train_loss': loss}}


if __name__ == '__main__':
    seed_everything(42)
    device = device("cuda")
    model = OurModel().to(device)
    trainer = Trainer(max_epochs=2, min_epochs=1, auto_lr_find=False,
                      progress_bar_refresh_rate=10)
    trainer.fit(model)
    # ok Now the Problem
    test_input = CustomDataset(samples=1).__getitem__(0)[0].requires_grad_()
    ig = IntegratedGradients(model)
    attr, delta = ig.attribute(test_input, target=1, return_convergence_delta=True)
1

There are 1 answers

0
SLuck On BEST ANSWER

The solution was to wrap the forward function. Make sure that the shape going into the mode.foward() is correct!

# Solution is this wrapper function
def modified_f(in_vec):
    # Shape here is wrong
    print("IN:", in_vec.size())
    x = torch.reshape(in_vec, (int(in_vec.size()[0]/SAMPLE_DIM), SAMPLE_DIM))
    print("x:", x.size())

    res = model.forward(x)
    print("res:", res.size())
    res = torch.reshape(res, (res.size()[0], 1))
    print("res2:", res.size())

    return res


ig = IntegratedGradients(modified_f)
attr, delta = ig.attribute(test_input, return_convergence_delta=True, n_steps=STEP_AMOUNT)