Skip to content

Commit

Permalink
feat(cheatcodes): vm.rememberKeys (foundry-rs#9087)
Browse files Browse the repository at this point in the history
* feat(`cheatcodes`): vm.rememberKeys

* docs + return addresses + test

* remeberKeys with language

* doc nits

* cargo cheats

* set script wallet in config if unset

* nit

* test
  • Loading branch information
yash-atreya authored and rplusq committed Nov 29, 2024
1 parent 21eadeb commit f06a910
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 7 deletions.
40 changes: 40 additions & 0 deletions crates/cheatcodes/assets/cheatcodes.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions crates/cheatcodes/spec/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2407,6 +2407,20 @@ interface Vm {
#[cheatcode(group = Crypto)]
function rememberKey(uint256 privateKey) external returns (address keyAddr);

/// Derive a set number of wallets from a mnemonic at the derivation path `m/44'/60'/0'/0/{0..count}`.
///
/// The respective private keys are saved to the local forge wallet for later use and their addresses are returned.
#[cheatcode(group = Crypto)]
function rememberKeys(string calldata mnemonic, string calldata derivationPath, uint32 count) external returns (address[] memory keyAddrs);

/// Derive a set number of wallets from a mnemonic in the specified language at the derivation path `m/44'/60'/0'/0/{0..count}`.
///
/// The respective private keys are saved to the local forge wallet for later use and their addresses are returned.
#[cheatcode(group = Crypto)]
function rememberKeys(string calldata mnemonic, string calldata derivationPath, string calldata language, uint32 count)
external
returns (address[] memory keyAddrs);

// -------- Uncategorized Utilities --------

/// Labels an address in call traces.
Expand Down
102 changes: 95 additions & 7 deletions crates/cheatcodes/src/crypto.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
//! Implementations of [`Crypto`](spec::Group::Crypto) Cheatcodes.
use crate::{Cheatcode, Cheatcodes, Result, Vm::*};
use crate::{Cheatcode, Cheatcodes, Result, ScriptWallets, Vm::*};
use alloy_primitives::{keccak256, Address, B256, U256};
use alloy_signer::{Signer, SignerSync};
use alloy_signer_local::{
coins_bip39::{
ChineseSimplified, ChineseTraditional, Czech, English, French, Italian, Japanese, Korean,
Portuguese, Spanish, Wordlist,
},
MnemonicBuilder, PrivateKeySigner,
LocalSigner, MnemonicBuilder, PrivateKeySigner,
};
use alloy_sol_types::SolValue;
use foundry_wallets::multi_wallet::MultiWallet;
use k256::{
ecdsa::SigningKey,
elliptic_curve::{bigint::ArrayEncoding, sec1::ToEncodedPoint},
};
use p256::ecdsa::{signature::hazmat::PrehashSigner, Signature, SigningKey as P256SigningKey};
use std::sync::Arc;

/// The BIP32 default derivation path prefix.
const DEFAULT_DERIVATION_PATH_PREFIX: &str = "m/44'/60'/0'/0/";
Expand Down Expand Up @@ -89,14 +91,56 @@ impl Cheatcode for rememberKeyCall {
fn apply(&self, state: &mut Cheatcodes) -> Result {
let Self { privateKey } = self;
let wallet = parse_wallet(privateKey)?;
let address = wallet.address();
if let Some(script_wallets) = state.script_wallets() {
script_wallets.add_local_signer(wallet);
}
let address = inject_wallet(state, wallet);
Ok(address.abi_encode())
}
}

impl Cheatcode for rememberKeys_0Call {
fn apply(&self, state: &mut Cheatcodes) -> Result {
let Self { mnemonic, derivationPath, count } = self;
tracing::info!("Remembering {} keys", count);
let wallets = derive_wallets::<English>(mnemonic, derivationPath, *count)?;

tracing::info!("Adding {} keys to script wallets", count);

let mut addresses = Vec::<Address>::with_capacity(wallets.len());
for wallet in wallets {
let addr = inject_wallet(state, wallet);
addresses.push(addr);
}

Ok(addresses.abi_encode())
}
}

impl Cheatcode for rememberKeys_1Call {
fn apply(&self, state: &mut Cheatcodes) -> Result {
let Self { mnemonic, derivationPath, language, count } = self;
let wallets = derive_wallets_str(mnemonic, derivationPath, language, *count)?;
let mut addresses = Vec::<Address>::with_capacity(wallets.len());
for wallet in wallets {
let addr = inject_wallet(state, wallet);
addresses.push(addr);
}

Ok(addresses.abi_encode())
}
}

fn inject_wallet(state: &mut Cheatcodes, wallet: LocalSigner<SigningKey>) -> Address {
let address = wallet.address();
if let Some(script_wallets) = state.script_wallets() {
script_wallets.add_local_signer(wallet);
} else {
// This is needed in case of testing scripts, wherein script wallets are not set on setup.
let script_wallets = ScriptWallets::new(MultiWallet::default(), None);
script_wallets.add_local_signer(wallet);
Arc::make_mut(&mut state.config).script_wallets = Some(script_wallets);
}
address
}

impl Cheatcode for sign_1Call {
fn apply(&self, _state: &mut Cheatcodes) -> Result {
let Self { privateKey, digest } = self;
Expand Down Expand Up @@ -228,7 +272,7 @@ fn sign_with_wallet(
} else if signers.len() == 1 {
*signers.keys().next().unwrap()
} else {
bail!("could not determine signer");
bail!("could not determine signer, there are multiple signers available use vm.sign(signer, digest) to specify one");
};

let wallet = signers
Expand Down Expand Up @@ -309,6 +353,50 @@ fn derive_key<W: Wordlist>(mnemonic: &str, path: &str, index: u32) -> Result {
Ok(private_key.abi_encode())
}

fn derive_wallets_str(
mnemonic: &str,
path: &str,
language: &str,
count: u32,
) -> Result<Vec<LocalSigner<SigningKey>>> {
match language {
"chinese_simplified" => derive_wallets::<ChineseSimplified>(mnemonic, path, count),
"chinese_traditional" => derive_wallets::<ChineseTraditional>(mnemonic, path, count),
"czech" => derive_wallets::<Czech>(mnemonic, path, count),
"english" => derive_wallets::<English>(mnemonic, path, count),
"french" => derive_wallets::<French>(mnemonic, path, count),
"italian" => derive_wallets::<Italian>(mnemonic, path, count),
"japanese" => derive_wallets::<Japanese>(mnemonic, path, count),
"korean" => derive_wallets::<Korean>(mnemonic, path, count),
"portuguese" => derive_wallets::<Portuguese>(mnemonic, path, count),
"spanish" => derive_wallets::<Spanish>(mnemonic, path, count),
_ => Err(fmt_err!("unsupported mnemonic language: {language:?}")),
}
}

fn derive_wallets<W: Wordlist>(
mnemonic: &str,
path: &str,
count: u32,
) -> Result<Vec<LocalSigner<SigningKey>>> {
let mut out = path.to_string();

if !out.ends_with('/') {
out.push('/');
}

let mut wallets = Vec::with_capacity(count as usize);
for idx in 0..count {
let wallet = MnemonicBuilder::<W>::default()
.phrase(mnemonic)
.derivation_path(format!("{out}{idx}"))?
.build()?;
wallets.push(wallet);
}

Ok(wallets)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
38 changes: 38 additions & 0 deletions crates/forge/tests/cli/script.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2108,3 +2108,41 @@ Script ran successfully.
"#]]);
});

forgetest_init!(can_remeber_keys, |prj, cmd| {
let script = prj
.add_source(
"Foo",
r#"
import "forge-std/Script.sol";
interface Vm {
function rememberKeys(string calldata mnemonic, string calldata derivationPath, uint32 count) external returns (address[] memory keyAddrs);
}
contract WalletScript is Script {
function run() public {
string memory mnemonic = "test test test test test test test test test test test junk";
string memory derivationPath = "m/44'/60'/0'/0/";
address[] memory wallets = Vm(address(vm)).rememberKeys(mnemonic, derivationPath, 3);
for (uint256 i = 0; i < wallets.length; i++) {
console.log(wallets[i]);
}
}
}"#,
)
.unwrap();
cmd.arg("script").arg(script).assert_success().stdout_eq(str![[r#"
[COMPILING_FILES] with [SOLC_VERSION]
[SOLC_VERSION] [ELAPSED]
Compiler run successful!
Script ran successfully.
[GAS]
== Logs ==
0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266
0x70997970C51812dc3A010C7d01b50e0d17dc79C8
0x3C44CdDdB6a900fa2b585dd299e03d12FA4293BC
"#]]);
});
55 changes: 55 additions & 0 deletions crates/forge/tests/cli/test_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2219,3 +2219,58 @@ warning: specifying argument for --decode-internal is deprecated and will be rem
"#]]);
});

// Test a script that calls vm.rememberKeys
forgetest_init!(script_testing, |prj, cmd| {
prj
.add_source(
"Foo",
r#"
import "forge-std/Script.sol";
interface Vm {
function rememberKeys(string calldata mnemonic, string calldata derivationPath, uint32 count) external returns (address[] memory keyAddrs);
}
contract WalletScript is Script {
function run() public {
string memory mnemonic = "test test test test test test test test test test test junk";
string memory derivationPath = "m/44'/60'/0'/0/";
address[] memory wallets = Vm(address(vm)).rememberKeys(mnemonic, derivationPath, 3);
for (uint256 i = 0; i < wallets.length; i++) {
console.log(wallets[i]);
}
}
}
contract FooTest {
WalletScript public script;
function setUp() public {
script = new WalletScript();
}
function testWalletScript() public {
script.run();
}
}
"#,
)
.unwrap();

cmd.args(["test", "--mt", "testWalletScript", "-vvv"]).assert_success().stdout_eq(str![[r#"
[COMPILING_FILES] with [SOLC_VERSION]
[SOLC_VERSION] [ELAPSED]
Compiler run successful!
Ran 1 test for src/Foo.sol:FooTest
[PASS] testWalletScript() ([GAS])
Logs:
0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266
0x70997970C51812dc3A010C7d01b50e0d17dc79C8
0x3C44CdDdB6a900fa2b585dd299e03d12FA4293BC
...
"#]]);
});
2 changes: 2 additions & 0 deletions testdata/cheats/Vm.sol

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit f06a910

Please sign in to comment.