So I tried to use Captum with PyTorch Lightning. I am having issues when passing the Module to Captum, since it seems to do weird reshaping of the tensors. For example in the below minimal example, the lightning code works easy and well. But when I use IntegratedGradient with "n_step>=1" I get an issue. The code of the LighningModule is not that important I would say, I wonder more at the code line at the very bottom.
Does anyone know how to work around this?
from captum.attr import IntegratedGradients
from torch import nn, optim, rand, sum as tsum, reshape, device
import torch.nn.functional as F
from pytorch_lightning import seed_everything, LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
SAMPLE_DIM = 3
class CustomDataset(Dataset):
def __init__(self, samples=42):
self.dataset = rand(samples, SAMPLE_DIM).cuda().float() * 2 - 1
def __getitem__(self, index):
return (self.dataset[index], (tsum(self.dataset[index]) > 0).cuda().float())
def __len__(self):
return self.dataset.size()[0]
class OurModel(LightningModule):
def __init__(self):
super(OurModel, self).__init__()
# Network layers
self.linear = nn.Linear(SAMPLE_DIM, 2048)
self.linear2 = nn.Linear(2048, 1)
self.output = nn.Sigmoid()
# Hyper-parameters, that we will auto-tune using lightning!
self.lr = 0.001
self.batch_size = 512
def forward(self, x):
x = self.linear(x)
x = self.linear2(x)
output = self.output(x)
return reshape(output, (-1,))
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.lr)
def train_dataloader(self):
loader = DataLoader(CustomDataset(samples=1000), batch_size=self.batch_size, shuffle=True)
return loader
def training_step(self, batch, batch_nb):
x, y = batch
loss = F.binary_cross_entropy(self(x), y)
return {'loss': loss, 'log': {'train_loss': loss}}
if __name__ == '__main__':
seed_everything(42)
device = device("cuda")
model = OurModel().to(device)
trainer = Trainer(max_epochs=2, min_epochs=1, auto_lr_find=False,
progress_bar_refresh_rate=10)
trainer.fit(model)
# ok Now the Problem
test_input = CustomDataset(samples=1).__getitem__(0)[0].requires_grad_()
ig = IntegratedGradients(model)
attr, delta = ig.attribute(test_input, target=1, return_convergence_delta=True)
The solution was to wrap the forward function. Make sure that the shape going into the mode.foward() is correct!