From f06a910d755164d77069dfcaccecedccb1c9d2f4 Mon Sep 17 00:00:00 2001 From: Yash Atreya <44857776+yash-atreya@users.noreply.github.com> Date: Mon, 14 Oct 2024 12:29:58 +0530 Subject: [PATCH] feat(`cheatcodes`): vm.rememberKeys (#9087) * feat(`cheatcodes`): vm.rememberKeys * docs + return addresses + test * remeberKeys with language * doc nits * cargo cheats * set script wallet in config if unset * nit * test --- crates/cheatcodes/assets/cheatcodes.json | 40 +++++++++ crates/cheatcodes/spec/src/vm.rs | 14 ++++ crates/cheatcodes/src/crypto.rs | 102 +++++++++++++++++++++-- crates/forge/tests/cli/script.rs | 38 +++++++++ crates/forge/tests/cli/test_cmd.rs | 55 ++++++++++++ testdata/cheats/Vm.sol | 2 + 6 files changed, 244 insertions(+), 7 deletions(-) diff --git a/crates/cheatcodes/assets/cheatcodes.json b/crates/cheatcodes/assets/cheatcodes.json index b41f736677bcb..554cae92a890a 100644 --- a/crates/cheatcodes/assets/cheatcodes.json +++ b/crates/cheatcodes/assets/cheatcodes.json @@ -7533,6 +7533,46 @@ "status": "stable", "safety": "safe" }, + { + "func": { + "id": "rememberKeys_0", + "description": "Derive a set number of wallets from a mnemonic at the derivation path `m/44'/60'/0'/0/{0..count}`.\nThe respective private keys are saved to the local forge wallet for later use and their addresses are returned.", + "declaration": "function rememberKeys(string calldata mnemonic, string calldata derivationPath, uint32 count) external returns (address[] memory keyAddrs);", + "visibility": "external", + "mutability": "", + "signature": "rememberKeys(string,string,uint32)", + "selector": "0x97cb9189", + "selectorBytes": [ + 151, + 203, + 145, + 137 + ] + }, + "group": "crypto", + "status": "stable", + "safety": "safe" + }, + { + "func": { + "id": "rememberKeys_1", + "description": "Derive a set number of wallets from a mnemonic in the specified language at the derivation path `m/44'/60'/0'/0/{0..count}`.\nThe respective private keys are saved to the local forge wallet for later use and their addresses are returned.", + "declaration": "function rememberKeys(string calldata mnemonic, string calldata derivationPath, string calldata language, uint32 count) external returns (address[] memory keyAddrs);", + "visibility": "external", + "mutability": "", + "signature": "rememberKeys(string,string,string,uint32)", + "selector": "0xf8d58eaf", + "selectorBytes": [ + 248, + 213, + 142, + 175 + ] + }, + "group": "crypto", + "status": "stable", + "safety": "safe" + }, { "func": { "id": "removeDir", diff --git a/crates/cheatcodes/spec/src/vm.rs b/crates/cheatcodes/spec/src/vm.rs index 2a342a04438ec..0fd8a62a7edc6 100644 --- a/crates/cheatcodes/spec/src/vm.rs +++ b/crates/cheatcodes/spec/src/vm.rs @@ -2407,6 +2407,20 @@ interface Vm { #[cheatcode(group = Crypto)] function rememberKey(uint256 privateKey) external returns (address keyAddr); + /// Derive a set number of wallets from a mnemonic at the derivation path `m/44'/60'/0'/0/{0..count}`. + /// + /// The respective private keys are saved to the local forge wallet for later use and their addresses are returned. + #[cheatcode(group = Crypto)] + function rememberKeys(string calldata mnemonic, string calldata derivationPath, uint32 count) external returns (address[] memory keyAddrs); + + /// Derive a set number of wallets from a mnemonic in the specified language at the derivation path `m/44'/60'/0'/0/{0..count}`. + /// + /// The respective private keys are saved to the local forge wallet for later use and their addresses are returned. + #[cheatcode(group = Crypto)] + function rememberKeys(string calldata mnemonic, string calldata derivationPath, string calldata language, uint32 count) + external + returns (address[] memory keyAddrs); + // -------- Uncategorized Utilities -------- /// Labels an address in call traces. diff --git a/crates/cheatcodes/src/crypto.rs b/crates/cheatcodes/src/crypto.rs index f080938ac5a28..e5e6f56e36609 100644 --- a/crates/cheatcodes/src/crypto.rs +++ b/crates/cheatcodes/src/crypto.rs @@ -1,6 +1,6 @@ //! Implementations of [`Crypto`](spec::Group::Crypto) Cheatcodes. -use crate::{Cheatcode, Cheatcodes, Result, Vm::*}; +use crate::{Cheatcode, Cheatcodes, Result, ScriptWallets, Vm::*}; use alloy_primitives::{keccak256, Address, B256, U256}; use alloy_signer::{Signer, SignerSync}; use alloy_signer_local::{ @@ -8,14 +8,16 @@ use alloy_signer_local::{ ChineseSimplified, ChineseTraditional, Czech, English, French, Italian, Japanese, Korean, Portuguese, Spanish, Wordlist, }, - MnemonicBuilder, PrivateKeySigner, + LocalSigner, MnemonicBuilder, PrivateKeySigner, }; use alloy_sol_types::SolValue; +use foundry_wallets::multi_wallet::MultiWallet; use k256::{ ecdsa::SigningKey, elliptic_curve::{bigint::ArrayEncoding, sec1::ToEncodedPoint}, }; use p256::ecdsa::{signature::hazmat::PrehashSigner, Signature, SigningKey as P256SigningKey}; +use std::sync::Arc; /// The BIP32 default derivation path prefix. const DEFAULT_DERIVATION_PATH_PREFIX: &str = "m/44'/60'/0'/0/"; @@ -89,14 +91,56 @@ impl Cheatcode for rememberKeyCall { fn apply(&self, state: &mut Cheatcodes) -> Result { let Self { privateKey } = self; let wallet = parse_wallet(privateKey)?; - let address = wallet.address(); - if let Some(script_wallets) = state.script_wallets() { - script_wallets.add_local_signer(wallet); - } + let address = inject_wallet(state, wallet); Ok(address.abi_encode()) } } +impl Cheatcode for rememberKeys_0Call { + fn apply(&self, state: &mut Cheatcodes) -> Result { + let Self { mnemonic, derivationPath, count } = self; + tracing::info!("Remembering {} keys", count); + let wallets = derive_wallets::(mnemonic, derivationPath, *count)?; + + tracing::info!("Adding {} keys to script wallets", count); + + let mut addresses = Vec::
::with_capacity(wallets.len()); + for wallet in wallets { + let addr = inject_wallet(state, wallet); + addresses.push(addr); + } + + Ok(addresses.abi_encode()) + } +} + +impl Cheatcode for rememberKeys_1Call { + fn apply(&self, state: &mut Cheatcodes) -> Result { + let Self { mnemonic, derivationPath, language, count } = self; + let wallets = derive_wallets_str(mnemonic, derivationPath, language, *count)?; + let mut addresses = Vec::
::with_capacity(wallets.len()); + for wallet in wallets { + let addr = inject_wallet(state, wallet); + addresses.push(addr); + } + + Ok(addresses.abi_encode()) + } +} + +fn inject_wallet(state: &mut Cheatcodes, wallet: LocalSigner) -> Address { + let address = wallet.address(); + if let Some(script_wallets) = state.script_wallets() { + script_wallets.add_local_signer(wallet); + } else { + // This is needed in case of testing scripts, wherein script wallets are not set on setup. + let script_wallets = ScriptWallets::new(MultiWallet::default(), None); + script_wallets.add_local_signer(wallet); + Arc::make_mut(&mut state.config).script_wallets = Some(script_wallets); + } + address +} + impl Cheatcode for sign_1Call { fn apply(&self, _state: &mut Cheatcodes) -> Result { let Self { privateKey, digest } = self; @@ -228,7 +272,7 @@ fn sign_with_wallet( } else if signers.len() == 1 { *signers.keys().next().unwrap() } else { - bail!("could not determine signer"); + bail!("could not determine signer, there are multiple signers available use vm.sign(signer, digest) to specify one"); }; let wallet = signers @@ -309,6 +353,50 @@ fn derive_key(mnemonic: &str, path: &str, index: u32) -> Result { Ok(private_key.abi_encode()) } +fn derive_wallets_str( + mnemonic: &str, + path: &str, + language: &str, + count: u32, +) -> Result>> { + match language { + "chinese_simplified" => derive_wallets::(mnemonic, path, count), + "chinese_traditional" => derive_wallets::(mnemonic, path, count), + "czech" => derive_wallets::(mnemonic, path, count), + "english" => derive_wallets::(mnemonic, path, count), + "french" => derive_wallets::(mnemonic, path, count), + "italian" => derive_wallets::(mnemonic, path, count), + "japanese" => derive_wallets::(mnemonic, path, count), + "korean" => derive_wallets::(mnemonic, path, count), + "portuguese" => derive_wallets::(mnemonic, path, count), + "spanish" => derive_wallets::(mnemonic, path, count), + _ => Err(fmt_err!("unsupported mnemonic language: {language:?}")), + } +} + +fn derive_wallets( + mnemonic: &str, + path: &str, + count: u32, +) -> Result>> { + let mut out = path.to_string(); + + if !out.ends_with('/') { + out.push('/'); + } + + let mut wallets = Vec::with_capacity(count as usize); + for idx in 0..count { + let wallet = MnemonicBuilder::::default() + .phrase(mnemonic) + .derivation_path(format!("{out}{idx}"))? + .build()?; + wallets.push(wallet); + } + + Ok(wallets) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/forge/tests/cli/script.rs b/crates/forge/tests/cli/script.rs index e8f1466b911ef..81b7461205fe3 100644 --- a/crates/forge/tests/cli/script.rs +++ b/crates/forge/tests/cli/script.rs @@ -2108,3 +2108,41 @@ Script ran successfully. "#]]); }); + +forgetest_init!(can_remeber_keys, |prj, cmd| { + let script = prj + .add_source( + "Foo", + r#" +import "forge-std/Script.sol"; + +interface Vm { + function rememberKeys(string calldata mnemonic, string calldata derivationPath, uint32 count) external returns (address[] memory keyAddrs); +} + +contract WalletScript is Script { + function run() public { + string memory mnemonic = "test test test test test test test test test test test junk"; + string memory derivationPath = "m/44'/60'/0'/0/"; + address[] memory wallets = Vm(address(vm)).rememberKeys(mnemonic, derivationPath, 3); + for (uint256 i = 0; i < wallets.length; i++) { + console.log(wallets[i]); + } + } +}"#, + ) + .unwrap(); + cmd.arg("script").arg(script).assert_success().stdout_eq(str![[r#" +[COMPILING_FILES] with [SOLC_VERSION] +[SOLC_VERSION] [ELAPSED] +Compiler run successful! +Script ran successfully. +[GAS] + +== Logs == + 0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266 + 0x70997970C51812dc3A010C7d01b50e0d17dc79C8 + 0x3C44CdDdB6a900fa2b585dd299e03d12FA4293BC + +"#]]); +}); diff --git a/crates/forge/tests/cli/test_cmd.rs b/crates/forge/tests/cli/test_cmd.rs index ea3ce1911ba97..2cd258f661602 100644 --- a/crates/forge/tests/cli/test_cmd.rs +++ b/crates/forge/tests/cli/test_cmd.rs @@ -2219,3 +2219,58 @@ warning: specifying argument for --decode-internal is deprecated and will be rem "#]]); }); + +// Test a script that calls vm.rememberKeys +forgetest_init!(script_testing, |prj, cmd| { + prj + .add_source( + "Foo", + r#" +import "forge-std/Script.sol"; + +interface Vm { +function rememberKeys(string calldata mnemonic, string calldata derivationPath, uint32 count) external returns (address[] memory keyAddrs); +} + +contract WalletScript is Script { +function run() public { + string memory mnemonic = "test test test test test test test test test test test junk"; + string memory derivationPath = "m/44'/60'/0'/0/"; + address[] memory wallets = Vm(address(vm)).rememberKeys(mnemonic, derivationPath, 3); + for (uint256 i = 0; i < wallets.length; i++) { + console.log(wallets[i]); + } +} +} + +contract FooTest { + WalletScript public script; + + + function setUp() public { + script = new WalletScript(); + } + + function testWalletScript() public { + script.run(); + } +} + +"#, + ) + .unwrap(); + + cmd.args(["test", "--mt", "testWalletScript", "-vvv"]).assert_success().stdout_eq(str![[r#" +[COMPILING_FILES] with [SOLC_VERSION] +[SOLC_VERSION] [ELAPSED] +Compiler run successful! + +Ran 1 test for src/Foo.sol:FooTest +[PASS] testWalletScript() ([GAS]) +Logs: + 0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266 + 0x70997970C51812dc3A010C7d01b50e0d17dc79C8 + 0x3C44CdDdB6a900fa2b585dd299e03d12FA4293BC +... +"#]]); +}); diff --git a/testdata/cheats/Vm.sol b/testdata/cheats/Vm.sol index e47a79cc65ddb..bb9b07a6595c3 100644 --- a/testdata/cheats/Vm.sol +++ b/testdata/cheats/Vm.sol @@ -371,6 +371,8 @@ interface Vm { function record() external; function recordLogs() external; function rememberKey(uint256 privateKey) external returns (address keyAddr); + function rememberKeys(string calldata mnemonic, string calldata derivationPath, uint32 count) external returns (address[] memory keyAddrs); + function rememberKeys(string calldata mnemonic, string calldata derivationPath, string calldata language, uint32 count) external returns (address[] memory keyAddrs); function removeDir(string calldata path, bool recursive) external; function removeFile(string calldata path) external; function replace(string calldata input, string calldata from, string calldata to) external pure returns (string memory output);