From ed7ab5ef0513ba303d33efd41d3e9e381169d59b Mon Sep 17 00:00:00 2001 From: Sean Klein Date: Tue, 5 Dec 2023 16:10:43 -0800 Subject: [PATCH] Add function for automatic transaction retry (#58) Attempts to implement https://www.cockroachlabs.com/docs/v23.1/advanced-client-side-transaction-retries This functionality includes some cockroach-specific calls, so it exists behind a new `cockroach` feature flag. --- .github/workflows/rust.yml | 8 ++ Cargo.toml | 5 + src/async_traits.rs | 218 +++++++++++++++++++++++++++++----- src/connection.rs | 16 +-- tests/test.rs | 231 ++++++++++++++++++++++++++++++++++++- 5 files changed, 444 insertions(+), 34 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c8b3c70..e4f6e9f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -23,6 +23,14 @@ jobs: - name: Check style run: cargo fmt -- --check + check-without-cockroach: + runs-on: ubuntu-latest + steps: + # actions/checkout@v2 + - uses: actions/checkout@72f2cec99f417b1a1c5e2e88945068983b7965f9 + - name: Cargo check + run: cargo check --no-default-features + build-and-test: runs-on: ${{ matrix.os }} strategy: diff --git a/Cargo.toml b/Cargo.toml index 7102210..77eb826 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,11 @@ license = "MIT" repository = "https://github.com/oxidecomputer/async-bb8-diesel" keywords = ["diesel", "r2d2", "pool", "tokio", "async"] +[features] +# Enables CockroachDB-specific functions. +cockroach = [] +default = [ "cockroach" ] + [dependencies] bb8 = "0.8" async-trait = "0.1.73" diff --git a/src/async_traits.rs b/src/async_traits.rs index f8ddfd8..d9543c7 100644 --- a/src/async_traits.rs +++ b/src/async_traits.rs @@ -3,7 +3,10 @@ use crate::connection::Connection as SingleConnection; use async_trait::async_trait; use diesel::{ - connection::{Connection as DieselConnection, SimpleConnection, TransactionManager}, + connection::{ + Connection as DieselConnection, SimpleConnection, TransactionManager, + TransactionManagerStatus, + }, dsl::Limit, query_dsl::{ methods::{ExecuteDsl, LimitDsl, LoadQuery}, @@ -25,21 +28,28 @@ where async fn batch_execute_async(&self, query: &str) -> Result<(), DieselError>; } +#[cfg(feature = "cockroach")] +fn retryable_error(err: &DieselError) -> bool { + use diesel::result::DatabaseErrorKind::SerializationFailure; + match err { + DieselError::DatabaseError(SerializationFailure, _boxed_error_information) => true, + _ => false, + } +} + /// An async variant of [`diesel::connection::Connection`]. #[async_trait] pub trait AsyncConnection: AsyncSimpleConnection where Conn: 'static + DieselConnection, - Self: Send, + Self: Send + Sized + 'static, { - type OwnedConnection: Sync + Send + 'static; - #[doc(hidden)] - async fn get_owned_connection(&self) -> Self::OwnedConnection; + fn get_owned_connection(&self) -> Self; #[doc(hidden)] - fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn>; + fn as_sync_conn(&self) -> MutexGuard<'_, Conn>; #[doc(hidden)] - fn as_async_conn(owned: &Self::OwnedConnection) -> &SingleConnection; + fn as_async_conn(&self) -> &SingleConnection; /// Runs the function `f` in an context where blocking is safe. async fn run(&self, f: Func) -> Result @@ -48,40 +58,195 @@ where E: Send + 'static, Func: FnOnce(&mut Conn) -> Result + Send + 'static, { - let connection = self.get_owned_connection().await; - Self::run_with_connection(connection, f).await + let connection = self.get_owned_connection(); + connection.run_with_connection(f).await } #[doc(hidden)] - async fn run_with_connection( - connection: Self::OwnedConnection, - f: Func, - ) -> Result + async fn run_with_connection(self, f: Func) -> Result where R: Send + 'static, E: Send + 'static, Func: FnOnce(&mut Conn) -> Result + Send + 'static, { - spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection))) + spawn_blocking(move || f(&mut *self.as_sync_conn())) .await .unwrap() // Propagate panics } #[doc(hidden)] - async fn run_with_shared_connection( - connection: Arc, - f: Func, - ) -> Result + async fn run_with_shared_connection(self: &Arc, f: Func) -> Result where R: Send + 'static, E: Send + 'static, Func: FnOnce(&mut Conn) -> Result + Send + 'static, { - spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection))) + let conn = self.clone(); + spawn_blocking(move || f(&mut *conn.as_sync_conn())) .await .unwrap() // Propagate panics } + #[doc(hidden)] + async fn transaction_depth(&self) -> Result { + let conn = self.get_owned_connection(); + + Self::run_with_connection(conn, |conn| { + match Conn::TransactionManager::transaction_manager_status_mut(&mut *conn) { + TransactionManagerStatus::Valid(status) => { + Ok(status.transaction_depth().map(|d| d.into()).unwrap_or(0)) + } + TransactionManagerStatus::InError => Err(DieselError::BrokenTransactionManager), + } + }) + .await + } + + // Diesel's "begin_transaction" chooses whether to issue "BEGIN" or a + // "SAVEPOINT" depending on the transaction depth. + // + // This method is a wrapper around that call, with validation that + // we're actually issuing the BEGIN statement here. + #[doc(hidden)] + async fn start_transaction(self: &Arc) -> Result<(), DieselError> { + if self.transaction_depth().await? != 0 { + return Err(DieselError::AlreadyInTransaction); + } + self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) + .await?; + Ok(()) + } + + // Diesel's "begin_transaction" chooses whether to issue "BEGIN" or a + // "SAVEPOINT" depending on the transaction depth. + // + // This method is a wrapper around that call, with validation that + // we're actually issuing our first SAVEPOINT here. + #[doc(hidden)] + async fn add_retry_savepoint(self: &Arc) -> Result<(), DieselError> { + match self.transaction_depth().await? { + 0 => return Err(DieselError::NotInTransaction), + 1 => (), + _ => return Err(DieselError::AlreadyInTransaction), + }; + + self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) + .await?; + Ok(()) + } + + #[doc(hidden)] + async fn commit_transaction(self: &Arc) -> Result<(), DieselError> { + self.run_with_shared_connection(|conn| Conn::TransactionManager::commit_transaction(conn)) + .await?; + Ok(()) + } + + #[doc(hidden)] + async fn rollback_transaction(self: &Arc) -> Result<(), DieselError> { + self.run_with_shared_connection(|conn| { + Conn::TransactionManager::rollback_transaction(conn) + }) + .await?; + Ok(()) + } + + /// Issues a function `f` as a transaction. + /// + /// If it fails, asynchronously calls `retry` to decide if to retry. + /// + /// This function throws an error if it is called from within an existing + /// transaction. + #[cfg(feature = "cockroach")] + async fn transaction_async_with_retry( + &'a self, + f: Func, + retry: RetryFunc, + ) -> Result + where + R: Send + 'static, + Fut: Future> + Send, + Func: Fn(SingleConnection) -> Fut + Send + Sync, + RetryFut: Future + Send, + RetryFunc: Fn() -> RetryFut + Send + Sync, + { + // Check out a connection once, and use it for the duration of the + // operation. + let conn = Arc::new(self.get_owned_connection()); + + // Refer to CockroachDB's guide on advanced client-side transaction + // retries for the full context: + // https://www.cockroachlabs.com/docs/v23.1/advanced-client-side-transaction-retries + // + // In short, they expect a particular name for this savepoint, but + // Diesel has Opinions on savepoint names, so we use this session + // variable to identify that any name is valid. + // + // TODO: It may be preferable to set this once per connection -- but + // that'll require more interaction with how sessions with the database + // are constructed. + Self::start_transaction(&conn).await?; + conn.run_with_shared_connection(|conn| { + conn.batch_execute("SET LOCAL force_savepoint_restart = true") + }) + .await?; + + loop { + // Add a SAVEPOINT to which we can later return. + Self::add_retry_savepoint(&conn).await?; + + let async_conn = SingleConnection(Self::as_async_conn(&conn).0.clone()); + match f(async_conn).await { + Ok(value) => { + // The user-level operation succeeded: try to commit the + // transaction by RELEASE-ing the retry savepoint. + if let Err(err) = Self::commit_transaction(&conn).await { + // Diesel's implementation of "commit_transaction" + // calls "rollback_transaction" in the error path. + // + // We're still in the transaction, but we at least + // tried to ROLLBACK to our savepoint. + if !retryable_error(&err) || !retry().await { + // Bail: ROLLBACK the initial BEGIN statement too. + let _ = Self::rollback_transaction(&conn).await; + return Err(err); + } + // ROLLBACK happened, we want to retry. + continue; + } + + // Commit the top-level transaction too. + Self::commit_transaction(&conn).await?; + return Ok(value); + } + Err(user_error) => { + // The user-level operation failed: ROLLBACK to the retry + // savepoint. + if let Err(first_rollback_err) = Self::rollback_transaction(&conn).await { + // If we fail while rolling back, prioritize returning + // the ROLLBACK error over the user errors. + return match Self::rollback_transaction(&conn).await { + Ok(()) => Err(first_rollback_err), + Err(second_rollback_err) => Err(second_rollback_err), + }; + } + + // We rolled back to the retry savepoint, and now want to + // retry. + if retryable_error(&user_error) && retry().await { + continue; + } + + // If we aren't retrying, ROLLBACK the BEGIN statement too. + return match Self::rollback_transaction(&conn).await { + Ok(()) => Err(user_error), + Err(err) => Err(err), + }; + } + } + } + } + async fn transaction_async(&'a self, f: Func) -> Result where R: Send + 'static, @@ -91,14 +256,14 @@ where { // Check out a connection once, and use it for the duration of the // operation. - let conn = Arc::new(self.get_owned_connection().await); + let conn = Arc::new(self.get_owned_connection()); // This function mimics the implementation of: // https://docs.diesel.rs/master/diesel/connection/trait.TransactionManager.html#method.transaction // // However, it modifies all callsites to instead issue // known-to-be-synchronous operations from an asynchronous context. - Self::run_with_shared_connection(conn.clone(), |conn| { + conn.run_with_shared_connection(|conn| { Conn::TransactionManager::begin_transaction(conn).map_err(E::from) }) .await?; @@ -118,17 +283,18 @@ where let async_conn = SingleConnection(Self::as_async_conn(&conn).0.clone()); match f(async_conn).await { Ok(value) => { - Self::run_with_shared_connection(conn.clone(), |conn| { + conn.run_with_shared_connection(|conn| { Conn::TransactionManager::commit_transaction(conn).map_err(E::from) }) .await?; Ok(value) } Err(user_error) => { - match Self::run_with_shared_connection(conn.clone(), |conn| { - Conn::TransactionManager::rollback_transaction(conn).map_err(E::from) - }) - .await + match conn + .run_with_shared_connection(|conn| { + Conn::TransactionManager::rollback_transaction(conn).map_err(E::from) + }) + .await { Ok(()) => Err(user_error), Err(err) => Err(err), diff --git a/src/connection.rs b/src/connection.rs index ff65b4e..1f00b10 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -50,17 +50,19 @@ where Conn: 'static + R2D2Connection, Connection: crate::AsyncSimpleConnection, { - type OwnedConnection = Connection; - - async fn get_owned_connection(&self) -> Self::OwnedConnection { + fn get_owned_connection(&self) -> Self { Connection(self.0.clone()) } - fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn> { - owned.inner() + // Accesses the connection synchronously, protected by a mutex. + // + // Avoid calling from asynchronous contexts. + fn as_sync_conn(&self) -> MutexGuard<'_, Conn> { + self.inner() } - fn as_async_conn(owned: &Self::OwnedConnection) -> &Connection { - owned + // TODO: Consider removing me. + fn as_async_conn(&self) -> &Connection { + self } } diff --git a/tests/test.rs b/tests/test.rs index 3cba9f9..7aad3af 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -2,7 +2,9 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. -use async_bb8_diesel::{AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, ConnectionError}; +use async_bb8_diesel::{ + AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, AsyncSimpleConnection, ConnectionError, +}; use diesel::OptionalExtension; use diesel::{pg::PgConnection, prelude::*}; @@ -151,6 +153,233 @@ async fn test_transaction() { test_end(crdb).await; } +#[tokio::test] +async fn test_transaction_automatic_retry_success_case() { + let crdb = test_start().await; + + let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); + let pool = bb8::Pool::builder().build(manager).await.unwrap(); + let conn = pool.get().await.unwrap(); + + use user::dsl; + + // Transaction that can retry but does not need to. + assert_eq!(conn.transaction_depth().await.unwrap(), 0); + conn.transaction_async_with_retry( + |conn| async move { + assert!(conn.transaction_depth().await.unwrap() > 0); + diesel::insert_into(dsl::user) + .values((dsl::id.eq(3), dsl::name.eq("Sally"))) + .execute_async(&conn) + .await?; + Ok(()) + }, + || async { panic!("Should not attempt to retry this operation") }, + ) + .await + .expect("Transaction failed"); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); + + test_end(crdb).await; +} + +#[tokio::test] +async fn test_transaction_automatic_retry_explicit_rollback() { + let crdb = test_start().await; + + let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); + let pool = bb8::Pool::builder().build(manager).await.unwrap(); + let conn = pool.get().await.unwrap(); + + use std::sync::{Arc, Mutex}; + + let transaction_attempted_count = Arc::new(Mutex::new(0)); + let should_retry_query_count = Arc::new(Mutex::new(0)); + + // Test a transaction that: + // + // 1. Retries on the first call + // 2. Explicitly rolls back on the second call + assert_eq!(conn.transaction_depth().await.unwrap(), 0); + let err = conn + .transaction_async_with_retry( + |_conn| { + let transaction_attempted_count = transaction_attempted_count.clone(); + async move { + let mut count = transaction_attempted_count.lock().unwrap(); + *count += 1; + + if *count < 2 { + eprintln!("test: Manually restarting txn"); + return Err::<(), _>(diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::SerializationFailure, + Box::new("restart transaction".to_string()), + )); + } + eprintln!("test: Manually rolling back txn"); + return Err(diesel::result::Error::RollbackTransaction); + } + }, + || async { + *should_retry_query_count.lock().unwrap() += 1; + true + }, + ) + .await + .expect_err("Transaction should have failed"); + + assert_eq!(err, diesel::result::Error::RollbackTransaction); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); + + // The transaction closure should have been attempted twice, but + // we should have only asked whether or not to retry once -- after + // the first failure, but not the second. + assert_eq!(*transaction_attempted_count.lock().unwrap(), 2); + assert_eq!(*should_retry_query_count.lock().unwrap(), 1); + + test_end(crdb).await; +} + +#[tokio::test] +async fn test_transaction_automatic_retry_injected_errors() { + let crdb = test_start().await; + + let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); + let pool = bb8::Pool::builder().build(manager).await.unwrap(); + let conn = pool.get().await.unwrap(); + + use std::sync::{Arc, Mutex}; + + let transaction_attempted_count = Arc::new(Mutex::new(0)); + let should_retry_query_count = Arc::new(Mutex::new(0)); + + // Tests a transaction that is forced to retry by CockroachDB. + // + // By setting this session variable, we expect that: + // - "any statement executed inside of an explicit transaction (with the + // exception of SET statements) will return a transaction retry error." + // - "after the 3rd retry error, the transaction will proceed as + // normal" + // + // See: https://www.cockroachlabs.com/docs/v23.1/transaction-retry-error-example#test-transaction-retry-logic + // for more details + const EXPECTED_ERR_COUNT: usize = 3; + conn.batch_execute_async("SET inject_retry_errors_enabled = true") + .await + .expect("Failed to inject error"); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); + conn.transaction_async_with_retry( + |conn| { + let transaction_attempted_count = transaction_attempted_count.clone(); + async move { + *transaction_attempted_count.lock().unwrap() += 1; + + use user::dsl; + let _ = diesel::insert_into(dsl::user) + .values((dsl::id.eq(0), dsl::name.eq("Jim"))) + .execute_async(&conn) + .await?; + Ok(()) + } + }, + || async { + *should_retry_query_count.lock().unwrap() += 1; + true + }, + ) + .await + .expect("Transaction should have succeeded"); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); + + // The transaction closure should have been attempted twice, but + // we should have only asked whether or not to retry once -- after + // the first failure, but not the second. + assert_eq!( + *transaction_attempted_count.lock().unwrap(), + EXPECTED_ERR_COUNT + 1 + ); + assert_eq!( + *should_retry_query_count.lock().unwrap(), + EXPECTED_ERR_COUNT + ); + + test_end(crdb).await; +} + +#[tokio::test] +async fn test_transaction_automatic_retry_does_not_retry_non_retryable_errors() { + let crdb = test_start().await; + + let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); + let pool = bb8::Pool::builder().build(manager).await.unwrap(); + let conn = pool.get().await.unwrap(); + + // Test a transaction that: + // + // Fails with a non-retryable error. It should exit immediately. + assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!( + conn.transaction_async_with_retry( + |_| async { Err::<(), _>(diesel::result::Error::NotFound) }, + || async { panic!("Should not attempt to retry this operation") } + ) + .await + .expect_err("Transaction should have failed"), + diesel::result::Error::NotFound, + ); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); + + test_end(crdb).await; +} + +#[tokio::test] +async fn test_transaction_automatic_retry_nested_transactions_fail() { + let crdb = test_start().await; + + let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); + let pool = bb8::Pool::builder().build(manager).await.unwrap(); + let conn = pool.get().await.unwrap(); + + #[derive(Debug, PartialEq)] + struct OnlyReturnFromOuterTransaction {} + + // This outer transaction should succeed immediately... + assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!( + OnlyReturnFromOuterTransaction {}, + conn.transaction_async_with_retry( + |conn| async move { + // ... but this inner transaction should fail! We do not support + // retryable nested transactions. + let err = conn + .transaction_async_with_retry( + |_| async { + panic!("Shouldn't run"); + + // Adding this unreachable statement for type inference + #[allow(unreachable_code)] + Ok(()) + }, + || async { panic!("Shouldn't retry inner transaction") }, + ) + .await + .expect_err("Nested transaction should have failed"); + assert_eq!(err, diesel::result::Error::AlreadyInTransaction); + + // We still want to show that control exists within the outer + // transaction, so we explicitly return here. + Ok(OnlyReturnFromOuterTransaction {}) + }, + || async { panic!("Shouldn't retry outer transaction") }, + ) + .await + .expect("Transaction should have succeeded") + ); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); + + test_end(crdb).await; +} + #[tokio::test] async fn test_transaction_custom_error() { let crdb = test_start().await;