diff --git a/Cargo.toml b/Cargo.toml index 19f7573851..ec0e4d7767 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,4 +79,9 @@ criterion = { version = "0.4" } [[bench]] name = "seq_choose" path = "benches/seq_choose.rs" +harness = false + +[[bench]] +name = "shuffle" +path = "benches/shuffle.rs" harness = false \ No newline at end of file diff --git a/benches/seq_choose.rs b/benches/seq_choose.rs index 44b4bdf972..2c34d77ced 100644 --- a/benches/seq_choose.rs +++ b/benches/seq_choose.rs @@ -1,4 +1,4 @@ -// Copyright 2018-2022 Developers of the Rand project. +// Copyright 2018-2023 Developers of the Rand project. // // Licensed under the Apache License, Version 2.0 or the MIT license diff --git a/benches/shuffle.rs b/benches/shuffle.rs new file mode 100644 index 0000000000..3d6878219f --- /dev/null +++ b/benches/shuffle.rs @@ -0,0 +1,50 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::*; +use rand::SeedableRng; + +criterion_group!( +name = benches; +config = Criterion::default(); +targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + bench_rng::(c, "ChaCha12"); + bench_rng::(c, "Pcg32"); + bench_rng::(c, "Pcg64"); +} + +fn bench_rng(c: &mut Criterion, rng_name: &'static str) { + for length in [1, 2, 3, 10, 100, 1000, 10000].map(|x| black_box(x)) { + c.bench_function(format!("shuffle_{length}_{rng_name}").as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + let mut vec: Vec = (0..length).collect(); + b.iter(|| { + vec.shuffle(&mut rng); + vec[0] + }) + }); + + if length >= 10 { + c.bench_function( + format!("partial_shuffle_{length}_{rng_name}").as_str(), + |b| { + let mut rng = Rng::seed_from_u64(123); + let mut vec: Vec = (0..length).collect(); + b.iter(|| { + vec.partial_shuffle(&mut rng, length / 2); + vec[0] + }) + }, + ); + } + } +} diff --git a/src/seq/coin_flipper.rs b/src/seq/coin_flipper.rs index 77c18ded43..05f18d71b2 100644 --- a/src/seq/coin_flipper.rs +++ b/src/seq/coin_flipper.rs @@ -1,3 +1,11 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + use crate::RngCore; pub(crate) struct CoinFlipper { diff --git a/src/seq/increasing_uniform.rs b/src/seq/increasing_uniform.rs new file mode 100644 index 0000000000..3208c656fb --- /dev/null +++ b/src/seq/increasing_uniform.rs @@ -0,0 +1,108 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::{Rng, RngCore}; + +/// Similar to a Uniform distribution, +/// but after returning a number in the range [0,n], n is increased by 1. +pub(crate) struct IncreasingUniform { + pub rng: R, + n: u32, + // Chunk is a random number in [0, (n + 1) * (n + 2) *..* (n + chunk_remaining) ) + chunk: u32, + chunk_remaining: u8, +} + +impl IncreasingUniform { + /// Create a dice roller. + /// The next item returned will be a random number in the range [0,n] + pub fn new(rng: R, n: u32) -> Self { + // If n = 0, the first number returned will always be 0 + // so we don't need to generate a random number + let chunk_remaining = if n == 0 { 1 } else { 0 }; + Self { + rng, + n, + chunk: 0, + chunk_remaining, + } + } + + /// Returns a number in [0,n] and increments n by 1. + /// Generates new random bits as needed + /// Panics if `n >= u32::MAX` + #[inline] + pub fn next_index(&mut self) -> usize { + let next_n = self.n + 1; + + // There's room for further optimisation here: + // gen_range uses rejection sampling (or other method; see #1196) to avoid bias. + // When the initial sample is biased for range 0..bound + // it may still be viable to use for a smaller bound + // (especially if small biases are considered acceptable). + + let next_chunk_remaining = self.chunk_remaining.checked_sub(1).unwrap_or_else(|| { + // If the chunk is empty, generate a new chunk + let (bound, remaining) = calculate_bound_u32(next_n); + // bound = (n + 1) * (n + 2) *..* (n + remaining) + self.chunk = self.rng.gen_range(0..bound); + // Chunk is a random number in + // [0, (n + 1) * (n + 2) *..* (n + remaining) ) + + remaining - 1 + }); + + let result = if next_chunk_remaining == 0 { + // `chunk` is a random number in the range [0..n+1) + // Because `chunk_remaining` is about to be set to zero + // we do not need to clear the chunk here + self.chunk as usize + } else { + // `chunk` is a random number in a range that is a multiple of n+1 + // so r will be a random number in [0..n+1) + let r = self.chunk % next_n; + self.chunk /= next_n; + r as usize + }; + + self.chunk_remaining = next_chunk_remaining; + self.n = next_n; + result + } +} + +#[inline] +/// Calculates `bound`, `count` such that bound (m)*(m+1)*..*(m + remaining - 1) +fn calculate_bound_u32(m: u32) -> (u32, u8) { + debug_assert!(m > 0); + #[inline] + const fn inner(m: u32) -> (u32, u8) { + let mut product = m; + let mut current = m + 1; + + loop { + if let Some(p) = u32::checked_mul(product, current) { + product = p; + current += 1; + } else { + // Count has a maximum value of 13 for when min is 1 or 2 + let count = (current - m) as u8; + return (product, count); + } + } + } + + const RESULT2: (u32, u8) = inner(2); + if m == 2 { + // Making this value a constant instead of recalculating it + // gives a significant (~50%) performance boost for small shuffles + return RESULT2; + } + + inner(m) +} diff --git a/src/seq/mod.rs b/src/seq/mod.rs index e1286105c5..d9b38e920d 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2018 Developers of the Rand project. +// Copyright 2018-2023 Developers of the Rand project. // // Licensed under the Apache License, Version 2.0 or the MIT license @@ -29,6 +29,8 @@ mod coin_flipper; #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub mod index; +mod increasing_uniform; + #[cfg(feature = "alloc")] use core::ops::Index; @@ -42,6 +44,7 @@ use crate::distributions::WeightedError; use crate::Rng; use self::coin_flipper::CoinFlipper; +use self::increasing_uniform::IncreasingUniform; /// Extension trait on slices, providing random mutation and sampling methods. /// @@ -620,10 +623,11 @@ impl SliceRandom for [T] { where R: Rng + ?Sized, { - for i in (1..self.len()).rev() { - // invariant: elements with index > i have been locked in place. - self.swap(i, gen_index(rng, i + 1)); + if self.len() <= 1 { + // There is no need to shuffle an empty or single element slice + return; } + self.partial_shuffle(rng, self.len()); } fn partial_shuffle( @@ -632,19 +636,30 @@ impl SliceRandom for [T] { where R: Rng + ?Sized, { - // This applies Durstenfeld's algorithm for the - // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) - // for an unbiased permutation, but exits early after choosing `amount` - // elements. - - let len = self.len(); - let end = if amount >= len { 0 } else { len - amount }; + let m = self.len().saturating_sub(amount); - for i in (end..len).rev() { - // invariant: elements with index > i have been locked in place. - self.swap(i, gen_index(rng, i + 1)); + // The algorithm below is based on Durstenfeld's algorithm for the + // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) + // for an unbiased permutation. + // It ensures that the last `amount` elements of the slice + // are randomly selected from the whole slice. + + //`IncreasingUniform::next_index()` is faster than `gen_index` + //but only works for 32 bit integers + //So we must use the slow method if the slice is longer than that. + if self.len() < (u32::MAX as usize) { + let mut chooser = IncreasingUniform::new(rng, m as u32); + for i in m..self.len() { + let index = chooser.next_index(); + self.swap(i, index); + } + } else { + for i in m..self.len() { + let index = gen_index(rng, i + 1); + self.swap(i, index); + } } - let r = self.split_at_mut(end); + let r = self.split_at_mut(m); (r.1, r.0) } } @@ -765,11 +780,11 @@ mod test { let mut r = crate::test::rng(414); nums.shuffle(&mut r); - assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]); + assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]); nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; let res = nums.partial_shuffle(&mut r, 6); - assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]); - assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]); + assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]); + assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]); } #[derive(Clone)]