Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(s2n-quic-dc): import latest changes #2267

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dc/s2n-quic-dc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,22 @@ exclude = ["corpus.tar.gz"]
testing = ["bolero-generator", "s2n-quic-core/testing"]

[dependencies]
arrayvec = "0.7"
atomic-waker = "1"
aws-lc-rs = "1"
bitflags = "2"
bolero-generator = { version = "0.11", optional = true }
bytes = "1"
crossbeam-channel = "0.5"
crossbeam-epoch = "0.9"
crossbeam-queue = { version = "0.3" }
flurry = "0.5"
libc = "0.2"
num-rational = { version = "0.4", default-features = false }
once_cell = "1"
pin-project-lite = "0.2"
rand = { version = "0.8", features = ["small_rng"] }
rand_chacha = "0.3"
s2n-codec = { version = "=0.41.0", path = "../../common/s2n-codec", default-features = false }
s2n-quic-core = { version = "=0.41.0", path = "../../quic/s2n-quic-core", default-features = false }
s2n-quic-platform = { version = "=0.41.0", path = "../../quic/s2n-quic-platform" }
Expand Down
86 changes: 86 additions & 0 deletions dc/s2n-quic-dc/src/clock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use core::{fmt, pin::Pin, task::Poll, time::Duration};
use s2n_quic_core::{ensure, time};
use tracing::trace;

pub mod tokio;
pub use time::clock::Cached;

pub use time::Timestamp;
pub type SleepHandle = Pin<Box<dyn Sleep>>;

pub trait Clock: 'static + Send + Sync + fmt::Debug + time::Clock {
fn sleep(&self, amount: Duration) -> (SleepHandle, Timestamp);
}

pub trait Sleep: Clock + core::future::Future<Output = ()> {
fn update(self: Pin<&mut Self>, target: Timestamp);
}

pub struct Timer {
/// The `Instant` at which the timer should expire
target: Option<Timestamp>,
/// The handle to the timer entry in the tokio runtime
sleep: Pin<Box<dyn Sleep>>,
}

impl fmt::Debug for Timer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Timer")
.field("target", &self.target)
.finish()
}
}

impl Timer {
#[inline]
pub fn new(clock: &dyn Clock) -> Self {
/// We can't create a timer without first arming it to something, so just set it to 1s in
/// the future.
const INITIAL_TIMEOUT: Duration = Duration::from_secs(1);

Self::new_with_timeout(clock, INITIAL_TIMEOUT)
}

#[inline]
pub fn new_with_timeout(clock: &dyn Clock, timeout: Duration) -> Self {
let (sleep, target) = clock.sleep(timeout);
Self {
target: Some(target),
sleep,
}
}

#[inline]
pub fn cancel(&mut self) {
trace!(cancel = ?self.target);
self.target = None;
}
}

impl time::clock::Timer for Timer {
#[inline]
fn poll_ready(&mut self, cx: &mut core::task::Context) -> Poll<()> {
ensure!(self.target.is_some(), Poll::Ready(()));

let res = self.sleep.as_mut().poll(cx);

if res.is_ready() {
// clear the target after it fires, otherwise we'll endlessly wake up the task
self.target = None;
}

res
}

#[inline]
fn update(&mut self, target: Timestamp) {
// no need to update if it hasn't changed
ensure!(self.target != Some(target));

self.sleep.as_mut().update(target);
self.target = Some(target);
}
}
122 changes: 122 additions & 0 deletions dc/s2n-quic-dc/src/clock/tokio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use super::SleepHandle;
use core::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use pin_project_lite::pin_project;
use s2n_quic_core::{ready, time::Timestamp};
use tokio::time::{self, sleep_until, Instant};
use tracing::trace;

#[derive(Clone, Debug)]
pub struct Clock(Instant);

impl Default for Clock {
#[inline]
fn default() -> Self {
Self(Instant::now())
}
}

impl s2n_quic_core::time::Clock for Clock {
#[inline]
fn get_time(&self) -> Timestamp {
let time = self.0.elapsed();
unsafe { Timestamp::from_duration(time) }
}
}

pin_project!(
pub struct Sleep {
clock: Clock,
#[pin]
sleep: time::Sleep,
}
);

impl s2n_quic_core::time::Clock for Sleep {
#[inline]
fn get_time(&self) -> Timestamp {
self.clock.get_time()
}
}

impl Future for Sleep {
type Output = ();

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
ready!(core::future::Future::poll(this.sleep, cx));
Poll::Ready(())
}
}

impl super::Sleep for Sleep {
#[inline]
fn update(self: Pin<&mut Self>, target: Timestamp) {
let target = unsafe { target.as_duration() };

// floor the delay to milliseconds to reduce timer churn
let delay = Duration::from_millis(target.as_millis() as u64);

let target = self.clock.0 + delay;

// if the clock has changed let the sleep future know
trace!(update = ?target);
self.project().sleep.reset(target);
}
}

impl super::Clock for Sleep {
#[inline]
fn sleep(&self, amount: Duration) -> (SleepHandle, Timestamp) {
self.clock.sleep(amount)
}
}

impl fmt::Debug for Sleep {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sleep")
.field("clock", &self.clock)
.field("sleep", &self.sleep)
.finish()
}
}

impl super::Clock for Clock {
#[inline]
fn sleep(&self, amount: Duration) -> (SleepHandle, Timestamp) {
let now = Instant::now();
let sleep = sleep_until(now + amount);
let sleep = Sleep {
clock: self.clone(),
sleep,
};
let sleep = Box::pin(sleep);
let target = now.saturating_duration_since(self.0);
let target = unsafe { Timestamp::from_duration(target) };
(sleep, target)
}
}

#[cfg(test)]
mod tests {
use crate::clock::{tokio::Clock, Timer};
use core::time::Duration;
use s2n_quic_core::time::{clock::Timer as _, Clock as _};

#[tokio::test]
async fn clock_test() {
let clock = Clock::default();
let mut timer = Timer::new(&clock);
timer.ready().await;
timer.update(clock.get_time() + Duration::from_secs(1));
timer.ready().await;
}
}
9 changes: 4 additions & 5 deletions dc/s2n-quic-dc/src/congestion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ pub struct Controller {

impl Controller {
#[inline]
pub fn new(mtu: u16) -> Self {
let mut controller = BbrCongestionController::new(mtu, Default::default());
let publisher = &mut NoopPublisher;
controller.on_mtu_update(mtu, publisher);
Self { controller }
pub fn new(max_datagram_size: u16) -> Self {
Self {
controller: BbrCongestionController::new(max_datagram_size, Default::default()),
}
}

#[inline]
Expand Down
3 changes: 3 additions & 0 deletions dc/s2n-quic-dc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0

pub mod allocator;
pub mod clock;
pub mod congestion;
pub mod control;
pub mod credentials;
Expand All @@ -11,8 +12,10 @@ pub mod msg;
pub mod packet;
pub mod path;
pub mod pool;
pub mod random;
pub mod recovery;
pub mod socket;
pub mod stream;
pub mod task;

pub use s2n_quic_core::dc::{Version, SUPPORTED_VERSIONS};
1 change: 1 addition & 0 deletions dc/s2n-quic-dc/src/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
pub mod addr;
pub mod cmsg;
pub mod recv;
pub mod segment;
pub mod send;
4 changes: 1 addition & 3 deletions dc/s2n-quic-dc/src/msg/recv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use s2n_quic_core::{
buffer::Deque as Buffer,
ensure,
inet::{ExplicitCongestionNotification, SocketAddress},
path::MaxMtu,
ready,
};
use std::{io, os::fd::AsRawFd};
Expand All @@ -32,8 +31,7 @@ impl fmt::Debug for Message {

impl Message {
#[inline]
pub fn new(max_mtu: MaxMtu) -> Self {
let max_mtu: u16 = max_mtu.into();
pub fn new(max_mtu: u16) -> Self {
let max_mtu = max_mtu as usize;
let buffer_len = cmsg::MAX_GRO_SEGMENTS * max_mtu;
// the recv syscall doesn't return more than this
Expand Down
98 changes: 98 additions & 0 deletions dc/s2n-quic-dc/src/msg/segment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use arrayvec::ArrayVec;
use core::ops::Deref;
use s2n_quic_core::{ensure, inet::ExplicitCongestionNotification};
use std::io::IoSlice;

/// The maximum number of segments in sendmsg calls
///
/// From <https://elixir.bootlin.com/linux/v6.8.7/source/include/uapi/linux/uio.h#L27>
/// > #define UIO_FASTIOV 8
pub const MAX_COUNT: usize = if cfg!(target_os = "linux") { 8 } else { 1 };

/// The maximum payload allowed in sendmsg calls
///
/// From <https://github.com/torvalds/linux/blob/8cd26fd90c1ad7acdcfb9f69ca99d13aa7b24561/net/ipv4/ip_output.c#L987-L995>
/// > Linux enforces a u16::MAX - IP_HEADER_LEN - UDP_HEADER_LEN
pub const MAX_TOTAL: u16 = u16::MAX - 50;

type Segments<'a> = ArrayVec<IoSlice<'a>, MAX_COUNT>;

pub struct Batch<'a> {
segments: Segments<'a>,
ecn: ExplicitCongestionNotification,
}

impl<'a> Deref for Batch<'a> {
type Target = [IoSlice<'a>];

#[inline]
fn deref(&self) -> &Self::Target {
&self.segments
}
}

impl<'a> Batch<'a> {
#[inline]
pub fn new<Q>(queue: Q) -> Self
where
Q: IntoIterator<Item = (ExplicitCongestionNotification, &'a [u8])>,
{
// this value is replaced by the first segment
let mut ecn = ExplicitCongestionNotification::Ect0;
let mut total_len = 0u16;
let mut segments = Segments::new();

for segment in queue {
let packet_len = segment.1.len();
debug_assert!(
packet_len <= u16::MAX as usize,
"segments should not exceed the maximum datagram size"
);
let packet_len = packet_len as u16;

// make sure the packet fits in u16::MAX
let Some(new_total_len) = total_len.checked_add(packet_len) else {
break;
};
// make sure we don't exceed the max allowed payload size
ensure!(new_total_len < MAX_TOTAL, break);

// track if the current segment is undersized from the previous
let mut undersized_segment = false;

// make sure we're compatible with the previous segment
if let Some(first_segment) = segments.first() {
ensure!(first_segment.len() >= packet_len as usize, break);
// this is the last segment we can push if the segment is undersized
undersized_segment = first_segment.len() > packet_len as usize;
// make sure ecn doesn't change with this transmission
ensure!(ecn == segment.0, break);
} else {
// update the ecn value with the first segment
ecn = segment.0;
}

// update the total len once we confirm this segment can be written
total_len = new_total_len;

let iovec = std::io::IoSlice::new(segment.1);
segments.push(iovec);

// if this segment was undersized, then bail
ensure!(!undersized_segment, break);

// make sure we have capacity before looping back around
ensure!(!segments.is_full(), break);
}

Self { segments, ecn }
}

#[inline]
pub fn ecn(&self) -> ExplicitCongestionNotification {
self.ecn
}
}
Loading
Loading