Skip to content

Commit

Permalink
implements weighted shuffle using binary tree
Browse files Browse the repository at this point in the history
  • Loading branch information
behzadnouri committed Mar 11, 2024
1 parent f205d0e commit 9558d6d
Showing 1 changed file with 105 additions and 52 deletions.
157 changes: 105 additions & 52 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,116 +18,117 @@ use {
/// non-zero weighted indices.
#[derive(Clone)]
pub struct WeightedShuffle<T> {
arr: Vec<T>, // Underlying array implementing binary indexed tree.
arr: Vec<T>, // Underlying array implementing binary tree.
sum: T, // Current sum of weights, excluding already selected indices.
msb: usize, // Most significant bit of indices.
zeros: Vec<usize>, // Indices of zero weighted entries.
}

// The implementation uses binary indexed tree:
// https://en.wikipedia.org/wiki/Fenwick_tree
// to maintain cumulative sum of weights excluding already selected indices
// over self.arr.
impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + CheckedAdd,
{
/// If weights are negative or overflow the total sum
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
let size = weights.len() + 1;
let zero = <T as Default>::default();
let mut arr = vec![zero; size];
let mut arr = vec![zero; get_tree_size(weights.len())];
let msb = (arr.len() + 1) >> 2;
let mut sum = zero;
let mut zeros = Vec::default();
let mut num_negative = 0;
let mut num_overflow = 0;
for (mut k, &weight) in (1usize..).zip(weights) {
for (k, &weight) in weights.iter().enumerate() {
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
if !(weight >= zero) {
zeros.push(k - 1);
zeros.push(k);
num_negative += 1;
continue;
}
if weight == zero {
zeros.push(k - 1);
zeros.push(k);
continue;
}
sum = match sum.checked_add(&weight) {
Some(val) => val,
None => {
zeros.push(k - 1);
zeros.push(k);
num_overflow += 1;
continue;
}
};
while k < size {
arr[k] += weight;
k += k & k.wrapping_neg();
}
let index = get_mask_bits(msb).fold(0, |index, mask| {
(index << 1)
+ if k & mask == 0 {
arr[index] += weight;
1
} else {
2
}
});
arr[index] = weight
}
if num_negative > 0 {
datapoint_error!("weighted-shuffle-negative", (name, num_negative, i64));
}
if num_overflow > 0 {
datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64));
}
Self { arr, sum, zeros }
Self {
arr,
sum,
msb,
zeros,
}
}
}

impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
{
// Returns cumulative sum of current weights upto index k (inclusive).
fn cumsum(&self, mut k: usize) -> T {
let mut out = <T as Default>::default();
while k != 0 {
out += self.arr[k];
k ^= k & k.wrapping_neg();
}
out
}

// Removes given weight at index k.
fn remove(&mut self, mut k: usize, weight: T) {
fn remove(&mut self, k: usize, weight: T) {
self.sum -= weight;
let size = self.arr.len();
while k < size {
self.arr[k] -= weight;
k += k & k.wrapping_neg();
}
let index = get_mask_bits(self.msb).fold(0, |index, mask| {
(index << 1)
+ if k & mask == 0 {
self.arr[index] -= weight;
1
} else {
2
}
});
self.arr[index] -= weight;
}

// Returns smallest index such that self.cumsum(k) > val,
// Returns smallest index such that cumsum of weights[..=k] > val,
// along with its respective weight.
fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) {
fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) {
let zero = <T as Default>::default();
debug_assert!(val >= zero);
debug_assert!(val < self.sum);
let mut lo = (/*index:*/ 0, /*cumsum:*/ zero);
let mut hi = (self.arr.len() - 1, self.sum);
while lo.0 + 1 < hi.0 {
let k = lo.0 + (hi.0 - lo.0) / 2;
let sum = self.cumsum(k);
if sum <= val {
lo = (k, sum);
let (index, k) = get_mask_bits(self.msb).fold((0, 0), |(index, k), mask| {
if val >= self.arr[k] {
val -= self.arr[k];
(index | mask, (k << 1) + 2)
} else {
hi = (k, sum);
(index, (k << 1) + 1)
}
}
debug_assert!(lo.1 <= val);
debug_assert!(hi.1 > val);
(hi.0, hi.1 - lo.1)
});
(index, self.arr[k])
}

pub fn remove_index(&mut self, index: usize) {
pub fn remove_index(&mut self, k: usize) {
let zero = <T as Default>::default();
let weight = self.cumsum(index + 1) - self.cumsum(index);
let index = get_mask_bits(self.msb).fold(0, |index, mask| {
(index << 1) + if k & mask == 0 { 1 } else { 2 }
});
let weight = self.arr[index];
if weight != zero {
self.remove(index + 1, weight);
} else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) {
self.remove(k, weight);
} else if let Some(index) = self.zeros.iter().position(|ix| *ix == k) {
self.zeros.remove(index);
}
}
Expand All @@ -143,7 +144,7 @@ where
if self.sum > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
let (index, _weight) = WeightedShuffle::search(self, sample);
return Some(index - 1);
return Some(index);
}
if self.zeros.is_empty() {
return None;
Expand All @@ -164,7 +165,7 @@ where
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
let (index, weight) = WeightedShuffle::search(&self, sample);
self.remove(index, weight);
return Some(index - 1);
return Some(index);
}
if self.zeros.is_empty() {
return None;
Expand All @@ -176,6 +177,25 @@ where
}
}

// Maps number of items to the size of the binary tree
// holding those items on the leaves.
fn get_tree_size(count: usize) -> usize {
let shift = usize::BITS - count.leading_zeros()
+ if count.is_power_of_two() || count == 0 {
0
} else {
1
};
(1usize << shift) - 1
}

fn get_mask_bits(msb: usize) -> impl Iterator<Item = usize> {
debug_assert!(msb.is_power_of_two() || msb == 0);
std::iter::successors((msb != 0).then_some(msb), |&bit| {
(bit != 1).then_some(bit >> 1)
})
}

#[cfg(test)]
mod tests {
use {
Expand Down Expand Up @@ -218,6 +238,39 @@ mod tests {
shuffle
}

#[test]
fn test_get_tree_size() {
assert_eq!(get_tree_size(0), 0);
assert_eq!(get_tree_size(1), 1);
assert_eq!(get_tree_size(2), 3);
assert_eq!(get_tree_size(3), 7);
assert_eq!(get_tree_size(4), 7);
for count in 5..9 {
assert_eq!(get_tree_size(count), 15);
}
for count in 9..17 {
assert_eq!(get_tree_size(count), 31);
}
for count in 17..33 {
assert_eq!(get_tree_size(count), 63);
}
assert_eq!(get_tree_size((1 << 16) - 1), (1 << 17) - 1);
assert_eq!(get_tree_size(1 << 16), (1 << 17) - 1);
assert_eq!(get_tree_size((1 << 16) + 1), (1 << 18) - 1);
assert_eq!(get_tree_size((1 << 17) - 1), (1 << 18) - 1);
assert_eq!(get_tree_size(1 << 17), (1 << 18) - 1);
assert_eq!(get_tree_size((1 << 17) + 1), (1 << 19) - 1);
}

#[test]
fn test_get_mask_bits() {
assert_eq!(get_mask_bits(0).next(), None);
assert_eq!(get_mask_bits(1).collect::<Vec<_>>(), [1]);
assert_eq!(get_mask_bits(2).collect::<Vec<_>>(), [2, 1]);
assert_eq!(get_mask_bits(4).collect::<Vec<_>>(), [4, 2, 1]);
assert_eq!(get_mask_bits(8).collect::<Vec<_>>(), [8, 4, 2, 1]);
}

// Asserts that empty weights will return empty shuffle.
#[test]
fn test_weighted_shuffle_empty_weights() {
Expand Down

0 comments on commit 9558d6d

Please sign in to comment.