I'm running a Yolov8 object detector with TorchServe. In my custom_handler, I'm trying to grab the detection output JSON and also get the image of the annotated bounding boxes.
When I run the code below, I get no errors, but no image is saved. I also tried making random files with Python's basic file IO and it doesn't create those files either.
Is saving the image directly here possible? If not, what's the best practice?
import logging
import os
from collections import Counter
from PIL import Image
import torch
from torchvision import transforms
from ultralytics import YOLO
from ts.torch_handler.object_detector import ObjectDetector
logger = logging.getLogger(__name__)
try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
except ImportError as error:
XLA_AVAILABLE = False
class Yolov8Handler(ObjectDetector):
image_processing = transforms.Compose(
[transforms.Resize(640), transforms.CenterCrop(640), transforms.ToTensor()]
)
def __init__(self):
super(Yolov8Handler, self).__init__()
def initialize(self, context):
# Set device type
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif XLA_AVAILABLE:
self.device = xm.xla_device()
else:
self.device = torch.device("cpu")
# Load the model
properties = context.system_properties
self.manifest = context.manifest
model_dir = properties.get("model_dir")
self.model_pt_path = None
if "serializedFile" in self.manifest["model"]:
serialized_file = self.manifest["model"]["serializedFile"]
self.model_pt_path = os.path.join(model_dir, serialized_file)
self.model = self._load_torchscript_model(self.model_pt_path)
logger.debug("Model file %s loaded successfully", self.model_pt_path)
self.initialized = True
def _load_torchscript_model(self, model_pt_path):
"""Loads the PyTorch model and returns the NN model object.
Args:
model_pt_path (str): denotes the path of the model file.
Returns:
(NN Model Object) : Loads the model object.
"""
# TODO: remove this method if https://github.com/pytorch/text/issues/1793 gets resolved
model = YOLO(model_pt_path)
model.to(self.device)
return model
def postprocess(self, res):
output = []
for data in res:
classes = data.boxes.cls.tolist()
names = data.names
# Map to class names
classes = map(lambda cls: names[int(cls)], classes)
# Get a count of objects detected
result = Counter(classes)
output.append(dict(result))
img_array = data.plot()
im = Image.fromarray(img_array[..., ::-1])
im.save('./result.jpg')
f = open("random.txt", "w")
f.write("Save me!")
f.close()
return output
I debugged with a logger and found through os.getcwd() that TorchServe stores the files for the session in a directory inside of /tmp/models/
In my case, the file was stored in /tmp/models/b3c9cda84767441ab93c842245ee2dfb/result.jpg
The path can be specified inside of im.save() to a more appropriate directory