Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
refacto: cleanup state traits, CachedState interior mutability
Browse files Browse the repository at this point in the history
  • Loading branch information
tdelabro committed Jan 3, 2024
1 parent 720139f commit 67e3575
Show file tree
Hide file tree
Showing 10 changed files with 394 additions and 225 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ target
*.DS_Store

tmp_venv/*

516 changes: 343 additions & 173 deletions crates/blockifier/src/state/cached_state.rs

Large diffs are not rendered by default.

33 changes: 20 additions & 13 deletions crates/blockifier/src/state/cached_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@ fn set_initial_state_values(
class_hash_initial_values: HashMap<ContractAddress, ClassHash>,
storage_initial_values: HashMap<StorageEntry, StarkFelt>,
) {
assert!(state.cache == StateCache::default(), "Cache already initialized.");

state.class_hash_to_class = class_hash_to_class;
state.cache.class_hash_initial_values.extend(class_hash_initial_values);
state.cache.nonce_initial_values.extend(nonce_initial_values);
state.cache.storage_initial_values.extend(storage_initial_values);
let cache = state.cache.get_mut();
assert_eq!(cache.class_hash_to_class, Default::default(), "Cache already initialized");
assert_eq!(cache.class_hash_initial_values, Default::default(), "Cache already initialized");
assert_eq!(cache.nonce_initial_values, Default::default(), "Cache already initialized");
assert_eq!(cache.storage_initial_values, Default::default(), "Cache already initialized");

cache.class_hash_to_class = class_hash_to_class;
cache.class_hash_initial_values.extend(class_hash_initial_values);
cache.nonce_initial_values.extend(nonce_initial_values);
cache.storage_initial_values.extend(storage_initial_values);
}

#[test]
fn get_uninitialized_storage_value() {
let mut state: CachedState<DictStateReader> = CachedState::default();
let state: CachedState<DictStateReader> = CachedState::default();
let contract_address = contract_address!("0x1");
let key = StorageKey(patricia_key!("0x10"));

Expand Down Expand Up @@ -95,7 +99,7 @@ fn cast_between_storage_mapping_types() {

#[test]
fn get_uninitialized_value() {
let mut state: CachedState<DictStateReader> = CachedState::default();
let state: CachedState<DictStateReader> = CachedState::default();
let contract_address = contract_address!("0x1");

assert_eq!(state.get_nonce_at(contract_address).unwrap(), Nonce::default());
Expand Down Expand Up @@ -137,7 +141,7 @@ fn get_and_increment_nonce() {
fn get_contract_class() {
// Positive flow.
let existing_class_hash = class_hash!(TEST_CLASS_HASH);
let mut state = deprecated_create_test_state();
let state = deprecated_create_test_state();
assert_eq!(
state.get_compiled_contract_class(existing_class_hash).unwrap(),
get_test_contract_class()
Expand All @@ -153,7 +157,7 @@ fn get_contract_class() {

#[test]
fn get_uninitialized_class_hash_value() {
let mut state: CachedState<DictStateReader> = CachedState::default();
let state: CachedState<DictStateReader> = CachedState::default();
let valid_contract_address = contract_address!("0x1");

assert_eq!(state.get_class_hash_at(valid_contract_address).unwrap(), ClassHash::default());
Expand Down Expand Up @@ -252,7 +256,7 @@ fn cached_state_state_diff_conversion() {
address_to_nonce: IndexMap::from_iter([(contract_address2, Nonce(StarkFelt::from(1_u64)))]),
};

assert_eq!(expected_state_diff, state.to_state_diff());
assert_eq!(expected_state_diff, state.cached_state_diff());
}

fn create_state_changes_for_test<S: StateReader>(
Expand Down Expand Up @@ -383,14 +387,17 @@ fn global_contract_cache_is_used() {
let mut state = CachedState::new(DictStateReader::default(), global_cache.clone());

// Assert local cache is initialized empty even if global cache is not empty.
assert!(state.class_hash_to_class.get(&class_hash).is_none());
assert!(state.cache.get_mut().class_hash_to_class.get(&class_hash).is_none());

// Check state uses the global cache.
assert_eq!(state.get_compiled_contract_class(class_hash).unwrap(), contract_class);
assert_eq!(global_cache.lock().unwrap().cache_hits().unwrap(), 1);
assert_eq!(global_cache.lock().unwrap().cache_size(), 1);
// Verify local cache is also updated.
assert_eq!(state.class_hash_to_class.get(&class_hash).unwrap(), &contract_class);
assert_eq!(
state.cache.get_mut().class_hash_to_class.get(&class_hash).unwrap(),
&contract_class
);

// Idempotency: getting the same class again uses the local cache.
assert_eq!(state.get_compiled_contract_class(class_hash).unwrap(), contract_class);
Expand Down
15 changes: 6 additions & 9 deletions crates/blockifier/src/state/state_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use starknet_api::state::StorageKey;
use crate::abi::abi_utils::get_fee_token_var_address;
use crate::abi::sierra_types::next_storage_key;
use crate::execution::contract_class::ContractClass;
use crate::state::cached_state::CommitmentStateDiff;
use crate::state::errors::StateError;

pub type StateResult<T> = Result<T, StateError>;
Expand All @@ -25,32 +24,32 @@ pub trait StateReader {
/// its address).
/// Default: 0 for an uninitialized contract address.
fn get_storage_at(
&mut self,
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<StarkFelt>;

/// Returns the nonce of the given contract instance.
/// Default: 0 for an uninitialized contract address.
fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult<Nonce>;
fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce>;

/// Returns the class hash of the contract class at the given contract instance.
/// Default: 0 (uninitialized class hash) for an uninitialized contract address.
fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult<ClassHash>;
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash>;

/// Returns the contract class of the given class hash.
fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult<ContractClass>;
fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult<ContractClass>;

/// Returns the compiled class hash of the given class hash.
fn get_compiled_class_hash(&mut self, class_hash: ClassHash) -> StateResult<CompiledClassHash>;
fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult<CompiledClassHash>;

/// Returns the storage value representing the balance (in fee token) at the given address.
// TODO(Dori, 1/7/2023): When a standard representation for large integers is set, change the
// return type to that.
// TODO(Dori, 1/9/2023): NEW_TOKEN_SUPPORT Determine fee token address based on tx version,
// once v3 is introduced.
fn get_fee_token_balance(
&mut self,
&self,
contract_address: ContractAddress,
fee_token_address: ContractAddress,
) -> Result<(StarkFelt, StarkFelt), StateError> {
Expand Down Expand Up @@ -101,6 +100,4 @@ pub trait State: StateReader {
class_hash: ClassHash,
compiled_class_hash: CompiledClassHash,
) -> StateResult<()>;

fn to_state_diff(&mut self) -> CommitmentStateDiff;
}
2 changes: 1 addition & 1 deletion crates/blockifier/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use starknet_api::transaction::{
use starknet_api::{calldata, contract_address, patricia_key, stark_felt};

use crate::abi::abi_utils::{get_fee_token_var_address, selector_from_name};
use crate::abi::constants::{self};
use crate::abi::constants;
use crate::execution::contract_class::{ContractClass, ContractClassV0};
use crate::execution::entry_point::{CallEntryPoint, CallType};
use crate::execution::execution_utils::felt_to_stark_felt;
Expand Down
10 changes: 5 additions & 5 deletions crates/blockifier/src/test_utils/dict_state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub struct DictStateReader {

impl StateReader for DictStateReader {
fn get_storage_at(
&mut self,
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<StarkFelt> {
Expand All @@ -30,27 +30,27 @@ impl StateReader for DictStateReader {
Ok(value)
}

fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult<Nonce> {
fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
let nonce = self.address_to_nonce.get(&contract_address).copied().unwrap_or_default();
Ok(nonce)
}

fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult<ContractClass> {
fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult<ContractClass> {
let contract_class = self.class_hash_to_class.get(&class_hash).cloned();
match contract_class {
Some(contract_class) => Ok(contract_class),
_ => Err(StateError::UndeclaredClassHash(class_hash)),
}
}

fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult<ClassHash> {
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
let class_hash =
self.address_to_class_hash.get(&contract_address).copied().unwrap_or_default();
Ok(class_hash)
}

fn get_compiled_class_hash(
&mut self,
&self,
class_hash: ClassHash,
) -> StateResult<starknet_api::core::CompiledClassHash> {
let compiled_class_hash =
Expand Down
17 changes: 8 additions & 9 deletions crates/blockifier/src/transaction/account_transactions_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,10 @@ fn test_infinite_recursion(
if success {
assert!(tx_execution_info.revert_error.is_none());
} else {
assert!(
tx_execution_info
.revert_error
.unwrap()
.contains("RunResources has no remaining steps.")
);
assert!(tx_execution_info
.revert_error
.unwrap()
.contains("RunResources has no remaining steps."));
}
}

Expand Down Expand Up @@ -938,9 +936,10 @@ fn test_insufficient_max_fee_reverts(
.unwrap();
assert!(tx_execution_info3.is_reverted());
assert!(tx_execution_info3.actual_fee == actual_fee_depth1);
assert!(
tx_execution_info3.revert_error.unwrap().contains("RunResources has no remaining steps.")
);
assert!(tx_execution_info3
.revert_error
.unwrap()
.contains("RunResources has no remaining steps."));
}

#[rstest]
Expand Down
13 changes: 5 additions & 8 deletions crates/native_blockifier/src/state_readers/papyrus_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl PapyrusReader {
// Currently unused - will soon replace the same `impl` for `PapyrusStateReader`.
impl StateReader for PapyrusReader {
fn get_storage_at(
&mut self,
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<StarkFelt> {
Expand All @@ -47,7 +47,7 @@ impl StateReader for PapyrusReader {
.map_err(|error| StateError::StateReadError(error.to_string()))
}

fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult<Nonce> {
fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
let state_number = StateNumber(self.latest_block);
match self
.reader()?
Expand All @@ -60,7 +60,7 @@ impl StateReader for PapyrusReader {
}
}

fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult<ClassHash> {
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
let state_number = StateNumber(self.latest_block);
match self
.reader()?
Expand All @@ -75,7 +75,7 @@ impl StateReader for PapyrusReader {

/// Returns a V1 contract if found, or a V0 contract if a V1 contract is not
/// found, or an `Error` otherwise.
fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult<ContractClass> {
fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult<ContractClass> {
let state_number = StateNumber(self.latest_block);
let class_declaration_block_number = self
.reader()?
Expand Down Expand Up @@ -112,10 +112,7 @@ impl StateReader for PapyrusReader {
}
}

fn get_compiled_class_hash(
&mut self,
_class_hash: ClassHash,
) -> StateResult<CompiledClassHash> {
fn get_compiled_class_hash(&self, _class_hash: ClassHash) -> StateResult<CompiledClassHash> {
todo!()
}
}
10 changes: 5 additions & 5 deletions crates/native_blockifier/src/state_readers/py_state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl PyStateReader {

impl StateReader for PyStateReader {
fn get_storage_at(
&mut self,
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<StarkFelt> {
Expand All @@ -44,7 +44,7 @@ impl StateReader for PyStateReader {
.map_err(|err| StateError::StateReadError(err.to_string()))
}

fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult<Nonce> {
fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
Python::with_gil(|py| -> PyResult<PyFelt> {
let args = (ON_CHAIN_STORAGE_DOMAIN, PyFelt::from(contract_address));
self.state_reader_proxy.as_ref(py).call_method1("get_nonce_at", args)?.extract()
Expand All @@ -53,7 +53,7 @@ impl StateReader for PyStateReader {
.map_err(|err| StateError::StateReadError(err.to_string()))
}

fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult<ClassHash> {
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
Python::with_gil(|py| -> PyResult<PyFelt> {
let args = (PyFelt::from(contract_address),);
self.state_reader_proxy.as_ref(py).call_method1("get_class_hash_at", args)?.extract()
Expand All @@ -62,7 +62,7 @@ impl StateReader for PyStateReader {
.map_err(|err| StateError::StateReadError(err.to_string()))
}

fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult<ContractClass> {
fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult<ContractClass> {
Python::with_gil(|py| -> Result<ContractClass, PyErr> {
let args = (PyFelt::from(class_hash),);
let py_raw_compiled_class: PyRawCompiledClass = self
Expand All @@ -82,7 +82,7 @@ impl StateReader for PyStateReader {
})
}

fn get_compiled_class_hash(&mut self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
Python::with_gil(|py| -> PyResult<PyFelt> {
let args = (PyFelt::from(class_hash),);
self.state_reader_proxy
Expand Down
2 changes: 1 addition & 1 deletion crates/native_blockifier/src/transaction_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl<S: StateReader> TransactionExecutor<S> {
self.state.move_classes_to_global_cache();
}

PyStateDiff::from(self.state.to_state_diff())
PyStateDiff::from(self.state.cached_state_diff())
}

// Block pre-processing; see `block_execution::pre_process_block` documentation.
Expand Down

0 comments on commit 67e3575

Please sign in to comment.