Skip to content

Commit

Permalink
Merge pull request #373 from nappa85/master
Browse files Browse the repository at this point in the history
  • Loading branch information
tyt2y3 authored Dec 14, 2021
2 parents 5656c49 + 9a34254 commit 7da5b6b
Show file tree
Hide file tree
Showing 15 changed files with 389 additions and 147 deletions.
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async-trait = { version = "^0.1" }
chrono = { version = "^0", optional = true }
futures = { version = "^0.3" }
futures-util = { version = "^0.3" }
log = { version = "^0.4", optional = true }
tracing = "0.1"
rust_decimal = { version = "^1", optional = true }
sea-orm-macros = { version = "^0.4.2", path = "sea-orm-macros", optional = true }
sea-query = { version = "^0.19.4", features = ["thread-safe"] }
Expand All @@ -36,8 +36,9 @@ serde = { version = "^1.0", features = ["derive"] }
serde_json = { version = "^1", optional = true }
sqlx = { version = "^0.5", optional = true }
uuid = { version = "0.8", features = ["serde", "v4"], optional = true }
ouroboros = "0.11"
ouroboros = "0.14"
url = "^2.2"
once_cell = "1.8"

[dev-dependencies]
smol = { version = "^1.2" }
Expand All @@ -47,12 +48,12 @@ tokio = { version = "^1.6", features = ["full"] }
actix-rt = { version = "2.2.0" }
maplit = { version = "^1" }
rust_decimal_macros = { version = "^1" }
env_logger = { version = "^0.9" }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
sea-orm = { path = ".", features = ["debug-print"] }
pretty_assertions = { version = "^0.7" }

[features]
debug-print = ["log"]
debug-print = []
default = [
"macros",
"mock",
Expand Down
5 changes: 2 additions & 3 deletions sea-orm-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,8 @@ pub fn test(_: TokenStream, input: TokenStream) -> TokenStream {
#[test]
#(#attrs)*
fn #name() #ret {
let _ = ::env_logger::builder()
.filter_level(::log::LevelFilter::Debug)
.is_test(true)
let _ = ::tracing_subscriber::fmt()
.with_max_level(::tracing::Level::DEBUG)
.try_init();
crate::block_on!(async { #body })
}
Expand Down
32 changes: 29 additions & 3 deletions src/database/db_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
StatementBuilder, TransactionError,
};
use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder};
use tracing::instrument;
use std::{future::Future, pin::Pin};
use url::Url;

Expand Down Expand Up @@ -49,6 +50,7 @@ pub enum DatabaseBackend {

/// The same as [DatabaseBackend] just shorter :)
pub type DbBackend = DatabaseBackend;
#[derive(Debug)]
pub(crate) enum InnerConnection {
#[cfg(feature = "sqlx-mysql")]
MySql(PoolConnection<sqlx::MySql>),
Expand Down Expand Up @@ -104,6 +106,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
}
}

#[instrument(level = "trace")]
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
Expand All @@ -118,6 +121,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
}
}

#[instrument(level = "trace")]
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
Expand All @@ -132,6 +136,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
}
}

#[instrument(level = "trace")]
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
Expand All @@ -146,6 +151,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
}
}

#[instrument(level = "trace")]
fn stream(
&'a self,
stmt: Statement,
Expand All @@ -160,13 +166,14 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => {
crate::QueryStream::from((Arc::clone(conn), stmt))
crate::QueryStream::from((Arc::clone(conn), stmt, None))
}
DatabaseConnection::Disconnected => panic!("Disconnected"),
})
})
}

#[instrument(level = "trace")]
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
Expand All @@ -177,14 +184,15 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => {
DatabaseTransaction::new_mock(Arc::clone(conn)).await
DatabaseTransaction::new_mock(Arc::clone(conn), None).await
}
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}

/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
#[instrument(level = "trace", skip(_callback))]
async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
Expand All @@ -205,7 +213,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => {
let transaction = DatabaseTransaction::new_mock(Arc::clone(conn))
let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None)
.await
.map_err(TransactionError::Connection)?;
transaction.run(_callback).await
Expand Down Expand Up @@ -237,6 +245,24 @@ impl DatabaseConnection {
}
}

impl DatabaseConnection {
/// Sets a callback to metric this connection
pub fn set_metric_callback<F>(&mut self, callback: F)
where
F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
{
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.set_metric_callback(callback),
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.set_metric_callback(callback),
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.set_metric_callback(callback),
_ => {},
}
}
}

impl DbBackend {
/// Check if the URI is the same as the specified database backend.
/// Returns true if they match.
Expand Down
6 changes: 6 additions & 0 deletions src/database/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
Statement,
};
use sea_query::{Value, ValueType, Values};
use tracing::instrument;
use std::{collections::BTreeMap, sync::Arc};

/// Defines a Mock database suitable for testing
Expand Down Expand Up @@ -89,6 +90,7 @@ impl MockDatabase {
}

impl MockDatabaseTrait for MockDatabase {
#[instrument(level = "trace")]
fn execute(&mut self, counter: usize, statement: Statement) -> Result<ExecResult, DbErr> {
if let Some(transaction) = &mut self.transaction {
transaction.push(statement);
Expand All @@ -104,6 +106,7 @@ impl MockDatabaseTrait for MockDatabase {
}
}

#[instrument(level = "trace")]
fn query(&mut self, counter: usize, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
if let Some(transaction) = &mut self.transaction {
transaction.push(statement);
Expand All @@ -122,6 +125,7 @@ impl MockDatabaseTrait for MockDatabase {
}
}

#[instrument(level = "trace")]
fn begin(&mut self) {
if self.transaction.is_some() {
self.transaction
Expand All @@ -133,6 +137,7 @@ impl MockDatabaseTrait for MockDatabase {
}
}

#[instrument(level = "trace")]
fn commit(&mut self) {
if self.transaction.is_some() {
if self.transaction.as_mut().unwrap().commit(self.db_backend) {
Expand All @@ -144,6 +149,7 @@ impl MockDatabaseTrait for MockDatabase {
}
}

#[instrument(level = "trace")]
fn rollback(&mut self) {
if self.transaction.is_some() {
if self.transaction.as_mut().unwrap().rollback(self.db_backend) {
Expand Down
2 changes: 2 additions & 0 deletions src/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub use db_connection::*;
pub use mock::*;
pub use statement::*;
pub use stream::*;
use tracing::instrument;
pub use transaction::*;

use crate::DbErr;
Expand Down Expand Up @@ -42,6 +43,7 @@ pub struct ConnectOptions {

impl Database {
/// Method to create a [DatabaseConnection] on a database
#[instrument(level = "trace", skip(opt))]
pub async fn connect<C>(opt: C) -> Result<DatabaseConnection, DbErr>
where
C: Into<ConnectOptions>,
Expand Down
101 changes: 57 additions & 44 deletions src/database/stream/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,46 @@ use futures::TryStreamExt;
#[cfg(feature = "sqlx-dep")]
use sqlx::{pool::PoolConnection, Executor};

use tracing::instrument;

use crate::{DbErr, InnerConnection, QueryResult, Statement};

/// Creates a stream from a [QueryResult]
#[ouroboros::self_referencing]
pub struct QueryStream {
stmt: Statement,
conn: InnerConnection,
#[borrows(mut conn, stmt)]
metric_callback: Option<crate::metric::Callback>,
#[borrows(mut conn, stmt, metric_callback)]
#[not_covariant]
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this>>,
}

#[cfg(feature = "sqlx-mysql")]
impl From<(PoolConnection<sqlx::MySql>, Statement)> for QueryStream {
fn from((conn, stmt): (PoolConnection<sqlx::MySql>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::MySql(conn))
impl From<(PoolConnection<sqlx::MySql>, Statement, Option<crate::metric::Callback>)> for QueryStream {
fn from((conn, stmt, metric_callback): (PoolConnection<sqlx::MySql>, Statement, Option<crate::metric::Callback>)) -> Self {
QueryStream::build(stmt, InnerConnection::MySql(conn), metric_callback)
}
}

#[cfg(feature = "sqlx-postgres")]
impl From<(PoolConnection<sqlx::Postgres>, Statement)> for QueryStream {
fn from((conn, stmt): (PoolConnection<sqlx::Postgres>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::Postgres(conn))
impl From<(PoolConnection<sqlx::Postgres>, Statement, Option<crate::metric::Callback>)> for QueryStream {
fn from((conn, stmt, metric_callback): (PoolConnection<sqlx::Postgres>, Statement, Option<crate::metric::Callback>)) -> Self {
QueryStream::build(stmt, InnerConnection::Postgres(conn), metric_callback)
}
}

#[cfg(feature = "sqlx-sqlite")]
impl From<(PoolConnection<sqlx::Sqlite>, Statement)> for QueryStream {
fn from((conn, stmt): (PoolConnection<sqlx::Sqlite>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::Sqlite(conn))
impl From<(PoolConnection<sqlx::Sqlite>, Statement, Option<crate::metric::Callback>)> for QueryStream {
fn from((conn, stmt, metric_callback): (PoolConnection<sqlx::Sqlite>, Statement, Option<crate::metric::Callback>)) -> Self {
QueryStream::build(stmt, InnerConnection::Sqlite(conn), metric_callback)
}
}

#[cfg(feature = "mock")]
impl From<(Arc<crate::MockDatabaseConnection>, Statement)> for QueryStream {
fn from((conn, stmt): (Arc<crate::MockDatabaseConnection>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::Mock(conn))
impl From<(Arc<crate::MockDatabaseConnection>, Statement, Option<crate::metric::Callback>)> for QueryStream {
fn from((conn, stmt, metric_callback): (Arc<crate::MockDatabaseConnection>, Statement, Option<crate::metric::Callback>)) -> Self {
QueryStream::build(stmt, InnerConnection::Mock(conn), metric_callback)
}
}

Expand All @@ -59,41 +62,51 @@ impl std::fmt::Debug for QueryStream {
}

impl QueryStream {
fn build(stmt: Statement, conn: InnerConnection) -> QueryStream {
#[instrument(level = "trace", skip(metric_callback))]
fn build(stmt: Statement, conn: InnerConnection, metric_callback: Option<crate::metric::Callback>) -> QueryStream {
QueryStreamBuilder {
stmt,
conn,
stream_builder: |conn, stmt| match conn {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
metric_callback,
stream_builder: |conn, stmt, metric_callback| {
match conn {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
crate::metric::metric_ok!(metric_callback, stmt, {
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
})
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
crate::metric::metric_ok!(metric_callback, stmt, {
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
})
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
crate::metric::metric_ok!(metric_callback, stmt, {
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
})
}
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => c.fetch(stmt),
}
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => c.fetch(stmt),
},
}
}
.build()
}
Expand Down
Loading

0 comments on commit 7da5b6b

Please sign in to comment.