From c1220675907c7ea82236b1ccd73507fdaebb19d5 Mon Sep 17 00:00:00 2001 From: Cameron Bytheway Date: Mon, 8 Jul 2024 13:25:05 -0600 Subject: [PATCH] feat(s2n-quic-dc): import latest changes --- dc/s2n-quic-dc/Cargo.toml | 4 + dc/s2n-quic-dc/src/clock.rs | 86 +++++ dc/s2n-quic-dc/src/clock/tokio.rs | 122 +++++++ dc/s2n-quic-dc/src/congestion.rs | 9 +- dc/s2n-quic-dc/src/lib.rs | 3 + dc/s2n-quic-dc/src/msg.rs | 1 + dc/s2n-quic-dc/src/msg/recv.rs | 4 +- dc/s2n-quic-dc/src/msg/segment.rs | 98 ++++++ dc/s2n-quic-dc/src/path.rs | 37 +- dc/s2n-quic-dc/src/path/secret/key.rs | 6 + dc/s2n-quic-dc/src/path/secret/map.rs | 27 +- dc/s2n-quic-dc/src/path/secret/map/test.rs | 4 +- dc/s2n-quic-dc/src/random.rs | 77 +++++ dc/s2n-quic-dc/src/stream.rs | 3 + dc/s2n-quic-dc/src/stream/crypto.rs | 93 +++++ dc/s2n-quic-dc/src/stream/pacer.rs | 47 +++ dc/s2n-quic-dc/src/stream/recv.rs | 27 +- dc/s2n-quic-dc/src/stream/recv/packet.rs | 6 +- dc/s2n-quic-dc/src/stream/send.rs | 2 + dc/s2n-quic-dc/src/stream/send/application.rs | 4 +- dc/s2n-quic-dc/src/stream/send/flow.rs | 9 +- .../src/stream/send/flow/blocking.rs | 2 +- .../src/stream/send/flow/non_blocking.rs | 2 +- dc/s2n-quic-dc/src/stream/send/path.rs | 42 ++- dc/s2n-quic-dc/src/stream/send/queue.rs | 326 ++++++++++++++++++ dc/s2n-quic-dc/src/stream/send/shared.rs | 121 +++++++ dc/s2n-quic-dc/src/stream/send/worker.rs | 11 +- dc/s2n-quic-dc/src/stream/socket.rs | 35 ++ .../src/stream/socket/application.rs | 92 +++++ .../src/stream/socket/application/builder.rs | 72 ++++ dc/s2n-quic-dc/src/stream/socket/fd.rs | 37 ++ dc/s2n-quic-dc/src/stream/socket/fd/tcp.rs | 71 ++++ dc/s2n-quic-dc/src/stream/socket/fd/udp.rs | 145 ++++++++ dc/s2n-quic-dc/src/stream/socket/handle.rs | 181 ++++++++++ dc/s2n-quic-dc/src/stream/socket/tokio.rs | 7 + dc/s2n-quic-dc/src/stream/socket/tokio/tcp.rs | 144 ++++++++ dc/s2n-quic-dc/src/stream/socket/tokio/udp.rs | 212 ++++++++++++ dc/s2n-quic-dc/src/stream/socket/tracing.rs | 155 +++++++++ dc/s2n-quic-dc/src/task.rs | 4 + dc/s2n-quic-dc/src/task/waker.rs | 4 + dc/s2n-quic-dc/src/task/waker/worker.rs | 126 +++++++ quic/s2n-quic-core/src/macros.rs | 2 +- quic/s2n-quic-core/src/time/clock.rs | 8 +- 43 files changed, 2353 insertions(+), 115 deletions(-) create mode 100644 dc/s2n-quic-dc/src/clock.rs create mode 100644 dc/s2n-quic-dc/src/clock/tokio.rs create mode 100644 dc/s2n-quic-dc/src/msg/segment.rs create mode 100644 dc/s2n-quic-dc/src/random.rs create mode 100644 dc/s2n-quic-dc/src/stream/crypto.rs create mode 100644 dc/s2n-quic-dc/src/stream/pacer.rs create mode 100644 dc/s2n-quic-dc/src/stream/send/queue.rs create mode 100644 dc/s2n-quic-dc/src/stream/send/shared.rs create mode 100644 dc/s2n-quic-dc/src/stream/socket.rs create mode 100644 dc/s2n-quic-dc/src/stream/socket/application.rs create mode 100644 dc/s2n-quic-dc/src/stream/socket/application/builder.rs create mode 100644 dc/s2n-quic-dc/src/stream/socket/fd.rs create mode 100644 dc/s2n-quic-dc/src/stream/socket/fd/tcp.rs create mode 100644 dc/s2n-quic-dc/src/stream/socket/fd/udp.rs create mode 100644 dc/s2n-quic-dc/src/stream/socket/handle.rs create mode 100644 dc/s2n-quic-dc/src/stream/socket/tokio.rs create mode 100644 dc/s2n-quic-dc/src/stream/socket/tokio/tcp.rs create mode 100644 dc/s2n-quic-dc/src/stream/socket/tokio/udp.rs create mode 100644 dc/s2n-quic-dc/src/stream/socket/tracing.rs create mode 100644 dc/s2n-quic-dc/src/task.rs create mode 100644 dc/s2n-quic-dc/src/task/waker.rs create mode 100644 dc/s2n-quic-dc/src/task/waker/worker.rs diff --git a/dc/s2n-quic-dc/Cargo.toml b/dc/s2n-quic-dc/Cargo.toml index d1fa274f6e..b6cf820e86 100644 --- a/dc/s2n-quic-dc/Cargo.toml +++ b/dc/s2n-quic-dc/Cargo.toml @@ -14,18 +14,22 @@ exclude = ["corpus.tar.gz"] testing = ["bolero-generator", "s2n-quic-core/testing"] [dependencies] +arrayvec = "0.7" atomic-waker = "1" aws-lc-rs = "1" bitflags = "2" bolero-generator = { version = "0.11", optional = true } bytes = "1" crossbeam-channel = "0.5" +crossbeam-epoch = "0.9" crossbeam-queue = { version = "0.3" } flurry = "0.5" libc = "0.2" num-rational = { version = "0.4", default-features = false } once_cell = "1" +pin-project-lite = "0.2" rand = { version = "0.8", features = ["small_rng"] } +rand_chacha = "0.3" s2n-codec = { version = "=0.41.0", path = "../../common/s2n-codec", default-features = false } s2n-quic-core = { version = "=0.41.0", path = "../../quic/s2n-quic-core", default-features = false } s2n-quic-platform = { version = "=0.41.0", path = "../../quic/s2n-quic-platform" } diff --git a/dc/s2n-quic-dc/src/clock.rs b/dc/s2n-quic-dc/src/clock.rs new file mode 100644 index 0000000000..a8bf3cdcbc --- /dev/null +++ b/dc/s2n-quic-dc/src/clock.rs @@ -0,0 +1,86 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::{fmt, pin::Pin, task::Poll, time::Duration}; +use s2n_quic_core::{ensure, time}; +use tracing::trace; + +pub mod tokio; +pub use time::clock::Cached; + +pub use time::Timestamp; +pub type SleepHandle = Pin>; + +pub trait Clock: 'static + Send + Sync + fmt::Debug + time::Clock { + fn sleep(&self, amount: Duration) -> (SleepHandle, Timestamp); +} + +pub trait Sleep: Clock + core::future::Future { + fn update(self: Pin<&mut Self>, target: Timestamp); +} + +pub struct Timer { + /// The `Instant` at which the timer should expire + target: Option, + /// The handle to the timer entry in the tokio runtime + sleep: Pin>, +} + +impl fmt::Debug for Timer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Timer") + .field("target", &self.target) + .finish() + } +} + +impl Timer { + #[inline] + pub fn new(clock: &dyn Clock) -> Self { + /// We can't create a timer without first arming it to something, so just set it to 1s in + /// the future. + const INITIAL_TIMEOUT: Duration = Duration::from_secs(1); + + Self::new_with_timeout(clock, INITIAL_TIMEOUT) + } + + #[inline] + pub fn new_with_timeout(clock: &dyn Clock, timeout: Duration) -> Self { + let (sleep, target) = clock.sleep(timeout); + Self { + target: Some(target), + sleep, + } + } + + #[inline] + pub fn cancel(&mut self) { + trace!(cancel = ?self.target); + self.target = None; + } +} + +impl time::clock::Timer for Timer { + #[inline] + fn poll_ready(&mut self, cx: &mut core::task::Context) -> Poll<()> { + ensure!(self.target.is_some(), Poll::Ready(())); + + let res = self.sleep.as_mut().poll(cx); + + if res.is_ready() { + // clear the target after it fires, otherwise we'll endlessly wake up the task + self.target = None; + } + + res + } + + #[inline] + fn update(&mut self, target: Timestamp) { + // no need to update if it hasn't changed + ensure!(self.target != Some(target)); + + self.sleep.as_mut().update(target); + self.target = Some(target); + } +} diff --git a/dc/s2n-quic-dc/src/clock/tokio.rs b/dc/s2n-quic-dc/src/clock/tokio.rs new file mode 100644 index 0000000000..e80ff5d7cf --- /dev/null +++ b/dc/s2n-quic-dc/src/clock/tokio.rs @@ -0,0 +1,122 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::SleepHandle; +use core::{ + fmt, + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use pin_project_lite::pin_project; +use s2n_quic_core::{ready, time::Timestamp}; +use tokio::time::{self, sleep_until, Instant}; +use tracing::trace; + +#[derive(Clone, Debug)] +pub struct Clock(Instant); + +impl Default for Clock { + #[inline] + fn default() -> Self { + Self(Instant::now()) + } +} + +impl s2n_quic_core::time::Clock for Clock { + #[inline] + fn get_time(&self) -> Timestamp { + let time = self.0.elapsed(); + unsafe { Timestamp::from_duration(time) } + } +} + +pin_project!( + pub struct Sleep { + clock: Clock, + #[pin] + sleep: time::Sleep, + } +); + +impl s2n_quic_core::time::Clock for Sleep { + #[inline] + fn get_time(&self) -> Timestamp { + self.clock.get_time() + } +} + +impl Future for Sleep { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + ready!(core::future::Future::poll(this.sleep, cx)); + Poll::Ready(()) + } +} + +impl super::Sleep for Sleep { + #[inline] + fn update(self: Pin<&mut Self>, target: Timestamp) { + let target = unsafe { target.as_duration() }; + + // floor the delay to milliseconds to reduce timer churn + let delay = Duration::from_millis(target.as_millis() as u64); + + let target = self.clock.0 + delay; + + // if the clock has changed let the sleep future know + trace!(update = ?target); + self.project().sleep.reset(target); + } +} + +impl super::Clock for Sleep { + #[inline] + fn sleep(&self, amount: Duration) -> (SleepHandle, Timestamp) { + self.clock.sleep(amount) + } +} + +impl fmt::Debug for Sleep { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Sleep") + .field("clock", &self.clock) + .field("sleep", &self.sleep) + .finish() + } +} + +impl super::Clock for Clock { + #[inline] + fn sleep(&self, amount: Duration) -> (SleepHandle, Timestamp) { + let now = Instant::now(); + let sleep = sleep_until(now + amount); + let sleep = Sleep { + clock: self.clone(), + sleep, + }; + let sleep = Box::pin(sleep); + let target = now.saturating_duration_since(self.0); + let target = unsafe { Timestamp::from_duration(target) }; + (sleep, target) + } +} + +#[cfg(test)] +mod tests { + use crate::clock::{tokio::Clock, Timer}; + use core::time::Duration; + use s2n_quic_core::time::{clock::Timer as _, Clock as _}; + + #[tokio::test] + async fn clock_test() { + let clock = Clock::default(); + let mut timer = Timer::new(&clock); + timer.ready().await; + timer.update(clock.get_time() + Duration::from_secs(1)); + timer.ready().await; + } +} diff --git a/dc/s2n-quic-dc/src/congestion.rs b/dc/s2n-quic-dc/src/congestion.rs index 9fdd9a9432..8d181554da 100644 --- a/dc/s2n-quic-dc/src/congestion.rs +++ b/dc/s2n-quic-dc/src/congestion.rs @@ -19,11 +19,10 @@ pub struct Controller { impl Controller { #[inline] - pub fn new(mtu: u16) -> Self { - let mut controller = BbrCongestionController::new(mtu, Default::default()); - let publisher = &mut NoopPublisher; - controller.on_mtu_update(mtu, publisher); - Self { controller } + pub fn new(max_datagram_size: u16) -> Self { + Self { + controller: BbrCongestionController::new(max_datagram_size, Default::default()), + } } #[inline] diff --git a/dc/s2n-quic-dc/src/lib.rs b/dc/s2n-quic-dc/src/lib.rs index e98c61d057..0ce8c3e3af 100644 --- a/dc/s2n-quic-dc/src/lib.rs +++ b/dc/s2n-quic-dc/src/lib.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 pub mod allocator; +pub mod clock; pub mod congestion; pub mod control; pub mod credentials; @@ -11,8 +12,10 @@ pub mod msg; pub mod packet; pub mod path; pub mod pool; +pub mod random; pub mod recovery; pub mod socket; pub mod stream; +pub mod task; pub use s2n_quic_core::dc::{Version, SUPPORTED_VERSIONS}; diff --git a/dc/s2n-quic-dc/src/msg.rs b/dc/s2n-quic-dc/src/msg.rs index 25a3b74412..a34cbc8845 100644 --- a/dc/s2n-quic-dc/src/msg.rs +++ b/dc/s2n-quic-dc/src/msg.rs @@ -4,4 +4,5 @@ pub mod addr; pub mod cmsg; pub mod recv; +pub mod segment; pub mod send; diff --git a/dc/s2n-quic-dc/src/msg/recv.rs b/dc/s2n-quic-dc/src/msg/recv.rs index 5649b88b75..be692ef0bd 100644 --- a/dc/s2n-quic-dc/src/msg/recv.rs +++ b/dc/s2n-quic-dc/src/msg/recv.rs @@ -8,7 +8,6 @@ use s2n_quic_core::{ buffer::Deque as Buffer, ensure, inet::{ExplicitCongestionNotification, SocketAddress}, - path::MaxMtu, ready, }; use std::{io, os::fd::AsRawFd}; @@ -32,8 +31,7 @@ impl fmt::Debug for Message { impl Message { #[inline] - pub fn new(max_mtu: MaxMtu) -> Self { - let max_mtu: u16 = max_mtu.into(); + pub fn new(max_mtu: u16) -> Self { let max_mtu = max_mtu as usize; let buffer_len = cmsg::MAX_GRO_SEGMENTS * max_mtu; // the recv syscall doesn't return more than this diff --git a/dc/s2n-quic-dc/src/msg/segment.rs b/dc/s2n-quic-dc/src/msg/segment.rs new file mode 100644 index 0000000000..60185ac8e6 --- /dev/null +++ b/dc/s2n-quic-dc/src/msg/segment.rs @@ -0,0 +1,98 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use arrayvec::ArrayVec; +use core::ops::Deref; +use s2n_quic_core::{ensure, inet::ExplicitCongestionNotification}; +use std::io::IoSlice; + +/// The maximum number of segments in sendmsg calls +/// +/// From +/// > #define UIO_FASTIOV 8 +pub const MAX_COUNT: usize = if cfg!(target_os = "linux") { 8 } else { 1 }; + +/// The maximum payload allowed in sendmsg calls +/// +/// From +/// > Linux enforces a u16::MAX - IP_HEADER_LEN - UDP_HEADER_LEN +pub const MAX_TOTAL: u16 = u16::MAX - 50; + +type Segments<'a> = ArrayVec, MAX_COUNT>; + +pub struct Batch<'a> { + segments: Segments<'a>, + ecn: ExplicitCongestionNotification, +} + +impl<'a> Deref for Batch<'a> { + type Target = [IoSlice<'a>]; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.segments + } +} + +impl<'a> Batch<'a> { + #[inline] + pub fn new(queue: Q) -> Self + where + Q: IntoIterator, + { + // this value is replaced by the first segment + let mut ecn = ExplicitCongestionNotification::Ect0; + let mut total_len = 0u16; + let mut segments = Segments::new(); + + for segment in queue { + let packet_len = segment.1.len(); + debug_assert!( + packet_len <= u16::MAX as usize, + "segments should not exceed the maximum datagram size" + ); + let packet_len = packet_len as u16; + + // make sure the packet fits in u16::MAX + let Some(new_total_len) = total_len.checked_add(packet_len) else { + break; + }; + // make sure we don't exceed the max allowed payload size + ensure!(new_total_len < MAX_TOTAL, break); + + // track if the current segment is undersized from the previous + let mut undersized_segment = false; + + // make sure we're compatible with the previous segment + if let Some(first_segment) = segments.first() { + ensure!(first_segment.len() >= packet_len as usize, break); + // this is the last segment we can push if the segment is undersized + undersized_segment = first_segment.len() > packet_len as usize; + // make sure ecn doesn't change with this transmission + ensure!(ecn == segment.0, break); + } else { + // update the ecn value with the first segment + ecn = segment.0; + } + + // update the total len once we confirm this segment can be written + total_len = new_total_len; + + let iovec = std::io::IoSlice::new(segment.1); + segments.push(iovec); + + // if this segment was undersized, then bail + ensure!(!undersized_segment, break); + + // make sure we have capacity before looping back around + ensure!(!segments.is_full(), break); + } + + Self { segments, ecn } + } + + #[inline] + pub fn ecn(&self) -> ExplicitCongestionNotification { + self.ecn + } +} diff --git a/dc/s2n-quic-dc/src/path.rs b/dc/s2n-quic-dc/src/path.rs index 9a98ad630b..f4fa3014a1 100644 --- a/dc/s2n-quic-dc/src/path.rs +++ b/dc/s2n-quic-dc/src/path.rs @@ -1,47 +1,12 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use s2n_quic_core::{ - path::{Handle, MaxMtu, Tuple}, - varint::VarInt, -}; +use s2n_quic_core::path::{Handle, Tuple}; pub mod secret; #[cfg(any(test, feature = "testing"))] pub mod testing; -pub static DEFAULT_MAX_DATA: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { - std::env::var("DC_QUIC_DEFAULT_MAX_DATA") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(1u32 << 25) - .into() -}); - -pub static DEFAULT_MTU: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { - let default_mtu = if cfg!(target_os = "linux") { - 8940 - } else { - 1450 - }; - - std::env::var("DC_QUIC_DEFAULT_MTU") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(default_mtu) - .try_into() - .unwrap() -}); - -pub static DEFAULT_IDLE_TIMEOUT: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { - std::env::var("DC_QUIC_DEFAULT_IDLE_TIMEOUT") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(crate::stream::DEFAULT_IDLE_TIMEOUT.as_secs()) - .try_into() - .unwrap() -}); - pub trait Controller { type Handle: Handle; diff --git a/dc/s2n-quic-dc/src/path/secret/key.rs b/dc/s2n-quic-dc/src/path/secret/key.rs index f85f330489..ef6af3b8b3 100644 --- a/dc/s2n-quic-dc/src/path/secret/key.rs +++ b/dc/s2n-quic-dc/src/path/secret/key.rs @@ -87,6 +87,12 @@ impl Opener { Ok(()) } + + #[doc(hidden)] + #[cfg(any(test, feature = "testing"))] + pub fn dedup_check(&self) -> decrypt::Result { + self.dedup.check(&self.opener) + } } impl decrypt::Key for Opener { diff --git a/dc/s2n-quic-dc/src/path/secret/map.rs b/dc/s2n-quic-dc/src/path/secret/map.rs index 5a64cc4e95..02fdfb7fd1 100644 --- a/dc/s2n-quic-dc/src/path/secret/map.rs +++ b/dc/s2n-quic-dc/src/path/secret/map.rs @@ -79,7 +79,7 @@ pub(super) struct State { // can use a single map to store both kinds and treat them identically. // // In the future it's likely we'll want to build bidirectional support in which case splitting - // this into two maps (per the discussino in "Managing memory consumption" above) will be + // this into two maps (per the discussion in "Managing memory consumption" above) will be // needed. pub(super) peers: flurry::HashMap>, @@ -497,7 +497,7 @@ impl Map { secret, sender, receiver_shared.clone().new_receiver(), - testing::test_application_params(), + dc::testing::TEST_APPLICATION_PARAMS, ); let entry = Arc::new(entry); provider.insert(entry); @@ -524,7 +524,7 @@ impl Map { secret, sender, receiver, - testing::test_application_params(), + dc::testing::TEST_APPLICATION_PARAMS, ); self.insert(Arc::new(entry)); } @@ -549,6 +549,12 @@ impl Map { } } } + + #[doc(hidden)] + #[cfg(any(test, feature = "testing"))] + pub fn handled_control_packets(&self) -> usize { + self.state.handled_control_packets.load(Ordering::Relaxed) + } } impl receiver::Error { @@ -871,20 +877,5 @@ impl dc::Path for HandshakingPath { } } -#[cfg(any(test, feature = "testing"))] -pub mod testing { - use s2n_quic_core::{ - connection::Limits, dc::ApplicationParams, transport::parameters::InitialFlowControlLimits, - }; - - pub fn test_application_params() -> ApplicationParams { - ApplicationParams::new( - s2n_quic_core::path::MaxMtu::default().into(), - &InitialFlowControlLimits::default(), - &Limits::default(), - ) - } -} - #[cfg(test)] mod test; diff --git a/dc/s2n-quic-dc/src/path/secret/map/test.rs b/dc/s2n-quic-dc/src/path/secret/map/test.rs index 08f3abafa3..b0a2e7c1cd 100644 --- a/dc/s2n-quic-dc/src/path/secret/map/test.rs +++ b/dc/s2n-quic-dc/src/path/secret/map/test.rs @@ -24,7 +24,7 @@ fn fake_entry(peer: u16) -> Arc { ), sender::State::new([0; 16]), receiver::State::without_shared(), - super::testing::test_application_params(), + dc::testing::TEST_APPLICATION_PARAMS, )) } @@ -137,7 +137,7 @@ impl Model { secret, sender::State::new(stateless_reset), state.state.receiver_shared.clone().new_receiver(), - super::testing::test_application_params(), + dc::testing::TEST_APPLICATION_PARAMS, ))); self.invariants.insert(Invariant::ContainsIp(ip)); diff --git a/dc/s2n-quic-dc/src/random.rs b/dc/s2n-quic-dc/src/random.rs new file mode 100644 index 0000000000..198facda78 --- /dev/null +++ b/dc/s2n-quic-dc/src/random.rs @@ -0,0 +1,77 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use rand::{rngs::adapter::ReseedingRng, RngCore, SeedableRng}; +use rand_chacha::ChaChaCore; + +pub use s2n_quic_core::random::*; + +struct AwsLc; + +impl RngCore for AwsLc { + #[inline] + fn next_u32(&mut self) -> u32 { + let mut v = [0; 4]; + self.fill_bytes(&mut v); + u32::from_ne_bytes(v) + } + + #[inline] + fn next_u64(&mut self) -> u64 { + let mut v = [0; 8]; + self.fill_bytes(&mut v); + u64::from_ne_bytes(v) + } + + #[inline] + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.try_fill_bytes(dest).unwrap() + } + + #[inline] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { + aws_lc_rs::rand::fill(dest).map_err(rand::Error::new) + } +} + +pub struct Random { + public: ReseedingRng, + private: ReseedingRng, +} + +impl Default for Random { + #[inline] + fn default() -> Self { + Self { + public: build_rng(), + private: build_rng(), + } + } +} + +// Constructs a `ReseedingRng` with a ChaCha RNG initially seeded from the OS, +// that will reseed from the OS after RESEED_THRESHOLD is exceeded +fn build_rng() -> ReseedingRng { + // Number of generated bytes after which to reseed the public and private random + // generators. + // + // This value is based on THREAD_RNG_RESEED_THRESHOLD from + // [rand::rngs::thread.rs](https://github.com/rust-random/rand/blob/ef75e56cf5824d33c55622bf84a70ec6e22761ba/src/rngs/thread.rs#L39) + const RESEED_THRESHOLD: u64 = 1024 * 64; + + let prng = ChaChaCore::from_rng(AwsLc) + .unwrap_or_else(|err| panic!("could not initialize random generator: {err}")); + ReseedingRng::new(prng, RESEED_THRESHOLD, AwsLc) +} + +impl Generator for Random { + #[inline] + fn public_random_fill(&mut self, dest: &mut [u8]) { + self.public.fill_bytes(dest); + } + + #[inline] + fn private_random_fill(&mut self, dest: &mut [u8]) { + self.private.fill_bytes(dest); + } +} diff --git a/dc/s2n-quic-dc/src/stream.rs b/dc/s2n-quic-dc/src/stream.rs index 84484bb9e7..ef3a82b922 100644 --- a/dc/s2n-quic-dc/src/stream.rs +++ b/dc/s2n-quic-dc/src/stream.rs @@ -8,12 +8,15 @@ pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(30); /// The maximum time a send stream will wait for ACKs from inflight packets pub const DEFAULT_INFLIGHT_TIMEOUT: Duration = Duration::from_secs(5); +pub mod crypto; +pub mod pacer; pub mod packet_map; pub mod packet_number; pub mod processing; pub mod recv; pub mod send; pub mod server; +pub mod socket; bitflags::bitflags! { #[derive(Clone, Copy, Debug, PartialEq, Eq)] diff --git a/dc/s2n-quic-dc/src/stream/crypto.rs b/dc/s2n-quic-dc/src/stream/crypto.rs new file mode 100644 index 0000000000..fa054b8446 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/crypto.rs @@ -0,0 +1,93 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::path::secret::{Map, Opener, Sealer}; +use core::{fmt, sync::atomic::Ordering}; +use crossbeam_epoch::{pin, Atomic}; + +// TODO support key updates +pub struct Crypto { + sealer: Atomic, + opener: Atomic, + map: Map, +} + +impl fmt::Debug for Crypto { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Crypto") + .field("sealer", &self.sealer) + .field("opener", &self.opener) + .finish() + } +} + +impl Crypto { + #[inline] + pub fn new(sealer: Sealer, opener: Opener, map: &Map) -> Self { + let sealer = Atomic::new(sealer); + let opener = Atomic::new(opener); + Self { + sealer, + opener, + map: map.clone(), + } + } + + #[inline(always)] + pub fn tag_len(&self) -> usize { + 16 + } + + #[inline] + pub fn map(&self) -> &Map { + &self.map + } + + #[inline] + pub fn seal_with(&self, seal: impl FnOnce(&Sealer) -> R) -> R { + let pin = pin(); + let sealer = self.seal_pin(&pin); + seal(sealer) + } + + #[inline] + fn seal_pin<'a>(&self, pin: &'a crossbeam_epoch::Guard) -> &'a Sealer { + let sealer = self.sealer.load(Ordering::Acquire, pin); + unsafe { sealer.deref() } + } + + #[inline] + pub fn open_with(&self, open: impl FnOnce(&Opener) -> R) -> R { + let pin = pin(); + let opener = self.open_pin(&pin); + open(opener) + } + + #[inline] + fn open_pin<'a>(&self, pin: &'a crossbeam_epoch::Guard) -> &'a Opener { + let opener = self.opener.load(Ordering::Acquire, pin); + unsafe { opener.deref() } + } +} + +impl Drop for Crypto { + #[inline] + fn drop(&mut self) { + use crossbeam_epoch::Shared; + let pin = pin(); + let sealer = self.sealer.swap(Shared::null(), Ordering::AcqRel, &pin); + let opener = self.opener.swap(Shared::null(), Ordering::AcqRel, &pin); + + // no need to drop either one + if sealer.is_null() && opener.is_null() { + return; + } + + unsafe { + pin.defer_unchecked(move || { + drop(sealer.try_into_owned()); + drop(opener.try_into_owned()); + }) + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/pacer.rs b/dc/s2n-quic-dc/src/stream/pacer.rs new file mode 100644 index 0000000000..d6216dea61 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/pacer.rs @@ -0,0 +1,47 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::task::{Context, Poll}; +use s2n_quic_core::time::{Clock, Timestamp}; +use tracing::trace; + +#[derive(Default)] +pub struct Naive { + transmissions_without_yield: u8, + yield_window: Option, +} + +impl Naive { + #[inline] + pub fn poll_pacing(&mut self, cx: &mut Context, clock: &C) -> Poll<()> { + if self.transmissions_without_yield < 5 { + trace!("pass"); + self.transmissions_without_yield += 1; + return Poll::Ready(()); + } + + // reset the counter + self.transmissions_without_yield = 0; + + // record the time that we yielded + let now = clock.get_time(); + let prev_yield_window = core::mem::replace( + &mut self.yield_window, + Some(now + core::time::Duration::from_millis(1)), + ); + + // if the current time falls outside of the previous window then don't actually yield - the + // application isn't sending at that rate + if let Some(yield_window) = prev_yield_window { + if now > yield_window { + trace!("underflow"); + self.transmissions_without_yield += 1; + return Poll::Ready(()); + } + } + + trace!("yield"); + cx.waker().wake_by_ref(); + Poll::Pending + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv.rs b/dc/s2n-quic-dc/src/stream/recv.rs index 6d656cd717..8dfe9ca8ec 100644 --- a/dc/s2n-quic-dc/src/stream/recv.rs +++ b/dc/s2n-quic-dc/src/stream/recv.rs @@ -4,6 +4,7 @@ use super::{TransportFeatures, DEFAULT_IDLE_TIMEOUT}; use crate::{ allocator::Allocator, + clock, crypto::{decrypt, encrypt, UninitSlice}, packet::{control, stream}, }; @@ -118,7 +119,7 @@ impl Receiver { where B: buffer::Duplex, C: buffer::writer::Storage, - Clk: Clock, + Clk: Clock + ?Sized, { // try copying the out_buf into the application chunk, if possible if chunk.has_remaining_capacity() && !out_buf.buffer_is_empty() { @@ -223,7 +224,7 @@ impl Receiver { ) -> Result<(), Error> where D: decrypt::Key, - Clk: Clock, + Clk: Clock + ?Sized, B: buffer::Duplex, { probes::on_stream_packet( @@ -262,7 +263,7 @@ impl Receiver { ) -> Result<(), Error> where D: decrypt::Key, - Clk: Clock, + Clk: Clock + ?Sized, B: buffer::Duplex, { use buffer::reader::Storage as _; @@ -320,7 +321,7 @@ impl Receiver { ) -> Result<(), Error> where D: decrypt::Key, - Clk: Clock, + Clk: Clock + ?Sized, { // ensure the packet is authentic before processing it let res = packet.decrypt_in_place(crypto); @@ -353,7 +354,7 @@ impl Receiver { ) -> Result<(), Error> where D: decrypt::Key, - Clk: Clock, + Clk: Clock + ?Sized, { // ensure the packet is authentic before processing it let res = packet.decrypt(crypto, payload_out); @@ -395,7 +396,7 @@ impl Receiver { clock: &Clk, ) -> Result<(), Error> where - Clk: Clock, + Clk: Clock + ?Sized, { tracing::trace!( stream_id = %packet.stream_id(), @@ -496,7 +497,7 @@ impl Receiver { } #[inline] - fn update_idle_timer(&mut self, clock: &Clk) { + fn update_idle_timer(&mut self, clock: &Clk) { let target = clock.get_time() + self.idle_timeout; self.idle_timer.set(target); @@ -552,7 +553,7 @@ impl Receiver { #[inline] pub fn on_timeout(&mut self, clock: &Clk, load_last_activity: Ld) where - Clk: Clock, + Clk: Clock + ?Sized, Ld: FnOnce() -> Timestamp, { let now = clock.get_time(); @@ -584,7 +585,7 @@ impl Receiver { #[inline] fn poll_idle_timer(&mut self, clock: &Clk, load_last_activity: Ld) -> Poll<()> where - Clk: Clock, + Clk: Clock + ?Sized, Ld: FnOnce() -> Timestamp, { let now = clock.get_time(); @@ -622,7 +623,7 @@ impl Receiver { ) where E: encrypt::Key, A: Allocator, - Clk: Clock, + Clk: Clock + ?Sized, { (if self.error.is_none() { Self::on_transmit_ack @@ -634,7 +635,7 @@ impl Receiver { source_control_port, output, // avoid querying the clock for every transmitted packet - &s2n_quic_core::time::Cached::new(clock), + &clock::Cached::new(clock), ) } @@ -648,7 +649,7 @@ impl Receiver { ) where E: encrypt::Key, A: Allocator, - Clk: Clock, + Clk: Clock + ?Sized, { ensure!(self.should_transmit()); @@ -755,7 +756,7 @@ impl Receiver { ) where E: encrypt::Key, A: Allocator, - Clk: Clock, + Clk: Clock + ?Sized, { ensure!(self.should_transmit()); diff --git a/dc/s2n-quic-dc/src/stream/recv/packet.rs b/dc/s2n-quic-dc/src/stream/recv/packet.rs index 4476e6a45d..16e4e90b20 100644 --- a/dc/s2n-quic-dc/src/stream/recv/packet.rs +++ b/dc/s2n-quic-dc/src/stream/recv/packet.rs @@ -4,7 +4,7 @@ use super::*; use s2n_quic_core::buffer::{reader, writer, Reader}; -pub struct Packet<'a, 'p, D: decrypt::Key, C: Clock> { +pub struct Packet<'a, 'p, D: decrypt::Key, C: Clock + ?Sized> { pub packet: &'a mut stream::decoder::Packet<'p>, pub payload_cursor: usize, pub is_decrypted_in_place: bool, @@ -14,7 +14,7 @@ pub struct Packet<'a, 'p, D: decrypt::Key, C: Clock> { pub receiver: &'a mut Receiver, } -impl<'a, 'p, D: decrypt::Key, C: Clock> reader::Storage for Packet<'a, 'p, D, C> { +impl<'a, 'p, D: decrypt::Key, C: Clock + ?Sized> reader::Storage for Packet<'a, 'p, D, C> { type Error = Error; #[inline] @@ -88,7 +88,7 @@ impl<'a, 'p, D: decrypt::Key, C: Clock> reader::Storage for Packet<'a, 'p, D, C> } } -impl<'a, 'p, D: decrypt::Key, C: Clock> Reader for Packet<'a, 'p, D, C> { +impl<'a, 'p, D: decrypt::Key, C: Clock + ?Sized> Reader for Packet<'a, 'p, D, C> { #[inline] fn current_offset(&self) -> VarInt { self.packet.stream_offset() + self.payload_cursor diff --git a/dc/s2n-quic-dc/src/stream/send.rs b/dc/s2n-quic-dc/src/stream/send.rs index 5d22de8cf7..dcdd1635e2 100644 --- a/dc/s2n-quic-dc/src/stream/send.rs +++ b/dc/s2n-quic-dc/src/stream/send.rs @@ -8,6 +8,8 @@ pub mod filter; pub mod flow; pub mod path; pub mod probes; +pub mod queue; +pub mod shared; pub mod transmission; pub mod worker; diff --git a/dc/s2n-quic-dc/src/stream/send/application.rs b/dc/s2n-quic-dc/src/stream/send/application.rs index 6b2631dfad..a89fecc15a 100644 --- a/dc/s2n-quic-dc/src/stream/send/application.rs +++ b/dc/s2n-quic-dc/src/stream/send/application.rs @@ -77,7 +77,7 @@ impl State { let buffer_len = { let estimated_len = reader.buffered_len() + max_header_len; - (path.mtu as usize).min(estimated_len) + (path.max_datagram_size as usize).min(estimated_len) }; message.push(buffer_len, |buffer| { @@ -108,7 +108,7 @@ impl State { ); // buffer is clamped to u16::MAX so this is safe to cast without loss - let _: u16 = path.mtu; + let _: u16 = path.max_datagram_size; let packet_len = packet_len as u16; let payload_len = reader.consumed_len() as u16; total_payload_len += payload_len as usize; diff --git a/dc/s2n-quic-dc/src/stream/send/flow.rs b/dc/s2n-quic-dc/src/stream/send/flow.rs index 982f95c01b..5e516e2494 100644 --- a/dc/s2n-quic-dc/src/stream/send/flow.rs +++ b/dc/s2n-quic-dc/src/stream/send/flow.rs @@ -5,12 +5,7 @@ use s2n_quic_core::varint::VarInt; pub mod blocking; pub mod non_blocking; - -/// The maximum payload allowed in sendmsg calls -/// -/// > https://github.com/torvalds/linux/blob/8cd26fd90c1ad7acdcfb9f69ca99d13aa7b24561/net/ipv4/ip_output.c#L987-L995 -/// > Linux enforces a u16::MAX - IP_HEADER_LEN - UDP_HEADER_LEN -pub const MAX_PAYLOAD: u16 = u16::MAX - 50; +pub use crate::msg::segment::MAX_TOTAL; /// Flow credits acquired by an application request #[derive(Debug)] @@ -40,7 +35,7 @@ impl Request { /// Clamps the request with the given number of credits #[inline] pub fn clamp(&mut self, credits: u64) { - let len = self.len.min(credits.min(MAX_PAYLOAD as _) as usize); + let len = self.len.min(credits.min(MAX_TOTAL as _) as usize); // if we didn't acquire the entire len, then clear the `is_fin` flag if self.len != len { diff --git a/dc/s2n-quic-dc/src/stream/send/flow/blocking.rs b/dc/s2n-quic-dc/src/stream/send/flow/blocking.rs index 42afaa6e01..c6bddf205e 100644 --- a/dc/s2n-quic-dc/src/stream/send/flow/blocking.rs +++ b/dc/s2n-quic-dc/src/stream/send/flow/blocking.rs @@ -134,7 +134,7 @@ mod tests { let expected_len = VarInt::from_u16(u16::MAX); let state = State::new(initial_offset); let path_info = path::Info { - mtu: 1500, + max_datagram_size: 1500, send_quantum: 10, ecn: Default::default(), next_expected_control_packet: Default::default(), diff --git a/dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs b/dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs index 21ae41bdd2..32d048c6d0 100644 --- a/dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs +++ b/dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs @@ -211,7 +211,7 @@ mod tests { let expected_len = VarInt::from_u16(u16::MAX); let state = Arc::new(State::new(initial_offset)); let path_info = path::Info { - mtu: 1500, + max_datagram_size: 1500, send_quantum: 10, ecn: Default::default(), next_expected_control_packet: Default::default(), diff --git a/dc/s2n-quic-dc/src/stream/send/path.rs b/dc/s2n-quic-dc/src/stream/send/path.rs index 0b53b71a60..9ffd10e944 100644 --- a/dc/s2n-quic-dc/src/stream/send/path.rs +++ b/dc/s2n-quic-dc/src/stream/send/path.rs @@ -23,7 +23,11 @@ impl State { #[inline] pub fn new(info: Info) -> Self { Self { - info: AtomicU64::new(Self::encode_info(info.ecn, info.send_quantum, info.mtu)), + info: AtomicU64::new(Self::encode_info( + info.ecn, + info.send_quantum, + info.max_datagram_size, + )), next_expected_control_packet: AtomicU64::new( info.next_expected_control_packet.as_u64(), ), @@ -35,7 +39,7 @@ impl State { pub fn load(&self) -> Info { // use relaxed since it's ok to be slightly out of sync with the current MTU/send_quantum let data = self.info.load(Ordering::Relaxed); - let (ecn, send_quantum, mtu) = Self::decode_info(data); + let (ecn, send_quantum, max_datagram_size) = Self::decode_info(data); let next_expected_control_packet = self.next_expected_control_packet.load(Ordering::Relaxed); @@ -43,7 +47,7 @@ impl State { VarInt::new(next_expected_control_packet).unwrap_or(VarInt::MAX); Info { - mtu, + max_datagram_size, send_quantum, ecn, next_expected_control_packet, @@ -51,14 +55,19 @@ impl State { } #[inline] - pub fn update_info(&self, ecn: ExplicitCongestionNotification, send_quantum: u8, mtu: u16) { - let info = Self::encode_info(ecn, send_quantum, mtu); + pub fn update_info( + &self, + ecn: ExplicitCongestionNotification, + send_quantum: u8, + max_datagram_size: u16, + ) { + let info = Self::encode_info(ecn, send_quantum, max_datagram_size); self.info.store(info, Ordering::Relaxed); } #[inline] fn decode_info(mut data: u64) -> (ExplicitCongestionNotification, u8, u16) { - let mtu = data as u16; + let max_datagram_size = data as u16; data >>= 16; let send_quantum = data as u8; @@ -72,11 +81,15 @@ impl State { debug_assert_eq!(data, 0, "unexpected extra data"); - (ecn, send_quantum, mtu) + (ecn, send_quantum, max_datagram_size) } #[inline] - fn encode_info(ecn: ExplicitCongestionNotification, send_quantum: u8, mtu: u16) -> u64 { + fn encode_info( + ecn: ExplicitCongestionNotification, + send_quantum: u8, + max_datagram_size: u16, + ) -> u64 { let mut data = 0u64; data |= ecn as u8 as u64; @@ -85,7 +98,7 @@ impl State { data |= send_quantum as u64; data <<= 16; - data |= mtu as u64; + data |= max_datagram_size as u64; data } @@ -99,7 +112,7 @@ impl State { #[derive(Clone, Copy, Debug)] pub struct Info { - pub mtu: u16, + pub max_datagram_size: u16, pub send_quantum: u8, pub ecn: ExplicitCongestionNotification, pub next_expected_control_packet: VarInt, @@ -110,7 +123,7 @@ impl Info { #[inline] pub fn max_flow_credits(&self, max_header_len: usize, max_segments: usize) -> u64 { // trim off the headers since those don't count for flow control - let max_payload_size_per_segment = self.mtu as usize - max_header_len; + let max_payload_size_per_segment = self.max_datagram_size as usize - max_header_len; // clamp the number of segments we can transmit in a single burst let max_segments = max_segments.min(self.send_quantum as usize); @@ -132,10 +145,11 @@ mod tests { check!() .with_type() .cloned() - .for_each(|(ecn, send_quantum, mtu)| { - let actual = State::decode_info(State::encode_info(ecn, send_quantum, mtu)); + .for_each(|(ecn, send_quantum, max_datagram_size)| { + let actual = + State::decode_info(State::encode_info(ecn, send_quantum, max_datagram_size)); - assert_eq!((ecn, send_quantum, mtu), actual); + assert_eq!((ecn, send_quantum, max_datagram_size), actual); }) } } diff --git a/dc/s2n-quic-dc/src/stream/send/queue.rs b/dc/s2n-quic-dc/src/stream/send/queue.rs new file mode 100644 index 0000000000..d15bbd811d --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/queue.rs @@ -0,0 +1,326 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + msg::{addr, segment}, + stream::{ + send::{ + application::{self, transmission}, + buffer, worker, + }, + socket::Socket, + }, +}; +use bytes::buf::UninitSlice; +use core::task::{Context, Poll}; +use s2n_quic_core::{assume, buffer::reader, ensure, inet::ExplicitCongestionNotification, ready}; +use s2n_quic_platform::features::Gso; +use std::{collections::VecDeque, io}; + +/// An enqueued segment waiting to be transmitted on the socket +#[derive(Debug)] +pub struct Segment { + ecn: ExplicitCongestionNotification, + buffer: buffer::Segment, + offset: u16, +} + +impl Segment { + #[inline] + fn as_slice(&self) -> &[u8] { + &self.buffer[self.offset as usize..] + } +} + +pub struct Message<'a> { + batch: &'a mut Option>, + queue: &'a mut Queue, + max_segments: usize, + segment_alloc: &'a buffer::Allocator, +} + +impl<'a> application::Message for Message<'a> { + #[inline] + fn max_segments(&self) -> usize { + self.max_segments + } + + #[inline] + fn push transmission::Event<()>>( + &mut self, + buffer_len: usize, + p: P, + ) { + let mut buffer = self.segment_alloc.alloc(buffer_len); + + let transmission = { + let buffer = buffer.make_mut(); + + debug_assert!(buffer.capacity() >= buffer_len); + + let slice = UninitSlice::uninit(buffer.spare_capacity_mut()); + + let transmission = p(slice); + + unsafe { + let packet_len = transmission.info.packet_len; + assume!(buffer.capacity() >= packet_len as usize); + buffer.set_len(packet_len as usize); + } + + transmission + }; + + let transmission::Event { + packet_number, + info, + has_more_app_data, + } = transmission; + + let ecn = info.ecn; + + if let Some(batch) = self.batch.as_mut() { + let info = info.map(|_| buffer.clone()); + + batch.push(transmission::Event { + packet_number, + info, + has_more_app_data, + }); + } + + self.queue.segments.push_back(Segment { + ecn, + buffer, + offset: 0, + }); + } +} + +#[derive(Debug, Default)] +pub struct Queue { + /// Holds any segments that haven't been flushed to the socket + segments: VecDeque, + /// How many bytes we've accepted from the caller of `poll_write`, but actually returned + /// `Poll::Pending` for. This many bytes will be skipped the next time `poll_write` is called. + /// + /// This functionality ensures that we don't return to the application until we've flushed all + /// outstanding packets to the underlying socket. Experience has shown applications rely on + /// TCP's behavior, which never really requires `flush` or `shutdown` to progress the stream. + accepted_len: usize, +} + +impl Queue { + #[inline] + pub fn is_empty(&self) -> bool { + self.segments.is_empty() + } + + #[inline] + pub fn accepted_len(&self) -> usize { + self.accepted_len + } + + #[inline] + pub fn push_buffer( + &mut self, + buf: &mut B, + batch: &mut Option>, + max_segments: usize, + segment_alloc: &buffer::Allocator, + push: F, + ) -> Result<(), E> + where + B: reader::Storage, + F: FnOnce(&mut Message, &mut reader::storage::Tracked) -> Result<(), E>, + { + let mut message = Message { + batch, + queue: self, + max_segments, + segment_alloc, + }; + + let mut buf = buf.track_read(); + + push(&mut message, &mut buf)?; + + // record how many bytes we encrypted/buffered so we only return Ready once everything has + // been flushed + self.accepted_len += buf.consumed_len(); + + Ok(()) + } + + #[inline] + pub fn poll_flush( + &mut self, + cx: &mut Context, + limit: usize, + socket: &S, + addr: &addr::Addr, + segment_alloc: &buffer::Allocator, + gso: &Gso, + ) -> Poll> + where + S: ?Sized + Socket, + { + ready!(self.poll_flush_segments(cx, socket, addr, segment_alloc, gso))?; + + // Consume accepted credits + let accepted = limit.min(self.accepted_len); + self.accepted_len -= accepted; + Poll::Ready(Ok(accepted)) + } + + #[inline] + fn poll_flush_segments( + &mut self, + cx: &mut Context, + socket: &S, + addr: &addr::Addr, + segment_alloc: &buffer::Allocator, + gso: &Gso, + ) -> Poll> + where + S: ?Sized + Socket, + { + ensure!(!self.segments.is_empty(), Poll::Ready(Ok(()))); + + let default_addr = addr::Addr::new(Default::default()); + + let addr = if socket.features().is_connected() { + // no need to load the socket addr if the stream is already connected + &default_addr + } else { + addr + }; + + if socket.features().is_stream() { + self.poll_flush_segments_stream(cx, socket, addr, segment_alloc) + } else { + self.poll_flush_segments_datagram(cx, socket, addr, segment_alloc, gso) + } + } + + #[inline] + fn poll_flush_segments_stream( + &mut self, + cx: &mut Context, + socket: &S, + addr: &addr::Addr, + segment_alloc: &buffer::Allocator, + ) -> Poll> + where + S: ?Sized + Socket, + { + while !self.segments.is_empty() { + let segments = segment::Batch::new(self.segments.iter().map(|v| (v.ecn, v.as_slice()))); + + let ecn = segments.ecn(); + let res = ready!(socket.poll_send(cx, addr, ecn, &segments)); + + drop(segments); + + match res { + Ok(written_len) => { + self.consume_segments(written_len, segment_alloc); + + // keep trying to drain the buffer + continue; + } + Err(err) => { + // the socket encountered an error so clear everything out since we're shutting + // down + self.segments.clear(); + self.accepted_len = 0; + return Err(err).into(); + } + } + } + + Ok(()).into() + } + + #[inline] + fn consume_segments(&mut self, consumed: usize, segment_alloc: &buffer::Allocator) { + ensure!(consumed > 0); + + let mut remaining = consumed; + + while let Some(mut segment) = self.segments.pop_front() { + if let Some(r) = remaining.checked_sub(segment.as_slice().len()) { + remaining = r; + + // try to reuse the buffer for future allocations + segment_alloc.free(segment.buffer); + + // if we don't have any remaining bytes to pop then we're done + ensure!(remaining > 0, break); + + continue; + } + + segment.offset += core::mem::take(&mut remaining) as u16; + + debug_assert!(!segment.as_slice().is_empty()); + + self.segments.push_front(segment); + break; + } + + debug_assert_eq!( + remaining, 0, + "consumed ({consumed}) with too many bytes remaining ({remaining})" + ); + } + + #[inline] + fn poll_flush_segments_datagram( + &mut self, + cx: &mut Context, + socket: &S, + addr: &addr::Addr, + segment_alloc: &buffer::Allocator, + gso: &Gso, + ) -> Poll> + where + S: ?Sized + Socket, + { + let mut max_segments = gso.max_segments(); + + while !self.segments.is_empty() { + // construct all of the segments we're going to send in this batch + let segments = segment::Batch::new( + self.segments + .iter() + .map(|v| (v.ecn, v.as_slice())) + .take(max_segments), + ); + + let ecn = segments.ecn(); + let res = match ready!(socket.poll_send(cx, addr, ecn, &segments)) { + Ok(_) => Ok(()), + Err(error) => { + if gso.handle_socket_error(&error).is_some() { + // update the max_segments value if it was changed due to the error + max_segments = 1; + } + Err(error) + } + }; + + // consume the segments that we transmitted + let segment_count = segments.len(); + drop(segments); + for segment in self.segments.drain(..segment_count) { + // try to reuse the buffer for future allocations + segment_alloc.free(segment.buffer); + } + + res?; + } + + Ok(()).into() + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/shared.rs b/dc/s2n-quic-dc/src/stream/send/shared.rs new file mode 100644 index 0000000000..879d3c62bf --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/shared.rs @@ -0,0 +1,121 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + stream::{ + packet_number, + send::{ + application::transmission, buffer, error::Error, flow, path, queue::Queue, + worker::Transmission, + }, + }, + task::waker::worker::Waker as WorkerWaker, +}; +use core::{ + fmt, + sync::atomic::{AtomicU64, Ordering}, +}; +use crossbeam_queue::SegQueue; +use s2n_quic_core::recovery::bandwidth::Bandwidth; +use tracing::trace; + +#[derive(Debug)] +pub struct Message { + /// The event being submitted to the worker + pub event: Event, +} + +#[derive(Debug)] +pub enum Event { + Shutdown { queue: Queue, is_panicking: bool }, +} + +pub struct State { + pub flow: flow::non_blocking::State, + pub packet_number: packet_number::Counter, + pub path: path::State, + pub worker_waker: WorkerWaker, + bandwidth: AtomicU64, + /// A channel sender for pushing transmission information to the worker task + /// + /// We use an unbounded sender since we already rely on flow control to apply backpressure + worker_queue: SegQueue, + pub application_transmission_queue: transmission::Queue, + pub segment_alloc: buffer::Allocator, +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("send::shared::State") + .field("flow", &self.flow) + .field("packet_number", &self.packet_number) + .field("path", &self.path) + .finish() + } +} + +impl State { + #[inline] + pub fn new( + flow: flow::non_blocking::State, + path: path::Info, + bandwidth: Option, + ) -> Self { + let path = path::State::new(path); + let bandwidth = bandwidth.map(|v| v.serialize()).unwrap_or(u64::MAX).into(); + Self { + flow, + packet_number: Default::default(), + path, + bandwidth, + // this will get set once the waker spawns + worker_waker: Default::default(), + worker_queue: Default::default(), + application_transmission_queue: Default::default(), + segment_alloc: Default::default(), + } + } + + #[inline] + pub fn bandwidth(&self) -> Bandwidth { + Bandwidth::deserialize(self.bandwidth.load(Ordering::Relaxed)) + } + + #[inline] + pub fn set_bandwidth(&self, value: Bandwidth) { + self.bandwidth.store(value.serialize(), Ordering::Relaxed); + } + + #[inline] + pub fn pop_worker_message(&self) -> Option { + self.worker_queue.pop() + } + + #[inline] + pub fn push_to_worker(&self, transmissions: Vec) -> Result<(), Error> { + trace!(event = "transmission", len = transmissions.len()); + self.application_transmission_queue + .push_batch(transmissions); + + self.worker_waker.wake(); + + Ok(()) + } + + #[inline] + pub fn shutdown(&self, queue: Queue, is_panicking: bool) { + trace!( + event = "shutdown", + queue = queue.accepted_len(), + is_panicking = is_panicking + ); + let message = Message { + event: Event::Shutdown { + queue, + is_panicking, + }, + }; + self.worker_queue.push(message); + self.worker_waker.wake(); + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/worker.rs b/dc/s2n-quic-dc/src/stream/send/worker.rs index 8900970a2c..7301b6a1f0 100644 --- a/dc/s2n-quic-dc/src/stream/send/worker.rs +++ b/dc/s2n-quic-dc/src/stream/send/worker.rs @@ -102,7 +102,7 @@ pub struct Worker { pub max_data: VarInt, pub local_max_data_window: VarInt, pub peer_activity: Option, - pub mtu: u16, + pub max_datagram_size: u16, pub max_sent_segment_size: u16, } @@ -114,7 +114,7 @@ pub struct PeerActivity { impl Worker { #[inline] pub fn new(stream_id: stream::Id, params: &ApplicationParams) -> Self { - let mtu = params.max_datagram_size; + let max_datagram_size = params.max_datagram_size; let initial_max_data = params.remote_max_data; let local_max_data = params.local_send_max_data; @@ -122,7 +122,7 @@ impl Worker { let mut unacked_ranges = IntervalSet::new(); unacked_ranges.insert(VarInt::ZERO..=VarInt::MAX).unwrap(); - let cca = congestion::Controller::new(mtu); + let cca = congestion::Controller::new(max_datagram_size); let max_sent_offset = VarInt::ZERO; Self { @@ -154,7 +154,7 @@ impl Worker { max_data: initial_max_data, local_max_data_window: local_max_data, peer_activity: None, - mtu, + max_datagram_size, max_sent_segment_size: 0, } } @@ -195,7 +195,8 @@ impl Worker { pub fn send_quantum_packets(&self) -> u8 { // TODO use div_ceil when we're on 1.73+ MSRV // https://doc.rust-lang.org/std/primitive.u64.html#method.div_ceil - let send_quantum = (self.cca.send_quantum() as u64 + self.mtu as u64 - 1) / self.mtu as u64; + let send_quantum = (self.cca.send_quantum() as u64 + self.max_datagram_size as u64 - 1) + / self.max_datagram_size as u64; send_quantum.try_into().unwrap_or(u8::MAX) } diff --git a/dc/s2n-quic-dc/src/stream/socket.rs b/dc/s2n-quic-dc/src/stream/socket.rs new file mode 100644 index 0000000000..d2ab7406c5 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/socket.rs @@ -0,0 +1,35 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::TransportFeatures; + +pub mod application; +mod fd; +mod handle; +mod tokio; +mod tracing; + +pub use self::tracing::Tracing; +pub use crate::socket::*; +pub use application::Application; +pub use handle::{Ext, Flags, Socket}; + +pub type ArcApplication = std::sync::Arc; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum Protocol { + Tcp, + Udp, + Other(&'static str), +} + +impl Protocol { + s2n_quic_core::state::is!(is_tcp, Tcp); + s2n_quic_core::state::is!(is_udp, Udp); + + #[inline] + pub fn is_other(&self) -> bool { + matches!(self, Self::Other(_)) + } +} diff --git a/dc/s2n-quic-dc/src/stream/socket/application.rs b/dc/s2n-quic-dc/src/stream/socket/application.rs new file mode 100644 index 0000000000..b588aa56a4 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/socket/application.rs @@ -0,0 +1,92 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{Protocol, Socket, TransportFeatures}; +use std::sync::Arc; + +pub mod builder; + +pub use builder::Builder; + +pub trait Application: 'static + Send + Sync { + fn protocol(&self) -> Protocol; + + fn features(&self) -> TransportFeatures; + + fn write_application(&self) -> &dyn Socket; + + fn read_application(&self) -> &dyn Socket; +} + +impl Application for Arc { + #[inline] + fn protocol(&self) -> Protocol { + (**self).protocol() + } + + #[inline] + fn features(&self) -> TransportFeatures { + (**self).features() + } + + #[inline] + fn write_application(&self) -> &dyn Socket { + (**self).write_application() + } + + #[inline] + fn read_application(&self) -> &dyn Socket { + (**self).read_application() + } +} + +pub struct Single(S); + +impl Application for Single { + #[inline] + fn protocol(&self) -> Protocol { + self.0.protocol() + } + + #[inline] + fn features(&self) -> TransportFeatures { + self.0.features() + } + + #[inline] + fn write_application(&self) -> &dyn Socket { + &self.0 + } + + #[inline] + fn read_application(&self) -> &dyn Socket { + &self.0 + } +} + +pub struct Pair { + read: S, + write: S, +} + +impl Application for Pair { + #[inline] + fn protocol(&self) -> Protocol { + self.read.protocol() + } + + #[inline] + fn features(&self) -> TransportFeatures { + self.read.features() + } + + #[inline] + fn write_application(&self) -> &dyn Socket { + &self.write + } + + #[inline] + fn read_application(&self) -> &dyn Socket { + &self.read + } +} diff --git a/dc/s2n-quic-dc/src/stream/socket/application/builder.rs b/dc/s2n-quic-dc/src/stream/socket/application/builder.rs new file mode 100644 index 0000000000..e2586d6144 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/socket/application/builder.rs @@ -0,0 +1,72 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::super::{ArcApplication, Tracing}; +use std::{io, sync::Arc}; +use tokio::io::unix::AsyncFd; + +pub trait Builder: 'static + Send + Sync { + fn build(self: Box) -> io::Result; +} + +impl Builder for std::net::UdpSocket { + #[inline] + fn build(self: Box) -> io::Result { + let v = AsyncFd::new(*self)?; + let v = Tracing(v); + let v = super::Single(v); + let v = Arc::new(v); + Ok(v) + } +} + +impl Builder for std::net::TcpStream { + #[inline] + fn build(self: Box) -> io::Result { + let v = tokio::net::TcpStream::from_std(*self)?; + let v = Tracing(v); + let v = super::Single(v); + let v = Arc::new(v); + Ok(v) + } +} + +impl Builder for tokio::net::TcpStream { + #[inline] + fn build(self: Box) -> io::Result { + let v = Tracing(*self); + let v = super::Single(v); + let v = Arc::new(v); + Ok(v) + } +} + +impl Builder for Arc { + #[inline] + fn build(self: Box) -> io::Result { + // TODO avoid the Box> indirection here? + let v = AsyncFd::new(*self)?; + let v = Tracing(v); + let v = super::Single(v); + let v = Arc::new(v); + Ok(v) + } +} + +pub struct UdpPair { + pub reader: Arc, + pub writer: Arc, +} + +impl Builder for UdpPair { + #[inline] + fn build(self: Box) -> io::Result { + let read = AsyncFd::new(self.reader)?; + let read = Tracing(read); + let write = AsyncFd::new(self.writer)?; + let write = Tracing(write); + let v = super::Pair { read, write }; + let v = Arc::new(v); + Ok(v) + } +} diff --git a/dc/s2n-quic-dc/src/stream/socket/fd.rs b/dc/s2n-quic-dc/src/stream/socket/fd.rs new file mode 100644 index 0000000000..885a5cfffa --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/socket/fd.rs @@ -0,0 +1,37 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_quic_core::ensure; +use std::{io, os::fd::AsRawFd}; + +pub mod tcp; +pub mod udp; + +pub type Flags = libc::c_int; + +#[inline] +pub fn peek(fd: &T) -> io::Result +where + T: AsRawFd, +{ + libc_call(|| unsafe { + let flags = libc::MSG_PEEK | libc::MSG_TRUNC; + + // macos doesn't seem to support MSG_TRUNC so we need to give it at least 1 byte + if cfg!(target_os = "macos") { + let mut buf = [0u8]; + libc::recv(fd.as_raw_fd(), buf.as_mut_ptr() as *mut _, 1, flags) as _ + } else { + libc::recv(fd.as_raw_fd(), core::ptr::null_mut(), 0, flags) as _ + } + }) +} + +#[inline] +pub fn libc_call(call: impl FnOnce() -> isize) -> io::Result { + let res = call(); + + ensure!(res >= 0, Err(io::Error::last_os_error())); + + Ok(res as _) +} diff --git a/dc/s2n-quic-dc/src/stream/socket/fd/tcp.rs b/dc/s2n-quic-dc/src/stream/socket/fd/tcp.rs new file mode 100644 index 0000000000..6ecc11feff --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/socket/fd/tcp.rs @@ -0,0 +1,71 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{libc_call, Flags}; +use std::{ + io::{self, IoSlice, IoSliceMut}, + os::fd::AsRawFd, +}; + +pub use super::peek; + +/// Receives segments on the provided socket +#[inline] +pub fn recv(fd: &T, segments: &mut [IoSliceMut], flags: Flags) -> io::Result +where + T: AsRawFd, +{ + let fd = fd.as_raw_fd(); + + // if we only have a single segment then use recv, which should be slightly cheaper than recvmsg + libc_call(|| if segments.len() == 1 { + let segment: &mut [u8] = &mut segments[0]; + let buf = segment.as_ptr() as *mut _; + let len = segment.len() as _; + unsafe { libc::recv(fd, buf, len, flags) } + } else { + let mut msg = unsafe { core::mem::zeroed::() }; + + msg.msg_iov = segments.as_mut_ptr() as *mut _; + msg.msg_iovlen = segments.len() as _; + + unsafe { libc::recvmsg(fd.as_raw_fd(), &mut msg, flags) } + } as _) +} + +/// Sends segments on the provided socket +#[inline] +pub fn send(fd: &T, segments: &[IoSlice]) -> io::Result +where + T: AsRawFd, +{ + debug_assert!(!segments.is_empty()); + + let fd = fd.as_raw_fd(); + let flags = Flags::default(); + + // if we only have a single segment then use send, which should be slightly cheaper than sendmsg + libc_call(|| if segments.len() == 1 { + let segment: &[u8] = &segments[0]; + let buf = segment.as_ptr() as *const _; + let len = segment.len() as _; + unsafe { libc::send(fd, buf, len, flags) } + } else { + let mut msg = unsafe { core::mem::zeroed::() }; + + // msghdr wants a `*mut iovec` but it doesn't actually end up mutating it + msg.msg_iov = segments.as_ptr() as *mut IoSlice as *mut _; + msg.msg_iovlen = segments.len() as _; + + unsafe { libc::sendmsg(fd, &msg, flags) } + } as _) +} + +#[inline] +pub fn shutdown(fd: &T) -> io::Result<()> +where + T: AsRawFd, +{ + libc_call(|| unsafe { libc::shutdown(fd.as_raw_fd(), libc::SHUT_WR) as _ })?; + Ok(()) +} diff --git a/dc/s2n-quic-dc/src/stream/socket/fd/udp.rs b/dc/s2n-quic-dc/src/stream/socket/fd/udp.rs new file mode 100644 index 0000000000..a47704eaa4 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/socket/fd/udp.rs @@ -0,0 +1,145 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{libc_call, Flags}; +use crate::msg::{ + addr::Addr, + cmsg::{self, Encoder}, +}; +use s2n_quic_core::inet::ExplicitCongestionNotification; +use std::{ + io::{self, IoSlice, IoSliceMut}, + os::fd::AsRawFd, +}; + +pub use super::peek; + +#[inline] +pub fn recv( + fd: &T, + addr: &mut Addr, + cmsg: &mut cmsg::Receiver, + buffer: &mut [IoSliceMut], + flags: Flags, +) -> io::Result +where + T: AsRawFd, +{ + recv_msghdr(addr, cmsg, buffer, |msghdr| { + libc_call(|| unsafe { libc::recvmsg(fd.as_raw_fd(), msghdr, flags) as _ }) + }) +} + +/// Constructs a msghdr for receiving +#[inline] +fn recv_msghdr( + addr: &mut Addr, + cmsg: &mut cmsg::Receiver, + segments: &mut [IoSliceMut], + exec: impl FnOnce(&mut libc::msghdr) -> io::Result, +) -> io::Result { + debug_assert!(!segments.is_empty()); + + let mut msg = unsafe { core::mem::zeroed::() }; + + addr.recv_with_msg(&mut msg); + + // setup cmsg info + let mut cmsg_storage = cmsg::Storage::<{ cmsg::DECODER_LEN }>::default(); + + msg.msg_control = cmsg_storage.as_mut_ptr() as *mut _; + msg.msg_controllen = cmsg_storage.len() as _; + + msg.msg_iov = segments.as_ptr() as *mut IoSliceMut as *mut _; + msg.msg_iovlen = segments.len() as _; + + let len = exec(&mut msg)?; + + cmsg.with_msg(&msg); + + Ok(len) +} + +#[inline] +pub fn send( + fd: &T, + addr: &Addr, + ecn: ExplicitCongestionNotification, + buffer: &[IoSlice], +) -> io::Result +where + T: AsRawFd, +{ + send_msghdr(addr, ecn, buffer, |msghdr| { + libc_call(|| unsafe { libc::sendmsg(fd.as_raw_fd(), msghdr, 0) as _ }) + }) +} + +/// Constructs a msghdr for sending +#[inline] +fn send_msghdr( + addr: &Addr, + ecn: ExplicitCongestionNotification, + segments: &[IoSlice], + exec: impl FnOnce(&libc::msghdr) -> io::Result, +) -> io::Result { + debug_assert!(!segments.is_empty()); + + let mut msg = unsafe { core::mem::zeroed::() }; + + addr.send_with_msg(&mut msg); + + // make sure we constructed a valid iovec + #[cfg(debug_assertions)] + check_send_iovec(segments); + + // setup cmsg info + let mut cmsg_storage = cmsg::Storage::<{ cmsg::ENCODER_LEN }>::default(); + let mut cmsg = cmsg_storage.encoder(); + if ecn != ExplicitCongestionNotification::NotEct { + // TODO enable this once we consolidate s2n-quic-core crates + // let _ = cmsg.encode_ecn(ecn, &addr); + } + + if segments.len() > 1 { + let _ = cmsg.encode_gso(segments[0].len() as _); + } + + if !cmsg.is_empty() { + msg.msg_control = cmsg.as_mut_ptr() as *mut _; + msg.msg_controllen = cmsg.len() as _; + } + + msg.msg_iov = segments.as_ptr() as *mut IoSlice as *mut _; + msg.msg_iovlen = segments.len() as _; + + exec(&mut msg) +} + +#[cfg(debug_assertions)] +fn check_send_iovec(segments: &[T]) +where + T: core::ops::Deref, +{ + let mut total_len = 0; + let mut segment_size = None; + let mut can_accept_more = true; + + for segment in segments { + assert!(can_accept_more); + + if let Some(expected_len) = segment_size { + assert!(expected_len >= segment.len()); + // we can only have more segments if the current one matches the previous + can_accept_more = expected_len == segment.len(); + } else { + segment_size = Some(segment.len()); + } + total_len += segment.len(); + } + + assert!( + total_len <= u16::MAX as usize, + "payloads should not exceed 2^16" + ); +} diff --git a/dc/s2n-quic-dc/src/stream/socket/handle.rs b/dc/s2n-quic-dc/src/stream/socket/handle.rs new file mode 100644 index 0000000000..5eb4a69693 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/socket/handle.rs @@ -0,0 +1,181 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{Protocol, TransportFeatures}; +use crate::msg::{self, addr::Addr, cmsg}; +use core::task::{Context, Poll}; +use s2n_quic_core::inet::ExplicitCongestionNotification; +use std::{ + io::{self, IoSlice, IoSliceMut}, + net::SocketAddr, + sync::Arc, +}; + +pub type Flags = libc::c_int; + +pub trait Socket: 'static + Send + Sync { + /// Returns the local address for the socket + fn local_addr(&self) -> io::Result; + + /// Returns the local port for the socket + #[inline] + fn local_port(&self) -> io::Result { + Ok(self.local_addr()?.port()) + } + + fn protocol(&self) -> Protocol; + + /// Returns the [`TransportFeatures`] that the socket supports + fn features(&self) -> TransportFeatures; + + /// Returns the amount of buffered data on the socket + fn poll_peek_len(&self, cx: &mut Context) -> Poll>; + + #[inline] + fn poll_recv_buffer( + &self, + cx: &mut Context, + msg: &mut msg::recv::Message, + ) -> Poll> { + #[cfg(debug_assertions)] + if !self.features().is_stream() { + assert!( + msg.is_empty(), + "receive buffer should be empty for datagram protocols" + ); + } + + msg.poll_recv_with(|addr, cmsg, buffer| self.poll_recv(cx, addr, cmsg, buffer)) + } + + /// Receives data on the socket + fn poll_recv( + &self, + cx: &mut Context, + addr: &mut Addr, + cmsg: &mut cmsg::Receiver, + buffer: &mut [IoSliceMut], + ) -> Poll>; + + #[inline] + fn try_send_buffer(&self, msg: &mut msg::send::Message) -> io::Result { + msg.send_with(|addr, ecn, iov| self.try_send(addr, ecn, iov)) + } + + /// Tries to send data on the socket, returning `Err(WouldBlock)` if none could be sent. + fn try_send( + &self, + addr: &Addr, + ecn: ExplicitCongestionNotification, + buffer: &[IoSlice], + ) -> io::Result; + + #[inline] + fn poll_send_buffer( + &self, + cx: &mut Context, + msg: &mut msg::send::Message, + ) -> Poll> { + msg.poll_send_with(|addr, ecn, iov| self.poll_send(cx, addr, ecn, iov)) + } + + /// Sends data on the socket + fn poll_send( + &self, + cx: &mut Context, + addr: &Addr, + ecn: ExplicitCongestionNotification, + buffer: &[IoSlice], + ) -> Poll>; + + /// Shuts down the sender half of the socket, if a concept exists + fn send_finish(&self) -> io::Result<()>; +} + +pub trait Ext: Socket { + #[inline] + fn recv_buffer<'a>(&'a self, msg: &'a mut msg::recv::Message) -> ExtRecvBuffer<'a, Self> { + ExtRecvBuffer { socket: self, msg } + } +} + +pub struct ExtRecvBuffer<'a, T: Socket + ?Sized> { + socket: &'a T, + msg: &'a mut msg::recv::Message, +} + +impl<'a, T: Socket> core::future::Future for ExtRecvBuffer<'a, T> { + type Output = io::Result; + + fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.socket.poll_recv_buffer(cx, self.msg) + } +} + +impl Ext for T {} + +macro_rules! impl_box { + ($b:ident) => { + impl Socket for $b { + #[inline(always)] + fn local_addr(&self) -> io::Result { + (**self).local_addr() + } + + #[inline(always)] + fn protocol(&self) -> Protocol { + (**self).protocol() + } + + #[inline(always)] + fn features(&self) -> TransportFeatures { + (**self).features() + } + + #[inline(always)] + fn poll_peek_len(&self, cx: &mut Context) -> Poll> { + (**self).poll_peek_len(cx) + } + + #[inline(always)] + fn poll_recv( + &self, + cx: &mut Context, + addr: &mut Addr, + cmsg: &mut cmsg::Receiver, + buffer: &mut [IoSliceMut], + ) -> Poll> { + (**self).poll_recv(cx, addr, cmsg, buffer) + } + + #[inline(always)] + fn try_send( + &self, + addr: &Addr, + ecn: ExplicitCongestionNotification, + buffer: &[IoSlice], + ) -> io::Result { + (**self).try_send(addr, ecn, buffer) + } + + #[inline(always)] + fn poll_send( + &self, + cx: &mut Context, + addr: &Addr, + ecn: ExplicitCongestionNotification, + buffer: &[IoSlice], + ) -> Poll> { + (**self).poll_send(cx, addr, ecn, buffer) + } + + #[inline(always)] + fn send_finish(&self) -> io::Result<()> { + (**self).send_finish() + } + } + }; +} + +impl_box!(Box); +impl_box!(Arc); diff --git a/dc/s2n-quic-dc/src/stream/socket/tokio.rs b/dc/s2n-quic-dc/src/stream/socket/tokio.rs new file mode 100644 index 0000000000..2b10060f8b --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/socket/tokio.rs @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{fd, Protocol, Socket, TransportFeatures}; + +mod tcp; +mod udp; diff --git a/dc/s2n-quic-dc/src/stream/socket/tokio/tcp.rs b/dc/s2n-quic-dc/src/stream/socket/tokio/tcp.rs new file mode 100644 index 0000000000..51f676666b --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/socket/tokio/tcp.rs @@ -0,0 +1,144 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{ + fd::{tcp, Flags}, + Protocol, Socket, TransportFeatures, +}; +use crate::msg::{addr::Addr, cmsg}; +use core::task::{Context, Poll}; +use s2n_quic_core::{inet::ExplicitCongestionNotification, ready}; +use std::{ + io::{self, IoSlice, IoSliceMut}, + net::SocketAddr, +}; +use tokio::{io::Interest, net::TcpStream}; + +impl Socket for TcpStream { + #[inline] + fn local_addr(&self) -> io::Result { + (*self).local_addr() + } + + #[inline] + fn protocol(&self) -> Protocol { + Protocol::Udp + } + + #[inline] + fn features(&self) -> TransportFeatures { + TransportFeatures::TCP + } + + #[inline] + fn poll_peek_len(&self, cx: &mut Context) -> Poll> { + loop { + ready!(self.poll_read_ready(cx))?; + + let res = self.try_io(Interest::READABLE, || tcp::peek(self)); + + match res { + Ok(len) => return Ok(len).into(), + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => { + // try the operation again if we were interrupted + continue; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // register the waker + continue; + } + Err(err) => return Err(err).into(), + } + } + } + + #[inline] + fn poll_recv( + &self, + cx: &mut Context, + _addr: &mut Addr, + cmsg: &mut cmsg::Receiver, + buffer: &mut [IoSliceMut], + ) -> Poll> { + loop { + ready!(self.poll_read_ready(cx))?; + + let flags = Flags::default(); + let res = self.try_io(Interest::READABLE, || tcp::recv(self, buffer, flags)); + + match res { + Ok(len) => { + // we don't need ECN markings from TCP since it handles that logic for us + cmsg.set_ecn(ExplicitCongestionNotification::NotEct); + + // TCP doesn't have segments so just set it to 0 (which will indicate a single + // stream of bytes) + cmsg.set_segment_len(0); + + return Ok(len).into(); + } + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => { + // try the operation again if we were interrupted + continue; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // register the waker + continue; + } + Err(err) => return Err(err).into(), + } + } + } + + #[inline] + fn try_send( + &self, + _addr: &Addr, + _ecn: ExplicitCongestionNotification, + buffer: &[IoSlice], + ) -> io::Result { + loop { + match tcp::send(self, buffer) { + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => { + // try the operation again if we were interrupted + continue; + } + res => return res, + } + } + } + + #[inline] + fn poll_send( + &self, + cx: &mut Context, + _addr: &Addr, + _ecn: ExplicitCongestionNotification, + buffer: &[IoSlice], + ) -> Poll> { + loop { + ready!(self.poll_write_ready(cx))?; + + let res = self.try_io(Interest::WRITABLE, || tcp::send(self, buffer)); + + match res { + Ok(len) => return Ok(len).into(), + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => { + // try the operation again if we were interrupted + continue; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // register the waker + continue; + } + Err(err) => return Err(err).into(), + } + } + } + + #[inline] + fn send_finish(&self) -> io::Result<()> { + // AsyncWrite::poll_shutdown requires a `&mut self` so we just use libc directly + tcp::shutdown(self) + } +} diff --git a/dc/s2n-quic-dc/src/stream/socket/tokio/udp.rs b/dc/s2n-quic-dc/src/stream/socket/tokio/udp.rs new file mode 100644 index 0000000000..d84ea4919f --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/socket/tokio/udp.rs @@ -0,0 +1,212 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{ + fd::{udp, Flags}, + Protocol, Socket, TransportFeatures, +}; +use crate::msg::{addr::Addr, cmsg}; +use core::task::{Context, Poll}; +use s2n_quic_core::{ensure, inet::ExplicitCongestionNotification, ready}; +use std::{ + io::{self, IoSlice, IoSliceMut}, + net::SocketAddr, + os::fd::AsRawFd, +}; +use tokio::io::unix::{AsyncFd, TryIoError}; + +trait UdpSocket: 'static + AsRawFd + Send + Sync { + fn local_addr(&self) -> io::Result; +} + +impl UdpSocket for std::net::UdpSocket { + #[inline] + fn local_addr(&self) -> io::Result { + (*self).local_addr() + } +} + +impl UdpSocket for tokio::net::UdpSocket { + #[inline] + fn local_addr(&self) -> io::Result { + (*self).local_addr() + } +} + +impl UdpSocket for std::sync::Arc { + #[inline] + fn local_addr(&self) -> io::Result { + (**self).local_addr() + } +} + +impl UdpSocket for Box { + #[inline] + fn local_addr(&self) -> io::Result { + (**self).local_addr() + } +} + +impl Socket for AsyncFd +where + T: UdpSocket, +{ + #[inline] + fn local_addr(&self) -> io::Result { + self.get_ref().local_addr() + } + + #[inline] + fn protocol(&self) -> Protocol { + Protocol::Udp + } + + #[inline] + fn features(&self) -> TransportFeatures { + TransportFeatures::UDP + } + + #[inline] + fn poll_peek_len(&self, cx: &mut Context) -> Poll> { + loop { + let mut socket = ready!(self.poll_read_ready(cx))?; + + let res = socket.try_io(udp::peek); + + match res { + Ok(Ok(len)) => return Ok(len).into(), + Ok(Err(ref e)) if e.kind() == io::ErrorKind::Interrupted => { + // try the operation again if we were interrupted + continue; + } + Ok(Err(err)) => return Err(err).into(), + Err(err) => { + // we got a WouldBlock so loop back around to register the waker + let _: TryIoError = err; + continue; + } + } + } + } + + #[inline] + fn poll_recv( + &self, + cx: &mut Context, + addr: &mut Addr, + cmsg: &mut cmsg::Receiver, + buffer: &mut [IoSliceMut], + ) -> Poll> { + // no point in receiving empty packets + ensure!(!buffer.is_empty(), Ok(0).into()); + + debug_assert!( + buffer.iter().any(|s| !s.is_empty()), + "trying to recv into an empty buffer" + ); + + loop { + let mut socket = ready!(self.poll_read_ready(cx))?; + let flags = Flags::default(); + + let res = socket.try_io(|fd| udp::recv(fd, addr, cmsg, buffer, flags)); + + match res { + Ok(Ok(0)) => { + // no point in processing empty UDP packets + continue; + } + Ok(Ok(len)) => return Ok(len).into(), + Ok(Err(ref e)) if e.kind() == io::ErrorKind::Interrupted => { + // try the operation again if we were interrupted + continue; + } + Ok(Err(err)) => return Err(err).into(), + Err(err) => { + // we got a WouldBlock so loop back around to register the waker + let _: TryIoError = err; + continue; + } + } + } + } + + #[inline] + fn try_send( + &self, + addr: &Addr, + ecn: ExplicitCongestionNotification, + buffer: &[IoSlice], + ) -> io::Result { + // no point in sending empty packets + ensure!(!buffer.is_empty(), Ok(0)); + + debug_assert!( + buffer.iter().any(|s| !s.is_empty()), + "trying to send from an empty buffer" + ); + + debug_assert!( + addr.get().port() != 0, + "cannot send packet to unspecified port" + ); + + loop { + match udp::send(self.get_ref(), addr, ecn, buffer) { + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => { + // try the operation again if we were interrupted + continue; + } + res => return res, + } + } + } + + #[inline] + fn poll_send( + &self, + cx: &mut Context, + addr: &Addr, + ecn: ExplicitCongestionNotification, + buffer: &[IoSlice], + ) -> Poll> { + // no point in sending empty packets + ensure!(!buffer.is_empty(), Ok(0).into()); + + debug_assert!( + buffer.iter().any(|s| !s.is_empty()), + "trying to send from an empty buffer" + ); + + debug_assert!( + addr.get().port() != 0, + "cannot send packet to unspecified port" + ); + + loop { + let mut socket = ready!(self.poll_write_ready(cx))?; + + let res = socket.try_io(|fd| udp::send(fd, addr, ecn, buffer)); + + match res { + Ok(Ok(len)) => return Ok(len).into(), + Ok(Err(ref e)) if e.kind() == io::ErrorKind::Interrupted => { + // try the operation again if we were interrupted + continue; + } + Ok(Err(err)) => return Err(err).into(), + Err(err) => { + // we got a WouldBlock so loop back around to register the waker + let _: TryIoError = err; + continue; + } + } + } + } + + #[inline] + fn send_finish(&self) -> io::Result<()> { + // UDP sockets don't need a shut down + Ok(()) + } +} diff --git a/dc/s2n-quic-dc/src/stream/socket/tracing.rs b/dc/s2n-quic-dc/src/stream/socket/tracing.rs new file mode 100644 index 0000000000..85bca014c5 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/socket/tracing.rs @@ -0,0 +1,155 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{Protocol, Socket, TransportFeatures}; +use crate::msg::{addr::Addr, cmsg}; +use core::task::{Context, Poll}; +use s2n_quic_core::inet::ExplicitCongestionNotification; +use std::{ + io::{self, IoSlice, IoSliceMut}, + net::SocketAddr, +}; +use tracing::trace; + +pub struct Tracing(pub S); + +impl Socket for Tracing { + #[inline(always)] + fn local_addr(&self) -> io::Result { + self.0.local_addr() + } + + #[inline] + fn protocol(&self) -> Protocol { + self.0.protocol() + } + + #[inline(always)] + fn features(&self) -> TransportFeatures { + self.0.features() + } + + #[inline(always)] + fn poll_peek_len(&self, cx: &mut Context) -> Poll> { + let result = self.0.poll_peek_len(cx); + + trace!( + operation = %"poll_peek_len", + protocol = ?self.protocol(), + local_addr = ?self.local_addr(), + result = ?result, + ); + + result + } + + #[inline(always)] + fn poll_recv( + &self, + cx: &mut Context, + addr: &mut Addr, + cmsg: &mut cmsg::Receiver, + buffer: &mut [IoSliceMut], + ) -> Poll> { + let result = self.0.poll_recv(cx, addr, cmsg, buffer); + + match &result { + Poll::Ready(Ok(_)) => trace!( + operation = %"poll_recv", + protocol = ?self.protocol(), + local_addr = ?self.local_addr(), + remote_addr = ?addr, + ecn = ?cmsg.ecn(), + segments = buffer.len(), + segment_len = cmsg.segment_len(), + buffer_len = { + let v: usize = buffer.iter().map(|s| s.len()).sum(); + v + }, + result = ?result, + ), + _ => trace!( + operation = %"poll_recv", + protocol = ?self.protocol(), + local_addr = ?self.local_addr(), + segments = buffer.len(), + buffer_len = { + let v: usize = buffer.iter().map(|s| s.len()).sum(); + v + }, + result = ?result, + ), + } + + result + } + + #[inline(always)] + fn try_send( + &self, + addr: &Addr, + ecn: ExplicitCongestionNotification, + buffer: &[IoSlice], + ) -> io::Result { + let result = self.0.try_send(addr, ecn, buffer); + + trace!( + operation = %"try_send", + protocol = ?self.protocol(), + local_addr = ?self.local_addr(), + remote_addr = ?addr, + ?ecn, + segments = buffer.len(), + segment_len = buffer.first().map_or(0, |s| s.len()), + buffer_len = { + let v: usize = buffer.iter().map(|s| s.len()).sum(); + v + }, + result = ?result, + ); + + result + } + + #[inline(always)] + fn poll_send( + &self, + cx: &mut Context, + addr: &Addr, + ecn: ExplicitCongestionNotification, + buffer: &[IoSlice], + ) -> Poll> { + let result = self.0.poll_send(cx, addr, ecn, buffer); + + trace!( + operation = %"poll_send", + protocol = ?self.protocol(), + local_addr = ?self.local_addr(), + remote_addr = ?addr, + ?ecn, + segments = buffer.len(), + segment_len = buffer.first().map_or(0, |s| s.len()), + buffer_len = { + let v: usize = buffer.iter().map(|s| s.len()).sum(); + v + }, + result = ?result, + ); + + result + } + + #[inline(always)] + fn send_finish(&self) -> io::Result<()> { + let result = self.0.send_finish(); + + trace!( + operation = %"send_finish", + protocol = ?self.protocol(), + local_addr = ?self.local_addr(), + result = ?result, + ); + + result + } +} diff --git a/dc/s2n-quic-dc/src/task.rs b/dc/s2n-quic-dc/src/task.rs new file mode 100644 index 0000000000..a2fc259e2b --- /dev/null +++ b/dc/s2n-quic-dc/src/task.rs @@ -0,0 +1,4 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod waker; diff --git a/dc/s2n-quic-dc/src/task/waker.rs b/dc/s2n-quic-dc/src/task/waker.rs new file mode 100644 index 0000000000..d0c7fd5260 --- /dev/null +++ b/dc/s2n-quic-dc/src/task/waker.rs @@ -0,0 +1,4 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod worker; diff --git a/dc/s2n-quic-dc/src/task/waker/worker.rs b/dc/s2n-quic-dc/src/task/waker/worker.rs new file mode 100644 index 0000000000..2568b160a0 --- /dev/null +++ b/dc/s2n-quic-dc/src/task/waker/worker.rs @@ -0,0 +1,126 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::{ + sync::atomic::{AtomicU8, Ordering}, + task, +}; +use crossbeam_epoch::{pin, Atomic}; +use s2n_quic_core::{ensure, state::is}; + +/// An atomic waker that doesn't change often +#[derive(Debug)] +pub struct Waker { + waker: Atomic, + has_woken: AtomicU8, +} + +impl Default for Waker { + #[inline] + fn default() -> Self { + Self { + waker: Atomic::null(), + has_woken: AtomicU8::new(Status::Sleeping.as_u8()), + } + } +} + +impl Waker { + #[inline] + pub fn update(&self, waker: &task::Waker) { + let pin = crossbeam_epoch::pin(); + + let waker = crossbeam_epoch::Owned::new(waker.clone()).into_shared(&pin); + let prev = self.waker.swap(waker, Ordering::AcqRel, &pin); + + ensure!(!prev.is_null()); + + unsafe { + pin.defer_unchecked(move || { + drop(prev.try_into_owned()); + }) + } + } + + #[inline] + pub fn wake(&self) { + let status = self.swap_status(Status::PendingWork, Ordering::Acquire); + + // we only need to `wake_by_ref` if the worker is sleeping + ensure!(matches!(status, Status::Sleeping)); + + let guard = crossbeam_epoch::pin(); + let waker = self.waker.load(Ordering::Acquire, &guard); + let Some(waker) = (unsafe { waker.as_ref() }) else { + return; + }; + waker.wake_by_ref(); + } + + /// Called when the worker wakes + #[inline] + pub fn on_worker_wake(&self) { + self.swap_status(Status::Working, Ordering::Release); + } + + /// Called before the worker sleeps + /// + /// Returns the previous [`Status`] + #[inline] + pub fn on_worker_sleep(&self) -> Status { + self.swap_status(Status::Sleeping, Ordering::Release) + } + + #[inline] + fn swap_status(&self, status: Status, ordering: Ordering) -> Status { + Status::from_u8(self.has_woken.swap(status.as_u8(), ordering)) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Status { + Sleeping, + PendingWork, + Working, +} + +impl Status { + is!(is_sleeping, Sleeping); + is!(is_pending_work, PendingWork); + is!(is_working, Working); + + #[inline] + fn as_u8(self) -> u8 { + match self { + Self::Sleeping => 0, + Self::PendingWork => 1, + Self::Working => 2, + } + } + + #[inline] + fn from_u8(v: u8) -> Self { + match v { + 1 => Self::PendingWork, + 2 => Self::Working, + _ => Self::Sleeping, + } + } +} + +impl Drop for Waker { + #[inline] + fn drop(&mut self) { + let pin = pin(); + let waker = crossbeam_epoch::Shared::null(); + let prev = self.waker.swap(waker, Ordering::AcqRel, &pin); + + ensure!(!prev.is_null()); + + unsafe { + pin.defer_unchecked(move || { + drop(prev.try_into_owned()); + }) + } + } +} diff --git a/quic/s2n-quic-core/src/macros.rs b/quic/s2n-quic-core/src/macros.rs index 057cc9095e..28dbbdcbad 100644 --- a/quic/s2n-quic-core/src/macros.rs +++ b/quic/s2n-quic-core/src/macros.rs @@ -153,7 +153,7 @@ macro_rules! ensure { /// Implements a future that wraps `T::poll_ready` and yields after ready macro_rules! impl_ready_future { ($name:ident, $fut:ident, $output:ty) => { - pub struct $fut<'a, T>(&'a mut T); + pub struct $fut<'a, T: ?Sized>(&'a mut T); impl<'a, T: $name> core::future::Future for $fut<'a, T> { type Output = $output; diff --git a/quic/s2n-quic-core/src/time/clock.rs b/quic/s2n-quic-core/src/time/clock.rs index 2ecf3fc03b..c4f764d3e2 100644 --- a/quic/s2n-quic-core/src/time/clock.rs +++ b/quic/s2n-quic-core/src/time/clock.rs @@ -27,7 +27,7 @@ pub trait ClockWithTimer: Clock { fn timer(&self) -> Self::Timer; } -pub trait Timer: Sized { +pub trait Timer { #[inline] fn ready(&mut self) -> TimerReady { TimerReady(self) @@ -57,12 +57,12 @@ impl Clock for Timestamp { } /// A clock that caches the time query for the inner clock -pub struct Cached<'a, C: Clock> { +pub struct Cached<'a, C: Clock + ?Sized> { clock: &'a C, cached_value: core::cell::Cell>, } -impl<'a, C: Clock> Cached<'a, C> { +impl<'a, C: Clock + ?Sized> Cached<'a, C> { #[inline] pub fn new(clock: &'a C) -> Self { Self { @@ -72,7 +72,7 @@ impl<'a, C: Clock> Cached<'a, C> { } } -impl<'a, C: Clock> Clock for Cached<'a, C> { +impl<'a, C: Clock + ?Sized> Clock for Cached<'a, C> { #[inline] fn get_time(&self) -> Timestamp { if let Some(time) = self.cached_value.get() {