How to limit different concurrency number by service on Tonic?

288 views Asked by At

As we know Tonic uses the Service and Layer mechanism of Tower framework, if we want to add the concurrency limit on the service on the server side, we can do this:

let layer = tower::ServiceBuilder::new().
    .load_shed()
    .concurrency_limit(2)
    .into_inner();

Service::builder()
    .layer(layer)
    .add_service(greeter)
    .add_service(echo)
    .serve(addr)
    .await?;

The above code we have: the greeter and echo service maximum concurrency is 2, both for the connection of concurrent cap of a total of 4, exert the function of each layer service is the same, I understand it right?

Now what if I want to implement different concurrency limits for the greeter and echo services, say 3 for one and 2 for the other? I couldn't find a corresponding method in Tonic's documentation or code. However, for the Tower Service it is possible, I can wrap the two services separately, for example:

let greeter = tower::ServiceBuilder::new()
    .load_shed()    
    .concurrency_limit(3)
    .service(greeter);

let echo = tower::ServiceBuilder::new()
    .load_shed()    
    .concurrency_limit(2)
    .service(echo);

Server::builder()
    .add_service(greeter)
    .add_service(echo)
    .serve(addr)
    .await?;

But that won't work. Because add_service() method requires the participation of type must implement the tonic::server::NamedService trait.

I could implement this trait in a newtype, but it only defines a const NAME element. I'm not very familiar with generic programming in Rust and haven't found a way to implement it yet.

May I ask if anyone understands this or has done similar functions? I need your help, thank you!

2

There are 2 answers

2
Njuguna Mureithi On

You are almost there, the service you are looking for should ideally be generated while compiling. You should have build.rs:

fn main() -> Result<(), Box<dyn std::error::Error>> {
    tonic_build::compile_protos("proto/service.proto")?;
    Ok(())
}

The protofile might look like:

package helloworld;

// The greeting service definition.
service Greeter {
    // Sends a greeting
    rpc SayHello (HelloRequest) returns (HelloReply) {}
}

// The request message containing the user's name.
message HelloRequest {
    string name = 1;
}

// The response message containing the greetings
message HelloReply {
    string message = 1;
}

This would generate some code like this:

//---snip--
impl<T: Greeter> tonic::server::NamedService for GreeterServer<T> {
    const NAME: &'static str = "helloworld.Greeter";
}

You would still need to implement the greeter implementation:

// Implement the service skeleton for the "Greeter" service
// defined in the proto
#[derive(Debug, Default)]
pub struct MyGreeter {}

// Implement the service function(s) defined in the proto
// for the Greeter service (SayHello...)
#[tonic::async_trait]
impl Greeter for MyGreeter {
    async fn say_hello(
        &self,
        request: Request<HelloRequest>,
    ) -> Result<Response<HelloResponse>, Status> {
        println!("Received request from: {:?}", request);

        let response = greeter::HelloResponse {
            message: format!("Hello {}!", request.into_inner().name).into(),
        };

        Ok(Response::new(response))
    }
}

You would bring it all together with the server:

let greeter = MyGreeter::default();
Server::builder()
    .add_service(GreeterServer::new(greeter))
    .serve(addr)
    .await?;

With that working, you still face a problem where using different layers may change what you want. You may want to create a wrapper type.

#[derive(Clone)]
struct GreetWrapper<Service>(Service);

impl<S: NamedService> NamedService for GreetWrapper<S> {
    const NAME: &'static str = S::NAME;
}

fn main() {
    let greeter = MyGreeter::default();
    let greeter = tower::ServiceBuilder::new()
        .load_shed()
        .concurrency_limit(3)
        .service(GreeterServer::new(greeter));
    tonic::transport::Server::builder().add_service(GreetWrapper(greeter));
}

Here is an example that compiles:

/*
[dependencies]
tower =  { version = "0.4", features = ["full"] }
tonic = "*"
http = "*"
*/
use tonic::codegen::*;

#[derive(Clone)]
struct Greet<Service>(Service);

impl<S> tonic::transport::NamedService for Greet<S> {
    const NAME: &'static str = "Greet";
}

impl<B, S> Service<http::Request<B>> for Greet<S>
where
    B: Body + Send + 'static,
    B::Error: Into<StdError> + Send + 'static,
{
    type Response = http::Response<tonic::body::BoxBody>;
    type Error = std::convert::Infallible;
    type Future = BoxFuture<Self::Response, Self::Error>;
    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }
    fn call(&mut self, _req: http::Request<B>) -> Self::Future {
        // Something like 
        // self.0.call(req).await.unwrap()
        todo!()
    }
}

fn main() {
    let svc = tower::service_fn(|_request: http::Response<()>| async move { todo!() });
    let greeter = tower::ServiceBuilder::new()
        .load_shed()
        .concurrency_limit(3)
        .service(svc);
    tonic::transport::Server::builder().add_service(Greet(greeter));
}

Here is the link to compiling code.

Hopefully you should get it working, happy greeting

2
Bill.W On

I have resolved the issue with the NameService trait. The following is the compiled code, but there are still semantic issues.

use futures::{future::BoxFuture, FutureExt};
use std::{
    convert::Infallible,
    marker::PhantomData,
    task::{Context, Poll},
};
use tonic::{  // tonic-0.9.2
    transport::{NamedService, Server},
    Request, Response, Status,
};
use tower::{BoxError, Service};

#[derive(Clone)]
struct NamedWrapper<S, O> {
    inner: S,
    _origin_svc_type: PhantomData<O>,
}

impl<S, O> NamedService for NamedWrapper<S, O>
where
    O: NamedService,
{
    const NAME: &'static str = O::NAME;
}

impl<S, O, Req> Service<Req> for NamedWrapper<S, O>
where
    S: Service<Req> + 'static,
    S::Response: Send + 'static,
    S::Future: Send + 'static,
    S::Error: Into<BoxError> + std::fmt::Debug + Send + 'static,
{
    type Response = S::Response;
    // Due to the input parameter of the `Server::add_service(svc)` method -
    // the association type representing the error of the `svc` must be of the `Infailable` type
    type Error = Infallible;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx).map(|r| {
            if let Err(e) = r {
                // How to handle this error? We need this error, but the Infallible type is not allowed.
                eprintln!("{:?}", e);
            }
            Ok(())
        })
    }

    fn call(&mut self, req: Req) -> Self::Future {
        let fut = self.inner.call(req).boxed();
        Box::pin(async move {
            // It's impossible not to raise an error. In fact, the `Status` type should be returned, but it's limited by the Infailable type
            // In this example, due to the effect of `load-shed`, an `Overloaded` error will be thrown
            Ok(fut.await.unwrap())
        })
    }
}

impl<T, O> NamedWrapper<T, O>
where
    O: NamedService,
{
    pub fn new(svc: T, ost: PhantomData<O>) -> Self {
        Self {
            inner: svc,
            _origin_svc_type: ost,
        }
    }
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let addr = "[::1]:50051".parse().unwrap();

    let greeter = tower::ServiceBuilder::new()
        .load_shed()
        .concurrency_limit(3)
        .service(GreeterServer::new(MyGreeter::default()));
    let greeter = NamedWrapper::new(greeter, PhantomData::<GreeterServer<MyGreeter>>);

    let echo = tower::ServiceBuilder::new()
        .load_shed()
        .concurrency_limit(2)
        .service(EchoServer::new(MyEcho::default()));
    let echo = NamedWrapper::new(echo, PhantomData::<EchoServer<MyEcho>>);

    Server::builder()
        .add_service(greeter)
        .add_service(echo)
        .serve(addr)
        .await?;

    Ok(())
}
// Omit the remaining code and refer to the example in Tonic: https://github.com/hyperium/tonic/blob/master/examples/src/multiplex/server.rs