Detectron2: Log training and validation loss

216 views Asked by At

I want to train a detectron2 model in AzureML. In AzureML, one can log metrics. Standard, detectron2 logs losses (total loss, classifier loss, bounding box loss and etc.). However, I do not fully understand what loss it is (training, validation) and how it prevents overfitting (does it use the weights that achieved the lowest validation loss?). For the purpose of better understanding the training process, I want to log the training and validation loss in AzureML. However, I am unsure if I am doing it the correct way. I read that one can create a hook (from https://github.com/facebookresearch/detectron2/issues/810) although I'm unsure what that exactly entails (I'm new). Currently I have something like this:

# After setting up the cfg

from detectron2.engine import HookBase
from detectron2.data import build_detection_train_loader
import detectron2.utils.comm as comm

# Test/vali loss
from detectron2.utils.events import get_event_storage

class TrainingLoss(HookBase):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self.cfg.DATASETS.TRAIN = self.cfg.DATASETS.TRAIN
        self._loader = iter(build_detection_train_loader(self.cfg))

    def after_step(self):
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)

            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in
                                 comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(total_val_loss=losses_reduced,
                                                 **loss_dict_reduced)

            print(f"Training Loss (Iteration {self.trainer.iter}): {losses_reduced}")

class ValidationLoss(HookBase):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self.cfg.DATASETS.TRAIN = cfg.DATASETS.TEST
        self._loader = iter(build_detection_train_loader(self.cfg))

    def after_step(self):
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)

            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in
                                 comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(total_val_loss=losses_reduced,
                                                 **loss_dict_reduced)

            print(f"Vali Loss (Iteration {self.trainer.iter}): {losses_reduced}")



trainer = DefaultTrainer(cfg)
val_loss = ValidationLoss(cfg)
train_loss = TrainingLoss(cfg)
trainer.register_hooks([val_loss])
trainer.register_hooks([train_loss])
trainer.resume_or_load(resume=False)
trainer.train()

During training, it prints something like the following:

enter image description here

The numbers may be quite small in the image, but no matter what I print (the training or validation loss), they do not match what detectron2 would log by default. For example, the total_loss I calculated for iteration 99 is 1.936, whereas detectron2 logs 2.097. I know I logged the calculated training loss, but when I do the calculated validation loss, it is also a bit off.

Does anyone know how one should properly log these metrics? How does detectron2 actually calculate the loss? And does it save the weights that achieved the lowest validation loss, or simply after the number of iterations has ended?

1

There are 1 answers

1
Prayag Pawar On

To log the training and validation losses correctly, you may not need to create separate hooks for training and validation. Instead, you can modify the existing hooks or create a new one to handle both training and validation.

from detectron2.data import build_detection_train_loader
import detectron2.utils.comm as comm

class LossHook(HookBase):
    def __init__(self, cfg, is_validation=False):
        super().__init__()
        self.cfg = cfg.clone()
        self.cfg.DATASETS.TRAIN = self.cfg.DATASETS.TEST if is_validation else self.cfg.DATASETS.TRAIN
        self._loader = iter(build_detection_train_loader(self.cfg))
        self.loss_prefix = "val_" if is_validation else "train_"

    def after_step(self):
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)

            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {self.loss_prefix + k: v.item() for k, v in
                                 comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(total_loss=losses_reduced,
                                                 **loss_dict_reduced)

            print(f"{self.loss_prefix.capitalize()}Loss (Iteration {self.trainer.iter}): {losses_reduced}")

#training code
trainer = DefaultTrainer(cfg)
train_loss_hook = LossHook(cfg, is_validation=False)
val_loss_hook = LossHook(cfg, is_validation=True)
trainer.register_hooks([train_loss_hook, val_loss_hook])
trainer.resume_or_load(resume=False)
trainer.train()