diff --git a/tower/Cargo.toml b/tower/Cargo.toml index f80b7b1fb..1c373f15a 100644 --- a/tower/Cargo.toml +++ b/tower/Cargo.toml @@ -47,7 +47,7 @@ full = [ ] # FIXME: Use weak dependency once available (https://github.com/rust-lang/cargo/issues/8832) log = ["tracing/log"] -balance = ["discover", "load", "ready-cache", "make", "rand", "slab"] +balance = ["discover", "load", "ready-cache", "make", "slab"] buffer = ["__common", "tokio/sync", "tokio/rt", "tokio-util", "tracing"] discover = ["__common"] filter = ["__common", "futures-util"] @@ -72,7 +72,6 @@ futures-core = { version = "0.3", optional = true } futures-util = { version = "0.3", default-features = false, features = ["alloc"], optional = true } hdrhistogram = { version = "7.0", optional = true, default-features = false } indexmap = { version = "1.0.2", optional = true } -rand = { version = "0.8", features = ["small_rng"], optional = true } slab = { version = "0.4", optional = true } tokio = { version = "1.6", optional = true, features = ["sync"] } tokio-stream = { version = "0.1.0", optional = true } @@ -88,9 +87,12 @@ tokio = { version = "1.6.2", features = ["macros", "sync", "test-util", "rt-mult tokio-stream = "0.1" tokio-test = "0.4" tower-test = { version = "0.4", path = "../tower-test" } +tracing = { version = "0.1.2", default-features = false, features = ["std"] } tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] } http = "0.2" lazy_static = "1.4.0" +rand = { version = "0.8", features = ["small_rng"] } +quickcheck = "1" [package.metadata.docs.rs] all-features = true diff --git a/tower/src/balance/p2c/service.rs b/tower/src/balance/p2c/service.rs index 48c6b45f1..a836b4960 100644 --- a/tower/src/balance/p2c/service.rs +++ b/tower/src/balance/p2c/service.rs @@ -2,10 +2,10 @@ use super::super::error; use crate::discover::{Change, Discover}; use crate::load::Load; use crate::ready_cache::{error::Failed, ReadyCache}; +use crate::util::rng::{sample_inplace, HasherRng, Rng}; use futures_core::ready; use futures_util::future::{self, TryFutureExt}; use pin_project_lite::pin_project; -use rand::{rngs::SmallRng, Rng, SeedableRng}; use std::hash::Hash; use std::marker::PhantomData; use std::{ @@ -39,7 +39,7 @@ where services: ReadyCache, ready_index: Option, - rng: SmallRng, + rng: Box, _req: PhantomData, } @@ -86,20 +86,20 @@ where { /// Constructs a load balancer that uses operating system entropy. pub fn new(discover: D) -> Self { - Self::from_rng(discover, &mut rand::thread_rng()).expect("ThreadRNG must be valid") + Self::from_rng(discover, HasherRng::default()) } /// Constructs a load balancer seeded with the provided random number generator. - pub fn from_rng(discover: D, rng: R) -> Result { - let rng = SmallRng::from_rng(rng)?; - Ok(Self { + pub fn from_rng(discover: D, rng: R) -> Self { + let rng = Box::new(rng); + Self { rng, discover, services: ReadyCache::default(), ready_index: None, _req: PhantomData, - }) + } } /// Returns the number of endpoints currently tracked by the balancer. @@ -185,14 +185,14 @@ where len => { // Get two distinct random indexes (in a random order) and // compare the loads of the service at each index. - let idxs = rand::seq::index::sample(&mut self.rng, len, 2); + let idxs = sample_inplace(&mut self.rng, len as u32, 2); - let aidx = idxs.index(0); - let bidx = idxs.index(1); + let aidx = idxs[0]; + let bidx = idxs[1]; debug_assert_ne!(aidx, bidx, "random indices must be distinct"); - let aload = self.ready_index_load(aidx); - let bload = self.ready_index_load(bidx); + let aload = self.ready_index_load(aidx as usize); + let bload = self.ready_index_load(bidx as usize); let chosen = if aload <= bload { aidx } else { bidx }; trace!( @@ -203,7 +203,7 @@ where chosen = if chosen == aidx { "a" } else { "b" }, "p2c", ); - Some(chosen) + Some(chosen as usize) } } } diff --git a/tower/src/util/mod.rs b/tower/src/util/mod.rs index dddf8ed7a..71253e99d 100644 --- a/tower/src/util/mod.rs +++ b/tower/src/util/mod.rs @@ -19,6 +19,8 @@ mod ready; mod service_fn; mod then; +pub mod rng; + pub use self::{ and_then::{AndThen, AndThenLayer}, boxed::{BoxLayer, BoxService, UnsyncBoxService}, diff --git a/tower/src/util/rng.rs b/tower/src/util/rng.rs new file mode 100644 index 000000000..2c58358fb --- /dev/null +++ b/tower/src/util/rng.rs @@ -0,0 +1,194 @@ +//! Utilities for generating random numbers. +//! +//! This module provides a generic [`Rng`] trait and a [`HasherRng`] that +//! implements the trait based on [`RandomState`] or any other [`Hasher`]. +//! +//! These utlities replace tower's internal usage of `rand` with these smaller +//! more light weight methods. Most of the implemenations are extracted from +//! their corresponding `rand` implementations. + +use std::{ + collections::hash_map::RandomState, + hash::{BuildHasher, Hasher}, + ops::Range, +}; + +/// A simple [`PRNG`] trait for use within tower middleware. +pub trait Rng { + /// Generate a random [`u64`]. + fn next_u64(&mut self) -> u64; + + /// Generate a random [`f64`] between `[0, 1)`. + fn next_f64(&mut self) -> f64 { + // Borrowed from: + // https://github.com/rust-random/rand/blob/master/src/distributions/float.rs#L106 + let float_size = std::mem::size_of::() as u32 * 8; + let precison = 52 + 1; + let scale = 1.0 / ((1u64 << precison) as f64); + + let value = self.next_u64(); + let value = value >> (float_size - precison); + + scale * value as f64 + } + + /// Randomly pick a value within the range. + /// + /// # Panic + /// + /// - If start < end this will panic in debug mode. + fn next_range(&mut self, range: Range) -> u64 { + debug_assert!( + range.start < range.end, + "The range start must be smaller than the end" + ); + let start = range.start; + let end = range.end; + + let range = end - start; + + let n = self.next_u64(); + + (n % range) + start + } +} + +impl Rng for Box { + fn next_u64(&mut self) -> u64 { + (**self).next_u64() + } +} + +/// A [`Rng`] implementation that uses a [`Hasher`] to generate the random +/// values. The implementation uses an internal counter to pass to the hasher +/// for each iteration of [`Rng::next_u64`]. +/// +/// # Default +/// +/// This hasher has a default type of [`RandomState`] which just uses the +/// libstd method of getting a random u64. +#[derive(Debug)] +pub struct HasherRng { + hasher: H, + counter: u64, +} + +impl HasherRng { + /// Create a new default [`HasherRng`]. + pub fn new() -> Self { + HasherRng::default() + } +} + +impl Default for HasherRng { + fn default() -> Self { + HasherRng::with_hasher(RandomState::default()) + } +} + +impl HasherRng { + /// Create a new [`HasherRng`] with the provided hasher. + pub fn with_hasher(hasher: H) -> Self { + HasherRng { hasher, counter: 0 } + } +} + +impl Rng for HasherRng +where + H: BuildHasher, +{ + fn next_u64(&mut self) -> u64 { + let mut hasher = self.hasher.build_hasher(); + hasher.write_u64(self.counter); + self.counter = self.counter.wrapping_add(1); + hasher.finish() + } +} + +/// An inplace sampler borrowed from the Rand implementation for use internally +/// for the balance middleware. +/// ref: https://github.com/rust-random/rand/blob/b73640705d6714509f8ceccc49e8df996fa19f51/src/seq/index.rs#L425 +/// +/// Docs from rand: +/// +/// Randomly sample exactly `amount` indices from `0..length`, using an inplace +/// partial Fisher-Yates method. +/// Sample an amount of indices using an inplace partial fisher yates method. +/// +/// This allocates the entire `length` of indices and randomizes only the first `amount`. +/// It then truncates to `amount` and returns. +/// +/// This method is not appropriate for large `length` and potentially uses a lot +/// of memory; because of this we only implement for `u32` index (which improves +/// performance in all cases). +/// +/// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time. +pub(crate) fn sample_inplace(rng: &mut R, length: u32, amount: u32) -> Vec { + debug_assert!(amount <= length); + let mut indices: Vec = Vec::with_capacity(length as usize); + indices.extend(0..length); + for i in 0..amount { + let j: u64 = rng.next_range(i as u64..length as u64); + indices.swap(i as usize, j as usize); + } + indices.truncate(amount as usize); + debug_assert_eq!(indices.len(), amount as usize); + indices +} + +#[cfg(test)] +mod tests { + use super::*; + use quickcheck::*; + + quickcheck! { + fn next_f64(counter: u64) -> TestResult { + let mut rng = HasherRng::default(); + rng.counter = counter; + let n = rng.next_f64(); + + TestResult::from_bool(n < 1.0 && n >= 0.0) + } + + fn next_range(counter: u64, range: Range) -> TestResult { + if range.start >= range.end{ + return TestResult::discard(); + } + + let mut rng = HasherRng::default(); + rng.counter = counter; + + let n = rng.next_range(range.clone()); + + TestResult::from_bool(n >= range.start && (n < range.end || range.start == range.end)) + } + + fn sample_inplace(counter: u64, length: u32, amount: u32) -> TestResult { + if amount > length || length > u32::MAX { + return TestResult::discard(); + } + + let mut rng = HasherRng::default(); + rng.counter = counter; + + let indxs = super::sample_inplace(&mut rng, length, amount); + + for indx in indxs { + if indx > length { + return TestResult::failed(); + } + } + + TestResult::passed() + } + } + + #[test] + fn sample_inplace_boundaries() { + let mut r = HasherRng::default(); + + assert_eq!(super::sample_inplace(&mut r, 0, 0).len(), 0); + assert_eq!(super::sample_inplace(&mut r, 1, 0).len(), 0); + assert_eq!(super::sample_inplace(&mut r, 1, 1), vec![0]); + } +}