Suppose there is a multi-label classification task and I know that each image must have a fixed number of labels N. For example, for a dataset, the label has a total of 9 classes. These classes can be categorized into three overall categories, for example, category 1 contains 2 subclasses, Category 2 contains 3 subclasses, and Category 3 contains 4 subclasses. Each image is labelled with 3 classes, where each class comes from each Category (1,2,3). How should I design the loss function?
I have two ideas, one is to use the torch.nn.MultiLabelSoftMarginLoss loss function designed specifically for normal multi-label classification tasks (different samples may have different number of labels), and the other idea is to divide the loss function into 6 sub-loss functions, each corresponding to a category, and then, BCE is applied to each sub-loss function. For example,
import torch
import torch.nn.functional as F
def custom_loss(output, target):
loss = 0.0
num_categories = 3
category_sizes = [2, 3, 4]
for i in range(num_categories):
category_start = sum(category_sizes[:i])
category_end = category_start + category_sizes[i]
category_output = output[:, category_start:category_end]
category_target = target[:, category_start:category_end]
category_loss = F.binary_cross_entropy_with_logits(category_output, category_target)
loss += category_loss
return loss / 3
Which of these two ideas is better? Or there is a better loss function design for Multi-Label Classification Problem with Fixed Number of Labels。