From 2708c7f32d74b8003418b328bff54a3c3fb658e4 Mon Sep 17 00:00:00 2001 From: Zach Kolodny Date: Mon, 3 Jun 2024 22:42:15 -0400 Subject: [PATCH] refactor InMemoryCustomRefundStorage to leverage InMemoryStorage as a field --- src/tests/complex_tests/mod.rs | 2 +- src/tests/mod.rs | 6 +- src/tests/run_manually.rs | 5 +- src/tests/storage.rs | 94 +++++++++++++ src/tests/utils/mod.rs | 8 -- src/tests/utils/storage.rs | 216 ------------------------------ src/tests/utils/testing_tracer.rs | 2 +- 7 files changed, 102 insertions(+), 231 deletions(-) create mode 100644 src/tests/storage.rs delete mode 100644 src/tests/utils/storage.rs diff --git a/src/tests/complex_tests/mod.rs b/src/tests/complex_tests/mod.rs index 8cd689b8..5827f7dc 100644 --- a/src/tests/complex_tests/mod.rs +++ b/src/tests/complex_tests/mod.rs @@ -176,7 +176,7 @@ pub(crate) fn generate_base_layer( use crate::external_calls::run; use crate::toolset::GeometryConfig; - let mut storage_impl = InMemoryCustomRefundStorage::new(None); + let mut storage_impl = InMemoryStorage::new(); let mut tree = ZKSyncTestingTree::empty(); test_artifact.entry_point_address = diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 33ccfe0f..3c206a44 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -5,6 +5,9 @@ pub mod complex_tests; pub mod run_manually; #[cfg(test)] pub mod simple_tests; +#[cfg(test)] +pub(crate) mod storage; +#[cfg(test)] pub(crate) mod utils; use crate::blake2::Blake2s256; @@ -22,7 +25,6 @@ use circuit_definitions::circuit_definitions::recursion_layer::ZkSyncRecursiveLa use circuit_definitions::ZkSyncDefaultRoundFunction; use std::alloc::Global; use std::collections::HashMap; -use utils::storage::InMemoryCustomRefundStorage; const ACCOUNT_CODE_STORAGE_ADDRESS: Address = H160([ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -35,7 +37,7 @@ const KNOWN_CODE_HASHES_ADDRESS: Address = H160([ ]); pub(crate) fn save_predeployed_contracts( - storage: &mut InMemoryCustomRefundStorage, + storage: &mut InMemoryStorage, tree: &mut impl BinarySparseStorageTree<256, 32, 32, 8, 32, Blake2s256, ZkSyncStorageLeaf>, contracts: &HashMap>, ) { diff --git a/src/tests/run_manually.rs b/src/tests/run_manually.rs index 98d74d39..326e215a 100644 --- a/src/tests/run_manually.rs +++ b/src/tests/run_manually.rs @@ -24,8 +24,7 @@ use crate::zkevm_circuits::base_structures::vm_state::GlobalContextWitness; use crate::zkevm_circuits::main_vm::main_vm_entry_point; use circuit_definitions::aux_definitions::witness_oracle::VmWitnessOracle; use circuit_definitions::zk_evm::vm_state::cycle; -use utils::storage::InMemoryCustomRefundStorage; -use utils::StorageRefund; +use storage::{InMemoryCustomRefundStorage, StorageRefund}; use zkevm_assembly::Assembly; #[test] @@ -261,7 +260,7 @@ pub(crate) fn run_with_options(entry_point_bytecode: Vec<[u8; 32]>, options: Opt let mut known_contracts = HashMap::new(); known_contracts.extend(options.other_contracts.iter().cloned()); - save_predeployed_contracts(&mut storage_impl, &mut tree, &known_contracts); + save_predeployed_contracts(&mut storage_impl.storage, &mut tree, &known_contracts); let mut basic_block_circuits = vec![]; diff --git a/src/tests/storage.rs b/src/tests/storage.rs new file mode 100644 index 00000000..b7bf9b49 --- /dev/null +++ b/src/tests/storage.rs @@ -0,0 +1,94 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, Mutex}, +}; + +use circuit_definitions::{ + ethereum_types::{Address, H160}, + zk_evm::{ + abstractions::{Storage, StorageAccessRefund}, + aux_structures::{LogQuery, PubdataCost, Timestamp}, + reference_impls::{event_sink::ApplicationData, memory::SimpleMemory}, + testing::{storage::InMemoryStorage, NUM_SHARDS}, + tracing::{ + AfterDecodingData, AfterExecutionData, BeforeExecutionData, Tracer, VmLocalStateData, + }, + vm_state::PrimitiveValue, + }, +}; +use zkevm_assembly::zkevm_opcode_defs::{ + decoding::{AllowedPcOrImm, EncodingModeProduction, VmEncodingMode}, + AddOpcode, DecodedOpcode, NopOpcode, Opcode, PtrOpcode, RetOpcode, MAX_PUBDATA_COST_PER_QUERY, + STORAGE_ACCESS_COLD_READ_COST, STORAGE_ACCESS_COLD_WRITE_COST, STORAGE_ACCESS_WARM_READ_COST, + STORAGE_ACCESS_WARM_WRITE_COST, STORAGE_AUX_BYTE, TRANSIENT_STORAGE_AUX_BYTE, +}; + +use crate::ethereum_types::U256; + +/// Enum holding the types of storage refunds +#[derive(Debug, Copy, Clone)] +pub(crate) enum StorageRefund { + Cold, + Warm, +} + +#[derive(Debug, Clone)] +pub struct InMemoryCustomRefundStorage { + pub storage: InMemoryStorage, + pub slot_refund: Option>>, +} + +impl InMemoryCustomRefundStorage { + pub fn new(slot_refund: Option>>) -> Self { + Self { + storage: InMemoryStorage::new(), + slot_refund, + } + } +} + +impl Storage for InMemoryCustomRefundStorage { + #[track_caller] + fn get_access_refund( + &mut self, // to avoid any hacks inside, like prefetch + _monotonic_cycle_counter: u32, + _partial_query: &LogQuery, + ) -> StorageAccessRefund { + match &self.slot_refund { + None => StorageAccessRefund::Cold, + Some(val) => { + let (refund_type, val) = *val.lock().unwrap(); + + match refund_type { + StorageRefund::Cold => dbg!(StorageAccessRefund::Cold), + StorageRefund::Warm => dbg!(StorageAccessRefund::Warm { ergs: val }), + } + } + } + } + + #[track_caller] + fn execute_partial_query( + &mut self, + monotonic_cycle_counter: u32, + query: LogQuery, + ) -> (LogQuery, PubdataCost) { + self.storage + .execute_partial_query(monotonic_cycle_counter, query) + } + + #[track_caller] + fn start_frame(&mut self, timestamp: Timestamp) { + self.storage.start_frame(timestamp) + } + + #[track_caller] + fn finish_frame(&mut self, timestamp: Timestamp, panicked: bool) { + self.storage.finish_frame(timestamp, panicked) + } + + #[track_caller] + fn start_new_tx(&mut self, timestamp: Timestamp) { + self.storage.start_new_tx(timestamp) + } +} diff --git a/src/tests/utils/mod.rs b/src/tests/utils/mod.rs index 0c946515..a34f300b 100644 --- a/src/tests/utils/mod.rs +++ b/src/tests/utils/mod.rs @@ -1,10 +1,2 @@ pub mod preprocess_asm; -pub mod storage; pub mod testing_tracer; - -/// Enum holding the types of storage refunds -#[derive(Debug, Copy, Clone)] -pub(crate) enum StorageRefund { - Cold, - Warm, -} diff --git a/src/tests/utils/storage.rs b/src/tests/utils/storage.rs deleted file mode 100644 index 27748ab4..00000000 --- a/src/tests/utils/storage.rs +++ /dev/null @@ -1,216 +0,0 @@ -use std::{ - collections::{HashMap, HashSet}, - sync::{Arc, Mutex}, -}; - -use circuit_definitions::{ - ethereum_types::{Address, H160}, - zk_evm::{ - abstractions::{Storage, StorageAccessRefund}, - aux_structures::{LogQuery, PubdataCost, Timestamp}, - reference_impls::{event_sink::ApplicationData, memory::SimpleMemory}, - testing::{storage::InMemoryStorage, NUM_SHARDS}, - tracing::{ - AfterDecodingData, AfterExecutionData, BeforeExecutionData, Tracer, VmLocalStateData, - }, - vm_state::PrimitiveValue, - }, -}; -use zkevm_assembly::zkevm_opcode_defs::{ - decoding::{AllowedPcOrImm, EncodingModeProduction, VmEncodingMode}, - AddOpcode, DecodedOpcode, NopOpcode, Opcode, PtrOpcode, RetOpcode, MAX_PUBDATA_COST_PER_QUERY, - STORAGE_ACCESS_COLD_READ_COST, STORAGE_ACCESS_COLD_WRITE_COST, STORAGE_ACCESS_WARM_READ_COST, - STORAGE_ACCESS_WARM_WRITE_COST, STORAGE_AUX_BYTE, TRANSIENT_STORAGE_AUX_BYTE, -}; - -use crate::ethereum_types::U256; - -use super::{ - preprocess_asm::{EXCEPTION_PREFIX, PRINT_PREFIX, PRINT_PTR_PREFIX, PRINT_REG_PREFIX}, - testing_tracer::{OutOfCircuitException, TestingTracer}, - StorageRefund, -}; - -#[derive(Debug, Clone)] -pub struct InMemoryCustomRefundStorage { - pub inner: [HashMap>; NUM_SHARDS], - pub inner_transient: [HashMap>; NUM_SHARDS], - pub cold_warm_markers: [HashMap>; NUM_SHARDS], - pub transient_cold_warm_markers: [HashMap>; NUM_SHARDS], // not used - pub frames_stack: Vec>, - pub slot_refund: Option>>, -} - -impl InMemoryCustomRefundStorage { - pub fn new(slot_refund: Option>>) -> Self { - Self { - inner: [(); NUM_SHARDS].map(|_| HashMap::default()), - inner_transient: [(); NUM_SHARDS].map(|_| HashMap::default()), - cold_warm_markers: [(); NUM_SHARDS].map(|_| HashMap::default()), - transient_cold_warm_markers: [(); NUM_SHARDS].map(|_| HashMap::default()), - frames_stack: vec![ApplicationData::empty()], - slot_refund, - } - } - - pub fn populate(&mut self, elements: Vec<(u8, Address, U256, U256)>) { - for (shard_id, address, key, value) in elements.into_iter() { - let shard_level_map = &mut self.inner[shard_id as usize]; - let address_level_map = shard_level_map.entry(address).or_default(); - address_level_map.insert(key, value); - } - } -} - -impl Storage for InMemoryCustomRefundStorage { - #[track_caller] - fn get_access_refund( - &mut self, // to avoid any hacks inside, like prefetch - _monotonic_cycle_counter: u32, - _partial_query: &LogQuery, - ) -> StorageAccessRefund { - match &self.slot_refund { - None => StorageAccessRefund::Cold, - Some(val) => { - let (refund_type, val) = *val.lock().unwrap(); - - match refund_type { - StorageRefund::Cold => dbg!(StorageAccessRefund::Cold), - StorageRefund::Warm => dbg!(StorageAccessRefund::Warm { ergs: val }), - } - } - } - } - - #[track_caller] - fn execute_partial_query( - &mut self, - _monotonic_cycle_counter: u32, - mut query: LogQuery, - ) -> (LogQuery, PubdataCost) { - let aux_byte = query.aux_byte; - let shard_level_map = if aux_byte == STORAGE_AUX_BYTE { - &mut self.inner[query.shard_id as usize] - } else { - &mut self.inner_transient[query.shard_id as usize] - }; - let shard_level_warm_map = if aux_byte == STORAGE_AUX_BYTE { - &mut self.cold_warm_markers[query.shard_id as usize] - } else { - &mut self.transient_cold_warm_markers[query.shard_id as usize] - }; - let frame_data = self.frames_stack.last_mut().expect("frame must be started"); - - assert!(!query.rollback); - if query.rw_flag { - // write, also append rollback - let address_level_map = shard_level_map.entry(query.address).or_default(); - let current_value = address_level_map - .get(&query.key) - .copied() - .unwrap_or(U256::zero()); - address_level_map.insert(query.key, query.written_value); - - // mark as warm, and return - let address_level_warm_map = shard_level_warm_map.entry(query.address).or_default(); - let warm = address_level_warm_map.contains(&query.key); - if !warm { - address_level_warm_map.insert(query.key); - } - query.read_value = current_value; - - frame_data.forward.push(query); - query.rollback = true; - frame_data.rollbacks.push(query); - query.rollback = false; - - let pubdata_cost = if aux_byte == STORAGE_AUX_BYTE { - PubdataCost(MAX_PUBDATA_COST_PER_QUERY) - } else { - PubdataCost(0i32) - }; - - (query, pubdata_cost) - } else { - // read, do not append to rollback - let address_level_map = shard_level_map.entry(query.address).or_default(); - let current_value = address_level_map - .get(&query.key) - .copied() - .unwrap_or(U256::zero()); - // mark as warm, and return - let address_level_warm_map = shard_level_warm_map.entry(query.address).or_default(); - let warm = address_level_warm_map.contains(&query.key); - if !warm { - address_level_warm_map.insert(query.key); - } - query.read_value = current_value; - frame_data.forward.push(query); - - (query, PubdataCost(0i32)) - } - } - - #[track_caller] - fn start_frame(&mut self, _timestamp: Timestamp) { - let new = ApplicationData::empty(); - self.frames_stack.push(new); - } - - #[track_caller] - fn finish_frame(&mut self, _timestamp: Timestamp, panicked: bool) { - // if we panic then we append forward and rollbacks to the forward of parent, - // otherwise we place rollbacks of child before rollbacks of the parent - let current_frame = self - .frames_stack - .pop() - .expect("frame must be started before finishing"); - let ApplicationData { forward, rollbacks } = current_frame; - let parent_data = self - .frames_stack - .last_mut() - .expect("parent_frame_must_exist"); - if panicked { - // perform actual rollback - for query in rollbacks.iter().rev() { - let LogQuery { - shard_id, - address, - key, - read_value, - written_value, - aux_byte, - .. - } = *query; - let shard_level_map = if aux_byte == STORAGE_AUX_BYTE { - &mut self.inner[shard_id as usize] - } else { - &mut self.inner_transient[shard_id as usize] - }; - let address_level_map = shard_level_map - .get_mut(&address) - .expect("must always exist on rollback"); - let current_value_ref = address_level_map - .get_mut(&key) - .expect("must always exist on rollback"); - assert_eq!(*current_value_ref, written_value); // compare current value - *current_value_ref = read_value; // write back an old value - } - - parent_data.forward.extend(forward); - // add to forward part, but in reverse order - parent_data.forward.extend(rollbacks.into_iter().rev()); - } else { - parent_data.forward.extend(forward); - // we need to prepend rollbacks. No reverse here, as we do not care yet! - parent_data.rollbacks.extend(rollbacks); - } - } - - #[track_caller] - fn start_new_tx(&mut self, _: Timestamp) { - for transient in self.inner_transient.iter_mut() { - transient.clear(); - } - } -} diff --git a/src/tests/utils/testing_tracer.rs b/src/tests/utils/testing_tracer.rs index 540a5cdc..22f60625 100644 --- a/src/tests/utils/testing_tracer.rs +++ b/src/tests/utils/testing_tracer.rs @@ -20,6 +20,7 @@ use zkevm_assembly::zkevm_opcode_defs::RetOpcode; use zkevm_assembly::zkevm_opcode_defs::REGISTERS_COUNT; use crate::ethereum_types::U256; +use crate::tests::storage::StorageRefund; use crate::zk_evm::reference_impls::memory::SimpleMemory; use crate::zk_evm::tracing::*; @@ -30,7 +31,6 @@ use crate::tests::utils::preprocess_asm::PRINT_REG_PREFIX; use super::preprocess_asm::STORAGE_REFUND_COLD_PREFIX; use super::preprocess_asm::STORAGE_REFUND_WARM_PREFIX; -use super::StorageRefund; #[derive(Debug, Clone, PartialEq, Default)] enum TracerState {