I am a complete beginner and I'm trying to use Distilbert to predict Mask in text classification. I want to compare Distilbert's predicted Mask with the labels in the original dataset. The problem I have is that the mask tensor predicted by Distilbert is the original vocabulary of model 30522 (class). If I want to compare it with the original labels, I have to convert the tensor of labels to 30522 to calculate the loss.
with torch.no_grad():
for test_batch in test_loader:
test_input_ids = test_batch['input_ids'].to(device, dtype=torch.long)
test_attention_mask = test_batch['attention_mask'].to(device)
test_label_ids = test_batch['labels'].to(device)
# Forward pass
test_outputs = model(test_input_ids, attention_mask=test_attention_mask)
test_prediction_logits = test_outputs.logits
test_prediction_logits = test_prediction_logits.max(dim=1).values # Now shape is [12, 30522]
# Convert label_ids to a binary format suitable for MultiLabelSoftMarginLoss
test_num_classes = test_prediction_logits.size(-1) # fine the total number of distinct labels
labels = torch.zeros(len(test_label_ids), test_num_classes).to(device) # Initialize the binary label matrix
for i, label_id in enumerate(test_label_ids):
labels[i, label_id] = 1 # Set the corresponding label positions to 1
# Gather the predictions for the masked positions
test_predicted_probs = torch.sigmoid(test_prediction_logits) # Convert logits to probabilities
test_predicted_labels = (test_predicted_probs > 0.5).int() # Convert probabilities to binary labels
all_preds = test_predicted_labels.detach().cpu().numpy()
all_labels = test_predicted_labels.detach().cpu().numpy()
accuracy = accuracy_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
precisions, recalls, f1s, _ = precision_recall_fscore_support(all_labels, all_preds, average=None)
30522 is too large. In this case, the predicted label and the real label seem unable to calculate Accuracy, Recall, and F1 Score.
Here is the output.
Accuracy: 1.0 Epoch 1/10 | Average Loss: 4.395622090669349e-05, Recall: 6.552650547146321e-05, F1 Score: 6.552650547146321e-05
I try to write code myself, and I want to simply try to achieve the task. Sorry if I asked an unprofessional question! Please help me!