From e15d8c6b7d01f78c3014c8261d0d3328ff08a44b Mon Sep 17 00:00:00 2001 From: Ryan Pate Date: Thu, 1 Feb 2024 11:46:54 -0800 Subject: [PATCH] fix(cheatcodes): Properly record call to create2 factory in state diff --- crates/cheatcodes/src/inspector.rs | 53 +++++++++++++++-- crates/forge/tests/it/repros.rs | 7 +++ testdata/repros/Issue6634.t.sol | 95 ++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 testdata/repros/Issue6634.t.sol diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index a036ffa597e0d..07b32d71c9361 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -1237,7 +1237,13 @@ impl Inspector for Cheatcodes { // Apply the Create2 deployer if self.broadcast.is_some() || self.config.always_use_create_2_factory { - match apply_create2_deployer(data, call, self.prank.as_ref(), self.broadcast.as_ref()) { + match apply_create2_deployer( + data, + call, + self.prank.as_ref(), + self.broadcast.as_ref(), + &mut self.recorded_account_diffs_stack, + ) { Ok(_) => {} Err(err) => return (InstructionResult::Revert, None, gas, Error::encode(err)), }; @@ -1249,6 +1255,15 @@ impl Inspector for Cheatcodes { let address = self.allow_cheatcodes_on_create(data, call); // If `recordAccountAccesses` has been called, record the create if let Some(recorded_account_diffs_stack) = &mut self.recorded_account_diffs_stack { + // If the create scheme is create2, and the caller is the DEFAULT_CREATE2_DEPLOYER then + // we must add 1 to the depth to account for the call to the create2 factory. + let mut depth = data.journaled_state.depth(); + if let CreateScheme::Create2 { salt: _ } = call.scheme { + if call.caller == DEFAULT_CREATE2_DEPLOYER { + depth += 1; + } + } + // Record the create context as an account access and create a new vector to record all // subsequent account accesses recorded_account_diffs_stack.push(vec![AccountAccess { @@ -1269,7 +1284,7 @@ impl Inspector for Cheatcodes { deployedCode: vec![], // updated on create_end storageAccesses: vec![], // updated on create_end }, - depth: data.journaled_state.depth(), + depth, }]); } @@ -1432,17 +1447,47 @@ fn apply_create2_deployer( call: &mut CreateInputs, prank: Option<&Prank>, broadcast: Option<&Broadcast>, + diffs_stack: &mut Option>>, ) -> Result<(), DB::Error> { - if let CreateScheme::Create2 { salt: _ } = call.scheme { + if let CreateScheme::Create2 { salt } = call.scheme { let mut base_depth = 1; if let Some(prank) = &prank { base_depth = prank.depth; } else if let Some(broadcast) = &broadcast { base_depth = broadcast.depth; } + // If the create scheme is Create2 and the depth equals the broadcast/prank/default // depth, then use the default create2 factory as the deployer if data.journaled_state.depth() == base_depth { + if let Some(recorded_account_diffs_stack) = diffs_stack { + // If broadcasting, or the create2 factory option is enabled, then record + // the call to the create2 factory. We do not explicitly check this here + // because those cases are checked when the `apply_create2_deployer` function + // is called. + let calldata = [&salt.to_be_bytes::<32>()[..], &call.init_code[..]].concat(); + recorded_account_diffs_stack.push(vec![AccountAccess { + access: crate::Vm::AccountAccess { + chainInfo: crate::Vm::ChainInfo { + forkId: data.db.active_fork_id().unwrap_or_default(), + chainId: U256::from(data.env.cfg.chain_id), + }, + accessor: call.caller, + account: DEFAULT_CREATE2_DEPLOYER, + kind: crate::Vm::AccountAccessKind::Call, + initialized: true, + oldBalance: U256::ZERO, // updated on create_end + newBalance: U256::ZERO, // updated on create_end + value: call.value, + data: calldata, + reverted: false, + deployedCode: vec![], // updated on create_end + storageAccesses: vec![], // updated on create_end + }, + depth: data.journaled_state.depth(), + }]) + } + // Sanity checks for our CREATE2 deployer let info = &data.journaled_state.load_account(DEFAULT_CREATE2_DEPLOYER, data.db)?.0.info; @@ -1475,9 +1520,9 @@ fn process_broadcast_create( data: &mut EVMData<'_, DB>, call: &mut CreateInputs, ) -> (Bytes, Option
, u64) { + call.caller = broadcast_sender; match call.scheme { CreateScheme::Create => { - call.caller = broadcast_sender; (bytecode, None, data.journaled_state.account(broadcast_sender).info.nonce) } CreateScheme::Create2 { salt } => { diff --git a/crates/forge/tests/it/repros.rs b/crates/forge/tests/it/repros.rs index 7323d7ef131a4..8506a248de017 100644 --- a/crates/forge/tests/it/repros.rs +++ b/crates/forge/tests/it/repros.rs @@ -310,3 +310,10 @@ test_repro!(5529; |config| { cheats_config.always_use_create_2_factory = true; config.runner.cheats_config = std::sync::Arc::new(cheats_config); }); + +// https://github.com/foundry-rs/foundry/issues/6634 +test_repro!(6634; |config| { + let mut cheats_config = config.runner.cheats_config.as_ref().clone(); + cheats_config.always_use_create_2_factory = true; + config.runner.cheats_config = std::sync::Arc::new(cheats_config); +}); diff --git a/testdata/repros/Issue6634.t.sol b/testdata/repros/Issue6634.t.sol new file mode 100644 index 0000000000000..3b1acb9c18aa9 --- /dev/null +++ b/testdata/repros/Issue6634.t.sol @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +pragma solidity 0.8.18; + +import "ds-test/test.sol"; +import "../cheats/Vm.sol"; +import "../logs/console.sol"; + +contract Box { + uint256 public number; + + constructor(uint256 _number) { + number = _number; + } +} + +// https://github.com/foundry-rs/foundry/issues/6634 +contract Issue6634Test is DSTest { + Vm constant vm = Vm(HEVM_ADDRESS); + + function test_Create2FactoryCallRecordedInStandardTest() public { + address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C; + + vm.startStateDiffRecording(); + Box a = new Box{salt: 0}(1); + + Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff(); + address addr = vm.computeCreate2Address( + 0, keccak256(abi.encodePacked(type(Box).creationCode, uint256(1))), address(CREATE2_DEPLOYER) + ); + assertEq(addr, called[1].account, "state diff contract address is not correct"); + assertEq(address(a), called[1].account, "returned address is not correct"); + + assertEq(called.length, 2, "incorrect length"); + assertEq(uint256(called[0].kind), uint256(Vm.AccountAccessKind.Call), "first AccountAccess is incorrect kind"); + assertEq(called[0].account, CREATE2_DEPLOYER, "first AccountAccess account is incorrect"); + assertEq(called[0].accessor, address(this), "first AccountAccess accessor is incorrect"); + assertEq( + uint256(called[1].kind), uint256(Vm.AccountAccessKind.Create), "second AccountAccess is incorrect kind" + ); + assertEq(called[1].accessor, CREATE2_DEPLOYER, "second AccountAccess accessor is incorrect"); + assertEq(called[1].account, address(a), "second AccountAccess account is incorrect"); + } + + function test_Create2FactoryCallRecordedWhenPranking() public { + address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C; + address accessor = address(0x5555); + + vm.startPrank(accessor); + vm.startStateDiffRecording(); + Box a = new Box{salt: 0}(1); + + Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff(); + address addr = vm.computeCreate2Address( + 0, keccak256(abi.encodePacked(type(Box).creationCode, uint256(1))), address(CREATE2_DEPLOYER) + ); + assertEq(addr, called[1].account, "state diff contract address is not correct"); + assertEq(address(a), called[1].account, "returned address is not correct"); + + assertEq(called.length, 2, "incorrect length"); + assertEq(uint256(called[0].kind), uint256(Vm.AccountAccessKind.Call), "first AccountAccess is incorrect kind"); + assertEq(called[0].account, CREATE2_DEPLOYER, "first AccountAccess accout is incorrect"); + assertEq(called[0].accessor, accessor, "first AccountAccess accessor is incorrect"); + assertEq( + uint256(called[1].kind), uint256(Vm.AccountAccessKind.Create), "second AccountAccess is incorrect kind" + ); + assertEq(called[1].accessor, CREATE2_DEPLOYER, "second AccountAccess accessor is incorrect"); + assertEq(called[1].account, address(a), "second AccountAccess account is incorrect"); + } + + function test_Create2FactoryCallRecordedWhenBroadcasting() public { + address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C; + address accessor = address(0x5555); + + vm.startBroadcast(accessor); + vm.startStateDiffRecording(); + Box a = new Box{salt: 0}(1); + + Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff(); + address addr = vm.computeCreate2Address( + 0, keccak256(abi.encodePacked(type(Box).creationCode, uint256(1))), address(CREATE2_DEPLOYER) + ); + assertEq(addr, called[1].account, "state diff contract address is not correct"); + assertEq(address(a), called[1].account, "returned address is not correct"); + + assertEq(called.length, 2, "incorrect length"); + assertEq(uint256(called[0].kind), uint256(Vm.AccountAccessKind.Call), "first AccountAccess is incorrect kind"); + assertEq(called[0].account, CREATE2_DEPLOYER, "first AccountAccess accout is incorrect"); + assertEq(called[0].accessor, accessor, "first AccountAccess accessor is incorrect"); + assertEq( + uint256(called[1].kind), uint256(Vm.AccountAccessKind.Create), "second AccountAccess is incorrect kind" + ); + assertEq(called[1].accessor, CREATE2_DEPLOYER, "second AccountAccess accessor is incorrect"); + assertEq(called[1].account, address(a), "second AccountAccess account is incorrect"); + } +}