diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index feac2c2e43..2dc24d86c3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -682,3 +682,30 @@ jobs: name: "dhat / report" status: "success" url: "${{ steps.s3.outputs.URL }}" + + loom: + runs-on: ubuntu-latest + strategy: + matrix: + crate: [quic/s2n-quic-core] + steps: + - uses: actions/checkout@v3 + with: + submodules: true + + - uses: actions-rs/toolchain@v1.0.7 + id: toolchain + with: + toolchain: stable + profile: minimal + override: true + + - uses: camshaft/rust-cache@v1 + with: + key: ${{ matrix.crate }} + + - name: ${{ matrix.crate }} + # run the tests with release mode since some of the loom models can be expensive + run: cd ${{ matrix.crate }} && cargo test --release loom + env: + RUSTFLAGS: --cfg loom -Cdebug-assertions diff --git a/quic/s2n-quic-bench/Cargo.toml b/quic/s2n-quic-bench/Cargo.toml index 6566022156..8947b1dc22 100644 --- a/quic/s2n-quic-bench/Cargo.toml +++ b/quic/s2n-quic-bench/Cargo.toml @@ -10,6 +10,7 @@ publish = false [dependencies] criterion = { version = "0.4", features = ["html_reports"] } +crossbeam-channel = { version = "0.5" } s2n-codec = { path = "../../common/s2n-codec", features = ["testing"] } s2n-quic-core = { path = "../s2n-quic-core", features = ["testing"] } s2n-quic-crypto = { path = "../s2n-quic-crypto", features = ["testing"] } diff --git a/quic/s2n-quic-bench/src/lib.rs b/quic/s2n-quic-bench/src/lib.rs index f29787a73e..7facdc8fee 100644 --- a/quic/s2n-quic-bench/src/lib.rs +++ b/quic/s2n-quic-bench/src/lib.rs @@ -7,6 +7,7 @@ mod buffer; mod crypto; mod frame; mod packet; +mod sync; mod varint; pub fn benchmarks(c: &mut Criterion) { @@ -14,5 +15,6 @@ pub fn benchmarks(c: &mut Criterion) { crypto::benchmarks(c); frame::benchmarks(c); packet::benchmarks(c); + sync::benchmarks(c); varint::benchmarks(c); } diff --git a/quic/s2n-quic-bench/src/sync.rs b/quic/s2n-quic-bench/src/sync.rs new file mode 100644 index 0000000000..1d0d468563 --- /dev/null +++ b/quic/s2n-quic-bench/src/sync.rs @@ -0,0 +1,80 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use criterion::{BenchmarkId, Criterion, Throughput}; +use crossbeam_channel::bounded; +use s2n_quic_core::sync::spsc; + +pub fn benchmarks(c: &mut Criterion) { + spsc_benches(c); +} + +fn spsc_benches(c: &mut Criterion) { + let mut group = c.benchmark_group("spsc"); + + for i in [1, 64, 1024, 4096] { + group.throughput(Throughput::Elements(i as _)); + group.bench_with_input(BenchmarkId::new("s2n/send_recv", i), &i, |b, input| { + let (mut sender, mut receiver) = spsc::channel(*input); + b.iter(|| { + { + let mut slice = sender.try_slice().unwrap().unwrap(); + while slice.push(123usize).is_ok() {} + } + + { + let mut slice = receiver.try_slice().unwrap().unwrap(); + while slice.pop().is_some() {} + } + }); + }); + group.bench_with_input( + BenchmarkId::new("crossbeam/send_recv", i), + &i, + |b, input| { + let (sender, receiver) = bounded(*input); + b.iter(|| { + { + while sender.try_send(123usize).is_ok() {} + } + + { + while receiver.try_recv().is_ok() {} + } + }); + }, + ); + + group.bench_with_input(BenchmarkId::new("s2n/send_recv_iter", i), &i, |b, input| { + let (mut sender, mut receiver) = spsc::channel(*input); + b.iter(|| { + { + let mut slice = sender.try_slice().unwrap().unwrap(); + let _ = slice.extend(&mut core::iter::repeat(123usize)); + } + + { + let mut slice = receiver.try_slice().unwrap().unwrap(); + slice.clear(); + } + }); + }); + group.bench_with_input( + BenchmarkId::new("crossbeam/send_recv_iter", i), + &i, + |b, input| { + let (sender, receiver) = bounded(*input); + b.iter(|| { + { + while sender.try_send(123usize).is_ok() {} + } + + { + for _ in receiver.try_iter() {} + } + }); + }, + ); + } + group.finish(); +} diff --git a/quic/s2n-quic-core/Cargo.toml b/quic/s2n-quic-core/Cargo.toml index 97941299e2..6e97a63268 100644 --- a/quic/s2n-quic-core/Cargo.toml +++ b/quic/s2n-quic-core/Cargo.toml @@ -12,7 +12,7 @@ exclude = ["corpus.tar.gz"] [features] default = ["alloc", "std"] -alloc = [] +alloc = ["atomic-waker", "cache-padded"] std = ["alloc", "once_cell"] testing = ["std", "generator", "s2n-codec/testing", "checked-counters", "insta", "futures-test"] generator = ["bolero-generator"] @@ -20,9 +20,11 @@ checked-counters = [] event-tracing = ["tracing"] [dependencies] +atomic-waker = { version = "1", optional = true } bolero-generator = { version = "0.8", default-features = false, optional = true } byteorder = { version = "1", default-features = false } bytes = { version = "1", default-features = false } +cache-padded = { version = "1", optional = true } hex-literal = "0.3" # used for event snapshot testing - needs an internal API so we require a minimum version insta = { version = ">=1.12", features = ["json"], optional = true } @@ -40,7 +42,11 @@ once_cell = { version = "1", optional = true } bolero = "0.8" bolero-generator = { version = "0.8", default-features = false } insta = { version = "1", features = ["json"] } +futures = "0.3" futures-test = "0.3" ip_network = "0.4" plotters = { version = "0.3", default-features = false, features = ["svg_backend", "line_series"] } s2n-codec = { path = "../../common/s2n-codec", features = ["testing"] } + +[target.'cfg(loom)'.dev-dependencies] +loom = { version = "0.5", features = ["checkpoint", "futures"] } diff --git a/quic/s2n-quic-core/src/lib.rs b/quic/s2n-quic-core/src/lib.rs index 252500d800..d7bc8dc99e 100644 --- a/quic/s2n-quic-core/src/lib.rs +++ b/quic/s2n-quic-core/src/lib.rs @@ -6,6 +6,30 @@ #[cfg(feature = "alloc")] extern crate alloc; +/// Asserts that a boolean expression is true at runtime, only if debug_assertions are enabled. +/// +/// Otherwise, the compiler is told to assume that the expression is always true and can perform +/// additional optimizations. +/// +/// # Safety +/// +/// The caller _must_ ensure this condition is never possible, otherwise the compiler +/// may optimize based on false assumptions and behave incorrectly. +#[macro_export] +macro_rules! assume { + ($cond:expr) => { + $crate::assume!($cond, "assumption failed: {}", stringify!($cond)); + }; + ($cond:expr $(, $fmtarg:expr)* $(,)?) => { + let v = $cond; + + debug_assert!(v $(, $fmtarg)*); + if cfg!(not(debug_assertions)) && !v { + core::hint::unreachable_unchecked(); + } + }; +} + pub mod ack; pub mod application; #[cfg(feature = "alloc")] @@ -33,8 +57,12 @@ pub mod recovery; pub mod slice; pub mod stateless_reset; pub mod stream; +pub mod sync; pub mod time; pub mod token; pub mod transmission; pub mod transport; pub mod varint; + +#[cfg(any(test, feature = "testing"))] +pub mod testing; diff --git a/quic/s2n-quic-core/src/slice.rs b/quic/s2n-quic-core/src/slice.rs index bf6ddb974c..0acf1e056c 100644 --- a/quic/s2n-quic-core/src/slice.rs +++ b/quic/s2n-quic-core/src/slice.rs @@ -91,7 +91,8 @@ where #[cfg(test)] mod tests { use super::*; - use bolero::{check, generator::*}; + use crate::testing::InlineVec; + use bolero::check; fn assert_eq_slices(a: &[A], b: &[B]) where @@ -128,28 +129,6 @@ mod tests { } } - #[derive(Clone, Copy, Debug, TypeGenerator)] - struct InlineVec { - values: [T; LEN], - - #[generator(_code = "0..LEN")] - len: usize, - } - - impl core::ops::Deref for InlineVec { - type Target = [T]; - - fn deref(&self) -> &Self::Target { - &self.values[..self.len] - } - } - - impl core::ops::DerefMut for InlineVec { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.values[..self.len] - } - } - const LEN: usize = if cfg!(kani) { 2 } else { 32 }; #[test] diff --git a/quic/s2n-quic-core/src/sync.rs b/quic/s2n-quic-core/src/sync.rs new file mode 100644 index 0000000000..8881093b07 --- /dev/null +++ b/quic/s2n-quic-core/src/sync.rs @@ -0,0 +1,5 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +#[cfg(feature = "alloc")] +pub mod spsc; diff --git a/quic/s2n-quic-core/src/sync/__fuzz___/sync__spsc__tests__model/corpus.tar.gz b/quic/s2n-quic-core/src/sync/__fuzz___/sync__spsc__tests__model/corpus.tar.gz new file mode 100644 index 0000000000..ce0bd7950b --- /dev/null +++ b/quic/s2n-quic-core/src/sync/__fuzz___/sync__spsc__tests__model/corpus.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16959cd0f6caa76f919caf98d976bd25db44c813ca60c684892f4db108694b90 +size 993280 diff --git a/quic/s2n-quic-core/src/sync/spsc.rs b/quic/s2n-quic-core/src/sync/spsc.rs new file mode 100644 index 0000000000..7527471708 --- /dev/null +++ b/quic/s2n-quic-core/src/sync/spsc.rs @@ -0,0 +1,42 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +mod primitive; +mod recv; +mod send; +mod slice; +mod state; + +use slice::*; +use state::*; + +pub use recv::{Receiver, RecvSlice}; +pub use send::{SendSlice, Sender}; + +#[inline] +pub fn channel(capacity: usize) -> (Sender, Receiver) { + let state = State::new(capacity); + let sender = Sender(state.clone()); + let receiver = Receiver(state); + (sender, receiver) +} + +#[cfg(test)] +mod tests; + +type Result = core::result::Result; + +#[derive(Clone, Copy, Debug)] +pub struct ClosedError; + +#[derive(Clone, Copy, Debug)] +pub enum PushError { + Full(T), + Closed, +} + +impl From for PushError { + fn from(_error: ClosedError) -> Self { + Self::Closed + } +} diff --git a/quic/s2n-quic-core/src/sync/spsc/primitive.rs b/quic/s2n-quic-core/src/sync/spsc/primitive.rs new file mode 100644 index 0000000000..1a6453229b --- /dev/null +++ b/quic/s2n-quic-core/src/sync/spsc/primitive.rs @@ -0,0 +1,53 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +#[cfg(all(loom, test))] +mod loom { + use ::core::task::Waker; + use ::loom::future::AtomicWaker as Inner; + + pub use ::loom::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + + #[derive(Debug, Default)] + pub struct AtomicWaker(Inner); + + impl AtomicWaker { + pub fn new() -> Self { + Self(Inner::new()) + } + + pub fn wake(&self) { + self.0.wake(); + } + + pub fn take(&self) -> Option { + self.0.take_waker() + } + + pub fn register(&self, waker: &Waker) { + self.0.register_by_ref(waker); + } + } +} + +#[cfg(all(loom, test))] +pub use self::loom::*; + +mod core { + pub use ::core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + pub use atomic_waker::AtomicWaker; +} + +#[cfg(not(all(loom, test)))] +pub use self::core::*; + +/// Indicates if the type is a zero-sized type +/// +/// This can be used to optimize the code to avoid needless calculations. +pub trait IsZst { + const IS_ZST: bool; +} + +impl IsZst for T { + const IS_ZST: bool = ::core::mem::size_of::() == 0; +} diff --git a/quic/s2n-quic-core/src/sync/spsc/recv.rs b/quic/s2n-quic-core/src/sync/spsc/recv.rs new file mode 100644 index 0000000000..175916c88d --- /dev/null +++ b/quic/s2n-quic-core/src/sync/spsc/recv.rs @@ -0,0 +1,143 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{state::Side, Cursor, Result, State}; +use core::task::{Context, Poll}; + +pub struct Receiver(pub(super) State); + +impl Receiver { + #[inline] + pub fn capacity(&self) -> usize { + self.0.cursor.capacity() + } + + #[inline] + pub fn len(&self) -> usize { + self.0.cursor.recv_len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.0.cursor.is_empty() + } + + #[inline] + pub fn is_full(&self) -> bool { + self.0.cursor.is_full() + } + + #[inline] + pub fn poll_slice(&mut self, cx: &mut Context) -> Poll>> { + macro_rules! acquire_filled { + () => { + match self.0.acquire_filled() { + Ok(true) => { + let cursor = self.0.cursor; + return Ok(RecvSlice(&mut self.0, cursor)).into(); + } + Ok(false) => { + // the queue is full + } + Err(err) => { + // the channel was closed + return Err(err).into(); + } + } + }; + } + + // check capacity before registering a waker + acquire_filled!(); + + // register the waker + self.0.receiver.register(cx.waker()); + + // check once more to avoid a loss of notification + acquire_filled!(); + + Poll::Pending + } + + #[inline] + pub fn try_slice(&mut self) -> Result>> { + Ok(if self.0.acquire_filled()? { + let cursor = self.0.cursor; + Some(RecvSlice(&mut self.0, cursor)) + } else { + None + }) + } +} + +impl Drop for Receiver { + #[inline] + fn drop(&mut self) { + self.0.close(Side::Receiver); + } +} + +pub struct RecvSlice<'a, T>(&'a mut State, Cursor); + +impl<'a, T> RecvSlice<'a, T> { + #[inline] + pub fn peek(&mut self) -> (&mut [T], &mut [T]) { + let _ = self.0.acquire_filled(); + let (slice, _) = self.0.as_pairs(); + unsafe { + // Safety: the first pair of returned slices is the `initialized` half + slice.assume_init().into_mut() + } + } + + #[inline] + pub fn pop(&mut self) -> Option { + if self.0.cursor.is_empty() && !self.0.acquire_filled().unwrap_or(false) { + return None; + } + + let (pair, _) = self.0.as_pairs(); + let value = unsafe { + // Safety: the state's cursor indicates that the first slot contains initialized data + pair.take(0) + }; + self.0.cursor.increment_head(1); + Some(value) + } + + #[inline] + pub fn clear(&mut self) -> usize { + // don't try to `acquire_filled` so the caller can observe any updates through peek/pop + + let (pair, _) = self.0.as_pairs(); + let len = pair.len(); + + for entry in pair.iter() { + unsafe { + // Safety: the state's cursor indicates that each slot in the `iter` contains data + let _ = entry.take(); + } + } + + self.0.cursor.increment_head(len); + + len + } + + #[inline] + pub fn len(&self) -> usize { + self.0.cursor.recv_len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.0.cursor.is_empty() + } +} + +impl<'a, T> Drop for RecvSlice<'a, T> { + #[inline] + fn drop(&mut self) { + self.0.persist_head(self.1); + } +} diff --git a/quic/s2n-quic-core/src/sync/spsc/send.rs b/quic/s2n-quic-core/src/sync/spsc/send.rs new file mode 100644 index 0000000000..bfa549544c --- /dev/null +++ b/quic/s2n-quic-core/src/sync/spsc/send.rs @@ -0,0 +1,123 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{state::Side, Cursor, PushError, Result, State}; +use core::task::{Context, Poll}; + +pub struct Sender(pub(super) State); + +impl Sender { + #[inline] + pub fn capacity(&self) -> usize { + self.0.cursor.capacity() + } + + #[inline] + pub fn poll_slice(&mut self, cx: &mut Context) -> Poll>> { + macro_rules! acquire_capacity { + () => { + match self.0.acquire_capacity() { + Ok(true) => { + let cursor = self.0.cursor; + return Ok(SendSlice(&mut self.0, cursor)).into(); + } + Ok(false) => { + // the queue is full + } + Err(err) => { + // the channel was closed + return Err(err).into(); + } + } + }; + } + + // check capacity before registering a waker + acquire_capacity!(); + + // register the waker + self.0.sender.register(cx.waker()); + + // check once more to avoid a loss of notification + acquire_capacity!(); + + Poll::Pending + } + + #[inline] + pub fn try_slice(&mut self) -> Result>> { + Ok(if self.0.acquire_capacity()? { + let cursor = self.0.cursor; + Some(SendSlice(&mut self.0, cursor)) + } else { + None + }) + } +} + +impl Drop for Sender { + #[inline] + fn drop(&mut self) { + self.0.close(Side::Sender); + } +} + +pub struct SendSlice<'a, T>(&'a mut State, Cursor); + +impl<'a, T> SendSlice<'a, T> { + #[inline] + pub fn push(&mut self, value: T) -> Result<(), PushError> { + if self.0.cursor.is_full() && !self.0.acquire_capacity()? { + return Err(PushError::Full(value)); + } + + let (_, pair) = self.0.as_pairs(); + + unsafe { + // Safety: the second pair of slices contains uninitialized memory and the cursor + // indicates we have capacity to write at least one value + pair.write(0, value); + } + + self.0.cursor.increment_tail(1); + + Ok(()) + } + + pub fn extend>(&mut self, iter: &mut I) -> Result<()> { + if self.0.acquire_capacity()? { + let (_, pair) = self.0.as_pairs(); + + let mut idx = 0; + let capacity = self.capacity(); + + while idx < capacity { + if let Some(value) = iter.next() { + unsafe { + // Safety: the second pair of slices contains uninitialized memory + pair.write(idx, value); + } + idx += 1; + } else { + break; + } + } + + self.0.cursor.increment_tail(idx); + } + + Ok(()) + } + + #[inline] + pub fn capacity(&self) -> usize { + self.0.cursor.send_capacity() + } +} + +impl<'a, T> Drop for SendSlice<'a, T> { + #[inline] + fn drop(&mut self) { + self.0.persist_tail(self.1); + } +} diff --git a/quic/s2n-quic-core/src/sync/spsc/slice.rs b/quic/s2n-quic-core/src/sync/spsc/slice.rs new file mode 100644 index 0000000000..ca0e814247 --- /dev/null +++ b/quic/s2n-quic-core/src/sync/spsc/slice.rs @@ -0,0 +1,173 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::{cell::UnsafeCell, mem::MaybeUninit, ops::Deref}; + +#[repr(transparent)] +pub struct Cell(MaybeUninit>); + +impl Cell { + #[inline] + pub unsafe fn write(&self, value: T) { + UnsafeCell::raw_get(self.0.as_ptr()).write(value); + } + + #[inline] + pub unsafe fn take(&self) -> T { + self.0.assume_init_ref().get().read() + } +} + +#[derive(Debug)] +pub struct Slice<'a, T>(pub(super) &'a [T]); + +impl<'a, T> Slice<'a, Cell> { + /// Assumes that the slice of [`Cell`]s is initialized and converts it to a slice of + /// [`UnsafeCell`]s. + /// + /// See [`core::mem::MaybeUninit::assume_init`] + #[inline] + pub unsafe fn assume_init(self) -> Slice<'a, UnsafeCell> { + Slice(&*(self.0 as *const [Cell] as *const [UnsafeCell])) + } + + /// Writes a value into a cell at the provided index + /// + /// # Safety + /// + /// The cell at `index` must be uninitialized and the caller must have synchronized access. + #[inline] + pub unsafe fn write(&self, index: usize, value: T) { + self.0.get_unchecked(index).write(value) + } + + /// Reads and takes the memory at a cell at the provided index + /// + /// # Safety + /// + /// The cell at `index` must be initialized and the caller must have synchronized access. + #[inline] + pub unsafe fn take(&self, index: usize) -> T { + self.0.get_unchecked(index).take() + } +} + +impl<'a, T> Slice<'a, UnsafeCell> { + /// Converts the slice of [`UnsafeCell`]s into a mutable slice + /// + /// # Safety + /// + /// The slice must be exclusively owned, otherwise data races may occur. + #[inline] + pub unsafe fn into_mut(self) -> &'a mut [T] { + let ptr = self.0.as_ptr() as *mut T; + let len = self.0.len(); + core::slice::from_raw_parts_mut(ptr, len) + } +} + +impl<'a, T> Deref for Slice<'a, T> { + type Target = [T]; + + #[inline] + fn deref(&self) -> &[T] { + self.0 + } +} + +impl<'a, T: PartialEq> PartialEq<[T]> for Slice<'a, UnsafeCell> { + #[inline] + fn eq(&self, other: &[T]) -> bool { + if self.len() != other.len() { + return false; + } + + for (a, b) in self.iter().zip(other) { + if unsafe { &*a.get() } != b { + return false; + } + } + + true + } +} + +impl<'a, T: PartialEq> PartialEq>> for [T] { + #[inline] + fn eq(&self, other: &Slice<'a, UnsafeCell>) -> bool { + other.eq(self) + } +} + +impl<'a, T: PartialEq> PartialEq>> for &[T] { + #[inline] + fn eq(&self, other: &Slice<'a, UnsafeCell>) -> bool { + other.eq(self) + } +} + +#[derive(Debug)] +pub struct Pair { + pub head: S, + pub tail: S, +} + +impl<'a, T> Pair>> { + #[inline] + pub unsafe fn assume_init(self) -> Pair>> { + Pair { + head: self.head.assume_init(), + tail: self.tail.assume_init(), + } + } + + #[inline] + pub unsafe fn write(&self, index: usize, value: T) { + self.cell(index).write(value) + } + + #[inline] + pub unsafe fn take(&self, index: usize) -> T { + self.cell(index).take() + } + + unsafe fn cell(&self, index: usize) -> &Cell { + if let Some(cell) = self.head.0.get(index) { + cell + } else { + assume!( + index >= self.head.0.len(), + "index must always be equal or greater than the `head` len" + ); + let index = index - self.head.0.len(); + + assume!( + self.tail.get(index).is_some(), + "index must be in-bounds for the `tail` slice: head={}, tail={}, index={}", + self.head.0.len(), + self.tail.0.len(), + index + ); + self.tail.get_unchecked(index) + } + } + + #[inline] + pub fn iter(&self) -> impl Iterator> { + self.head.0.iter().chain(self.tail.0) + } + + #[inline] + pub fn len(&self) -> usize { + self.head.len() + self.tail.len() + } +} + +impl<'a, T> Pair>> { + #[inline] + pub unsafe fn into_mut(self) -> (&'a mut [T], &'a mut [T]) { + let head = self.head.into_mut(); + let tail = self.tail.into_mut(); + (head, tail) + } +} diff --git a/quic/s2n-quic-core/src/sync/spsc/state.rs b/quic/s2n-quic-core/src/sync/spsc/state.rs new file mode 100644 index 0000000000..d1072d1ec8 --- /dev/null +++ b/quic/s2n-quic-core/src/sync/spsc/state.rs @@ -0,0 +1,566 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{ + primitive::{AtomicBool, AtomicUsize, AtomicWaker, IsZst, Ordering}, + Cell, ClosedError, Result, Slice, +}; +use alloc::alloc::Layout; +use cache_padded::CachePadded; +use core::{ + fmt, + marker::PhantomData, + ops::Deref, + panic::{RefUnwindSafe, UnwindSafe}, + ptr::NonNull, +}; + +type Pair<'a, T> = super::Pair>>; + +const MINIMUM_CAPACITY: usize = 2; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Side { + Sender, + Receiver, +} + +#[derive(Clone, Copy)] +pub struct Cursor { + head: usize, + tail: usize, + capacity: usize, +} + +impl fmt::Debug for Cursor { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Cursor") + .field("head", &self.head) + .field("tail", &self.tail) + .field("len", &self.len()) + .field("capacity", &self.capacity()) + .field("is_empty", &self.is_empty()) + .field("is_full", &self.is_full()) + .field("is_contiguous", &self.is_contiguous()) + .finish() + } +} + +impl Cursor { + #[inline] + fn new(capacity: usize) -> Self { + Self { + head: 0, + tail: 0, + capacity, + } + } + + #[inline] + fn invariants(&self) { + unsafe { + assume!( + self.capacity >= MINIMUM_CAPACITY, + "the capacity must be at least the MINIMUM_CAPACITY value" + ); + assume!( + self.head < self.capacity, + "the `head` pointer should be strictly less than the capacity" + ); + assume!( + self.tail < self.capacity, + "the `tail` pointer should be strictly less than the capacity" + ); + let len = count(self.head, self.tail, self.capacity); + assume!( + len < self.capacity, + "the computed `len` should be strictly less than the capacity" + ); + } + } + + #[inline] + pub fn capacity(&self) -> usize { + self.invariants(); + // To make cursor management easier, we never allow the callers to hit the total capacity. + // We also account for this when allocating the state by adding 1 to the request capacity. + self.capacity - 1 + } + + #[inline] + fn cap(&self) -> usize { + self.invariants(); + self.capacity + } + + #[inline] + pub fn len(&self) -> usize { + self.invariants(); + count(self.head, self.tail, self.cap()) + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.invariants(); + self.tail == self.head + } + + #[inline] + pub fn is_full(&self) -> bool { + self.invariants(); + count(self.tail, self.head, self.cap()) == 1 + } + + #[inline] + pub fn recv_len(&self) -> usize { + self.invariants(); + self.len() + } + + #[inline] + pub fn send_capacity(&self) -> usize { + self.invariants(); + self.capacity() - self.recv_len() + } + + #[inline] + pub fn increment_head(&mut self, n: usize) { + self.invariants(); + unsafe { + assume!( + n <= self.capacity(), + "n should never exceed the total capacity" + ); + } + self.head = self.wrap_add(self.head, n); + self.invariants(); + } + + #[inline] + pub fn increment_tail(&mut self, n: usize) { + self.invariants(); + unsafe { + assume!( + n <= self.capacity(), + "n should never exceed the total capacity" + ); + } + self.tail = self.wrap_add(self.tail, n); + self.invariants(); + } + + #[inline] + fn wrap_add(&self, idx: usize, addend: usize) -> usize { + wrap_index(idx.wrapping_add(addend), self.cap()) + } + + #[inline] + fn is_contiguous(&self) -> bool { + self.tail >= self.head + } +} + +/// Returns the index in the underlying buffer for a given logical element index. +#[inline] +fn wrap_index(index: usize, size: usize) -> usize { + // size is always a power of 2 + unsafe { + assume!( + size.is_power_of_two(), + "The calculations in the lengths rely on the capacity being a power of 2" + ); + assume!( + size >= MINIMUM_CAPACITY, + "The calculations in the lengths rely on the capacity being at least {}", + MINIMUM_CAPACITY + ); + } + index & (size - 1) +} + +/// Calculate the number of elements left to be read in the buffer +#[inline] +fn count(head: usize, tail: usize, size: usize) -> usize { + // size is always a power of 2 + unsafe { + assume!( + size.is_power_of_two(), + "The calculations in the lengths rely on the capacity being a power of 2" + ); + assume!( + size >= MINIMUM_CAPACITY, + "The calculations in the lengths rely on the capacity being at least {}", + MINIMUM_CAPACITY + ); + } + (tail.wrapping_sub(head)) & (size - 1) +} + +/// The synchronized state between two peers +/// +/// The internal design of the cursor management is based on [`alloc::collections::VecDeque`]. +pub struct State { + header: NonNull>, + pub cursor: Cursor, +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("State") + .field("header", self.deref()) + .field("cursor", &self.cursor) + .finish() + } +} + +/// Safety: synchronization of state is managed through atomic values +unsafe impl Send for State {} + +/// Safety: synchronization of state is managed through atomic values +unsafe impl Sync for State {} + +/// The data behind the header pointer itself is unwind safe +impl UnwindSafe for State {} + +impl Clone for State { + #[inline] + fn clone(&self) -> Self { + Self { + header: self.header, + cursor: self.cursor, + } + } +} + +impl Deref for State { + type Target = Header; + + #[inline] + fn deref(&self) -> &Self::Target { + unsafe { self.header.as_ref() } + } +} + +impl State { + #[inline] + pub fn new(capacity: usize) -> Self { + // If we're sending a zero-sized type, set the capacity to the maximum value, since we're + // not sending any data and just coordinating cursors at this point + let capacity = if T::IS_ZST { + // The total capacity must be a power of two + usize::MAX / 2 + 1 + } else { + // Add 1 to the requested capacity so it's easier to manage cursor wrapping + core::cmp::max(capacity + 1, MINIMUM_CAPACITY).next_power_of_two() + }; + let header = Header::alloc(capacity).expect("could not allocate channel"); + let cursor = Cursor::new(capacity); + Self { header, cursor } + } + + /// Tries to acquire more unfilled slots on the channel + /// + /// If the channel is closed, an error is returned. If the channel has at least one slot of + /// capacity, `true` is returned. Otherwise `false` is returned. + #[inline] + pub fn acquire_capacity(&mut self) -> Result { + if !self.open.load(Ordering::Acquire) { + return Err(ClosedError); + } + + if !self.cursor.is_full() { + return Ok(true); + } + + // update the cached version + self.cursor.head = self.head.load(Ordering::Acquire); + + let is_full = self.cursor.is_full(); + + Ok(!is_full) + } + + /// Tries to acquire more filled slots on the channel + /// + /// If the channel is closed, an error is returned. If the channel has at least one slot of + /// capacity, `true` is returned. Otherwise `false` is returned. + #[inline] + pub fn acquire_filled(&mut self) -> Result { + if !self.cursor.is_empty() { + return Ok(true); + } + + self.cursor.tail = self.tail.load(Ordering::Acquire); + + if !self.cursor.is_empty() { + return Ok(true); + } + + if !self.open.load(Ordering::Acquire) { + // make one more effort to load the remaining items + self.cursor.tail = self.tail.load(Ordering::Acquire); + + if !self.cursor.is_empty() { + return Ok(true); + } + + return Err(ClosedError); + } + + Ok(false) + } + + /// Notifies the peer of `head` updates for the given cursor + #[inline] + pub fn persist_head(&self, prev: Cursor) { + // nothing changed + if prev.head == self.cursor.head { + return; + } + + self.head.store(self.cursor.head, Ordering::Release); + + self.sender.wake(); + } + + /// Notifies the peer of `tail` updates for the given cursor + #[inline] + pub fn persist_tail(&self, prev: Cursor) { + // nothing changed + if prev.tail == self.cursor.tail { + return; + } + + self.tail.store(self.cursor.tail, Ordering::Release); + + self.receiver.wake(); + } + + #[inline] + fn data(&self) -> &[Cell] { + unsafe { + // Safety: the state must still be allocated and the cursor inbounds + let ptr = self.data_ptr(); + let capacity = self.cursor.capacity; + core::slice::from_raw_parts(ptr, capacity) + } + } + + #[inline] + fn data_ptr(&self) -> *const Cell { + unsafe { + // If the type is zero-sized, no need to calculate offsets + if T::IS_ZST { + return NonNull::>::dangling().as_ptr(); + } + + // Safety: the state must still be allocated and the cursor inbounds + let capacity = self.cursor.capacity; + let (_, offset) = Header::::layout_unchecked(capacity); + + let ptr = self.header.as_ptr() as *const u8; + let ptr = ptr.add(offset); + ptr as *const Cell + } + } + + /// Closes one side of the channel and notifies the peer of the event + #[inline] + pub fn close(&mut self, side: Side) { + // notify the other side that we've closed the channel + match side { + Side::Sender => self.receiver.wake(), + Side::Receiver => self.sender.wake(), + } + + let was_open = self.open.swap(false, Ordering::SeqCst); + + // make sure the peer is notified before fully dropping the contents + match side { + Side::Sender => self.receiver.wake(), + Side::Receiver => self.sender.wake(), + } + + if !was_open { + unsafe { + // Safety: we synchronization closing between the two peers through atomic + // variables. At this point both sides have agreed on its final state. + self.drop_contents(); + } + } + } + + /// Returns the channel slots as two pairs of filled and unfilled slices + #[inline] + pub fn as_pairs(&self) -> (Pair, Pair) { + let data = self.data(); + self.data_to_pairs(data) + } + + #[inline] + fn data_to_pairs<'a>(&self, data: &'a [Cell]) -> (Pair<'a, T>, Pair<'a, T>) { + self.cursor.invariants(); + + let head = self.cursor.head; + let tail = self.cursor.tail; + + let (filled, unfilled) = if self.cursor.is_contiguous() { + unsafe { + assume!(data.len() >= tail, "data must span the tail length"); + } + let (data, unfilled_head) = data.split_at(tail); + + unsafe { + assume!(data.len() >= head, "data must span the head length"); + } + let (unfilled_tail, filled_head) = data.split_at(head); + + let filled = Pair { + head: Slice(filled_head), + tail: Slice(&[]), + }; + let unfilled = Pair { + head: Slice(unfilled_head), + tail: Slice(unfilled_tail), + }; + (filled, unfilled) + } else { + unsafe { + assume!(data.len() >= head, "data must span the head length"); + } + let (data, filled_head) = data.split_at(head); + + unsafe { + assume!(data.len() >= tail, "data must span the tail length"); + } + let (filled_tail, unfilled_head) = data.split_at(tail); + + let filled = Pair { + head: Slice(filled_head), + tail: Slice(filled_tail), + }; + let unfilled = Pair { + head: Slice(unfilled_head), + tail: Slice(&[]), + }; + (filled, unfilled) + }; + + unsafe { + assume!( + filled.len() == self.cursor.recv_len(), + "filled len should agree with the cursor len {} == {}\n{:?}", + filled.len(), + self.cursor.recv_len(), + self.cursor + ); + } + + (filled, unfilled) + } + + /// Frees the contents of the channel + /// + /// # Safety + /// + /// Each side must have synchronized and agreed on the final state before calling this + #[inline] + unsafe fn drop_contents(&mut self) { + // refresh the cursor from the shared state + self.cursor.head = self.head.load(Ordering::Acquire); + self.cursor.tail = self.tail.load(Ordering::Acquire); + + // release all of the filled data + let (filled, _unfilled) = self.as_pairs(); + if !T::IS_ZST { + for cell in filled.iter() { + drop(cell.take()); + } + } + + // make sure we free any stored wakers + let header = self.header.as_mut(); + drop(header.receiver.take()); + drop(header.sender.take()); + + // free the header + let ptr = self.header.as_ptr() as *mut u8; + let capacity = self.cursor.capacity; + let (layout, _offset) = Header::::layout_unchecked(capacity); + alloc::alloc::dealloc(ptr, layout) + } +} + +pub struct Header { + head: CachePadded, + tail: CachePadded, + open: CachePadded, + pub receiver: AtomicWaker, + pub sender: AtomicWaker, + data: PhantomData, +} + +impl fmt::Debug for Header { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Header") + .field("head", &self.head.load(Ordering::Relaxed)) + .field("tail", &self.tail.load(Ordering::Relaxed)) + .field("open", &self.open.load(Ordering::Relaxed)) + .field("receiver", &self.receiver) + .field("sender", &self.sender) + .finish() + } +} + +impl Header { + /// Allocates a header and data slice for the given capacity + fn alloc(capacity: usize) -> Option> { + unsafe { + // Safety: we assume that `alloc` gives us a valid pointer to write to + let (layout, _offset) = Self::layout(capacity).ok()?; + let state = alloc::alloc::alloc(layout); + let state = state as *mut Self; + let state = NonNull::new(state)?; + + state.as_ptr().write(Self::new()); + + Some(state) + } + } + + #[inline] + fn new() -> Self { + Self { + head: CachePadded::new(AtomicUsize::new(0)), + tail: CachePadded::new(AtomicUsize::new(0)), + sender: AtomicWaker::new(), + receiver: AtomicWaker::new(), + open: CachePadded::new(AtomicBool::new(true)), + data: PhantomData, + } + } + + /// Computes the checked layout for the header + #[inline] + fn layout(capacity: usize) -> Result<(Layout, usize), alloc::alloc::LayoutError> { + let header_layout = Layout::new::(); + // A slice of cells is allocated in the same region as the header + let data_layout = Layout::array::>(capacity)?; + let (layout, offset) = header_layout.extend(data_layout)?; + Ok((layout, offset)) + } + + /// Computes the memory layout of the header without checking of its validatity + /// + /// # Safety + /// + /// The layout must have been previously checked before calling this. + #[inline] + unsafe fn layout_unchecked(capacity: usize) -> (Layout, usize) { + if let Ok(v) = Self::layout(capacity) { + v + } else { + core::hint::unreachable_unchecked() + } + } +} diff --git a/quic/s2n-quic-core/src/sync/spsc/tests.rs b/quic/s2n-quic-core/src/sync/spsc/tests.rs new file mode 100644 index 0000000000..2f799db950 --- /dev/null +++ b/quic/s2n-quic-core/src/sync/spsc/tests.rs @@ -0,0 +1,512 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use bolero::{check, generator::*}; +use core::task::{Context, Poll, Waker}; +use futures_test::task::{new_count_waker, AwokenCount}; +use std::collections::VecDeque; + +#[derive(Clone, Copy, Debug, TypeGenerator)] +enum Operation { + Push(u16), + AsyncPush(u16), + Pop(u16), + AsyncPop(u16), + Clear, + DropSend, + DropRecv, +} + +struct WakeState { + waker: Waker, + wake_count: AwokenCount, + snapshot: Option, +} + +impl Default for WakeState { + fn default() -> Self { + let (waker, wake_count) = new_count_waker(); + Self { + waker, + wake_count, + snapshot: None, + } + } +} + +impl WakeState { + fn context(&self) -> Context { + Context::from_waker(&self.waker) + } + + fn snapshot(&mut self) { + self.snapshot = Some(self.wake_count.get()); + } + + fn assert_wake(&mut self) { + if let Some(prev) = self.snapshot.take() { + assert_eq!(self.wake_count.get(), prev + 1); + } + } +} + +struct Model { + oracle: VecDeque, + send: Option>, + send_waker: WakeState, + recv: Option>, + recv_waker: WakeState, + capacity: usize, +} + +impl Model { + fn new(capacity: usize) -> Self { + let (send, recv) = channel(capacity); + let capacity = send.capacity(); + Self { + oracle: Default::default(), + send: Some(send), + send_waker: Default::default(), + recv: Some(recv), + recv_waker: Default::default(), + capacity, + } + } + + fn apply_all(&mut self, operations: &[Operation], mut generator: impl FnMut() -> T) { + for op in operations { + self.apply(*op, &mut generator); + } + } + + fn apply(&mut self, operation: Operation, mut generator: impl FnMut() -> T) { + match operation { + Operation::Push(count) => { + let mut did_push = false; + if let Some(send) = self.send.as_mut() { + match send.try_slice() { + Ok(Some(mut slice)) => { + for _ in 0..count { + let value = generator(); + if slice.push(value.clone()).is_ok() { + self.oracle.push_back(value); + did_push = true; + } + } + } + Ok(None) => { + assert_eq!( + self.oracle.len(), + self.capacity, + "slice should return None when at capacity" + ); + } + Err(_) => { + assert!(self.recv.is_none()); + } + } + } + + if did_push { + self.recv_waker.assert_wake(); + } + } + Operation::AsyncPush(count) => { + let mut did_push = false; + if let Some(send) = self.send.as_mut() { + match send.poll_slice(&mut self.send_waker.context()) { + Poll::Ready(Ok(mut slice)) => { + for _ in 0..count { + let value = generator(); + if slice.push(value.clone()).is_ok() { + self.oracle.push_back(value); + + did_push = true; + } + } + } + Poll::Ready(Err(_)) => { + assert!(self.recv.is_none()); + } + Poll::Pending => { + assert_eq!( + self.oracle.len(), + self.capacity, + "slice should return Pending when at capacity" + ); + self.send_waker.snapshot(); + } + } + } + + if did_push { + self.recv_waker.assert_wake(); + } + } + Operation::Pop(count) => { + let mut did_pop = false; + if let Some(recv) = self.recv.as_mut() { + match recv.try_slice() { + Ok(Some(mut slice)) => { + for _ in 0..count { + let value = slice.pop(); + assert_eq!(value, self.oracle.pop_front()); + did_pop |= value.is_some(); + } + } + Ok(None) => { + assert!(self.oracle.is_empty()); + } + Err(_) => { + assert!(self.send.is_none()); + assert!(self.oracle.is_empty()); + } + } + } + + if did_pop { + self.send_waker.assert_wake(); + } + } + Operation::AsyncPop(count) => { + let mut did_pop = false; + if let Some(recv) = self.recv.as_mut() { + match recv.poll_slice(&mut self.recv_waker.context()) { + Poll::Ready(Ok(mut slice)) => { + for _ in 0..count { + let value = slice.pop(); + assert_eq!(value, self.oracle.pop_front()); + did_pop |= value.is_some(); + } + } + Poll::Ready(Err(_)) => { + assert!(self.send.is_none()); + assert!(self.oracle.is_empty()); + } + Poll::Pending => { + assert!(self.oracle.is_empty()); + self.recv_waker.snapshot(); + } + } + } + + if did_pop { + self.send_waker.assert_wake(); + } + } + Operation::Clear => { + let mut did_pop = false; + if let Some(recv) = self.recv.as_mut() { + match recv.try_slice() { + Ok(Some(mut slice)) => { + // we need to pull in the latest values to clear everything + while slice.clear() > 0 { + did_pop = true; + let _ = slice.peek(); + } + self.oracle.clear(); + } + Ok(None) => { + assert!(self.oracle.is_empty()); + } + Err(_) => { + assert!(self.send.is_none()); + assert!(self.oracle.is_empty()); + } + } + } + + if did_pop { + self.send_waker.assert_wake(); + } + } + Operation::DropSend => { + self.send = None; + } + Operation::DropRecv => { + self.recv = None; + } + } + } + + fn finish(mut self) { + loop { + self.apply(Operation::Pop(u16::MAX), || unimplemented!()); + if self.oracle.is_empty() || self.recv.is_none() { + return; + } + } + } +} + +#[cfg(any(kani, miri))] +type Operations = crate::testing::InlineVec; +#[cfg(not(any(kani, miri)))] +type Operations = Vec; + +#[test] +fn model() { + let max_capacity = if cfg!(any(kani, miri)) { 2 } else { 64 }; + + let generator = (1usize..max_capacity, gen::()); + + check!() + .with_generator(generator) + .for_each(|(capacity, ops)| { + let mut model = Model::new(*capacity); + let mut cursor = 0; + let generator = || { + let v = cursor; + cursor += 1; + Box::new(v) + }; + + model.apply_all(ops, generator); + model.finish(); + }) +} + +#[test] +fn model_zst() { + let max_capacity = if cfg!(any(kani, miri)) { 2 } else { 64 }; + + let generator = (1usize..max_capacity, gen::()); + + check!() + .with_generator(generator) + .for_each(|(capacity, ops)| { + let mut model = Model::new(*capacity); + let generator = || (); + + model.apply_all(ops, generator); + model.finish(); + }) +} + +#[test] +// TODO enable this once https://github.com/model-checking/kani/pull/2172 is merged and released +// #[cfg_attr(kani, kani::proof, kani::unwind(3))] +fn alloc_test() { + let capacity = if cfg!(any(kani, miri)) { + 1usize..3 + } else { + 1usize..4096 + }; + + check!() + .with_generator((capacity, gen::())) + .cloned() + .for_each(|(capacity, push_value)| { + let (mut send, mut recv) = channel(capacity); + + // kani takes a very long time to check the push/pop functionality so bail for now + if cfg!(not(kani_slow)) { + return; + } + + send.try_slice().unwrap().unwrap().push(push_value).unwrap(); + + let pop_value = recv.try_slice().unwrap().unwrap().pop().unwrap(); + assert_eq!(pop_value, push_value); + }) +} + +#[cfg(not(loom))] +mod loom { + pub use std::*; + + pub mod future { + pub use futures::executor::block_on; + } + + pub fn model R, R>(f: F) -> R { + f() + } +} + +const CAPACITY: usize = if cfg!(loom) { 2 } else { 10 }; +const BATCH_COUNT: usize = if cfg!(loom) { 2 } else { 100 }; +const BATCH_SIZE: usize = if cfg!(loom) { 3 } else { 20 }; +const EXPECTED_COUNT: usize = BATCH_COUNT * BATCH_SIZE; + +// TODO The async rx loom tests seem to take an unbounded amount if time if the batch count/size is +// anything bigger than 1. Ideally, these sizes would be bigger to test more permutations of +// orderings so we should investigate what's causing loom to endless spin. +const ASYNC_RX_BATCH_COUNT: usize = if cfg!(loom) { 1 } else { BATCH_COUNT }; +const ASYNC_RX_BATCH_SIZE: usize = if cfg!(loom) { 1 } else { BATCH_SIZE }; +const ASYNC_RX_EXPECTED_COUNT: usize = ASYNC_RX_BATCH_COUNT * ASYNC_RX_BATCH_SIZE; + +#[test] +fn loom_spin_tx_spin_rx_test() { + loom_scenario( + CAPACITY, + |send| loom_spin_tx(send, BATCH_COUNT, BATCH_SIZE), + |recv| loom_spin_rx(recv, EXPECTED_COUNT), + ) +} + +#[test] +fn loom_spin_tx_async_rx_test() { + loom_scenario( + CAPACITY, + |send| loom_spin_tx(send, ASYNC_RX_BATCH_COUNT, ASYNC_RX_BATCH_SIZE), + |recv| loom_async_rx(recv, ASYNC_RX_EXPECTED_COUNT), + ) +} + +#[test] +fn loom_async_tx_spin_rx_test() { + loom_scenario( + CAPACITY, + |send| loom_async_tx(send, BATCH_COUNT, BATCH_SIZE), + |recv| loom_spin_rx(recv, EXPECTED_COUNT), + ) +} + +#[test] +fn loom_async_tx_async_rx_test() { + loom_scenario( + CAPACITY, + |send| loom_async_tx(send, ASYNC_RX_BATCH_COUNT, ASYNC_RX_BATCH_SIZE), + |recv| loom_async_rx(recv, ASYNC_RX_EXPECTED_COUNT), + ) +} + +fn loom_scenario(capacity: usize, sender: fn(Sender), receiver: fn(Receiver)) { + loom::model(move || { + let (send, recv) = channel(capacity); + + let a = loom::thread::spawn(move || sender(send)); + + let b = loom::thread::spawn(move || receiver(recv)); + + // loom tests will still run after returning so we don't need to join + if cfg!(not(loom)) { + a.join().unwrap(); + b.join().unwrap(); + } + }); +} + +fn loom_spin_rx(mut recv: Receiver, expected: usize) { + use loom::hint; + + let mut value = 0u32; + loop { + match recv.try_slice() { + Ok(Some(mut slice)) => { + while let Some(actual) = slice.pop() { + assert_eq!(actual, value); + value += 1; + } + } + Ok(None) => hint::spin_loop(), + Err(_) => { + assert_eq!(value as usize, expected); + return; + } + } + } +} + +fn loom_async_rx(mut recv: Receiver, expected: usize) { + use futures::{future::poll_fn, ready}; + + loom::future::block_on(async move { + let mut value = 0u32; + poll_fn(|cx| loop { + match ready!(recv.poll_slice(cx)) { + Ok(mut slice) => { + while let Some(actual) = slice.pop() { + assert_eq!(actual, value); + value += 1; + } + } + Err(_err) => return Poll::Ready(()), + } + }) + .await; + + assert_eq!(value as usize, expected); + }); +} + +fn loom_spin_tx(mut send: Sender, batch_count: usize, batch_size: usize) { + use loom::hint; + + let max_value = (batch_count * batch_size) as u32; + let mut value = 0u32; + + 'done: while max_value > value { + let mut remaining = batch_size; + while remaining > 0 { + match send.try_slice() { + Ok(Some(mut slice)) => { + let num_items = remaining; + for _ in 0..num_items { + if slice.push(value).is_err() { + hint::spin_loop(); + continue; + } + value += 1; + remaining -= 1; + } + } + Ok(None) => { + // we don't have capacity to send so yield the thread + hint::spin_loop(); + } + Err(_) => { + // The peer dropped the channel so bail + break 'done; + } + } + } + } + + assert_eq!(value, max_value); +} + +fn loom_async_tx(mut send: Sender, batch_count: usize, batch_size: usize) { + use futures::{future::poll_fn, ready}; + + loom::future::block_on(async move { + let max_value = (batch_count * batch_size) as u32; + let mut value = 0u32; + + while max_value > value { + let mut remaining = batch_size; + let result = poll_fn(|cx| loop { + return match ready!(send.poll_slice(cx)) { + Ok(mut slice) => { + let num_items = remaining; + for _ in 0..num_items { + if slice.push(value).is_err() { + // try polling the slice capacity again + break; + } + value += 1; + remaining -= 1; + } + + if remaining > 0 { + continue; + } + + Ok(()) + } + Err(err) => Err(err), + } + .into(); + }) + .await; + + if result.is_err() { + return; + } + } + + assert_eq!(value, max_value); + }); +} diff --git a/quic/s2n-quic-core/src/testing.rs b/quic/s2n-quic-core/src/testing.rs new file mode 100644 index 0000000000..9cdba7606c --- /dev/null +++ b/quic/s2n-quic-core/src/testing.rs @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use bolero_generator::TypeGenerator; + +#[derive(Clone, Copy, Debug, TypeGenerator)] +pub struct InlineVec { + values: [T; LEN], + + #[generator(_code = "0..LEN")] + len: usize, +} + +impl core::ops::Deref for InlineVec { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + &self.values[..self.len] + } +} + +impl core::ops::DerefMut for InlineVec { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.values[..self.len] + } +} diff --git a/quic/s2n-quic-crypto/src/aes/x86.rs b/quic/s2n-quic-crypto/src/aes/x86.rs index 5544812593..4dd66bb31d 100644 --- a/quic/s2n-quic-crypto/src/aes/x86.rs +++ b/quic/s2n-quic-crypto/src/aes/x86.rs @@ -6,6 +6,7 @@ use crate::{ arch::*, block::{BatchMut, Block, Zeroed}, }; +use s2n_quic_core::assume; use zeroize::Zeroize; #[cfg(any(test, feature = "testing"))] @@ -67,7 +68,7 @@ impl super::aes128::EncryptionKey for EncryptionKey { #[inline(always)] fn keyround(&self, index: usize) -> &Self::KeyRound { unsafe { - unsafe_assert!(index < N); + assume!(index < N); self.0.get_unchecked(index) } } @@ -80,7 +81,7 @@ impl super::aes256::EncryptionKey for EncryptionKey { #[inline(always)] fn keyround(&self, index: usize) -> &Self::KeyRound { unsafe { - unsafe_assert!(index < N); + assume!(index < N); self.0.get_unchecked(index) } } @@ -96,7 +97,7 @@ impl super::aes128::DecryptionKey for DecryptionKey { #[inline(always)] fn keyround(&self, index: usize) -> &Self::KeyRound { unsafe { - unsafe_assert!(index < N); + assume!(index < N); self.0.get_unchecked(index) } } @@ -109,7 +110,7 @@ impl super::aes256::DecryptionKey for DecryptionKey { #[inline(always)] fn keyround(&self, index: usize) -> &Self::KeyRound { unsafe { - unsafe_assert!(index < N); + assume!(index < N); self.0.get_unchecked(index) } } diff --git a/quic/s2n-quic-crypto/src/aesgcm/generic.rs b/quic/s2n-quic-crypto/src/aesgcm/generic.rs index a7091171cb..11340a7537 100644 --- a/quic/s2n-quic-crypto/src/aesgcm/generic.rs +++ b/quic/s2n-quic-crypto/src/aesgcm/generic.rs @@ -16,6 +16,7 @@ use core::{ marker::PhantomData, sync::atomic::{compiler_fence, Ordering}, }; +use s2n_quic_core::assume; use zeroize::Zeroize; pub struct AesGcm { @@ -135,7 +136,7 @@ where // then fill the last index with the ek0 block if can_interleave_ek0 && payload_block_count < N { unsafe { - unsafe_assert!(cipher_blocks.len() > N - 1); + assume!(cipher_blocks.len() > N - 1); *cipher_blocks.get_unchecked_mut(N - 1) = ek0; did_interleave_ek0 = true; } @@ -151,7 +152,7 @@ where } let block = unsafe { - unsafe_assert!(idx < ghash_blocks.len()); + assume!(idx < ghash_blocks.len()); ghash_blocks.get_unchecked(idx) }; self.ghash.update(&mut ghash_state, block); @@ -183,14 +184,14 @@ where |idx, block| { // XOR the cipher blocks into the payload let ghash_block = unsafe { - unsafe_assert!(payload.len() >= BLOCK_LEN); + assume!(payload.len() >= BLOCK_LEN); let payload_block = payload.read_block(); payload.xor_block(payload_block, *block) }; // move the cipher blocks to be hashed on the next batch unsafe { - unsafe_assert!(idx < ghash_blocks.len()); + assume!(idx < ghash_blocks.len()); *ghash_blocks.get_unchecked_mut(idx) = ghash_block; } }, @@ -201,7 +202,7 @@ where } unsafe { - unsafe_assert!( + assume!( partial_blocks <= N, "only a single batch should be left to process" ); @@ -219,13 +220,13 @@ where // XOR the cipher blocks into the payload let ghash_block = if idx == last_block_idx { unsafe { - unsafe_assert!(0 < payload.len() && payload.len() < BLOCK_LEN); + assume!(0 < payload.len() && payload.len() < BLOCK_LEN); let payload_block = payload.read_last_block(payload_rem); payload.xor_last_block(payload_block, *cipher_block, payload_rem) } } else { unsafe { - unsafe_assert!(payload.len() >= BLOCK_LEN); + assume!(payload.len() >= BLOCK_LEN); let payload_block = payload.read_block(); payload.xor_block(payload_block, *cipher_block) } @@ -244,7 +245,7 @@ where ); if can_interleave_ek0 { ek0 = unsafe { - unsafe_assert!(cipher_blocks.len() > N - 1); + assume!(cipher_blocks.len() > N - 1); *cipher_blocks.get_unchecked(N - 1) }; } else { @@ -332,7 +333,7 @@ where unsafe { // since QUIC short packets only contain small AAD values, we can limit the // amount of work to a single batch size. - unsafe_assert!( + assume!( len <= N * BLOCK_LEN, "aad cannot exceed {} bytes; got {}", N * BLOCK_LEN, diff --git a/quic/s2n-quic-crypto/src/aesgcm/x86.rs b/quic/s2n-quic-crypto/src/aesgcm/x86.rs index 6721d18214..080b9dbaf4 100644 --- a/quic/s2n-quic-crypto/src/aesgcm/x86.rs +++ b/quic/s2n-quic-crypto/src/aesgcm/x86.rs @@ -9,6 +9,7 @@ use crate::{ Block, }, }; +use s2n_quic_core::assume; #[cfg(any(test, feature = "testing"))] pub mod testing; @@ -21,13 +22,13 @@ impl Payload<__m128i> for &mut [u8] { #[inline(always)] unsafe fn read_block(&self) -> __m128i { - unsafe_assert!(self.len() >= BLOCK_LEN); + assume!(self.len() >= BLOCK_LEN); _mm_loadu_si128(*self as *const _ as *const _) } #[inline(always)] unsafe fn xor_block(&mut self, cleartext_block: __m128i, aes_block: __m128i) -> __m128i { - unsafe_assert!(self.len() >= BLOCK_LEN); + assume!(self.len() >= BLOCK_LEN); let addr = *self as *mut [u8] as *mut u8; // read the cleartext block and XOR it with the provided AES block @@ -46,8 +47,8 @@ impl Payload<__m128i> for &mut [u8] { #[inline(always)] unsafe fn read_last_block(&self, len: usize) -> __m128i { - unsafe_assert!(0 < len && len < BLOCK_LEN); - unsafe_assert!(self.len() == len); + assume!(0 < len && len < BLOCK_LEN); + assume!(self.len() == len); __m128i::from_slice(self) } @@ -58,8 +59,8 @@ impl Payload<__m128i> for &mut [u8] { aes_block: __m128i, len: usize, ) -> __m128i { - unsafe_assert!(0 < len && len < BLOCK_LEN); - unsafe_assert!(self.len() == len); + assume!(0 < len && len < BLOCK_LEN); + assume!(self.len() == len); let addr = *self as *mut [u8] as *mut u8; let xored = cleartext_block.xor(aes_block.mask(len)); diff --git a/quic/s2n-quic-crypto/src/block/x86.rs b/quic/s2n-quic-crypto/src/block/x86.rs index daad763a11..241c31c1b2 100644 --- a/quic/s2n-quic-crypto/src/block/x86.rs +++ b/quic/s2n-quic-crypto/src/block/x86.rs @@ -6,6 +6,7 @@ use crate::{ block::{Batch, BatchMut, Block, Zeroed}, }; use core::mem::size_of; +use s2n_quic_core::assume; pub const LEN: usize = size_of::<__m128i>(); @@ -126,7 +127,7 @@ impl M128iExt for __m128i { fn into_slice(self, bytes: &mut [u8]) { unsafe { debug_assert!(Avx2::is_supported()); - unsafe_assert!(bytes.len() <= LEN); + assume!(bytes.len() <= LEN); copy_128( &self as *const _ as *const u8, bytes.as_mut_ptr(), @@ -139,7 +140,7 @@ impl M128iExt for __m128i { fn mask(self, len: usize) -> Self { unsafe { debug_assert!(Avx2::is_supported()); - unsafe_assert!(0 < len && len < LEN); + assume!(0 < len && len < LEN); // compute a mask that can be shifted to only include a `len` of bytes const MASK: [u8; 31] = { diff --git a/quic/s2n-quic-crypto/src/cipher_suite.rs b/quic/s2n-quic-crypto/src/cipher_suite.rs index 660a44c64a..aad0ba6d18 100644 --- a/quic/s2n-quic-crypto/src/cipher_suite.rs +++ b/quic/s2n-quic-crypto/src/cipher_suite.rs @@ -4,7 +4,10 @@ use crate::{aead::Aead, header_key::HeaderKey, iv}; use ::ring::{aead, hkdf}; use core::fmt; -use s2n_quic_core::crypto::{label, CryptoError}; +use s2n_quic_core::{ + assume, + crypto::{label, CryptoError}, +}; use zeroize::{Zeroize, Zeroizing}; mod negotiated; @@ -141,7 +144,7 @@ macro_rules! impl_cipher_suite { use core::convert::TryInto; let res = (&tag[..]).try_into(); unsafe { - unsafe_assert!(res.is_ok()); + assume!(res.is_ok()); } res.unwrap() }; @@ -170,7 +173,7 @@ macro_rules! impl_cipher_suite { use core::convert::TryInto; let res = tag.try_into(); unsafe { - unsafe_assert!(res.is_ok()); + assume!(res.is_ok()); } res.unwrap() }; diff --git a/quic/s2n-quic-crypto/src/ghash/x86/precomputed.rs b/quic/s2n-quic-crypto/src/ghash/x86/precomputed.rs index b8fc24c9f8..0327470045 100644 --- a/quic/s2n-quic-crypto/src/ghash/x86/precomputed.rs +++ b/quic/s2n-quic-crypto/src/ghash/x86/precomputed.rs @@ -10,6 +10,7 @@ use crate::{ KEY_LEN, }, }; +use s2n_quic_core::assume; use zeroize::{DefaultIsZeroes, Zeroize}; impl ghash::GHash for P { @@ -76,7 +77,7 @@ impl Powers for Allocated { #[inline(always)] fn power(&self, index: usize) -> &H { unsafe { - unsafe_assert!(index < self.state.len()); + assume!(index < self.state.len()); self.state.get_unchecked(index) } } @@ -125,7 +126,7 @@ impl Powers for Array { #[inline(always)] fn power(&self, index: usize) -> &H { unsafe { - unsafe_assert!(index < self.state.len()); + assume!(index < self.state.len()); self.state.get_unchecked(index) } } @@ -169,7 +170,7 @@ impl State { fn update(&self, powers: &P, b: &__m128i) -> Self { unsafe { debug_assert!(Avx2::is_supported()); - unsafe_assert!( + assume!( self.power != 0, "update called more than requested capacity" ); @@ -208,7 +209,7 @@ impl State { unsafe { debug_assert!(Avx2::is_supported()); - unsafe_assert!( + assume!( power == 0, "ghash update count incorrect: remaining {}", power diff --git a/quic/s2n-quic-crypto/src/lib.rs b/quic/s2n-quic-crypto/src/lib.rs index bc397b37fc..dd1ef21277 100644 --- a/quic/s2n-quic-crypto/src/lib.rs +++ b/quic/s2n-quic-crypto/src/lib.rs @@ -1,24 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -/// Asserts that a boolean expression is true at runtime, only if debug_assertions are enabled. -/// -/// Otherwise, the compiler is told to assume that the expression is always true and can perform -/// additional optimizations. -macro_rules! unsafe_assert { - ($cond:expr) => { - unsafe_assert!($cond, "assumption failed: {}", stringify!($cond)); - }; - ($cond:expr $(, $fmtarg:expr)* $(,)?) => { - let v = $cond; - - debug_assert!(v $(, $fmtarg)*); - if cfg!(not(debug_assertions)) && !v { - core::hint::unreachable_unchecked(); - } - }; -} - #[macro_use] mod negotiated; #[macro_use]