diff --git a/crates/vm2-interface/src/lib.rs b/crates/vm2-interface/src/lib.rs index 5e1a8fb..83e95a6 100644 --- a/crates/vm2-interface/src/lib.rs +++ b/crates/vm2-interface/src/lib.rs @@ -57,27 +57,27 @@ //! } //! //! trait Tracer { -//! fn before_instruction(&mut self, _state: &mut S) {} -//! fn after_instruction(&mut self, _state: &mut S) {} +//! fn before_instruction(&mut self, state: &mut S, storage: &mut S::StorageInterface) {} +//! fn after_instruction(&mut self, state: &mut S, storage: &mut S::StorageInterface) {} //! } //! //! impl Tracer for T { -//! fn before_instruction(&mut self, state: &mut S) { +//! fn before_instruction(&mut self, state: &mut S, storage: &mut S::StorageInterface) { //! match OP::VALUE { //! Opcode::NewOpcode => {} //! // Do this for every old opcode //! Opcode::NearCall => { -//! ::before_instruction::(self, state) +//! ::before_instruction::(self, state, storage) //! } //! } //! } -//! fn after_instruction(&mut self, _state: &mut S) {} +//! fn after_instruction(&mut self, state: &mut S, storage: &mut S::StorageInterface) {} //! } //! //! // Now you can use the new features by implementing TracerV2 //! struct MyTracer; //! impl Tracer for MyTracer { -//! fn before_instruction(&mut self, state: &mut S) { +//! fn before_instruction(&mut self, state: &mut S, _: &mut S::StorageInterface) { //! if OP::VALUE == Opcode::NewOpcode { //! state.get_some_new_field(); //! } diff --git a/crates/vm2-interface/src/state_interface.rs b/crates/vm2-interface/src/state_interface.rs index 7b817f4..9b72082 100644 --- a/crates/vm2-interface/src/state_interface.rs +++ b/crates/vm2-interface/src/state_interface.rs @@ -2,6 +2,9 @@ use primitive_types::{H160, U256}; /// Public interface of the VM state. Encompasses both read and write methods. pub trait StateInterface { + /// Storage interface required for operations that read storage. + type StorageInterface; + /// Reads a register with the specified zero-based index. Returns a value together with a pointer flag. fn read_register(&self, register: u8) -> (U256, bool); /// Sets a register with the specified zero-based index @@ -42,7 +45,12 @@ pub trait StateInterface { /// Iterates over storage slots read or written during VM execution. fn get_storage_state(&self) -> impl Iterator; /// Gets value of the specified storage slot. - fn get_storage(&mut self, address: H160, slot: U256) -> U256; + fn get_storage( + &mut self, + storage: &mut Self::StorageInterface, + address: H160, + slot: U256, + ) -> U256; /// Iterates over all transient storage slots set during VM execution. fn get_transient_storage_state(&self) -> impl Iterator; @@ -217,6 +225,8 @@ pub struct DummyState; #[cfg(test)] impl StateInterface for DummyState { + type StorageInterface = (); + fn read_register(&self, _: u8) -> (U256, bool) { unimplemented!() } @@ -277,7 +287,7 @@ impl StateInterface for DummyState { std::iter::empty() } - fn get_storage(&mut self, _: H160, _: U256) -> U256 { + fn get_storage(&mut self, _: &mut Self::StorageInterface, _: H160, _: U256) -> U256 { unimplemented!() } diff --git a/crates/vm2-interface/src/tracer_interface.rs b/crates/vm2-interface/src/tracer_interface.rs index 4717bd6..adbf35b 100644 --- a/crates/vm2-interface/src/tracer_interface.rs +++ b/crates/vm2-interface/src/tracer_interface.rs @@ -245,7 +245,7 @@ impl OpcodeType for opcodes::Ret { /// struct FarCallCounter(usize); /// /// impl Tracer for FarCallCounter { -/// fn before_instruction(&mut self, state: &mut S) { +/// fn before_instruction(&mut self, state: &mut S, _: &mut S::StorageInterface) { /// match OP::VALUE { /// Opcode::FarCall(_) => self.0 += 1, /// _ => {} @@ -257,14 +257,24 @@ pub trait Tracer { /// Executes logic before an instruction handler. /// /// The default implementation does nothing. - fn before_instruction(&mut self, state: &mut S) { + fn before_instruction( + &mut self, + state: &mut S, + storage: &mut S::StorageInterface, + ) { let _ = state; + let _ = storage; } /// Executes logic after an instruction handler. /// /// The default implementation does nothing. - fn after_instruction(&mut self, state: &mut S) { + fn after_instruction( + &mut self, + state: &mut S, + storage: &mut S::StorageInterface, + ) { let _ = state; + let _ = storage; } /// Provides cycle statistics for "complex" instructions from the prover perspective (mostly precompile calls). @@ -297,14 +307,22 @@ impl Tracer for () {} // Multiple tracers can be combined by building a linked list out of tuples. impl Tracer for (A, B) { - fn before_instruction(&mut self, state: &mut S) { - self.0.before_instruction::(state); - self.1.before_instruction::(state); + fn before_instruction( + &mut self, + state: &mut S, + storage: &mut S::StorageInterface, + ) { + self.0.before_instruction::(state, storage); + self.1.before_instruction::(state, storage); } - fn after_instruction(&mut self, state: &mut S) { - self.0.after_instruction::(state); - self.1.after_instruction::(state); + fn after_instruction( + &mut self, + state: &mut S, + storage: &mut S::StorageInterface, + ) { + self.0.after_instruction::(state, storage); + self.1.after_instruction::(state, storage); } fn on_extra_prover_cycles(&mut self, stats: CycleStats) { @@ -321,7 +339,11 @@ mod tests { struct FarCallCounter(usize); impl Tracer for FarCallCounter { - fn before_instruction(&mut self, _: &mut S) { + fn before_instruction( + &mut self, + _: &mut S, + _: &mut S::StorageInterface, + ) { if let super::Opcode::FarCall(CallingMode::Normal) = OP::VALUE { self.0 += 1; } @@ -332,13 +354,13 @@ mod tests { fn test_tracer() { let mut tracer = FarCallCounter(0); - tracer.before_instruction::(&mut DummyState); + tracer.before_instruction::(&mut DummyState, &mut ()); assert_eq!(tracer.0, 0); - tracer.before_instruction::, _>(&mut DummyState); + tracer.before_instruction::, _>(&mut DummyState, &mut ()); assert_eq!(tracer.0, 1); - tracer.before_instruction::, _>(&mut DummyState); + tracer.before_instruction::, _>(&mut DummyState, &mut ()); assert_eq!(tracer.0, 1); } @@ -346,12 +368,12 @@ mod tests { fn test_aggregate_tracer() { let mut tracer = (FarCallCounter(0), (FarCallCounter(0), FarCallCounter(0))); - tracer.before_instruction::(&mut DummyState); + tracer.before_instruction::(&mut DummyState, &mut ()); assert_eq!(tracer.0 .0, 0); assert_eq!(tracer.1 .0 .0, 0); assert_eq!(tracer.1 .1 .0, 0); - tracer.before_instruction::, _>(&mut DummyState); + tracer.before_instruction::, _>(&mut DummyState, &mut ()); assert_eq!(tracer.0 .0, 1); assert_eq!(tracer.1 .0 .0, 1); assert_eq!(tracer.1 .1 .0, 1); diff --git a/crates/vm2/src/instruction_handlers/common.rs b/crates/vm2/src/instruction_handlers/common.rs index c46ef47..8d78f1a 100644 --- a/crates/vm2/src/instruction_handlers/common.rs +++ b/crates/vm2/src/instruction_handlers/common.rs @@ -1,10 +1,7 @@ use zksync_vm2_interface::{opcodes, OpcodeType, Tracer}; use super::ret::free_panic; -use crate::{ - addressing_modes::Arguments, instruction::ExecutionStatus, tracing::VmAndWorld, VirtualMachine, - World, -}; +use crate::{addressing_modes::Arguments, instruction::ExecutionStatus, VirtualMachine, World}; #[inline(always)] pub(crate) fn boilerplate>( @@ -56,15 +53,15 @@ pub(crate) fn full_boilerplate>( } if args.predicate().satisfied(&vm.state.flags) { - tracer.before_instruction::(&mut VmAndWorld { vm, world }); + tracer.before_instruction::(vm, world); vm.state.current_frame.pc = unsafe { vm.state.current_frame.pc.add(1) }; let result = business_logic(vm, args, world, tracer); - tracer.after_instruction::(&mut VmAndWorld { vm, world }); + tracer.after_instruction::(vm, world); result } else { - tracer.before_instruction::(&mut VmAndWorld { vm, world }); + tracer.before_instruction::(vm, world); vm.state.current_frame.pc = unsafe { vm.state.current_frame.pc.add(1) }; - tracer.after_instruction::(&mut VmAndWorld { vm, world }); + tracer.after_instruction::(vm, world); ExecutionStatus::Running } } diff --git a/crates/vm2/src/instruction_handlers/ret.rs b/crates/vm2/src/instruction_handlers/ret.rs index 382949d..7a29d92 100644 --- a/crates/vm2/src/instruction_handlers/ret.rs +++ b/crates/vm2/src/instruction_handlers/ret.rs @@ -15,7 +15,6 @@ use crate::{ instruction::{ExecutionEnd, ExecutionStatus}, mode_requirements::ModeRequirements, predication::Flags, - tracing::VmAndWorld, Instruction, Predicate, VirtualMachine, World, }; @@ -144,13 +143,13 @@ pub(crate) fn free_panic>( world: &mut W, tracer: &mut T, ) -> ExecutionStatus { - tracer.before_instruction::, _>(&mut VmAndWorld { vm, world }); + tracer.before_instruction::, _>(vm, world); // args aren't used for panics unless TO_LABEL let result = naked_ret::( vm, &Arguments::new(Predicate::Always, 0, ModeRequirements::none()), ); - tracer.after_instruction::, _>(&mut VmAndWorld { vm, world }); + tracer.after_instruction::, _>(vm, world); result } @@ -162,7 +161,7 @@ pub(crate) fn panic_from_failed_far_call>( tracer: &mut T, exception_handler: u16, ) { - tracer.before_instruction::, _>(&mut VmAndWorld { vm, world }); + tracer.before_instruction::, _>(vm, world); // Gas is already subtracted in the far call code. // No need to roll back, as no changes are made in this "frame". @@ -172,7 +171,7 @@ pub(crate) fn panic_from_failed_far_call>( vm.state.flags = Flags::new(true, false, false); vm.state.current_frame.set_pc_from_u16(exception_handler); - tracer.after_instruction::, _>(&mut VmAndWorld { vm, world }); + tracer.after_instruction::, _>(vm, world); } fn invalid>( diff --git a/crates/vm2/src/tracing.rs b/crates/vm2/src/tracing.rs index dff5b62..0810563 100644 --- a/crates/vm2/src/tracing.rs +++ b/crates/vm2/src/tracing.rs @@ -12,48 +12,44 @@ use crate::{ VirtualMachine, World, }; -pub(crate) struct VmAndWorld<'a, T, W> { - pub vm: &'a mut VirtualMachine, - pub world: &'a mut W, -} +impl> StateInterface for VirtualMachine { + type StorageInterface = W; -impl> StateInterface for VmAndWorld<'_, T, W> { fn read_register(&self, register: u8) -> (U256, bool) { ( - self.vm.state.registers[register as usize], - self.vm.state.register_pointer_flags & (1 << register) != 0, + self.state.registers[register as usize], + self.state.register_pointer_flags & (1 << register) != 0, ) } fn set_register(&mut self, register: u8, value: U256, is_pointer: bool) { - self.vm.state.registers[register as usize] = value; + self.state.registers[register as usize] = value; - self.vm.state.register_pointer_flags &= !(1 << register); - self.vm.state.register_pointer_flags |= u16::from(is_pointer) << register; + self.state.register_pointer_flags &= !(1 << register); + self.state.register_pointer_flags |= u16::from(is_pointer) << register; } fn number_of_callframes(&self) -> usize { - self.vm - .state + self.state .previous_frames .iter() .map(|frame| frame.near_calls.len() + 1) .sum::() - + self.vm.state.current_frame.near_calls.len() + + self.state.current_frame.near_calls.len() + 1 } fn current_frame(&mut self) -> impl CallframeInterface + '_ { - let near_call = self.vm.state.current_frame.near_calls.len().checked_sub(1); + let near_call = self.state.current_frame.near_calls.len().checked_sub(1); CallframeWrapper { - frame: &mut self.vm.state.current_frame, + frame: &mut self.state.current_frame, near_call, } } fn callframe(&mut self, mut n: usize) -> impl CallframeInterface + '_ { - for far_frame in std::iter::once(&mut self.vm.state.current_frame) - .chain(self.vm.state.previous_frames.iter_mut().rev()) + for far_frame in std::iter::once(&mut self.state.current_frame) + .chain(self.state.previous_frames.iter_mut().rev()) { let near_calls = far_frame.near_calls.len(); match n.cmp(&near_calls) { @@ -76,19 +72,19 @@ impl> StateInterface for VmAndWorld<'_, T, W> { } fn read_heap_byte(&self, heap: HeapId, index: u32) -> u8 { - self.vm.state.heaps[heap].read_byte(index) + self.state.heaps[heap].read_byte(index) } fn read_heap_u256(&self, heap: HeapId, index: u32) -> U256 { - self.vm.state.heaps[heap].read_u256(index) + self.state.heaps[heap].read_u256(index) } fn write_heap_u256(&mut self, heap: HeapId, index: u32, value: U256) { - self.vm.state.heaps.write_u256(heap, index, value); + self.state.heaps.write_u256(heap, index, value); } fn flags(&self) -> Flags { - let flags = &self.vm.state.flags; + let flags = &self.state.flags; Flags { less_than: Predicate::IfLT.satisfied(flags), greater: Predicate::IfGT.satisfied(flags), @@ -97,50 +93,50 @@ impl> StateInterface for VmAndWorld<'_, T, W> { } fn set_flags(&mut self, flags: Flags) { - self.vm.state.flags = predication::Flags::new(flags.less_than, flags.equal, flags.greater); + self.state.flags = predication::Flags::new(flags.less_than, flags.equal, flags.greater); } fn transaction_number(&self) -> u16 { - self.vm.state.transaction_number + self.state.transaction_number } fn set_transaction_number(&mut self, value: u16) { - self.vm.state.transaction_number = value; + self.state.transaction_number = value; } fn context_u128_register(&self) -> u128 { - self.vm.state.context_u128 + self.state.context_u128 } fn set_context_u128_register(&mut self, value: u128) { - self.vm.state.context_u128 = value; + self.state.context_u128 = value; } fn get_storage_state(&self) -> impl Iterator { - self.vm - .world_diff + self.world_diff .get_storage_state() .iter() .map(|(key, value)| (*key, *value)) } - fn get_storage(&mut self, address: H160, slot: U256) -> U256 { - self.vm - .world_diff - .just_read_storage(self.world, address, slot) + fn get_storage( + &mut self, + storage: &mut Self::StorageInterface, + address: H160, + slot: U256, + ) -> U256 { + self.world_diff.just_read_storage(storage, address, slot) } fn get_transient_storage_state(&self) -> impl Iterator { - self.vm - .world_diff + self.world_diff .get_transient_storage_state() .iter() .map(|(key, value)| (*key, *value)) } fn get_transient_storage(&self, address: H160, slot: U256) -> U256 { - self.vm - .world_diff + self.world_diff .get_transient_storage_state() .get(&(address, slot)) .copied() @@ -148,25 +144,24 @@ impl> StateInterface for VmAndWorld<'_, T, W> { } fn write_transient_storage(&mut self, address: H160, slot: U256, value: U256) { - self.vm - .world_diff + self.world_diff .write_transient_storage(address, slot, value); } fn events(&self) -> impl Iterator { - self.vm.world_diff.events().iter().copied() + self.world_diff.events().iter().copied() } fn l2_to_l1_logs(&self) -> impl Iterator { - self.vm.world_diff.l2_to_l1_logs().iter().copied() + self.world_diff.l2_to_l1_logs().iter().copied() } fn pubdata(&self) -> i32 { - self.vm.world_diff.pubdata() + self.world_diff.pubdata() } fn set_pubdata(&mut self, value: i32) { - self.vm.world_diff.pubdata.0 = value; + self.world_diff.pubdata.0 = value; } }