From fd84effdc0008b98090642de96977f9bce952eba Mon Sep 17 00:00:00 2001 From: Cameron Bytheway Date: Thu, 25 May 2023 18:14:46 -0600 Subject: [PATCH] refactor(s2n-quic-platform): simplify Message trait --- quic/s2n-quic-core/src/inet/datagram.rs | 2 + quic/s2n-quic-platform/src/message.rs | 77 +++-- quic/s2n-quic-platform/src/message/macros.rs | 89 ------ quic/s2n-quic-platform/src/message/mmsg.rs | 70 +--- quic/s2n-quic-platform/src/message/msg.rs | 302 +++--------------- quic/s2n-quic-platform/src/message/msg/ext.rs | 92 ++++++ .../src/message/msg/handle.rs | 113 +++++++ quic/s2n-quic-platform/src/message/queue.rs | 59 ++-- .../src/message/queue/slice.rs | 4 +- quic/s2n-quic-platform/src/message/simple.rs | 43 ++- 10 files changed, 345 insertions(+), 506 deletions(-) delete mode 100644 quic/s2n-quic-platform/src/message/macros.rs create mode 100644 quic/s2n-quic-platform/src/message/msg/ext.rs create mode 100644 quic/s2n-quic-platform/src/message/msg/handle.rs diff --git a/quic/s2n-quic-core/src/inet/datagram.rs b/quic/s2n-quic-core/src/inet/datagram.rs index d93d343cfa..d272412a30 100644 --- a/quic/s2n-quic-core/src/inet/datagram.rs +++ b/quic/s2n-quic-core/src/inet/datagram.rs @@ -32,4 +32,6 @@ pub struct AncillaryData { /// Correctly threading this value through to connections ensures packets end up on the same /// network interfaces and thereby have consistent MAC addresses. pub local_interface: Option, + /// Set when the packet buffer is an aggregate of multiple received packets + pub segment_size: u16, } diff --git a/quic/s2n-quic-platform/src/message.rs b/quic/s2n-quic-platform/src/message.rs index 1ef011d4a2..373a5201fe 100644 --- a/quic/s2n-quic-platform/src/message.rs +++ b/quic/s2n-quic-platform/src/message.rs @@ -1,27 +1,29 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -#[macro_use] -mod macros; +use core::ffi::c_void; +use s2n_quic_core::{inet::datagram, io::tx, path}; +#[cfg(any(s2n_quic_platform_socket_msg, s2n_quic_platform_socket_mmsg))] +pub mod cmsg; #[cfg(s2n_quic_platform_socket_mmsg)] pub mod mmsg; - #[cfg(s2n_quic_platform_socket_msg)] pub mod msg; - -#[cfg(any(s2n_quic_platform_socket_msg, s2n_quic_platform_socket_mmsg))] -pub mod cmsg; - pub mod queue; pub mod simple; -use core::ffi::c_void; -use s2n_quic_core::{ - inet::{datagram, ExplicitCongestionNotification, SocketAddress}, - io::tx, - path, -}; +pub mod default { + cfg_if::cfg_if! { + if #[cfg(s2n_quic_platform_socket_mmsg)] { + pub use super::mmsg::*; + } else if #[cfg(s2n_quic_platform_socket_msg)] { + pub use super::msg::*; + } else { + pub use super::simple::*; + } + } +} /// An abstract message that can be sent and received on a network pub trait Message { @@ -29,21 +31,6 @@ pub trait Message { const SUPPORTS_GSO: bool; - /// Returns the ECN values for the message - fn ecn(&self) -> ExplicitCongestionNotification; - - /// Sets the ECN values for the message - fn set_ecn(&mut self, ecn: ExplicitCongestionNotification, remote_address: &SocketAddress); - - /// Returns the `SocketAddress` for the message - fn remote_address(&self) -> Option; - - /// Sets the `SocketAddress` for the message - fn set_remote_address(&mut self, remote_address: &SocketAddress); - - /// Returns the path handle for the message - fn path_handle(&self) -> Option; - /// Returns the length of the payload fn payload_len(&self) -> usize; @@ -63,18 +50,11 @@ pub trait Message { /// This should used in scenarios where the data pointers are the same. fn replicate_fields_from(&mut self, other: &Self); - /// Returns a pointer for the message payload - fn payload_ptr(&self) -> *const u8; - /// Returns a mutable pointer for the message payload fn payload_ptr_mut(&mut self) -> *mut u8; - /// Returns a slice for the message payload - fn payload(&self) -> &[u8] { - unsafe { core::slice::from_raw_parts(self.payload_ptr(), self.payload_len()) } - } - /// Returns a mutable slice for the message payload + #[inline] fn payload_mut(&mut self) -> &mut [u8] { unsafe { core::slice::from_raw_parts_mut(self.payload_ptr_mut(), self.payload_len()) } } @@ -101,10 +81,7 @@ pub trait Message { } /// Reads the message as an RX packet - fn rx_read( - &mut self, - local_address: &path::LocalAddress, - ) -> Option<(datagram::Header, &mut [u8])>; + fn rx_read(&mut self, local_address: &path::LocalAddress) -> Option>; /// Writes the message into the TX packet fn tx_write>( @@ -113,6 +90,26 @@ pub trait Message { ) -> Result; } +pub struct RxMessage<'a, Handle: Copy> { + /// The received header for the message + pub header: datagram::Header, + /// The number of segments inside the message + pub segment_size: usize, + /// The full payload of the message + pub payload: &'a mut [u8], +} + +impl<'a, Handle: Copy> RxMessage<'a, Handle> { + #[inline] + pub fn for_each, &mut [u8])>(self, mut on_packet: F) { + debug_assert_ne!(self.segment_size, 0); + + for segment in self.payload.chunks_mut(self.segment_size) { + on_packet(self.header, segment); + } + } +} + /// A message ring used to back a queue pub trait Ring { /// The type of message that is stored in the ring diff --git a/quic/s2n-quic-platform/src/message/macros.rs b/quic/s2n-quic-platform/src/message/macros.rs deleted file mode 100644 index bde936c16f..0000000000 --- a/quic/s2n-quic-platform/src/message/macros.rs +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#![allow(unused_macros)] - -macro_rules! impl_message_delegate { - ($name:ident, $field:tt, $field_ty:ty) => { - impl $crate::message::Message for $name { - type Handle = <$field_ty as $crate::message::Message>::Handle; - - const SUPPORTS_GSO: bool = <$field_ty as $crate::message::Message>::SUPPORTS_GSO; - - fn ecn(&self) -> ExplicitCongestionNotification { - $crate::message::Message::ecn(&self.$field) - } - - fn set_ecn( - &mut self, - ecn: ExplicitCongestionNotification, - remote_address: &SocketAddress, - ) { - $crate::message::Message::set_ecn(&mut self.$field, ecn, remote_address) - } - - fn remote_address(&self) -> Option { - $crate::message::Message::remote_address(&self.$field) - } - - fn set_remote_address(&mut self, remote_address: &SocketAddress) { - $crate::message::Message::set_remote_address(&mut self.$field, remote_address) - } - - fn path_handle(&self) -> Option { - $crate::message::Message::path_handle(&self.$field) - } - - fn payload_len(&self) -> usize { - $crate::message::Message::payload_len(&self.$field) - } - - unsafe fn set_payload_len(&mut self, payload_len: usize) { - $crate::message::Message::set_payload_len(&mut self.$field, payload_len) - } - - fn can_gso>(&self, other: &mut M) -> bool { - $crate::message::Message::can_gso(&self.$field, other) - } - - fn set_segment_size(&mut self, size: usize) { - $crate::message::Message::set_segment_size(&mut self.$field, size) - } - - unsafe fn reset(&mut self, mtu: usize) { - $crate::message::Message::reset(&mut self.$field, mtu) - } - - fn replicate_fields_from(&mut self, other: &Self) { - $crate::message::Message::replicate_fields_from(&mut self.$field, &other.$field) - } - - fn payload_ptr(&self) -> *const u8 { - $crate::message::Message::payload_ptr(&self.$field) - } - - fn payload_ptr_mut(&mut self) -> *mut u8 { - $crate::message::Message::payload_ptr_mut(&mut self.$field) - } - - #[inline] - fn rx_read( - &mut self, - local_address: &s2n_quic_core::path::LocalAddress, - ) -> Option<( - s2n_quic_core::inet::datagram::Header, - &mut [u8], - )> { - $crate::message::Message::rx_read(&mut self.$field, local_address) - } - - #[inline] - fn tx_write>( - &mut self, - message: M, - ) -> Result { - $crate::message::Message::tx_write(&mut self.$field, message) - } - } - }; -} diff --git a/quic/s2n-quic-platform/src/message/mmsg.rs b/quic/s2n-quic-platform/src/message/mmsg.rs index 5767f65a73..be4dcc5b34 100644 --- a/quic/s2n-quic-platform/src/message/mmsg.rs +++ b/quic/s2n-quic-platform/src/message/mmsg.rs @@ -6,71 +6,18 @@ use crate::message::{ Message as MessageTrait, }; use alloc::vec::Vec; -use core::{fmt, mem::zeroed}; +use core::mem::zeroed; use libc::mmsghdr; -use s2n_quic_core::{ - inet::{datagram, ExplicitCongestionNotification, SocketAddress}, - io::tx, - path, -}; - -#[repr(transparent)] -pub struct Message(pub(crate) mmsghdr); +use s2n_quic_core::{io::tx, path}; +pub use libc::mmsghdr as Message; pub type Handle = msg::Handle; -impl_message_delegate!(Message, 0, mmsghdr); - -impl fmt::Debug for Message { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let alt = f.alternate(); - let mut s = f.debug_struct("mmsghdr"); - - s.field("remote_address", &self.remote_address()).field( - "ancillary_data", - &crate::message::cmsg::decode(&self.0.msg_hdr), - ); - - if alt { - s.field("payload", &self.payload()); - } else { - s.field("payload_len", &self.payload_len()); - } - - s.finish() - } -} - impl MessageTrait for mmsghdr { type Handle = Handle; const SUPPORTS_GSO: bool = libc::msghdr::SUPPORTS_GSO; - #[inline] - fn ecn(&self) -> ExplicitCongestionNotification { - self.msg_hdr.ecn() - } - - #[inline] - fn set_ecn(&mut self, ecn: ExplicitCongestionNotification, remote_address: &SocketAddress) { - self.msg_hdr.set_ecn(ecn, remote_address) - } - - #[inline] - fn remote_address(&self) -> Option { - self.msg_hdr.remote_address() - } - - #[inline] - fn set_remote_address(&mut self, remote_address: &SocketAddress) { - self.msg_hdr.set_remote_address(remote_address) - } - - #[inline] - fn path_handle(&self) -> Option { - self.msg_hdr.path_handle() - } - #[inline] fn payload_len(&self) -> usize { self.msg_len as usize @@ -99,11 +46,6 @@ impl MessageTrait for mmsghdr { self.msg_hdr.reset(mtu) } - #[inline] - fn payload_ptr(&self) -> *const u8 { - self.msg_hdr.payload_ptr() - } - #[inline] fn payload_ptr_mut(&mut self) -> *mut u8 { self.msg_hdr.payload_ptr_mut() @@ -119,7 +61,7 @@ impl MessageTrait for mmsghdr { fn rx_read( &mut self, local_address: &path::LocalAddress, - ) -> Option<(datagram::Header, &mut [u8])> { + ) -> Option> { unsafe { // We need to replicate the `msg_len` field to the inner type before delegating // Safety: The `msg_len` is associated with the same buffer as the `msg_hdr` @@ -172,9 +114,9 @@ impl Ring { .map(|msg_hdr| unsafe { let mut mmsghdr = zeroed::(); let payload_len = msg_hdr.payload_len(); - mmsghdr.msg_hdr = msg_hdr.0; + mmsghdr.msg_hdr = msg_hdr; mmsghdr.set_payload_len(payload_len); - Message(mmsghdr) + mmsghdr }) .collect(); diff --git a/quic/s2n-quic-platform/src/message/msg.rs b/quic/s2n-quic-platform/src/message/msg.rs index 491ad7b1bc..a8de12564e 100644 --- a/quic/s2n-quic-platform/src/message/msg.rs +++ b/quic/s2n-quic-platform/src/message/msg.rs @@ -4,180 +4,49 @@ use crate::message::{cmsg, cmsg::Encoder, Message as MessageTrait}; use alloc::vec::Vec; use core::{ - fmt, mem::{size_of, zeroed}, pin::Pin, }; use libc::{c_void, iovec, msghdr, sockaddr_in, sockaddr_in6, AF_INET, AF_INET6}; use s2n_quic_core::{ inet::{ - datagram, AncillaryData, ExplicitCongestionNotification, IpV4Address, IpV6Address, - SocketAddress, SocketAddressV4, SocketAddressV6, + datagram, ExplicitCongestionNotification, IpV4Address, IpV6Address, SocketAddress, + SocketAddressV4, SocketAddressV6, }, io::tx, - path::{self, Handle as _, LocalAddress, RemoteAddress}, + path::{self, Handle as _}, }; -#[cfg(any(test, feature = "generator"))] -use bolero_generator::*; - -#[repr(transparent)] -pub struct Message(pub(crate) msghdr); - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(any(test, feature = "generator"), derive(TypeGenerator))] -pub struct Handle { - pub remote_address: RemoteAddress, - pub local_address: LocalAddress, -} - -impl Handle { - #[inline] - fn with_ancillary_data(&mut self, ancillary_data: AncillaryData) { - self.local_address = ancillary_data.local_address; - } - - #[inline] - pub(crate) fn update_msg_hdr(self, msghdr: &mut msghdr) { - // when sending a packet, we start out with no cmsg items - msghdr.msg_controllen = 0; - - msghdr.set_remote_address(&self.remote_address.0); - - #[cfg(s2n_quic_platform_pktinfo)] - match self.local_address.0 { - SocketAddress::IpV4(addr) => { - use s2n_quic_core::inet::Unspecified; - - let ip = addr.ip(); - - if ip.is_unspecified() { - return; - } - - let mut pkt_info = unsafe { core::mem::zeroed::() }; - pkt_info.ipi_spec_dst.s_addr = u32::from_ne_bytes((*ip).into()); - - msghdr.encode_cmsg(libc::IPPROTO_IP, libc::IP_PKTINFO, pkt_info); - } - SocketAddress::IpV6(addr) => { - use s2n_quic_core::inet::Unspecified; - - let ip = addr.ip(); - - if ip.is_unspecified() { - return; - } - - let mut pkt_info = unsafe { core::mem::zeroed::() }; - - pkt_info.ipi6_addr.s6_addr = (*ip).into(); - - msghdr.encode_cmsg(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pkt_info); - } - } - } -} - -impl path::Handle for Handle { - #[inline] - fn from_remote_address(remote_address: RemoteAddress) -> Self { - Self { - remote_address, - local_address: SocketAddressV4::UNSPECIFIED.into(), - } - } - - #[inline] - fn remote_address(&self) -> RemoteAddress { - self.remote_address - } - - #[inline] - fn local_address(&self) -> LocalAddress { - self.local_address - } - - #[inline] - fn eq(&self, other: &Self) -> bool { - let mut eq = true; - - // only compare local addresses if the OS returns them - if cfg!(s2n_quic_platform_pktinfo) { - eq &= self.local_address.eq(&other.local_address); - } - - eq && path::Handle::eq(&self.remote_address, &other.remote_address) - } - - #[inline] - fn strict_eq(&self, other: &Self) -> bool { - PartialEq::eq(self, other) - } - - #[inline] - fn maybe_update(&mut self, other: &Self) { - // once we discover our path, update the address local address - if self.local_address.port() == 0 { - self.local_address = other.local_address; - } - } -} +mod ext; +mod handle; -impl_message_delegate!(Message, 0, msghdr); +use ext::Ext as _; -impl fmt::Debug for Message { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let alt = f.alternate(); - let mut s = f.debug_struct("msghdr"); +pub use handle::Handle; +pub use libc::msghdr as Message; - s.field("remote_address", &self.remote_address()) - .field("anciliary_data", &cmsg::decode(&self.0)); - - if alt { - s.field("payload", &self.payload()); - } else { - s.field("payload_len", &self.payload_len()); - } - - s.finish() - } -} - -impl Message { - fn new( - iovec: *mut iovec, - msg_name: *mut c_void, - msg_namelen: usize, - msg_control: *mut c_void, - msg_controllen: usize, - ) -> Self { - let mut msghdr = unsafe { core::mem::zeroed::() }; - - msghdr.msg_iov = iovec; - msghdr.msg_iovlen = 1; // a single iovec is allocated per message - - msghdr.msg_name = msg_name; - msghdr.msg_namelen = msg_namelen as _; - - msghdr.msg_control = msg_control; - msghdr.msg_controllen = msg_controllen as _; +#[cfg(any(test, feature = "generator"))] +use bolero_generator::*; - Self(msghdr) - } +fn new( + iovec: *mut iovec, + msg_name: *mut c_void, + msg_namelen: usize, + msg_control: *mut c_void, + msg_controllen: usize, +) -> Message { + let mut msghdr = unsafe { core::mem::zeroed::() }; - #[inline] - pub(crate) fn header(msghdr: &msghdr) -> Option> { - let addr = msghdr.remote_address()?; - let mut path = Handle::from_remote_address(addr.into()); + msghdr.msg_iov = iovec; + msghdr.msg_iovlen = 1; // a single iovec is allocated per message - let ancillary_data = cmsg::decode(msghdr); - let ecn = ancillary_data.ecn; + msghdr.msg_name = msg_name; + msghdr.msg_namelen = msg_namelen as _; - path.with_ancillary_data(ancillary_data); + msghdr.msg_control = msg_control; + msghdr.msg_controllen = msg_controllen as _; - Some(datagram::Header { path, ecn }) - } + msghdr } impl MessageTrait for msghdr { @@ -185,82 +54,6 @@ impl MessageTrait for msghdr { const SUPPORTS_GSO: bool = cfg!(s2n_quic_platform_gso); - #[inline] - fn ecn(&self) -> ExplicitCongestionNotification { - let ancillary_data = cmsg::decode(self); - ancillary_data.ecn - } - - #[inline] - fn set_ecn(&mut self, ecn: ExplicitCongestionNotification, remote_address: &SocketAddress) { - if ecn == ExplicitCongestionNotification::NotEct { - return; - } - - let ecn = ecn as libc::c_int; - - // the remote address needs to be unmapped in order to set the appropriate cmsg - match remote_address.unmap() { - SocketAddress::IpV4(_) => { - // FreeBSD uses an unsigned_char for IP_TOS - // see https://svnweb.freebsd.org/base/stable/8/sys/netinet/ip_input.c?view=markup&pathrev=247944#l1716 - #[cfg(target_os = "freebsd")] - let ecn = ecn as libc::c_uchar; - - self.encode_cmsg(libc::IPPROTO_IP, libc::IP_TOS, ecn) - } - SocketAddress::IpV6(_) => self.encode_cmsg(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn), - }; - } - - #[inline] - fn remote_address(&self) -> Option { - debug_assert!(!self.msg_name.is_null()); - match self.msg_namelen as usize { - size if size == size_of::() => { - let sockaddr: &sockaddr_in = unsafe { &*(self.msg_name as *const _) }; - let port = sockaddr.sin_port.to_be(); - let addr: IpV4Address = sockaddr.sin_addr.s_addr.to_ne_bytes().into(); - Some(SocketAddressV4::new(addr, port).into()) - } - size if size == size_of::() => { - let sockaddr: &sockaddr_in6 = unsafe { &*(self.msg_name as *const _) }; - let port = sockaddr.sin6_port.to_be(); - let addr: IpV6Address = sockaddr.sin6_addr.s6_addr.into(); - Some(SocketAddressV6::new(addr, port).into()) - } - _ => None, - } - } - - #[inline] - fn set_remote_address(&mut self, remote_address: &SocketAddress) { - debug_assert!(!self.msg_name.is_null()); - - match remote_address { - SocketAddress::IpV4(addr) => { - let sockaddr: &mut sockaddr_in = unsafe { &mut *(self.msg_name as *mut _) }; - sockaddr.sin_family = AF_INET as _; - sockaddr.sin_port = addr.port().to_be(); - sockaddr.sin_addr.s_addr = u32::from_ne_bytes((*addr.ip()).into()); - self.msg_namelen = size_of::() as _; - } - SocketAddress::IpV6(addr) => { - let sockaddr: &mut sockaddr_in6 = unsafe { &mut *(self.msg_name as *mut _) }; - sockaddr.sin6_family = AF_INET6 as _; - sockaddr.sin6_port = addr.port().to_be(); - sockaddr.sin6_addr.s6_addr = (*addr.ip()).into(); - self.msg_namelen = size_of::() as _; - } - } - } - - #[inline] - fn path_handle(&self) -> Option { - let header = Message::header(self)?; - Some(header.path) - } - #[inline] fn payload_len(&self) -> usize { debug_assert!(!self.msg_iov.is_null()); @@ -275,7 +68,7 @@ impl MessageTrait for msghdr { #[inline] fn can_gso>(&self, other: &mut M) -> bool { - if let Some(header) = Message::header(self) { + if let Some((header, _cmsg)) = self.header() { let mut other_handle = *other.path_handle(); // when reading the header back from the msghdr, we don't know the port @@ -348,14 +141,6 @@ impl MessageTrait for msghdr { self.msg_controllen = other.msg_controllen; } - #[inline] - fn payload_ptr(&self) -> *const u8 { - unsafe { - let iovec = &*self.msg_iov; - iovec.iov_base as *const _ - } - } - #[inline] fn payload_ptr_mut(&mut self) -> *mut u8 { unsafe { @@ -368,8 +153,8 @@ impl MessageTrait for msghdr { fn rx_read( &mut self, local_address: &path::LocalAddress, - ) -> Option<(datagram::Header, &mut [u8])> { - let mut header = Message::header(self)?; + ) -> Option> { + let (mut header, cmsg) = self.header()?; // only copy the port if we are told the IP address if cfg!(s2n_quic_platform_pktinfo) { @@ -379,7 +164,20 @@ impl MessageTrait for msghdr { } let payload = self.payload_mut(); - Some((header, payload)) + + let segment_size = if cmsg.segment_size == 0 { + payload.len() + } else { + cmsg.segment_size as _ + }; + + let message = crate::message::RxMessage { + header, + segment_size, + payload, + }; + + Some(message) } #[inline] @@ -406,7 +204,7 @@ impl MessageTrait for msghdr { } pub struct Ring { - pub(crate) messages: Vec, + pub(crate) messages: Vec, pub(crate) storage: Storage, } @@ -496,7 +294,7 @@ impl Ring { iovec.iov_len = mtu; iovecs[index] = iovec; - let msg = Message::new( + let msg = new( (&mut iovecs[index]) as *mut _, (&mut msg_names[index]) as *mut _ as *mut _, size_of::(), @@ -508,7 +306,7 @@ impl Ring { } for index in 0..capacity { - messages.push(Message(messages[index].0)); + messages.push(messages[index]); } Self { @@ -579,7 +377,7 @@ mod tests { let mut iovec = unsafe { zeroed::() }; msghdr.msg_iov = &mut iovec; - let mut message = Message(msghdr); + let mut message = msghdr; check!() .with_type::() @@ -621,15 +419,15 @@ mod tests { msghdr.msg_controllen = cmsg_buf.len() as _; msghdr.msg_control = (&mut cmsg_buf[0]) as *mut u8 as _; - let mut message = Message(msghdr); + let mut message = msghdr; - handle.update_msg_hdr(&mut message.0); + handle.update_msg_hdr(&mut message); if segment_size > 1 { message.set_segment_size(segment_size); } - let header = Message::header(&message.0).unwrap(); + let (header, _cmsg) = message.header().unwrap(); assert_eq!(header.path.remote_address, handle.remote_address); @@ -642,7 +440,7 @@ mod tests { message.reset(0); } - let header = Message::header(&msghdr).unwrap(); + let (header, _cmsg) = msghdr.header().unwrap(); assert!(header.path.remote_address.is_unspecified()); }); } diff --git a/quic/s2n-quic-platform/src/message/msg/ext.rs b/quic/s2n-quic-platform/src/message/msg/ext.rs new file mode 100644 index 0000000000..518550ff99 --- /dev/null +++ b/quic/s2n-quic-platform/src/message/msg/ext.rs @@ -0,0 +1,92 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; + +pub trait Ext: cmsg::Encoder { + fn header(&self) -> Option<(datagram::Header, datagram::AncillaryData)>; + fn set_ecn(&mut self, ecn: ExplicitCongestionNotification, remote_address: &SocketAddress); + fn remote_address(&self) -> Option; + fn set_remote_address(&mut self, remote_address: &SocketAddress); +} + +impl Ext for msghdr { + #[inline] + fn header(&self) -> Option<(datagram::Header, datagram::AncillaryData)> { + let addr = self.remote_address()?; + let mut path = Handle::from_remote_address(addr.into()); + + let ancillary_data = cmsg::decode(self); + let ecn = ancillary_data.ecn; + + path.with_ancillary_data(ancillary_data); + + let header = datagram::Header { path, ecn }; + + Some((header, ancillary_data)) + } + + #[inline] + fn set_ecn(&mut self, ecn: ExplicitCongestionNotification, remote_address: &SocketAddress) { + if ecn == ExplicitCongestionNotification::NotEct { + return; + } + + let ecn = ecn as libc::c_int; + + // the remote address needs to be unmapped in order to set the appropriate cmsg + match remote_address.unmap() { + SocketAddress::IpV4(_) => { + // FreeBSD uses an unsigned_char for IP_TOS + // see https://svnweb.freebsd.org/base/stable/8/sys/netinet/ip_input.c?view=markup&pathrev=247944#l1716 + #[cfg(target_os = "freebsd")] + let ecn = ecn as libc::c_uchar; + + self.encode_cmsg(libc::IPPROTO_IP, libc::IP_TOS, ecn) + } + SocketAddress::IpV6(_) => self.encode_cmsg(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn), + }; + } + + #[inline] + fn remote_address(&self) -> Option { + debug_assert!(!self.msg_name.is_null()); + match self.msg_namelen as usize { + size if size == size_of::() => { + let sockaddr: &sockaddr_in = unsafe { &*(self.msg_name as *const _) }; + let port = sockaddr.sin_port.to_be(); + let addr: IpV4Address = sockaddr.sin_addr.s_addr.to_ne_bytes().into(); + Some(SocketAddressV4::new(addr, port).into()) + } + size if size == size_of::() => { + let sockaddr: &sockaddr_in6 = unsafe { &*(self.msg_name as *const _) }; + let port = sockaddr.sin6_port.to_be(); + let addr: IpV6Address = sockaddr.sin6_addr.s6_addr.into(); + Some(SocketAddressV6::new(addr, port).into()) + } + _ => None, + } + } + + #[inline] + fn set_remote_address(&mut self, remote_address: &SocketAddress) { + debug_assert!(!self.msg_name.is_null()); + + match remote_address { + SocketAddress::IpV4(addr) => { + let sockaddr: &mut sockaddr_in = unsafe { &mut *(self.msg_name as *mut _) }; + sockaddr.sin_family = AF_INET as _; + sockaddr.sin_port = addr.port().to_be(); + sockaddr.sin_addr.s_addr = u32::from_ne_bytes((*addr.ip()).into()); + self.msg_namelen = size_of::() as _; + } + SocketAddress::IpV6(addr) => { + let sockaddr: &mut sockaddr_in6 = unsafe { &mut *(self.msg_name as *mut _) }; + sockaddr.sin6_family = AF_INET6 as _; + sockaddr.sin6_port = addr.port().to_be(); + sockaddr.sin6_addr.s6_addr = (*addr.ip()).into(); + self.msg_namelen = size_of::() as _; + } + } + } +} diff --git a/quic/s2n-quic-platform/src/message/msg/handle.rs b/quic/s2n-quic-platform/src/message/msg/handle.rs new file mode 100644 index 0000000000..d7cec5698b --- /dev/null +++ b/quic/s2n-quic-platform/src/message/msg/handle.rs @@ -0,0 +1,113 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::ext::Ext as _; +use crate::message::cmsg::Encoder; +use libc::msghdr; +use s2n_quic_core::{ + inet::{AncillaryData, SocketAddress, SocketAddressV4}, + path::{self, LocalAddress, RemoteAddress}, +}; + +#[cfg(any(test, feature = "generator"))] +use bolero_generator::*; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(any(test, feature = "generator"), derive(TypeGenerator))] +pub struct Handle { + pub remote_address: RemoteAddress, + pub local_address: LocalAddress, +} + +impl Handle { + #[inline] + pub(super) fn with_ancillary_data(&mut self, ancillary_data: AncillaryData) { + self.local_address = ancillary_data.local_address; + } + + #[inline] + pub(super) fn update_msg_hdr(&self, msghdr: &mut msghdr) { + // when sending a packet, we start out with no cmsg items + msghdr.msg_controllen = 0; + + msghdr.set_remote_address(&self.remote_address.0); + + #[cfg(s2n_quic_platform_pktinfo)] + match self.local_address.0 { + SocketAddress::IpV4(addr) => { + use s2n_quic_core::inet::Unspecified; + + let ip = addr.ip(); + + if ip.is_unspecified() { + return; + } + + let mut pkt_info = unsafe { core::mem::zeroed::() }; + pkt_info.ipi_spec_dst.s_addr = u32::from_ne_bytes((*ip).into()); + + msghdr.encode_cmsg(libc::IPPROTO_IP, libc::IP_PKTINFO, pkt_info); + } + SocketAddress::IpV6(addr) => { + use s2n_quic_core::inet::Unspecified; + + let ip = addr.ip(); + + if ip.is_unspecified() { + return; + } + + let mut pkt_info = unsafe { core::mem::zeroed::() }; + + pkt_info.ipi6_addr.s6_addr = (*ip).into(); + + msghdr.encode_cmsg(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pkt_info); + } + } + } +} + +impl path::Handle for Handle { + #[inline] + fn from_remote_address(remote_address: RemoteAddress) -> Self { + Self { + remote_address, + local_address: SocketAddressV4::UNSPECIFIED.into(), + } + } + + #[inline] + fn remote_address(&self) -> RemoteAddress { + self.remote_address + } + + #[inline] + fn local_address(&self) -> LocalAddress { + self.local_address + } + + #[inline] + fn eq(&self, other: &Self) -> bool { + let mut eq = true; + + // only compare local addresses if the OS returns them + if cfg!(s2n_quic_platform_pktinfo) { + eq &= self.local_address.eq(&other.local_address); + } + + eq && path::Handle::eq(&self.remote_address, &other.remote_address) + } + + #[inline] + fn strict_eq(&self, other: &Self) -> bool { + PartialEq::eq(self, other) + } + + #[inline] + fn maybe_update(&mut self, other: &Self) { + // once we discover our path, update the address local address + if self.local_address.port() == 0 { + self.local_address = other.local_address; + } + } +} diff --git a/quic/s2n-quic-platform/src/message/queue.rs b/quic/s2n-quic-platform/src/message/queue.rs index e382bd9e33..5d548e3852 100644 --- a/quic/s2n-quic-platform/src/message/queue.rs +++ b/quic/s2n-quic-platform/src/message/queue.rs @@ -164,28 +164,11 @@ mod tests { use super::*; use crate::{buffer::VecBuffer, message::Message}; use bolero::{check, generator::*}; - use s2n_quic_core::inet; + use s2n_quic_core::path::{self, Handle}; use std::collections::VecDeque; const MTU: usize = 1200; - - fn set(message: &mut M, value: u8, len: usize) { - assert_eq!( - message.payload_len(), - MTU, - "payload len should be reset for free messages" - ); - unsafe { - message.set_payload_len(len); - } - for b in message.payload_mut().iter_mut() { - *b = value; - } - } - - fn gen_address() -> impl ValueGenerator { - gen() - } + const MAX_PAYLOAD: usize = 32; #[derive(Clone, Copy, Debug, TypeGenerator)] enum Operation { @@ -195,11 +178,10 @@ mod tests { count: usize, /// Length of the payload to be pushed - #[generator(1..32)] + #[generator(1..MAX_PAYLOAD)] len: usize, - #[generator(gen_address())] - address: inet::SocketAddress, + address: path::RemoteAddress, /// true if the operation is successful success: bool, @@ -215,6 +197,7 @@ mod tests { } fn check(mut queue: Queue, capacity: usize, ops: &[Operation]) { + let mut payload_buffer = [0u8; MAX_PAYLOAD]; let mut oracle = VecDeque::new(); let mut value = 0u8; for op in ops { @@ -227,20 +210,22 @@ mod tests { } => { let mut free = queue.free_mut(); let count = count.min(free.len()); - let mut payload = value; // push messages onto the queue and the oracle for message in &mut free[..count] { - set(message, payload, len); - - message.set_remote_address(&address); - oracle.push_back((address, len, payload)); - payload = payload.wrapping_add(1); + for byte in &mut payload_buffer[..len] { + *byte = value; + } + + let address = Handle::from_remote_address(address); + let output = (address, &payload_buffer[..len]); + message.tx_write(output).unwrap(); + oracle.push_back((address, len, value)); + value = value.wrapping_add(1); } // if successful, finish the slice, otherwise cancel if success { - value = payload; free.finish(count); } else { oracle.drain((oracle.len() - count)..); @@ -265,15 +250,17 @@ mod tests { assert_eq!(capacity, queue.occupied_len() + queue.free_len()); // assert the queue matches the oracle - let occupied = queue.occupied_mut(); + let mut occupied = queue.occupied_mut(); assert_eq!(oracle.len(), occupied.len()); - for (message, (address, len, value)) in occupied.iter().zip(oracle.iter()) { - let address = *address; - - assert_eq!(message.remote_address(), Some(address)); - assert_eq!(message.payload_len(), *len); - assert!(message.payload().iter().all(|v| v == value)); + for (message, (address, len, value)) in occupied.iter_mut().zip(oracle.iter()) { + let local_address = LocalAddress(Default::default()); + let message = message.rx_read(&local_address).unwrap(); + message.for_each(|header, payload| { + assert!(header.path.eq(address)); + assert_eq!(payload.len(), *len); + assert!(payload.iter().all(|v| v == value)); + }); } } } diff --git a/quic/s2n-quic-platform/src/message/queue/slice.rs b/quic/s2n-quic-platform/src/message/queue/slice.rs index 5931fec3bd..f22f1fd762 100644 --- a/quic/s2n-quic-platform/src/message/queue/slice.rs +++ b/quic/s2n-quic-platform/src/message/queue/slice.rs @@ -277,8 +277,8 @@ impl<'a, Message: message::Message, B: Behavior, H: path::Handle> rx // iterate over the filled packets and invoke the callback for each one let messages = &mut self.messages[range]; for message in messages { - if let Some((header, payload)) = message.rx_read(self.local_address) { - on_packet(header, payload); + if let Some(message) = message.rx_read(self.local_address) { + message.for_each(&mut on_packet); } } diff --git a/quic/s2n-quic-platform/src/message/simple.rs b/quic/s2n-quic-platform/src/message/simple.rs index 1e70fc7dad..08865db5d8 100644 --- a/quic/s2n-quic-platform/src/message/simple.rs +++ b/quic/s2n-quic-platform/src/message/simple.rs @@ -7,7 +7,7 @@ use core::pin::Pin; use s2n_quic_core::{ inet::{datagram, ExplicitCongestionNotification, SocketAddress}, io::tx, - path::{self, Handle as _}, + path, }; /// A simple message type that holds an address and payload @@ -20,34 +20,28 @@ pub struct Message { payload_len: usize, } -pub type Handle = path::Tuple; - -impl MessageTrait for Message { - type Handle = Handle; - - const SUPPORTS_GSO: bool = false; - +impl Message { fn ecn(&self) -> ExplicitCongestionNotification { ExplicitCongestionNotification::default() } - fn set_ecn(&mut self, _ecn: ExplicitCongestionNotification, _remote_address: &SocketAddress) { - // the std UDP socket doesn't provide a method to set ECN - } - - fn remote_address(&self) -> Option { + pub(crate) fn remote_address(&self) -> Option { Some(self.address) } - fn set_remote_address(&mut self, remote_address: &SocketAddress) { + pub(crate) fn set_remote_address(&mut self, remote_address: &SocketAddress) { let remote_address = *remote_address; self.address = remote_address; } +} - fn path_handle(&self) -> Option { - Some(Handle::from_remote_address(self.address.into())) - } +pub type Handle = path::Tuple; + +impl MessageTrait for Message { + type Handle = Handle; + + const SUPPORTS_GSO: bool = false; fn payload_len(&self) -> usize { self.payload_len @@ -66,10 +60,6 @@ impl MessageTrait for Message { self.set_payload_len(mtu) } - fn payload_ptr(&self) -> *const u8 { - self.payload_ptr as *const _ - } - fn payload_ptr_mut(&mut self) -> *mut u8 { self.payload_ptr } @@ -84,7 +74,7 @@ impl MessageTrait for Message { fn rx_read( &mut self, local_address: &path::LocalAddress, - ) -> Option<(datagram::Header, &mut [u8])> { + ) -> Option> { let path = path::Tuple { remote_address: self.address.into(), local_address: *local_address, @@ -94,7 +84,14 @@ impl MessageTrait for Message { ecn: self.ecn(), }; let payload = self.payload_mut(); - Some((header, payload)) + + let message = super::RxMessage { + header, + segment_size: payload.len(), + payload, + }; + + Some(message) } #[inline]