Skip to content

Commit

Permalink
util: Add rng utilities
Browse files Browse the repository at this point in the history
This adds new PRNG utilities that only use libstd and not the external
`rand` crate. This change's motivation are that in tower middleware that
need PRNG don't need the complexity and vast utilities of the `rand`
crate.

This adds a `Rng` trait which abstracts the simple PRNG features tower
needs. This also provides a `HasherRng` which uses the `RandomState`
type from libstd to generate random `u64` values. In addition, there is
an internal only `sample_inplace` which is used within the balance p2c
middleware to randomly pick a ready service. This implementation is
crate private since its quite specific to the balance implementation.

The goal of this in addition to the balance middlware getting `rand`
removed is for the upcoming `Retry` changes. The `next_f64` will be used
in the jitter portion of the backoff utilities in #685.
  • Loading branch information
LucioFranco committed Aug 23, 2022
1 parent 6d34340 commit f49ac5d
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 15 deletions.
6 changes: 4 additions & 2 deletions tower/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ full = [
]
# FIXME: Use weak dependency once available (https://github.com/rust-lang/cargo/issues/8832)
log = ["tracing/log"]
balance = ["discover", "load", "ready-cache", "make", "rand", "slab"]
balance = ["discover", "load", "ready-cache", "make", "slab"]
buffer = ["__common", "tokio/sync", "tokio/rt", "tokio-util", "tracing"]
discover = ["__common"]
filter = ["__common", "futures-util"]
Expand All @@ -72,7 +72,6 @@ futures-core = { version = "0.3", optional = true }
futures-util = { version = "0.3", default-features = false, features = ["alloc"], optional = true }
hdrhistogram = { version = "7.0", optional = true, default-features = false }
indexmap = { version = "1.0.2", optional = true }
rand = { version = "0.8", features = ["small_rng"], optional = true }
slab = { version = "0.4", optional = true }
tokio = { version = "1.6", optional = true, features = ["sync"] }
tokio-stream = { version = "0.1.0", optional = true }
Expand All @@ -88,9 +87,12 @@ tokio = { version = "1.6.2", features = ["macros", "sync", "test-util", "rt-mult
tokio-stream = "0.1"
tokio-test = "0.4"
tower-test = { version = "0.4", path = "../tower-test" }
tracing = { version = "0.1.2", default-features = false, features = ["std"] }
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] }
http = "0.2"
lazy_static = "1.4.0"
rand = { version = "0.8", features = ["small_rng"] }
quickcheck = "1"

[package.metadata.docs.rs]
all-features = true
Expand Down
26 changes: 13 additions & 13 deletions tower/src/balance/p2c/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use super::super::error;
use crate::discover::{Change, Discover};
use crate::load::Load;
use crate::ready_cache::{error::Failed, ReadyCache};
use crate::util::rng::{sample_inplace, HasherRng, Rng};
use futures_core::ready;
use futures_util::future::{self, TryFutureExt};
use pin_project_lite::pin_project;
use rand::{rngs::SmallRng, Rng, SeedableRng};
use std::hash::Hash;
use std::marker::PhantomData;
use std::{
Expand Down Expand Up @@ -39,7 +39,7 @@ where
services: ReadyCache<D::Key, D::Service, Req>,
ready_index: Option<usize>,

rng: SmallRng,
rng: Box<dyn Rng + Send + Sync>,

_req: PhantomData<Req>,
}
Expand Down Expand Up @@ -86,20 +86,20 @@ where
{
/// Constructs a load balancer that uses operating system entropy.
pub fn new(discover: D) -> Self {
Self::from_rng(discover, &mut rand::thread_rng()).expect("ThreadRNG must be valid")
Self::from_rng(discover, HasherRng::default())
}

/// Constructs a load balancer seeded with the provided random number generator.
pub fn from_rng<R: Rng>(discover: D, rng: R) -> Result<Self, rand::Error> {
let rng = SmallRng::from_rng(rng)?;
Ok(Self {
pub fn from_rng<R: Rng + Send + Sync + 'static>(discover: D, rng: R) -> Self {
let rng = Box::new(rng);
Self {
rng,
discover,
services: ReadyCache::default(),
ready_index: None,

_req: PhantomData,
})
}
}

/// Returns the number of endpoints currently tracked by the balancer.
Expand Down Expand Up @@ -185,14 +185,14 @@ where
len => {
// Get two distinct random indexes (in a random order) and
// compare the loads of the service at each index.
let idxs = rand::seq::index::sample(&mut self.rng, len, 2);
let idxs = sample_inplace(&mut self.rng, len as u32, 2);

let aidx = idxs.index(0);
let bidx = idxs.index(1);
let aidx = idxs[0];
let bidx = idxs[1];
debug_assert_ne!(aidx, bidx, "random indices must be distinct");

let aload = self.ready_index_load(aidx);
let bload = self.ready_index_load(bidx);
let aload = self.ready_index_load(aidx as usize);
let bload = self.ready_index_load(bidx as usize);
let chosen = if aload <= bload { aidx } else { bidx };

trace!(
Expand All @@ -203,7 +203,7 @@ where
chosen = if chosen == aidx { "a" } else { "b" },
"p2c",
);
Some(chosen)
Some(chosen as usize)
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions tower/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ mod ready;
mod service_fn;
mod then;

pub mod rng;

pub use self::{
and_then::{AndThen, AndThenLayer},
boxed::{BoxLayer, BoxService, UnsyncBoxService},
Expand Down
194 changes: 194 additions & 0 deletions tower/src/util/rng.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
//! Utilities for generating random numbers.
//!
//! This module provides a generic [`Rng`] trait and a [`HasherRng`] that
//! implements the trait based on [`RandomState`] or any other [`Hasher`].
//!
//! These utlities replace tower's internal usage of `rand` with these smaller
//! more light weight methods. Most of the implemenations are extracted from
//! their corresponding `rand` implementations.
use std::{
collections::hash_map::RandomState,
hash::{BuildHasher, Hasher},
ops::Range,
};

/// A simple [`PRNG`] trait for use within tower middleware.
pub trait Rng {
/// Generate a random [`u64`].
fn next_u64(&mut self) -> u64;

/// Generate a random [`f64`] between `[0, 1)`.
fn next_f64(&mut self) -> f64 {
// Borrowed from:
// https://github.com/rust-random/rand/blob/master/src/distributions/float.rs#L106
let float_size = std::mem::size_of::<f64>() as u32 * 8;
let precison = 52 + 1;
let scale = 1.0 / ((1u64 << precison) as f64);

let value = self.next_u64();
let value = value >> (float_size - precison);

scale * value as f64
}

/// Randomly pick a value within the range.
///
/// # Panic
///
/// - If start < end this will panic in debug mode.
fn next_range(&mut self, range: Range<u64>) -> u64 {
debug_assert!(
range.start < range.end,
"The range start must be smaller than the end"
);
let start = range.start;
let end = range.end;

let range = end - start;

let n = self.next_u64();

(n % range) + start
}
}

impl<R: Rng + ?Sized> Rng for Box<R> {
fn next_u64(&mut self) -> u64 {
(**self).next_u64()
}
}

/// A [`Rng`] implementation that uses a [`Hasher`] to generate the random
/// values. The implementation uses an internal counter to pass to the hasher
/// for each iteration of [`Rng::next_u64`].
///
/// # Default
///
/// This hasher has a default type of [`RandomState`] which just uses the
/// libstd method of getting a random u64.
#[derive(Debug)]
pub struct HasherRng<H = RandomState> {
hasher: H,
counter: u64,
}

impl HasherRng {
/// Create a new default [`HasherRng`].
pub fn new() -> Self {
HasherRng::default()
}
}

impl Default for HasherRng {
fn default() -> Self {
HasherRng::with_hasher(RandomState::default())
}
}

impl<H> HasherRng<H> {
/// Create a new [`HasherRng`] with the provided hasher.
pub fn with_hasher(hasher: H) -> Self {
HasherRng { hasher, counter: 0 }
}
}

impl<H> Rng for HasherRng<H>
where
H: BuildHasher,
{
fn next_u64(&mut self) -> u64 {
let mut hasher = self.hasher.build_hasher();
hasher.write_u64(self.counter);
self.counter = self.counter.wrapping_add(1);
hasher.finish()
}
}

/// An inplace sampler borrowed from the Rand implementation for use internally
/// for the balance middleware.
/// ref: https://github.com/rust-random/rand/blob/b73640705d6714509f8ceccc49e8df996fa19f51/src/seq/index.rs#L425
///
/// Docs from rand:
///
/// Randomly sample exactly `amount` indices from `0..length`, using an inplace
/// partial Fisher-Yates method.
/// Sample an amount of indices using an inplace partial fisher yates method.
///
/// This allocates the entire `length` of indices and randomizes only the first `amount`.
/// It then truncates to `amount` and returns.
///
/// This method is not appropriate for large `length` and potentially uses a lot
/// of memory; because of this we only implement for `u32` index (which improves
/// performance in all cases).
///
/// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time.
pub(crate) fn sample_inplace<R: Rng>(rng: &mut R, length: u32, amount: u32) -> Vec<u32> {
debug_assert!(amount <= length);
let mut indices: Vec<u32> = Vec::with_capacity(length as usize);
indices.extend(0..length);
for i in 0..amount {
let j: u64 = rng.next_range(i as u64..length as u64);
indices.swap(i as usize, j as usize);
}
indices.truncate(amount as usize);
debug_assert_eq!(indices.len(), amount as usize);
indices
}

#[cfg(test)]
mod tests {
use super::*;
use quickcheck::*;

quickcheck! {
fn next_f64(counter: u64) -> TestResult {
let mut rng = HasherRng::default();
rng.counter = counter;
let n = rng.next_f64();

TestResult::from_bool(n < 1.0 && n >= 0.0)
}

fn next_range(counter: u64, range: Range<u64>) -> TestResult {
if range.start >= range.end{
return TestResult::discard();
}

let mut rng = HasherRng::default();
rng.counter = counter;

let n = rng.next_range(range.clone());

TestResult::from_bool(n >= range.start && (n < range.end || range.start == range.end))
}

fn sample_inplace(counter: u64, length: u32, amount: u32) -> TestResult {
if amount > length || length > u32::MAX {
return TestResult::discard();
}

let mut rng = HasherRng::default();
rng.counter = counter;

let indxs = super::sample_inplace(&mut rng, length, amount);

for indx in indxs {
if indx > length {
return TestResult::failed();
}
}

TestResult::passed()
}
}

#[test]
fn sample_inplace_boundaries() {
let mut r = HasherRng::default();

assert_eq!(super::sample_inplace(&mut r, 0, 0).len(), 0);
assert_eq!(super::sample_inplace(&mut r, 1, 0).len(), 0);
assert_eq!(super::sample_inplace(&mut r, 1, 1), vec![0]);
}
}

0 comments on commit f49ac5d

Please sign in to comment.