Torchvision RetinaNet predicts unwanted class background

534 views Asked by At

I want to train the pretrained RetinaNet from torchvision with my custom dataset with 2 classes (without background). To train with RetinaNet, I did follow modifications:

num_classes = 3 # num of objects to identify + background class

model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True)
# replace classification layer 
in_features = model.head.classification_head.conv[0].in_channels
num_anchors = model.head.classification_head.num_anchors
model.head.classification_head.num_classes = num_classes

cls_logits = torch.nn.Conv2d(in_features, num_anchors * num_classes, kernel_size = 3, stride=1, padding=1)
torch.nn.init.normal_(cls_logits.weight, std=0.01)  # as per pytorch code
torch.nn.init.constant_(cls_logits.bias, -math.log((1 - 0.01) / 0.01))  # as per pytorcch code 
# assign cls head to model
model.head.classification_head.cls_logits = cls_logits

The problem is that I got the detections for class 0, which is background no matter whether the num_classes is 2 or 3.

I tried to understand the source code and could not find anything like in fasterrcnn roi_head

# remove predictions with the background label
boxes = boxes[:, 1:]
scores = scores[:, 1:]
labels = labels[:, 1:]

How can I solve this problem? Any help would be very appreciated!

0

There are 0 answers