From 8e8e9492c159edc866a635ff8da3ec61fc2dc2c5 Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Thu, 7 Mar 2024 12:20:11 -0600 Subject: [PATCH] implements weighted shuffle using binary tree --- gossip/src/weighted_shuffle.rs | 163 +++++++++++++++++++++++---------- 1 file changed, 113 insertions(+), 50 deletions(-) diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 250d1efb0f6800..4ae083328dc883 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -18,15 +18,11 @@ use { /// non-zero weighted indices. #[derive(Clone)] pub struct WeightedShuffle { - arr: Vec, // Underlying array implementing binary indexed tree. + arr: Vec, // Underlying array implementing binary tree. sum: T, // Current sum of weights, excluding already selected indices. zeros: Vec, // Indices of zero weighted entries. } -// The implementation uses binary indexed tree: -// https://en.wikipedia.org/wiki/Fenwick_tree -// to maintain cumulative sum of weights excluding already selected indices -// over self.arr. impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, @@ -34,36 +30,39 @@ where /// If weights are negative or overflow the total sum /// they are treated as zero. pub fn new(name: &'static str, weights: &[T]) -> Self { - let size = weights.len() + 1; let zero = ::default(); - let mut arr = vec![zero; size]; + let mut arr = vec![zero; get_tree_size(weights.len())]; let mut sum = zero; let mut zeros = Vec::default(); let mut num_negative = 0; let mut num_overflow = 0; - for (mut k, &weight) in (1usize..).zip(weights) { + for (k, &weight) in weights.iter().enumerate() { #[allow(clippy::neg_cmp_op_on_partial_ord)] // weight < zero does not work for NaNs. if !(weight >= zero) { - zeros.push(k - 1); + zeros.push(k); num_negative += 1; continue; } if weight == zero { - zeros.push(k - 1); + zeros.push(k); continue; } sum = match sum.checked_add(&weight) { Some(val) => val, None => { - zeros.push(k - 1); + zeros.push(k); num_overflow += 1; continue; } }; - while k < size { - arr[k] += weight; - k += k & k.wrapping_neg(); + let mut index = arr.len() + k; + while index != 0 { + let offset = index & 1; + index = (index - 1) >> 1; + if offset > 0 { + arr[index] += weight; + } } } if num_negative > 0 { @@ -80,54 +79,65 @@ impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub, { - // Returns cumulative sum of current weights upto index k (inclusive). - fn cumsum(&self, mut k: usize) -> T { - let mut out = ::default(); - while k != 0 { - out += self.arr[k]; - k ^= k & k.wrapping_neg(); - } - out - } - // Removes given weight at index k. - fn remove(&mut self, mut k: usize, weight: T) { + fn remove(&mut self, k: usize, weight: T) { self.sum -= weight; - let size = self.arr.len(); - while k < size { - self.arr[k] -= weight; - k += k & k.wrapping_neg(); + let mut index = self.arr.len() + k; + while index != 0 { + let offset = index & 1; + index = (index - 1) >> 1; + if offset > 0 { + self.arr[index] -= weight; + } } } - // Returns smallest index such that self.cumsum(k) > val, + // Returns smallest index such that cumsum of weights[..=k] > val, // along with its respective weight. - fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) { + fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) { let zero = ::default(); debug_assert!(val >= zero); debug_assert!(val < self.sum); - let mut lo = (/*index:*/ 0, /*cumsum:*/ zero); - let mut hi = (self.arr.len() - 1, self.sum); - while lo.0 + 1 < hi.0 { - let k = lo.0 + (hi.0 - lo.0) / 2; - let sum = self.cumsum(k); - if sum <= val { - lo = (k, sum); + let mut index = 0; + let mut weight = self.sum; + while index < self.arr.len() { + if val < self.arr[index] { + weight = self.arr[index]; + index = (index << 1) + 1; } else { - hi = (k, sum); + weight -= self.arr[index]; + val -= self.arr[index]; + index = (index << 1) + 2; } } - debug_assert!(lo.1 <= val); - debug_assert!(hi.1 > val); - (hi.0, hi.1 - lo.1) + (index - self.arr.len(), weight) } - pub fn remove_index(&mut self, index: usize) { - let zero = ::default(); - let weight = self.cumsum(index + 1) - self.cumsum(index); - if weight != zero { - self.remove(index + 1, weight); - } else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) { + pub fn remove_index(&mut self, k: usize) { + let mut index = self.arr.len() + k; + let mut weight = ::default(); // zero + while index != 0 { + let offset = index & 1; + index = (index - 1) >> 1; + if offset > 0 { + if self.arr[index] != weight { + self.remove(k, self.arr[index] - weight); + } else { + self.remove_zero(k); + } + return; + } + weight += self.arr[index]; + } + if self.sum != weight { + self.remove(k, self.sum - weight); + } else { + self.remove_zero(k); + } + } + + fn remove_zero(&mut self, k: usize) { + if let Some(index) = self.zeros.iter().position(|&ix| ix == k) { self.zeros.remove(index); } } @@ -143,7 +153,7 @@ where if self.sum > zero { let sample = ::Sampler::sample_single(zero, self.sum, rng); let (index, _weight) = WeightedShuffle::search(self, sample); - return Some(index - 1); + return Some(index); } if self.zeros.is_empty() { return None; @@ -164,7 +174,7 @@ where let sample = ::Sampler::sample_single(zero, self.sum, rng); let (index, weight) = WeightedShuffle::search(&self, sample); self.remove(index, weight); - return Some(index - 1); + return Some(index); } if self.zeros.is_empty() { return None; @@ -176,6 +186,19 @@ where } } +// Maps number of items to the "internal" size of the binary tree "implicitly" +// holding those items on the leaves. +fn get_tree_size(count: usize) -> usize { + let shift = usize::BITS + - count.leading_zeros() + - if count.is_power_of_two() && count != 1 { + 1 + } else { + 0 + }; + (1usize << shift) - 1 +} + #[cfg(test)] mod tests { use { @@ -218,6 +241,30 @@ mod tests { shuffle } + #[test] + fn test_get_tree_size() { + assert_eq!(get_tree_size(0), 0); + assert_eq!(get_tree_size(1), 1); + assert_eq!(get_tree_size(2), 1); + assert_eq!(get_tree_size(3), 3); + assert_eq!(get_tree_size(4), 3); + for count in 5..9 { + assert_eq!(get_tree_size(count), 7); + } + for count in 9..17 { + assert_eq!(get_tree_size(count), 15); + } + for count in 17..33 { + assert_eq!(get_tree_size(count), 31); + } + assert_eq!(get_tree_size((1 << 16) - 1), (1 << 16) - 1); + assert_eq!(get_tree_size(1 << 16), (1 << 16) - 1); + assert_eq!(get_tree_size((1 << 16) + 1), (1 << 17) - 1); + assert_eq!(get_tree_size((1 << 17) - 1), (1 << 17) - 1); + assert_eq!(get_tree_size(1 << 17), (1 << 17) - 1); + assert_eq!(get_tree_size((1 << 17) + 1), (1 << 18) - 1); + } + // Asserts that empty weights will return empty shuffle. #[test] fn test_weighted_shuffle_empty_weights() { @@ -357,4 +404,20 @@ mod tests { assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0])); } } + + #[test] + fn test_weighted_shuffle_paranoid() { + let mut rng = rand::thread_rng(); + for size in 0..1351 { + let weights: Vec<_> = repeat_with(|| rng.gen_range(0..1000)).take(size).collect(); + let seed = rng.gen::<[u8; 32]>(); + let mut rng = ChaChaRng::from_seed(seed); + let shuffle_slow = weighted_shuffle_slow(&mut rng.clone(), weights.clone()); + let shuffle = WeightedShuffle::new("", &weights); + if size > 0 { + assert_eq!(shuffle.first(&mut rng.clone()), Some(shuffle_slow[0])); + } + assert_eq!(shuffle.shuffle(&mut rng).collect::>(), shuffle_slow); + } + } }