diff --git a/src/database/connection.rs b/src/database/connection.rs index 4919ecb25..8df3d7839 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -31,6 +31,9 @@ pub trait ConnectionTrait: Sync { fn is_mock_connection(&self) -> bool { false } + + /// Explicitly close the database connection + async fn close(self) -> Result<(), DbErr>; } /// Stream query results diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 7fd51633b..affd8c999 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -162,6 +162,25 @@ impl ConnectionTrait for DatabaseConnection { fn is_mock_connection(&self) -> bool { matches!(self, DatabaseConnection::MockDatabaseConnection(_)) } + + async fn close(self) -> Result<(), DbErr> { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.close().await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.close().await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.close().await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + // Nothing to cleanup, we just consume the `DatabaseConnection` + Ok(()) + } + DatabaseConnection::Disconnected => { + Err(DbErr::Conn(RuntimeErr::Internal("Disconnected".to_owned()))) + } + } + } } #[async_trait::async_trait] diff --git a/src/database/transaction.rs b/src/database/transaction.rs index 85fd304a6..8324b84bc 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -386,6 +386,10 @@ impl ConnectionTrait for DatabaseTransaction { _ => unreachable!(), } } + + async fn close(self) -> Result<(), DbErr> { + self.rollback().await + } } impl StreamTrait for DatabaseTransaction { diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 73dfa59fd..fc734482f 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -185,6 +185,11 @@ impl SqlxMySqlPoolConnection { { self.metric_callback = Some(Arc::new(callback)); } + + /// Explicitly close the MySQL connection + pub async fn close(self) -> Result<(), DbErr> { + Ok(self.pool.close().await) + } } impl From for QueryResult { diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index 167b34b85..2b5816aaf 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -200,6 +200,11 @@ impl SqlxPostgresPoolConnection { { self.metric_callback = Some(Arc::new(callback)); } + + /// Explicitly close the Postgres connection + pub async fn close(self) -> Result<(), DbErr> { + Ok(self.pool.close().await) + } } impl From for QueryResult { diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index b09f1bd0d..13b857242 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -192,6 +192,11 @@ impl SqlxSqlitePoolConnection { { self.metric_callback = Some(Arc::new(callback)); } + + /// Explicitly close the SQLite connection + pub async fn close(self) -> Result<(), DbErr> { + Ok(self.pool.close().await) + } } impl From for QueryResult {