diff --git a/crates/provider/src/fillers/nonce.rs b/crates/provider/src/fillers/nonce.rs index 237cb7d9583..ad9f0aa60f8 100644 --- a/crates/provider/src/fillers/nonce.rs +++ b/crates/provider/src/fillers/nonce.rs @@ -7,8 +7,8 @@ use alloy_network::{Network, TransactionBuilder}; use alloy_primitives::Address; use alloy_transport::{Transport, TransportResult}; use dashmap::DashMap; +use futures::lock::Mutex; use std::sync::Arc; -use tokio::sync::Mutex; /// A [`TxFiller`] that fills nonces on transactions. /// @@ -41,7 +41,7 @@ use tokio::sync::Mutex; /// ``` #[derive(Clone, Debug, Default)] pub struct NonceFiller { - nonces: DashMap>>>, + nonces: DashMap>>, } impl TxFiller for NonceFiller { @@ -86,29 +86,31 @@ impl TxFiller for NonceFiller { impl NonceFiller { /// Get the next nonce for the given account. - async fn get_next_nonce(&self, provider: &P, from: Address) -> TransportResult + async fn get_next_nonce(&self, provider: &P, address: Address) -> TransportResult where P: Provider, N: Network, T: Transport + Clone, { - // locks dashmap internally for a short duration to clone the `Arc` - let mutex = Arc::clone(self.nonces.entry(from).or_default().value()); - - // locks the value (does not lock dashmap) - let mut nonce = mutex.lock().await; - match *nonce { - Some(ref mut nonce) => { - *nonce += 1; - Ok(*nonce) - } - None => { - // initialize the nonce if we haven't seen this account before - let initial_nonce = provider.get_transaction_count(from).await?; - *nonce = Some(initial_nonce); - Ok(initial_nonce) - } - } + // Use `u64::MAX` as a sentinel value to indicate that the nonce has not been fetched yet. + const NONE: u64 = u64::MAX; + + // Locks dashmap internally for a short duration to clone the `Arc`. + // We also don't want to hold the dashmap lock through the await point below. + let nonce = { + let rm = self.nonces.entry(address).or_insert_with(|| Arc::new(Mutex::new(NONE))); + Arc::clone(rm.value()) + }; + + let mut nonce = nonce.lock().await; + let new_nonce = if *nonce == NONE { + // Initialize the nonce if we haven't seen this account before. + provider.get_transaction_count(address).await? + } else { + *nonce + 1 + }; + *nonce = new_nonce; + Ok(new_nonce) } } @@ -119,6 +121,58 @@ mod tests { use alloy_primitives::{address, U256}; use alloy_rpc_types_eth::TransactionRequest; + async fn check_nonces(filler: &NonceFiller, provider: &P, address: Address, start: u64) + where + P: Provider, + N: Network, + T: Transport + Clone, + { + for i in start..start + 5 { + let nonce = filler.get_next_nonce(&provider, address).await.unwrap(); + assert_eq!(nonce, i); + } + } + + #[tokio::test] + async fn smoke_test() { + let filler = NonceFiller::default(); + let provider = ProviderBuilder::new().on_anvil(); + let address = Address::ZERO; + check_nonces(&filler, &provider, address, 0).await; + + #[cfg(feature = "anvil-api")] + { + use crate::ext::AnvilApi; + filler.nonces.clear(); + provider.anvil_set_nonce(address, U256::from(69)).await.unwrap(); + check_nonces(&filler, &provider, address, 69).await; + } + } + + #[tokio::test] + async fn concurrency() { + let filler = Arc::new(NonceFiller::default()); + let provider = Arc::new(ProviderBuilder::new().on_anvil()); + let address = Address::ZERO; + let tasks = (0..5) + .map(|_| { + let filler = Arc::clone(&filler); + let provider = Arc::clone(&provider); + tokio::spawn(async move { filler.get_next_nonce(&provider, address).await }) + }) + .collect::>(); + + let mut ns = Vec::new(); + for task in tasks { + ns.push(task.await.unwrap().unwrap()); + } + ns.sort_unstable(); + assert_eq!(ns, (0..5).collect::>()); + + assert_eq!(filler.nonces.len(), 1); + assert_eq!(*filler.nonces.get(&address).unwrap().value().lock().await, 4); + } + #[tokio::test] async fn no_nonce_if_sender_unset() { let provider = ProviderBuilder::new().with_nonce_management().on_anvil();