My project is convert a vision transformers to onnx format and use it for image classification. I have a fine tunned model turn to onnx format. Now I'm running a RestAPI with FastAPI, here is my code :
inf = Inference()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post('/prediction_multiple_pneumonia', response_model=List[int])
async def detect_pneumonia(images: List[UploadFile]=File(...)):
"""Pneumonia detection on single or multiple images
Args:
images (List[UploadFile], optional): List of images as bytes
Returns:
pred: List corresponding to images diagnostic (0 : no pneumonia, 1 : pneumonia)
"""
images_bytes = []
for image in images :
images_bytes.append(image.file.read())
pred = inf.onnx_inferences(images_bytes)
return pred
Here is my Inference class :
class Inference():
def __init__(self):
self.ort_sess = onnxruntime.InferenceSession("./models/vit_model.onnx", providers=['CPUExecutionProvider'])
self.input_name = self.ort_sess.get_inputs()[0].name
self.output_name = self.ort_sess.get_outputs()[0].name
def load_onnx_image(self, bytes):
decoded = BytesIO(bytes)
image = Image.open(decoded).convert("RGB")
image = image.resize((224, 224), Image.BILINEAR)
img = np.array(image, dtype=np.float32)
img /= 255.0
img = np.transpose(img, (2, 0, 1))
return img
def onnx_inferences(self, bytes):
inputs = np.array([self.load_onnx_image(path) for path in bytes])
outputs = self.ort_sess.run([self.output_name], {self.input_name: inputs})[0]
logits = np.array(outputs)
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
predicted_classes = np.argmax(probabilities, axis=1)
return predicted_classes
But I get this error when sending request to my endpoint :
INFO: 127.0.0.1:50059 - "POST /prediction_multiple_pneumonia HTTP/1.1" 500 Internal Server Error
ERROR: Exception in ASGI application
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/uvicorn/protocols/http/h11_impl.py", line 407, in run_asgi
result = await app( # type: ignore[func-returns-value]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 78, in __call__
return await self.app(scope, receive, send)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/fastapi/applications.py", line 270, in __call__
await super().__call__(scope, receive, send)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/applications.py", line 124, in __call__
await self.middleware_stack(scope, receive, send)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/errors.py", line 184, in __call__
raise exc
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/errors.py", line 162, in __call__
await self.app(scope, receive, _send)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/cors.py", line 92, in __call__
await self.simple_response(scope, receive, send, request_headers=headers)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/cors.py", line 147, in simple_response
await self.app(scope, receive, send)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 79, in __call__
raise exc
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 68, in __call__
await self.app(scope, receive, sender)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/fastapi/middleware/asyncexitstack.py", line 21, in __call__
raise e
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/fastapi/middleware/asyncexitstack.py", line 18, in __call__
await self.app(scope, receive, send)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/routing.py", line 706, in __call__
await route.handle(scope, receive, send)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/routing.py", line 276, in handle
await self.app(scope, receive, send)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/starlette/routing.py", line 66, in app
response = await func(request)
^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/fastapi/routing.py", line 255, in app
content = await serialize_response(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/fastapi/routing.py", line 141, in serialize_response
raise ValidationError(errors, field.type_)
pydantic.error_wrappers.ValidationError: <exception str() failed>
When I try it outside my FastAPI, it works fine, but with FastAPI I have some troubles... I don't know where is the problem... This code is running when I use a Random Forest model, but when I switch to my onnx model, it doesn't work.