From 66c23c85dbc998471b9a9ac6be5cea8c36a289f9 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Wed, 10 Nov 2021 14:40:44 +0800 Subject: [PATCH] Revert MySQL & SQLite returning support --- .github/workflows/rust.yml | 2 ++ Cargo.toml | 1 - src/database/connection.rs | 10 +++--- src/database/db_connection.rs | 60 +++---------------------------- src/database/transaction.rs | 33 ++--------------- src/driver/sqlx_mysql.rs | 67 +++++------------------------------ src/driver/sqlx_sqlite.rs | 60 +++++-------------------------- src/executor/insert.rs | 6 ++-- src/executor/update.rs | 2 +- tests/returning_tests.rs | 15 ++++---- 10 files changed, 43 insertions(+), 213 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c4d112352..2926e2ab8 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -288,6 +288,7 @@ jobs: name: Examples runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: [ubuntu-latest] path: [basic, actix_example, actix4_example, axum_example, rocket_example] @@ -312,6 +313,7 @@ jobs: if: ${{ (needs.init.outputs.run-partial == 'true' && needs.init.outputs.run-issues == 'true') }} runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: [ubuntu-latest] path: [86, 249, 262] diff --git a/Cargo.toml b/Cargo.toml index 3f0e0ea31..d6d41c35e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,6 @@ sqlx = { version = "^0.5", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } ouroboros = "0.11" url = "^2.2" -regex = "^1" [dev-dependencies] smol = { version = "^1.2" } diff --git a/src/database/connection.rs b/src/database/connection.rs index 2a16156e9..c5730a4bd 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -45,11 +45,11 @@ pub trait ConnectionTrait<'a>: Sync { T: Send, E: std::error::Error + Send; - /// Check if the connection supports `RETURNING` syntax on insert - fn returning_on_insert(&self) -> bool; - - /// Check if the connection supports `RETURNING` syntax on update - fn returning_on_update(&self) -> bool; + /// Check if the connection supports `RETURNING` syntax on insert and update + fn support_returning(&self) -> bool { + let db_backend = self.get_database_backend(); + db_backend.support_returning() + } /// Check if the connection is a test connection for the Mock database fn is_mock_connection(&self) -> bool { diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index ab9d734a1..99de8633f 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -214,61 +214,6 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { } } - fn returning_on_insert(&self) -> bool { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => { - // Supported if it's MariaDB on or after version 10.5.0 - // Not supported in all MySQL versions - conn.support_returning - } - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(_) => { - // Supported by all Postgres versions - true - } - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => { - // Supported by SQLite on or after version 3.35.0 (2021-03-12) - conn.support_returning - } - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() { - DbBackend::MySql => false, - DbBackend::Postgres => true, - DbBackend::Sqlite => false, - }, - DatabaseConnection::Disconnected => panic!("Disconnected"), - } - } - - fn returning_on_update(&self) -> bool { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(_) => { - // Not supported in all MySQL & MariaDB versions - false - } - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(_) => { - // Supported by all Postgres versions - true - } - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => { - // Supported by SQLite on or after version 3.35.0 (2021-03-12) - conn.support_returning - } - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() { - DbBackend::MySql => false, - DbBackend::Postgres => true, - DbBackend::Sqlite => false, - }, - DatabaseConnection::Disconnected => panic!("Disconnected"), - } - } - #[cfg(feature = "mock")] fn is_mock_connection(&self) -> bool { matches!(self, DatabaseConnection::MockDatabaseConnection(_)) @@ -322,6 +267,11 @@ impl DbBackend { Self::Sqlite => Box::new(SqliteQueryBuilder), } } + + /// Check if the database supports `RETURNING` syntax on insert and update + pub fn support_returning(&self) -> bool { + matches!(self, Self::Postgres) + } } #[cfg(test)] diff --git a/src/database/transaction.rs b/src/database/transaction.rs index cfce9d58a..f4a1b6787 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -16,7 +16,6 @@ pub struct DatabaseTransaction { conn: Arc>, backend: DbBackend, open: bool, - support_returning: bool, } impl std::fmt::Debug for DatabaseTransaction { @@ -29,12 +28,10 @@ impl DatabaseTransaction { #[cfg(feature = "sqlx-mysql")] pub(crate) async fn new_mysql( inner: PoolConnection, - support_returning: bool, ) -> Result { Self::begin( Arc::new(Mutex::new(InnerConnection::MySql(inner))), DbBackend::MySql, - support_returning, ) .await } @@ -46,7 +43,6 @@ impl DatabaseTransaction { Self::begin( Arc::new(Mutex::new(InnerConnection::Postgres(inner))), DbBackend::Postgres, - true, ) .await } @@ -54,12 +50,10 @@ impl DatabaseTransaction { #[cfg(feature = "sqlx-sqlite")] pub(crate) async fn new_sqlite( inner: PoolConnection, - support_returning: bool, ) -> Result { Self::begin( Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), DbBackend::Sqlite, - support_returning, ) .await } @@ -69,28 +63,17 @@ impl DatabaseTransaction { inner: Arc, ) -> Result { let backend = inner.get_database_backend(); - Self::begin( - Arc::new(Mutex::new(InnerConnection::Mock(inner))), - backend, - match backend { - DbBackend::MySql => false, - DbBackend::Postgres => true, - DbBackend::Sqlite => false, - }, - ) - .await + Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await } async fn begin( conn: Arc>, backend: DbBackend, - support_returning: bool, ) -> Result { let res = DatabaseTransaction { conn, backend, open: true, - support_returning, }; match *res.conn.lock().await { #[cfg(feature = "sqlx-mysql")] @@ -347,8 +330,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { } async fn begin(&self) -> Result { - DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend, self.support_returning) - .await + DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await } /// Execute the function inside a transaction. @@ -365,17 +347,6 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { let transaction = self.begin().await.map_err(TransactionError::Connection)?; transaction.run(_callback).await } - - fn returning_on_insert(&self) -> bool { - self.support_returning - } - - fn returning_on_update(&self) -> bool { - match self.backend { - DbBackend::MySql => false, - _ => self.support_returning, - } - } } /// Defines errors for handling transaction failures diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index e29ee9f80..b2b89c680 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -1,9 +1,8 @@ -use regex::Regex; use std::{future::Future, pin::Pin}; use sqlx::{ mysql::{MySqlArguments, MySqlConnectOptions, MySqlQueryResult, MySqlRow}, - MySql, MySqlPool, Row, + MySql, MySqlPool, }; sea_query::sea_query_driver_mysql!(); @@ -11,7 +10,7 @@ use sea_query_driver_mysql::bind_query; use crate::{ debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, - DbBackend, QueryStream, Statement, TransactionError, + QueryStream, Statement, TransactionError, }; use super::sqlx_common::*; @@ -24,7 +23,6 @@ pub struct SqlxMySqlConnector; #[derive(Debug, Clone)] pub struct SqlxMySqlPoolConnection { pool: MySqlPool, - pub(crate) support_returning: bool, } impl SqlxMySqlConnector { @@ -44,7 +42,9 @@ impl SqlxMySqlConnector { opt.disable_statement_logging(); } if let Ok(pool) = options.pool_options().connect_with(opt).await { - into_db_connection(pool).await + Ok(DatabaseConnection::SqlxMySqlPoolConnection( + SqlxMySqlPoolConnection { pool }, + )) } else { Err(DbErr::Conn("Failed to connect.".to_owned())) } @@ -53,8 +53,8 @@ impl SqlxMySqlConnector { impl SqlxMySqlConnector { /// Instantiate a sqlx pool connection to a [DatabaseConnection] - pub async fn from_sqlx_mysql_pool(pool: MySqlPool) -> Result { - into_db_connection(pool).await + pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection { + DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection { pool }) } } @@ -129,7 +129,7 @@ impl SqlxMySqlPoolConnection { /// Bundle a set of SQL statements that execute together. pub async fn begin(&self) -> Result { if let Ok(conn) = self.pool.acquire().await { - DatabaseTransaction::new_mysql(conn, self.support_returning).await + DatabaseTransaction::new_mysql(conn).await } else { Err(DbErr::Query( "Failed to acquire connection from pool.".to_owned(), @@ -148,7 +148,7 @@ impl SqlxMySqlPoolConnection { E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { - let transaction = DatabaseTransaction::new_mysql(conn, self.support_returning) + let transaction = DatabaseTransaction::new_mysql(conn) .await .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await @@ -183,52 +183,3 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySq } query } - -async fn into_db_connection(pool: MySqlPool) -> Result { - let support_returning = parse_support_returning(&pool).await?; - Ok(DatabaseConnection::SqlxMySqlPoolConnection( - SqlxMySqlPoolConnection { - pool, - support_returning, - }, - )) -} - -async fn parse_support_returning(pool: &MySqlPool) -> Result { - let stmt = Statement::from_string( - DbBackend::MySql, - r#"SHOW VARIABLES LIKE "version""#.to_owned(), - ); - let query = sqlx_query(&stmt); - let row = query - .fetch_one(pool) - .await - .map_err(sqlx_error_to_query_err)?; - let version: String = row.try_get("Value").map_err(sqlx_error_to_query_err)?; - let support_returning = if !version.contains("MariaDB") { - // This is MySQL - // Not supported in all MySQL versions - false - } else { - // This is MariaDB - let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap(); - let captures = regex.captures(&version).unwrap(); - macro_rules! parse_captures { - ( $idx: expr ) => { - captures.get($idx).map_or(0, |m| { - m.as_str() - .parse::() - .map_err(|e| DbErr::Conn(e.to_string())) - .unwrap() - }) - }; - } - let ver_major = parse_captures!(1); - let ver_minor = parse_captures!(2); - // Supported if it's MariaDB with version 10.5.0 or after - ver_major >= 10 && ver_minor >= 5 - }; - debug_print!("db_version: {}", version); - debug_print!("db_support_returning: {}", support_returning); - Ok(support_returning) -} diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index 4ea160e85..69eee5752 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,9 +1,8 @@ -use regex::Regex; use std::{future::Future, pin::Pin}; use sqlx::{ sqlite::{SqliteArguments, SqliteConnectOptions, SqliteQueryResult, SqliteRow}, - Row, Sqlite, SqlitePool, + Sqlite, SqlitePool, }; sea_query::sea_query_driver_sqlite!(); @@ -11,7 +10,7 @@ use sea_query_driver_sqlite::bind_query; use crate::{ debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, - DbBackend, QueryStream, Statement, TransactionError, + QueryStream, Statement, TransactionError, }; use super::sqlx_common::*; @@ -24,7 +23,6 @@ pub struct SqlxSqliteConnector; #[derive(Debug, Clone)] pub struct SqlxSqlitePoolConnection { pool: SqlitePool, - pub(crate) support_returning: bool, } impl SqlxSqliteConnector { @@ -48,7 +46,9 @@ impl SqlxSqliteConnector { options.max_connections(1); } if let Ok(pool) = options.pool_options().connect_with(opt).await { - into_db_connection(pool).await + Ok(DatabaseConnection::SqlxSqlitePoolConnection( + SqlxSqlitePoolConnection { pool }, + )) } else { Err(DbErr::Conn("Failed to connect.".to_owned())) } @@ -57,8 +57,8 @@ impl SqlxSqliteConnector { impl SqlxSqliteConnector { /// Instantiate a sqlx pool connection to a [DatabaseConnection] - pub async fn from_sqlx_sqlite_pool(pool: SqlitePool) -> Result { - into_db_connection(pool).await + pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection { + DatabaseConnection::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection { pool }) } } @@ -133,7 +133,7 @@ impl SqlxSqlitePoolConnection { /// Bundle a set of SQL statements that execute together. pub async fn begin(&self) -> Result { if let Ok(conn) = self.pool.acquire().await { - DatabaseTransaction::new_sqlite(conn, self.support_returning).await + DatabaseTransaction::new_sqlite(conn).await } else { Err(DbErr::Query( "Failed to acquire connection from pool.".to_owned(), @@ -152,7 +152,7 @@ impl SqlxSqlitePoolConnection { E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { - let transaction = DatabaseTransaction::new_sqlite(conn, self.support_returning) + let transaction = DatabaseTransaction::new_sqlite(conn) .await .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await @@ -187,45 +187,3 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, Sql } query } - -async fn into_db_connection(pool: SqlitePool) -> Result { - let support_returning = parse_support_returning(&pool).await?; - Ok(DatabaseConnection::SqlxSqlitePoolConnection( - SqlxSqlitePoolConnection { - pool, - support_returning, - }, - )) -} - -async fn parse_support_returning(pool: &SqlitePool) -> Result { - let stmt = Statement::from_string( - DbBackend::Sqlite, - r#"SELECT sqlite_version() AS version"#.to_owned(), - ); - let query = sqlx_query(&stmt); - let row = query - .fetch_one(pool) - .await - .map_err(sqlx_error_to_query_err)?; - let version: String = row.try_get("version").map_err(sqlx_error_to_query_err)?; - let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap(); - let captures = regex.captures(&version).unwrap(); - macro_rules! parse_captures { - ( $idx: expr ) => { - captures.get($idx).map_or(0, |m| { - m.as_str() - .parse::() - .map_err(|e| DbErr::Conn(e.to_string())) - .unwrap() - }) - }; - } - let ver_major = parse_captures!(1); - let ver_minor = parse_captures!(2); - // Supported if it's version 3.35.0 (2021-03-12) or after - let support_returning = ver_major >= 3 && ver_minor >= 35; - debug_print!("db_version: {}", version); - debug_print!("db_support_returning: {}", support_returning); - Ok(support_returning) -} diff --git a/src/executor/insert.rs b/src/executor/insert.rs index fde1a3ab4..a6dbcbd53 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -41,7 +41,7 @@ where { // so that self is dropped before entering await let mut query = self.query; - if db.returning_on_insert() && ::PrimaryKey::iter().count() > 0 { + if db.support_returning() && ::PrimaryKey::iter().count() > 0 { let mut returning = Query::select(); returning.columns( ::PrimaryKey::iter().map(|c| c.into_column_ref()), @@ -113,7 +113,7 @@ where { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; type ValueTypeOf = as PrimaryKeyTrait>::ValueType; - let last_insert_id_opt = match db.returning_on_insert() { + let last_insert_id_opt = match db.support_returning() { true => { let cols = PrimaryKey::::iter() .map(|col| col.to_string()) @@ -147,7 +147,7 @@ where A: ActiveModelTrait, { let db_backend = db.get_database_backend(); - let found = match db.returning_on_insert() { + let found = match db.support_returning() { true => { let mut returning = Query::select(); returning.exprs(::Column::iter().map(|c| { diff --git a/src/executor/update.rs b/src/executor/update.rs index d27aa41d4..9870b10d9 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -90,7 +90,7 @@ where A: ActiveModelTrait, C: ConnectionTrait<'a>, { - match db.returning_on_update() { + match db.support_returning() { true => { let mut returning = Query::select(); returning.exprs(::Column::iter().map(|c| { diff --git a/tests/returning_tests.rs b/tests/returning_tests.rs index 55506399f..7fa0447b4 100644 --- a/tests/returning_tests.rs +++ b/tests/returning_tests.rs @@ -1,8 +1,8 @@ pub mod common; pub use common::{bakery_chain::*, setup::*, TestContext}; -use sea_orm::{entity::prelude::*, *}; -use sea_query::Query; +pub use sea_orm::{entity::prelude::*, *}; +pub use sea_query::Query; #[sea_orm_macros::test] #[cfg(any( @@ -37,7 +37,7 @@ async fn main() -> Result<(), DbErr> { create_tables(db).await?; - if db.returning_on_insert() { + if db.support_returning() { insert.returning(returning.clone()); let insert_res = db .query_one(builder.build(&insert)) @@ -46,11 +46,7 @@ async fn main() -> Result<(), DbErr> { let _id: i32 = insert_res.try_get("", "id")?; let _name: String = insert_res.try_get("", "name")?; let _profit_margin: f64 = insert_res.try_get("", "profit_margin")?; - } else { - let insert_res = db.execute(builder.build(&insert)).await?; - assert!(insert_res.rows_affected() > 0); - } - if db.returning_on_update() { + update.returning(returning.clone()); let update_res = db .query_one(builder.build(&update)) @@ -60,6 +56,9 @@ async fn main() -> Result<(), DbErr> { let _name: String = update_res.try_get("", "name")?; let _profit_margin: f64 = update_res.try_get("", "profit_margin")?; } else { + let insert_res = db.execute(builder.build(&insert)).await?; + assert!(insert_res.rows_affected() > 0); + let update_res = db.execute(builder.build(&update)).await?; assert!(update_res.rows_affected() > 0); }