From c657042c36e05807459af666f8024c106ea55afe Mon Sep 17 00:00:00 2001 From: James Prestwich Date: Thu, 18 Apr 2024 23:27:38 +0100 Subject: [PATCH] feature: WalletProvider (#569) * feature: WalletProvider * feature: signer_mut * test: a basic one * test: bubbles up * refactor: renames for clarity --- crates/network/src/ethereum/signer.rs | 6 +- crates/network/src/transaction/signer.rs | 10 +- crates/provider/src/fillers/join_fill.rs | 17 ++++ crates/provider/src/fillers/mod.rs | 4 +- crates/provider/src/fillers/signer.rs | 14 ++- crates/provider/src/lib.rs | 3 + crates/provider/src/wallet.rs | 113 +++++++++++++++++++++++ 7 files changed, 156 insertions(+), 11 deletions(-) create mode 100644 crates/provider/src/wallet.rs diff --git a/crates/network/src/ethereum/signer.rs b/crates/network/src/ethereum/signer.rs index 8aae8168952..adc94d09ead 100644 --- a/crates/network/src/ethereum/signer.rs +++ b/crates/network/src/ethereum/signer.rs @@ -100,15 +100,15 @@ impl NetworkSigner for EthereumSigner where N: Network, { - fn default_signer(&self) -> Address { + fn default_signer_address(&self) -> Address { self.default } - fn is_signer_for(&self, address: &Address) -> bool { + fn has_signer_for(&self, address: &Address) -> bool { self.secp_signers.contains_key(address) } - fn signers(&self) -> impl Iterator { + fn signer_addresses(&self) -> impl Iterator { self.secp_signers.keys().copied() } diff --git a/crates/network/src/transaction/signer.rs b/crates/network/src/transaction/signer.rs index e9913914b05..a9d70493697 100644 --- a/crates/network/src/transaction/signer.rs +++ b/crates/network/src/transaction/signer.rs @@ -20,13 +20,13 @@ pub trait NetworkSigner: std::fmt::Debug + Send + Sync { /// Get the default signer address. This address should be used /// in [`NetworkSigner::sign_transaction_from`] when no specific signer is /// specified. - fn default_signer(&self) -> Address; + fn default_signer_address(&self) -> Address; /// Return true if the signer contains a credential for the given address. - fn is_signer_for(&self, address: &Address) -> bool; + fn has_signer_for(&self, address: &Address) -> bool; /// Return an iterator of all signer addresses. - fn signers(&self) -> impl Iterator; + fn signer_addresses(&self) -> impl Iterator; /// Asynchronously sign an unsigned transaction, with a specified /// credential. @@ -41,7 +41,7 @@ pub trait NetworkSigner: std::fmt::Debug + Send + Sync { &self, tx: N::UnsignedTx, ) -> impl_future!(>) { - self.sign_transaction_from(self.default_signer(), tx) + self.sign_transaction_from(self.default_signer_address(), tx) } /// Asynchronously sign a transaction request, using the sender specified @@ -50,7 +50,7 @@ pub trait NetworkSigner: std::fmt::Debug + Send + Sync { &self, request: N::TransactionRequest, ) -> alloy_signer::Result { - let sender = request.from().unwrap_or_else(|| self.default_signer()); + let sender = request.from().unwrap_or_else(|| self.default_signer_address()); let tx = request.build_unsigned().map_err(|(_, e)| alloy_signer::Error::other(e))?; self.sign_transaction_from(sender, tx).await } diff --git a/crates/provider/src/fillers/join_fill.rs b/crates/provider/src/fillers/join_fill.rs index 857ac2ab077..8a0ff963c7c 100644 --- a/crates/provider/src/fillers/join_fill.rs +++ b/crates/provider/src/fillers/join_fill.rs @@ -23,6 +23,23 @@ impl JoinFill { pub const fn new(left: L, right: R) -> Self { Self { left, right } } + + /// Get a reference to the left filler. + pub const fn left(&self) -> &L { + &self.left + } + + /// Get a reference to the right filler. + pub const fn right(&self) -> &R { + &self.right + } + + /// Get a mutable reference to the left filler. + /// + /// NB: this function exists to enable the [`crate::WalletProvider`] impl. + pub(crate) fn right_mut(&mut self) -> &mut R { + &mut self.right + } } impl JoinFill { diff --git a/crates/provider/src/fillers/mod.rs b/crates/provider/src/fillers/mod.rs index b5dac499532..36298394618 100644 --- a/crates/provider/src/fillers/mod.rs +++ b/crates/provider/src/fillers/mod.rs @@ -213,8 +213,8 @@ where T: Transport + Clone, N: Network, { - inner: P, - filler: F, + pub(crate) inner: P, + pub(crate) filler: F, _pd: PhantomData (T, N)>, } diff --git a/crates/provider/src/fillers/signer.rs b/crates/provider/src/fillers/signer.rs index a6fe89702ad..c1e04c525d4 100644 --- a/crates/provider/src/fillers/signer.rs +++ b/crates/provider/src/fillers/signer.rs @@ -31,6 +31,18 @@ pub struct SignerFiller { signer: S, } +impl AsRef for SignerFiller { + fn as_ref(&self) -> &S { + &self.signer + } +} + +impl AsMut for SignerFiller { + fn as_mut(&mut self) -> &mut S { + &mut self.signer + } +} + impl SignerFiller { /// Creates a new signing layer with the given signer. pub const fn new(signer: S) -> Self { @@ -82,7 +94,7 @@ where }; if builder.from().is_none() { - builder.set_from(self.signer.default_signer()); + builder.set_from(self.signer.default_signer_address()); if !builder.can_build() { return Ok(SendableTx::Builder(builder)); } diff --git a/crates/provider/src/lib.rs b/crates/provider/src/lib.rs index ac39e56ac7e..1749ef360ba 100644 --- a/crates/provider/src/lib.rs +++ b/crates/provider/src/lib.rs @@ -49,6 +49,9 @@ pub use heart::{PendingTransaction, PendingTransactionBuilder, PendingTransactio mod provider; pub use provider::{FilterPollerBuilder, Provider, RootProvider}; +mod wallet; +pub use wallet::WalletProvider; + pub mod admin; pub mod debug; pub mod txpool; diff --git a/crates/provider/src/wallet.rs b/crates/provider/src/wallet.rs new file mode 100644 index 00000000000..03c1950f7eb --- /dev/null +++ b/crates/provider/src/wallet.rs @@ -0,0 +1,113 @@ +use crate::{ + fillers::{FillProvider, JoinFill, SignerFiller, TxFiller}, + Provider, +}; +use alloy_network::{Ethereum, Network, NetworkSigner}; +use alloy_primitives::Address; +use alloy_transport::Transport; + +/// Trait for Providers, Fill stacks, etc, which contain [`NetworkSigner`]. +pub trait WalletProvider { + /// The underlying [`NetworkSigner`] type contained in this stack. + type Signer: NetworkSigner; + + /// Get a reference to the underlying signer. + fn signer(&self) -> &Self::Signer; + + /// Get a mutable reference to the underlying signer. + fn signer_mut(&mut self) -> &mut Self::Signer; + + /// Get the default signer address. + fn default_signer_address(&self) -> Address { + self.signer().default_signer_address() + } + + /// Check if the signer can sign for the given address. + fn has_signer_for(&self, address: &Address) -> bool { + self.signer().has_signer_for(address) + } + + /// Get an iterator of all signer addresses. Note that because the signer + /// always has at least one address, this iterator will always have at least + /// one element. + fn signer_addresses(&self) -> impl Iterator { + self.signer().signer_addresses() + } +} + +impl WalletProvider for SignerFiller +where + S: NetworkSigner + Clone, + N: Network, +{ + type Signer = S; + + #[inline(always)] + fn signer(&self) -> &Self::Signer { + self.as_ref() + } + + #[inline(always)] + fn signer_mut(&mut self) -> &mut Self::Signer { + self.as_mut() + } +} + +impl WalletProvider for JoinFill +where + R: WalletProvider, + N: Network, +{ + type Signer = R::Signer; + + #[inline(always)] + fn signer(&self) -> &Self::Signer { + self.right().signer() + } + + #[inline(always)] + fn signer_mut(&mut self) -> &mut Self::Signer { + self.right_mut().signer_mut() + } +} + +impl WalletProvider for FillProvider +where + F: TxFiller + WalletProvider, + P: Provider, + T: Transport + Clone, + N: Network, +{ + type Signer = F::Signer; + + #[inline(always)] + fn signer(&self) -> &Self::Signer { + self.filler.signer() + } + + #[inline(always)] + fn signer_mut(&mut self) -> &mut Self::Signer { + self.filler.signer_mut() + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::ProviderBuilder; + + #[test] + fn basic_usage() { + let (provider, _anvil) = ProviderBuilder::new().on_anvil_with_signer(); + + assert_eq!(provider.default_signer_address(), provider.signer_addresses().next().unwrap()); + } + + #[test] + fn bubbles_through_fillers() { + let (provider, _anvil) = + ProviderBuilder::new().with_recommended_fillers().on_anvil_with_signer(); + + assert_eq!(provider.default_signer_address(), provider.signer_addresses().next().unwrap()); + } +}