Run onnx model inference with FastAPI

165 views Asked by At

My project is convert a vision transformers to onnx format and use it for image classification. I have a fine tunned model turn to onnx format. Now I'm running a RestAPI with FastAPI, here is my code :

inf = Inference()

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post('/prediction_multiple_pneumonia', response_model=List[int])
async def detect_pneumonia(images: List[UploadFile]=File(...)):
    """Pneumonia detection on single or multiple images
    Args:
        images (List[UploadFile], optional): List of images as bytes
    Returns:
        pred: List corresponding to images diagnostic (0 : no pneumonia, 1 : pneumonia)
    """
    images_bytes = []
    for image in images :
        images_bytes.append(image.file.read())
    pred = inf.onnx_inferences(images_bytes)
    return pred

Here is my Inference class :

class Inference():

    def __init__(self):
        self.ort_sess = onnxruntime.InferenceSession("./models/vit_model.onnx", providers=['CPUExecutionProvider'])
        self.input_name = self.ort_sess.get_inputs()[0].name
        self.output_name = self.ort_sess.get_outputs()[0].name
    
    def load_onnx_image(self, bytes):
        decoded = BytesIO(bytes)
        image = Image.open(decoded).convert("RGB")
        image = image.resize((224, 224), Image.BILINEAR)
        img = np.array(image, dtype=np.float32)
        img /= 255.0
        img = np.transpose(img, (2, 0, 1))
        return img
    
    def onnx_inferences(self, bytes):
        inputs = np.array([self.load_onnx_image(path) for path in bytes])
        outputs = self.ort_sess.run([self.output_name], {self.input_name: inputs})[0]
        logits = np.array(outputs)
        probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
        predicted_classes = np.argmax(probabilities, axis=1)
        return predicted_classes

But I get this error when sending request to my endpoint :

INFO:     127.0.0.1:50059 - "POST /prediction_multiple_pneumonia HTTP/1.1" 500 Internal Server Error
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/uvicorn/protocols/http/h11_impl.py", line 407, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 78, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/fastapi/applications.py", line 270, in __call__
    await super().__call__(scope, receive, send)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/applications.py", line 124, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/errors.py", line 184, in __call__
    raise exc
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/errors.py", line 162, in __call__
    await self.app(scope, receive, _send)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/cors.py", line 92, in __call__
    await self.simple_response(scope, receive, send, request_headers=headers)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/cors.py", line 147, in simple_response
    await self.app(scope, receive, send)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 79, in __call__
    raise exc
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 68, in __call__
    await self.app(scope, receive, sender)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/fastapi/middleware/asyncexitstack.py", line 21, in __call__
    raise e
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/fastapi/middleware/asyncexitstack.py", line 18, in __call__
    await self.app(scope, receive, send)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/routing.py", line 706, in __call__
    await route.handle(scope, receive, send)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/routing.py", line 276, in handle
    await self.app(scope, receive, send)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/routing.py", line 66, in app
    response = await func(request)
               ^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/fastapi/routing.py", line 255, in app
    content = await serialize_response(
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/fastapi/routing.py", line 141, in serialize_response
    raise ValidationError(errors, field.type_)
pydantic.error_wrappers.ValidationError: <exception str() failed>

When I try it outside my FastAPI, it works fine, but with FastAPI I have some troubles... I don't know where is the problem... This code is running when I use a Random Forest model, but when I switch to my onnx model, it doesn't work.

0

There are 0 answers