diff --git a/rs/execution_environment/src/query_handler/query_cache.rs b/rs/execution_environment/src/query_handler/query_cache.rs index a3b91c1c421..4dbc8bb0d67 100644 --- a/rs/execution_environment/src/query_handler/query_cache.rs +++ b/rs/execution_environment/src/query_handler/query_cache.rs @@ -114,7 +114,7 @@ impl CountBytes for QueryCache { impl Default for QueryCache { fn default() -> Self { QueryCache { - cache: Mutex::new(LruCache::new((u64::MAX / 2).into())), + cache: Mutex::new(LruCache::unbounded()), } } } @@ -134,6 +134,7 @@ impl QueryCache { if value.is_valid(env) { return Some(value.result()); } else { + // Remove the invalid entry. cache.pop(key); } } @@ -141,7 +142,6 @@ impl QueryCache { } pub(crate) fn insert(&self, key: EntryKey, value: EntryValue) { - let size = (key.count_bytes() + value.count_bytes()) as u64; - self.cache.lock().unwrap().put(key, value, size.into()); + self.cache.lock().unwrap().push(key, value); } } diff --git a/rs/utils/lru_cache/Cargo.toml b/rs/utils/lru_cache/Cargo.toml index 2d0491d958e..652eaa2f655 100644 --- a/rs/utils/lru_cache/Cargo.toml +++ b/rs/utils/lru_cache/Cargo.toml @@ -7,4 +7,4 @@ edition = "2021" [dependencies] ic-types = { path = "../../types/types" } -lru = { version = "0.7.1", default-features = false } \ No newline at end of file +lru = { version = "0.7.1", default-features = false } diff --git a/rs/utils/lru_cache/src/lib.rs b/rs/utils/lru_cache/src/lib.rs index 9761c4c1b01..02d4ab0fd7f 100644 --- a/rs/utils/lru_cache/src/lib.rs +++ b/rs/utils/lru_cache/src/lib.rs @@ -4,59 +4,70 @@ use std::hash::Hash; /// The upper bound on cache item size and cache capacity. /// It is needed to ensure that all arithmetic operations /// do not overflow. -const MAX_SIZE: NumBytes = NumBytes::new(u64::MAX / 2); +const MAX_SIZE: usize = usize::MAX / 2; /// A cache with bounded memory capacity that evicts items using the /// least-recently used eviction policy. It guarantees that the sum of /// sizes of the cached items does not exceed the pre-configured capacity. pub struct LruCache where - K: Eq + Hash, + K: CountBytes + Eq + Hash, + V: CountBytes, { - cache: lru::LruCache, - capacity: NumBytes, - size: NumBytes, + cache: lru::LruCache, + capacity: usize, + size: usize, } impl CountBytes for LruCache where - K: Eq + Hash, + K: CountBytes + Eq + Hash, + V: CountBytes, { fn count_bytes(&self) -> usize { - self.size.get() as usize + self.size } } impl LruCache where - K: Eq + Hash, + K: CountBytes + Eq + Hash, + V: CountBytes, { /// Constructs a new LRU cache with the given capacity. /// The capacity must not exceed `MAX_SIZE = (2^63 - 1)`. pub fn new(capacity: NumBytes) -> Self { + let capacity = capacity.get() as usize; assert!(capacity <= MAX_SIZE); let lru_cache = Self { cache: lru::LruCache::unbounded(), capacity, - size: NumBytes::new(0), + size: 0, }; lru_cache.check_invariants(); lru_cache } + /// Creates a new LRU Cache that never automatically evicts items. + pub fn unbounded() -> Self { + Self::new(NumBytes::new(MAX_SIZE as u64)) + } + /// Returns the value corresponding to the given key. /// It also marks the item as the most-recently used. pub fn get(&mut self, key: &K) -> Option<&V> { - self.cache.get(key).map(|(value, _size)| value) + self.cache.get(key) } /// Inserts or updates the item with the given key. /// It also marks the item as the most-recently used. /// The size parameter specifies the size of the item, /// which must not exceed `MAX_SIZE = (2^63 - 1)`. - pub fn put(&mut self, key: K, value: V, size: NumBytes) { + pub fn push(&mut self, key: K, value: V) { + let size = key.count_bytes() + value.count_bytes(); assert!(size <= MAX_SIZE); - if let Some((_, prev_size)) = self.cache.put(key, (value, size)) { + if let Some((prev_key, prev_value)) = self.cache.push(key, value) { + let prev_size = prev_key.count_bytes() + prev_value.count_bytes(); debug_assert!(self.size >= prev_size); // This cannot underflow because we know that `self.size` is // the sum of sizes of all items in the cache. @@ -64,7 +75,7 @@ where } // This cannot overflow because we know that // `self.size <= self.capacity <= MAX_SIZE` - // and `size <= MAX_SIZE == u64::MAX / 2`. + // and `size <= MAX_SIZE == usize::MAX / 2`. self.size += size; self.evict(); self.check_invariants(); @@ -73,7 +84,8 @@ where /// Removes and returns the value corresponding to the key from the cache or /// `None` if it does not exist. pub fn pop(&mut self, key: &K) -> Option { - if let Some((value, size)) = self.cache.pop(key) { + if let Some((key, value)) = self.cache.pop_entry(key) { + let size = key.count_bytes() + value.count_bytes(); debug_assert!(self.size >= size); self.size -= size; self.check_invariants(); @@ -86,7 +98,7 @@ where /// Clears the cache by removing all items. pub fn clear(&mut self) { self.cache.clear(); - self.size = NumBytes::new(0); + self.size = 0; self.check_invariants(); } @@ -94,7 +106,8 @@ where fn evict(&mut self) { while self.size > self.capacity { match self.cache.pop_lru() { - Some((_k, (_v, size))) => { + Some((key, value)) => { + let size = key.count_bytes() + value.count_bytes(); debug_assert!(self.size >= size); // This cannot underflow because we know that `self.size` is // the sum of sizes of all items in the cache. @@ -106,7 +119,13 @@ where } fn check_invariants(&self) { - debug_assert_eq!(self.size, self.cache.iter().map(|(_k, (_v, s))| *s).sum()); + debug_assert_eq!( + self.size, + self.cache + .iter() + .map(|(key, value)| key.count_bytes() + value.count_bytes()) + .sum::() + ); debug_assert!(self.size <= self.capacity); } } @@ -115,105 +134,180 @@ where mod tests { use super::*; + #[derive(Debug, Eq, Hash, PartialEq)] + struct ValueSize(u32, usize); + + impl CountBytes for ValueSize { + fn count_bytes(&self) -> usize { + self.1 + } + } + + #[derive(Debug, Eq, Hash, PartialEq)] + struct Key(u32); + + impl CountBytes for Key { + fn count_bytes(&self) -> usize { + 0 + } + } + #[test] fn lru_cache_single_entry() { - let mut lru = LruCache::::new(NumBytes::new(10)); + let mut lru = LruCache::::new(NumBytes::new(10)); - assert!(lru.get(&0).is_none()); + assert!(lru.get(&Key(0)).is_none()); - lru.put(0, 42, NumBytes::new(10)); - assert_eq!(*lru.get(&0).unwrap(), 42); + lru.push(Key(0), ValueSize(42, 10)); + assert_eq!(*lru.get(&Key(0)).unwrap(), ValueSize(42, 10)); - lru.put(0, 42, NumBytes::new(11)); - assert!(lru.get(&0).is_none()); + lru.push(Key(0), ValueSize(42, 11)); + assert!(lru.get(&Key(0)).is_none()); - lru.put(0, 24, NumBytes::new(10)); - assert_eq!(*lru.get(&0).unwrap(), 24); + lru.push(Key(0), ValueSize(24, 10)); + assert_eq!(*lru.get(&Key(0)).unwrap(), ValueSize(24, 10)); } #[test] fn lru_cache_multiple_entries() { - let mut lru = LruCache::::new(NumBytes::new(10)); + let mut lru = LruCache::::new(NumBytes::new(10)); for i in 0..20 { - lru.put(i, i, NumBytes::new(1)); + lru.push(Key(i), ValueSize(i, 1)); } for i in 0..20 { - let result = lru.get(&i); + let result = lru.get(&Key(i)); if i < 10 { assert!(result.is_none()); } else { - assert_eq!(*result.unwrap(), i); + assert_eq!(*result.unwrap(), ValueSize(i, 1)); } } } #[test] - fn lru_cache_eviction() { - let mut lru = LruCache::::new(NumBytes::new(10)); + fn lru_cache_value_eviction() { + let mut lru = LruCache::::new(NumBytes::new(10)); + + assert!(lru.get(&Key(0)).is_none()); + + lru.push(Key(0), ValueSize(42, 10)); + assert_eq!(*lru.get(&Key(0)).unwrap(), ValueSize(42, 10)); + + lru.push(Key(1), ValueSize(20, 0)); + assert_eq!(*lru.get(&Key(0)).unwrap(), ValueSize(42, 10)); + assert_eq!(*lru.get(&Key(1)).unwrap(), ValueSize(20, 0)); + + lru.push(Key(2), ValueSize(10, 10)); + assert!(lru.get(&Key(0)).is_none()); + assert_eq!(*lru.get(&Key(1)).unwrap(), ValueSize(20, 0)); + assert_eq!(*lru.get(&Key(2)).unwrap(), ValueSize(10, 10)); + + lru.push(Key(3), ValueSize(30, 10)); + assert!(lru.get(&Key(1)).is_none()); + assert!(lru.get(&Key(2)).is_none()); + assert_eq!(*lru.get(&Key(3)).unwrap(), ValueSize(30, 10)); + + lru.push(Key(3), ValueSize(60, 5)); + assert_eq!(*lru.get(&Key(3)).unwrap(), ValueSize(60, 5)); + + lru.push(Key(4), ValueSize(40, 5)); + assert_eq!(*lru.get(&Key(3)).unwrap(), ValueSize(60, 5)); + assert_eq!(*lru.get(&Key(4)).unwrap(), ValueSize(40, 5)); + + lru.push(Key(4), ValueSize(40, 10)); + assert!(lru.get(&Key(3)).is_none()); + assert_eq!(*lru.get(&Key(4)).unwrap(), ValueSize(40, 10)); + } + + #[test] + fn lru_cache_key_eviction() { + let mut lru = LruCache::::new(NumBytes::new(10)); - assert!(lru.get(&0).is_none()); + assert!(lru.get(&ValueSize(0, 10)).is_none()); - lru.put(0, 42, NumBytes::new(10)); - assert_eq!(*lru.get(&0).unwrap(), 42); + lru.push(ValueSize(0, 10), ValueSize(42, 0)); + assert_eq!(*lru.get(&ValueSize(0, 10)).unwrap(), ValueSize(42, 0)); - lru.put(1, 20, NumBytes::new(0)); - assert_eq!(*lru.get(&0).unwrap(), 42); - assert_eq!(*lru.get(&1).unwrap(), 20); + lru.push(ValueSize(1, 0), ValueSize(20, 0)); + assert_eq!(*lru.get(&ValueSize(0, 10)).unwrap(), ValueSize(42, 0)); + assert_eq!(*lru.get(&ValueSize(1, 0)).unwrap(), ValueSize(20, 0)); - lru.put(2, 10, NumBytes::new(10)); - assert!(lru.get(&0).is_none()); - assert_eq!(*lru.get(&1).unwrap(), 20); - assert_eq!(*lru.get(&2).unwrap(), 10); + lru.push(ValueSize(2, 10), ValueSize(10, 0)); + assert!(lru.get(&ValueSize(0, 10)).is_none()); + assert_eq!(*lru.get(&ValueSize(1, 0)).unwrap(), ValueSize(20, 0)); + assert_eq!(*lru.get(&ValueSize(2, 10)).unwrap(), ValueSize(10, 0)); + + lru.push(ValueSize(3, 10), ValueSize(30, 0)); + assert!(lru.get(&ValueSize(1, 0)).is_none()); + assert!(lru.get(&ValueSize(2, 10)).is_none()); + assert_eq!(*lru.get(&ValueSize(3, 10)).unwrap(), ValueSize(30, 0)); + + lru.push(ValueSize(3, 5), ValueSize(60, 0)); + assert_eq!(*lru.get(&ValueSize(3, 5)).unwrap(), ValueSize(60, 0)); + + lru.push(ValueSize(4, 5), ValueSize(40, 0)); + assert_eq!(*lru.get(&ValueSize(3, 5)).unwrap(), ValueSize(60, 0)); + assert_eq!(*lru.get(&ValueSize(4, 5)).unwrap(), ValueSize(40, 0)); + + lru.push(ValueSize(4, 10), ValueSize(40, 0)); + assert!(lru.get(&ValueSize(3, 5)).is_none()); + assert_eq!(*lru.get(&ValueSize(4, 10)).unwrap(), ValueSize(40, 0)); + } + + #[test] + fn lru_cache_key_and_value_eviction() { + let mut lru = LruCache::::new(NumBytes::new(10)); - lru.put(3, 30, NumBytes::new(10)); - assert!(lru.get(&1).is_none()); - assert!(lru.get(&2).is_none()); - assert_eq!(*lru.get(&3).unwrap(), 30); + lru.push(ValueSize(0, 5), ValueSize(42, 5)); + assert_eq!(*lru.get(&ValueSize(0, 5)).unwrap(), ValueSize(42, 5)); - lru.put(3, 60, NumBytes::new(5)); - assert_eq!(*lru.get(&3).unwrap(), 60); + lru.push(ValueSize(1, 0), ValueSize(20, 0)); + assert_eq!(*lru.get(&ValueSize(0, 5)).unwrap(), ValueSize(42, 5)); + assert_eq!(*lru.get(&ValueSize(1, 0)).unwrap(), ValueSize(20, 0)); - lru.put(4, 40, NumBytes::new(5)); - assert_eq!(*lru.get(&3).unwrap(), 60); - assert_eq!(*lru.get(&4).unwrap(), 40); + lru.push(ValueSize(2, 5), ValueSize(10, 5)); + assert!(lru.get(&ValueSize(0, 5)).is_none()); + assert_eq!(*lru.get(&ValueSize(1, 0)).unwrap(), ValueSize(20, 0)); + assert_eq!(*lru.get(&ValueSize(2, 5)).unwrap(), ValueSize(10, 5)); - lru.put(4, 40, NumBytes::new(10)); - assert!(lru.get(&3).is_none()); - assert_eq!(*lru.get(&4).unwrap(), 40); + lru.push(ValueSize(3, 5), ValueSize(30, 5)); + assert!(lru.get(&ValueSize(1, 0)).is_none()); + assert!(lru.get(&ValueSize(2, 5)).is_none()); + assert_eq!(*lru.get(&ValueSize(3, 5)).unwrap(), ValueSize(30, 5)); } #[test] fn lru_cache_clear() { - let mut lru = LruCache::::new(NumBytes::new(10)); - lru.put(0, 0, NumBytes::new(10)); + let mut lru = LruCache::::new(NumBytes::new(10)); + lru.push(Key(0), ValueSize(0, 10)); lru.clear(); - assert!(lru.get(&0).is_none()); + assert!(lru.get(&Key(0)).is_none()); } #[test] fn lru_cache_pop() { - let mut lru = LruCache::::new(NumBytes::new(10)); - lru.put(0, 0, NumBytes::new(5)); - lru.put(1, 1, NumBytes::new(5)); - lru.pop(&0); - assert!(lru.get(&0).is_none()); - assert!(lru.get(&1).is_some()); - lru.pop(&1); - assert!(lru.get(&1).is_none()); + let mut lru = LruCache::::new(NumBytes::new(10)); + lru.push(Key(0), ValueSize(0, 5)); + lru.push(Key(1), ValueSize(1, 5)); + lru.pop(&Key(0)); + assert!(lru.get(&Key(0)).is_none()); + assert!(lru.get(&Key(1)).is_some()); + lru.pop(&Key(1)); + assert!(lru.get(&Key(1)).is_none()); } #[test] fn lru_cache_count_bytes() { - let mut lru = LruCache::::new(NumBytes::new(10)); - lru.put(0, 0, NumBytes::new(4)); + let mut lru = LruCache::::new(NumBytes::new(10)); + lru.push(Key(0), ValueSize(0, 4)); assert_eq!(4, lru.count_bytes()); - lru.put(1, 1, NumBytes::new(6)); + lru.push(Key(1), ValueSize(1, 6)); assert_eq!(10, lru.count_bytes()); - lru.pop(&0); + lru.pop(&Key(0)); assert_eq!(6, lru.count_bytes()); - lru.pop(&1); + lru.pop(&Key(1)); assert_eq!(0, lru.count_bytes()); } }