The compute method of metric Multiclass Accuracy was called before the update method

580 views Asked by At

This is my code:

def _train_one_epoch(self, epoch: int) -> None:
    self.model.train()
    train_running_loss = 0.0
    training_progress = 0
    # len of all batches
    total_batches = (
        len(self.dataset_config.train_data) // self.training_config.batch_size
    )
    for inputs, labels in self.train_loader:
        # progress bar
        training_progress += 1
        print(
            f"Training Progress: [-- {training_progress/total_batches*100:.2f}% --] in epoch: {epoch} from {self.training_config.epochs} epochs",
            end="\r",
        )
        # move the inputs and labels to device
        inputs = inputs.to(self.training_config.device)
        labels = labels.to(self.training_config.device)
        # zero the parameter gradients
        self.optimizer.zero_grad()
        # forward pass
        outputs = self.model(inputs)
        # calculate the loss
        loss = self.training_config.criterion(outputs, labels)
        # backward pass
        loss.backward()
        # update the weights
        self.optimizer.step()
        # Calculate the metrics
        train_running_loss += loss.item() * inputs.size(0)
        preds = torch.argmax(outputs, dim=1)
    train_epoch_loss = train_running_loss / len(self.dataset_config.train_data)

    # Metrics on all valid data
    train_accuracy = self.accuracy.compute()
    train_precision = self.precision.compute()
    train_recall = self.recall.compute()
    train_f1_score = self.f1_score.compute()
    self.writer.add_scalar("Training/Loss", train_epoch_loss, epoch)
    self.writer.add_scalar("Training/Accuracy", train_accuracy, epoch)
    self.writer.add_scalar("Training/Precision", train_precision, epoch)
    self.writer.add_scalar("Training/Recall", train_recall, epoch)
    self.writer.add_scalar("Training/F1", train_f1_score, epoch)
    # Reset the metrics
    self.accuracy.reset()
    self.precision.reset()
    self.recall.reset()
    self.f1_score.reset()

after the warning, all the accuracy and other metrics values are printed as zero!

this is the output:

/home/sadegh/miniconda3/envs/ml/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
/home/sadegh/miniconda3/envs/ml/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric MulticlassPrecision was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
/home/sadegh/miniconda3/envs/ml/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric MulticlassRecall was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
/home/sadegh/miniconda3/envs/ml/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric MulticlassF1Score was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
Validation Progress: [-- 9.68% --] in epoch: 8 from 25 epochs Epoch: 8/25 | Valid Loss: 1.9874 | Valid Accuracy: 0.0000 | Valid Precision: 0.0000 | Valid Recall: 0.0000 | Valid F1: 0.0000

i don't get the warning. i checked torchmetrics documentation and it didn't use update method there to use these metrics. so what should be the problem?

1

There are 1 answers

0
BananaNosh On

According to

https://pytorch.org/torcheval/stable/generated/torcheval.metrics.MulticlassAccuracy.html

you need to use update. So in your case I would try

# move the inputs and labels to device
inputs = inputs.to(self.training_config.device)
labels = labels.to(self.training_config.device)
# zero the parameter gradients
self.optimizer.zero_grad()
# forward pass
outputs = self.model(inputs)
self.accuracy.update(outputs, labels)
self.precision.update(outputs, labels)
self.recall.update(outputs, labels)
self.f1_score.update(outputs, labels)
...

Hope this helps.