I get the same UAR and accuracy from my validation and training sets. How could this be happening?

42 views Asked by At

I am doing a practise model to learn how to use pytorch. I have an imbalanced class set, where one class has 70% of the labels. This means the model will just guess that label each time and get 70% accuracy.

The issue I can't explain is why the UAR is the same as the accuracy? I'll send you the relevant code and hopefully you can see the issue where I cannot:

def val_epoch(epoch, model, criterion, loader, num_classes, device): ''' Evaluate the model on the entire validation set. ''' model.eval()

# Initialize metrics
epoch_loss = torchmetrics.MeanMetric().to(device)
acc_metric = torchmetrics.Accuracy(task='multiclass', num_classes= 7).to(device)
epoch_recall = torchmetrics.Recall(average='macro', task="multiclass", num_classes= 7, validate_args=False).to(device)
#initialize a confusion matrix torchmetrics object
cmatrix = ConfusionMatrix(task="multiclass", num_classes=7).to(device)

with torch.no_grad():
    for inputs, lbls in loader:
        inputs, lbls = inputs.to(device), lbls.to(device)

        #Obtain validation loss
        outputs = model(inputs)
        loss = criterion(outputs, lbls)
        # Accumulate metrics
        
        epoch_loss(loss)
        acc_metric(outputs, lbls)
        epoch_recall(outputs, lbls)
        acculmate confusion matrix 
        cmatrix(outputs, lbls)
        
    # Calculate epoch metrics, and store in a dictionary for wandb
    metrics_dict = {
        'Loss_val': epoch_loss.compute(),
        'Accuracy_val': acc_metric.compute(),
        'UAR_val': epoch_recall.compute(),
    }

    # Compute the confusion matrix[enter image description here][1]
    cm = cmatrix.compute()

    return metrics_dict, cm
    

def train_model(model, train_loader, val_loader, optimizer, criterion, class_names, n_epochs, project_name, ident_str=None):

num_classes = len(class_names)
model.to(device)

# Initialise Weights and Biases (wandb) project
if ident_str is None:
  ident_str = datetime.now().strftime("%Y%m%d_%H%M%S")
exp_name = f"{model.__class__.__name__}_{ident_str}"
run = wandb.init(project=project_name, name=exp_name)

try:
    # Train by iterating over epochs
    for epoch in tq.tqdm(range(n_epochs), total=n_epochs, desc='Epochs'):
        train_metrics_dict = train_epoch(epoch, model, optimizer, criterion,
        train_loader, num_classes, device)
                
        val_metrics_dict, cm = val_epoch(epoch, model, criterion, 
        val_loader, num_classes, device)
        wandb.log({**train_metrics_dict, **val_metrics_dict})
finally:
    run.finish()
0

There are 0 answers