diff --git a/vm/rust/src/juno_state_reader.rs b/vm/rust/src/juno_state_reader.rs index cd02e8fbc1..ed5a79e2ff 100644 --- a/vm/rust/src/juno_state_reader.rs +++ b/vm/rust/src/juno_state_reader.rs @@ -35,18 +35,22 @@ extern "C" { -> *const c_char; } -static CLASS_CACHE: Lazy>> = Lazy::new(|| { - Mutex::new(SizedCache::with_size(128)) -}); +struct CachedContractClass { + pub definition: ContractClass, + pub cached_on_height: u64, +} +static CLASS_CACHE: Lazy>> = + Lazy::new(|| Mutex::new(SizedCache::with_size(128))); pub struct JunoStateReader { pub handle: usize, // uintptr_t equivalent + pub height: u64, } impl JunoStateReader { - pub fn new(handle: usize) -> Self { - Self { handle } + pub fn new(handle: usize, height: u64) -> Self { + Self { handle, height } } } @@ -115,7 +119,11 @@ impl StateReader for JunoStateReader { class_hash: &ClassHash, ) -> StateResult { if let Some(cached_class) = CLASS_CACHE.lock().unwrap().cache_get(class_hash) { - return Ok(cached_class.clone()) + // skip the cache if it comes from a height higher than ours. Class might be undefined on the height + // that we are reading from right now. + if cached_class.cached_on_height <= self.height { + return Ok(cached_class.definition.clone()); + } } let class_hash_bytes = felt_to_byte_array(&class_hash.0); @@ -126,7 +134,13 @@ impl StateReader for JunoStateReader { let json_str = unsafe { CStr::from_ptr(ptr) }.to_str().unwrap(); let contract_class = contract_class_from_json_str(json_str); if let Ok(class) = &contract_class { - CLASS_CACHE.lock().unwrap().cache_set(*class_hash, class.clone()); + CLASS_CACHE.lock().unwrap().cache_set( + *class_hash, + CachedContractClass { + definition: class.clone(), + cached_on_height: self.height, + }, + ); } unsafe { JunoFree(ptr as *const c_void) }; diff --git a/vm/rust/src/lib.rs b/vm/rust/src/lib.rs index 8c65efb2ce..93a8f3af07 100644 --- a/vm/rust/src/lib.rs +++ b/vm/rust/src/lib.rs @@ -67,7 +67,7 @@ pub extern "C" fn cairoVMCall( block_timestamp: c_ulonglong, chain_id: *const c_char, ) { - let reader = JunoStateReader::new(reader_handle); + let reader = JunoStateReader::new(reader_handle, block_number); let contract_addr_felt = ptr_to_felt(contract_address); let class_hash = if class_hash.is_null() { None @@ -148,7 +148,7 @@ pub extern "C" fn cairoVMExecute( gas_price: *const c_uchar, legacy_json: c_uchar, ) { - let reader = JunoStateReader::new(reader_handle); + let reader = JunoStateReader::new(reader_handle, block_number); let chain_id_str = unsafe { CStr::from_ptr(chain_id) }.to_str().unwrap(); let txn_json_str = unsafe { CStr::from_ptr(txns_json) }.to_str().unwrap(); let txns_and_query_bits: Result, serde_json::Error> =