From e82c7d672c76efbbcfd9f528c0b876daaadee642 Mon Sep 17 00:00:00 2001 From: Cameron Bytheway Date: Thu, 11 May 2023 16:45:45 -0600 Subject: [PATCH] feat(s2n-quic-xdp): implement IO traits for vectors of channel pairs --- tools/xdp/s2n-quic-xdp/src/io.rs | 15 --- tools/xdp/s2n-quic-xdp/src/io/rx.rs | 167 +++++++++++++++++-------- tools/xdp/s2n-quic-xdp/src/io/tests.rs | 90 ++++++++----- tools/xdp/s2n-quic-xdp/src/io/tx.rs | 146 ++++++++++++++++----- 4 files changed, 285 insertions(+), 133 deletions(-) diff --git a/tools/xdp/s2n-quic-xdp/src/io.rs b/tools/xdp/s2n-quic-xdp/src/io.rs index 1c15c04a80..ffdc912d02 100644 --- a/tools/xdp/s2n-quic-xdp/src/io.rs +++ b/tools/xdp/s2n-quic-xdp/src/io.rs @@ -1,21 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -// TODO replace with `core::task::ready` once we bump MSRV to 1.64.0 -// https://doc.rust-lang.org/core/task/macro.ready.html -// -// See https://github.com/aws/s2n-quic/issues/1750 -macro_rules! ready { - ($value:expr) => { - match $value { - ::core::task::Poll::Ready(v) => v, - ::core::task::Poll::Pending => { - return ::core::task::Poll::Pending; - } - } - }; -} - pub mod rx; pub mod tx; diff --git a/tools/xdp/s2n-quic-xdp/src/io/rx.rs b/tools/xdp/s2n-quic-xdp/src/io/rx.rs index 870ee947ff..6a4cf70b56 100644 --- a/tools/xdp/s2n-quic-xdp/src/io/rx.rs +++ b/tools/xdp/s2n-quic-xdp/src/io/rx.rs @@ -5,7 +5,10 @@ use crate::{ if_xdp::{RxTxDescriptor, UmemDescriptor}, umem::Umem, }; -use core::task::{Context, Poll}; +use core::{ + cell::UnsafeCell, + task::{Context, Poll}, +}; use s2n_codec::DecoderBufferMut; use s2n_quic_core::{ event, @@ -28,18 +31,28 @@ pub trait ErrorLogger: Send { } pub struct Rx { - occupied: Occupied, - free: Free, + channels: UnsafeCell>, + /// Store a vec of slices on the struct so we don't have to allocate every time `queue` is + /// called. Since this causes the type to be self-referential it does need a bit of unsafe code + /// to pull this off. + slices: UnsafeCell< + Vec<( + spsc::RecvSlice<'static, RxTxDescriptor>, + spsc::SendSlice<'static, UmemDescriptor>, + )>, + >, umem: Umem, error_logger: Option>, } impl Rx { /// Creates a RX IO interface for an s2n-quic endpoint - pub fn new(occupied: Occupied, free: Free, umem: Umem) -> Self { + pub fn new(channels: Vec<(Occupied, Free)>, umem: Umem) -> Self { + let slices = UnsafeCell::new(Vec::with_capacity(channels.len())); + let channels = UnsafeCell::new(channels); Self { - occupied, - free, + channels, + slices, umem, error_logger: None, } @@ -60,13 +73,46 @@ impl rx::Rx for Rx { #[inline] fn poll_ready(&mut self, cx: &mut Context) -> Poll> { // poll both channels to make sure we can make progress in both - let free = self.free.poll_slice(cx); - let occupied = self.occupied.poll_slice(cx); - ready!(free)?; - ready!(occupied)?; + let mut is_any_ready = false; + let mut is_all_occupied_closed = true; + let mut is_all_free_closed = true; + + for (occupied, free) in self.channels.get_mut() { + let mut is_ready = true; + + macro_rules! ready { + ($slice:ident, $closed:ident) => { + match $slice.poll_slice(cx) { + Poll::Ready(Ok(_)) => { + $closed = false; + } + Poll::Ready(Err(_)) => { + // defer returning an error until all slices return one + } + Poll::Pending => { + $closed = false; + is_ready = false + } + } + }; + } + + ready!(occupied, is_all_occupied_closed); + ready!(free, is_all_free_closed); + + is_any_ready |= is_ready; + } - Poll::Ready(Ok(())) + if is_all_occupied_closed || is_all_free_closed { + return Err(spsc::ClosedError).into(); + } + + if is_any_ready { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } } #[inline] @@ -88,14 +134,21 @@ impl rx::Rx for Rx { core::mem::transmute(self) }; - let occupied = this.occupied.slice(); - let free = this.free.slice(); + let slices = this.slices.get_mut(); + + for (occupied, free) in this.channels.get_mut().iter_mut() { + if occupied.is_empty() || free.capacity() == 0 { + continue; + } + + slices.push((occupied.slice(), free.slice())); + } + let umem = &mut this.umem; let error_logger = &mut this.error_logger; let mut queue = Queue { - occupied, - free, + slices, umem, error_logger, }; @@ -111,8 +164,10 @@ impl rx::Rx for Rx { } pub struct Queue<'a> { - occupied: spsc::RecvSlice<'a, RxTxDescriptor>, - free: spsc::SendSlice<'a, UmemDescriptor>, + slices: &'a mut Vec<( + spsc::RecvSlice<'a, RxTxDescriptor>, + spsc::SendSlice<'a, UmemDescriptor>, + )>, umem: &'a mut Umem, error_logger: &'a mut Option>, } @@ -122,48 +177,58 @@ impl<'a> rx::Queue for Queue<'a> { #[inline] fn for_each, &mut [u8])>(&mut self, mut on_packet: F) { - // only pop as many items as we have capacity to free them - while self.free.capacity() > 0 { - let descriptor = match self.occupied.pop() { - Some(v) => v, - None => break, - }; - - let buffer = unsafe { - // Safety: this descriptor should be unique, assuming the tasks are functioning - // properly - self.umem.get_mut(descriptor) - }; - - // create a decoder from the descriptor's buffer - let decoder = DecoderBufferMut::new(buffer); - - // try to decode the packet and emit the result - match decoder::decode_packet(decoder) { - Ok(Some((header, payload))) => { - on_packet(header, payload.into_less_safe_slice()); - } - Ok(None) | Err(_) => { - // This shouldn't happen. If it does, the BPF program isn't properly validating - // packets before they get to userspace. - if let Some(error_logger) = self.error_logger.as_mut() { - error_logger.log_invalid_packet(buffer); + for (occupied, free) in self.slices.iter_mut() { + // only pop as many items as we have capacity to free them + while free.capacity() > 0 { + let descriptor = match occupied.pop() { + Some(v) => v, + None => break, + }; + + let buffer = unsafe { + // Safety: this descriptor should be unique, assuming the tasks are functioning + // properly + self.umem.get_mut(descriptor) + }; + + // create a decoder from the descriptor's buffer + let decoder = DecoderBufferMut::new(buffer); + + // try to decode the packet and emit the result + match decoder::decode_packet(decoder) { + Ok(Some((header, payload))) => { + on_packet(header, payload.into_less_safe_slice()); + } + Ok(None) | Err(_) => { + // This shouldn't happen. If it does, the BPF program isn't properly validating + // packets before they get to userspace. + if let Some(error_logger) = self.error_logger.as_mut() { + error_logger.log_invalid_packet(buffer); + } } } - } - // send the descriptor to the free queue - let result = self.free.push(descriptor.into()); + // send the descriptor to the free queue + let result = free.push(descriptor.into()); - debug_assert!( - result.is_ok(), - "free queue capacity should always exceed occupied" - ); + debug_assert!( + result.is_ok(), + "free queue capacity should always exceed occupied" + ); + } } } #[inline] fn is_empty(&self) -> bool { - self.occupied.is_empty() + self.slices.is_empty() + } +} + +impl<'a> Drop for Queue<'a> { + #[inline] + fn drop(&mut self) { + // make sure we drop all of the slices to flush our changes + self.slices.clear(); } } diff --git a/tools/xdp/s2n-quic-xdp/src/io/tests.rs b/tools/xdp/s2n-quic-xdp/src/io/tests.rs index 2480be431f..ed521f6ad6 100644 --- a/tools/xdp/s2n-quic-xdp/src/io/tests.rs +++ b/tools/xdp/s2n-quic-xdp/src/io/tests.rs @@ -28,24 +28,41 @@ use s2n_quic_core::{ /// Tests the s2n-quic-core IO trait implementations by sending packets over spsc channels #[tokio::test] async fn tx_rx_test() { + let frame_count = 16; let mut umem = Umem::builder(); - umem.frame_count = 16; + umem.frame_count = frame_count; umem.frame_size = 128; let umem = umem.build().unwrap(); // send a various amount of packets for each test for packets in [1, 100, 1000, 10_000] { - let (input, tx_input) = spsc::channel(16); - let (mut rx_free, tx_free) = spsc::channel(32); - let (tx_occupied, rx_occupied) = spsc::channel(16); - let (rx_output, output) = spsc::channel(32); + for input_counts in [1, 2] { + eprintln!("packets: {packets}, input_counts: {input_counts}"); - rx_free.slice().extend(&mut umem.frames()).unwrap(); + let mut rx_inputs = vec![]; + let mut tx_outputs = vec![]; - tokio::spawn(packet_gen(packets, input)); - tokio::spawn(send(tx_free, tx_occupied, umem.clone(), tx_input)); - tokio::spawn(recv(rx_occupied, rx_free, umem.clone(), rx_output)); - packet_checker(packets, output).await; + let mut frames = umem.frames(); + + for _ in 0..input_counts { + let (mut rx_free, tx_free) = spsc::channel(32); + let (tx_occupied, rx_occupied) = spsc::channel(16); + + let mut rx_frames = (&mut frames).take((frame_count / input_counts) as usize); + rx_free.slice().extend(&mut rx_frames).unwrap(); + + tx_outputs.push((tx_free, tx_occupied)); + rx_inputs.push((rx_occupied, rx_free)); + } + + let (input, tx_input) = spsc::channel(16); + let (rx_output, output) = spsc::channel(32); + + tokio::spawn(packet_gen(packets, input)); + tokio::spawn(send(tx_outputs, umem.clone(), tx_input)); + tokio::spawn(recv(rx_inputs, umem.clone(), rx_output)); + packet_checker(packets, output).await; + } } } @@ -111,13 +128,12 @@ async fn packet_gen(count: u32, mut output: spsc::Sender) { /// Sends packets over the TX queue from an input channel async fn send( - free: tx::Free, - occupied: tx::Occupied, + outputs: Vec<(tx::Free, tx::Occupied)>, umem: Umem, mut input: spsc::Receiver, ) { let state = Default::default(); - let mut tx = Tx::new(free, occupied, umem, state); + let mut tx = Tx::new(outputs, umem, state); loop { let res = select(input.acquire(), tx.ready()).await; @@ -149,27 +165,32 @@ async fn send( trace!("send finishing"); - let (mut free, occupied) = tx.consume(); + let channels = tx.consume(); - // notify the recv task that there aren't going to be any more packets sent - drop(occupied); + let free: Vec<_> = channels + .into_iter() + .map(|(mut free, occupied)| { + // notify the recv task that there aren't going to be any more packets sent + drop(occupied); - // drain the free queue so the `recv` task doesn't shut down prematurely - while free.acquire().await.is_ok() { - free.slice().clear(); - } + async move { + // drain the free queue so the `recv` task doesn't shut down prematurely + while free.acquire().await.is_ok() { + free.slice().clear(); + } + } + }) + .collect(); + + // wait until all of the futures finish + futures::future::join_all(free).await; trace!("shutting down send"); } /// Receives raw packets and converts them into [`Packet`]s, putting them on the `output` channel. -async fn recv( - occupied: rx::Occupied, - free: rx::Free, - umem: Umem, - mut output: spsc::Sender, -) { - let mut rx = Rx::new(occupied, free, umem); +async fn recv(inputs: Vec<(rx::Occupied, rx::Free)>, umem: Umem, mut output: spsc::Sender) { + let mut rx = Rx::new(inputs, umem); while rx.ready().await.is_ok() { trace!("recv ready"); @@ -198,24 +219,25 @@ async fn recv( /// Checks that the received [`Packet`]s match the expected values async fn packet_checker(total: u32, mut output: spsc::Receiver) { - let mut expected = 0; + let mut actual = s2n_quic_core::interval_set::IntervalSet::default(); + while output.acquire().await.is_ok() { let mut output = output.slice(); while let Some(packet) = output.pop() { trace!("output packet recv: {packet:?}"); - assert_eq!( - packet.counter, expected, - "packet counter should be sequential" - ); - expected += 1; + actual.insert_value(packet.counter).unwrap(); } // we want to consume the output queue as fast as possible so the `recv` task doesn't have // to block on the checker } - assert_eq!(total, expected, "total output packets does not match input"); + assert_eq!( + total as usize, + actual.count(), + "total output packets does not match input" + ); } /// Randomly yields to other tasks diff --git a/tools/xdp/s2n-quic-xdp/src/io/tx.rs b/tools/xdp/s2n-quic-xdp/src/io/tx.rs index e7d26bdc69..6827027351 100644 --- a/tools/xdp/s2n-quic-xdp/src/io/tx.rs +++ b/tools/xdp/s2n-quic-xdp/src/io/tx.rs @@ -5,7 +5,10 @@ use crate::{ if_xdp::{RxTxDescriptor, UmemDescriptor}, umem::Umem, }; -use core::task::{Context, Poll}; +use core::{ + cell::UnsafeCell, + task::{Context, Poll}, +}; use s2n_codec::{Encoder as _, EncoderBuffer}; use s2n_quic_core::{ event, @@ -18,8 +21,16 @@ pub type Free = spsc::Receiver; pub type Occupied = spsc::Sender; pub struct Tx { - free: Free, - occupied: Occupied, + channels: UnsafeCell>, + /// Store a vec of slices on the struct so we don't have to allocate every time `queue` is + /// called. Since this causes the type to be self-referential it does need a bit of unsafe code + /// to pull this off. + slices: UnsafeCell< + Vec<( + spsc::RecvSlice<'static, UmemDescriptor>, + spsc::SendSlice<'static, RxTxDescriptor>, + )>, + >, umem: Umem, encoder: encoder::State, is_full: bool, @@ -27,13 +38,15 @@ pub struct Tx { impl Tx { /// Creates a TX IO interface for an s2n-quic endpoint - pub fn new(free: Free, occupied: Occupied, umem: Umem, encoder: encoder::State) -> Self { + pub fn new(channels: Vec<(Free, Occupied)>, umem: Umem, encoder: encoder::State) -> Self { + let slices = UnsafeCell::new(Vec::with_capacity(channels.len())); + let channels = UnsafeCell::new(channels); Self { - occupied, - free, + channels, + slices, umem, encoder, - is_full: false, + is_full: true, } } @@ -41,8 +54,8 @@ impl Tx { /// /// This is used for internal tests only. #[cfg(test)] - pub fn consume(self) -> (Free, Occupied) { - (self.free, self.occupied) + pub fn consume(self) -> Vec<(Free, Occupied)> { + self.channels.into_inner() } } @@ -53,19 +66,51 @@ impl tx::Tx for Tx { #[inline] fn poll_ready(&mut self, cx: &mut Context) -> Poll> { - // poll both channels to make sure we can make progress in both - let free = self.free.poll_slice(cx); - let occupied = self.occupied.poll_slice(cx); - - ready!(free)?; - ready!(occupied)?; - - // we only need to wake up if the queue was previously completely filled up + // If we didn't fill up the queue then we don't need to poll for capacity if !self.is_full { return Poll::Pending; } - Poll::Ready(Ok(())) + // poll both channels to make sure we can make progress in both + let mut is_any_ready = false; + let mut is_all_free_closed = true; + let mut is_all_occupied_closed = true; + + for (free, occupied) in self.channels.get_mut() { + let mut is_ready = true; + + macro_rules! ready { + ($slice:ident, $closed:ident) => { + match $slice.poll_slice(cx) { + Poll::Ready(Ok(_)) => { + $closed = false; + } + Poll::Ready(Err(_)) => { + // defer returning an error until all slices return one + } + Poll::Pending => { + $closed = false; + is_ready = false + } + } + }; + } + + ready!(occupied, is_all_occupied_closed); + ready!(free, is_all_free_closed); + + is_any_ready |= is_ready; + } + + if is_all_free_closed || is_all_occupied_closed { + return Err(spsc::ClosedError).into(); + } + + if is_any_ready { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } } #[inline] @@ -87,25 +132,37 @@ impl tx::Tx for Tx { core::mem::transmute(self) }; - let mut free = this.free.slice(); - let mut occupied = this.occupied.slice(); + let slices = this.slices.get_mut(); + + let mut capacity = 0; - // if we were full, then try to synchronize the peer's queues - if this.is_full { + for (free, occupied) in this.channels.get_mut().iter_mut() { + let mut free = free.slice(); + let mut occupied = occupied.slice(); + + // try to synchronize the peer's queues let _ = free.sync(); let _ = occupied.sync(); + + if free.is_empty() || occupied.capacity() == 0 { + continue; + } + + capacity += free.len().min(occupied.capacity()); + slices.push((free, occupied)); } // update our full status - this.is_full = free.is_empty() || occupied.capacity() == 0; + this.is_full = slices.is_empty(); let umem = &mut this.umem; let encoder = &mut this.encoder; let is_full = &mut this.is_full; let mut queue = Queue { - free, - occupied, + slices, + slice_index: 0, + capacity, umem, encoder, is_full, @@ -122,8 +179,12 @@ impl tx::Tx for Tx { } pub struct Queue<'a> { - free: spsc::RecvSlice<'a, UmemDescriptor>, - occupied: spsc::SendSlice<'a, RxTxDescriptor>, + slices: &'a mut Vec<( + spsc::RecvSlice<'a, UmemDescriptor>, + spsc::SendSlice<'a, RxTxDescriptor>, + )>, + slice_index: usize, + capacity: usize, umem: &'a mut Umem, encoder: &'a mut encoder::State, is_full: &'a mut bool, @@ -141,12 +202,17 @@ impl<'a> tx::Queue for Queue<'a> { M: tx::Message, { // if we're at capacity, then return an error - if *self.is_full { + if self.capacity == 0 { return Err(tx::Error::AtCapacity); } + let (free, occupied) = unsafe { + // Safety: the slice index should always be in bounds + self.slices.get_unchecked_mut(self.slice_index) + }; + // take the first free descriptor, we should have at least one item - let (head, _) = self.free.peek(); + let (head, _) = free.peek(); let descriptor = head[0]; let buffer = unsafe { @@ -166,7 +232,7 @@ impl<'a> tx::Queue for Queue<'a> { let descriptor = descriptor.with_len(len as _); // push the descriptor on so it can be transmitted - let result = self.occupied.push(descriptor); + let result = occupied.push(descriptor); debug_assert!( result.is_ok(), @@ -174,10 +240,16 @@ impl<'a> tx::Queue for Queue<'a> { ); // make sure we give capacity back to the free queue - self.free.release(1); + free.release(1); + + // if this slice is at capacity then increment the index and try the next one + if free.is_empty() || occupied.capacity() == 0 { + self.slice_index += 1; + } // check to see if we're full now - *self.is_full = !self.has_capacity(); + self.capacity -= 1; + *self.is_full = self.capacity == 0; // let the caller know how big the payload was let outcome = tx::Outcome { @@ -190,6 +262,14 @@ impl<'a> tx::Queue for Queue<'a> { #[inline] fn capacity(&self) -> usize { - self.free.len().min(self.occupied.capacity()) + self.capacity + } +} + +impl<'a> Drop for Queue<'a> { + #[inline] + fn drop(&mut self) { + // make sure we drop all of the slices to flush our changes + self.slices.clear(); } }