Skip to content

Commit

Permalink
Add function for automatic transaction retry (#58)
Browse files Browse the repository at this point in the history
Attempts to implement https://www.cockroachlabs.com/docs/v23.1/advanced-client-side-transaction-retries

This functionality includes some cockroach-specific calls, so it exists behind a new `cockroach` feature flag.
  • Loading branch information
smklein authored Dec 6, 2023
1 parent 2d15684 commit ed7ab5e
Show file tree
Hide file tree
Showing 5 changed files with 444 additions and 34 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ jobs:
- name: Check style
run: cargo fmt -- --check

check-without-cockroach:
runs-on: ubuntu-latest
steps:
# actions/checkout@v2
- uses: actions/checkout@72f2cec99f417b1a1c5e2e88945068983b7965f9
- name: Cargo check
run: cargo check --no-default-features

build-and-test:
runs-on: ${{ matrix.os }}
strategy:
Expand Down
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ license = "MIT"
repository = "https://github.com/oxidecomputer/async-bb8-diesel"
keywords = ["diesel", "r2d2", "pool", "tokio", "async"]

[features]
# Enables CockroachDB-specific functions.
cockroach = []
default = [ "cockroach" ]

[dependencies]
bb8 = "0.8"
async-trait = "0.1.73"
Expand Down
218 changes: 192 additions & 26 deletions src/async_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
use crate::connection::Connection as SingleConnection;
use async_trait::async_trait;
use diesel::{
connection::{Connection as DieselConnection, SimpleConnection, TransactionManager},
connection::{
Connection as DieselConnection, SimpleConnection, TransactionManager,
TransactionManagerStatus,
},
dsl::Limit,
query_dsl::{
methods::{ExecuteDsl, LimitDsl, LoadQuery},
Expand All @@ -25,21 +28,28 @@ where
async fn batch_execute_async(&self, query: &str) -> Result<(), DieselError>;
}

#[cfg(feature = "cockroach")]
fn retryable_error(err: &DieselError) -> bool {
use diesel::result::DatabaseErrorKind::SerializationFailure;
match err {
DieselError::DatabaseError(SerializationFailure, _boxed_error_information) => true,
_ => false,
}
}

/// An async variant of [`diesel::connection::Connection`].
#[async_trait]
pub trait AsyncConnection<Conn>: AsyncSimpleConnection<Conn>
where
Conn: 'static + DieselConnection,
Self: Send,
Self: Send + Sized + 'static,
{
type OwnedConnection: Sync + Send + 'static;

#[doc(hidden)]
async fn get_owned_connection(&self) -> Self::OwnedConnection;
fn get_owned_connection(&self) -> Self;
#[doc(hidden)]
fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn>;
fn as_sync_conn(&self) -> MutexGuard<'_, Conn>;
#[doc(hidden)]
fn as_async_conn(owned: &Self::OwnedConnection) -> &SingleConnection<Conn>;
fn as_async_conn(&self) -> &SingleConnection<Conn>;

/// Runs the function `f` in an context where blocking is safe.
async fn run<R, E, Func>(&self, f: Func) -> Result<R, E>
Expand All @@ -48,40 +58,195 @@ where
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
let connection = self.get_owned_connection().await;
Self::run_with_connection(connection, f).await
let connection = self.get_owned_connection();
connection.run_with_connection(f).await
}

#[doc(hidden)]
async fn run_with_connection<R, E, Func>(
connection: Self::OwnedConnection,
f: Func,
) -> Result<R, E>
async fn run_with_connection<R, E, Func>(self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
spawn_blocking(move || f(&mut *self.as_sync_conn()))
.await
.unwrap() // Propagate panics
}

#[doc(hidden)]
async fn run_with_shared_connection<R, E, Func>(
connection: Arc<Self::OwnedConnection>,
f: Func,
) -> Result<R, E>
async fn run_with_shared_connection<R, E, Func>(self: &Arc<Self>, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
let conn = self.clone();
spawn_blocking(move || f(&mut *conn.as_sync_conn()))
.await
.unwrap() // Propagate panics
}

#[doc(hidden)]
async fn transaction_depth(&self) -> Result<u32, DieselError> {
let conn = self.get_owned_connection();

Self::run_with_connection(conn, |conn| {
match Conn::TransactionManager::transaction_manager_status_mut(&mut *conn) {
TransactionManagerStatus::Valid(status) => {
Ok(status.transaction_depth().map(|d| d.into()).unwrap_or(0))
}
TransactionManagerStatus::InError => Err(DieselError::BrokenTransactionManager),
}
})
.await
}

// Diesel's "begin_transaction" chooses whether to issue "BEGIN" or a
// "SAVEPOINT" depending on the transaction depth.
//
// This method is a wrapper around that call, with validation that
// we're actually issuing the BEGIN statement here.
#[doc(hidden)]
async fn start_transaction(self: &Arc<Self>) -> Result<(), DieselError> {
if self.transaction_depth().await? != 0 {
return Err(DieselError::AlreadyInTransaction);
}
self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn))
.await?;
Ok(())
}

// Diesel's "begin_transaction" chooses whether to issue "BEGIN" or a
// "SAVEPOINT" depending on the transaction depth.
//
// This method is a wrapper around that call, with validation that
// we're actually issuing our first SAVEPOINT here.
#[doc(hidden)]
async fn add_retry_savepoint(self: &Arc<Self>) -> Result<(), DieselError> {
match self.transaction_depth().await? {
0 => return Err(DieselError::NotInTransaction),
1 => (),
_ => return Err(DieselError::AlreadyInTransaction),
};

self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn))
.await?;
Ok(())
}

#[doc(hidden)]
async fn commit_transaction(self: &Arc<Self>) -> Result<(), DieselError> {
self.run_with_shared_connection(|conn| Conn::TransactionManager::commit_transaction(conn))
.await?;
Ok(())
}

#[doc(hidden)]
async fn rollback_transaction(self: &Arc<Self>) -> Result<(), DieselError> {
self.run_with_shared_connection(|conn| {
Conn::TransactionManager::rollback_transaction(conn)
})
.await?;
Ok(())
}

/// Issues a function `f` as a transaction.
///
/// If it fails, asynchronously calls `retry` to decide if to retry.
///
/// This function throws an error if it is called from within an existing
/// transaction.
#[cfg(feature = "cockroach")]
async fn transaction_async_with_retry<R, Func, Fut, RetryFut, RetryFunc, 'a>(
&'a self,
f: Func,
retry: RetryFunc,
) -> Result<R, DieselError>
where
R: Send + 'static,
Fut: Future<Output = Result<R, DieselError>> + Send,
Func: Fn(SingleConnection<Conn>) -> Fut + Send + Sync,
RetryFut: Future<Output = bool> + Send,
RetryFunc: Fn() -> RetryFut + Send + Sync,
{
// Check out a connection once, and use it for the duration of the
// operation.
let conn = Arc::new(self.get_owned_connection());

// Refer to CockroachDB's guide on advanced client-side transaction
// retries for the full context:
// https://www.cockroachlabs.com/docs/v23.1/advanced-client-side-transaction-retries
//
// In short, they expect a particular name for this savepoint, but
// Diesel has Opinions on savepoint names, so we use this session
// variable to identify that any name is valid.
//
// TODO: It may be preferable to set this once per connection -- but
// that'll require more interaction with how sessions with the database
// are constructed.
Self::start_transaction(&conn).await?;
conn.run_with_shared_connection(|conn| {
conn.batch_execute("SET LOCAL force_savepoint_restart = true")
})
.await?;

loop {
// Add a SAVEPOINT to which we can later return.
Self::add_retry_savepoint(&conn).await?;

let async_conn = SingleConnection(Self::as_async_conn(&conn).0.clone());
match f(async_conn).await {
Ok(value) => {
// The user-level operation succeeded: try to commit the
// transaction by RELEASE-ing the retry savepoint.
if let Err(err) = Self::commit_transaction(&conn).await {
// Diesel's implementation of "commit_transaction"
// calls "rollback_transaction" in the error path.
//
// We're still in the transaction, but we at least
// tried to ROLLBACK to our savepoint.
if !retryable_error(&err) || !retry().await {
// Bail: ROLLBACK the initial BEGIN statement too.
let _ = Self::rollback_transaction(&conn).await;
return Err(err);
}
// ROLLBACK happened, we want to retry.
continue;
}

// Commit the top-level transaction too.
Self::commit_transaction(&conn).await?;
return Ok(value);
}
Err(user_error) => {
// The user-level operation failed: ROLLBACK to the retry
// savepoint.
if let Err(first_rollback_err) = Self::rollback_transaction(&conn).await {
// If we fail while rolling back, prioritize returning
// the ROLLBACK error over the user errors.
return match Self::rollback_transaction(&conn).await {
Ok(()) => Err(first_rollback_err),
Err(second_rollback_err) => Err(second_rollback_err),
};
}

// We rolled back to the retry savepoint, and now want to
// retry.
if retryable_error(&user_error) && retry().await {
continue;
}

// If we aren't retrying, ROLLBACK the BEGIN statement too.
return match Self::rollback_transaction(&conn).await {
Ok(()) => Err(user_error),
Err(err) => Err(err),
};
}
}
}
}

async fn transaction_async<R, E, Func, Fut, 'a>(&'a self, f: Func) -> Result<R, E>
where
R: Send + 'static,
Expand All @@ -91,14 +256,14 @@ where
{
// Check out a connection once, and use it for the duration of the
// operation.
let conn = Arc::new(self.get_owned_connection().await);
let conn = Arc::new(self.get_owned_connection());

// This function mimics the implementation of:
// https://docs.diesel.rs/master/diesel/connection/trait.TransactionManager.html#method.transaction
//
// However, it modifies all callsites to instead issue
// known-to-be-synchronous operations from an asynchronous context.
Self::run_with_shared_connection(conn.clone(), |conn| {
conn.run_with_shared_connection(|conn| {
Conn::TransactionManager::begin_transaction(conn).map_err(E::from)
})
.await?;
Expand All @@ -118,17 +283,18 @@ where
let async_conn = SingleConnection(Self::as_async_conn(&conn).0.clone());
match f(async_conn).await {
Ok(value) => {
Self::run_with_shared_connection(conn.clone(), |conn| {
conn.run_with_shared_connection(|conn| {
Conn::TransactionManager::commit_transaction(conn).map_err(E::from)
})
.await?;
Ok(value)
}
Err(user_error) => {
match Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::rollback_transaction(conn).map_err(E::from)
})
.await
match conn
.run_with_shared_connection(|conn| {
Conn::TransactionManager::rollback_transaction(conn).map_err(E::from)
})
.await
{
Ok(()) => Err(user_error),
Err(err) => Err(err),
Expand Down
16 changes: 9 additions & 7 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,19 @@ where
Conn: 'static + R2D2Connection,
Connection<Conn>: crate::AsyncSimpleConnection<Conn>,
{
type OwnedConnection = Connection<Conn>;

async fn get_owned_connection(&self) -> Self::OwnedConnection {
fn get_owned_connection(&self) -> Self {
Connection(self.0.clone())
}

fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn> {
owned.inner()
// Accesses the connection synchronously, protected by a mutex.
//
// Avoid calling from asynchronous contexts.
fn as_sync_conn(&self) -> MutexGuard<'_, Conn> {
self.inner()
}

fn as_async_conn(owned: &Self::OwnedConnection) -> &Connection<Conn> {
owned
// TODO: Consider removing me.
fn as_async_conn(&self) -> &Connection<Conn> {
self
}
}
Loading

0 comments on commit ed7ab5e

Please sign in to comment.