While training and testing my classification model I've faced with such a problem:
The model was trained on a perfectly balanced dataset (binary classification with 50%/50% classes distribution). I got pretty high accuracy on training set as well as on val/test sets. However, despite the fact that training precision and recall reach those highs too, on val/test sets these metrics drop significantly.
I got confused, so I decided to calculate these metrics manually using raw stat scores
stat_scores = BinaryStatScores().to(device)
tp, fp, tn, fn, _ = stat_scores(y_logits, y_real)
Compared to original metrics accuracy seems to be pretty similar, but precision and recall distinguish drastically. The comparison is depicted below. (blue - calculated with torchmetrics, orange - calculated manually, x axis is a list of epochs)
and here is some comparisons between calculated metrics in each split
my code for calculating metrics in current batch:
def batch_metrics(
y_real: torch.Tensor,
y_logits: torch.Tensor,
device: torch.device,
"""Calculates all necessary metrics on a single batch
Tracked metrics: accuracy, precision, recall
y_real (torch.Tensor): torch tensor with real classes
y_logits (torch.Tensor): torch tensor with predicted classes
device (torch.device): A target device to compute on (e.g. "cuda" or "cpu").
Tuple[float, float, float, float, float, float, float, float]: A tuple of metrics
thresholds = None
# get probabilities-like logits
y_logits = torch.sigmoid(y_logits)
# Accuracy
accuracy_metric = BinaryAccuracy().to(device)
acc = accuracy_metric(y_logits, y_real).item()
# Precision
precision_metric = BinaryPrecision().to(device)
precision = precision_metric(y_logits, y_real).item()
# Recall
recall_metric = BinaryRecall().to(device)
recall = recall_metric(y_logits, y_real).item()
# Area Under the Receiver Operating Characteristic Curve
roc_metric = BinaryAUROC(thresholds=thresholds).to(device)
auroc = roc_metric(y_logits, y_real).item()
# Stat scores
stat_scores = BinaryStatScores().to(device)
tp, fp, tn, fn, _ = stat_scores(y_logits, y_real)
return (
So my question is where is the ground truth?:) and maybe you can give me some hints regarding my mistakes?
I've figured out that I forgot to shuffle data in my test\val sets so metrics per batch were erroneous...