Skip to content

Commit

Permalink
Flatten error handling, reducing generics (#54)
Browse files Browse the repository at this point in the history
* Simplify a lot of the error propagation

* Flatten error types
  • Loading branch information
smklein authored Oct 6, 2023
1 parent da04c08 commit 1446f7e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 58 deletions.
1 change: 1 addition & 0 deletions examples/customize_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ impl bb8::CustomizeConnection<DieselPgConn, ConnectionError> for ConnectionCusto
connection
.batch_execute_async("please execute some raw sql for me")
.await
.map_err(ConnectionError::from)
}
}

Expand Down
5 changes: 2 additions & 3 deletions examples/usage.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use async_bb8_diesel::{
AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, ConnectionError, OptionalExtension,
};
use async_bb8_diesel::{AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, ConnectionError};
use diesel::OptionalExtension;
use diesel::{pg::PgConnection, prelude::*};

table! {
Expand Down
85 changes: 38 additions & 47 deletions src/async_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,37 @@ use tokio::task::spawn_blocking;

/// An async variant of [`diesel::connection::SimpleConnection`].
#[async_trait]
pub trait AsyncSimpleConnection<Conn, ConnErr>
pub trait AsyncSimpleConnection<Conn>
where
Conn: 'static + SimpleConnection,
{
async fn batch_execute_async(&self, query: &str) -> Result<(), ConnErr>;
async fn batch_execute_async(&self, query: &str) -> Result<(), DieselError>;
}

/// An async variant of [`diesel::connection::Connection`].
#[async_trait]
pub trait AsyncConnection<Conn, ConnErr>: AsyncSimpleConnection<Conn, ConnErr>
pub trait AsyncConnection<Conn>: AsyncSimpleConnection<Conn>
where
Conn: 'static + DieselConnection,
ConnErr: From<DieselError> + Send + 'static,
Self: Send,
{
type OwnedConnection: Sync + Send + 'static;

#[doc(hidden)]
async fn get_owned_connection(&self) -> Result<Self::OwnedConnection, ConnErr>;
async fn get_owned_connection(&self) -> Self::OwnedConnection;
#[doc(hidden)]
fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn>;
#[doc(hidden)]
fn as_async_conn(owned: &Self::OwnedConnection) -> &SingleConnection<Conn>;

/// Runs the function `f` in an context where blocking is safe.
///
/// Any error may be propagated through `f`, as long as that
/// error type may be constructed from `ConnErr` (as that error
/// type may also be generated).
async fn run<R, E, Func>(&self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: From<ConnErr> + Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
let connection = self.get_owned_connection().await?;
let connection = self.get_owned_connection().await;
Self::run_with_connection(connection, f).await
}

Expand All @@ -64,7 +59,7 @@ where
) -> Result<R, E>
where
R: Send + 'static,
E: From<ConnErr> + Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
Expand All @@ -79,7 +74,7 @@ where
) -> Result<R, E>
where
R: Send + 'static,
E: From<ConnErr> + Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
Expand All @@ -90,7 +85,7 @@ where
async fn transaction<R, E, Func>(&self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: From<DieselError> + From<ConnErr> + Send + 'static,
E: From<DieselError> + Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
self.run(|conn| conn.transaction(|c| f(c))).await
Expand All @@ -99,21 +94,21 @@ where
async fn transaction_async<R, E, Func, Fut, 'a>(&'a self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: From<DieselError> + From<ConnErr> + Send,
E: From<DieselError> + Send + 'static,
Fut: Future<Output = Result<R, E>> + Send,
Func: FnOnce(SingleConnection<Conn>) -> Fut + Send,
{
// Check out a connection once, and use it for the duration of the
// operation.
let conn = Arc::new(self.get_owned_connection().await?);
let conn = Arc::new(self.get_owned_connection().await);

// This function mimics the implementation of:
// https://docs.diesel.rs/master/diesel/connection/trait.TransactionManager.html#method.transaction
//
// However, it modifies all callsites to instead issue
// known-to-be-synchronous operations from an asynchronous context.
Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::begin_transaction(conn).map_err(ConnErr::from)
Conn::TransactionManager::begin_transaction(conn).map_err(E::from)
})
.await?;

Expand All @@ -133,14 +128,14 @@ where
match f(async_conn).await {
Ok(value) => {
Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::commit_transaction(conn).map_err(ConnErr::from)
Conn::TransactionManager::commit_transaction(conn).map_err(E::from)
})
.await?;
Ok(value)
}
Err(user_error) => {
match Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::rollback_transaction(conn).map_err(ConnErr::from)
Conn::TransactionManager::rollback_transaction(conn).map_err(E::from)
})
.await
{
Expand All @@ -154,112 +149,108 @@ where

/// An async variant of [`diesel::query_dsl::RunQueryDsl`].
#[async_trait]
pub trait AsyncRunQueryDsl<Conn, AsyncConn, E>
pub trait AsyncRunQueryDsl<Conn, AsyncConn>
where
Conn: 'static + DieselConnection,
{
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, E>
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, DieselError>
where
Self: ExecuteDsl<Conn>;

async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LimitDsl,
Limit<Self>: LoadQuery<'static, Conn, U>;
}

#[async_trait]
impl<T, AsyncConn, Conn, E> AsyncRunQueryDsl<Conn, AsyncConn, E> for T
impl<T, AsyncConn, Conn> AsyncRunQueryDsl<Conn, AsyncConn> for T
where
T: 'static + Send + RunQueryDsl<Conn>,
Conn: 'static + DieselConnection,
AsyncConn: Send + Sync + AsyncConnection<Conn, E>,
E: From<DieselError> + Send + 'static,
AsyncConn: Send + Sync + AsyncConnection<Conn>,
{
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, E>
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, DieselError>
where
Self: ExecuteDsl<Conn>,
{
asc.run(|conn| self.execute(conn).map_err(E::from)).await
asc.run(|conn| self.execute(conn)).await
}

async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.load(conn).map_err(E::from)).await
asc.run(|conn| self.load(conn)).await
}

async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.get_result(conn).map_err(E::from)).await
asc.run(|conn| self.get_result(conn)).await
}

async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.get_results(conn).map_err(E::from))
.await
asc.run(|conn| self.get_results(conn)).await
}

async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LimitDsl,
Limit<Self>: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.first(conn).map_err(E::from)).await
asc.run(|conn| self.first(conn)).await
}
}

#[async_trait]
pub trait AsyncSaveChangesDsl<Conn, AsyncConn, E>
pub trait AsyncSaveChangesDsl<Conn, AsyncConn>
where
Conn: 'static + DieselConnection,
{
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, E>
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, DieselError>
where
Self: Sized,
Conn: diesel::query_dsl::UpdateAndFetchResults<Self, Output>,
Output: Send + 'static;
}

#[async_trait]
impl<T, AsyncConn, Conn, E> AsyncSaveChangesDsl<Conn, AsyncConn, E> for T
impl<T, AsyncConn, Conn> AsyncSaveChangesDsl<Conn, AsyncConn> for T
where
T: 'static + Send + Sync + diesel::SaveChangesDsl<Conn>,
Conn: 'static + DieselConnection,
AsyncConn: Send + Sync + AsyncConnection<Conn, E>,
E: 'static + Send + From<DieselError>,
AsyncConn: Send + Sync + AsyncConnection<Conn>,
{
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, E>
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, DieselError>
where
Conn: diesel::query_dsl::UpdateAndFetchResults<Self, Output>,
Output: Send + 'static,
{
asc.run(|conn| self.save_changes(conn).map_err(E::from))
.await
asc.run(|conn| self.save_changes(conn)).await
}
}
14 changes: 6 additions & 8 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! An async wrapper around a [`diesel::Connection`].
use crate::{ConnectionError, ConnectionResult};
use async_trait::async_trait;
use diesel::r2d2::R2D2Connection;
use std::sync::{Arc, Mutex, MutexGuard};
Expand Down Expand Up @@ -31,31 +30,30 @@ impl<C> Connection<C> {
}

#[async_trait]
impl<Conn> crate::AsyncSimpleConnection<Conn, ConnectionError> for Connection<Conn>
impl<Conn> crate::AsyncSimpleConnection<Conn> for Connection<Conn>
where
Conn: 'static + R2D2Connection,
{
#[inline]
async fn batch_execute_async(&self, query: &str) -> ConnectionResult<()> {
async fn batch_execute_async(&self, query: &str) -> Result<(), diesel::result::Error> {
let diesel_conn = Connection(self.0.clone());
let query = query.to_string();
task::spawn_blocking(move || diesel_conn.inner().batch_execute(&query))
.await
.unwrap() // Propagate panics
.map_err(ConnectionError::from)
}
}

#[async_trait]
impl<Conn> crate::AsyncConnection<Conn, ConnectionError> for Connection<Conn>
impl<Conn> crate::AsyncConnection<Conn> for Connection<Conn>
where
Conn: 'static + R2D2Connection,
Connection<Conn>: crate::AsyncSimpleConnection<Conn, ConnectionError>,
Connection<Conn>: crate::AsyncSimpleConnection<Conn>,
{
type OwnedConnection = Connection<Conn>;

async fn get_owned_connection(&self) -> Result<Self::OwnedConnection, ConnectionError> {
Ok(Connection(self.0.clone()))
async fn get_owned_connection(&self) -> Self::OwnedConnection {
Connection(self.0.clone())
}

fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn> {
Expand Down

0 comments on commit 1446f7e

Please sign in to comment.