Skip to content

Commit

Permalink
feat(blockifier): support reverts in native (#2271)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoni-Starkware authored Nov 25, 2024
1 parent 145a96a commit 6b63197
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub fn execute_entry_point_call(
Some(builtin_costs),
&mut syscall_handler,
);
syscall_handler.finalize();

let call_result = execution_result.map_err(EntryPointExecutionError::NativeUnexpectedError)?;

Expand Down
26 changes: 9 additions & 17 deletions crates/blockifier/src/execution/native/syscall_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,12 @@ impl<'state> NativeSyscallHandler<'state> {
entry_point: CallEntryPoint,
remaining_gas: &mut u64,
) -> SyscallResult<Retdata> {
let call_info = entry_point
.execute(self.base.state, self.base.context, remaining_gas)
.map_err(|e| self.handle_error(remaining_gas, e.into()))?;
let retdata = call_info.execution.retdata.clone();

if call_info.execution.failed {
let error = SyscallExecutionError::SyscallError { error_data: retdata.0 };
return Err(self.handle_error(remaining_gas, error));
}

self.base.inner_calls.push(call_info);
let raw_retdata = self
.base
.execute_inner_call(entry_point, remaining_gas)
.map_err(|e| self.handle_error(remaining_gas, e))?;

Ok(retdata)
Ok(Retdata(raw_retdata))
}

pub fn gas_costs(&self) -> &GasCosts {
Expand Down Expand Up @@ -228,6 +221,9 @@ impl<'state> NativeSyscallHandler<'state> {
}),
}
}
pub fn finalize(&mut self) {
self.base.finalize();
}
}

impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> {
Expand Down Expand Up @@ -471,11 +467,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> {

let key = StorageKey::try_from(address)
.map_err(|e| self.handle_error(remaining_gas, e.into()))?;
self.base.accessed_keys.insert(key);

let write_result =
self.base.state.set_storage_at(self.base.call.storage_address, key, value);
write_result.map_err(|e| self.handle_error(remaining_gas, e.into()))?;
self.base.storage_write(key, value).map_err(|e| self.handle_error(remaining_gas, e))?;

Ok(())
}
Expand Down
31 changes: 2 additions & 29 deletions crates/blockifier/src/execution/syscalls/hint_processor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::any::Any;
use std::collections::{hash_map, HashMap};
use std::collections::HashMap;

use cairo_lang_casm::hints::{Hint, StarknetHint};
use cairo_lang_runner::casm_run::execute_core_hint_base;
Expand Down Expand Up @@ -65,7 +65,6 @@ use crate::execution::syscalls::{
storage_read,
storage_write,
StorageReadResponse,
StorageWriteResponse,
SyscallRequest,
SyscallRequestWrapper,
SyscallResponse,
Expand Down Expand Up @@ -643,34 +642,8 @@ impl<'a> SyscallHintProcessor<'a> {
Ok(StorageReadResponse { value })
}

pub fn set_contract_storage_at(
&mut self,
key: StorageKey,
value: Felt,
) -> SyscallResult<StorageWriteResponse> {
let contract_address = self.storage_address();

match self.base.original_values.entry(key) {
hash_map::Entry::Vacant(entry) => {
entry.insert(self.base.state.get_storage_at(contract_address, key)?);
}
hash_map::Entry::Occupied(_) => {}
}

self.base.accessed_keys.insert(key);
self.base.state.set_storage_at(contract_address, key, value)?;

Ok(StorageWriteResponse {})
}

pub fn finalize(&mut self) {
self.base
.context
.revert_infos
.0
.last_mut()
.expect("Missing contract revert info.")
.original_values = std::mem::take(&mut self.base.original_values);
self.base.finalize();
}
}

Expand Down
3 changes: 2 additions & 1 deletion crates/blockifier/src/execution/syscalls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,8 @@ pub fn storage_write(
syscall_handler: &mut SyscallHintProcessor<'_>,
_remaining_gas: &mut u64,
) -> SyscallResult<StorageWriteResponse> {
syscall_handler.set_contract_storage_at(request.address, request.value)
syscall_handler.base.storage_write(request.address, request.value)?;
Ok(StorageWriteResponse {})
}

// Keccak syscall.
Expand Down
27 changes: 26 additions & 1 deletion crates/blockifier/src/execution/syscalls/syscall_base.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet};
use std::collections::{hash_map, HashMap, HashSet};
use std::convert::From;

use starknet_api::core::{ClassHash, ContractAddress};
Expand Down Expand Up @@ -99,6 +99,22 @@ impl<'state> SyscallHandlerBase<'state> {
Ok(self.state.get_storage_at(block_hash_contract_address, key)?)
}

pub fn storage_write(&mut self, key: StorageKey, value: Felt) -> SyscallResult<()> {
let contract_address = self.call.storage_address;

match self.original_values.entry(key) {
hash_map::Entry::Vacant(entry) => {
entry.insert(self.state.get_storage_at(contract_address, key)?);
}
hash_map::Entry::Occupied(_) => {}
}

self.accessed_keys.insert(key);
self.state.set_storage_at(contract_address, key, value)?;

Ok(())
}

pub fn execute_inner_call(
&mut self,
call: CallEntryPoint,
Expand Down Expand Up @@ -138,4 +154,13 @@ impl<'state> SyscallHandlerBase<'state> {

Ok(raw_retdata)
}

pub fn finalize(&mut self) {
self.context
.revert_infos
.0
.last_mut()
.expect("Missing contract revert info.")
.original_values = std::mem::take(&mut self.original_values);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ use crate::test_utils::{
BALANCE,
};

// TODO: Add test for native once reverts are supported.
#[test]
fn test_call_contract_that_panics() {
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1);
#[cfg_attr(feature = "cairo_native", test_case(CairoVersion::Native; "Native"))]
#[test_case(CairoVersion::Cairo1;"VM")]
fn test_call_contract_that_panics(cairo_version: CairoVersion) {
let test_contract = FeatureContract::TestContract(cairo_version);
let empty_contract = FeatureContract::Empty(CairoVersion::Cairo1);
let chain_info = &ChainInfo::create_for_testing();
let mut state = test_state(chain_info, BALANCE, &[(test_contract, 1), (empty_contract, 0)]);
Expand Down Expand Up @@ -61,6 +61,11 @@ fn test_call_contract_that_panics() {
);
assert!(inner_call.execution.events.is_empty());
assert!(inner_call.execution.l2_to_l1_messages.is_empty());

// Check that the tracked resource is SierraGas to make sure that Native is running.
for call in res.iter() {
assert_eq!(call.tracked_resource, TrackedResource::SierraGas);
}
}

#[cfg_attr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,16 @@ fn test_library_call_assert_fails(cairo_version: CairoVersion) {
..trivial_external_entry_point_new(test_contract)
};
let call_info = entry_point_call.execute_directly(&mut state).unwrap();
let expected_err_retdata = match test_contract.cairo_version() {
CairoVersion::Cairo0 | CairoVersion::Cairo1 => {
// 'x != y', 'ENTRYPOINT_FAILED'.
vec![felt!("0x7820213d2079"), felt!("0x454e545259504f494e545f4641494c4544")]
}
#[cfg(feature = "cairo_native")]
// 'x != y'.
CairoVersion::Native => vec![felt!("0x7820213d2079")],
};

assert_eq!(
call_info.execution,
CallExecution {
retdata: Retdata(expected_err_retdata),
retdata: Retdata(vec![
// 'x != y'.
felt!("0x7820213d2079"),
// 'ENTRYPOINT_FAILED'.
felt!("0x454e545259504f494e545f4641494c4544")
]),
gas_consumed: 150980,
failed: true,
..Default::default()
Expand Down
14 changes: 13 additions & 1 deletion crates/blockifier/src/transaction/account_transactions_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ use starknet_types_core::felt::Felt;
use crate::check_tx_execution_error_for_invalid_scenario;
use crate::context::{BlockContext, TransactionContext};
use crate::execution::call_info::CallInfo;
use crate::execution::contract_class::TrackedResource;
use crate::execution::entry_point::EntryPointExecutionContext;
use crate::execution::syscalls::SyscallSelector;
use crate::fee::fee_utils::{get_fee_by_gas_vector, get_sequencer_balance_keys};
Expand Down Expand Up @@ -1761,15 +1762,19 @@ fn test_revert_in_execute(
}

#[rstest]
#[cfg_attr(feature = "cairo_native", case::native(CairoVersion::Native))]
#[case::vm(CairoVersion::Cairo1)]
fn test_call_contract_that_panics(
#[case] cairo_version: CairoVersion,
mut block_context: BlockContext,
default_all_resource_bounds: ValidResourceBounds,
#[values(true, false)] enable_reverts: bool,
#[values("test_revert_helper", "bad_selector")] inner_selector: &str,
) {
// Override enable reverts.
block_context.versioned_constants.enable_reverts = enable_reverts;
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1);
let test_contract = FeatureContract::TestContract(cairo_version);
// TODO(Yoni): use `class_version` here once the feature contract fully supports Native.
let account = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1);
let chain_info = &block_context.chain_info;
let state = &mut test_state(chain_info, BALANCE, &[(test_contract, 1), (account, 1)]);
Expand Down Expand Up @@ -1809,4 +1814,11 @@ fn test_call_contract_that_panics(
// If reverts are enabled, `test_call_contract_revert` should catch it and ignore it.
// Otherwise, the transaction should revert.
assert_eq!(tx_execution_info.is_reverted(), !enable_reverts);

if enable_reverts {
// Check that the tracked resource is SierraGas to make sure that Native is running.
for call in tx_execution_info.execute_call_info.unwrap().iter() {
assert_eq!(call.tracked_resource, TrackedResource::SierraGas);
}
}
}

0 comments on commit 6b63197

Please sign in to comment.