diff --git a/quic/s2n-quic-platform/src/socket.rs b/quic/s2n-quic-platform/src/socket.rs index 4cefdef350..d8d307870a 100644 --- a/quic/s2n-quic-platform/src/socket.rs +++ b/quic/s2n-quic-platform/src/socket.rs @@ -10,6 +10,7 @@ pub mod mmsg; pub mod msg; pub mod ring; pub mod std; +pub mod task; cfg_if! { if #[cfg(s2n_quic_platform_socket_mmsg)] { diff --git a/quic/s2n-quic-platform/src/socket/task.rs b/quic/s2n-quic-platform/src/socket/task.rs new file mode 100644 index 0000000000..336c294ea7 --- /dev/null +++ b/quic/s2n-quic-platform/src/socket/task.rs @@ -0,0 +1,9 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod events; +pub mod rx; +pub mod tx; + +pub use rx::Receiver; +pub use tx::Sender; diff --git a/quic/s2n-quic-platform/src/socket/task/events.rs b/quic/s2n-quic-platform/src/socket/task/events.rs new file mode 100644 index 0000000000..a1a31179f5 --- /dev/null +++ b/quic/s2n-quic-platform/src/socket/task/events.rs @@ -0,0 +1,161 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Some of the functions in these impls are not used on non-unix systems +#![cfg_attr(not(unix), allow(dead_code))] + +use crate::features::Gso; +use core::ops::ControlFlow; + +#[derive(Debug)] +pub struct TxEvents { + count: usize, + is_blocked: bool, + #[cfg_attr(not(s2n_quic_platform_gso), allow(dead_code))] + gso: Gso, +} + +impl TxEvents { + #[inline] + pub fn new(gso: Gso) -> Self { + Self { + count: 0, + is_blocked: false, + gso, + } + } + + /// Returns if the task is blocked + #[inline] + pub fn is_blocked(&self) -> bool { + self.is_blocked + } + + /// Returns if the task was blocked and resets the value + #[inline] + pub fn take_blocked(&mut self) -> bool { + core::mem::take(&mut self.is_blocked) + } + + /// Sets the task to blocked + #[inline] + pub fn blocked(&mut self) { + self.is_blocked = true; + } + + /// Returns and resets the number of messages sent + #[inline] + pub fn take_count(&mut self) -> usize { + core::mem::take(&mut self.count) + } +} + +impl crate::syscall::SocketEvents for TxEvents { + #[inline] + fn on_complete(&mut self, count: usize) -> ControlFlow<(), ()> { + // increment the total sent packets and reset our blocked status + self.count += count; + self.is_blocked = false; + ControlFlow::Continue(()) + } + + #[inline] + fn on_error(&mut self, error: ::std::io::Error) -> ControlFlow<(), ()> { + use std::io::ErrorKind::*; + + match error.kind() { + WouldBlock => { + // record that we're blocked + self.is_blocked = true; + ControlFlow::Break(()) + } + Interrupted => { + // if we got interrupted break and have the task try again + ControlFlow::Break(()) + } + #[cfg(s2n_quic_platform_gso)] + _ if error.raw_os_error() == Some(libc::EIO) => { + // on platforms that don't support GSO we need to disable it and mark the packet as + // "sent" even though we weren't able to. + self.count += 1; + + self.gso.disable(); + + // We `continue` instead of break because it's very unlikely the message would be + // accepted at a later time, so we just discard the packet. + ControlFlow::Continue(()) + } + _ => { + // ignore all other errors and just consider the packet sent + self.count += 1; + + // We `continue` instead of break because it's very unlikely the message would be + // accepted at a later time, so we just discard the packet. + ControlFlow::Continue(()) + } + } + } +} + +#[derive(Debug, Default)] +pub struct RxEvents { + count: usize, + is_blocked: bool, +} + +impl RxEvents { + /// Returns if the task is blocked + #[inline] + pub fn is_blocked(&self) -> bool { + self.is_blocked + } + + /// Returns if the task was blocked and resets the value + #[inline] + pub fn take_blocked(&mut self) -> bool { + core::mem::take(&mut self.is_blocked) + } + + /// Sets the task to blocked + #[inline] + pub fn blocked(&mut self) { + self.is_blocked = true; + } + + /// Returns and resets the number of messages sent + #[inline] + pub fn take_count(&mut self) -> usize { + core::mem::take(&mut self.count) + } +} + +impl crate::syscall::SocketEvents for RxEvents { + #[inline] + fn on_complete(&mut self, count: usize) -> ControlFlow<(), ()> { + // increment the total sent packets and reset our blocked status + self.count += count; + self.is_blocked = false; + ControlFlow::Continue(()) + } + + #[inline] + fn on_error(&mut self, error: ::std::io::Error) -> ControlFlow<(), ()> { + use std::io::ErrorKind::*; + + match error.kind() { + WouldBlock => { + // record that we're blocked + self.is_blocked = true; + ControlFlow::Break(()) + } + Interrupted => { + // if we got interrupted break and have the task try again + ControlFlow::Break(()) + } + _ => { + // ignore all other errors and have the task try again + ControlFlow::Break(()) + } + } + } +} diff --git a/quic/s2n-quic-platform/src/socket/task/rx.rs b/quic/s2n-quic-platform/src/socket/task/rx.rs new file mode 100644 index 0000000000..fd9c259010 --- /dev/null +++ b/quic/s2n-quic-platform/src/socket/task/rx.rs @@ -0,0 +1,120 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + message::Message, + socket::{ring::Producer, task::events}, +}; +use core::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use futures::ready; + +pub trait Socket { + type Error; + + fn recv( + &mut self, + cx: &mut Context, + entries: &mut [T], + events: &mut events::RxEvents, + ) -> Result<(), Self::Error>; +} + +pub struct Receiver> { + ring: Producer, + /// Implementation of a socket that fills free slots in the ring buffer + rx: S, + /// The number of messages that have been filled but not yet released to the consumer. + /// + /// This value is to avoid calling `release` too much and excessively waking up the consumer. + pending: u32, +} + +impl Receiver +where + T: Message + Unpin, + S: Socket + Unpin, +{ + #[inline] + pub fn new(ring: Producer, rx: S) -> Self { + Self { + ring, + rx, + pending: 0, + } + } + + #[inline] + fn poll_ring(&mut self, watermark: u32, cx: &mut Context) -> Poll> { + loop { + let count = match self.ring.poll_acquire(watermark, cx) { + Poll::Ready(count) => count, + Poll::Pending if self.pending == 0 => { + return if !self.ring.is_open() { + Err(()).into() + } else { + Poll::Pending + }; + } + Poll::Pending => 0, + }; + + // if the number of free slots increased since last time then yield + if count > self.pending { + return Ok(()).into(); + } + + // If there is no additional capacity available (i.e. we have filled all slots), + // then release those filled slots for the consumer to read from. Once + // the consumer reads, we will have spare capacity to populate again. + self.release(); + } + } + + #[inline] + fn release(&mut self) { + let to_release = core::mem::take(&mut self.pending); + self.ring.release(to_release); + } +} + +impl Future for Receiver +where + T: Message + Unpin, + S: Socket + Unpin, +{ + type Output = Option; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + let mut events = events::RxEvents::default(); + + while !events.take_blocked() { + if ready!(this.poll_ring(u32::MAX, cx)).is_err() { + return None.into(); + } + + // slice the ring data by the number of slots we've already filled + let entries = &mut this.ring.data()[this.pending as usize..]; + + // perform the recv syscall + match this.rx.recv(cx, entries, &mut events) { + Ok(()) => { + // increment the number of received messages + this.pending += events.take_count() as u32 + } + Err(err) => return Some(err).into(), + } + } + + // release any of the messages we wrote back to the consumer + this.release(); + + Poll::Pending + } +} diff --git a/quic/s2n-quic-platform/src/socket/task/tx.rs b/quic/s2n-quic-platform/src/socket/task/tx.rs new file mode 100644 index 0000000000..72ad65a368 --- /dev/null +++ b/quic/s2n-quic-platform/src/socket/task/tx.rs @@ -0,0 +1,121 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + features::Gso, + message::Message, + socket::{ring::Consumer, task::events}, +}; +use core::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use futures::ready; + +pub trait Socket { + type Error; + + fn send( + &mut self, + cx: &mut Context, + entries: &mut [T], + events: &mut events::TxEvents, + ) -> Result<(), Self::Error>; +} + +pub struct Sender> { + ring: Consumer, + /// Implementation of a socket that transmits filled slots in the ring buffer + tx: S, + /// The number of messages that have been transmitted but not yet released to the producer. + /// + /// This value is to avoid calling `release` too much and excessively waking up the producer. + pending: u32, + events: events::TxEvents, +} + +impl Sender +where + T: Message + Unpin, + S: Socket + Unpin, +{ + #[inline] + pub fn new(ring: Consumer, tx: S, gso: Gso) -> Self { + Self { + ring, + tx, + pending: 0, + events: events::TxEvents::new(gso), + } + } + + #[inline] + fn poll_ring(&mut self, watermark: u32, cx: &mut Context) -> Poll> { + loop { + let count = match self.ring.poll_acquire(watermark, cx) { + Poll::Ready(count) => count, + Poll::Pending if self.pending == 0 => { + return if !self.ring.is_open() { + Err(()).into() + } else { + Poll::Pending + }; + } + Poll::Pending => 0, + }; + + // if the number of free slots increased since last time then yield + if count > self.pending { + return Ok(()).into(); + } + + // If there is no additional capacity available (i.e. we have filled all slots), + // then release those filled slots for the consumer to read from. Once + // the consumer reads, we will have spare capacity to populate again. + self.release(); + } + } + + #[inline] + fn release(&mut self) { + let to_release = core::mem::take(&mut self.pending); + self.ring.release(to_release); + } +} + +impl Future for Sender +where + T: Message + Unpin, + S: Socket + Unpin, +{ + type Output = Option; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + while !this.events.take_blocked() { + if ready!(this.poll_ring(u32::MAX, cx)).is_err() { + return None.into(); + } + + // slice the ring data by the number of items we've already received + let entries = &mut this.ring.data()[this.pending as usize..]; + + // perform the send syscall + match this.tx.send(cx, entries, &mut this.events) { + Ok(()) => { + // increment the number of received messages + this.pending += this.events.take_count() as u32 + } + Err(err) => return Some(err).into(), + } + } + + // release any of the messages we wrote back to the consumer + this.release(); + + Poll::Pending + } +} diff --git a/quic/s2n-quic-platform/src/syscall.rs b/quic/s2n-quic-platform/src/syscall.rs index b8c4d7ccb1..d49891250d 100644 --- a/quic/s2n-quic-platform/src/syscall.rs +++ b/quic/s2n-quic-platform/src/syscall.rs @@ -23,9 +23,17 @@ pub enum SocketType { pub trait SocketEvents { /// Called when `count` packets are completed + /// + /// If `Continue` is returned, the socket will assume the packet was acceptable and continue + /// with the remaining packets. If `Break` is returned, the syscall stop looping and yield to + /// the caller. fn on_complete(&mut self, count: usize) -> ControlFlow<(), ()>; /// Called when an error occurs on a socket + /// + /// If `Continue` is returned, the socket will discard the packet and continue + /// with the remaining packets. If `Break` is returned, the syscall will assume the current + /// packet can be retried and yield to the caller. fn on_error(&mut self, error: io::Error) -> ControlFlow<(), ()>; }