From 1446f7e0c1f05f33a0581abd51fa873c7652ab61 Mon Sep 17 00:00:00 2001 From: Sean Klein Date: Fri, 6 Oct 2023 12:28:40 -0700 Subject: [PATCH] Flatten error handling, reducing generics (#54) * Simplify a lot of the error propagation * Flatten error types --- examples/customize_connection.rs | 1 + examples/usage.rs | 5 +- src/async_traits.rs | 85 ++++++++++++++------------------ src/connection.rs | 14 +++--- 4 files changed, 47 insertions(+), 58 deletions(-) diff --git a/examples/customize_connection.rs b/examples/customize_connection.rs index b034b6a..0558327 100644 --- a/examples/customize_connection.rs +++ b/examples/customize_connection.rs @@ -15,6 +15,7 @@ impl bb8::CustomizeConnection for ConnectionCusto connection .batch_execute_async("please execute some raw sql for me") .await + .map_err(ConnectionError::from) } } diff --git a/examples/usage.rs b/examples/usage.rs index 8ee0136..2455891 100644 --- a/examples/usage.rs +++ b/examples/usage.rs @@ -1,6 +1,5 @@ -use async_bb8_diesel::{ - AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, ConnectionError, OptionalExtension, -}; +use async_bb8_diesel::{AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, ConnectionError}; +use diesel::OptionalExtension; use diesel::{pg::PgConnection, prelude::*}; table! { diff --git a/src/async_traits.rs b/src/async_traits.rs index 2c5cd13..4a19088 100644 --- a/src/async_traits.rs +++ b/src/async_traits.rs @@ -18,42 +18,37 @@ use tokio::task::spawn_blocking; /// An async variant of [`diesel::connection::SimpleConnection`]. #[async_trait] -pub trait AsyncSimpleConnection +pub trait AsyncSimpleConnection where Conn: 'static + SimpleConnection, { - async fn batch_execute_async(&self, query: &str) -> Result<(), ConnErr>; + async fn batch_execute_async(&self, query: &str) -> Result<(), DieselError>; } /// An async variant of [`diesel::connection::Connection`]. #[async_trait] -pub trait AsyncConnection: AsyncSimpleConnection +pub trait AsyncConnection: AsyncSimpleConnection where Conn: 'static + DieselConnection, - ConnErr: From + Send + 'static, Self: Send, { type OwnedConnection: Sync + Send + 'static; #[doc(hidden)] - async fn get_owned_connection(&self) -> Result; + async fn get_owned_connection(&self) -> Self::OwnedConnection; #[doc(hidden)] fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn>; #[doc(hidden)] fn as_async_conn(owned: &Self::OwnedConnection) -> &SingleConnection; /// Runs the function `f` in an context where blocking is safe. - /// - /// Any error may be propagated through `f`, as long as that - /// error type may be constructed from `ConnErr` (as that error - /// type may also be generated). async fn run(&self, f: Func) -> Result where R: Send + 'static, - E: From + Send + 'static, + E: Send + 'static, Func: FnOnce(&mut Conn) -> Result + Send + 'static, { - let connection = self.get_owned_connection().await?; + let connection = self.get_owned_connection().await; Self::run_with_connection(connection, f).await } @@ -64,7 +59,7 @@ where ) -> Result where R: Send + 'static, - E: From + Send + 'static, + E: Send + 'static, Func: FnOnce(&mut Conn) -> Result + Send + 'static, { spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection))) @@ -79,7 +74,7 @@ where ) -> Result where R: Send + 'static, - E: From + Send + 'static, + E: Send + 'static, Func: FnOnce(&mut Conn) -> Result + Send + 'static, { spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection))) @@ -90,7 +85,7 @@ where async fn transaction(&self, f: Func) -> Result where R: Send + 'static, - E: From + From + Send + 'static, + E: From + Send + 'static, Func: FnOnce(&mut Conn) -> Result + Send + 'static, { self.run(|conn| conn.transaction(|c| f(c))).await @@ -99,13 +94,13 @@ where async fn transaction_async(&'a self, f: Func) -> Result where R: Send + 'static, - E: From + From + Send, + E: From + Send + 'static, Fut: Future> + Send, Func: FnOnce(SingleConnection) -> Fut + Send, { // Check out a connection once, and use it for the duration of the // operation. - let conn = Arc::new(self.get_owned_connection().await?); + let conn = Arc::new(self.get_owned_connection().await); // This function mimics the implementation of: // https://docs.diesel.rs/master/diesel/connection/trait.TransactionManager.html#method.transaction @@ -113,7 +108,7 @@ where // However, it modifies all callsites to instead issue // known-to-be-synchronous operations from an asynchronous context. Self::run_with_shared_connection(conn.clone(), |conn| { - Conn::TransactionManager::begin_transaction(conn).map_err(ConnErr::from) + Conn::TransactionManager::begin_transaction(conn).map_err(E::from) }) .await?; @@ -133,14 +128,14 @@ where match f(async_conn).await { Ok(value) => { Self::run_with_shared_connection(conn.clone(), |conn| { - Conn::TransactionManager::commit_transaction(conn).map_err(ConnErr::from) + Conn::TransactionManager::commit_transaction(conn).map_err(E::from) }) .await?; Ok(value) } Err(user_error) => { match Self::run_with_shared_connection(conn.clone(), |conn| { - Conn::TransactionManager::rollback_transaction(conn).map_err(ConnErr::from) + Conn::TransactionManager::rollback_transaction(conn).map_err(E::from) }) .await { @@ -154,30 +149,30 @@ where /// An async variant of [`diesel::query_dsl::RunQueryDsl`]. #[async_trait] -pub trait AsyncRunQueryDsl +pub trait AsyncRunQueryDsl where Conn: 'static + DieselConnection, { - async fn execute_async(self, asc: &AsyncConn) -> Result + async fn execute_async(self, asc: &AsyncConn) -> Result where Self: ExecuteDsl; - async fn load_async(self, asc: &AsyncConn) -> Result, E> + async fn load_async(self, asc: &AsyncConn) -> Result, DieselError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn get_result_async(self, asc: &AsyncConn) -> Result + async fn get_result_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn get_results_async(self, asc: &AsyncConn) -> Result, E> + async fn get_results_async(self, asc: &AsyncConn) -> Result, DieselError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn first_async(self, asc: &AsyncConn) -> Result + async fn first_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LimitDsl, @@ -185,61 +180,59 @@ where } #[async_trait] -impl AsyncRunQueryDsl for T +impl AsyncRunQueryDsl for T where T: 'static + Send + RunQueryDsl, Conn: 'static + DieselConnection, - AsyncConn: Send + Sync + AsyncConnection, - E: From + Send + 'static, + AsyncConn: Send + Sync + AsyncConnection, { - async fn execute_async(self, asc: &AsyncConn) -> Result + async fn execute_async(self, asc: &AsyncConn) -> Result where Self: ExecuteDsl, { - asc.run(|conn| self.execute(conn).map_err(E::from)).await + asc.run(|conn| self.execute(conn)).await } - async fn load_async(self, asc: &AsyncConn) -> Result, E> + async fn load_async(self, asc: &AsyncConn) -> Result, DieselError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, { - asc.run(|conn| self.load(conn).map_err(E::from)).await + asc.run(|conn| self.load(conn)).await } - async fn get_result_async(self, asc: &AsyncConn) -> Result + async fn get_result_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, { - asc.run(|conn| self.get_result(conn).map_err(E::from)).await + asc.run(|conn| self.get_result(conn)).await } - async fn get_results_async(self, asc: &AsyncConn) -> Result, E> + async fn get_results_async(self, asc: &AsyncConn) -> Result, DieselError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, { - asc.run(|conn| self.get_results(conn).map_err(E::from)) - .await + asc.run(|conn| self.get_results(conn)).await } - async fn first_async(self, asc: &AsyncConn) -> Result + async fn first_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LimitDsl, Limit: LoadQuery<'static, Conn, U>, { - asc.run(|conn| self.first(conn).map_err(E::from)).await + asc.run(|conn| self.first(conn)).await } } #[async_trait] -pub trait AsyncSaveChangesDsl +pub trait AsyncSaveChangesDsl where Conn: 'static + DieselConnection, { - async fn save_changes_async(self, asc: &AsyncConn) -> Result + async fn save_changes_async(self, asc: &AsyncConn) -> Result where Self: Sized, Conn: diesel::query_dsl::UpdateAndFetchResults, @@ -247,19 +240,17 @@ where } #[async_trait] -impl AsyncSaveChangesDsl for T +impl AsyncSaveChangesDsl for T where T: 'static + Send + Sync + diesel::SaveChangesDsl, Conn: 'static + DieselConnection, - AsyncConn: Send + Sync + AsyncConnection, - E: 'static + Send + From, + AsyncConn: Send + Sync + AsyncConnection, { - async fn save_changes_async(self, asc: &AsyncConn) -> Result + async fn save_changes_async(self, asc: &AsyncConn) -> Result where Conn: diesel::query_dsl::UpdateAndFetchResults, Output: Send + 'static, { - asc.run(|conn| self.save_changes(conn).map_err(E::from)) - .await + asc.run(|conn| self.save_changes(conn)).await } } diff --git a/src/connection.rs b/src/connection.rs index 03893b1..ff65b4e 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,6 +1,5 @@ //! An async wrapper around a [`diesel::Connection`]. -use crate::{ConnectionError, ConnectionResult}; use async_trait::async_trait; use diesel::r2d2::R2D2Connection; use std::sync::{Arc, Mutex, MutexGuard}; @@ -31,31 +30,30 @@ impl Connection { } #[async_trait] -impl crate::AsyncSimpleConnection for Connection +impl crate::AsyncSimpleConnection for Connection where Conn: 'static + R2D2Connection, { #[inline] - async fn batch_execute_async(&self, query: &str) -> ConnectionResult<()> { + async fn batch_execute_async(&self, query: &str) -> Result<(), diesel::result::Error> { let diesel_conn = Connection(self.0.clone()); let query = query.to_string(); task::spawn_blocking(move || diesel_conn.inner().batch_execute(&query)) .await .unwrap() // Propagate panics - .map_err(ConnectionError::from) } } #[async_trait] -impl crate::AsyncConnection for Connection +impl crate::AsyncConnection for Connection where Conn: 'static + R2D2Connection, - Connection: crate::AsyncSimpleConnection, + Connection: crate::AsyncSimpleConnection, { type OwnedConnection = Connection; - async fn get_owned_connection(&self) -> Result { - Ok(Connection(self.0.clone())) + async fn get_owned_connection(&self) -> Self::OwnedConnection { + Connection(self.0.clone()) } fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn> {