Skip to content

Commit

Permalink
implements weighted shuffle using binary tree
Browse files Browse the repository at this point in the history
  • Loading branch information
behzadnouri committed Mar 14, 2024
1 parent 51dc7e6 commit 8e8e949
Showing 1 changed file with 113 additions and 50 deletions.
163 changes: 113 additions & 50 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,52 +18,51 @@ use {
/// non-zero weighted indices.
#[derive(Clone)]
pub struct WeightedShuffle<T> {
arr: Vec<T>, // Underlying array implementing binary indexed tree.
arr: Vec<T>, // Underlying array implementing binary tree.
sum: T, // Current sum of weights, excluding already selected indices.
zeros: Vec<usize>, // Indices of zero weighted entries.
}

// The implementation uses binary indexed tree:
// https://en.wikipedia.org/wiki/Fenwick_tree
// to maintain cumulative sum of weights excluding already selected indices
// over self.arr.
impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + CheckedAdd,
{
/// If weights are negative or overflow the total sum
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
let size = weights.len() + 1;
let zero = <T as Default>::default();
let mut arr = vec![zero; size];
let mut arr = vec![zero; get_tree_size(weights.len())];
let mut sum = zero;
let mut zeros = Vec::default();
let mut num_negative = 0;
let mut num_overflow = 0;
for (mut k, &weight) in (1usize..).zip(weights) {
for (k, &weight) in weights.iter().enumerate() {
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
if !(weight >= zero) {
zeros.push(k - 1);
zeros.push(k);
num_negative += 1;
continue;
}
if weight == zero {
zeros.push(k - 1);
zeros.push(k);
continue;
}
sum = match sum.checked_add(&weight) {
Some(val) => val,
None => {
zeros.push(k - 1);
zeros.push(k);
num_overflow += 1;
continue;
}
};
while k < size {
arr[k] += weight;
k += k & k.wrapping_neg();
let mut index = arr.len() + k;
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
if offset > 0 {
arr[index] += weight;
}
}
}
if num_negative > 0 {
Expand All @@ -80,54 +79,65 @@ impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
{
// Returns cumulative sum of current weights upto index k (inclusive).
fn cumsum(&self, mut k: usize) -> T {
let mut out = <T as Default>::default();
while k != 0 {
out += self.arr[k];
k ^= k & k.wrapping_neg();
}
out
}

// Removes given weight at index k.
fn remove(&mut self, mut k: usize, weight: T) {
fn remove(&mut self, k: usize, weight: T) {
self.sum -= weight;
let size = self.arr.len();
while k < size {
self.arr[k] -= weight;
k += k & k.wrapping_neg();
let mut index = self.arr.len() + k;
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
if offset > 0 {
self.arr[index] -= weight;
}
}
}

// Returns smallest index such that self.cumsum(k) > val,
// Returns smallest index such that cumsum of weights[..=k] > val,
// along with its respective weight.
fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) {
fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) {
let zero = <T as Default>::default();
debug_assert!(val >= zero);
debug_assert!(val < self.sum);
let mut lo = (/*index:*/ 0, /*cumsum:*/ zero);
let mut hi = (self.arr.len() - 1, self.sum);
while lo.0 + 1 < hi.0 {
let k = lo.0 + (hi.0 - lo.0) / 2;
let sum = self.cumsum(k);
if sum <= val {
lo = (k, sum);
let mut index = 0;
let mut weight = self.sum;
while index < self.arr.len() {
if val < self.arr[index] {
weight = self.arr[index];
index = (index << 1) + 1;
} else {
hi = (k, sum);
weight -= self.arr[index];
val -= self.arr[index];
index = (index << 1) + 2;
}
}
debug_assert!(lo.1 <= val);
debug_assert!(hi.1 > val);
(hi.0, hi.1 - lo.1)
(index - self.arr.len(), weight)
}

pub fn remove_index(&mut self, index: usize) {
let zero = <T as Default>::default();
let weight = self.cumsum(index + 1) - self.cumsum(index);
if weight != zero {
self.remove(index + 1, weight);
} else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) {
pub fn remove_index(&mut self, k: usize) {
let mut index = self.arr.len() + k;
let mut weight = <T as Default>::default(); // zero
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
if offset > 0 {
if self.arr[index] != weight {
self.remove(k, self.arr[index] - weight);
} else {
self.remove_zero(k);
}
return;
}
weight += self.arr[index];
}
if self.sum != weight {
self.remove(k, self.sum - weight);
} else {
self.remove_zero(k);
}
}

fn remove_zero(&mut self, k: usize) {
if let Some(index) = self.zeros.iter().position(|&ix| ix == k) {
self.zeros.remove(index);
}
}
Expand All @@ -143,7 +153,7 @@ where
if self.sum > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
let (index, _weight) = WeightedShuffle::search(self, sample);
return Some(index - 1);
return Some(index);
}
if self.zeros.is_empty() {
return None;
Expand All @@ -164,7 +174,7 @@ where
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
let (index, weight) = WeightedShuffle::search(&self, sample);
self.remove(index, weight);
return Some(index - 1);
return Some(index);
}
if self.zeros.is_empty() {
return None;
Expand All @@ -176,6 +186,19 @@ where
}
}

// Maps number of items to the "internal" size of the binary tree "implicitly"
// holding those items on the leaves.
fn get_tree_size(count: usize) -> usize {
let shift = usize::BITS
- count.leading_zeros()
- if count.is_power_of_two() && count != 1 {
1
} else {
0
};
(1usize << shift) - 1
}

#[cfg(test)]
mod tests {
use {
Expand Down Expand Up @@ -218,6 +241,30 @@ mod tests {
shuffle
}

#[test]
fn test_get_tree_size() {
assert_eq!(get_tree_size(0), 0);
assert_eq!(get_tree_size(1), 1);
assert_eq!(get_tree_size(2), 1);
assert_eq!(get_tree_size(3), 3);
assert_eq!(get_tree_size(4), 3);
for count in 5..9 {
assert_eq!(get_tree_size(count), 7);
}
for count in 9..17 {
assert_eq!(get_tree_size(count), 15);
}
for count in 17..33 {
assert_eq!(get_tree_size(count), 31);
}
assert_eq!(get_tree_size((1 << 16) - 1), (1 << 16) - 1);
assert_eq!(get_tree_size(1 << 16), (1 << 16) - 1);
assert_eq!(get_tree_size((1 << 16) + 1), (1 << 17) - 1);
assert_eq!(get_tree_size((1 << 17) - 1), (1 << 17) - 1);
assert_eq!(get_tree_size(1 << 17), (1 << 17) - 1);
assert_eq!(get_tree_size((1 << 17) + 1), (1 << 18) - 1);
}

// Asserts that empty weights will return empty shuffle.
#[test]
fn test_weighted_shuffle_empty_weights() {
Expand Down Expand Up @@ -357,4 +404,20 @@ mod tests {
assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0]));
}
}

#[test]
fn test_weighted_shuffle_paranoid() {
let mut rng = rand::thread_rng();
for size in 0..1351 {
let weights: Vec<_> = repeat_with(|| rng.gen_range(0..1000)).take(size).collect();
let seed = rng.gen::<[u8; 32]>();
let mut rng = ChaChaRng::from_seed(seed);
let shuffle_slow = weighted_shuffle_slow(&mut rng.clone(), weights.clone());
let shuffle = WeightedShuffle::new("", &weights);
if size > 0 {
assert_eq!(shuffle.first(&mut rng.clone()), Some(shuffle_slow[0]));
}
assert_eq!(shuffle.shuffle(&mut rng).collect::<Vec<_>>(), shuffle_slow);
}
}
}

0 comments on commit 8e8e949

Please sign in to comment.