Skip to content

Commit

Permalink
implements weighted shuffle using binary tree
Browse files Browse the repository at this point in the history
This is partial port of firedancer's implementation of weighted shuffle:
https://github.com/firedancer-io/firedancer/blob/3401bfc26/src/ballet/wsample/fd_wsample.c

Though Fenwick trees use less space, inverse queries require an
additional O(log n) factor for binary search resulting an overall
O(n log n log n) performance for weighted shuffle.

This commit instead uses a binary tree where each node contains the sum
of all weights in its left sub-tree. The weights themselves are
implicitly stored at the leaves. Inverse queries and updates to the tree
all can be done O(log n) resulting an overall O(n log n) weighted
shuffle implementation.

Based on benchmarks, this results in 24% improvement in
WeightedShuffle::shuffle:

Fenwick tree:
    test bench_weighted_shuffle_new     ... bench:      36,686 ns/iter (+/- 191)
    test bench_weighted_shuffle_shuffle ... bench:     342,625 ns/iter (+/- 4,067)

Binary tree:
    test bench_weighted_shuffle_new     ... bench:      59,131 ns/iter (+/- 362)
    test bench_weighted_shuffle_shuffle ... bench:     260,194 ns/iter (+/- 11,195)

Though WeightedShuffle::new is now slower, it generally can be cached
and reused as in Turbine:
https://github.com/anza-xyz/agave/blob/b3fd87fe8/turbine/src/cluster_nodes.rs#L68

Additionally the new code has better asymptotic performance. For
example with 20_000 weights WeightedShuffle::shuffle is 31% faster:

Fenwick tree:
    test bench_weighted_shuffle_new     ... bench:     255,071 ns/iter (+/- 9,591)
    test bench_weighted_shuffle_shuffle ... bench:   2,466,058 ns/iter (+/- 9,873)

Binary tree:
    test bench_weighted_shuffle_new     ... bench:     830,727 ns/iter (+/- 10,210)
    test bench_weighted_shuffle_shuffle ... bench:   1,696,160 ns/iter (+/- 75,271)
  • Loading branch information
behzadnouri committed Mar 14, 2024
1 parent b3fd87f commit 18531c9
Showing 1 changed file with 128 additions and 58 deletions.
186 changes: 128 additions & 58 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,52 +18,54 @@ use {
/// non-zero weighted indices.
#[derive(Clone)]
pub struct WeightedShuffle<T> {
arr: Vec<T>, // Underlying array implementing binary indexed tree.
sum: T, // Current sum of weights, excluding already selected indices.
// Underlying array implementing binary tree.
// tree[i] is the sum of weights in the left sub-tree of node i.
tree: Vec<T>,
// Current sum of all weights, excluding already sampled ones.
weight: T,
zeros: Vec<usize>, // Indices of zero weighted entries.
}

// The implementation uses binary indexed tree:
// https://en.wikipedia.org/wiki/Fenwick_tree
// to maintain cumulative sum of weights excluding already selected indices
// over self.arr.
impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + CheckedAdd,
{
/// If weights are negative or overflow the total sum
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
let size = weights.len() + 1;
let zero = <T as Default>::default();
let mut arr = vec![zero; size];
let mut tree = vec![zero; get_tree_size(weights.len())];
let mut sum = zero;
let mut zeros = Vec::default();
let mut num_negative = 0;
let mut num_overflow = 0;
for (mut k, &weight) in (1usize..).zip(weights) {
for (k, &weight) in weights.iter().enumerate() {
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
if !(weight >= zero) {
zeros.push(k - 1);
zeros.push(k);
num_negative += 1;
continue;
}
if weight == zero {
zeros.push(k - 1);
zeros.push(k);
continue;
}
sum = match sum.checked_add(&weight) {
Some(val) => val,
None => {
zeros.push(k - 1);
zeros.push(k);
num_overflow += 1;
continue;
}
};
while k < size {
arr[k] += weight;
k += k & k.wrapping_neg();
let mut index = tree.len() + k;
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
if offset > 0 {
tree[index] += weight;
}
}
}
if num_negative > 0 {
Expand All @@ -72,62 +74,77 @@ where
if num_overflow > 0 {
datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64));
}
Self { arr, sum, zeros }
Self {
tree,
weight: sum,
zeros,
}
}
}

impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
{
// Returns cumulative sum of current weights upto index k (inclusive).
fn cumsum(&self, mut k: usize) -> T {
let mut out = <T as Default>::default();
while k != 0 {
out += self.arr[k];
k ^= k & k.wrapping_neg();
}
out
}

// Removes given weight at index k.
fn remove(&mut self, mut k: usize, weight: T) {
self.sum -= weight;
let size = self.arr.len();
while k < size {
self.arr[k] -= weight;
k += k & k.wrapping_neg();
fn remove(&mut self, k: usize, weight: T) {
self.weight -= weight;
let mut index = self.tree.len() + k;
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
if offset > 0 {
self.tree[index] -= weight;
}
}
}

// Returns smallest index such that self.cumsum(k) > val,
// Returns smallest index such that cumsum of weights[..=k] > val,
// along with its respective weight.
fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) {
fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) {
let zero = <T as Default>::default();
debug_assert!(val >= zero);
debug_assert!(val < self.sum);
let mut lo = (/*index:*/ 0, /*cumsum:*/ zero);
let mut hi = (self.arr.len() - 1, self.sum);
while lo.0 + 1 < hi.0 {
let k = lo.0 + (hi.0 - lo.0) / 2;
let sum = self.cumsum(k);
if sum <= val {
lo = (k, sum);
debug_assert!(val < self.weight);
let mut index = 0;
let mut weight = self.weight;
while index < self.tree.len() {
if val < self.tree[index] {
weight = self.tree[index];
index = (index << 1) + 1;
} else {
hi = (k, sum);
weight -= self.tree[index];
val -= self.tree[index];
index = (index << 1) + 2;
}
}
debug_assert!(lo.1 <= val);
debug_assert!(hi.1 > val);
(hi.0, hi.1 - lo.1)
(index - self.tree.len(), weight)
}

pub fn remove_index(&mut self, index: usize) {
let zero = <T as Default>::default();
let weight = self.cumsum(index + 1) - self.cumsum(index);
if weight != zero {
self.remove(index + 1, weight);
} else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) {
pub fn remove_index(&mut self, k: usize) {
let mut index = self.tree.len() + k;
let mut weight = <T as Default>::default(); // zero
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
if offset > 0 {
if self.tree[index] != weight {
self.remove(k, self.tree[index] - weight);
} else {
self.remove_zero(k);
}
return;
}
weight += self.tree[index];
}
if self.weight != weight {
self.remove(k, self.weight - weight);
} else {
self.remove_zero(k);
}
}

fn remove_zero(&mut self, k: usize) {
if let Some(index) = self.zeros.iter().position(|&ix| ix == k) {
self.zeros.remove(index);
}
}
Expand All @@ -140,10 +157,10 @@ where
// Equivalent to weighted_shuffle.shuffle(&mut rng).next()
pub fn first<R: Rng>(&self, rng: &mut R) -> Option<usize> {
let zero = <T as Default>::default();
if self.sum > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
if self.weight > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.weight, rng);
let (index, _weight) = WeightedShuffle::search(self, sample);
return Some(index - 1);
return Some(index);
}
if self.zeros.is_empty() {
return None;
Expand All @@ -160,11 +177,11 @@ where
pub fn shuffle<R: Rng>(mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
std::iter::from_fn(move || {
let zero = <T as Default>::default();
if self.sum > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
if self.weight > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.weight, rng);
let (index, weight) = WeightedShuffle::search(&self, sample);
self.remove(index, weight);
return Some(index - 1);
return Some(index);
}
if self.zeros.is_empty() {
return None;
Expand All @@ -176,6 +193,19 @@ where
}
}

// Maps number of items to the "internal" size of the binary tree "implicitly"
// holding those items on the leaves.
fn get_tree_size(count: usize) -> usize {
let shift = usize::BITS
- count.leading_zeros()
- if count.is_power_of_two() && count != 1 {
1
} else {
0
};
(1usize << shift) - 1
}

#[cfg(test)]
mod tests {
use {
Expand Down Expand Up @@ -218,6 +248,30 @@ mod tests {
shuffle
}

#[test]
fn test_get_tree_size() {
assert_eq!(get_tree_size(0), 0);
assert_eq!(get_tree_size(1), 1);
assert_eq!(get_tree_size(2), 1);
assert_eq!(get_tree_size(3), 3);
assert_eq!(get_tree_size(4), 3);
for count in 5..9 {
assert_eq!(get_tree_size(count), 7);
}
for count in 9..17 {
assert_eq!(get_tree_size(count), 15);
}
for count in 17..33 {
assert_eq!(get_tree_size(count), 31);
}
assert_eq!(get_tree_size((1 << 16) - 1), (1 << 16) - 1);
assert_eq!(get_tree_size(1 << 16), (1 << 16) - 1);
assert_eq!(get_tree_size((1 << 16) + 1), (1 << 17) - 1);
assert_eq!(get_tree_size((1 << 17) - 1), (1 << 17) - 1);
assert_eq!(get_tree_size(1 << 17), (1 << 17) - 1);
assert_eq!(get_tree_size((1 << 17) + 1), (1 << 18) - 1);
}

// Asserts that empty weights will return empty shuffle.
#[test]
fn test_weighted_shuffle_empty_weights() {
Expand Down Expand Up @@ -357,4 +411,20 @@ mod tests {
assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0]));
}
}

#[test]
fn test_weighted_shuffle_paranoid() {
let mut rng = rand::thread_rng();
for size in 0..1351 {
let weights: Vec<_> = repeat_with(|| rng.gen_range(0..1000)).take(size).collect();
let seed = rng.gen::<[u8; 32]>();
let mut rng = ChaChaRng::from_seed(seed);
let shuffle_slow = weighted_shuffle_slow(&mut rng.clone(), weights.clone());
let shuffle = WeightedShuffle::new("", &weights);
if size > 0 {
assert_eq!(shuffle.first(&mut rng.clone()), Some(shuffle_slow[0]));
}
assert_eq!(shuffle.shuffle(&mut rng).collect::<Vec<_>>(), shuffle_slow);
}
}
}

0 comments on commit 18531c9

Please sign in to comment.