I am writing a service in Rust and came up with such a solution to abstract logic from the database implementation
use std::ops::AsyncFnOnce;
use async_trait::async_trait;
#[async_trait]
pub trait DbTransaction {
type Db: Send + Sync;
fn db(&mut self) -> Result<&mut Self::Db, anyhow::Error>;
async fn commit(self) -> Result<(), anyhow::Error>;
async fn rollback(self) -> Result<(), anyhow::Error>;
}
#[async_trait]
pub trait DbPool {
type Db: Send + Sync;
type Transaction: DbTransaction<Db = Self::Db> + Send + Sync;
async fn acquire(&self) -> Result<Self::Db, anyhow::Error>;
async fn transaction(&self) -> Result<Self::Transaction, anyhow::Error>;
}
pub async fn run_transaction<Trans, Res, Scope>(mut trans: Trans, scope: Scope) -> Result<Res, anyhow::Error>
where Trans: DbTransaction,
Scope: AsyncFnOnce(&mut Trans::Db) -> Result<Res, anyhow::Error>
{
let result = scope(trans.db()?).await;
match result {
Ok(res) => {
trans.commit().await?;
Ok(res)
},
Err(err) => {
if let Err(rollback_err) = trans.rollback().await {
eprintln!("rollback error: {}", rollback_err)
}
Err(err)
}
}
}
In another module of the same package, I implemented these traits for sqlx library:
use std::{env, marker::PhantomData};
use async_trait::async_trait;
use dotenvy::dotenv;
use sqlx::{PgConnection, PgPool, Postgres, Transaction};
use crate::domain::core::db::{DbPool, DbTransaction};
pub type PgDb = PgConnection; // EXTERNAL TYPE
pub struct PgTransaction<'a>(Transaction<'a, Postgres>);
#[async_trait]
impl<'a> DbTransaction for PgTransaction<'a> {
type Db = PgDb;
fn db(&mut self) -> Result<&mut Self::Db, anyhow::Error> {
Ok(&mut *self.0) // MUST RETURN REFERENCE
}
async fn commit(self) -> Result<(), anyhow::Error> {
self.0.commit().await?;
Ok(())
}
async fn rollback(self) -> Result<(), anyhow::Error> {
self.0.rollback().await?;
Ok(())
}
}
pub struct PgPoolImpl<'a>(PgPool, PhantomData<&'a mut PgPool>);
#[async_trait]
impl<'a> DbPool for PgPoolImpl<'a> {
type Db = PgDb;
type Transaction = PgTransaction<'a>;
async fn acquire(&self) -> Result<PgDb, anyhow::Error> {
Ok(self.acquire().await?) // MUST RETURN NEW PgDb
}
async fn transaction(&self) -> Result<Self::Transaction, anyhow::Error> {
Ok(PgTransaction(self.0.begin().await?))
}
}
pub async fn establish_connection<'a>() -> PgPoolImpl<'a> {
dotenv().ok();
let db_url = env::var("DATABASE_URL").expect("DATABASE_URL is missing");
let pool = PgPool::connect(&db_url).await.expect("failed to connect to to pool");
PgPoolImpl(pool, PhantomData)
}
This makes it easy to write services and tests for them without being tied to a specific database in like this:
use std::marker::PhantomData;
use std::sync::Arc;
use async_trait::async_trait;
use crate::domain::core::db::{run_transaction, DbPool};
use crate::domain::internal::verification::code_generator::VerificationCodeGenerator;
use crate::domain::models::register_requests::RegisterRequest;
use crate::domain::models::users::{CreateUserInfo, UserRole, UserUniqueInfo};
use crate::domain::repositories::register_requests::RegisterRequestsRepo;
use crate::domain::repositories::users::{UserConflict, UsersRepo};
use crate::domain::services::{ServiceError, ServiceResult};
pub enum RegisterUserError {
Conflict(UserConflict)
}
pub struct RegisterUserInfo {
pub login: String,
pub email: String,
pub password: String
}
pub trait RegisterUserServiceRepo<Db> = DbPool<Db = Db> + Sync + Send where Db: UsersRepo + RegisterRequestsRepo + Sync + Send;
pub struct RegisterUserServiceImpl<Db, Repo: RegisterUserServiceRepo<Db>> {
_phantom: PhantomData<Db>,
repo: Repo,
verification_code_gen: Arc<dyn VerificationCodeGenerator + Sync + Send>
}
#[async_trait]
pub trait RegisterUserService {
async fn register_user_and_get_id(&self, info: RegisterUserInfo) -> ServiceResult<uuid::Uuid, RegisterUserError>;
}
#[async_trait]
impl<Db, Repo: RegisterUserServiceRepo<Db>> RegisterUserService for RegisterUserServiceImpl<Db, Repo> {
async fn register_user_and_get_id(&self, info: RegisterUserInfo) -> ServiceResult<uuid::Uuid, RegisterUserError> {
run_transaction(self.repo.transaction().await?, async |db| -> Result<_, anyhow::Error> {
if let Some(conflict) = db.get_conflict(UserUniqueInfo {
login: info.login.clone(),
email: info.email.clone(),
}).await? {
return Ok(Err(ServiceError::Service(RegisterUserError::Conflict(conflict))));
}
let id = db.add_user_and_get_id(CreateUserInfo {
login: info.login,
role: UserRole::User,
password: info.password,
email: info.email,
}).await?;
db.add_request(RegisterRequest {
user_id: id,
verification_code: self.verification_code_gen.generate(),
}).await?;
Ok(Ok(id))
}).await?
}
}
The problem is that it works when I keep DbTransaction and PgDb in the same package. If I try to split this code into separate rust packages (not modules), an error occurs indicating that it is not possible to implement the foreign trait DbTransaction and DbPool for the alias of the foreign type PgDb = PgConnection.
What should I do to split the packages? I think I need to change PgDb somehow, but I don't understand how. I tried to make PgDb a structure wrapping PgConnection, but here is another problem: DbPool::acquire should return a new instance of PgDb(for becoming the owner of a new connection), and DbTransaction::db should return a mutable reference to PgDb(for taking ownership for this transaction). Thus, it will be impossible to wrap *self.0 in PgTransaction, because in this case I will need to copy transaction (Transaction::deref() -> PgConnection) and return reference to it, which is impossible because PgConnection does not implement the Copy trait.