How to add binary loss function to a multiclass segmentation task

15 views Asked by At

I have a multi-class segmentation task (using DeepLabV3+ model) I currently am using both the Dice Loss (for segmentaion) and Cross Entropy Loss for classification together to provide me with a weighted loss function, i.e.

self.optimizer.zero_grad()

seg_pred, clf_pred = self.model.forward(x)

classes = class_count(y, seg_pred.shape[1])
# Calculate the weighted loss
# - The Dice loss measures the overlap between the predicted segmentation mask and the ground truth mask.
sgm_loss = self.loss(seg_pred, y)
# - CrossEntropy loss: multi-class classification tasks
#        is applied to each pixel independently, treating the segmentation task as a pixel-wise classification problem
clf_loss = self.clf_loss(clf_pred, classes)

loss = ((1.0 - self.loss_weight) * clf_loss) + (self.loss_weight * sgm_loss)

loss.backward()
self.optimizer.step()

I want to be able to add an additional weight to the loss function. My application is looking at detecting artefacts in an image. The purpose of the binary loss would be to tell the model that it is more important to determine if a pixel is an artefact or not, more than what type of artefact it is. That is to say, if the model detects it is an artefact, but gets the type of artefact wrong, that is better than saying it is normal.

I was planning on using the torch.nn.BCELoss function but cannot work out the right/best way to convert the DeepLabV3+ model outputs to a binary representation.

0

There are 0 answers