I wrote the following class to perform instance segmentation and return the masks of a given class. The code seems to be running randomly and it's not deterministic. The labels printed (as well as the number of labels) change at every execution even if I am running the code on the same input image containing a single person. Is there a problem in how I load the weights? The code is not printing any warning nor exception. Note that I am running the code on the CPU.
import numpy as np
import torch
from torch import Tensor
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
import torchvision.transforms as T
import PIL
from PIL import Image
class RetinaNet:
def __init__(self, weights: RetinaNet_ResNet50_FPN_V2_Weights = RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1):
# Load the pre-trained DeepLabV3 model
self.weights = weights
self.model = retinanet_resnet50_fpn_v2(
pretrained=RetinaNet_ResNet50_FPN_V2_Weights
)
self.model.eval()
# Check if a GPU is available and if not, use a CPU
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
# Define the transformation
self.transform = T.Compose([
T.ToTensor(),
])
def infer_on_image(self, image: PIL.Image.Image, label: str) -> Tensor:
# Transform image
input_tensor = self.transform(image)
input_tensor = input_tensor.unsqueeze(0)
input_tensor.to(self.device)
# Run model
with torch.no_grad():
predictions = self.model(input_tensor)
# Post-processing to create masks for requested label
label_index = self.get_label_index(label)
boxes = predictions[0]['boxes'][predictions[0]['labels'] == label_index]
print('labels', predictions[0]['labels']) # random output
masks = torch.zeros((len(boxes), input_tensor.shape[1], input_tensor.shape[2]), dtype=torch.uint8)
for i, box in enumerate(boxes.cpu().numpy()):
x1, y1, x2, y2 = map(int, box)
masks[i, y1:y2, x1:x2] = 1
return masks
def get_label_index(self,label: str) -> int:
return self.weights.value.meta['categories'].index(label)
def get_label(self, label_index: int) -> str:
return self.weights.value.meta['categories'][label_index]
@staticmethod
def load_image(file_path: str) -> PIL.Image.Image:
return Image.open(file_path).convert("RGB")
if __name__ == '__main__':
from matplotlib import pyplot as plt
image_path = 'person.jpg'
# Run inference
retinanet = RetinaNet()
masks = retinanet.infer_on_image(
image=retinanet.load_image(image_path),
label='person'
)
# Plot image
plt.imshow(retinanet.load_image(image_path))
plt.show()
# PLot mask
for i, mask in enumerate(masks):
mask = mask.unsqueeze(2)
plt.title(f'mask {i}')
plt.imshow(mask)
plt.show()
For me, I always implement the below script and reproduce exactly same result, except for using DDP.
At the start point of
__main__
script,DDP dataloader samplers with asynchronous task reproduces the different data augmentation by time. It can be handled with some tricks, but not used on my way.
At the implementation of
dataloader
class,