How to implement shared mutable state in a rust-axum backend server?

388 views Asked by At

I am working on what can best be described as a ChatGpt clone website using Axum and HTMX. I'm really struggling to understand how shared mutable state can be implemented on an Axum backend. I've been referencing this example but when I finally broke down and cloned the example, it didn't even compile.. So I'm worried I'm referencing outdated information. My current main() function is this:

pub type SharedState = Arc<RwLock<AppState>>;

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AppState {
    chats: Vec<Chat>,
    current_chat_idx: u8,
}

#[tokio::main]
async fn main() {
    let state = SharedState::default();
    let router: Router<()> = Router::new()
        .route("/", get(home))
        .route("/chat/:current_agent", get(chat))
        .route("/snippets/chat_list", get(chat_list))
        .route("/snippets/chat_buttons", get(chat_buttons))
        .route("/snippets/chat_history", get(chat_history))
        .layer(Extension(state))
        .with_state(AppState::default());
    axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
        .serve(router.into_make_service())
        .await
        .unwrap();
}

The above code compiles and runs but I'm confused. If .layer() gives me access the the SharedState struct, why do I then need to pass another instance of AppState with the .with_state() method? Isn't this a waste of memory and completely pointless disregarding that this is the only way to get this to compile?

I've written these endpoints using state: Extension<SharedState> as a parameter:

pub async fn chat_history(state: Extension<SharedState>) -> Html<String> {
    let mut shared_state = state.write().unwrap();
    let agent = shared_state.current_agent();

    let messages: Vec<MessageRender> = agent
        .memory
        .cache()
        .as_ref()
        .into_iter()
        .map(|m| MessageRender::from(m))
        .collect();

    let template = fs::read_to_string("src/html/snippets/chat_history.html").unwrap();
    let r = render!(&template, messages => messages);
    Html(r)
}
pub async fn chat_list(state: Extension<SharedState>) -> Html<String> {
    let shared_state = state.read().unwrap();
    let current_chat_name = &shared_state.current_chat().name;
    println!(
        "Chat list current name, idx: {}, {}",
        current_chat_name,
        state.read().unwrap().current_chat_idx
    );

    let template = fs::read_to_string("src/html/snippets/chat_list.html").unwrap();
    let r = render!(&template,chats => shared_state.chats, current_chat_name => current_chat_name );
    Html(r)
}

they both compile fine and behave as expected on the webpage. But when I try to compile this endpoint:

pub async fn post_prompt(
    state: Extension<super::SharedState>,
    Form(prompt): Form<PromptForm>,
) -> Html<String> {
    let mut shared_state = state.write().unwrap();
    let agent = shared_state.current_agent();

    let res = agent.prompt(prompt.user_input).await;
    if let Ok(response) = res {
        println!("{}", response);
    }
    Html(String::new())
}

I get this error:

error[E0277]: the trait bound `fn(axum::Extension<Arc<std::sync::RwLock<AppState>>>, Form<PromptForm>) -> impl Future<Output = axum::response::Html<std::string::String>> {api::post_prompt}: Handler<_, _, _>` is not satisfied
   --> src/main.rs:83:58
    |
83  |         .route("/chat/:current_agent/prompt_agent", post(post_prompt))
    |                                                     ---- ^^^^^^^^^^^ the trait `Handler<_, _, _>` is not implemented for fn item `fn(axum::Extension<Arc<std::sync::RwLock<AppState>>>, Form<PromptForm>) -> impl Future<Output = axum::response::Html<std::string::String>> {api::post_prompt}`
    |                                                     |
    |                                                     required by a bound introduced by this call
    |
    = help: the following other types implement trait `Handler<T, S, B>`:
              <Layered<L, H, T, S, B, B2> as Handler<T, S, B2>>
              <MethodRouter<S, B> as Handler<(), S, B>>

I tried removing the Form parameter to try compile using only the Extension parameter, but I struggled with the same error. Really would appreciate some help

1

There are 1 answers

0
voidkandy On

The reason the post_prompt endpoint wasn't compiling was because It was awaiting a std::sync::RWLock, changing the RWLock to a tokio::sync::RWLock fixed the issue and the handler code now compiles