Get starlette request body in the middleware context

17k views Asked by At

I have such middleware

class RequestContext(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
        request_id = request_ctx.set(str(uuid4()))  # generate uuid to request
        body = await request.body()
        if body:
            logger.info(...)  # log request with body
        else:
            logger.info(...)  # log request without body
 
        response = await call_next(request)
        response.headers['X-Request-ID'] = request_ctx.get()
        logger.info("%s" % (response.status_code))
        request_ctx.reset(request_id)

        return response

So the line body = await request.body() freezes all requests that have body and I have 504 from all of them. How can I safely read the request body in this context? I just want to log request parameters.

7

There are 7 answers

3
Yagiz Degirmenci On BEST ANSWER

I would not create a Middleware that inherits from BaseHTTPMiddleware since it has some issues, FastAPI gives you a opportunity to create your own routers, in my experience this approach is way better.

from fastapi import APIRouter, FastAPI, Request, Response, Body
from fastapi.routing import APIRoute

from typing import Callable, List
from uuid import uuid4


class ContextIncludedRoute(APIRoute):
    def get_route_handler(self) -> Callable:
        original_route_handler = super().get_route_handler()

        async def custom_route_handler(request: Request) -> Response:
            request_id = str(uuid4())
            response: Response = await original_route_handler(request)

            if await request.body():
                print(await request.body())

            response.headers["Request-ID"] = request_id
            return response

        return custom_route_handler


app = FastAPI()
router = APIRouter(route_class=ContextIncludedRoute)


@router.post("/context")
async def non_default_router(bod: List[str] = Body(...)):
    return bod


app.include_router(router)

Works as expected.

b'["string"]'
INFO:     127.0.0.1:49784 - "POST /context HTTP/1.1" 200 OK
0
SteveTheProgrammer On

In case you still wanted to use BaseHTTP, I recently ran into this problem and came up with a solution:

Middleware Code

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
import json
from .async_iterator_wrapper import async_iterator_wrapper as aiwrap

class some_middleware(BaseHTTPMiddleware):
   async def dispatch(self, request:Request, call_next:RequestResponseEndpoint):
      # --------------------------
      # DO WHATEVER YOU TO DO HERE
      #---------------------------
      
      response = await call_next(request)

      # Consuming FastAPI response and grabbing body here
      resp_body = [section async for section in response.__dict__['body_iterator']]
      # Repairing FastAPI response
      response.__setattr__('body_iterator', aiwrap(resp_body)

      # Formatting response body for logging
      try:
         resp_body = json.loads(resp_body[0].decode())
      except:
         resp_body = str(resp_body)

async_iterator_wrapper Code from TypeError from Python 3 async for loop

class async_iterator_wrapper:
    def __init__(self, obj):
        self._it = iter(obj)
    def __aiter__(self):
        return self
    async def __anext__(self):
        try:
            value = next(self._it)
        except StopIteration:
            raise StopAsyncIteration
        return value

I really hope this can help someone! I found this very helpful for logging.

Big thanks to @Eddified for the aiwrap class

0
LoveToCode On

You can do this safely with a generic ASGI middleware:

from typing import Iterable, List, Protocol, Generator

import pytest

from starlette.responses import Response
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Scope, Send, Receive, Message


class Logger(Protocol):
    def info(self, message: str) -> None:
        ...


class BodyLoggingMiddleware:
    def __init__(
        self,
        app: ASGIApp,
        logger: Logger,
    ) -> None:
        self.app = app
        self.logger = logger

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"]  != "http":
            await self.app(scope, receive, send)
            return
        
        done = False
        chunks: "List[bytes]" = []

        async def wrapped_receive() -> Message:
            nonlocal done
            message = await receive()
            if message["type"] == "http.disconnect":
                done = True
                return message
            body = message.get("body", b"")
            more_body = message.get("more_body", False)
            if not more_body:
                done = True
            chunks.append(body)
            return message
        try:
            await self.app(scope, wrapped_receive, send)
        finally:
            while not done:
                await wrapped_receive()
            self.logger.info(b"".join(chunks).decode())  # or somethin


async def consume_body_app(scope: Scope, receive: Receive, send: Send) -> None:
    done = False
    while not done:
        msg = await receive()
        done = "more_body" not in msg
    await Response()(scope, receive, send)


async def consume_partial_body_app(scope: Scope, receive: Receive, send: Send) -> None:
    await receive()
    await Response()(scope, receive, send)


class TestException(Exception):
    pass


async def consume_body_and_error_app(scope: Scope, receive: Receive, send: Send) -> None:
    done = False
    while not done:
        msg = await receive()
        done = "more_body" not in msg
    raise TestException


async def consume_partial_body_and_error_app(scope: Scope, receive: Receive, send: Send) -> None:
    await receive()
    raise TestException


class TestLogger:
    def __init__(self, recorder: List[str]) -> None:
        self.recorder = recorder
    
    def info(self, message: str) -> None:
        self.recorder.append(message)


@pytest.mark.parametrize(
    "chunks, expected_logs", [
        ([b"foo", b" ", b"bar", b" ", "baz"], ["foo bar baz"]),
    ]
)
@pytest.mark.parametrize(
    "app",
    [consume_body_app, consume_partial_body_app]
)
def test_body_logging_middleware_no_errors(chunks: Iterable[bytes], expected_logs: Iterable[str], app: ASGIApp) -> None:
    logs: List[str] = []
    client = TestClient(BodyLoggingMiddleware(app, TestLogger(logs)))

    def chunk_gen() -> Generator[bytes, None, None]:
        yield from iter(chunks)

    resp = client.get("/", data=chunk_gen())
    assert resp.status_code == 200
    assert logs == expected_logs


@pytest.mark.parametrize(
    "chunks, expected_logs", [
        ([b"foo", b" ", b"bar", b" ", "baz"], ["foo bar baz"]),
    ]
)
@pytest.mark.parametrize(
    "app",
    [consume_body_and_error_app, consume_partial_body_and_error_app]
)
def test_body_logging_middleware_with_errors(chunks: Iterable[bytes], expected_logs: Iterable[str], app: ASGIApp) -> None:
    logs: List[str] = []
    client = TestClient(BodyLoggingMiddleware(app, TestLogger(logs)))

    def chunk_gen() -> Generator[bytes, None, None]:
        yield from iter(chunks)

    with pytest.raises(TestException):
        client.get("/", data=chunk_gen())
    assert logs == expected_logs


if __name__ == "__main__":
    import os
    pytest.main(args=[os.path.abspath(__file__)])
1
Duccio On

The issue is in Uvicorn. The FastAPI/Starlette::Request class does cache the body, but the Uvicorn function RequestResponseCycle::request() does not, so if you instantiate two or more Request classes and ask for the body(), only the instance that asks for the body first will have a valid body.

I solved creating a mock function that returns a cached copy of the request():

class LogRequestsMiddleware:
def __init__(self, app:ASGIApp) -> None:
    self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
    receive_cached_ = await receive()
    async def receive_cached():
        return receive_cached_
    request = Request(scope, receive = receive_cached)
        
    # do what you need here

    await self.app(scope, receive_cached, send)

app.add_middleware(LogRequestsMiddleware)
0
bc30138 On

Just because such solution not stated yet, but it's worked for me:

from typing import Callable, Awaitable

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import StreamingResponse
from starlette.concurrency import iterate_in_threadpool

class LogStatsMiddleware(BaseHTTPMiddleware):
    async def dispatch(  # type: ignore
        self, request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]],
    ) -> Response:
        response = await call_next(request)
        response_body = [section async for section in response.body_iterator]
        response.body_iterator = iterate_in_threadpool(iter(response_body))
        logging.info(f"response_body={response_body[0].decode()}")
        return response

def init_app(app):
    app.add_middleware(LogStatsMiddleware)

iterate_in_threadpool actually making from iterator object async Iterator

If you look on implementation of starlette.responses.StreamingResponse you'll see, that this function used exactly for this

0
Tsvi Sabo On

If you only want to read request parameters, best solution i found was to implement a "route_class" and add it as arg when creating the fastapi.APIRouter, this is because parsing the request within the middleware is considered problematic The intention behind the route handler from what i understand is to attach exceptions handling logic to specific routers, but since it's being invoked before every route call, you can use it to access the Request arg

Fastapi documentation

You could do something as follows:

class MyRequestLoggingRoute(APIRoute):
    def get_route_handler(self) -> Callable:
        original_route_handler = super().get_route_handler()

        async def custom_route_handler(request: Request) -> Response:
            body = await request.body()
            if body:
               logger.info(...)  # log request with body
            else:
               logger.info(...)  # log request without body
            try:

                return await original_route_handler(request)
            except RequestValidationError as exc:
               detail = {"errors": exc.errors(), "body": body.decode()}
               raise HTTPException(status_code=422, detail=detail)

        return custom_route_handler
0
Syed Hammad Ahmed On

Turns out await request.json() can only be called once per the request cycle. So if you need to access the request body in multiple middlewares for filtering or authentication etc then there's a work around which is to create a custom middleware that copies the contents of request body in request.state. The middleware should be loaded as early as necessary. Each middleware next in chain or controller can then access the request body from request.state instead of calling await request.json() again. Here's a example:

class CopyRequestMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        request_body = await request.json()
        request.state.body = request_body

        response = await call_next(request)
        return response

class LogRequestMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # Since it'll be loaded after CopyRequestMiddleware it can access request.state.body.
        request_body = request.state.body
        print(request_body)
    
        response = await call_next(request)
        return response

The controller will access request body from request.state as well

request_body = request.state.body