Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements for shuffle and partial_shuffle #1272

Merged
merged 11 commits into from
Jan 8, 2023
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,9 @@ criterion = { version = "0.4" }
[[bench]]
name = "seq_choose"
path = "benches/seq_choose.rs"
harness = false

[[bench]]
name = "shuffle"
path = "benches/shuffle.rs"
harness = false
2 changes: 1 addition & 1 deletion benches/seq_choose.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2022 Developers of the Rand project.
// Copyright 2018-2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand Down
50 changes: 50 additions & 0 deletions benches/shuffle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2018-2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::prelude::*;
use rand::SeedableRng;

criterion_group!(
name = benches;
config = Criterion::default();
targets = bench
);
criterion_main!(benches);

pub fn bench(c: &mut Criterion) {
bench_rng::<rand_chacha::ChaCha12Rng>(c, "ChaCha12");
bench_rng::<rand_pcg::Pcg32>(c, "Pcg32");
bench_rng::<rand_pcg::Pcg64>(c, "Pcg64");
}

fn bench_rng<Rng: RngCore + SeedableRng>(c: &mut Criterion, rng_name: &'static str) {
for length in [1, 2, 3, 10, 100, 1000, 10000].map(|x| black_box(x)) {
c.bench_function(format!("shuffle_{length}_{rng_name}").as_str(), |b| {
let mut rng = Rng::seed_from_u64(123);
let mut vec: Vec<usize> = (0..length).collect();
b.iter(|| {
vec.shuffle(&mut rng);
vec[0]
})
});

if length >= 10 {
c.bench_function(
format!("partial_shuffle_{length}_{rng_name}").as_str(),
|b| {
let mut rng = Rng::seed_from_u64(123);
let mut vec: Vec<usize> = (0..length).collect();
b.iter(|| {
vec.partial_shuffle(&mut rng, length / 2);
vec[0]
})
},
);
}
}
}
8 changes: 8 additions & 0 deletions src/seq/coin_flipper.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
// Copyright 2018-2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use crate::RngCore;

pub(crate) struct CoinFlipper<R: RngCore> {
Expand Down
108 changes: 108 additions & 0 deletions src/seq/increasing_uniform.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright 2018-2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use crate::{Rng, RngCore};

/// Similar to a Uniform distribution,
/// but after returning a number in the range [0,n], n is increased by 1.
pub(crate) struct IncreasingUniform<R: RngCore> {
pub rng: R,
n: u32,
// Chunk is a random number in [0, (n + 1) * (n + 2) *..* (n + chunk_remaining) )
chunk: u32,
chunk_remaining: u8,
}

impl<R: RngCore> IncreasingUniform<R> {
/// Create a dice roller.
/// The next item returned will be a random number in the range [0,n]
pub fn new(rng: R, n: u32) -> Self {
// If n = 0, the first number returned will always be 0
// so we don't need to generate a random number
let chunk_remaining = if n == 0 { 1 } else { 0 };
Self {
rng,
n,
chunk: 0,
chunk_remaining,
}
}

/// Returns a number in [0,n] and increments n by 1.
/// Generates new random bits as needed
/// Panics if `n >= u32::MAX`
#[inline]
pub fn next_index(&mut self) -> usize {
let next_n = self.n + 1;

// There's room for further optimisation here:
// gen_range uses rejection sampling (or other method; see #1196) to avoid bias.
// When the initial sample is biased for range 0..bound
// it may still be viable to use for a smaller bound
// (especially if small biases are considered acceptable).

let next_chunk_remaining = self.chunk_remaining.checked_sub(1).unwrap_or_else(|| {
// If the chunk is empty, generate a new chunk
let (bound, remaining) = calculate_bound_u32(next_n);
// bound = (n + 1) * (n + 2) *..* (n + remaining)
self.chunk = self.rng.gen_range(0..bound);
// Chunk is a random number in
// [0, (n + 1) * (n + 2) *..* (n + remaining) )

remaining - 1
});

let result = if next_chunk_remaining == 0 {
// `chunk` is a random number in the range [0..n+1)
// Because `chunk_remaining` is about to be set to zero
// we do not need to clear the chunk here
self.chunk as usize
} else {
// `chunk` is a random number in a range that is a multiple of n+1
// so r will be a random number in [0..n+1)
let r = self.chunk % next_n;
self.chunk /= next_n;
r as usize
Comment on lines +68 to +70
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's probably also room for further optimisation here: modulus is a slow operation (see https://www.pcg-random.org/posts/bounded-rands.html).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did read that article and it helped me find some of the optimizations I used for this. I also tried using a method based on bitmask but it turned out about 50% slower than this. Obviously I could easily have missed something.

};

self.chunk_remaining = next_chunk_remaining;
self.n = next_n;
result
}
}

#[inline]
/// Calculates `bound`, `count` such that bound (m)*(m+1)*..*(m + remaining - 1)
fn calculate_bound_u32(m: u32) -> (u32, u8) {
debug_assert!(m > 0);
#[inline]
const fn inner(m: u32) -> (u32, u8) {
let mut product = m;
let mut current = m + 1;

loop {
if let Some(p) = u32::checked_mul(product, current) {
product = p;
current += 1;
} else {
// Count has a maximum value of 13 for when min is 1 or 2
let count = (current - m) as u8;
return (product, count);
}
}
}

const RESULT2: (u32, u8) = inner(2);
if m == 2 {
// Making this value a constant instead of recalculating it
// gives a significant (~50%) performance boost for small shuffles
return RESULT2;
}

inner(m)
}
51 changes: 33 additions & 18 deletions src/seq/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018 Developers of the Rand project.
// Copyright 2018-2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand Down Expand Up @@ -29,6 +29,8 @@ mod coin_flipper;
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub mod index;

mod increasing_uniform;

#[cfg(feature = "alloc")]
use core::ops::Index;

Expand All @@ -42,6 +44,7 @@ use crate::distributions::WeightedError;
use crate::Rng;

use self::coin_flipper::CoinFlipper;
use self::increasing_uniform::IncreasingUniform;

/// Extension trait on slices, providing random mutation and sampling methods.
///
Expand Down Expand Up @@ -620,10 +623,11 @@ impl<T> SliceRandom for [T] {
where
R: Rng + ?Sized,
{
for i in (1..self.len()).rev() {
// invariant: elements with index > i have been locked in place.
self.swap(i, gen_index(rng, i + 1));
if self.len() <= 1 {
// There is no need to shuffle an empty or single element slice
return;
}
self.partial_shuffle(rng, self.len());
}

fn partial_shuffle<R>(
Expand All @@ -632,19 +636,30 @@ impl<T> SliceRandom for [T] {
where
R: Rng + ?Sized,
{
// This applies Durstenfeld's algorithm for the
// [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
// for an unbiased permutation, but exits early after choosing `amount`
// elements.

let len = self.len();
let end = if amount >= len { 0 } else { len - amount };
let m = self.len().saturating_sub(amount);

for i in (end..len).rev() {
// invariant: elements with index > i have been locked in place.
self.swap(i, gen_index(rng, i + 1));
// The algorithm below is based on Durstenfeld's algorithm for the
// [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
// for an unbiased permutation.
// It ensures that the last `amount` elements of the slice
// are randomly selected from the whole slice.

//`IncreasingUniform::next_index()` is faster than `gen_index`
//but only works for 32 bit integers
//So we must use the slow method if the slice is longer than that.
if self.len() < (u32::MAX as usize) {
wainwrightmark marked this conversation as resolved.
Show resolved Hide resolved
let mut chooser = IncreasingUniform::new(rng, m as u32);
for i in m..self.len() {
let index = chooser.next_index();
self.swap(i, index);
}
} else {
for i in m..self.len() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to reverse the iterator (both loops). Your code can only "choose" the last element of the list with probability 1/len when it should be m/len.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ermm, I'm pretty sure I've got this right. The last element gets swapped to a random place in the list so it has a m/len probability of being in the first m elements. Earlier elements are more likely to be chosen initially but can get booted out by later ones. The test_shuffle test is checking this and I've also tried similar tests with longer lists and more runs.

The reason I don't reverse the iterator is because the increasing_uniform needs i to increase and a decreasing version would be more complicated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. We previously reversed since this way the proof by induction is easier. But we can also prove this algorithm works.

First, lets not use m = len - amount since in the last PR we used m = amount. I'll continue to use end = len - amount.

Lets say we have a list elts = [e0, e1, .., ei, ..] of length len. Elements are "chosen" if they appear in elts[end..len] after the algorithm; additionally we need to show that this slice is fully shuffled.

Algorithm is:

for i in end..len {
    elts.swap(i, rng.sample_range(0..=i));
}

For any length, for amount = 0 or amount = 1, this is clearly correct. We'll prove by induction, assuming that the algorithm is already proven correct for amount-1 and len-1 (so that end does not change and the algorithm only has one last swap to perform).

Thus, we assume:

  • For any elt ei, we have P(ei in elts[0..end]) = end/(len-1) [here we say nothing about element order]
  • For any elt ei, for any k in end..(len-1), P(elts[k] = ei) = (amount-1)/(len-1) [fully shuffled]

We perform the last step of the algorithm: let x = sample_range(0..=len); elts.swap(len-1, x);. Now:

  • Any element in elts[0..end] is moved to elts[len-1] with probability 1/len, thus for any elt ei except e_last, P(ei in elts[0..end]) = end/(len-1) * (len-1)/len = end/len
  • For any elt ei previously in elts[end..len-1], the chance it is not moved is (len-1)/len, thus, for these ei, for any k in end..(len-1), P(elts[k] = ei) = (amount-1)/(len-1) * (len-1)/len = (amount-1)/len
  • For any elt ei previously in elts[end..len-1], P(elts[len-1] = ei) = 1/len
  • The previous two points together imply that for any ei previously in elts[end..len-1], for any k in end..len, P(elts[k] = ei) = (amount-1+1)/len = amount/len
  • Element e_last may appear in any position with probability 1/len

Thus each element has chance amount/len to appear in ents[end..len] and this slice is fully shuffled.

let index = gen_index(rng, i + 1);
self.swap(i, index);
}
}
let r = self.split_at_mut(end);
let r = self.split_at_mut(m);
(r.1, r.0)
}
}
Expand Down Expand Up @@ -765,11 +780,11 @@ mod test {

let mut r = crate::test::rng(414);
nums.shuffle(&mut r);
assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]);
assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]);
nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let res = nums.partial_shuffle(&mut r, 6);
assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]);
assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]);
assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]);
assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]);
}

#[derive(Clone)]
Expand Down