Skip to content

Commit

Permalink
Replace async_trait with native async fns
Browse files Browse the repository at this point in the history
  • Loading branch information
gahag-cw authored and djc committed Dec 6, 2024
1 parent 2bde6f9 commit 38d74f7
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 39 deletions.
1 change: 0 additions & 1 deletion bb8/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ workspace = ".."
readme = "../README.md"

[dependencies]
async-trait = "0.1"
futures-util = { version = "0.3.2", default-features = false, features = ["alloc"] }
parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.0", features = ["rt", "sync", "time"] }
Expand Down
20 changes: 12 additions & 8 deletions bb8/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::borrow::Cow;
use std::error;
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::time::Duration;

use async_trait::async_trait;

use crate::inner::PoolInner;
use crate::internals::Conn;

Expand Down Expand Up @@ -381,23 +381,24 @@ impl<M: ManageConnection> Builder<M> {
}

/// A trait which provides connection-specific functionality.
#[async_trait]
pub trait ManageConnection: Sized + Send + Sync + 'static {
/// The connection type this manager deals with.
type Connection: Send + 'static;
/// The error type returned by `Connection`s.
type Error: fmt::Debug + Send + 'static;

/// Attempts to create a new connection.
async fn connect(&self) -> Result<Self::Connection, Self::Error>;
fn connect(&self) -> impl Future<Output = Result<Self::Connection, Self::Error>> + Send;
/// Determines if the connection is still connected to the database.
async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error>;
fn is_valid(
&self,
conn: &mut Self::Connection,
) -> impl Future<Output = Result<(), Self::Error>> + Send;
/// Synchronously determine if the connection is no longer usable, if possible.
fn has_broken(&self, conn: &mut Self::Connection) -> bool;
}

/// A trait which provides functionality to initialize a connection
#[async_trait]
pub trait CustomizeConnection<C: Send + 'static, E: 'static>:
fmt::Debug + Send + Sync + 'static
{
Expand All @@ -406,8 +407,11 @@ pub trait CustomizeConnection<C: Send + 'static, E: 'static>:
///
/// The default implementation simply returns `Ok(())`. If this method returns an
/// error, it will be forwarded to the configured error sink.
async fn on_acquire(&self, _connection: &mut C) -> Result<(), E> {
Ok(())
fn on_acquire<'a>(
&'a self,
_connection: &'a mut C,
) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send + 'a>> {
Box::pin(async { Ok(()) })
}
}

Expand Down
21 changes: 8 additions & 13 deletions bb8/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use std::task::Poll;
use std::time::Duration;
use std::{error, fmt};

use async_trait::async_trait;
use futures_util::future::{err, lazy, ok, pending, ready, try_join_all, FutureExt};
use futures_util::stream::{FuturesUnordered, TryStreamExt};
use tokio::sync::oneshot;
Expand Down Expand Up @@ -43,7 +42,6 @@ impl<C> OkManager<C> {
}
}

#[async_trait]
impl<C> ManageConnection for OkManager<C>
where
C: Default + Send + Sync + 'static,
Expand Down Expand Up @@ -78,7 +76,6 @@ impl<C> NthConnectionFailManager<C> {
}
}

#[async_trait]
impl<C> ManageConnection for NthConnectionFailManager<C>
where
C: Default + Send + Sync + 'static,
Expand Down Expand Up @@ -214,7 +211,6 @@ struct BrokenConnectionManager<C> {
_c: PhantomData<C>,
}

#[async_trait]
impl<C: Default + Send + Sync + 'static> ManageConnection for BrokenConnectionManager<C> {
type Connection = C;
type Error = Error;
Expand Down Expand Up @@ -380,7 +376,6 @@ async fn test_now_invalid() {

struct Handler;

#[async_trait]
impl ManageConnection for Handler {
type Connection = FakeConnection;
type Error = Error;
Expand Down Expand Up @@ -689,7 +684,6 @@ async fn test_conns_drop_on_pool_drop() {

struct Handler;

#[async_trait]
impl ManageConnection for Handler {
type Connection = Connection;
type Error = Error;
Expand Down Expand Up @@ -741,7 +735,6 @@ async fn test_retry() {
struct Connection;
struct Handler;

#[async_trait]
impl ManageConnection for Handler {
type Connection = Connection;
type Error = Error;
Expand Down Expand Up @@ -787,7 +780,6 @@ async fn test_conn_fail_once() {
}
}

#[async_trait]
impl ManageConnection for Handler {
type Connection = Connection;
type Error = Error;
Expand Down Expand Up @@ -912,11 +904,15 @@ async fn test_customize_connection_acquire() {
count: AtomicUsize,
}

#[async_trait]
impl<E: 'static> CustomizeConnection<Connection, E> for CountingCustomizer {
async fn on_acquire(&self, connection: &mut Connection) -> Result<(), E> {
connection.custom_field = 1 + self.count.fetch_add(1, Ordering::SeqCst);
Ok(())
fn on_acquire<'a>(
&'a self,
connection: &'a mut Connection,
) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send + 'a>> {
Box::pin(async move {
connection.custom_field = 1 + self.count.fetch_add(1, Ordering::SeqCst);
Ok(())
})
}
}

Expand Down Expand Up @@ -952,7 +948,6 @@ async fn test_broken_connections_dont_starve_pool() {
#[derive(Debug)]
struct Connection;

#[async_trait::async_trait]
impl bb8::ManageConnection for ConnectionManager {
type Connection = Connection;
type Error = Infallible;
Expand Down
1 change: 0 additions & 1 deletion postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ repository = "https://github.com/djc/bb8"
"with-time-0_3" = ["tokio-postgres/with-time-0_3"]

[dependencies]
async-trait = "0.1"
bb8 = { version = "0.8", path = "../bb8" }
tokio = { version = "1.0.0", features = ["rt"] }
tokio-postgres = "0.7"
Expand Down
26 changes: 15 additions & 11 deletions postgres/examples/custom_state.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::collections::BTreeMap;
use std::future::Future;
use std::ops::Deref;
use std::pin::Pin;
use std::str::FromStr;

use async_trait::async_trait;
use bb8::{CustomizeConnection, Pool};
use bb8_postgres::PostgresConnectionManager;
use tokio_postgres::config::Config;
Expand Down Expand Up @@ -43,16 +44,20 @@ async fn main() {
#[derive(Debug)]
struct Customizer;

#[async_trait]
impl CustomizeConnection<CustomPostgresConnection, Error> for Customizer {
async fn on_acquire(&self, conn: &mut CustomPostgresConnection) -> Result<(), Error> {
conn.custom_state
.insert(QueryName::BasicSelect, conn.prepare("SELECT 1").await?);

conn.custom_state
.insert(QueryName::Addition, conn.prepare("SELECT 1 + 1 + 1").await?);

Ok(())
fn on_acquire<'a>(
&'a self,
conn: &'a mut CustomPostgresConnection,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
Box::pin(async {
conn.custom_state
.insert(QueryName::BasicSelect, conn.prepare("SELECT 1").await?);

conn.custom_state
.insert(QueryName::Addition, conn.prepare("SELECT 1 + 1 + 1").await?);

Ok(())
})
}
}

Expand Down Expand Up @@ -96,7 +101,6 @@ where
}
}

#[async_trait]
impl<Tls> bb8::ManageConnection for CustomPostgresConnectionManager<Tls>
where
Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
Expand Down
2 changes: 0 additions & 2 deletions postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
pub use bb8;
pub use tokio_postgres;

use async_trait::async_trait;
use tokio_postgres::config::Config;
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::{Client, Error, Socket};
Expand Down Expand Up @@ -45,7 +44,6 @@ where
}
}

#[async_trait]
impl<Tls> bb8::ManageConnection for PostgresConnectionManager<Tls>
where
Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
Expand Down
1 change: 0 additions & 1 deletion redis/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ license = "MIT"
repository = "https://github.com/djc/bb8"

[dependencies]
async-trait = "0.1"
bb8 = { version = "0.8", path = "../bb8" }
redis = { version = "0.27", default-features = false, features = ["tokio-comp"] }

Expand Down
2 changes: 0 additions & 2 deletions redis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
pub use bb8;
pub use redis;

use async_trait::async_trait;
use redis::{aio::MultiplexedConnection, ErrorKind};
use redis::{Client, IntoConnectionInfo, RedisError};

Expand All @@ -58,7 +57,6 @@ impl RedisConnectionManager {
}
}

#[async_trait]
impl bb8::ManageConnection for RedisConnectionManager {
type Connection = MultiplexedConnection;
type Error = RedisError;
Expand Down

0 comments on commit 38d74f7

Please sign in to comment.