Skip to content

Commit

Permalink
fix(cheatcodes): Properly record call to create2 factory in state diff
Browse files Browse the repository at this point in the history
  • Loading branch information
RPate97 committed Feb 22, 2024
1 parent 9fe9a3f commit e15d8c6
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 4 deletions.
53 changes: 49 additions & 4 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,13 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {

// Apply the Create2 deployer
if self.broadcast.is_some() || self.config.always_use_create_2_factory {
match apply_create2_deployer(data, call, self.prank.as_ref(), self.broadcast.as_ref()) {
match apply_create2_deployer(
data,
call,
self.prank.as_ref(),
self.broadcast.as_ref(),
&mut self.recorded_account_diffs_stack,
) {
Ok(_) => {}
Err(err) => return (InstructionResult::Revert, None, gas, Error::encode(err)),
};
Expand All @@ -1249,6 +1255,15 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
let address = self.allow_cheatcodes_on_create(data, call);
// If `recordAccountAccesses` has been called, record the create
if let Some(recorded_account_diffs_stack) = &mut self.recorded_account_diffs_stack {
// If the create scheme is create2, and the caller is the DEFAULT_CREATE2_DEPLOYER then
// we must add 1 to the depth to account for the call to the create2 factory.
let mut depth = data.journaled_state.depth();
if let CreateScheme::Create2 { salt: _ } = call.scheme {
if call.caller == DEFAULT_CREATE2_DEPLOYER {
depth += 1;
}
}

// Record the create context as an account access and create a new vector to record all
// subsequent account accesses
recorded_account_diffs_stack.push(vec![AccountAccess {
Expand All @@ -1269,7 +1284,7 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
deployedCode: vec![], // updated on create_end
storageAccesses: vec![], // updated on create_end
},
depth: data.journaled_state.depth(),
depth,
}]);
}

Expand Down Expand Up @@ -1432,17 +1447,47 @@ fn apply_create2_deployer<DB: DatabaseExt>(
call: &mut CreateInputs,
prank: Option<&Prank>,
broadcast: Option<&Broadcast>,
diffs_stack: &mut Option<Vec<Vec<AccountAccess>>>,
) -> Result<(), DB::Error> {
if let CreateScheme::Create2 { salt: _ } = call.scheme {
if let CreateScheme::Create2 { salt } = call.scheme {
let mut base_depth = 1;
if let Some(prank) = &prank {
base_depth = prank.depth;
} else if let Some(broadcast) = &broadcast {
base_depth = broadcast.depth;
}

// If the create scheme is Create2 and the depth equals the broadcast/prank/default
// depth, then use the default create2 factory as the deployer
if data.journaled_state.depth() == base_depth {
if let Some(recorded_account_diffs_stack) = diffs_stack {
// If broadcasting, or the create2 factory option is enabled, then record
// the call to the create2 factory. We do not explicitly check this here
// because those cases are checked when the `apply_create2_deployer` function
// is called.
let calldata = [&salt.to_be_bytes::<32>()[..], &call.init_code[..]].concat();
recorded_account_diffs_stack.push(vec![AccountAccess {
access: crate::Vm::AccountAccess {
chainInfo: crate::Vm::ChainInfo {
forkId: data.db.active_fork_id().unwrap_or_default(),
chainId: U256::from(data.env.cfg.chain_id),
},
accessor: call.caller,
account: DEFAULT_CREATE2_DEPLOYER,
kind: crate::Vm::AccountAccessKind::Call,
initialized: true,
oldBalance: U256::ZERO, // updated on create_end
newBalance: U256::ZERO, // updated on create_end
value: call.value,
data: calldata,
reverted: false,
deployedCode: vec![], // updated on create_end
storageAccesses: vec![], // updated on create_end
},
depth: data.journaled_state.depth(),
}])
}

// Sanity checks for our CREATE2 deployer
let info =
&data.journaled_state.load_account(DEFAULT_CREATE2_DEPLOYER, data.db)?.0.info;
Expand Down Expand Up @@ -1475,9 +1520,9 @@ fn process_broadcast_create<DB: DatabaseExt>(
data: &mut EVMData<'_, DB>,
call: &mut CreateInputs,
) -> (Bytes, Option<Address>, u64) {
call.caller = broadcast_sender;
match call.scheme {
CreateScheme::Create => {
call.caller = broadcast_sender;
(bytecode, None, data.journaled_state.account(broadcast_sender).info.nonce)
}
CreateScheme::Create2 { salt } => {
Expand Down
7 changes: 7 additions & 0 deletions crates/forge/tests/it/repros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,10 @@ test_repro!(5529; |config| {
cheats_config.always_use_create_2_factory = true;
config.runner.cheats_config = std::sync::Arc::new(cheats_config);
});

// https://github.com/foundry-rs/foundry/issues/6634
test_repro!(6634; |config| {
let mut cheats_config = config.runner.cheats_config.as_ref().clone();
cheats_config.always_use_create_2_factory = true;
config.runner.cheats_config = std::sync::Arc::new(cheats_config);
});
95 changes: 95 additions & 0 deletions testdata/repros/Issue6634.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// SPDX-License-Identifier: MIT OR Apache-2.0
pragma solidity 0.8.18;

import "ds-test/test.sol";
import "../cheats/Vm.sol";
import "../logs/console.sol";

contract Box {
uint256 public number;

constructor(uint256 _number) {
number = _number;
}
}

// https://github.com/foundry-rs/foundry/issues/6634
contract Issue6634Test is DSTest {
Vm constant vm = Vm(HEVM_ADDRESS);

function test_Create2FactoryCallRecordedInStandardTest() public {
address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C;

vm.startStateDiffRecording();
Box a = new Box{salt: 0}(1);

Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff();
address addr = vm.computeCreate2Address(
0, keccak256(abi.encodePacked(type(Box).creationCode, uint256(1))), address(CREATE2_DEPLOYER)
);
assertEq(addr, called[1].account, "state diff contract address is not correct");
assertEq(address(a), called[1].account, "returned address is not correct");

assertEq(called.length, 2, "incorrect length");
assertEq(uint256(called[0].kind), uint256(Vm.AccountAccessKind.Call), "first AccountAccess is incorrect kind");
assertEq(called[0].account, CREATE2_DEPLOYER, "first AccountAccess account is incorrect");
assertEq(called[0].accessor, address(this), "first AccountAccess accessor is incorrect");
assertEq(
uint256(called[1].kind), uint256(Vm.AccountAccessKind.Create), "second AccountAccess is incorrect kind"
);
assertEq(called[1].accessor, CREATE2_DEPLOYER, "second AccountAccess accessor is incorrect");
assertEq(called[1].account, address(a), "second AccountAccess account is incorrect");
}

function test_Create2FactoryCallRecordedWhenPranking() public {
address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C;
address accessor = address(0x5555);

vm.startPrank(accessor);
vm.startStateDiffRecording();
Box a = new Box{salt: 0}(1);

Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff();
address addr = vm.computeCreate2Address(
0, keccak256(abi.encodePacked(type(Box).creationCode, uint256(1))), address(CREATE2_DEPLOYER)
);
assertEq(addr, called[1].account, "state diff contract address is not correct");
assertEq(address(a), called[1].account, "returned address is not correct");

assertEq(called.length, 2, "incorrect length");
assertEq(uint256(called[0].kind), uint256(Vm.AccountAccessKind.Call), "first AccountAccess is incorrect kind");
assertEq(called[0].account, CREATE2_DEPLOYER, "first AccountAccess accout is incorrect");
assertEq(called[0].accessor, accessor, "first AccountAccess accessor is incorrect");
assertEq(
uint256(called[1].kind), uint256(Vm.AccountAccessKind.Create), "second AccountAccess is incorrect kind"
);
assertEq(called[1].accessor, CREATE2_DEPLOYER, "second AccountAccess accessor is incorrect");
assertEq(called[1].account, address(a), "second AccountAccess account is incorrect");
}

function test_Create2FactoryCallRecordedWhenBroadcasting() public {
address CREATE2_DEPLOYER = 0x4e59b44847b379578588920cA78FbF26c0B4956C;
address accessor = address(0x5555);

vm.startBroadcast(accessor);
vm.startStateDiffRecording();
Box a = new Box{salt: 0}(1);

Vm.AccountAccess[] memory called = vm.stopAndReturnStateDiff();
address addr = vm.computeCreate2Address(
0, keccak256(abi.encodePacked(type(Box).creationCode, uint256(1))), address(CREATE2_DEPLOYER)
);
assertEq(addr, called[1].account, "state diff contract address is not correct");
assertEq(address(a), called[1].account, "returned address is not correct");

assertEq(called.length, 2, "incorrect length");
assertEq(uint256(called[0].kind), uint256(Vm.AccountAccessKind.Call), "first AccountAccess is incorrect kind");
assertEq(called[0].account, CREATE2_DEPLOYER, "first AccountAccess accout is incorrect");
assertEq(called[0].accessor, accessor, "first AccountAccess accessor is incorrect");
assertEq(
uint256(called[1].kind), uint256(Vm.AccountAccessKind.Create), "second AccountAccess is incorrect kind"
);
assertEq(called[1].accessor, CREATE2_DEPLOYER, "second AccountAccess accessor is incorrect");
assertEq(called[1].account, address(a), "second AccountAccess account is incorrect");
}
}

0 comments on commit e15d8c6

Please sign in to comment.