I recently tried to implement a DeepLabV3 training pipeline using pytorch lightning. I wanted to use the build-in torchmetrics.JaccardIndex
as my evaluation metric. My LightningModule
looks like this:
import torchmetrics
from pytorch_lightning import LightningModule
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50
class DeepLabV3LightningModule(LightningModule):
def __init__(self):
self.model = deeplabv3_resnet50(
num_classes=38,
aux_loss=False
)
self.loss = nn.CrossEntropyLoss(ignore_index=255, reduction="mean")
self.iou_metric = torchmetrics.JaccardIndex(
task="multiclass",
threshold=0.5,
num_classes=38,
average="macro",
)
def training_step(self, batch, batch_idx):
imgs, masks = batch
out = self.model(imgs)
preds = out["out"]
loss = self.loss(preds, masks)
return loss
def validation_step(self, batch, batch_idx):
imgs, masks = batch
out = self.model(imgs)
preds = out["out"]
loss = self.loss(preds, masks)
preds = torch.softmax(preds, dim=1)
pred_labels = torch.argmax(preds, dim=1)
# measure runtime of metric update
start = timer()
self.iou_metric.update(pred_labels, masks)
elapsed = timer() - start
return elapsed
def validation_epoch_end(self, outputs):
avg_runtime = round(mean(outputs), 4)
print(f"GPU {self.local_rank}: {avg_runtime} seconds")
When using this validation procedure, it is extremely slow. On average, the update step of the metric takes 23.4 seconds. However, the first 3 updates are very fast (<1 second), then they become slow.
I tried to reproduce this behavior in a MWE:
from timeit import default_timer as timer
from statistics import mean
import torchmetrics
import torch
num_classes = 38
iou_metric = torchmetrics.JaccardIndex(
task="multiclass",
threshold=0.5,
num_classes=num_classes,
average="macro"
).to("cuda")
# dummy labels in shape [b, h, w]
label_mask = torch.randint(low=0, high=num_classes-1, size=(8, 480, 640), device="cuda")
# dummy predicted labels in shape [b, h, w]
pred_mask = torch.randint(low=0, high=num_classes-1, size=(8, 480, 640), device="cuda")
runtime_hist = []
for i in range(100):
start = timer()
iou_metric.update(label_mask, pred_mask)
elapsed = timer() - start
runtime_hist.append(elapsed)
avg_runtime = round(mean(runtime_hist), 2)
print(avg_runtime)
Here I get an average update duration of 0.03 seconds, so I do not encounter the extremely slow update as in my LightningModule
above. Can someone help me with an explanation for that?
Here some training information for my pytorch-lightning training pipeline:
- OS: Ubuntu 20.04.4
- CUDA 11.3
- DDP training strategy
- GPUs: 4x V100
- batch size: 8
- image size (width x height): 640 x 480
- number of workers in dataloader: 8
My package versions:
- pytorch lightning: 1.8.4.post0 (installed via pip)
- torch: 1.13.0
- torchvision: 0.14.0
- torchmetrics: 0.11.0
- numpy: 11.23.5
Thanks so much! Lukas