diff --git a/src/callframe.rs b/src/callframe.rs index f8b14357..b1775991 100644 --- a/src/callframe.rs +++ b/src/callframe.rs @@ -1,5 +1,9 @@ use crate::{ - decommit::is_kernel, heap::HeapId, program::Program, stack::Stack, world_diff::Snapshot, + decommit::is_kernel, + heap::HeapId, + program::Program, + stack::{Stack, StackSnapshot}, + world_diff::Snapshot, Instruction, }; use u256::H160; @@ -155,6 +159,50 @@ impl Callframe { .map(|f| f.previous_frame_gas) .sum::() } + + pub(crate) fn snapshot(&self) -> CallframeSnapshot { + CallframeSnapshot { + stack: self.stack.snapshot(), + + context_u128: self.context_u128, + sp: self.sp, + gas: self.gas, + near_calls: self.near_calls.clone(), + heap_size: self.heap_size, + aux_heap_size: self.aux_heap_size, + heaps_i_was_keeping_alive: self.heaps_i_am_keeping_alive.len(), + } + } + + /// Returns heaps that were created during the period that is rolled back + /// and thus can't be referenced anymore and should be deallocated. + pub(crate) fn rollback( + &mut self, + snapshot: CallframeSnapshot, + ) -> impl Iterator + '_ { + let CallframeSnapshot { + stack, + context_u128, + sp, + gas, + near_calls, + heap_size, + aux_heap_size, + heaps_i_was_keeping_alive, + } = snapshot; + + self.stack.rollback(stack); + + self.context_u128 = context_u128; + self.sp = sp; + self.gas = gas; + self.near_calls = near_calls; + self.heap_size = heap_size; + self.aux_heap_size = aux_heap_size; + + self.heaps_i_am_keeping_alive + .drain(heaps_i_was_keeping_alive..) + } } pub(crate) struct FrameRemnant { @@ -162,3 +210,18 @@ pub(crate) struct FrameRemnant { pub(crate) exception_handler: u16, pub(crate) snapshot: Snapshot, } + +/// Only contains the fields that can change (other than via tracer). +pub(crate) struct CallframeSnapshot { + stack: StackSnapshot, + + context_u128: u128, + sp: u16, + gas: u32, + near_calls: Vec, + + heap_size: u32, + aux_heap_size: u32, + + heaps_i_was_keeping_alive: usize, +} diff --git a/src/heap.rs b/src/heap.rs index 58e3a25b..97d357e9 100644 --- a/src/heap.rs +++ b/src/heap.rs @@ -1,5 +1,5 @@ use crate::instruction_handlers::HeapInterface; -use std::ops::{Index, IndexMut, Range}; +use std::ops::{Index, Range}; use u256::U256; use zkevm_opcode_defs::system_params::NEW_FRAME_MEMORY_STIPEND; @@ -21,8 +21,17 @@ impl HeapId { pub struct Heap(Vec); impl Heap { - pub fn reserve(&mut self, additional: usize) { - self.0.reserve_exact(additional); + fn write_u256(&mut self, start_address: u32, value: U256) { + let end = (start_address + 32) as usize; + if end > self.0.len() { + self.0.resize(end, 0); + } + + value.to_big_endian(&mut self.0[start_address as usize..end]); + } + + pub(crate) fn is_empty(&self) -> bool { + self.0.is_empty() } } @@ -39,16 +48,6 @@ impl HeapInterface for Heap { } U256::from_big_endian(&bytes) } - fn write_u256(&mut self, start_address: u32, value: U256) { - let end = (start_address + 32) as usize; - if end > self.0.len() { - self.0.resize(end, 0); - } - - let mut bytes = [0; 32]; - value.to_big_endian(&mut bytes); - self.0[start_address as usize..end].copy_from_slice(&bytes); - } fn read_range_big_endian(&self, range: Range) -> Vec { let end = (range.end as usize).min(self.0.len()); let mut result = vec![0; range.len()]; @@ -57,13 +56,13 @@ impl HeapInterface for Heap { } result } - fn memset(&mut self, src: &[u8]) { - self.0 = src.to_vec(); - } } #[derive(Debug, Clone)] -pub struct Heaps(Vec); +pub struct Heaps { + heaps: Vec, + bootloader_heap_rollback_info: Vec<(u32, U256)>, +} pub(crate) const CALLDATA_HEAP: HeapId = HeapId(1); pub const FIRST_HEAP: HeapId = HeapId(2); @@ -73,23 +72,50 @@ impl Heaps { pub(crate) fn new(calldata: Vec) -> Self { // The first heap can never be used because heap zero // means the current heap in precompile calls - Self(vec![ - Heap(vec![]), - Heap(calldata), - Heap(vec![]), - Heap(vec![]), - ]) + Self { + heaps: vec![Heap(vec![]), Heap(calldata), Heap(vec![]), Heap(vec![])], + bootloader_heap_rollback_info: vec![], + } } pub(crate) fn allocate(&mut self) -> HeapId { - let id = HeapId(self.0.len() as u32); - self.0 - .push(Heap(vec![0; NEW_FRAME_MEMORY_STIPEND as usize])); + self.allocate_inner(vec![0; NEW_FRAME_MEMORY_STIPEND as usize]) + } + + pub(crate) fn allocate_with_content(&mut self, content: &[u8]) -> HeapId { + self.allocate_inner(content.to_vec()) + } + + fn allocate_inner(&mut self, memory: Vec) -> HeapId { + let id = HeapId(self.heaps.len() as u32); + self.heaps.push(Heap(memory)); id } pub(crate) fn deallocate(&mut self, heap: HeapId) { - self.0[heap.0 as usize].0 = vec![]; + self.heaps[heap.0 as usize].0 = vec![]; + } + + pub fn write_u256(&mut self, heap: HeapId, start_address: u32, value: U256) { + if heap == FIRST_HEAP { + self.bootloader_heap_rollback_info + .push((start_address, self[heap].read_u256(start_address))); + } + self.heaps[heap.0 as usize].write_u256(start_address, value); + } + + pub(crate) fn snapshot(&self) -> usize { + self.bootloader_heap_rollback_info.len() + } + + pub(crate) fn rollback(&mut self, snapshot: usize) { + for (address, value) in self.bootloader_heap_rollback_info.drain(snapshot..).rev() { + self.heaps[FIRST_HEAP.0 as usize].write_u256(address, value); + } + } + + pub(crate) fn delete_history(&mut self) { + self.bootloader_heap_rollback_info.clear(); } } @@ -97,20 +123,16 @@ impl Index for Heaps { type Output = Heap; fn index(&self, index: HeapId) -> &Self::Output { - &self.0[index.0 as usize] - } -} - -impl IndexMut for Heaps { - fn index_mut(&mut self, index: HeapId) -> &mut Self::Output { - &mut self.0[index.0 as usize] + &self.heaps[index.0 as usize] } } impl PartialEq for Heaps { fn eq(&self, other: &Self) -> bool { - for i in 0..self.0.len().max(other.0.len()) { - if self.0.get(i).unwrap_or(&Heap(vec![])) != other.0.get(i).unwrap_or(&Heap(vec![])) { + for i in 0..self.heaps.len().max(other.heaps.len()) { + if self.heaps.get(i).unwrap_or(&Heap(vec![])) + != other.heaps.get(i).unwrap_or(&Heap(vec![])) + { return false; } } diff --git a/src/instruction_handlers/decommit.rs b/src/instruction_handlers/decommit.rs index 49b677fa..dfa3952a 100644 --- a/src/instruction_handlers/decommit.rs +++ b/src/instruction_handlers/decommit.rs @@ -8,7 +8,7 @@ use crate::{ Instruction, VirtualMachine, World, }; -use super::{common::instruction_boilerplate, HeapInterface}; +use super::common::instruction_boilerplate; fn decommit( vm: &mut VirtualMachine, @@ -38,9 +38,8 @@ fn decommit( vm.state.current_frame.gas += extra_cost; } - let heap = vm.state.heaps.allocate(); + let heap = vm.state.heaps.allocate_with_content(program.as_ref()); vm.state.current_frame.heaps_i_am_keeping_alive.push(heap); - vm.state.heaps[heap].memset(program.as_ref()); let value = FatPointer { offset: 0, diff --git a/src/instruction_handlers/heap_access.rs b/src/instruction_handlers/heap_access.rs index 88a1949f..b4d2d47c 100644 --- a/src/instruction_handlers/heap_access.rs +++ b/src/instruction_handlers/heap_access.rs @@ -7,7 +7,7 @@ use crate::{ fat_pointer::FatPointer, instruction::InstructionResult, state::State, - ExecutionEnd, Instruction, VirtualMachine, World, + ExecutionEnd, HeapId, Instruction, VirtualMachine, World, }; use std::ops::Range; use u256::U256; @@ -15,20 +15,18 @@ use u256::U256; pub trait HeapInterface { fn read_u256(&self, start_address: u32) -> U256; fn read_u256_partially(&self, range: Range) -> U256; - fn write_u256(&mut self, start_address: u32, value: U256); fn read_range_big_endian(&self, range: Range) -> Vec; - fn memset(&mut self, memory: &[u8]); } pub trait HeapFromState { - fn get_heap(state: &mut State) -> &mut impl HeapInterface; + fn get_heap(state: &State) -> HeapId; fn get_heap_size(state: &mut State) -> &mut u32; } pub struct Heap; impl HeapFromState for Heap { - fn get_heap(state: &mut State) -> &mut impl HeapInterface { - &mut state.heaps[state.current_frame.heap] + fn get_heap(state: &State) -> HeapId { + state.current_frame.heap } fn get_heap_size(state: &mut State) -> &mut u32 { &mut state.current_frame.heap_size @@ -37,8 +35,8 @@ impl HeapFromState for Heap { pub struct AuxHeap; impl HeapFromState for AuxHeap { - fn get_heap(state: &mut State) -> &mut impl HeapInterface { - &mut state.heaps[state.current_frame.aux_heap] + fn get_heap(state: &State) -> HeapId { + state.current_frame.aux_heap } fn get_heap_size(state: &mut State) -> &mut u32 { &mut state.current_frame.aux_heap_size @@ -72,7 +70,8 @@ fn load( return Ok(&PANIC); } - let value = H::get_heap(&mut vm.state).read_u256(address); + let heap = H::get_heap(&vm.state); + let value = vm.state.heaps[heap].read_u256(address); Register1::set(args, &mut vm.state, value); if INCREMENT { @@ -108,7 +107,8 @@ fn store, } +impl Heap { + fn write_u256(&mut self, start_address: u32, value: U256) { + assert!(self.write.is_none()); + self.write = Some((start_address, value)); + } + + pub(crate) fn is_empty(&self) -> bool { + unimplemented!() + } +} + impl HeapInterface for Heap { fn read_u256(&self, start_address: u32) -> U256 { assert!(self.write.is_none()); @@ -25,20 +36,10 @@ impl HeapInterface for Heap { U256::from_little_endian(&result) } - fn write_u256(&mut self, start_address: u32, value: U256) { - assert!(self.write.is_none()); - self.write = Some((start_address, value)); - } - fn read_range_big_endian(&self, _: std::ops::Range) -> Vec { // This is wrong, but this method is only used to get the final return value. vec![] } - - fn memset(&mut self, src: &[u8]) { - let u = U256::from_big_endian(src); - self.write_u256(0, u); - } } impl<'a> Arbitrary<'a> for Heap { @@ -69,6 +70,14 @@ impl Heaps { self.heap_id } + pub(crate) fn allocate_with_content(&mut self, content: &[u8]) -> HeapId { + let id = self.allocate(); + self.read + .get_mut(id) + .write_u256(0, U256::from_big_endian(content)); + id + } + pub(crate) fn deallocate(&mut self, _: HeapId) {} pub(crate) fn from_id( @@ -80,6 +89,22 @@ impl Heaps { read: u.arbitrary()?, }) } + + pub fn write_u256(&mut self, heap: HeapId, start_address: u32, value: U256) { + self.read.get_mut(heap).write_u256(start_address, value); + } + + pub(crate) fn snapshot(&self) -> usize { + unimplemented!() + } + + pub(crate) fn rollback(&mut self, _: usize) { + unimplemented!() + } + + pub(crate) fn delete_history(&mut self) { + unimplemented!() + } } impl Index for Heaps { @@ -90,12 +115,6 @@ impl Index for Heaps { } } -impl IndexMut for Heaps { - fn index_mut(&mut self, index: HeapId) -> &mut Self::Output { - self.read.get_mut(index) - } -} - impl PartialEq for Heaps { fn eq(&self, _: &Self) -> bool { false diff --git a/src/single_instruction_test/stack.rs b/src/single_instruction_test/stack.rs index 988bc4b5..802fc2eb 100644 --- a/src/single_instruction_test/stack.rs +++ b/src/single_instruction_test/stack.rs @@ -72,6 +72,14 @@ impl Stack { && (self.slot_written.is_none() || is_valid_tagged_value((self.value_written, self.pointer_tag_written))) } + + pub(crate) fn snapshot(&self) -> StackSnapshot { + unimplemented!() + } + + pub(crate) fn rollback(&mut self, _: StackSnapshot) { + unimplemented!() + } } #[derive(Default, Debug)] @@ -91,3 +99,5 @@ impl StackPool { pub fn recycle(&mut self, _: Box) {} } + +pub(crate) struct StackSnapshot; diff --git a/src/stack.rs b/src/stack.rs index dc8552b4..352afc82 100644 --- a/src/stack.rs +++ b/src/stack.rs @@ -53,6 +53,36 @@ impl Stack { pub(crate) fn clear_pointer_flag(&mut self, slot: u16) { self.pointer_flags.clear(slot); } + + pub(crate) fn snapshot(&self) -> StackSnapshot { + let dirty_prefix_end = NUMBER_OF_DIRTY_AREAS - self.dirty_areas.leading_zeros() as usize; + + StackSnapshot { + pointer_flags: self.pointer_flags.clone(), + dirty_areas: self.dirty_areas, + slots: self.slots[..DIRTY_AREA_SIZE * dirty_prefix_end].into(), + } + } + + pub(crate) fn rollback(&mut self, snapshot: StackSnapshot) { + let StackSnapshot { + pointer_flags, + dirty_areas, + slots, + } = snapshot; + + self.zero(); + + self.pointer_flags = pointer_flags; + self.dirty_areas = dirty_areas; + self.slots[..slots.len()].copy_from_slice(&slots); + } +} + +pub(crate) struct StackSnapshot { + pointer_flags: Bitset, + dirty_areas: u64, + slots: Box<[U256]>, } impl Clone for Box { diff --git a/src/state.rs b/src/state.rs index 3d8681d0..0cf21487 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,6 @@ use crate::{ addressing_modes::Addressable, - callframe::Callframe, + callframe::{Callframe, CallframeSnapshot}, fat_pointer::FatPointer, heap::{Heaps, CALLDATA_HEAP, FIRST_AUX_HEAP, FIRST_HEAP}, predication::Flags, @@ -106,6 +106,46 @@ impl State { pub(crate) fn get_context_u128(&self) -> u128 { self.current_frame.context_u128 } + + pub(crate) fn snapshot(&self) -> StateSnapshot { + assert!(self.heaps[self.current_frame.aux_heap].is_empty()); + StateSnapshot { + registers: self.registers, + register_pointer_flags: self.register_pointer_flags, + flags: self.flags.clone(), + bootloader_frame: self.current_frame.snapshot(), + bootloader_heap_snapshot: self.heaps.snapshot(), + transaction_number: self.transaction_number, + context_u128: self.context_u128, + } + } + + pub(crate) fn rollback(&mut self, snapshot: StateSnapshot) { + assert!(self.heaps[self.current_frame.aux_heap].is_empty()); + let StateSnapshot { + registers, + register_pointer_flags, + flags, + bootloader_frame, + bootloader_heap_snapshot, + transaction_number, + context_u128, + } = snapshot; + + for heap in self.current_frame.rollback(bootloader_frame) { + self.heaps.deallocate(heap); + } + self.heaps.rollback(bootloader_heap_snapshot); + self.registers = registers; + self.register_pointer_flags = register_pointer_flags; + self.flags = flags; + self.transaction_number = transaction_number; + self.context_u128 = context_u128; + } + + pub(crate) fn delete_history(&mut self) { + self.heaps.delete_history(); + } } impl Addressable for State { @@ -144,3 +184,17 @@ impl Addressable for State { self.current_frame.is_kernel } } + +pub(crate) struct StateSnapshot { + registers: [U256; 16], + register_pointer_flags: u16, + + flags: Flags, + + bootloader_frame: CallframeSnapshot, + + bootloader_heap_snapshot: usize, + transaction_number: u16, + + context_u128: u128, +} diff --git a/src/vm.rs b/src/vm.rs index 8243d016..f25ef91e 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -1,4 +1,5 @@ use crate::heap::HeapId; +use crate::state::StateSnapshot; use crate::world_diff::ExternalSnapshot; use crate::{ callframe::{Callframe, FrameRemnant}, @@ -173,19 +174,36 @@ impl VirtualMachine { /// # Panics /// Calling this function outside of the initial callframe is not allowed. pub fn snapshot(&self) -> VmSnapshot { - assert!(self.state.previous_frames.is_empty()); + assert!( + self.state.previous_frames.is_empty(), + "Snapshotting is only allowed in the bootloader!" + ); VmSnapshot { world_snapshot: self.world_diff.external_snapshot(), - state_snapshot: self.state.clone(), + state_snapshot: self.state.snapshot(), } } /// Returns the VM to the state it was in when the snapshot was created. /// # Panics /// Rolling back snapshots in anything but LIFO order may panic. + /// Rolling back outside the initial callframe will panic. pub fn rollback(&mut self, snapshot: VmSnapshot) { + assert!( + self.state.previous_frames.is_empty(), + "Rolling back is only allowed in the bootloader!" + ); self.world_diff.external_rollback(snapshot.world_snapshot); - self.state = snapshot.state_snapshot; + self.state.rollback(snapshot.state_snapshot); + } + + /// This must only be called when it is known that the VM cannot be rolled back, + /// so there must not be any external snapshots and the callstack + /// should ideally be empty, though in practice it sometimes contains + /// a near call inside the bootloader. + pub fn delete_history(&mut self) { + self.world_diff.delete_history(); + self.state.delete_history(); } #[allow(clippy::too_many_arguments)] @@ -306,5 +324,5 @@ impl VirtualMachine { pub struct VmSnapshot { world_snapshot: ExternalSnapshot, - state_snapshot: State, + state_snapshot: StateSnapshot, } diff --git a/src/world_diff.rs b/src/world_diff.rs index 8b89c1ab..9782be5b 100644 --- a/src/world_diff.rs +++ b/src/world_diff.rs @@ -315,13 +315,15 @@ impl WorldDiff { .rollback(snapshot.written_storage_slots); } - /// This must only be called when it is known that the VM cannot be rolled back, - /// so there must not be any external snapshots and the callstack - /// should ideally be empty, though in practice it sometimes contains - /// a near call inside the bootloader. - pub fn delete_history(&mut self) { + pub(crate) fn delete_history(&mut self) { self.storage_changes.delete_history(); + self.paid_changes.delete_history(); + self.transient_storage_changes.delete_history(); self.events.delete_history(); + self.l2_to_l1_logs.delete_history(); + self.pubdata.delete_history(); + self.storage_refunds.delete_history(); + self.pubdata_costs.delete_history(); self.decommitted_hashes.delete_history(); self.read_storage_slots.delete_history(); self.written_storage_slots.delete_history();