pytorch lightning: validation with torchmetric extremely slow

743 views Asked by At

I recently tried to implement a DeepLabV3 training pipeline using pytorch lightning. I wanted to use the build-in torchmetrics.JaccardIndex as my evaluation metric. My LightningModule looks like this:

import torchmetrics 
from pytorch_lightning import LightningModule
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50


class DeepLabV3LightningModule(LightningModule):
    def __init__(self):
        self.model = deeplabv3_resnet50(
            num_classes=38,
            aux_loss=False
        )
        self.loss = nn.CrossEntropyLoss(ignore_index=255, reduction="mean")
        self.iou_metric = torchmetrics.JaccardIndex(
            task="multiclass", 
            threshold=0.5, 
            num_classes=38,
            average="macro",
        )

    def training_step(self, batch, batch_idx):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        loss = self.loss(preds, masks)       
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        loss = self.loss(preds, masks)
        preds = torch.softmax(preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)
        
        # measure runtime of metric update
        start = timer()
        self.iou_metric.update(pred_labels, masks)
        elapsed = timer() - start
        return elapsed

    def validation_epoch_end(self, outputs):
        avg_runtime = round(mean(outputs), 4)
        print(f"GPU {self.local_rank}: {avg_runtime} seconds")

When using this validation procedure, it is extremely slow. On average, the update step of the metric takes 23.4 seconds. However, the first 3 updates are very fast (<1 second), then they become slow.

I tried to reproduce this behavior in a MWE:

from timeit import default_timer as timer
from statistics import mean
import torchmetrics
import torch

num_classes = 38

iou_metric = torchmetrics.JaccardIndex(
    task="multiclass",
    threshold=0.5, 
    num_classes=num_classes,
    average="macro"
).to("cuda")

# dummy labels in shape [b, h, w]
label_mask = torch.randint(low=0, high=num_classes-1, size=(8, 480, 640), device="cuda")

# dummy predicted labels in shape [b, h, w]
pred_mask = torch.randint(low=0, high=num_classes-1, size=(8, 480, 640), device="cuda")


runtime_hist = []
for i in range(100):
    start = timer()
    iou_metric.update(label_mask, pred_mask)
    elapsed = timer() - start
    runtime_hist.append(elapsed)


avg_runtime = round(mean(runtime_hist), 2)
print(avg_runtime)

Here I get an average update duration of 0.03 seconds, so I do not encounter the extremely slow update as in my LightningModule above. Can someone help me with an explanation for that?

Here some training information for my pytorch-lightning training pipeline:

  • OS: Ubuntu 20.04.4
  • CUDA 11.3
  • DDP training strategy
  • GPUs: 4x V100
  • batch size: 8
  • image size (width x height): 640 x 480
  • number of workers in dataloader: 8

My package versions:

  • pytorch lightning: 1.8.4.post0 (installed via pip)
  • torch: 1.13.0
  • torchvision: 0.14.0
  • torchmetrics: 0.11.0
  • numpy: 11.23.5

Thanks so much! Lukas

0

There are 0 answers