Predict only one class (person) in YOLACT/YOLACT++

1.2k views Asked by At

I want to predict only one class i.e. person from all the 84 classes that are being checked for and predicted.

For YOLACT reference https://github.com/dbolya/yolact

The results are pretty fine but I guess I just need to modify one of the codes and in a very short way but I cant manage to find out

There is one issue related to this in which I did what he mentioned like adding the 4 lines in Yolact/layers/output_utils.py and changing nothing else. Those lines are as following:

boxes = torch.cat((boxes[classes==0], boxes[classes==2]),dim=0)
scores = torch.cat((scores[classes==0], scores[classes==2]),dim=0)
masks = torch.cat((masks[classes==0], masks[classes==2]),dim=0)
classes = torch.cat((classes[classes==0], classes[classes==2]),dim=0)

But it gives the following error:

RuntimeError: strides[cur - 1] == sizes[cur] * strides[cur] INTERNAL ASSERT FAILED at 
/opt/conda/conda-bld/pytorch_1573049310284/work/torch/csrc/jit/fuser/executor.cpp:175, 
please report a bug to PyTorch. 
The above operation failed in interpreter, with the following stack trace:

terminate called without an active exception
Aborted (core dumped)

I tried adding the if condition as mentioned but still it gives error. I am using pytorch 1.3

1

There are 1 answers

1
kHarshit On BEST ANSWER

In order to show a single class (person, id:0) output at the time of inference, you simply need to add

cur_scores[1:] *= 0

after cur_scores = conf_preds[batch_idx, 1:, :] in line 83 of yolact/layers/functions/detection.py.

Then running

!python eval.py --trained_model=weights/yolact_resnet50_54_800000.pth --score_threshold=0.15 --top_k=15 --image=input_image.png:output_image.png

will give you single class inference.

As mentioned by the author in issue#218:

you can make the change to save on NMS computation, simply add cur_scores[<everything but your desired class>] *= 0

For the index, if you wanted only person (class 0), you can put 1:, but if you wanted another class than that you'd need to do 2 statements: one with :<class_idx> and the other with <class_idx>+1:. Then when you run eval, run it with --cross_class_nms=True and that'll remove all the other classes from NMS.

Other method is to modify the output in output_utils.py.