diff --git a/dc/s2n-quic-dc/Cargo.toml b/dc/s2n-quic-dc/Cargo.toml index c77c4562a3..e3eae626a0 100644 --- a/dc/s2n-quic-dc/Cargo.toml +++ b/dc/s2n-quic-dc/Cargo.toml @@ -25,6 +25,7 @@ bytes = "1" crossbeam-channel = "0.5" crossbeam-epoch = "0.9" crossbeam-queue = { version = "0.3" } +event-listener-strategy = "0.5" flurry = "0.5" libc = "0.2" num-rational = { version = "0.4", default-features = false } diff --git a/dc/s2n-quic-dc/src/credentials.rs b/dc/s2n-quic-dc/src/credentials.rs index da1c44338d..57ac71223c 100644 --- a/dc/s2n-quic-dc/src/credentials.rs +++ b/dc/s2n-quic-dc/src/credentials.rs @@ -16,18 +16,7 @@ pub use s2n_quic_core::varint::VarInt as KeyId; pub mod testing; #[derive( - Clone, - Copy, - Default, - PartialEq, - Eq, - Hash, - AsBytes, - FromBytes, - FromZeroes, - Unaligned, - PartialOrd, - Ord, + Clone, Copy, Default, PartialEq, Eq, AsBytes, FromBytes, FromZeroes, Unaligned, PartialOrd, Ord, )] #[cfg_attr( any(test, feature = "testing"), @@ -36,6 +25,15 @@ pub mod testing; #[repr(C)] pub struct Id([u8; 16]); +impl std::hash::Hash for Id { + fn hash(&self, state: &mut H) { + // The ID has very high quality entropy already, so write just one half of it to keep hash + // costs as low as possible. For the main use of the Hash impl in the fixed-size ID map + // this translates to just directly using these bytes for the indexing. + state.write_u64(u64::from_ne_bytes(self.0[..8].try_into().unwrap())); + } +} + impl fmt::Debug for Id { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { format_args!("{:#01x}", u128::from_be_bytes(self.0)).fmt(f) diff --git a/dc/s2n-quic-dc/src/fixed_map.rs b/dc/s2n-quic-dc/src/fixed_map.rs index 576abeba5e..18848b9737 100644 --- a/dc/s2n-quic-dc/src/fixed_map.rs +++ b/dc/s2n-quic-dc/src/fixed_map.rs @@ -8,6 +8,7 @@ //! extent possible) reducing the likelihood. use core::{ + fmt::Debug, hash::Hash, sync::atomic::{AtomicU8, Ordering}, }; @@ -21,7 +22,7 @@ pub struct Map { impl Map where - K: Hash + Eq, + K: Hash + Eq + Debug, S: BuildHasher, { pub fn with_capacity(entries: usize, hasher: S) -> Self { @@ -108,7 +109,7 @@ struct Slot { impl Slot where - K: Hash + Eq, + K: Hash + Eq + Debug, { fn new() -> Self { Slot { @@ -139,6 +140,10 @@ where // If `new_key` isn't already in this slot, replace one of the existing entries with the // new key. For now we rotate through based on `next_write`. let replacement = self.next_write.fetch_add(1, Ordering::Relaxed) as usize % SLOT_CAPACITY; + tracing::trace!( + "evicting {:?} - bucket overflow", + values[replacement].as_mut().unwrap().0 + ); values[replacement] = Some((new_key, new_value)); None } diff --git a/dc/s2n-quic-dc/src/lib.rs b/dc/s2n-quic-dc/src/lib.rs index 908fbcf683..71b63dbd51 100644 --- a/dc/s2n-quic-dc/src/lib.rs +++ b/dc/s2n-quic-dc/src/lib.rs @@ -17,6 +17,7 @@ pub mod random; pub mod recovery; pub mod socket; pub mod stream; +pub mod sync; pub mod task; #[cfg(any(test, feature = "testing"))] diff --git a/dc/s2n-quic-dc/src/path/secret.rs b/dc/s2n-quic-dc/src/path/secret.rs index d5115b9697..a519349376 100644 --- a/dc/s2n-quic-dc/src/path/secret.rs +++ b/dc/s2n-quic-dc/src/path/secret.rs @@ -12,3 +12,13 @@ pub mod stateless_reset; pub use key::{open, seal}; pub use map::Map; + +/// The handshake operation may return immediately if state for the target is already cached, +/// or perform an actual handshake if not. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum HandshakeKind { + /// Handshake was skipped because a secret was already present in the cache + Cached, + /// Handshake was performed to generate a new secret + Fresh, +} diff --git a/dc/s2n-quic-dc/src/path/secret/map.rs b/dc/s2n-quic-dc/src/path/secret/map.rs index af29cb440f..faf4ea3d15 100644 --- a/dc/s2n-quic-dc/src/path/secret/map.rs +++ b/dc/s2n-quic-dc/src/path/secret/map.rs @@ -22,6 +22,7 @@ use s2n_quic_core::{ }; use std::{ fmt, + hash::{BuildHasherDefault, Hasher}, net::{Ipv4Addr, SocketAddr}, sync::{ atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, @@ -51,6 +52,24 @@ pub struct Map { pub(super) state: Arc, } +#[derive(Default)] +pub(super) struct NoopIdHasher(Option); + +impl Hasher for NoopIdHasher { + fn finish(&self) -> u64 { + self.0.unwrap() + } + + fn write(&mut self, _bytes: &[u8]) { + unimplemented!() + } + + fn write_u64(&mut self, x: u64) { + debug_assert!(self.0.is_none()); + self.0 = Some(x); + } +} + // # Managing memory consumption // // For regular rotation with live peers, we retain at most two secrets: one derived from the most @@ -93,7 +112,7 @@ pub(super) struct State { pub(super) requested_handshakes: flurry::HashSet, // All known entries. - pub(super) ids: fixed_map::Map>, + pub(super) ids: fixed_map::Map, BuildHasherDefault>, pub(super) signer: stateless_reset::Signer, @@ -232,7 +251,7 @@ impl State { } impl Map { - pub fn new(signer: stateless_reset::Signer) -> Self { + pub fn new(signer: stateless_reset::Signer, capacity: usize) -> Self { // FIXME: Avoid unwrap and the whole socket. // // We only ever send on this socket - but we really should be sending on the same @@ -244,11 +263,11 @@ impl Map { control_socket.set_nonblocking(true).unwrap(); let state = State { // This is around 500MB with current entry size. - max_capacity: 500_000, + max_capacity: capacity, // FIXME: Allow configuring the rehandshake_period. rehandshake_period: Duration::from_secs(3600 * 24), - peers: fixed_map::Map::with_capacity(500_000, Default::default()), - ids: fixed_map::Map::with_capacity(500_000, Default::default()), + peers: fixed_map::Map::with_capacity(capacity, Default::default()), + ids: fixed_map::Map::with_capacity(capacity, Default::default()), requested_handshakes: Default::default(), cleaner: Cleaner::new(), signer, @@ -301,6 +320,19 @@ impl Map { Some((sealer, credentials, state.parameters.clone())) } + /// Retrieve a sealer by path secret ID. + /// + /// Generally callers should prefer to use one of the `pair` APIs; this is primarily useful for + /// "response" datagrams which want to be bound to the exact same shared secret. + /// + /// Note that unlike by-IP lookup this should typically not be done significantly after the + /// original secret was used for decryption. + pub fn seal_once_id(&self, id: Id) -> Option<(seal::Once, Credentials, ApplicationParams)> { + let state = self.state.ids.get_by_key(&id)?; + let (sealer, credentials) = state.uni_sealer(); + Some((sealer, credentials, state.parameters.clone())) + } + pub fn open_once( &self, credentials: &Credentials, @@ -485,7 +517,7 @@ impl Map { pub fn for_test_with_peers( peers: Vec<(schedule::Ciphersuite, dc::Version, SocketAddr)>, ) -> (Self, Vec) { - let provider = Self::new(stateless_reset::Signer::random()); + let provider = Self::new(stateless_reset::Signer::random(), peers.len() * 3); let mut secret = [0; 32]; aws_lc_rs::rand::fill(&mut secret).unwrap(); let mut stateless_reset = [0; control::TAG_LEN]; diff --git a/dc/s2n-quic-dc/src/path/secret/map/test.rs b/dc/s2n-quic-dc/src/path/secret/map/test.rs index f295c6e80b..b0a8fa41b8 100644 --- a/dc/s2n-quic-dc/src/path/secret/map/test.rs +++ b/dc/s2n-quic-dc/src/path/secret/map/test.rs @@ -32,7 +32,7 @@ fn fake_entry(peer: u16) -> Arc { #[test] fn cleans_after_delay() { let signer = stateless_reset::Signer::new(b"secret"); - let map = Map::new(signer); + let map = Map::new(signer, 50); // Stop background processing. We expect to manually invoke clean, and a background worker // might interfere with our state. @@ -60,7 +60,7 @@ fn cleans_after_delay() { #[test] fn thread_shutdown() { let signer = stateless_reset::Signer::new(b"secret"); - let map = Map::new(signer); + let map = Map::new(signer, 10); let state = Arc::downgrade(&map.state); drop(map); @@ -263,7 +263,7 @@ fn check_invariants() { let mut model = Model::default(); let signer = stateless_reset::Signer::new(b"secret"); - let mut map = Map::new(signer); + let mut map = Map::new(signer, 10_000); // Avoid background work interfering with testing. map.state.cleaner.stop(); @@ -293,7 +293,7 @@ fn check_invariants_no_overflow() { let mut model = Model::default(); let signer = stateless_reset::Signer::new(b"secret"); - let map = Map::new(signer); + let map = Map::new(signer, 10_000); // Avoid background work interfering with testing. map.state.cleaner.stop(); @@ -316,7 +316,7 @@ fn check_invariants_no_overflow() { #[ignore = "memory growth takes a long time to run"] fn no_memory_growth() { let signer = stateless_reset::Signer::new(b"secret"); - let map = Map::new(signer); + let map = Map::new(signer, 100_000); map.state.cleaner.stop(); for idx in 0..500_000 { // FIXME: this ends up 2**16 peers in the `peers` map @@ -325,10 +325,15 @@ fn no_memory_growth() { } #[test] -#[cfg(all(target_pointer_width = "64", target_os = "linux"))] fn entry_size() { + let mut should_check = true; + + should_check &= cfg!(target_pointer_width = "64"); + should_check &= cfg!(target_os = "linux"); + should_check &= std::env::var("S2N_QUIC_RUN_VERSION_SPECIFIC_TESTS").is_ok(); + // This gates to running only on specific GHA to reduce false positives. - if std::env::var("S2N_QUIC_RUN_VERSION_SPECIFIC_TESTS").is_ok() { + if should_check { assert_eq!(fake_entry(0).size(), 238); } } diff --git a/dc/s2n-quic-dc/src/stream.rs b/dc/s2n-quic-dc/src/stream.rs index a41b639eeb..1cd9465841 100644 --- a/dc/s2n-quic-dc/src/stream.rs +++ b/dc/s2n-quic-dc/src/stream.rs @@ -11,6 +11,7 @@ pub const DEFAULT_INFLIGHT_TIMEOUT: Duration = Duration::from_secs(5); pub const MAX_DATAGRAM_SIZE: usize = 1 << 15; // 32k pub mod application; +pub mod client; pub mod crypto; pub mod endpoint; pub mod environment; diff --git a/dc/s2n-quic-dc/src/stream/client.rs b/dc/s2n-quic-dc/src/stream/client.rs new file mode 100644 index 0000000000..5c924e6891 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/client.rs @@ -0,0 +1,4 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod tokio; diff --git a/dc/s2n-quic-dc/src/stream/client/tokio.rs b/dc/s2n-quic-dc/src/stream/client/tokio.rs new file mode 100644 index 0000000000..9cd4eda9e0 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/client/tokio.rs @@ -0,0 +1,121 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + path::secret, + stream::{ + application::Stream, + endpoint, + environment::tokio::{self as env, Environment}, + socket::Protocol, + }, +}; +use std::{io, net::SocketAddr}; +use tokio::net::TcpStream; + +/// Connects using the UDP transport layer +#[inline] +pub async fn connect_udp( + handshake_addr: SocketAddr, + handshake: H, + acceptor_addr: SocketAddr, + env: &Environment, + map: &secret::Map, +) -> io::Result +where + H: core::future::Future>, +{ + // ensure we have a secret for the peer + handshake.await?; + + let stream = endpoint::open_stream( + env, + handshake_addr.into(), + env::UdpUnbound(acceptor_addr.into()), + map, + None, + )?; + + // build the stream inside the application context + let mut stream = stream.build()?; + + debug_assert_eq!(stream.protocol(), Protocol::Udp); + + write_prelude(&mut stream).await?; + + Ok(stream) +} + +/// Connects using the TCP transport layer +#[inline] +pub async fn connect_tcp( + handshake_addr: SocketAddr, + handshake: H, + acceptor_addr: SocketAddr, + env: &Environment, + map: &secret::Map, +) -> io::Result +where + H: core::future::Future>, +{ + // Race TCP handshake with the TLS handshake + let (socket, _) = tokio::try_join!(TcpStream::connect(acceptor_addr), handshake,)?; + + let stream = endpoint::open_stream( + env, + handshake_addr.into(), + env::TcpRegistered(socket), + map, + None, + )?; + + // build the stream inside the application context + let mut stream = stream.build()?; + + debug_assert_eq!(stream.protocol(), Protocol::Tcp); + + write_prelude(&mut stream).await?; + + Ok(stream) +} + +/// Connects with a pre-existing TCP stream +/// +/// # Note +/// +/// The provided `map` must contain a shared secret for the `handshake_addr` +#[inline] +pub async fn connect_tcp_with( + handshake_addr: SocketAddr, + stream: TcpStream, + env: &Environment, + map: &secret::Map, +) -> io::Result { + let stream = endpoint::open_stream( + env, + handshake_addr.into(), + env::TcpRegistered(stream), + map, + None, + )?; + + // build the stream inside the application context + let mut stream = stream.build()?; + + debug_assert_eq!(stream.protocol(), Protocol::Tcp); + + write_prelude(&mut stream).await?; + + Ok(stream) +} + +#[inline] +async fn write_prelude(stream: &mut Stream) -> io::Result<()> { + // TODO should we actually write the prelude here or should we do late sealer binding on + // the first packet to reduce secret reordering on the peer + + stream + .write_from(&mut s2n_quic_core::buffer::reader::storage::Empty) + .await + .map(|_| ()) +} diff --git a/dc/s2n-quic-dc/src/stream/endpoint.rs b/dc/s2n-quic-dc/src/stream/endpoint.rs index df5e44ad4c..64efdc6dd5 100644 --- a/dc/s2n-quic-dc/src/stream/endpoint.rs +++ b/dc/s2n-quic-dc/src/stream/endpoint.rs @@ -19,10 +19,7 @@ use s2n_quic_core::{ inet::{ExplicitCongestionNotification, SocketAddress}, varint::VarInt, }; -use std::{ - io, - sync::{atomic::Ordering, Arc}, -}; +use std::{io, sync::Arc}; use tracing::{debug_span, Instrument as _}; type Result = core::result::Result; @@ -196,7 +193,7 @@ where let flow = flow::non_blocking::State::new(flow_offset); let path = send::path::Info { - max_datagram_size: parameters.max_datagram_size.load(Ordering::Relaxed), + max_datagram_size: parameters.max_datagram_size(), send_quantum, ecn: ExplicitCongestionNotification::Ect0, next_expected_control_packet: VarInt::ZERO, diff --git a/dc/s2n-quic-dc/src/stream/environment/tokio.rs b/dc/s2n-quic-dc/src/stream/environment/tokio.rs index 8e67f310f0..882f367a11 100644 --- a/dc/s2n-quic-dc/src/stream/environment/tokio.rs +++ b/dc/s2n-quic-dc/src/stream/environment/tokio.rs @@ -25,9 +25,15 @@ pub struct Builder { reader_rt: Option, writer_rt: Option, thread_name_prefix: Option, + threads: Option, } impl Builder { + pub fn with_threads(mut self, threads: usize) -> Self { + self.threads = Some(threads); + self + } + #[inline] pub fn build(self) -> io::Result { let clock = self.clock.unwrap_or_default(); @@ -36,20 +42,26 @@ impl Builder { let thread_name_prefix = self.thread_name_prefix.as_deref().unwrap_or("dc_quic"); - let reader_rt = self.reader_rt.map(>::Ok).unwrap_or_else(|| { - Ok(tokio::runtime::Builder::new_multi_thread() - .enable_all() - .thread_name(format!("{thread_name_prefix}::reader")) - .build()? - .into()) - })?; - let writer_rt = self.writer_rt.map(>::Ok).unwrap_or_else(|| { - Ok(tokio::runtime::Builder::new_multi_thread() + let make_rt = |suffix: &str, threads: Option| { + let mut builder = tokio::runtime::Builder::new_multi_thread(); + if let Some(threads) = threads { + builder.worker_threads(threads); + } + Ok(builder .enable_all() - .thread_name(format!("{thread_name_prefix}::writer")) + .thread_name(format!("{thread_name_prefix}::{suffix}")) .build()? .into()) - })?; + }; + + let reader_rt = self + .reader_rt + .map(>::Ok) + .unwrap_or_else(|| make_rt("reader", self.threads))?; + let writer_rt = self + .writer_rt + .map(>::Ok) + .unwrap_or_else(|| make_rt("writer", self.threads))?; Ok(Environment { clock, diff --git a/dc/s2n-quic-dc/src/stream/send/state.rs b/dc/s2n-quic-dc/src/stream/send/state.rs index e75cf290e5..9bd33af893 100644 --- a/dc/s2n-quic-dc/src/stream/send/state.rs +++ b/dc/s2n-quic-dc/src/stream/send/state.rs @@ -42,10 +42,7 @@ use s2n_quic_core::{ varint::VarInt, }; use slotmap::SlotMap; -use std::{ - collections::{BinaryHeap, VecDeque}, - sync::atomic::Ordering, -}; +use std::collections::{BinaryHeap, VecDeque}; use tracing::{debug, trace}; pub mod probe; @@ -121,7 +118,7 @@ pub struct PeerActivity { impl State { #[inline] pub fn new(stream_id: stream::Id, params: &ApplicationParams) -> Self { - let max_datagram_size = params.max_datagram_size.load(Ordering::Relaxed); + let max_datagram_size = params.max_datagram_size(); let initial_max_data = params.remote_max_data; let local_max_data = params.local_send_max_data; diff --git a/dc/s2n-quic-dc/src/stream/server.rs b/dc/s2n-quic-dc/src/stream/server.rs index ae6046f094..c4f4f494f0 100644 --- a/dc/s2n-quic-dc/src/stream/server.rs +++ b/dc/s2n-quic-dc/src/stream/server.rs @@ -7,6 +7,7 @@ use crate::{credentials::Credentials, msg::recv, packet}; use s2n_codec::{DecoderBufferMut, DecoderError}; pub mod handshake; +pub mod tokio; #[derive(Debug)] pub struct InitialPacket { diff --git a/dc/s2n-quic-dc/src/stream/server/tokio.rs b/dc/s2n-quic-dc/src/stream/server/tokio.rs new file mode 100644 index 0000000000..bf0ebd7ea1 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/server/tokio.rs @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod accept; +pub mod stats; +pub mod tcp; +pub mod udp; diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/accept.rs b/dc/s2n-quic-dc/src/stream/server/tokio/accept.rs new file mode 100644 index 0000000000..4e841cfada --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/server/tokio/accept.rs @@ -0,0 +1,143 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::stats; +use crate::{ + stream::{ + application::{Builder as StreamBuilder, Stream}, + environment::{tokio::Environment, Environment as _}, + }, + sync::channel, +}; +use core::time::Duration; +use s2n_quic_core::time::{Clock, Timestamp}; +use std::{io, net::SocketAddr}; +use tokio::time::sleep; + +#[derive(Clone, Copy, Default)] +pub enum Flavor { + #[default] + Fifo, + Lifo, +} + +pub type Sender = channel::Sender<(StreamBuilder, Timestamp)>; +pub type Receiver = channel::Receiver<(StreamBuilder, Timestamp)>; + +#[inline] +pub async fn accept(streams: &Receiver, stats: &stats::Sender) -> io::Result<(Stream, SocketAddr)> { + let (stream, queue_time) = streams.recv_front().await.map_err(|_err| { + io::Error::new( + io::ErrorKind::NotConnected, + "server acceptor runtime is no longer available", + ) + })?; + + let now = stream.shared.common.clock.get_time(); + let sojourn_time = now.saturating_duration_since(queue_time); + + // submit sojourn time statistics + stats.send(sojourn_time); + + // build the stream inside the application context + let stream = stream.build()?; + let remote_addr = stream.peer_addr()?; + + Ok((stream, remote_addr)) +} + +#[derive(Clone, Debug)] +pub struct Pruner { + /// Any sojourn duration multiplied by this value is unlikely to be accepted in time + pub sojourn_multiplier: u32, + + /// Don't prune anything under this amount, just so we can handle bursts of streams and + /// not prematurely drop things. + pub min_threshold: Duration, + + /// Anything older than this amount has likely timed out at this point. No need to hold + /// on to the stream any longer at this point + pub max_threshold: Duration, + + /// Minimum amount of time to sleep before pruning the queue again + pub min_sleep_time: Duration, + + /// Maximum amount of time to sleep before pruning the queue again + pub max_sleep_time: Duration, +} + +impl Default for Pruner { + fn default() -> Self { + Self { + sojourn_multiplier: 3, + min_threshold: Duration::from_millis(100), + max_threshold: Duration::from_secs(5), + min_sleep_time: Duration::from_millis(100), + max_sleep_time: Duration::from_secs(1), + } + } +} + +impl Pruner { + /// A task which prunes the accept queue to enforce a maximum sojourn time + pub async fn run( + self, + env: Environment, + channel: channel::WeakReceiver<(StreamBuilder, Timestamp)>, + stats: stats::Stats, + ) { + let Self { + sojourn_multiplier, + min_threshold, + max_threshold, + min_sleep_time, + max_sleep_time, + } = self; + + sleep(min_sleep_time).await; + + loop { + let now = env.clock().get_time(); + let smoothed_sojourn_time = stats.smoothed_sojourn_time(); + + // compute the oldest allowed queue time + let Some(queue_time_threshold) = now.checked_sub( + (smoothed_sojourn_time * sojourn_multiplier).clamp(min_threshold, max_threshold), + ) else { + sleep(min_sleep_time).await; + continue; + }; + + // Use optional locks to avoid lock contention. If there is contention on the channel, the + // old streams will naturally be pruned, since old ones will be dropped in favor of new + // ones. + let priority = channel::Priority::Optional; + + loop { + // pop off any items that have expired + let res = channel.pop_back_if(priority, |(_stream, queue_time)| { + queue_time.has_elapsed(queue_time_threshold) + }); + + match res { + // we pruned a stream + Ok(Some((stream, queue_time))) => { + tracing::debug!( + event = "accept::prune", + credentials = ?stream.shared.credentials(), + queue_duration = ?now.saturating_duration_since(queue_time), + ); + continue; + } + // no more streams left to prune + Ok(None) => break, + // the channel was closed + Err(_) => return, + } + } + + // wake up later based on the smoothed sojourn time + sleep(smoothed_sojourn_time.clamp(min_sleep_time, max_sleep_time)).await; + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/stats.rs b/dc/s2n-quic-dc/src/stream/server/tokio/stats.rs new file mode 100644 index 0000000000..0a0785a066 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/server/tokio/stats.rs @@ -0,0 +1,103 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::sync::channel as chan; +use core::{ + sync::atomic::{AtomicU64, Ordering}, + time::Duration, +}; +use s2n_quic_core::{packet::number::PacketNumberSpace, recovery::RttEstimator, time::Clock}; +use std::sync::Arc; +use tokio::time::sleep; + +pub fn channel() -> (Sender, Worker, Stats) { + // TODO configure this queue depth? + let (send, recv) = chan::new(1024); + let sender = Sender(send); + let stats = Stats::default(); + let worker = Worker { + queue: recv, + stats: stats.clone(), + }; + (sender, worker, stats) +} + +#[derive(Clone)] +pub struct Sender(chan::Sender); + +impl Sender { + #[inline] + pub fn send(&self, sojourn_time: Duration) { + // prefer recent samples + let _ = self.0.send_back(sojourn_time); + } +} + +pub struct Worker { + queue: chan::Receiver, + stats: Stats, +} + +impl Worker { + pub async fn run(self, clock: C) { + let mut rtt_estimator = RttEstimator::new(Duration::from_secs(30)); + + let debounce = Duration::from_millis(5); + let timeout = Duration::from_millis(5); + + loop { + let Ok(sample) = self.queue.recv_back().await else { + break; + }; + + let now = clock.get_time(); + rtt_estimator.update_rtt( + Duration::ZERO, + sample, + now, + true, + PacketNumberSpace::ApplicationData, + ); + + // allow some more samples to come through + sleep(debounce).await; + + while let Ok(Some(sample)) = self.queue.try_recv_back() { + rtt_estimator.update_rtt( + Duration::ZERO, + sample, + now, + true, + PacketNumberSpace::ApplicationData, + ); + } + + self.stats.update(&rtt_estimator); + + // wait before taking a new sample to avoid spinning + sleep(timeout).await; + } + } +} + +#[derive(Clone, Default)] +pub struct Stats(Arc); + +impl Stats { + #[inline] + pub fn smoothed_sojourn_time(&self) -> Duration { + Duration::from_nanos(self.0.smoothed_sojourn_time.load(Ordering::Relaxed)) + } + + fn update(&self, rtt_estimator: &RttEstimator) { + let smoothed_rtt = rtt_estimator.smoothed_rtt().as_nanos().min(u64::MAX as _) as _; + self.0 + .smoothed_sojourn_time + .store(smoothed_rtt, Ordering::Relaxed); + } +} + +#[derive(Default)] +struct StatsState { + smoothed_sojourn_time: AtomicU64, +} diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs new file mode 100644 index 0000000000..5d1f6da82a --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs @@ -0,0 +1,596 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::accept; +use crate::{ + msg, + path::secret, + stream::{ + endpoint, + environment::{ + tokio::{self as env, Environment}, + Environment as _, + }, + server, + socket::Socket, + }, +}; +use core::{ + future::poll_fn, + ops::ControlFlow, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use s2n_quic_core::{ + ensure, + packet::number::PacketNumberSpace, + ready, + recovery::RttEstimator, + time::{Clock, Timestamp}, +}; +use std::{collections::VecDeque, io}; +use tokio::{ + io::AsyncWrite as _, + net::{TcpListener, TcpStream}, +}; +use tracing::{debug, error, trace}; + +pub struct Acceptor { + sender: accept::Sender, + socket: TcpListener, + env: Environment, + secrets: secret::Map, + backlog: usize, + accept_flavor: accept::Flavor, +} + +impl Acceptor { + #[inline] + pub fn new( + socket: TcpListener, + sender: &accept::Sender, + env: &Environment, + secrets: &secret::Map, + backlog: usize, + accept_flavor: accept::Flavor, + ) -> Self { + Self { + sender: sender.clone(), + socket, + env: env.clone(), + secrets: secrets.clone(), + backlog, + accept_flavor, + } + } + + pub async fn run(self) { + let drop_guard = DropLog; + let mut fresh = FreshQueue::new(&self); + let mut workers = WorkerSet::new(&self); + let mut context = WorkerContext::new(&self); + + poll_fn(move |cx| { + let now = self.env.clock().get_time(); + + fresh.fill(cx, &self.socket); + trace!(accepted_backlog = fresh.len()); + + for socket in fresh.drain() { + workers.push(socket, now); + } + + trace!(pre_worker_count = workers.working.len()); + let res = workers.poll(cx, &mut context, now); + trace!(post_worker_count = workers.working.len()); + + workers.invariants(); + + if res.is_continue() { + Poll::Pending + } else { + Poll::Ready(()) + } + }) + .await; + + drop(drop_guard); + } +} + +/// Converts the kernel's TCP FIFO accept queue to LIFO +/// +/// This should produce overall better latencies in the case of overloaded queues. +struct FreshQueue { + queue: VecDeque, +} + +impl FreshQueue { + fn new(acceptor: &Acceptor) -> Self { + Self { + queue: VecDeque::with_capacity(acceptor.backlog), + } + } + + fn fill(&mut self, cx: &mut Context, listener: &TcpListener) { + // Allow draining the queue twice the capacity + // + // The idea here is to try and reduce the number of connections in the kernel's queue while + // bounding the amount of work we do in userspace. + // + // TODO: investigate getting the current length and dropping the front of the queue rather + // than pop/push with the userspace queue + let mut remaining = self.queue.capacity() * 2; + + while let Poll::Ready(res) = listener.poll_accept(cx) { + match res { + Ok((socket, _remote_addr)) => { + if self.queue.len() == self.queue.capacity() { + let _ = self.queue.pop_back(); + trace!("fresh backlog too full; dropping stream"); + } + // most recent streams go to the front of the line, since they're the most + // likely to be successfully processed + self.queue.push_front(socket); + } + Err(err) => { + // TODO submit to a separate error channel that the application can subscribe + // to + error!(listener_error = %err); + } + } + + remaining -= 1; + + if remaining == 0 { + return; + } + } + } + + fn len(&self) -> usize { + self.queue.len() + } + + fn drain(&mut self) -> impl Iterator + '_ { + self.queue.drain(..) + } +} + +struct WorkerSet { + /// A set of worker entries which process newly-accepted streams + workers: Box<[Worker]>, + /// FIFO queue for tracking free [`Worker`] entries + /// + /// None of the indices in this queue have associated sockets and are waiting to be assigned + /// for work. + free: VecDeque, + /// A list of [`Worker`] entries that are currently processing a socket + /// + /// This list is ordered by sojourn time, where the front of the list is the oldest. The front + /// will be the first to be reclaimed in the case of overload. + working: VecDeque, + /// Tracks the [sojourn time](https://en.wikipedia.org/wiki/Mean_sojourn_time) of processing + /// streams in worker entries. + sojourn_time: RttEstimator, +} + +impl WorkerSet { + #[inline] + pub fn new(acceptor: &Acceptor) -> Self { + let backlog = acceptor.backlog; + let mut workers = Vec::with_capacity(backlog); + let mut free = VecDeque::with_capacity(backlog); + let now = acceptor.env.clock().get_time(); + for idx in 0..backlog { + workers.push(Worker::new(now)); + free.push_back(idx); + } + Self { + workers: workers.into(), + free, + working: VecDeque::with_capacity(backlog), + // set the initial estimate high to avoid backlog churn before we get stable samples + sojourn_time: RttEstimator::new(Duration::from_secs(30)), + } + } + + #[inline] + pub fn push(&mut self, stream: TcpStream, now: Timestamp) { + let Some(idx) = self.next_worker(now) else { + // NOTE: we do not apply back pressure on the listener's `accept` since the aim is to + // keep that queue as short as possible so we can control the behavior in userspace. + // + // TODO: we need to investigate how this interacts with SYN cookies/retries and fast + // failure modes in kernel space. + trace!( + "could not find an available worker; dropping stream {:?}", + stream + ); + drop(stream); + return; + }; + self.workers[idx].push(stream, now); + self.working.push_back(idx); + } + + #[inline] + pub fn poll( + &mut self, + cx: &mut Context, + worker_cx: &mut WorkerContext, + now: Timestamp, + ) -> ControlFlow<()> { + let mut cf = ControlFlow::Continue(()); + + self.working.retain(|&idx| { + let worker = &mut self.workers[idx]; + let Poll::Ready(res) = worker.poll(cx, worker_cx, now) else { + // keep processing it + return true; + }; + + match res { + Ok(ControlFlow::Continue(())) => { + // update the accept_time estimate + let sample = worker.sojourn(now); + trace!(sojourn_sample = ?sample); + + self.sojourn_time.update_rtt( + Duration::ZERO, + worker.sojourn(now), + now, + true, + PacketNumberSpace::ApplicationData, + ); + } + Ok(ControlFlow::Break(())) => { + cf = ControlFlow::Break(()); + } + Err(err) => { + debug!(accept_stream_error = %err); + } + } + + // the worker is done so remove it from the working queue + self.free.push_back(idx); + false + }); + + cf + } + + #[inline] + fn next_worker(&mut self, now: Timestamp) -> Option { + // if we have a free worker then use that + if let Some(idx) = self.free.pop_front() { + trace!(op = %"next_worker", free = idx); + return Some(idx); + } + + let idx = *self.working.front().unwrap(); + let sojourn = self.workers[idx].sojourn(now); + + // if the worker's sojourn time exceeds the maximum, then reclaim it + if sojourn > self.max_sojourn_time() { + trace!(op = %"next_worker", injected = idx, ?sojourn); + return self.working.pop_front(); + } + + trace!(op = %"next_worker", ?sojourn, max_sojourn_time = ?self.max_sojourn_time()); + + None + } + + #[inline] + fn max_sojourn_time(&self) -> Duration { + // if we're double the smoothed sojourn time then the latency is already quite high on the + // stream - better to accept a new stream at this point + // + // FIXME: This currently hardcodes the min/max to try to avoid issues with very fast or + // very slow clients skewing our behavior too much, but it's not clear what the goal is. + (self.sojourn_time.smoothed_rtt() * 2).clamp(Duration::from_secs(1), Duration::from_secs(5)) + } + + #[cfg(not(debug_assertions))] + fn invariants(&self) {} + + #[cfg(debug_assertions)] + fn invariants(&self) { + for idx in 0..self.workers.len() { + let in_ready = self.free.contains(&idx); + let in_working = self.working.contains(&idx); + assert!( + in_working ^ in_ready, + "worker should either be in ready ({in_ready}) or working ({in_working}) list" + ); + } + + for idx in self.free.iter().copied() { + let worker = &self.workers[idx]; + assert!(worker.stream.is_none()); + assert!( + matches!(worker.state, WorkerState::Init), + "actual={:?}", + worker.state + ); + } + + let mut prev_queue_time = None; + for idx in self.working.iter().copied() { + let worker = &self.workers[idx]; + assert!(worker.stream.is_some()); + let queue_time = worker.queue_time; + if let Some(prev) = prev_queue_time { + assert!( + prev <= queue_time, + "front should be oldest; prev={prev:?}, queue_time={queue_time:?}" + ); + } + prev_queue_time = Some(queue_time); + } + } +} + +struct WorkerContext { + recv_buffer: msg::recv::Message, + sender: accept::Sender, + env: Environment, + secrets: secret::Map, + accept_flavor: accept::Flavor, +} + +impl WorkerContext { + fn new(acceptor: &Acceptor) -> Self { + Self { + recv_buffer: msg::recv::Message::new(u16::MAX), + sender: acceptor.sender.clone(), + env: acceptor.env.clone(), + secrets: acceptor.secrets.clone(), + accept_flavor: acceptor.accept_flavor, + } + } +} + +struct Worker { + queue_time: Timestamp, + stream: Option, + state: WorkerState, +} + +impl Worker { + pub fn new(now: Timestamp) -> Self { + Self { + queue_time: now, + stream: None, + state: WorkerState::Init, + } + } + + #[inline] + pub fn push(&mut self, stream: TcpStream, now: Timestamp) { + self.queue_time = now; + let prev_state = core::mem::replace(&mut self.state, WorkerState::Init); + let prev = core::mem::replace(&mut self.stream, Some(stream)); + + if prev.is_some() { + trace!(worker_prev_state = ?prev_state); + } + } + + #[inline] + pub fn poll( + &mut self, + cx: &mut Context, + context: &mut WorkerContext, + now: Timestamp, + ) -> Poll>> { + // if we don't have a stream then it's a bug in the worker impl - in production just return + // `Ready`, which will correct the state + if self.stream.is_none() { + debug_assert!( + false, + "Worker::poll should only be called with an active socket" + ); + return Poll::Ready(Ok(ControlFlow::Continue(()))); + } + + // make sure another worker didn't leave around a buffer + context.recv_buffer.clear(); + + let res = ready!(self + .state + .poll(cx, context, &mut self.stream, self.queue_time, now)); + + // if we're ready then reset the worker + self.state = WorkerState::Init; + self.stream = None; + + Poll::Ready(res) + } + + /// Returns the duration that the worker has been processing a stream + #[inline] + pub fn sojourn(&self, now: Timestamp) -> Duration { + now.saturating_duration_since(self.queue_time) + } +} + +#[derive(Debug)] +enum WorkerState { + /// Worker is waiting for a packet + Init, + /// Worker received a partial packet and is waiting on more data + Buffering(msg::recv::Message), + /// Worker encountered an error and is trying to send a response + Erroring { + offset: usize, + buffer: Vec, + error: io::Error, + }, +} + +impl WorkerState { + fn poll( + &mut self, + cx: &mut Context, + context: &mut WorkerContext, + stream: &mut Option, + queue_time: Timestamp, + now: Timestamp, + ) -> Poll>> { + loop { + // figure out where to put the received bytes + let (recv_buffer, recv_buffer_owned) = match self { + // borrow the context's recv buffer initially + WorkerState::Init => (&mut context.recv_buffer, false), + // we have our own recv buffer to use + WorkerState::Buffering(recv_buffer) => (recv_buffer, true), + // we encountered an error so try and send it back + WorkerState::Erroring { offset, buffer, .. } => { + let stream = Pin::new(stream.as_mut().unwrap()); + let len = ready!(stream.poll_write(cx, &buffer[*offset..]))?; + + *offset += len; + + // if we still need to send part of the buffer then loop back around + if *offset < buffer.len() { + continue; + } + + // io::Error doesn't implement clone so we have to take the error to return it + let WorkerState::Erroring { error, .. } = core::mem::replace(self, Self::Init) + else { + unreachable!() + }; + + return Err(error).into(); + } + }; + + // try to read an initial packet from the socket + let res = Self::poll_initial_packet(cx, stream.as_mut().unwrap(), recv_buffer); + + let Poll::Ready(res) = res else { + // if we got `Pending` but we don't own the recv buffer then we need to copy it + // into the worker so we can resume where we left off last time + if !recv_buffer_owned { + *self = WorkerState::Buffering(recv_buffer.take()); + }; + + return Poll::Pending; + }; + + let initial_packet = res?; + + debug!(?initial_packet); + + let stream_builder = match endpoint::accept_stream( + &context.env, + env::TcpReregistered(stream.take().unwrap()), + &initial_packet, + None, + Some(recv_buffer), + &context.secrets, + None, + ) { + Ok(stream) => stream, + Err(error) => { + if let Some(env::TcpReregistered(socket)) = error.peer { + if !error.secret_control.is_empty() { + // if we need to send an error then update the state and loop back + // around + *stream = Some(socket); + *self = WorkerState::Erroring { + offset: 0, + buffer: error.secret_control, + error: error.error, + }; + continue; + } + } + return Err(error.error).into(); + } + }; + + trace!( + enqueue_stream = ?stream_builder.shared.remote_ip(), + sojourn_time = ?now.saturating_duration_since(queue_time), + ); + + let item = (stream_builder, queue_time); + let res = match context.accept_flavor { + accept::Flavor::Fifo => context.sender.send_back(item), + accept::Flavor::Lifo => context.sender.send_front(item), + }; + + return Poll::Ready(Ok(match res { + Ok(prev) => { + if let Some((stream, queue_time)) = prev { + debug!( + event = "accept::prune", + credentials = ?stream.shared.credentials(), + queue_duration = ?now.saturating_duration_since(queue_time), + ); + } + ControlFlow::Continue(()) + } + Err(_err) => { + debug!("application accept queue dropped; shutting down"); + ControlFlow::Break(()) + } + })); + } + } + + #[inline] + fn poll_initial_packet( + cx: &mut Context, + stream: &mut TcpStream, + recv_buffer: &mut msg::recv::Message, + ) -> Poll> { + loop { + ensure!( + recv_buffer.payload_len() < 10_000, + Err(io::Error::new( + io::ErrorKind::InvalidData, + "prelude did not come in the first 10k bytes" + )) + .into() + ); + + let res = ready!(stream.poll_recv_buffer(cx, recv_buffer))?; + + match server::InitialPacket::peek(recv_buffer, 16) { + Ok(packet) => { + return Ok(packet).into(); + } + Err(s2n_codec::DecoderError::UnexpectedEof(_)) => { + // If at end of the stream, we're not going to succeed. End early. + if res == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "insufficient data in prelude before EOF", + )) + .into(); + } + // we don't have enough bytes buffered so try reading more + continue; + } + Err(err) => { + return Err(io::Error::new(io::ErrorKind::InvalidData, err.to_string())).into(); + } + } + } + } +} + +struct DropLog; + +impl Drop for DropLog { + #[inline] + fn drop(&mut self) { + debug!("acceptor task has been dropped"); + } +} diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/udp.rs b/dc/s2n-quic-dc/src/stream/server/tokio/udp.rs new file mode 100644 index 0000000000..a7890cbd99 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/server/tokio/udp.rs @@ -0,0 +1,148 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::accept; +use crate::{ + msg, + path::secret, + stream::{ + endpoint, + environment::{ + tokio::{self as env, Environment}, + Environment as _, + }, + server, + socket::{Ext as _, Socket}, + }, +}; +use core::ops::ControlFlow; +use s2n_quic_core::time::Clock as _; +use std::io; +use tracing::debug; + +pub struct Acceptor { + sender: accept::Sender, + socket: S, + recv_buffer: msg::recv::Message, + handshake: server::handshake::Map, + env: Environment, + secrets: secret::Map, + accept_flavor: accept::Flavor, +} + +impl Acceptor { + #[inline] + pub fn new( + socket: S, + sender: &accept::Sender, + env: &Environment, + secrets: &secret::Map, + accept_flavor: accept::Flavor, + ) -> Self { + Self { + sender: sender.clone(), + socket, + recv_buffer: msg::recv::Message::new(9000.try_into().unwrap()), + handshake: Default::default(), + env: env.clone(), + secrets: secrets.clone(), + accept_flavor, + } + } + + pub async fn run(mut self) { + loop { + match self.accept_one().await { + Ok(ControlFlow::Continue(())) => continue, + Ok(ControlFlow::Break(())) => break, + Err(err) => { + tracing::error!(acceptor_error = %err); + } + } + } + } + + async fn accept_one(&mut self) -> io::Result> { + let packet = self.recv_packet().await?; + + let now = self.env.clock().get_time(); + + let server::handshake::Outcome::Created { + receiver: handshake, + } = self.handshake.handle(&packet, &mut self.recv_buffer) + else { + return Ok(ControlFlow::Continue(())); + }; + + let remote_addr = self.recv_buffer.remote_address(); + let stream = match endpoint::accept_stream( + &self.env, + env::UdpUnbound(remote_addr), + &packet, + Some(handshake), + Some(&mut self.recv_buffer), + &self.secrets, + None, + ) { + Ok(stream) => stream, + Err(error) => { + tracing::trace!("send_start"); + + let addr = msg::addr::Addr::new(remote_addr); + let ecn = Default::default(); + let buffer = &[io::IoSlice::new(&error.secret_control)]; + + // ignore any errors since this is just for responding to invalid connect attempts + let _ = self.socket.try_send(&addr, ecn, buffer); + + tracing::trace!("send_finish"); + return Err(error.error); + } + }; + + let item = (stream, now); + let res = match self.accept_flavor { + accept::Flavor::Fifo => self.sender.send_back(item), + accept::Flavor::Lifo => self.sender.send_front(item), + }; + + match res { + Ok(prev) => { + if let Some((stream, queue_time)) = prev { + debug!( + event = "accept::prune", + credentials = ?stream.shared.credentials(), + queue_duration = ?now.saturating_duration_since(queue_time), + ); + } + + Ok(ControlFlow::Continue(())) + } + Err(_) => { + debug!("application accept queue dropped; shutting down"); + Ok(ControlFlow::Break(())) + } + } + } + + async fn recv_packet(&mut self) -> io::Result { + loop { + // discard any pending packets + self.recv_buffer.clear(); + tracing::trace!("recv_start"); + self.socket.recv_buffer(&mut self.recv_buffer).await?; + tracing::trace!("recv_finish"); + + match server::InitialPacket::peek(&mut self.recv_buffer, 16) { + Ok(initial_packet) => { + tracing::debug!(?initial_packet); + return Ok(initial_packet); + } + Err(initial_packet_error) => { + tracing::debug!(?initial_packet_error); + continue; + } + } + } + } +} diff --git a/dc/s2n-quic-dc/src/sync.rs b/dc/s2n-quic-dc/src/sync.rs new file mode 100644 index 0000000000..893bcc3842 --- /dev/null +++ b/dc/s2n-quic-dc/src/sync.rs @@ -0,0 +1,5 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod channel; +pub mod ring_deque; diff --git a/dc/s2n-quic-dc/src/sync/channel.rs b/dc/s2n-quic-dc/src/sync/channel.rs new file mode 100644 index 0000000000..a3141e67fd --- /dev/null +++ b/dc/s2n-quic-dc/src/sync/channel.rs @@ -0,0 +1,318 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::sync::ring_deque::{self, RingDeque}; +use core::{fmt, marker::PhantomPinned, pin::Pin, task::Poll}; +use event_listener_strategy::{ + easy_wrapper, + event_listener::{Event, EventListener}, + EventListenerFuture, Strategy, +}; +use pin_project_lite::pin_project; +use s2n_quic_core::ready; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Weak, +}; + +pub use ring_deque::{Closed, Priority}; + +pub fn new(cap: usize) -> (Sender, Receiver) { + assert!(cap >= 1, "capacity must be at least 2"); + + let channel = Arc::new(Channel { + queue: RingDeque::new(cap), + recv_ops: Event::new(), + sender_count: AtomicUsize::new(1), + receiver_count: AtomicUsize::new(1), + }); + + let s = Sender { + channel: channel.clone(), + }; + let r = Receiver { + listener: None, + channel, + _pin: PhantomPinned, + }; + (s, r) +} + +struct Channel { + queue: RingDeque, + recv_ops: Event, + sender_count: AtomicUsize, + receiver_count: AtomicUsize, +} + +impl Channel { + /// Closes the channel and notifies all blocked operations. + /// + /// Returns `Err` if this call has closed the channel and it was not closed already. + fn close(&self) -> Result<(), Closed> { + self.queue.close()?; + + // Notify all receive and send operations. + self.recv_ops.notify(usize::MAX); + + Ok(()) + } +} + +pub struct Sender { + channel: Arc>, +} + +impl Sender { + #[inline] + pub fn send_back(&self, msg: T) -> Result, Closed> { + let res = self.channel.queue.push_back(msg)?; + + // Notify a blocked receive operation. If the notified operation gets canceled, + // it will notify another blocked receive operation. + self.channel.recv_ops.notify_additional(1); + + Ok(res) + } + + #[inline] + pub fn send_front(&self, msg: T) -> Result, Closed> { + let res = self.channel.queue.push_front(msg)?; + + // Notify a blocked receive operation. If the notified operation gets canceled, + // it will notify another blocked receive operation. + self.channel.recv_ops.notify_additional(1); + + Ok(res) + } +} + +impl Drop for Sender { + fn drop(&mut self) { + // Decrement the sender count and close the channel if it drops down to zero. + if self.channel.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 { + let _ = self.channel.close(); + } + } +} + +impl fmt::Debug for Sender { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Sender {{ .. }}") + } +} + +impl Clone for Sender { + fn clone(&self) -> Sender { + let count = self.channel.sender_count.fetch_add(1, Ordering::Relaxed); + + // Make sure the count never overflows, even if lots of sender clones are leaked. + assert!(count < usize::MAX / 2, "too many senders"); + + Sender { + channel: self.channel.clone(), + } + } +} + +pin_project! { + /// The receiving side of a channel. + /// + /// Receivers can be cloned and shared among threads. When all receivers associated with a channel + /// are dropped, the channel becomes closed. + /// + /// The channel can also be closed manually by calling [`Receiver::close()`]. + /// + /// Receivers implement the [`Stream`] trait. + pub struct Receiver { + // Inner channel state. + channel: Arc>, + + // Listens for a send or close event to unblock this stream. + listener: Option, + + // Keeping this type `!Unpin` enables future optimizations. + #[pin] + _pin: PhantomPinned + } + + impl PinnedDrop for Receiver { + fn drop(this: Pin<&mut Self>) { + let this = this.project(); + + // Decrement the receiver count and close the channel if it drops down to zero. + if this.channel.receiver_count.fetch_sub(1, Ordering::AcqRel) == 1 { + let _ = this.channel.close(); + } + } + } +} + +#[allow(dead_code)] // TODO remove this once the module is public +impl Receiver { + /// Attempts to receive a message from the front of the channel. + /// + /// If the channel is empty, or empty and closed, this method returns an error. + #[inline] + pub fn try_recv_front(&self) -> Result, Closed> { + self.channel.queue.pop_front() + } + + /// Attempts to receive a message from the back of the channel. + /// + /// If the channel is empty, or empty and closed, this method returns an error. + #[inline] + pub fn try_recv_back(&self) -> Result, Closed> { + self.channel.queue.pop_back() + } + + /// Receives a message from the front of the channel. + /// + /// If the channel is empty, this method waits until there is a message. + /// + /// If the channel is closed, this method receives a message or returns an error if there are + /// no more messages. + #[inline] + pub fn recv_front(&self) -> Recv<'_, T> { + Recv::_new(RecvInner { + receiver: self, + pop_end: PopEnd::Front, + listener: None, + _pin: PhantomPinned, + }) + } + + /// Receives a message from the back of the channel. + /// + /// If the channel is empty, this method waits until there is a message. + /// + /// If the channel is closed, this method receives a message or returns an error if there are + /// no more messages. + #[inline] + pub fn recv_back(&self) -> Recv<'_, T> { + Recv::_new(RecvInner { + receiver: self, + pop_end: PopEnd::Back, + listener: None, + _pin: PhantomPinned, + }) + } + + #[inline] + pub fn downgrade(&self) -> WeakReceiver { + WeakReceiver { + channel: Arc::downgrade(&self.channel), + } + } +} + +impl fmt::Debug for Receiver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Receiver {{ .. }}") + } +} + +impl Clone for Receiver { + fn clone(&self) -> Receiver { + let count = self.channel.receiver_count.fetch_add(1, Ordering::Relaxed); + + // Make sure the count never overflows, even if lots of receiver clones are leaked. + assert!(count < usize::MAX / 2); + + Receiver { + channel: self.channel.clone(), + listener: None, + _pin: PhantomPinned, + } + } +} + +#[derive(Clone)] +pub struct WeakReceiver { + channel: Weak>, +} + +#[allow(dead_code)] // TODO remove this once the module is public +impl WeakReceiver { + #[inline] + pub fn pop_front_if(&self, priority: Priority, f: F) -> Result, Closed> + where + F: FnOnce(&T) -> bool, + { + let channel = self.channel.upgrade().ok_or(Closed)?; + channel.queue.pop_front_if(priority, f) + } + + #[inline] + pub fn pop_back_if(&self, priority: Priority, f: F) -> Result, Closed> + where + F: FnOnce(&T) -> bool, + { + let channel = self.channel.upgrade().ok_or(Closed)?; + channel.queue.pop_back_if(priority, f) + } +} + +easy_wrapper! { + /// A future returned by [`Receiver::recv()`]. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Recv<'a, T>(RecvInner<'a, T> => Result); + pub(crate) wait(); +} + +#[derive(Debug)] +enum PopEnd { + Front, + Back, +} + +pin_project! { + #[derive(Debug)] + #[project(!Unpin)] + struct RecvInner<'a, T> { + // Reference to the receiver. + receiver: &'a Receiver, + + pop_end: PopEnd, + + // Listener waiting on the channel. + listener: Option, + + // Keeping this type `!Unpin` enables future optimizations. + #[pin] + _pin: PhantomPinned + } +} + +impl<'a, T> EventListenerFuture for RecvInner<'a, T> { + type Output = Result; + + /// Run this future with the given `Strategy`. + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll> { + let this = self.project(); + + loop { + // Attempt to receive a message. + let message = match this.pop_end { + PopEnd::Front => this.receiver.try_recv_front(), + PopEnd::Back => this.receiver.try_recv_back(), + }?; + if let Some(msg) = message { + return Poll::Ready(Ok(msg)); + } + + // Receiving failed - now start listening for notifications or wait for one. + if this.listener.is_some() { + // Poll using the given strategy + ready!(S::poll(strategy, &mut *this.listener, cx)); + } else { + *this.listener = Some(this.receiver.channel.recv_ops.listen()); + } + } + } +} diff --git a/dc/s2n-quic-dc/src/sync/ring_deque.rs b/dc/s2n-quic-dc/src/sync/ring_deque.rs new file mode 100644 index 0000000000..bd952ace45 --- /dev/null +++ b/dc/s2n-quic-dc/src/sync/ring_deque.rs @@ -0,0 +1,168 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_quic_core::ensure; +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +#[cfg(test)] +mod tests; + +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct Closed; + +#[derive(Clone, Copy, Debug, Default)] +pub enum Priority { + #[default] + Required, + Optional, +} + +pub struct RingDeque { + inner: Arc>>, +} + +impl Clone for RingDeque { + #[inline] + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +#[allow(dead_code)] // TODO remove this once the module is public +impl RingDeque { + #[inline] + pub fn new(capacity: usize) -> Self { + let queue = VecDeque::with_capacity(capacity); + let inner = Inner { open: true, queue }; + let inner = Arc::new(Mutex::new(inner)); + RingDeque { inner } + } + + #[inline] + pub fn push_back(&self, value: T) -> Result, Closed> { + let mut inner = self.lock()?; + + let prev = if inner.queue.capacity() == inner.queue.len() { + inner.queue.pop_front() + } else { + None + }; + + inner.queue.push_back(value); + + Ok(prev) + } + + #[inline] + pub fn push_front(&self, value: T) -> Result, Closed> { + let mut inner = self.lock()?; + + let prev = if inner.queue.capacity() == inner.queue.len() { + inner.queue.pop_back() + } else { + None + }; + + inner.queue.push_front(value); + + Ok(prev) + } + + #[inline] + pub fn pop_back(&self) -> Result, Closed> { + let mut inner = self.lock()?; + Ok(inner.queue.pop_back()) + } + + #[inline] + pub fn pop_back_if(&self, priority: Priority, check: F) -> Result, Closed> + where + F: FnOnce(&T) -> bool, + { + let inner = match priority { + Priority::Required => Some(self.lock()?), + Priority::Optional => self.try_lock()?, + }; + + let Some(mut inner) = inner else { + return Ok(None); + }; + + let Some(back) = inner.queue.back() else { + return Ok(None); + }; + + if check(back) { + Ok(inner.queue.pop_back()) + } else { + Ok(None) + } + } + + #[inline] + pub fn pop_front(&self) -> Result, Closed> { + let mut inner = self.lock()?; + Ok(inner.queue.pop_front()) + } + + #[inline] + pub fn pop_front_if(&self, priority: Priority, check: F) -> Result, Closed> + where + F: FnOnce(&T) -> bool, + { + let inner = match priority { + Priority::Required => Some(self.lock()?), + Priority::Optional => self.try_lock()?, + }; + + let Some(mut inner) = inner else { + return Ok(None); + }; + + let Some(back) = inner.queue.front() else { + return Ok(None); + }; + + if check(back) { + Ok(inner.queue.pop_front()) + } else { + Ok(None) + } + } + + #[inline] + pub fn close(&self) -> Result<(), Closed> { + let mut inner = self.lock()?; + inner.open = false; + Ok(()) + } + + #[inline] + fn lock(&self) -> Result>, Closed> { + let inner = self.inner.lock().unwrap(); + ensure!(inner.open, Err(Closed)); + Ok(inner) + } + + #[inline] + fn try_lock(&self) -> Result>>, Closed> { + use std::sync::TryLockError; + let inner = match self.inner.try_lock() { + Ok(inner) => inner, + Err(TryLockError::WouldBlock) => return Ok(None), + Err(TryLockError::Poisoned(_)) => return Err(Closed), + }; + ensure!(inner.open, Err(Closed)); + Ok(Some(inner)) + } +} + +struct Inner { + open: bool, + queue: VecDeque, +} diff --git a/dc/s2n-quic-dc/src/sync/ring_deque/tests.rs b/dc/s2n-quic-dc/src/sync/ring_deque/tests.rs new file mode 100644 index 0000000000..058d41ef89 --- /dev/null +++ b/dc/s2n-quic-dc/src/sync/ring_deque/tests.rs @@ -0,0 +1,160 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use bolero::{check, TypeGenerator}; +use core::fmt; + +struct Model { + subject: RingDeque, + oracle: VecDeque, + open: bool, +} + +impl Default for Model { + fn default() -> Self { + Self::new(32) + } +} + +impl Model { + pub fn new(cap: usize) -> Self { + Self { + subject: RingDeque::new(cap), + oracle: VecDeque::with_capacity(cap), + open: true, + } + } +} + +impl Model { + pub fn pop_front(&mut self) -> Result, Closed> { + let expected = if self.open { + Ok(self.oracle.pop_front()) + } else { + Err(Closed) + }; + let actual = self.subject.pop_front(); + assert_eq!(expected, actual); + actual + } + + pub fn pop_back(&mut self) -> Result, Closed> { + let expected = if self.open { + Ok(self.oracle.pop_back()) + } else { + Err(Closed) + }; + let actual = self.subject.pop_back(); + assert_eq!(expected, actual); + actual + } + + pub fn push_front(&mut self, v: T) -> Result, Closed> { + let actual = self.subject.push_front(v.clone()); + let expected = if self.open { + let prev = if self.oracle.capacity() == self.oracle.len() { + self.oracle.pop_back() + } else { + None + }; + self.oracle.push_front(v); + Ok(prev) + } else { + Err(Closed) + }; + assert_eq!(expected, actual); + actual + } + + pub fn push_back(&mut self, v: T) -> Result, Closed> { + let actual = self.subject.push_back(v.clone()); + let expected = if self.open { + let prev = if self.oracle.capacity() == self.oracle.len() { + self.oracle.pop_front() + } else { + None + }; + self.oracle.push_back(v); + Ok(prev) + } else { + Err(Closed) + }; + assert_eq!(expected, actual); + actual + } + + pub fn close(&mut self) -> Result<(), Closed> { + let actual = self.subject.close(); + let expected = if self.open { + self.open = false; + Ok(()) + } else { + Err(Closed) + }; + assert_eq!(actual, expected); + actual + } +} + +#[derive(Clone, Copy, Debug, TypeGenerator)] +enum Operation { + PushFront, + PushBack, + PopFront, + PopBack, + Close, +} + +#[test] +fn model_test() { + check!().with_type::>().for_each(|ops| { + let mut v = 0; + let mut model = Model::::default(); + for op in ops { + match op { + Operation::PushFront => { + let _ = model.push_front(v); + v += 1; + } + Operation::PushBack => { + let _ = model.push_back(v); + v += 1; + } + Operation::PopFront => { + let _ = model.pop_front(); + } + Operation::PopBack => { + let _ = model.pop_back(); + } + Operation::Close => { + let _ = model.close(); + } + } + } + }) +} + +#[test] +fn overflow_front_test() { + let mut model = Model::new(4); + let _ = model.push_front(0); + let _ = model.push_front(1); + let _ = model.push_front(2); + let _ = model.push_front(3); + assert_eq!(model.push_front(4), Ok(Some(0))); + + assert_eq!(model.oracle.make_contiguous(), &[4, 3, 2, 1]); +} + +#[test] +fn overflow_back_test() { + let mut model = Model::new(4); + let _ = model.push_back(0); + let _ = model.push_back(1); + let _ = model.push_back(2); + let _ = model.push_back(3); + assert_eq!(model.push_back(4), Ok(Some(0))); + + assert_eq!(model.oracle.make_contiguous(), &[1, 2, 3, 4]); +} diff --git a/quic/s2n-quic-core/src/dc.rs b/quic/s2n-quic-core/src/dc.rs index 70c83d5b75..4a4a78c174 100644 --- a/quic/s2n-quic-core/src/dc.rs +++ b/quic/s2n-quic-core/src/dc.rs @@ -140,6 +140,10 @@ impl ApplicationParams { pub fn max_idle_timeout(&self) -> Option { Some(Duration::from_millis(self.max_idle_timeout?.get() as u64)) } + + pub fn max_datagram_size(&self) -> u16 { + self.max_datagram_size.load(Ordering::Relaxed) + } } #[cfg(test)]