Designing Loss Function for Multi-Label Classification Problem with Fixed Number of Labels

26 views Asked by At

Suppose there is a multi-label classification task and I know that each image must have a fixed number of labels N. For example, for a dataset, the label has a total of 9 classes. These classes can be categorized into three overall categories, for example, category 1 contains 2 subclasses, Category 2 contains 3 subclasses, and Category 3 contains 4 subclasses. Each image is labelled with 3 classes, where each class comes from each Category (1,2,3). How should I design the loss function?

I have two ideas, one is to use the torch.nn.MultiLabelSoftMarginLoss loss function designed specifically for normal multi-label classification tasks (different samples may have different number of labels), and the other idea is to divide the loss function into 6 sub-loss functions, each corresponding to a category, and then, BCE is applied to each sub-loss function. For example,

import torch
import torch.nn.functional as F

def custom_loss(output, target):
    loss = 0.0
    num_categories = 3
    
    category_sizes = [2, 3, 4]
    
    for i in range(num_categories):
        category_start = sum(category_sizes[:i]) 
        category_end = category_start + category_sizes[i] 
        category_output = output[:, category_start:category_end]
        category_target = target[:, category_start:category_end]
        
        category_loss = F.binary_cross_entropy_with_logits(category_output, category_target)
        
        loss += category_loss  
    
    return loss / 3

Which of these two ideas is better? Or there is a better loss function design for Multi-Label Classification Problem with Fixed Number of Labels。

0

There are 0 answers