diff --git a/Cargo.lock b/Cargo.lock index 862b9ca59021dc..568dedb330336e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5326,6 +5326,7 @@ dependencies = [ "strum_macros", "tar", "tempfile", + "test-case", "thiserror", ] diff --git a/accounts-db/Cargo.toml b/accounts-db/Cargo.toml index b4fcceea000552..6ce4d2f087e72d 100644 --- a/accounts-db/Cargo.toml +++ b/accounts-db/Cargo.toml @@ -79,6 +79,7 @@ solana-accounts-db = { path = ".", features = ["dev-context-only-utils"] } solana-logger = { workspace = true } solana-sdk = { workspace = true, features = ["dev-context-only-utils"] } static_assertions = { workspace = true } +test-case = { workspace = true } [package.metadata.docs.rs] targets = ["x86_64-unknown-linux-gnu"] diff --git a/accounts-db/src/rolling_bit_field.rs b/accounts-db/src/rolling_bit_field.rs index 65d3ff76b54ae7..73a71d7a084110 100644 --- a/accounts-db/src/rolling_bit_field.rs +++ b/accounts-db/src/rolling_bit_field.rs @@ -2,7 +2,11 @@ //! Relies on there being a sliding window of key values. The key values continue to increase. //! Old key values are removed from the lesser values and do not accumulate. -use {bv::BitVec, solana_nohash_hasher::IntSet, solana_sdk::clock::Slot}; +mod iterators; +use { + bv::BitVec, iterators::RollingBitFieldOnesIter, solana_nohash_hasher::IntSet, + solana_sdk::clock::Slot, +}; #[derive(Debug, Default, AbiExample, Clone)] pub struct RollingBitField { @@ -283,6 +287,14 @@ impl RollingBitField { } all } + + /// Returns an iterator over the rolling bit field + /// + /// The iterator yields all the 'set' bits. + /// Note, the iteration order of the bits in 'excess' is not deterministic. + pub fn iter_ones(&self) -> RollingBitFieldOnesIter<'_> { + RollingBitFieldOnesIter::new(self) + } } #[cfg(test)] diff --git a/accounts-db/src/rolling_bit_field/iterators.rs b/accounts-db/src/rolling_bit_field/iterators.rs new file mode 100644 index 00000000000000..dd075037ee119c --- /dev/null +++ b/accounts-db/src/rolling_bit_field/iterators.rs @@ -0,0 +1,76 @@ +//! Iterators for RollingBitField + +use {super::RollingBitField, std::ops::Range}; + +/// Iterate over the 'set' bits of a RollingBitField +#[derive(Debug)] +pub struct RollingBitFieldOnesIter<'a> { + rolling_bit_field: &'a RollingBitField, + excess_iter: std::collections::hash_set::Iter<'a, u64>, + bit_range: Range, +} + +impl<'a> RollingBitFieldOnesIter<'a> { + #[must_use] + pub fn new(rolling_bit_field: &'a RollingBitField) -> Self { + Self { + rolling_bit_field, + excess_iter: rolling_bit_field.excess.iter(), + bit_range: rolling_bit_field.min..rolling_bit_field.max_exclusive, + } + } +} + +impl Iterator for RollingBitFieldOnesIter<'_> { + type Item = u64; + + fn next(&mut self) -> Option { + // Iterate over the excess first + if let Some(excess) = self.excess_iter.next() { + return Some(*excess); + } + + // Then iterate over the bit vec + loop { + // If there are no more bits in the range, then we've iterated over everything and are done + let Some(bit) = self.bit_range.next() else { + return None; + }; + + if self.rolling_bit_field.contains_assume_in_range(&bit) { + break Some(bit); + } + } + } +} + +#[cfg(test)] +mod tests { + use {super::*, test_case::test_case}; + + #[test_case(128, vec![]; "empty")] + #[test_case(128, vec![128_007, 128_017, 128_107]; "without excess")] + #[test_case(128, vec![128_007, 128_017, 128_107, 3, 30, 300]; "with excess")] + // Even though these values are within the range, in an absolute sense, + // they will wrap around after multiples of 16. + #[test_case(16, vec![35, 40, 45 ])] + #[test_case(16, vec![ 40, 45, 50 ])] + #[test_case(16, vec![ 45, 50, 55 ])] + #[test_case(16, vec![ 50, 55, 60 ])] + #[test_case(16, vec![ 55, 60, 65 ])] + #[test_case(16, vec![ 60, 65, 70])] + fn test_rolling_bit_field_ones_iter(num_bits: u64, mut expected: Vec) { + let mut rolling_bit_field = RollingBitField::new(num_bits); + for val in &expected { + rolling_bit_field.insert(*val); + } + + let mut actual: Vec<_> = rolling_bit_field.iter_ones().collect(); + + // Since iteration order of the 'excess' is not deterministic, sort the 'actual' + // and 'expected' vectors to ensure they can compare deterministically. + actual.sort_unstable(); + expected.sort_unstable(); + assert_eq!(actual, expected); + } +}