Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(s2n-quic-xdp): implement IO traits for vectors of channel pairs #1761

Merged
merged 1 commit into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions tools/xdp/s2n-quic-xdp/src/io.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,6 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

// TODO replace with `core::task::ready` once we bump MSRV to 1.64.0
// https://doc.rust-lang.org/core/task/macro.ready.html
//
// See https://github.com/aws/s2n-quic/issues/1750
macro_rules! ready {
($value:expr) => {
match $value {
::core::task::Poll::Ready(v) => v,
::core::task::Poll::Pending => {
return ::core::task::Poll::Pending;
}
}
};
}

pub mod rx;
pub mod tx;

Expand Down
167 changes: 116 additions & 51 deletions tools/xdp/s2n-quic-xdp/src/io/rx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use crate::{
if_xdp::{RxTxDescriptor, UmemDescriptor},
umem::Umem,
};
use core::task::{Context, Poll};
use core::{
cell::UnsafeCell,
task::{Context, Poll},
};
use s2n_codec::DecoderBufferMut;
use s2n_quic_core::{
event,
Expand All @@ -28,18 +31,28 @@ pub trait ErrorLogger: Send {
}

pub struct Rx {
occupied: Occupied,
free: Free,
channels: UnsafeCell<Vec<(Occupied, Free)>>,
/// Store a vec of slices on the struct so we don't have to allocate every time `queue` is
/// called. Since this causes the type to be self-referential it does need a bit of unsafe code
/// to pull this off.
slices: UnsafeCell<
Vec<(
spsc::RecvSlice<'static, RxTxDescriptor>,
spsc::SendSlice<'static, UmemDescriptor>,
)>,
>,
umem: Umem,
error_logger: Option<Box<dyn ErrorLogger>>,
}

impl Rx {
/// Creates a RX IO interface for an s2n-quic endpoint
pub fn new(occupied: Occupied, free: Free, umem: Umem) -> Self {
pub fn new(channels: Vec<(Occupied, Free)>, umem: Umem) -> Self {
let slices = UnsafeCell::new(Vec::with_capacity(channels.len()));
let channels = UnsafeCell::new(channels);
Self {
occupied,
free,
channels,
slices,
umem,
error_logger: None,
}
Expand All @@ -60,13 +73,46 @@ impl rx::Rx for Rx {
#[inline]
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
// poll both channels to make sure we can make progress in both
let free = self.free.poll_slice(cx);
let occupied = self.occupied.poll_slice(cx);

ready!(free)?;
ready!(occupied)?;
let mut is_any_ready = false;
let mut is_all_occupied_closed = true;
let mut is_all_free_closed = true;

for (occupied, free) in self.channels.get_mut() {
let mut is_ready = true;

macro_rules! ready {
($slice:ident, $closed:ident) => {
match $slice.poll_slice(cx) {
Poll::Ready(Ok(_)) => {
$closed = false;
}
Poll::Ready(Err(_)) => {
// defer returning an error until all slices return one
}
Poll::Pending => {
$closed = false;
is_ready = false
}
}
};
}

ready!(occupied, is_all_occupied_closed);
ready!(free, is_all_free_closed);

is_any_ready |= is_ready;
}

Poll::Ready(Ok(()))
if is_all_occupied_closed || is_all_free_closed {
return Err(spsc::ClosedError).into();
}

if is_any_ready {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}

#[inline]
Expand All @@ -88,14 +134,21 @@ impl rx::Rx for Rx {
core::mem::transmute(self)
};

let occupied = this.occupied.slice();
let free = this.free.slice();
let slices = this.slices.get_mut();

for (occupied, free) in this.channels.get_mut().iter_mut() {
if occupied.is_empty() || free.capacity() == 0 {
continue;
}

slices.push((occupied.slice(), free.slice()));
}

let umem = &mut this.umem;
let error_logger = &mut this.error_logger;

let mut queue = Queue {
occupied,
free,
slices,
umem,
error_logger,
};
Expand All @@ -111,8 +164,10 @@ impl rx::Rx for Rx {
}

pub struct Queue<'a> {
occupied: spsc::RecvSlice<'a, RxTxDescriptor>,
free: spsc::SendSlice<'a, UmemDescriptor>,
slices: &'a mut Vec<(
spsc::RecvSlice<'a, RxTxDescriptor>,
spsc::SendSlice<'a, UmemDescriptor>,
)>,
umem: &'a mut Umem,
error_logger: &'a mut Option<Box<dyn ErrorLogger>>,
}
Expand All @@ -122,48 +177,58 @@ impl<'a> rx::Queue for Queue<'a> {

#[inline]
fn for_each<F: FnMut(datagram::Header<Self::Handle>, &mut [u8])>(&mut self, mut on_packet: F) {
// only pop as many items as we have capacity to free them
while self.free.capacity() > 0 {
let descriptor = match self.occupied.pop() {
Some(v) => v,
None => break,
};

let buffer = unsafe {
// Safety: this descriptor should be unique, assuming the tasks are functioning
// properly
self.umem.get_mut(descriptor)
};

// create a decoder from the descriptor's buffer
let decoder = DecoderBufferMut::new(buffer);

// try to decode the packet and emit the result
match decoder::decode_packet(decoder) {
Ok(Some((header, payload))) => {
on_packet(header, payload.into_less_safe_slice());
}
Ok(None) | Err(_) => {
// This shouldn't happen. If it does, the BPF program isn't properly validating
// packets before they get to userspace.
if let Some(error_logger) = self.error_logger.as_mut() {
error_logger.log_invalid_packet(buffer);
for (occupied, free) in self.slices.iter_mut() {
// only pop as many items as we have capacity to free them
while free.capacity() > 0 {
let descriptor = match occupied.pop() {
Some(v) => v,
None => break,
};

let buffer = unsafe {
// Safety: this descriptor should be unique, assuming the tasks are functioning
// properly
self.umem.get_mut(descriptor)
};

// create a decoder from the descriptor's buffer
let decoder = DecoderBufferMut::new(buffer);

// try to decode the packet and emit the result
match decoder::decode_packet(decoder) {
Ok(Some((header, payload))) => {
on_packet(header, payload.into_less_safe_slice());
}
Ok(None) | Err(_) => {
// This shouldn't happen. If it does, the BPF program isn't properly validating
// packets before they get to userspace.
if let Some(error_logger) = self.error_logger.as_mut() {
error_logger.log_invalid_packet(buffer);
}
}
}
}

// send the descriptor to the free queue
let result = self.free.push(descriptor.into());
// send the descriptor to the free queue
let result = free.push(descriptor.into());

debug_assert!(
result.is_ok(),
"free queue capacity should always exceed occupied"
);
debug_assert!(
result.is_ok(),
"free queue capacity should always exceed occupied"
);
}
}
}

#[inline]
fn is_empty(&self) -> bool {
self.occupied.is_empty()
self.slices.is_empty()
}
}

impl<'a> Drop for Queue<'a> {
#[inline]
fn drop(&mut self) {
// make sure we drop all of the slices to flush our changes
self.slices.clear();
}
}
90 changes: 56 additions & 34 deletions tools/xdp/s2n-quic-xdp/src/io/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,41 @@ use s2n_quic_core::{
/// Tests the s2n-quic-core IO trait implementations by sending packets over spsc channels
#[tokio::test]
async fn tx_rx_test() {
let frame_count = 16;
let mut umem = Umem::builder();
umem.frame_count = 16;
umem.frame_count = frame_count;
umem.frame_size = 128;
let umem = umem.build().unwrap();

// send a various amount of packets for each test
for packets in [1, 100, 1000, 10_000] {
let (input, tx_input) = spsc::channel(16);
let (mut rx_free, tx_free) = spsc::channel(32);
let (tx_occupied, rx_occupied) = spsc::channel(16);
let (rx_output, output) = spsc::channel(32);
for input_counts in [1, 2] {
eprintln!("packets: {packets}, input_counts: {input_counts}");

rx_free.slice().extend(&mut umem.frames()).unwrap();
let mut rx_inputs = vec![];
let mut tx_outputs = vec![];

tokio::spawn(packet_gen(packets, input));
tokio::spawn(send(tx_free, tx_occupied, umem.clone(), tx_input));
tokio::spawn(recv(rx_occupied, rx_free, umem.clone(), rx_output));
packet_checker(packets, output).await;
let mut frames = umem.frames();

for _ in 0..input_counts {
let (mut rx_free, tx_free) = spsc::channel(32);
let (tx_occupied, rx_occupied) = spsc::channel(16);

let mut rx_frames = (&mut frames).take((frame_count / input_counts) as usize);
rx_free.slice().extend(&mut rx_frames).unwrap();

tx_outputs.push((tx_free, tx_occupied));
rx_inputs.push((rx_occupied, rx_free));
}

let (input, tx_input) = spsc::channel(16);
let (rx_output, output) = spsc::channel(32);

tokio::spawn(packet_gen(packets, input));
tokio::spawn(send(tx_outputs, umem.clone(), tx_input));
tokio::spawn(recv(rx_inputs, umem.clone(), rx_output));
packet_checker(packets, output).await;
}
}
}

Expand Down Expand Up @@ -111,13 +128,12 @@ async fn packet_gen(count: u32, mut output: spsc::Sender<Packet>) {

/// Sends packets over the TX queue from an input channel
async fn send(
free: tx::Free,
occupied: tx::Occupied,
outputs: Vec<(tx::Free, tx::Occupied)>,
umem: Umem,
mut input: spsc::Receiver<Packet>,
) {
let state = Default::default();
let mut tx = Tx::new(free, occupied, umem, state);
let mut tx = Tx::new(outputs, umem, state);

loop {
let res = select(input.acquire(), tx.ready()).await;
Expand Down Expand Up @@ -149,27 +165,32 @@ async fn send(

trace!("send finishing");

let (mut free, occupied) = tx.consume();
let channels = tx.consume();

// notify the recv task that there aren't going to be any more packets sent
drop(occupied);
let free: Vec<_> = channels
.into_iter()
.map(|(mut free, occupied)| {
// notify the recv task that there aren't going to be any more packets sent
drop(occupied);

// drain the free queue so the `recv` task doesn't shut down prematurely
while free.acquire().await.is_ok() {
free.slice().clear();
}
async move {
// drain the free queue so the `recv` task doesn't shut down prematurely
while free.acquire().await.is_ok() {
free.slice().clear();
}
}
})
.collect();

// wait until all of the futures finish
futures::future::join_all(free).await;

trace!("shutting down send");
}

/// Receives raw packets and converts them into [`Packet`]s, putting them on the `output` channel.
async fn recv(
occupied: rx::Occupied,
free: rx::Free,
umem: Umem,
mut output: spsc::Sender<Packet>,
) {
let mut rx = Rx::new(occupied, free, umem);
async fn recv(inputs: Vec<(rx::Occupied, rx::Free)>, umem: Umem, mut output: spsc::Sender<Packet>) {
let mut rx = Rx::new(inputs, umem);

while rx.ready().await.is_ok() {
trace!("recv ready");
Expand Down Expand Up @@ -198,24 +219,25 @@ async fn recv(

/// Checks that the received [`Packet`]s match the expected values
async fn packet_checker(total: u32, mut output: spsc::Receiver<Packet>) {
let mut expected = 0;
let mut actual = s2n_quic_core::interval_set::IntervalSet::default();

while output.acquire().await.is_ok() {
let mut output = output.slice();
while let Some(packet) = output.pop() {
trace!("output packet recv: {packet:?}");

assert_eq!(
packet.counter, expected,
"packet counter should be sequential"
);
expected += 1;
actual.insert_value(packet.counter).unwrap();
}

// we want to consume the output queue as fast as possible so the `recv` task doesn't have
// to block on the checker
}

assert_eq!(total, expected, "total output packets does not match input");
assert_eq!(
total as usize,
actual.count(),
"total output packets does not match input"
);
}

/// Randomly yields to other tasks
Expand Down
Loading