I am experimenting with some of the pytorch codes. With cross entropy loss I found some interesting results and I have used both binary cross entropy loss and cross entropy loss of pytorch.
import torch
import torch.nn as nn
X = torch.tensor([[1,0],[1,0],[0,1],[0,1]],dtype=torch.float)
softmax = nn.Softmax(dim=1)
bce_loss = nn.BCELoss()
ce_loss= nn.CrossEntropyLoss()
pred = softmax(X)
bce_loss(X,X) # tensor(0.)
bce_loss(pred,X) # tensor(0.3133)
bce_loss(pred,pred) # tensor(0.5822)
ce_loss(X,torch.argmax(X,dim=1)) # tensor(0.3133)
I expected the cross entropy loss for the same input and output to be zero. Here X, pred and torch.argmax(X,dim=1) are same/similar with some transformations. This reasoning only worked for bce_loss(X,X) # tensor(0.)
where-else all other resulted in a loss greater than zero. I speculated the output for bce_loss(pred,X)
, bce_loss(pred,pred)
and ce_loss(X,torch.argmax(X,dim=1))
should be zero.
What is the mistake here?
The reason that you are seeing this is because
nn.CrossEntropyLoss
accepts logits and targets, a.k.a X should be logits, but is already between 0 and 1.X
should be much bigger, because after softmax it will go between 0 and 1.nn.CrossEntropyLoss
works with logits, to make use of the log sum trick.The way you are currently trying after it gets activated, your predictions become about
[0.73, 0.26]
.Binary cross entropy example works since it accepts already activated logits. By the way, you probably want to use
nn.Sigmoid
for activating binary cross entropy logits. For the 2-class example, softmax is also ok.