From 204b10bf29981d12732bcc8c03adb0485bfd2413 Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Sat, 23 Mar 2024 13:53:46 +0000 Subject: [PATCH] implements weighted shuffle using binary tree (#185) This is partial port of firedancer's implementation of weighted shuffle: https://github.com/firedancer-io/firedancer/blob/3401bfc26/src/ballet/wsample/fd_wsample.c Though Fenwick trees use less space, inverse queries require an additional O(log n) factor for binary search resulting an overall O(n log n log n) performance for weighted shuffle. This commit instead uses a binary tree where each node contains the sum of all weights in its left sub-tree. The weights themselves are implicitly stored at the leaves. Inverse queries and updates to the tree all can be done O(log n) resulting an overall O(n log n) weighted shuffle implementation. Based on benchmarks, this results in 24% improvement in WeightedShuffle::shuffle: Fenwick tree: test bench_weighted_shuffle_new ... bench: 36,686 ns/iter (+/- 191) test bench_weighted_shuffle_shuffle ... bench: 342,625 ns/iter (+/- 4,067) Binary tree: test bench_weighted_shuffle_new ... bench: 59,131 ns/iter (+/- 362) test bench_weighted_shuffle_shuffle ... bench: 260,194 ns/iter (+/- 11,195) Though WeightedShuffle::new is now slower, it generally can be cached and reused as in Turbine: https://github.com/anza-xyz/agave/blob/b3fd87fe8/turbine/src/cluster_nodes.rs#L68 Additionally the new code has better asymptotic performance. For example with 20_000 weights WeightedShuffle::shuffle is 31% faster: Fenwick tree: test bench_weighted_shuffle_new ... bench: 255,071 ns/iter (+/- 9,591) test bench_weighted_shuffle_shuffle ... bench: 2,466,058 ns/iter (+/- 9,873) Binary tree: test bench_weighted_shuffle_new ... bench: 830,727 ns/iter (+/- 10,210) test bench_weighted_shuffle_shuffle ... bench: 1,696,160 ns/iter (+/- 75,271) (cherry picked from commit b6d22374032683257a7d6633896445031ae1f20a) --- gossip/src/weighted_shuffle.rs | 186 +++++++++++++++++++++++---------- 1 file changed, 128 insertions(+), 58 deletions(-) diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 250d1efb0f6800..7c12debce469e0 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -18,15 +18,14 @@ use { /// non-zero weighted indices. #[derive(Clone)] pub struct WeightedShuffle { - arr: Vec, // Underlying array implementing binary indexed tree. - sum: T, // Current sum of weights, excluding already selected indices. + // Underlying array implementing binary tree. + // tree[i] is the sum of weights in the left sub-tree of node i. + tree: Vec, + // Current sum of all weights, excluding already sampled ones. + weight: T, zeros: Vec, // Indices of zero weighted entries. } -// The implementation uses binary indexed tree: -// https://en.wikipedia.org/wiki/Fenwick_tree -// to maintain cumulative sum of weights excluding already selected indices -// over self.arr. impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, @@ -34,36 +33,39 @@ where /// If weights are negative or overflow the total sum /// they are treated as zero. pub fn new(name: &'static str, weights: &[T]) -> Self { - let size = weights.len() + 1; let zero = ::default(); - let mut arr = vec![zero; size]; + let mut tree = vec![zero; get_tree_size(weights.len())]; let mut sum = zero; let mut zeros = Vec::default(); let mut num_negative = 0; let mut num_overflow = 0; - for (mut k, &weight) in (1usize..).zip(weights) { + for (k, &weight) in weights.iter().enumerate() { #[allow(clippy::neg_cmp_op_on_partial_ord)] // weight < zero does not work for NaNs. if !(weight >= zero) { - zeros.push(k - 1); + zeros.push(k); num_negative += 1; continue; } if weight == zero { - zeros.push(k - 1); + zeros.push(k); continue; } sum = match sum.checked_add(&weight) { Some(val) => val, None => { - zeros.push(k - 1); + zeros.push(k); num_overflow += 1; continue; } }; - while k < size { - arr[k] += weight; - k += k & k.wrapping_neg(); + let mut index = tree.len() + k; + while index != 0 { + let offset = index & 1; + index = (index - 1) >> 1; + if offset > 0 { + tree[index] += weight; + } } } if num_negative > 0 { @@ -72,7 +74,11 @@ where if num_overflow > 0 { datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64)); } - Self { arr, sum, zeros } + Self { + tree, + weight: sum, + zeros, + } } } @@ -80,54 +86,65 @@ impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub, { - // Returns cumulative sum of current weights upto index k (inclusive). - fn cumsum(&self, mut k: usize) -> T { - let mut out = ::default(); - while k != 0 { - out += self.arr[k]; - k ^= k & k.wrapping_neg(); - } - out - } - // Removes given weight at index k. - fn remove(&mut self, mut k: usize, weight: T) { - self.sum -= weight; - let size = self.arr.len(); - while k < size { - self.arr[k] -= weight; - k += k & k.wrapping_neg(); + fn remove(&mut self, k: usize, weight: T) { + self.weight -= weight; + let mut index = self.tree.len() + k; + while index != 0 { + let offset = index & 1; + index = (index - 1) >> 1; + if offset > 0 { + self.tree[index] -= weight; + } } } - // Returns smallest index such that self.cumsum(k) > val, + // Returns smallest index such that cumsum of weights[..=k] > val, // along with its respective weight. - fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) { + fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) { let zero = ::default(); debug_assert!(val >= zero); - debug_assert!(val < self.sum); - let mut lo = (/*index:*/ 0, /*cumsum:*/ zero); - let mut hi = (self.arr.len() - 1, self.sum); - while lo.0 + 1 < hi.0 { - let k = lo.0 + (hi.0 - lo.0) / 2; - let sum = self.cumsum(k); - if sum <= val { - lo = (k, sum); + debug_assert!(val < self.weight); + let mut index = 0; + let mut weight = self.weight; + while index < self.tree.len() { + if val < self.tree[index] { + weight = self.tree[index]; + index = (index << 1) + 1; } else { - hi = (k, sum); + weight -= self.tree[index]; + val -= self.tree[index]; + index = (index << 1) + 2; } } - debug_assert!(lo.1 <= val); - debug_assert!(hi.1 > val); - (hi.0, hi.1 - lo.1) + (index - self.tree.len(), weight) } - pub fn remove_index(&mut self, index: usize) { - let zero = ::default(); - let weight = self.cumsum(index + 1) - self.cumsum(index); - if weight != zero { - self.remove(index + 1, weight); - } else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) { + pub fn remove_index(&mut self, k: usize) { + let mut index = self.tree.len() + k; + let mut weight = ::default(); // zero + while index != 0 { + let offset = index & 1; + index = (index - 1) >> 1; + if offset > 0 { + if self.tree[index] != weight { + self.remove(k, self.tree[index] - weight); + } else { + self.remove_zero(k); + } + return; + } + weight += self.tree[index]; + } + if self.weight != weight { + self.remove(k, self.weight - weight); + } else { + self.remove_zero(k); + } + } + + fn remove_zero(&mut self, k: usize) { + if let Some(index) = self.zeros.iter().position(|&ix| ix == k) { self.zeros.remove(index); } } @@ -140,10 +157,10 @@ where // Equivalent to weighted_shuffle.shuffle(&mut rng).next() pub fn first(&self, rng: &mut R) -> Option { let zero = ::default(); - if self.sum > zero { - let sample = ::Sampler::sample_single(zero, self.sum, rng); + if self.weight > zero { + let sample = ::Sampler::sample_single(zero, self.weight, rng); let (index, _weight) = WeightedShuffle::search(self, sample); - return Some(index - 1); + return Some(index); } if self.zeros.is_empty() { return None; @@ -160,11 +177,11 @@ where pub fn shuffle(mut self, rng: &'a mut R) -> impl Iterator + 'a { std::iter::from_fn(move || { let zero = ::default(); - if self.sum > zero { - let sample = ::Sampler::sample_single(zero, self.sum, rng); + if self.weight > zero { + let sample = ::Sampler::sample_single(zero, self.weight, rng); let (index, weight) = WeightedShuffle::search(&self, sample); self.remove(index, weight); - return Some(index - 1); + return Some(index); } if self.zeros.is_empty() { return None; @@ -176,6 +193,19 @@ where } } +// Maps number of items to the "internal" size of the binary tree "implicitly" +// holding those items on the leaves. +fn get_tree_size(count: usize) -> usize { + let shift = usize::BITS + - count.leading_zeros() + - if count.is_power_of_two() && count != 1 { + 1 + } else { + 0 + }; + (1usize << shift) - 1 +} + #[cfg(test)] mod tests { use { @@ -218,6 +248,30 @@ mod tests { shuffle } + #[test] + fn test_get_tree_size() { + assert_eq!(get_tree_size(0), 0); + assert_eq!(get_tree_size(1), 1); + assert_eq!(get_tree_size(2), 1); + assert_eq!(get_tree_size(3), 3); + assert_eq!(get_tree_size(4), 3); + for count in 5..9 { + assert_eq!(get_tree_size(count), 7); + } + for count in 9..17 { + assert_eq!(get_tree_size(count), 15); + } + for count in 17..33 { + assert_eq!(get_tree_size(count), 31); + } + assert_eq!(get_tree_size((1 << 16) - 1), (1 << 16) - 1); + assert_eq!(get_tree_size(1 << 16), (1 << 16) - 1); + assert_eq!(get_tree_size((1 << 16) + 1), (1 << 17) - 1); + assert_eq!(get_tree_size((1 << 17) - 1), (1 << 17) - 1); + assert_eq!(get_tree_size(1 << 17), (1 << 17) - 1); + assert_eq!(get_tree_size((1 << 17) + 1), (1 << 18) - 1); + } + // Asserts that empty weights will return empty shuffle. #[test] fn test_weighted_shuffle_empty_weights() { @@ -357,4 +411,20 @@ mod tests { assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0])); } } + + #[test] + fn test_weighted_shuffle_paranoid() { + let mut rng = rand::thread_rng(); + for size in 0..1351 { + let weights: Vec<_> = repeat_with(|| rng.gen_range(0..1000)).take(size).collect(); + let seed = rng.gen::<[u8; 32]>(); + let mut rng = ChaChaRng::from_seed(seed); + let shuffle_slow = weighted_shuffle_slow(&mut rng.clone(), weights.clone()); + let shuffle = WeightedShuffle::new("", &weights); + if size > 0 { + assert_eq!(shuffle.first(&mut rng.clone()), Some(shuffle_slow[0])); + } + assert_eq!(shuffle.shuffle(&mut rng).collect::>(), shuffle_slow); + } + } }