PyTorch: CrossEntropyLoss, changing class weight does not change the computed loss

4.5k views Asked by At

According to Doc for cross entropy loss, the weighted loss is calculated by multiplying the weight for each class and the original loss.

However, in the pytorch implementation, the class weight seems to have no effect on the final loss value unless it is set to zero. Following is the code:

from torch import nn
import torch

logits = torch.FloatTensor([
    [0.1, 0.9],
])
label = torch.LongTensor([0])

criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711

# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711

# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0

As illustrated in the code, the class weight seems to have no effect unless it is set to 0, this behavior contradicts to the documentation.

Updates
I implemented a version of weighted cross entropy which is in my eyes the "correct" way to do it.

import torch
from torch import nn

def weighted_cross_entropy(logits, label, weight=None):
    assert len(logits.size()) == 2
    batch_size, label_num = logits.size()
    assert (batch_size == label.size(0))

    if weight is None:
        weight = torch.ones(label_num).float()

    assert (label_num == weight.size(0))

    x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
    log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))

    weights = torch.gather(weight, 0, label).float()

    return torch.mean((x_terms+log_terms)*weights)

logits = torch.FloatTensor([
    [0.1, 0.9],
    [0.0, 0.1],

])

label = torch.LongTensor([0, 1])

neg_weight = 0.1

weight = torch.FloatTensor([neg_weight, 1])

criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)

print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075

What I did is to multiply each instance in the batch with its associated class weight. The result is still different from the original pytorch implementation, which makes me wonder how pytorch actually implement this.

0

There are 0 answers