Use hyper to pass IP address of incoming connection to stack of Services

828 views Asked by At

I am trying to write a server using hyper that will pass the remote (client) address of the incoming connection down to a stack of Layers (that I have built using ServiceBuilder.

I have tried to use examples from the hyper docs and also this example from the Rust forums; however, these both

  • pass down data to a single handler function, not a stack of service layers
  • have a return type of Result<Response, Infallible>, which I don't want (I want to be able to drop a connection without returning a response).

Here is one of my tries (I have tried several approaches):

use std::{
    net::SocketAddr,
    time::Duration,
};

use hyper::{
    Body, Request, Response, Server,
    server::conn::AddrStream,
    service::{
        make_service_fn,
        service_fn,
    },
};
use tower::{
    Service, ServiceBuilder,
    timeout::TimeoutLayer,
};

async fn dummy_handle(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
    let response_text = format!(
        "{:?} {} {}", req.version(), req.method(), req.uri()
    );
    let response = Response::new(Body::from(response_text));
    Ok(response)
}

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
    
    // Dummy stack of service layers I want to wrap.
    let service = ServiceBuilder::new()
        .layer(TimeoutLayer::new(Duration::from_millis(1000 * 60)))
        .service_fn(dummy_handle);
    
    let make_svc = make_service_fn(|socket: &AddrStream| {
        let remote_addr = socket.remote_addr();
        let mut inner_svc = service.clone();
        let outer_svc = service_fn(move |mut req: Request<Body>| async {
            req.extensions_mut().insert(remote_addr);
            inner_svc.call(req)
        });
        
        async move { outer_svc }
    });
    
    Server::bind(&addr)
        .serve(make_svc)
        .await?;
    
    Ok(())
}

I understand full well that including error messages is helpful here; however, this is one of those cases where the Rust compiler spits out pages (or at least screenfuls) of cryptic stuff, so I am going to limit myself to a couple of choice examples.

First, I get this a lot:

type mismatch resolving `<impl Future<Output = [async output]> as Future>::Output == Result<_, _>`

for example, preceding this:

39 |            let outer_svc = service_fn(move |mut req: Request<Body>| async {
   |  _____________________________________-___________________________________-
   | | ____________________________________|
   | ||
40 | ||             req.extensions_mut().insert(remote_addr);
41 | ||             inner_svc.call(req)
42 | ||         });
   | ||         -
   | ||_________|
   | |__________the expected closure
   |            the expected `async` block
...
48 |            .serve(make_svc)
   |             ----- ^^^^^^^^ expected struct `service::util::ServiceFn`, found enum `Result`
   |             |
   |             required by a bound introduced by this call

And then the very next error message seems to be entirely contradictory:

[ several lines identical to above elided here ]

48  |            .serve(make_svc)
    |             ^^^^^ expected enum `Result`, found struct `service::util::ServiceFn`

I just can't figure out what the compiler wants from me.

2

There are 2 answers

1
Field On BEST ANSWER

Try this:

use std::{net::SocketAddr, time::Duration, convert::Infallible};

use hyper::{
    server::conn::AddrStream,
    service::{make_service_fn, service_fn},
    Body, Request, Response, Server,
};
use tower::{Service, ServiceBuilder};

async fn dummy_handle(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
    let response_text = format!("{:?} {} {}", req.version(), req.method(), req.uri());
    let response = Response::new(Body::from(response_text));
    Ok(response)
}

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let addr = SocketAddr::from(([127, 0, 0, 1], 8080));

    // Dummy stack of service layers I want to wrap.
    let service = ServiceBuilder::new()
        .timeout(Duration::from_millis(1000 * 60))
        .service_fn(dummy_handle);

    let make_svc = make_service_fn(|socket: &AddrStream| {
        let remote_addr = socket.remote_addr();
        let mut inner_svc = service.clone();
        let outer_svc = service_fn(move |mut req: Request<Body>| {
            req.extensions_mut().insert(remote_addr);
            inner_svc.call(req)
        });

        async { Ok::<_, Infallible>(outer_svc) }
    });

    Server::bind(&addr).serve(make_svc).await?;

    Ok(())
}

You were returning a future that returns another future:

|| async {
            req.extensions_mut().insert(remote_addr);
            inner_svc.call(req)
}

This is an Future<Output = Future<...>>.

Therefore, you need to turn your closure into this:

|| {
            req.extensions_mut().insert(remote_addr);
            inner_svc.call(req)
}
1
Michał Hanusek On

If you do not need ServiceBuilder then you can do:

use std::{
    net::SocketAddr,
};
use std::convert::Infallible;

use hyper::{
    Body, Request, Response, Server,
    server::conn::AddrStream,
    service::{
        make_service_fn,
        service_fn,
    },
};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>>
{
    let addr = SocketAddr::from(([127, 0, 0, 1], 8080));

    let make_svc = make_service_fn(|socket: &AddrStream| {
        let remote_addr = socket.remote_addr();
        async move {
            Ok::<_, Infallible>(service_fn(move |req: Request<Body>| async move {
                println!("remote_addr:  {:?}, request: {:?}", remote_addr, req);
                Ok::<_, Infallible>(
                    Response::new(Body::from(format!(
                        "{:?} {} {}", req.version(), req.method(), req.uri()
                    )))
                )
            }))
        }
    });

    let _ = Server::bind(&addr).serve(make_svc).await?;
    
    Ok(())
}

else

use std::{
    net::SocketAddr,
    time::Duration,
    task::{Context, Poll},
    future::Future,
    pin::Pin
};

use hyper::{
    http,
    Body, Request, Response, Server
};

use tower::{
    timeout::TimeoutLayer,
    Service, ServiceBuilder,
};

#[derive(Debug)]
pub struct CustomService;

impl Service<Request<Body>> for CustomService {
    type Response = Response<Body>;
    type Error = http::Error;
    type Future = Pin<Box<dyn Future<Output=Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { Poll::Ready(Ok(())) }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        let rsp = Response::builder();
        let body = Body::from(format!("{:?} {} {}", req.version(), req.method(), req.uri()));
        let rsp = rsp.status(200).body(body).unwrap();
        Box::pin(async {Ok(rsp) })
    }
}

#[derive(Debug)]
pub struct MakeService;

impl<T> Service<T> for MakeService {
    type Response = CustomService;
    type Error = std::io::Error;
    type Future = Pin<Box<dyn Future<Output=Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { Ok(()).into() }

    fn call(&mut self, _: T) -> Self::Future {
        Box::pin(async { Ok(CustomService) })
    }
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>>
{
    let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
    let service = ServiceBuilder::new()
    .layer(TimeoutLayer::new(Duration::from_secs(60)))
    .service(MakeService);
    let _ = Server::bind(&addr).serve(service).await?;
    Ok(())
}