diff --git a/src/impls.rs b/src/impls.rs index 7c00def6f2..64a3702ccb 100644 --- a/src/impls.rs +++ b/src/impls.rs @@ -20,14 +20,13 @@ //! non-reproducible sources (e.g. `OsRng`) need not bother with it. // TODO: eventually these should be exported somehow -#![allow(unused)] use core::intrinsics::transmute; use core::ptr::copy_nonoverlapping; -use core::slice; +use core::{fmt, slice}; use core::cmp::min; use core::mem::size_of; -use RngCore; +use {RngCore, BlockRngCore, CryptoRng, SeedableRng, Error}; /// Implement `next_u64` via `next_u32`, little-endian order. pub fn next_u64_via_u32(rng: &mut R) -> u64 { @@ -167,4 +166,182 @@ pub fn next_u64_via_fill(rng: &mut R) -> u64 { impl_uint_from_fill!(rng, u64, 8) } +/// Wrapper around PRNGs that implement [`BlockRngCore`] to keep a results +/// buffer and offer the methods from [`RngCore`]. +/// +/// `BlockRng` has optimized methods to read from the output array that the +/// algorithm of many cryptograpic RNGs generates natively. Also they handle the +/// bookkeeping when to generate a new batch of values. +/// +/// `next_u32` simply indexes the array. `next_u64` tries to read two `u32` +/// values at a time if possible, and handles edge cases like when only one +/// value is left. `try_fill_bytes` is optimized to even attempt to use the +/// [`BlockRngCore`] implementation to write the results directly to the +/// destination slice. No generated values are ever thown away. +/// +/// Although `BlockCoreRng::generate` can return a `Result`, we assume all PRNGs +/// to be infallible, and for the `Result` to only have a signaling function. +/// Therefore, the error is only reported by `try_fill_bytes`, all other +/// functions squelch the error. +/// +/// For easy initialization `BlockRng` also implements [`SeedableRng`]. +/// +/// [`BlockRngCore`]: ../BlockRngCore.t.html +/// [`RngCore`]: ../RngCore.t.html +/// [`SeedableRng`]: ../SeedableRng.t.html +#[derive(Clone)] +pub struct BlockRng> { + pub core: R, + pub results: R::Results, + pub index: usize, +} + +// Custom Debug implementation that does not expose the contents of `results`. +impl+fmt::Debug> fmt::Debug for BlockRng { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("BlockRng") + .field("core", &self.core) + .field("result_len", &self.results.as_ref().len()) + .field("index", &self.index) + .finish() + } +} + +impl> RngCore for BlockRng { + #[inline(always)] + fn next_u32(&mut self) -> u32 { + if self.index >= self.results.as_ref().len() { + let _ = self.core.generate(&mut self.results).unwrap(); + self.index = 0; + } + + let value = self.results.as_ref()[self.index]; + self.index += 1; + value + } + + #[inline(always)] + fn next_u64(&mut self) -> u64 { + let len = self.results.as_ref().len(); + + let index = self.index; + if index < len-1 { + self.index += 2; + // Read an u64 from the current index + if cfg!(any(target_arch = "x86", target_arch = "x86_64")) { + unsafe { *(&self.results.as_ref()[index] as *const u32 as *const u64) } + } else { + let x = self.results.as_ref()[index] as u64; + let y = self.results.as_ref()[index + 1] as u64; + (y << 32) | x + } + } else if index >= len { + let _ = self.core.generate(&mut self.results); + self.index = 2; + let x = self.results.as_ref()[0] as u64; + let y = self.results.as_ref()[1] as u64; + (y << 32) | x + } else { + let x = self.results.as_ref()[len-1] as u64; + let _ = self.core.generate(&mut self.results); + self.index = 1; + let y = self.results.as_ref()[0] as u64; + (y << 32) | x + } + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + let _ = self.try_fill_bytes(dest); + } + + // As an optimization we try to write directly into the output buffer. + // This is only enabled for platforms where unaligned writes are known to + // be safe and fast. + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + let mut filled = 0; + let mut res = Ok(()); + + // Continue filling from the current set of results + if self.index < self.results.as_ref().len() { + let (consumed_u32, filled_u8) = + fill_via_u32_chunks(&self.results.as_ref()[self.index..], + dest); + + self.index += consumed_u32; + filled += filled_u8; + } + + let len_remainder = + (dest.len() - filled) % (self.results.as_ref().len() * 4); + let len_direct = dest.len() - len_remainder; + + while filled < len_direct { + let dest_u32: &mut R::Results = unsafe { + ::core::mem::transmute(dest[filled..].as_mut_ptr()) + }; + let res2 = self.core.generate(dest_u32); + if res2.is_err() && res.is_ok() { res = res2 }; + filled += self.results.as_ref().len() * 4; + } + self.index = self.results.as_ref().len(); + + if len_remainder > 0 { + let res2 = self.core.generate(&mut self.results); + if res2.is_err() && res.is_ok() { res = res2 }; + + let (consumed_u32, _) = + fill_via_u32_chunks(&mut self.results.as_ref(), + &mut dest[filled..]); + + self.index = consumed_u32; + } + res + } + + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + let mut res = Ok(()); + let mut read_len = 0; + while read_len < dest.len() { + if self.index >= self.results.as_ref().len() { + let res2 = self.core.generate(&mut self.results); + if res2.is_err() && res.is_ok() { res = res2 }; + self.index = 0; + } + let (consumed_u32, filled_u8) = + fill_via_u32_chunks(&self.results.as_ref()[self.index..], + &mut dest[read_len..]); + + self.index += consumed_u32; + read_len += filled_u8; + } + res + } +} + +impl + SeedableRng> SeedableRng for BlockRng { + type Seed = R::Seed; + + fn from_seed(seed: Self::Seed) -> Self { + let results_empty = R::Results::default(); + Self { + core: R::from_seed(seed), + index: results_empty.as_ref().len(), // generate on first use + results: results_empty, + } + } + + fn from_rng(rng: &mut RNG) -> Result { + let results_empty = R::Results::default(); + Ok(Self { + core: R::from_rng(rng)?, + index: results_empty.as_ref().len(), // generate on first use + results: results_empty, + }) + } +} + +impl+CryptoRng> CryptoRng for BlockRng {} + // TODO: implement tests for the above