Strange behavior when calculating precision and recall with torchmetrics compared to manually calculated metrics

73 views Asked by At

While training and testing my classification model I've faced with such a problem: The model was trained on a perfectly balanced dataset (binary classification with 50%/50% classes distribution). I got pretty high accuracy on training set as well as on val/test sets. However, despite the fact that training precision and recall reach those highs too, on val/test sets these metrics drop significantly.
I got confused, so I decided to calculate these metrics manually using raw stat scores

...
stat_scores = BinaryStatScores().to(device)  
tp, fp, tn, fn, _ = stat_scores(y_logits, y_real)
...

Compared to original metrics accuracy seems to be pretty similar, but precision and recall distinguish drastically. The comparison is depicted below. (blue - calculated with torchmetrics, orange - calculated manually, x axis is a list of epochs)
enter image description here
and here is some comparisons between calculated metrics in each split
enter image description here enter image description here enter image description here

my code for calculating metrics in current batch:

def batch_metrics(
    y_real: torch.Tensor,
    y_logits: torch.Tensor,
    device: torch.device,
):
    """Calculates all necessary metrics on a single batch

    Tracked metrics: accuracy, precision, recall
    Args:
        y_real (torch.Tensor): torch tensor with real classes
        y_logits (torch.Tensor): torch tensor with predicted classes
        device (torch.device):  A target device to compute on (e.g. "cuda" or "cpu").

    Returns:
        Tuple[float, float, float, float, float, float, float, float]: A tuple of metrics
    """
    thresholds = None
    # get probabilities-like logits
    y_logits = torch.sigmoid(y_logits)

    # Accuracy
    accuracy_metric = BinaryAccuracy().to(device)
    acc = accuracy_metric(y_logits, y_real).item()

    # Precision
    precision_metric = BinaryPrecision().to(device)
    precision = precision_metric(y_logits, y_real).item()

    # Recall
    recall_metric = BinaryRecall().to(device)
    recall = recall_metric(y_logits, y_real).item()

    #  Area Under the Receiver Operating Characteristic Curve
    roc_metric = BinaryAUROC(thresholds=thresholds).to(device)
    auroc = roc_metric(y_logits, y_real).item()

    # Stat scores
    stat_scores = BinaryStatScores().to(device)
    tp, fp, tn, fn, _ = stat_scores(y_logits, y_real)

    return (
        acc,
        precision,
        recall,
        auroc,
        tp.item(),
        fp.item(),
        tn.item(),
        fn.item(),
    )

So my question is where is the ground truth?:) and maybe you can give me some hints regarding my mistakes?

1

There are 1 answers

0
Ikaryssik On

I've figured out that I forgot to shuffle data in my test\val sets so metrics per batch were erroneous...