How to share state between async-graphql and axum

1.9k views Asked by At

I'm now trying to build a graphql server using axum and async-graphql.

The biggest problem is that data sharing between graphql and axum is not good. The reason why I want to do this is because I want to pass the http header for authentication received by axum to the world of graphql.

Can anyone solve this problem or suggest another way?

Cargo.toml

[package]
name = "app"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
axum = { version = "0.6.1", features = ["headers"] }
tokio = { version = "1.19.2", features = ["rt-multi-thread", "macros"] }
serde = { version = "1.0.136", features = ["derive"] }
async-graphql = { version = "5.0.4", features = ["chrono"] }
sqlx = { version = "0.6.2", features = [ "runtime-actix-native-tls", "postgres", "chrono" ] }
dotenv = "0.15.0"
tower-http = { version = "0.3.5", features = ["cors", "trace"] }
tokio-stream = "0.1.11"
chrono = "0.4.23"
jsonwebtoken = "8.2.0"
thiserror = "1.0.38"
async-trait = "0.1.60"

main.rs


mod db;
mod repositories;
mod resolvers;

use async_graphql::{
    http::{playground_source, GraphQLPlaygroundConfig},
    Request, Response, Schema,
};
use axum::{
    extract::{Extension, State},
    http::{
        header::{ACCEPT, AUTHORIZATION},
        HeaderValue, Method, Request as AxumRequest,
    },
    middleware::Next,
    response::{Html, IntoResponse, Response as AxumResponse},
    routing::get,
    Json, Router,
};
use dotenv::dotenv;
use resolvers::{QueryRoot, Subscription};
use std::net::SocketAddr;
use tower_http::cors::CorsLayer;

use crate::{db::DB, resolvers::Mutation};

pub type MainSchema = Schema<QueryRoot, Mutation, Subscription>;

async fn graphql_handler(schema: Extension<MainSchema>, req: Json<Request>) -> Json<Response> {
    schema.execute(req.0).await.into()
}

async fn graphql_playground() -> impl IntoResponse {
    Html(playground_source(GraphQLPlaygroundConfig::new("/")))
}

#[derive(Clone, Debug)]
pub struct AppState {
    db: DB,
    token: Option<String>,
}

impl AppState {
    fn set_token(mut self, token: String) {
        self.token = Some(token)
    }
}

async fn propagate_header<B>(
    State(state): State<&AppState>,
    req: AxumRequest<B>,
    next: Next<B>,
) -> AxumResponse {
    let token = req.headers().get("Authorization");

    if token.is_some() {
        // TODO: Put token in state.
    };

    next.run(req).await
}

#[tokio::main]
async fn main() {
    dotenv().ok();
    let server = async {
        let db = DB::new().await;
        let state = AppState { db, token: None };
        let schema = Schema::build(QueryRoot, Mutation, Subscription)
            // .limit_depth(5)
            .data(&state)
            .finish();

        let cors_layer = CorsLayer::new()
            .allow_origin("*".parse::<HeaderValue>().unwrap())
            .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
            .allow_headers(vec![AUTHORIZATION, ACCEPT]);

        let app = Router::new()
            .route("/", get(graphql_playground).post(graphql_handler))
            .layer(cors_layer)
            .layer(axum::middleware::from_fn_with_state(
                &state,
                propagate_header,
            ))
            .layer(Extension(schema));

        let addr = SocketAddr::from(([0, 0, 0, 0], 8009));
        axum::Server::bind(&addr)
            .serve(app.into_make_service())
            .await
            .unwrap();
    };

    tokio::join!(server);
}

error

error[E0597]: `state` does not live long enough
  --> src/main.rs:71:19
   |
69 |           let schema = Schema::build(QueryRoot, Mutation, Subscription)
   |  ______________________-
70 | |             // .limit_depth(5)
71 | |             .data(&state)
   | |___________________^^^^^^- argument requires that `state` is borrowed for `'static`
   |                     |
   |                     borrowed value does not live long enough
...
93 |       };
   |       - `state` dropped here while still borrowed
2

There are 2 answers

0
Aaron On

You'll want to use Arc here, see here.

let state = AppState { db, token: None };
let arced = Arc::new(state);
let schema = Schema::build(QueryRoot, Mutation, Subscription)
    // .limit_depth(5)
    .data(arced.clone())
    .finish();

... etc ...

The next issue I can see is it seems like your state will be mutable, based off of looking at your token field, which you cannot do with an arc alone. You'll likely need to add in a Mutex, as well, or some other lock (e.g. RwLock, etc).

0
Shu On

You shouldn't put the token in the AppState.

AppState is a global state for your whole application. It contains information shared by all the connections to the service.

Putting the token here works while you test your application as you are the only user, but as soon as multiple users connect to your server, you could run into a situation where a normal user connects but, before their request reaches the handler, an admin connects and writes its token in the global state, and suddenly the normal user has admin permissions inside the handler.

The right way to solve your problem is to put the token (or, even better, the authenticated user) in the GraphQl context.

It can be done with a mix of:

For example:

#[derive(Clone, Debug)]
struct AppState {
    db: DatabaseConnection,
}

#[derive(Default, Debug)]
struct AuthenticatedUser {
    name: String,
    other_fields: ...
}

#[async_trait]
impl<S> FromRequestParts<S> for AuthenticatedUser
where
    S: Send + Sync,
{
    type Rejection = Response;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let token =
            TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
                .await
                .map(|token| token.token().to_string())
                .unwrap_or_default();


        // token now contains your bearer token, or an empty string
        // if the header was not present

        // Extract the global state containing the db connection
        use axum::RequestPartsExt;

        let Extension(state) = parts
            .extract::<Extension<AppState>>()
            .await
            .map_err(|err| err.into_response())?;


        // here you should actually perform the authorization
        // by using the `token` and the `state.db` to connect to the
        // db and check the user authentication, then fill the struct

        Ok(AuthenticatedUser{
            name: "Anonymous".into()
            ...
        })
    }
}

Then, in your router:

let app = Router::new()
        .route("/gql", post(graphql_handler))
        .layer(Extension(state))
        .layer(Extension(schema));

In your graphql handler:

pub(crate) async fn graphql_handler(
    user: AuthenticatedUser,
    schema: Extension<ServiceSchema>,
    req: GraphQLRequest,
) -> GraphQLResponse {
    // pass the `user` into the Context data
    schema.execute(req.into_inner().data(user)).await.into()
}

And, finally, in your graphql query/mutation handler:

    pub async fn whetever(&self, ctx: &Context<'_>, id: i32) -> anyhow::Result<Option<Whatever>> {
        // extract the user from the Context data
        let user = ctx.data::<AuthenticatedUser>();
        ...
    }