Skip to content

Commit

Permalink
Add pool retrieval to sync_db_pools.
Browse files Browse the repository at this point in the history
Generates a new method on attributed types, `pool()`, which returns an
opaque reference to a type that can be used to get pooled connections.

Also adds a code-generated example to the crate docs which includes
real, proper function signatures and fully checked examples.

Resolves rwf2#1884.
Closes rwf2#1972.
  • Loading branch information
SergioBenitez committed May 24, 2022
1 parent 5cb70ec commit 04819d8
Show file tree
Hide file tree
Showing 8 changed files with 308 additions and 68 deletions.
52 changes: 27 additions & 25 deletions contrib/sync_db_pools/codegen/src/database.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
use proc_macro::TokenStream;
use devise::{Spanned, Result, ext::SpanDiagnosticExt};

use crate::syn::{Fields, Data, Type, LitStr, DeriveInput, Ident, Visibility};
use crate::syn;

#[derive(Debug)]
struct DatabaseInvocation {
/// The attributes on the attributed structure.
attrs: Vec<syn::Attribute>,
/// The name of the structure on which `#[database(..)] struct This(..)` was invoked.
type_name: Ident,
type_name: syn::Ident,
/// The visibility of the structure on which `#[database(..)] struct This(..)` was invoked.
visibility: Visibility,
visibility: syn::Visibility,
/// The database name as passed in via #[database('database name')].
db_name: String,
/// The type inside the structure: struct MyDb(ThisType).
connection_type: Type,
connection_type: syn::Type,
}

const EXAMPLE: &str = "example: `struct MyDatabase(diesel::SqliteConnection);`";
Expand All @@ -24,27 +26,28 @@ const NO_GENERIC_STRUCTS: &str = "`database` attribute cannot be applied to stru

fn parse_invocation(attr: TokenStream, input: TokenStream) -> Result<DatabaseInvocation> {
let attr_stream2 = crate::proc_macro2::TokenStream::from(attr);
let string_lit = crate::syn::parse2::<LitStr>(attr_stream2)?;
let string_lit = crate::syn::parse2::<syn::LitStr>(attr_stream2)?;

let input = crate::syn::parse::<DeriveInput>(input).unwrap();
let input = crate::syn::parse::<syn::DeriveInput>(input).unwrap();
if !input.generics.params.is_empty() {
return Err(input.generics.span().error(NO_GENERIC_STRUCTS));
}

let structure = match input.data {
Data::Struct(s) => s,
syn::Data::Struct(s) => s,
_ => return Err(input.span().error(ONLY_ON_STRUCTS_MSG))
};

let inner_type = match structure.fields {
Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
syn::Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
let first = fields.unnamed.first().expect("checked length");
first.ty.clone()
}
_ => return Err(structure.fields.span().error(ONLY_UNNAMED_FIELDS).help(EXAMPLE))
};

Ok(DatabaseInvocation {
attrs: input.attrs,
type_name: input.ident,
visibility: input.vis,
db_name: string_lit.value(),
Expand All @@ -59,6 +62,7 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
// Store everything we're going to need to generate code.
let conn_type = &invocation.connection_type;
let name = &invocation.db_name;
let attrs = &invocation.attrs;
let guard_type = &invocation.type_name;
let vis = &invocation.visibility;
let fairing_name = format!("'{}' Database Pool", name);
Expand All @@ -69,7 +73,7 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
let rocket = quote!(#root::rocket);

let request_guard_type = quote_spanned! { span =>
#vis struct #guard_type(#root::Connection<Self, #conn_type>);
#(#attrs)* #vis struct #guard_type(#root::Connection<Self, #conn_type>);
};

let pool = quote_spanned!(span => #root::ConnectionPool<Self, #conn_type>);
Expand All @@ -79,32 +83,30 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
#request_guard_type

impl #guard_type {
/// Returns a fairing that initializes the associated database
/// connection pool.
/// Returns a fairing that initializes the database connection pool.
pub fn fairing() -> impl #rocket::fairing::Fairing {
<#pool>::fairing(#fairing_name, #name)
}

/// Retrieves a connection of type `Self` from the `rocket`
/// instance. Returns `Some` as long as `Self::fairing()` has been
/// attached.
pub async fn get_one<P>(__rocket: &#rocket::Rocket<P>) -> Option<Self>
where P: #rocket::Phase,
{
<#pool>::get_one(&__rocket).await.map(Self)
/// Returns an opaque type that represents the connection pool
/// backing connections of type `Self`.
pub fn pool<P: #rocket::Phase>(__rocket: &#rocket::Rocket<P>) -> Option<&#pool> {
<#pool>::pool(&__rocket)
}

/// Runs the provided closure on a thread from a threadpool. The
/// closure will be passed an `&mut r2d2::PooledConnection`.
/// `.await`ing the return value of this function yields the value
/// returned by the closure.
/// Runs the provided function `__f` in an async-safe blocking
/// thread.
pub async fn run<F, R>(&self, __f: F) -> R
where
F: FnOnce(&mut #conn_type) -> R + Send + 'static,
R: Send + 'static,
where F: FnOnce(&mut #conn_type) -> R + Send + 'static,
R: Send + 'static,
{
self.0.run(__f).await
}

/// Retrieves a connection of type `Self` from the `rocket` instance.
pub async fn get_one<P: #rocket::Phase>(__rocket: &#rocket::Rocket<P>) -> Option<Self> {
<#pool>::get_one(&__rocket).await.map(Self)
}
}

#[#rocket::async_trait]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ note: required by a bound in `rocket_sync_db_pools::Connection`
| pub struct Connection<K, C: Poolable> {
| ^^^^^^^^ required by this bound in `rocket_sync_db_pools::Connection`

error[E0277]: the trait bound `Unknown: Poolable` is not satisfied
--> tests/ui-fail-nightly/database-types.rs:5:1
|
5 | #[database("foo")]
| ^^^^^^^^^^^^^^^^^^ the trait `Poolable` is not implemented for `Unknown`
|
= help: the trait `Poolable` is implemented for `SqliteConnection`
note: required by a bound in `ConnectionPool`
--> $WORKSPACE/contrib/sync_db_pools/lib/src/connection.rs
|
| pub struct ConnectionPool<K, C: Poolable> {
| ^^^^^^^^ required by this bound in `ConnectionPool`
= note: this error originates in the attribute macro `database` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: the trait bound `Vec<i32>: Poolable` is not satisfied
--> tests/ui-fail-nightly/database-types.rs:9:10
|
Expand All @@ -23,3 +37,17 @@ note: required by a bound in `rocket_sync_db_pools::Connection`
|
| pub struct Connection<K, C: Poolable> {
| ^^^^^^^^ required by this bound in `rocket_sync_db_pools::Connection`

error[E0277]: the trait bound `Vec<i32>: Poolable` is not satisfied
--> tests/ui-fail-nightly/database-types.rs:8:1
|
8 | #[database("foo")]
| ^^^^^^^^^^^^^^^^^^ the trait `Poolable` is not implemented for `Vec<i32>`
|
= help: the trait `Poolable` is implemented for `SqliteConnection`
note: required by a bound in `ConnectionPool`
--> $WORKSPACE/contrib/sync_db_pools/lib/src/connection.rs
|
| pub struct ConnectionPool<K, C: Poolable> {
| ^^^^^^^^ required by this bound in `ConnectionPool`
= note: this error originates in the attribute macro `database` (in Nightly builds, run with -Z macro-backtrace for more info)
3 changes: 3 additions & 0 deletions contrib/sync_db_pools/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,8 @@ version = "0.5.0-rc.2"
path = "../../../core/lib"
default-features = false

[build-dependencies]
version_check = "0.9.1"

[package.metadata.docs.rs]
all-features = true
5 changes: 5 additions & 0 deletions contrib/sync_db_pools/lib/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fn main() {
if let Some(true) = version_check::is_feature_flaggable() {
println!("cargo:rustc-cfg=nightly");
}
}
19 changes: 9 additions & 10 deletions contrib/sync_db_pools/lib/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,36 +93,36 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
})
}

async fn get(&self) -> Result<Connection<K, C>, ()> {
pub async fn get(&self) -> Option<Connection<K, C>> {
let duration = std::time::Duration::from_secs(self.config.timeout as u64);
let permit = match timeout(duration, self.semaphore.clone().acquire_owned()).await {
Ok(p) => p.expect("internal invariant broken: semaphore should not be closed"),
Err(_) => {
error_!("database connection retrieval timed out");
return Err(());
return None;
}
};

let pool = self.pool.as_ref().cloned()
.expect("internal invariant broken: self.pool is Some");

match run_blocking(move || pool.get_timeout(duration)).await {
Ok(c) => Ok(Connection {
Ok(c) => Some(Connection {
connection: Arc::new(Mutex::new(Some(c))),
permit: Some(permit),
_marker: PhantomData,
}),
Err(e) => {
error_!("failed to get a database connection: {}", e);
Err(())
None
}
}
}

#[inline]
pub async fn get_one<P: Phase>(rocket: &Rocket<P>) -> Option<Connection<K, C>> {
match rocket.state::<Self>() {
Some(pool) => match pool.get().await.ok() {
match Self::pool(rocket) {
Some(pool) => match pool.get().await {
Some(conn) => Some(conn),
None => {
error_!("no connections available for `{}`", std::any::type_name::<K>());
Expand All @@ -137,13 +137,12 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
}

#[inline]
pub async fn get_pool<P: Phase>(rocket: &Rocket<P>) -> Option<Self> {
rocket.state::<Self>().cloned()
pub fn pool<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
rocket.state::<Self>()
}
}

impl<K: 'static, C: Poolable> Connection<K, C> {
#[inline]
pub async fn run<F, R>(&self, f: F) -> R
where F: FnOnce(&mut C) -> R + Send + 'static,
R: Send + 'static,
Expand Down Expand Up @@ -207,7 +206,7 @@ impl<'r, K: 'static, C: Poolable> FromRequest<'r> for Connection<K, C> {
#[inline]
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ()> {
match request.rocket().state::<ConnectionPool<K, C>>() {
Some(c) => c.get().await.into_outcome(Status::ServiceUnavailable),
Some(c) => c.get().await.into_outcome((Status::ServiceUnavailable, ())),
None => {
error_!("Missing database fairing for `{}`", std::any::type_name::<K>());
Outcome::Failure((Status::InternalServerError, ()))
Expand Down
Loading

0 comments on commit 04819d8

Please sign in to comment.