TorchMetrics MultiClass accuracy for semantic segmentation

1k views Asked by At

Let's use the following example for a semantic segmentation problem using TorchMetrics, where we predict tensors of shape (batch_size, classes, height, width):

# shape: (1, 3, 2, 2) => (batch_size, classes, height, width)
mask_multiclass_pred = torch.tensor(
    [[
        [
            # predictions for first class per pixel
            [0.85, 0.4],
            [0.4, 0.3],
        ],
        [
            # predictions for second class per pixel
            [0, 0.8],
            [0, 1],
        ],
        [
            # predictions for third class per pixel
            [0.8, 0.6],
            [0.7, 0.3],
        ]
    ]],
    dtype=torch.float32
)

Obviously, if we reduce this to the actual predicted classes as an index tensor:

reduced_pred = torch.argmax(mask_multiclass_pred, dim=1)
reduced_pred = torch.where(torch.amax(mask_multiclass_pred, dim=1) >= 0.5, reduced_pred, -1)

We get:

# shape: (1, 2, 2) => (batch_size, height, width)
tensor([[[0, 1],
         [2, 1]]])

...for the predictions.

Let's supposed the following would be our ground truth for the labels, in shape (batch_size, height, width) the MulticlassAccuracy documentation suggests the targets should be (N, ...), thus only batch_size and ... -> extra dimensions, which in semantic segmentation is height & width:

# shape: (1, 2, 2) => (batch_size, height, width)
# as suggested by TorchMetrics targets should be (N, ...) where ... is the extra dimensions, in this case 2D => class per pixel
mask_multiclass_gt = torch.tensor(
    [
        [
            # class 0, 1, or 2 per pixel => (2, 2) shape for mask
            [0, 1],
            [0, 2],
        ],
    ],
    dtype=torch.int
)

Now, if we calculate the MulticlassAccuracy:

seg_acc_cls = MulticlassAccuracy(num_classes=3, top_k=1, average="none", multidim_average="global")
seg_acc_cls(mask_multiclass_pred, mask_multiclass_gt)

We get the following result:

# shape (3,) => one accuracy per class (3 classes)
tensor([0.5000, 1.0000, 0.0000])

Why is this the output?

For example, shouldn't the first class be 0.75 instead of 0.5? Because for the default threshold of 0.5 our reduced predictions for the first class would be:

[0, 1]   =>   [True,  False]
[2, 1]   =>   [False, False]

And obviously then we have 1 TP, 2 TN, and 1 FN. So we should have (1+2)/4?!

Likewise, the second class would be:

[0, 1]   =>   [False, True]
[2, 1]   =>   [False, True]

So again, we have 1 TP, but also 1 FP (lower right), and then 2 TN, which again should be (1 TP + 2TN)/4 = 0.75 and not 1.0.

For the 3rd class we would get these reduced predictions:

[0, 1]   =>   [False, False]
[2, 1]   =>   [True,  False]

Which should be 0 TP (only lower right was True), 1 FP (lower left), and 2 TN should be 2/4 => 0.5.

1

There are 1 answers

2
DerekG On

Seems like you're having mostly a definitional issue here. Multiclass classification accuracy, (at least as defined in this package) is simply the class recall for each class i.e. TP/(TP+FN). True negatives are not taken into account in the scoring, or else sparse classes would have their accuracy dominated almost entirely by false negatives and would be fairly insensitive to the actual performance (TP and FN). For this metric, false positives do not directly impact accuracy (although, since it is multiclass and not a multilabel problem each pixel can have only one class, meaning that a FP in one class indirectly causes a FN in another class so FP are still reflected in the score).

Personally I find these multi-class / multi-label classification tasks especially on segmentation to be complex enough and metric definitions variable enough that I generally just re-implement them myself so I know what it is I'm calculating.