How is cross entropy loss work in pytorch?

5.1k views Asked by At

I am experimenting with some of the pytorch codes. With cross entropy loss I found some interesting results and I have used both binary cross entropy loss and cross entropy loss of pytorch.

import torch
import torch.nn as nn

X = torch.tensor([[1,0],[1,0],[0,1],[0,1]],dtype=torch.float)
softmax = nn.Softmax(dim=1)


bce_loss = nn.BCELoss()
ce_loss= nn.CrossEntropyLoss()

pred = softmax(X)

bce_loss(X,X) # tensor(0.)
bce_loss(pred,X) # tensor(0.3133)
bce_loss(pred,pred) # tensor(0.5822)

ce_loss(X,torch.argmax(X,dim=1)) # tensor(0.3133)

I expected the cross entropy loss for the same input and output to be zero. Here X, pred and torch.argmax(X,dim=1) are same/similar with some transformations. This reasoning only worked for bce_loss(X,X) # tensor(0.) where-else all other resulted in a loss greater than zero. I speculated the output for bce_loss(pred,X), bce_loss(pred,pred) and ce_loss(X,torch.argmax(X,dim=1)) should be zero.

What is the mistake here?

1

There are 1 answers

0
Hristo Vrigazov On

The reason that you are seeing this is because nn.CrossEntropyLoss accepts logits and targets, a.k.a X should be logits, but is already between 0 and 1. X should be much bigger, because after softmax it will go between 0 and 1.

ce_loss(X * 1000, torch.argmax(X,dim=1)) # tensor(0.)

nn.CrossEntropyLoss works with logits, to make use of the log sum trick.

The way you are currently trying after it gets activated, your predictions become about [0.73, 0.26].

Binary cross entropy example works since it accepts already activated logits. By the way, you probably want to use nn.Sigmoid for activating binary cross entropy logits. For the 2-class example, softmax is also ok.