I'm currently working on a YOLOv8-based project for object detection and segmentation. I'm using YOLOv8 for segmentation, and I want to extract binary masks for the detected objects using YOLOv8 for segmentation. However, when I try to obtain the masks using results_seg[0].masks.data, I'm getting a tensor full of zeros.
Here's my code.
while cap.isOpened():
# Read a frame from the video
success, frame = cap.read()
if success:
# Run YOLOv8 tracking on the frame, persisting tracks between frames
results_det = model1.track(frame, persist=True)
results_seg = model2(frame)
# get array results
masks = results_seg[0].masks.data
boxes = results_seg[0].boxes.data
print(masks)
# extract classes
clss = boxes[:, 5]
# get indices of results where class is 0 (object in COCO)
object_indices = torch.where(clss == 0)
# use these indices to extract the relevant masks
object_masks = masks[object_indices]
#print(clss)
# scale for visualizing results
object_mask = torch.any(object_masks, dim=0).int() * 255
object_mask_np = object_mask.cpu().numpy()
object_mask_np = object_mask_np.astype(np.int8)
# Get the boxes and track IDs
# Visualize the results on the frame
annotated_frame = results_det[0].plot()
seg_frame = results_seg[0].plot()
try:
boxes = results_det[0].boxes.xywh.cpu()
#print(results_det[0].boxes)
track_ids = results_det[0].boxes.id.int().cpu().tolist()
# Plot the tracks
for box, track_id in zip(boxes, track_ids):
x, y, w, h = box
track = track_history[track_id]
track.append((float(x), float(y))) # x, y center point
if len(track) > 30: # retain 90 tracks for 90 frames
track.pop(0)
# Draw the tracking lines
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(annotated_frame, [points], isClosed=False, color=(230, 230, 230), thickness=10)
except:
pass
# Display the annotated frame
cv2.imshow("YOLOv8 Tracking", annotated_frame)
cv2.imshow("YOLOv8 segm", seg_frame)
cv2.imshow("mask segm", object_mask_np)
The results_seg[0].plot() function works as expected and displays the segmentation results properly. It's defined as follows:
def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
"""
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
Args:
show_conf (bool): Whether to show the detection confidence score.
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
font (str): The font to use for the text.
pil (bool): Whether to return the image as a PIL Image.
example (str): An example string to display. Useful for indicating the expected format of the output.
Returns:
(None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
"""
annotator = Annotator(deepcopy(self.orig_img), line_width, font_size, font, pil, example)
boxes = self.boxes
masks = self.masks
probs = self.probs
names = self.names
hide_labels, hide_conf = False, not show_conf
if boxes is not None:
for d in reversed(boxes):
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
name = ('' if id is None else f'id:{id} ') + names[c]
label = None if hide_labels else (name if hide_conf else f'{name} {conf:.2f}')
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
if masks is not None:
im = torch.as_tensor(annotator.im, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0)
if TORCHVISION_0_10:
im = F.resize(im.contiguous(), masks.data.shape[1:], antialias=True) / 255
else:
im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im)
if probs is not None:
n5 = min(len(names), 5)
top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
text = f"{', '.join(f'{names[j] if names else j} {probs[j]:.2f}' for j in top5i)}, "
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
return np.asarray(annotator.im) if annotator.pil else annotator.im
I found that results_seg[0].boxes.data works properly and provides bounding box information. So the segmentation results themselves seem to be in order. However, when it comes to the binary masks, I'm encountering this issue.
Why I got tensor full of zeros and how to fix it?