diff --git a/CHANGELOG.md b/CHANGELOG.md index b1f74a89f4..de01c2f7f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,6 +85,9 @@ This project adheres to [Semantic Versioning](https://semver.org/). ### Changed +- Reimplemented sendmmsg/recvmmsg to avoid allocations and with better API + (#[1744](https://github.com/nix-rust/nix/pull/1744)) + - Rewrote the aio module. The new module: * Does more type checking at compile time rather than runtime. * Gives the caller control over whether and when to `Box` an aio operation. diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 461607d12f..04a6f937a4 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -1,21 +1,23 @@ //! Socket interface functions //! //! [Further reading](https://man7.org/linux/man-pages/man7/socket.7.html) -use cfg_if::cfg_if; -use crate::{Result, errno::Errno}; -use libc::{self, c_void, c_int, iovec, socklen_t, size_t, - CMSG_FIRSTHDR, CMSG_NXTHDR, CMSG_DATA, CMSG_LEN}; -use std::convert::TryInto; -use std::{mem, ptr, slice}; -use std::os::unix::io::RawFd; -#[cfg(feature = "net")] -use std::net; #[cfg(target_os = "linux")] #[cfg(feature = "uio")] use crate::sys::time::TimeSpec; #[cfg(feature = "uio")] use crate::sys::time::TimeVal; +use crate::{errno::Errno, Result}; +use cfg_if::cfg_if; +use libc::{ + self, c_int, c_void, iovec, size_t, socklen_t, CMSG_DATA, CMSG_FIRSTHDR, + CMSG_LEN, CMSG_NXTHDR, +}; +use std::convert::TryInto; use std::io::{IoSlice, IoSliceMut}; +#[cfg(feature = "net")] +use std::net; +use std::os::unix::io::RawFd; +use std::{mem, ptr, slice}; #[deny(missing_docs)] mod addr; @@ -32,60 +34,44 @@ pub use self::addr::{SockaddrLike, SockaddrStorage}; #[cfg(not(any(target_os = "illumos", target_os = "solaris")))] #[allow(deprecated)] -pub use self::addr::{ - AddressFamily, - SockAddr, - UnixAddr, -}; -#[allow(deprecated)] -#[cfg(not(any(target_os = "illumos", target_os = "solaris", target_os = "haiku")))] -#[cfg(feature = "net")] -pub use self::addr::{ - InetAddr, - IpAddr, - Ipv4Addr, - Ipv6Addr, - LinkAddr, - SockaddrIn, - SockaddrIn6 -}; +pub use self::addr::{AddressFamily, SockAddr, UnixAddr}; #[cfg(any(target_os = "illumos", target_os = "solaris"))] #[allow(deprecated)] +pub use self::addr::{AddressFamily, SockAddr, UnixAddr}; +#[allow(deprecated)] +#[cfg(not(any( + target_os = "illumos", + target_os = "solaris", + target_os = "haiku" +)))] +#[cfg(feature = "net")] pub use self::addr::{ - AddressFamily, - SockAddr, - UnixAddr, + InetAddr, IpAddr, Ipv4Addr, Ipv6Addr, LinkAddr, SockaddrIn, SockaddrIn6, }; #[allow(deprecated)] -#[cfg(any(target_os = "illumos", target_os = "solaris", target_os = "haiku"))] +#[cfg(any( + target_os = "illumos", + target_os = "solaris", + target_os = "haiku" +))] #[cfg(feature = "net")] pub use self::addr::{ - InetAddr, - IpAddr, - Ipv4Addr, - Ipv6Addr, - SockaddrIn, - SockaddrIn6 + InetAddr, IpAddr, Ipv4Addr, Ipv6Addr, SockaddrIn, SockaddrIn6, }; +#[cfg(any(target_os = "android", target_os = "linux"))] +pub use crate::sys::socket::addr::alg::AlgAddr; +#[cfg(any(target_os = "android", target_os = "linux"))] +pub use crate::sys::socket::addr::netlink::NetlinkAddr; #[cfg(any(target_os = "ios", target_os = "macos"))] #[cfg(feature = "ioctl")] pub use crate::sys::socket::addr::sys_control::SysControlAddr; #[cfg(any(target_os = "android", target_os = "linux"))] -pub use crate::sys::socket::addr::netlink::NetlinkAddr; -#[cfg(any(target_os = "android", target_os = "linux"))] -pub use crate::sys::socket::addr::alg::AlgAddr; -#[cfg(any(target_os = "android", target_os = "linux"))] pub use crate::sys::socket::addr::vsock::VsockAddr; #[cfg(feature = "uio")] pub use libc::{cmsghdr, msghdr}; -pub use libc::{ - sa_family_t, - sockaddr, - sockaddr_storage, - sockaddr_un, -}; +pub use libc::{sa_family_t, sockaddr, sockaddr_storage, sockaddr_un}; #[cfg(feature = "net")] pub use libc::{sockaddr_in, sockaddr_in6}; @@ -245,7 +231,7 @@ libc_bitflags! { } } -libc_bitflags!{ +libc_bitflags! { /// Additional socket options pub struct SockFlag: c_int { /// Set non-blocking mode on the new socket @@ -280,7 +266,7 @@ libc_bitflags!{ } } -libc_bitflags!{ +libc_bitflags! { /// Flags for send/recv and their relatives pub struct MsgFlags: c_int { /// Sends or requests out-of-band data on sockets that support this notion @@ -462,7 +448,7 @@ cfg_if! { } } -cfg_if!{ +cfg_if! { if #[cfg(any( target_os = "dragonfly", target_os = "freebsd", @@ -581,15 +567,20 @@ macro_rules! cmsg_space { } #[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub struct RecvMsg<'a, S> { +/// Contains outcome of sending or receiving a message +/// +/// Use [`cmsgs`][RecvMsg::cmsgs] to access all the control messages present, and +/// [`iovs`][RecvMsg::iovs`] to access underlying io slices. +pub struct RecvMsg<'a, 's, S> { pub bytes: usize, cmsghdr: Option<&'a cmsghdr>, pub address: Option, pub flags: MsgFlags, + iobufs: std::marker::PhantomData<& 's()>, mhdr: msghdr, } -impl<'a, S> RecvMsg<'a, S> { +impl<'a, S> RecvMsg<'a, '_, S> { /// Iterate over the valid control messages pointed to by this /// msghdr. pub fn cmsgs(&self) -> CmsgIterator { @@ -1468,24 +1459,6 @@ pub fn sendmsg(fd: RawFd, iov: &[IoSlice<'_>], cmsgs: &[ControlMessage], Errno::result(ret).map(|r| r as usize) } -#[cfg(any( - target_os = "linux", - target_os = "android", - target_os = "freebsd", - target_os = "netbsd", -))] -#[derive(Debug)] -pub struct SendMmsgData<'a, I, C, S> - where - I: AsRef<[IoSlice<'a>]>, - C: AsRef<[ControlMessage<'a>]>, - S: SockaddrLike + 'a -{ - pub iov: I, - pub cmsgs: C, - pub addr: Option, - pub _lt: std::marker::PhantomData<&'a I>, -} /// An extension of `sendmsg` that allows the caller to transmit multiple /// messages on a socket using a single system call. This has performance @@ -1510,51 +1483,66 @@ pub struct SendMmsgData<'a, I, C, S> target_os = "freebsd", target_os = "netbsd", ))] -pub fn sendmmsg<'a, I, C, S>( +pub fn sendmmsg<'a, XS, AS, C, I, S>( fd: RawFd, - data: impl std::iter::IntoIterator>, + data: &'a mut MultiHeaders, + slices: XS, + // one address per group of slices + addrs: AS, + // shared across all the messages + cmsgs: C, flags: MsgFlags -) -> Result> +) -> crate::Result> where + XS: IntoIterator, + AS: AsRef<[Option]>, I: AsRef<[IoSlice<'a>]> + 'a, C: AsRef<[ControlMessage<'a>]> + 'a, S: SockaddrLike + 'a { - let iter = data.into_iter(); - let size_hint = iter.size_hint(); - let reserve_items = size_hint.1.unwrap_or(size_hint.0); + let mut count = 0; - let mut output = Vec::::with_capacity(reserve_items); - let mut cmsgs_buffers = Vec::>::with_capacity(reserve_items); + for (i, ((slice, addr), mmsghdr)) in slices.into_iter().zip(addrs.as_ref()).zip(data.items.iter_mut() ).enumerate() { + let mut p = &mut mmsghdr.msg_hdr; + p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec; + p.msg_iovlen = slice.as_ref().len() as _; - for d in iter { - let capacity: usize = d.cmsgs.as_ref().iter().map(|c| c.space()).sum(); - let mut cmsgs_buffer = vec![0u8; capacity]; + p.msg_namelen = addr.as_ref().map_or(0, S::len); + p.msg_name = addr.as_ref().map_or(ptr::null(), S::as_ptr) as _; - output.push(libc::mmsghdr { - msg_hdr: pack_mhdr_to_send( - &mut cmsgs_buffer, - &d.iov, - &d.cmsgs, - d.addr.as_ref() - ), - msg_len: 0, - }); - cmsgs_buffers.push(cmsgs_buffer); - }; + // Encode each cmsg. This must happen after initializing the header because + // CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields. + // CMSG_FIRSTHDR is always safe + let mut pmhdr: *mut cmsghdr = unsafe { CMSG_FIRSTHDR(p) }; + for cmsg in cmsgs.as_ref() { + assert_ne!(pmhdr, ptr::null_mut()); + // Safe because we know that pmhdr is valid, and we initialized it with + // sufficient space + unsafe { cmsg.encode_into(pmhdr) }; + // Safe because mhdr is valid + pmhdr = unsafe { CMSG_NXTHDR(p, pmhdr) }; + } - let ret = unsafe { libc::sendmmsg(fd, output.as_mut_ptr(), output.len() as _, flags.bits() as _) }; + count = i+1; + } - let sent_messages = Errno::result(ret)? as usize; - let mut sent_bytes = Vec::with_capacity(sent_messages); + let sent = Errno::result(unsafe { + libc::sendmmsg( + fd, + data.items.as_mut_ptr(), + count as _, + flags.bits() as _ + ) + })? as usize; - for item in &output { - sent_bytes.push(item.msg_len as usize); - } + Ok(MultiResults { + rmm: data, + current_index: 0, + received: sent + }) - Ok(sent_bytes) } @@ -1565,138 +1553,345 @@ pub fn sendmmsg<'a, I, C, S>( target_os = "netbsd", ))] #[derive(Debug)] -pub struct RecvMmsgData<'a, I> +/// Preallocated structures needed for [`recvmmsg`] and [`sendmmsg`] functions +pub struct MultiHeaders { + // preallocated boxed slice of mmsghdr + items: Box<[libc::mmsghdr]>, + addresses: Box<[mem::MaybeUninit]>, + // while we are not using it directly - this is used to store control messages + // and we retain pointers to them inside items array + #[allow(dead_code)] + cmsg_buffers: Option>, + msg_controllen: usize, +} + +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "netbsd", +))] +impl MultiHeaders { + /// Preallocate structure used by [`recvmmsg`] and [`sendmmsg`] takes number of headers to preallocate + /// + /// `cmsg_buffer` should be created with [`cmsg_space!`] if needed + pub fn preallocate(num_slices: usize, cmsg_buffer: Option>) -> Self where - I: AsRef<[IoSliceMut<'a>]> + 'a, -{ - pub iov: I, - pub cmsg_buffer: Option<&'a mut Vec>, + S: Copy + SockaddrLike, + { + // we will be storing pointers to addresses inside mhdr - convert it into boxed + // slice so it can'be changed later by pushing anything into self.addresses + let mut addresses = vec![std::mem::MaybeUninit::uninit(); num_slices].into_boxed_slice(); + + let msg_controllen = cmsg_buffer.as_ref().map_or(0, |v| v.capacity()); + + // we'll need a cmsg_buffer for each slice, we preallocate a vector and split + // it into "slices" parts + let cmsg_buffers = + cmsg_buffer.map(|v| vec![0u8; v.capacity() * num_slices].into_boxed_slice()); + + let items = addresses + .iter_mut() + .enumerate() + .map(|(ix, address)| { + let (ptr, cap) = match &cmsg_buffers { + Some(v) => ((&v[ix * msg_controllen] as *const u8), msg_controllen), + None => (std::ptr::null(), 0), + }; + let msg_hdr = unsafe { pack_mhdr_to_receive(std::ptr::null(), 0, ptr, cap, address.as_mut_ptr()) }; + libc::mmsghdr { + msg_hdr, + msg_len: 0, + } + }) + .collect::>(); + + Self { + items: items.into_boxed_slice(), + addresses, + cmsg_buffers, + msg_controllen, + } + } } -/// An extension of `recvmsg` that allows the caller to receive multiple -/// messages from a socket using a single system call. This has -/// performance benefits for some applications. -/// -/// `iov` and `cmsg_buffer` should be constructed similarly to `recvmsg` +/// An extension of recvmsg that allows the caller to receive multiple messages from a socket using a single system call. /// -/// Multiple allocations are performed +/// This has performance benefits for some applications. /// -/// # Arguments +/// This method performs no allocations. /// -/// * `fd`: Socket file descriptor -/// * `data`: Struct that implements `IntoIterator` with `RecvMmsgData` items -/// * `flags`: Optional flags passed directly to the operating system. -/// -/// # RecvMmsgData -/// -/// * `iov`: Scatter-gather list of buffers to receive the message -/// * `cmsg_buffer`: Space to receive ancillary data. Should be created by -/// [`cmsg_space!`](../../macro.cmsg_space.html) +/// Returns an iterator producing [`RecvMsg`], one per received messages. Each `RecvMsg` can produce +/// iterators over [`IoSlice`] with [`iovs`][RecvMsg::iovs`] and +/// `ControlMessageOwned` with [`cmsgs`][RecvMsg::cmsgs]. /// -/// # Returns -/// A `Vec` with multiple `RecvMsg`, one per received message +/// # Bugs (in underlying implementation, at least in Linux) +/// The timeout argument does not work as intended. The timeout is checked only after the receipt +/// of each datagram, so that if up to `vlen`-1 datagrams are received before the timeout expires, +/// but then no further datagrams are received, the call will block forever. /// -/// # References -/// - [`recvmsg`](fn.recvmsg.html) -/// - [`RecvMsg`](struct.RecvMsg.html) +/// If an error occurs after at least one message has been received, the call succeeds, and returns +/// the number of messages received. The error code is expected to be returned on a subsequent +/// call to recvmmsg(). In the current implementation, however, the error code can be +/// overwritten in the meantime by an unrelated network event on a socket, for example an +/// incoming ICMP packet. + +// On aarch64 linux using recvmmsg and trying to get hardware/kernel timestamps might not +// always produce the desired results - see https://github.com/nix-rust/nix/pull/1744 for more +// details + #[cfg(any( target_os = "linux", target_os = "android", target_os = "freebsd", target_os = "netbsd", ))] -#[allow(clippy::needless_collect)] // Complicated false positive -pub fn recvmmsg<'a, I, S>( +pub fn recvmmsg<'a, XS, S, I>( fd: RawFd, - data: impl std::iter::IntoIterator, - IntoIter=impl ExactSizeIterator + Iterator>>, + data: &'a mut MultiHeaders, + slices: XS, flags: MsgFlags, - timeout: Option -) -> Result>> - where - I: AsRef<[IoSliceMut<'a>]> + 'a, - S: Copy + SockaddrLike + 'a + mut timeout: Option, +) -> crate::Result> +where + XS: IntoIterator, + I: AsRef<[IoSliceMut<'a>]> + 'a, { - let iter = data.into_iter(); + let mut count = 0; + for (i, (slice, mmsghdr)) in slices.into_iter().zip(data.items.iter_mut()).enumerate() { + let mut p = &mut mmsghdr.msg_hdr; + p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec; + p.msg_iovlen = slice.as_ref().len() as _; + count = i + 1; + } - let num_messages = iter.len(); + let timeout_ptr = timeout + .as_mut() + .map_or_else(std::ptr::null_mut, |t| t as *mut _ as *mut libc::timespec); - let mut output: Vec = Vec::with_capacity(num_messages); + let received = Errno::result(unsafe { + libc::recvmmsg( + fd, + data.items.as_mut_ptr(), + count as _, + flags.bits() as _, + timeout_ptr, + ) + })? as usize; - // Addresses should be pre-allocated. pack_mhdr_to_receive will store them - // as raw pointers, so we may not move them. Turn the vec into a boxed - // slice so we won't inadvertently reallocate the vec. - let mut addresses = vec![mem::MaybeUninit::uninit(); num_messages] - .into_boxed_slice(); + Ok(MultiResults { + rmm: data, + current_index: 0, + received, + }) +} - let results: Vec<_> = iter.enumerate().map(|(i, d)| { - let (msg_controllen, mhdr) = unsafe { - pack_mhdr_to_receive( - d.iov.as_ref(), - &mut d.cmsg_buffer, - addresses[i].as_mut_ptr(), +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "netbsd", +))] +#[derive(Debug)] +/// Iterator over results of [`recvmmsg`]/[`sendmmsg`] +/// +/// +pub struct MultiResults<'a, S> { + // preallocated structures + rmm: &'a MultiHeaders, + current_index: usize, + received: usize, +} + +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "netbsd", +))] +impl<'a, S> Iterator for MultiResults<'a, S> +where + S: Copy + SockaddrLike, +{ + type Item = RecvMsg<'a, 'a, S>; + + fn next(&mut self) -> Option { + if self.current_index >= self.received { + return None; + } + let mmsghdr = self.rmm.items[self.current_index]; + + // as long as we are not reading past the index writen by recvmmsg - address + // will be initialized + let address = unsafe { self.rmm.addresses[self.current_index].assume_init() }; + + self.current_index += 1; + Some(unsafe { + read_mhdr( + mmsghdr.msg_hdr, + mmsghdr.msg_len as isize, + self.rmm.msg_controllen, + address, ) + }) + } +} + +impl<'a, S> RecvMsg<'_, 'a, S> { + /// Iterate over the filled io slices pointed by this msghdr + pub fn iovs(&self) -> IoSliceIterator<'a> { + IoSliceIterator { + index: 0, + remaining: self.bytes, + slices: unsafe { + // safe for as long as mgdr is properly initialized and references are valid. + // for multi messages API we initialize it with an empty + // slice and replace with a concrete buffer + // for single message API we hold a lifetime reference to ioslices + std::slice::from_raw_parts(self.mhdr.msg_iov as *const _, self.mhdr.msg_iovlen as _) + }, + } + } +} + +#[derive(Debug)] +pub struct IoSliceIterator<'a> { + index: usize, + remaining: usize, + slices: &'a [IoSlice<'a>], +} + +impl<'a> Iterator for IoSliceIterator<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + if self.index >= self.slices.len() { + return None; + } + let slice = &self.slices[self.index][..self.remaining.min(self.slices[self.index].len())]; + self.remaining -= slice.len(); + self.index += 1; + if slice.is_empty() { + return None; + } + + Some(slice) + } +} + +// test contains both recvmmsg and timestaping which is linux only +// there are existing tests for recvmmsg only in tests/ +#[cfg(target_os = "linux")] +#[cfg(test)] +mod test { + use crate::sys::socket::{AddressFamily, ControlMessageOwned}; + use crate::*; + use std::str::FromStr; + + #[cfg_attr(qemu, ignore)] + #[test] + fn test_recvmm2() -> crate::Result<()> { + use crate::sys::socket::{ + sendmsg, setsockopt, socket, sockopt::Timestamping, MsgFlags, SockFlag, SockType, + SockaddrIn, TimestampingFlag, }; + use std::io::{IoSlice, IoSliceMut}; - output.push( - libc::mmsghdr { - msg_hdr: mhdr, - msg_len: 0, - } - ); + let sock_addr = SockaddrIn::from_str("127.0.0.1:6790").unwrap(); - (msg_controllen, &mut d.cmsg_buffer) - }).collect(); + let ssock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None, + )?; - let timeout = if let Some(mut t) = timeout { - t.as_mut() as *mut libc::timespec - } else { - ptr::null_mut() - }; + let rsock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::SOCK_NONBLOCK, + None, + )?; + + crate::sys::socket::bind(rsock, &sock_addr)?; + + setsockopt(rsock, Timestamping, &TimestampingFlag::all())?; + + let sbuf = (0..400).map(|i| i as u8).collect::>(); - let ret = unsafe { libc::recvmmsg(fd, output.as_mut_ptr(), output.len() as _, flags.bits() as _, timeout) }; - - let _ = Errno::result(ret)?; - - Ok(output - .into_iter() - .take(ret as usize) - .zip(addresses.iter().map(|addr| unsafe{addr.assume_init()})) - .zip(results.into_iter()) - .map(|((mmsghdr, address), (msg_controllen, cmsg_buffer))| { - // The cast is not unnecessary on all platforms. - #[allow(clippy::unnecessary_cast)] - unsafe { - read_mhdr( - mmsghdr.msg_hdr, - mmsghdr.msg_len as isize, - msg_controllen, - address, - cmsg_buffer - ) + let mut recv_buf = vec![0; 1024]; + + let mut recv_iovs = Vec::new(); + let mut pkt_iovs = Vec::new(); + + for (ix, chunk) in recv_buf.chunks_mut(256).enumerate() { + pkt_iovs.push(IoSliceMut::new(chunk)); + if ix % 2 == 1 { + recv_iovs.push(pkt_iovs); + pkt_iovs = Vec::new(); + } + } + drop(pkt_iovs); + + let flags = MsgFlags::empty(); + let iov1 = [IoSlice::new(&sbuf)]; + + let cmsg = cmsg_space!(crate::sys::socket::Timestamps); + sendmsg(ssock, &iov1, &[], flags, Some(&sock_addr)).unwrap(); + + let mut data = super::MultiHeaders::<()>::preallocate(recv_iovs.len(), Some(cmsg)); + + let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10)); + + let recv = super::recvmmsg(rsock, &mut data, recv_iovs.iter(), flags, Some(t))?; + + for rmsg in recv { + #[cfg(not(any(qemu, target_arch = "aarch64")))] + let mut saw_time = false; + let mut recvd = 0; + for cmsg in rmsg.cmsgs() { + if let ControlMessageOwned::ScmTimestampsns(timestamps) = cmsg { + let ts = timestamps.system; + + let sys_time = + crate::time::clock_gettime(crate::time::ClockId::CLOCK_REALTIME)?; + let diff = if ts > sys_time { + ts - sys_time + } else { + sys_time - ts + }; + assert!(std::time::Duration::from(diff).as_secs() < 60); + #[cfg(not(any(qemu, target_arch = "aarch64")))] + { + saw_time = true; + } + } } - }) - .collect()) -} -unsafe fn read_mhdr<'a, 'b, S>( + #[cfg(not(any(qemu, target_arch = "aarch64")))] + assert!(saw_time); + + for iov in rmsg.iovs() { + recvd += iov.len(); + } + assert_eq!(recvd, 400); + } + + Ok(()) + } +} +unsafe fn read_mhdr<'a, 'i, S>( mhdr: msghdr, r: isize, msg_controllen: usize, address: S, - cmsg_buffer: &'a mut Option<&'b mut Vec> -) -> RecvMsg<'b, S> +) -> RecvMsg<'a, 'i, S> where S: SockaddrLike { // The cast is not unnecessary on all platforms. #[allow(clippy::unnecessary_cast)] let cmsghdr = { if mhdr.msg_controllen > 0 { - // got control message(s) - cmsg_buffer - .as_mut() - .unwrap() - .set_len(mhdr.msg_controllen as usize); debug_assert!(!mhdr.msg_control.is_null()); debug_assert!(msg_controllen >= mhdr.msg_controllen as usize); CMSG_FIRSTHDR(&mhdr as *const msghdr) @@ -1711,38 +1906,43 @@ unsafe fn read_mhdr<'a, 'b, S>( address: Some(address), flags: MsgFlags::from_bits_truncate(mhdr.msg_flags), mhdr, + iobufs: std::marker::PhantomData, } } -unsafe fn pack_mhdr_to_receive<'outer, 'inner, I, S>( - iov: I, - cmsg_buffer: &mut Option<&mut Vec>, +/// Pack pointers to various structures into into msghdr +/// +/// # Safety +/// `iov_buffer` and `iov_buffer_len` must point to a slice +/// of `IoSliceMut` and number of available elements or be a null pointer and 0 +/// +/// `cmsg_buffer` and `cmsg_capacity` must point to a byte buffer used +/// to store control headers later or be a null pointer and 0 if control +/// headers are not used +/// +/// Buffers must remain valid for the whole lifetime of msghdr +unsafe fn pack_mhdr_to_receive( + iov_buffer: *const IoSliceMut, + iov_buffer_len: usize, + cmsg_buffer: *const u8, + cmsg_capacity: usize, address: *mut S, -) -> (usize, msghdr) +) -> msghdr where - I: AsRef<[IoSliceMut<'inner>]> + 'outer, - S: SockaddrLike + 'outer + S: SockaddrLike { - let (msg_control, msg_controllen) = cmsg_buffer.as_mut() - .map(|v| (v.as_mut_ptr(), v.capacity())) - .unwrap_or((ptr::null_mut(), 0)); - - let mhdr = { - // Musl's msghdr has private fields, so this is the only way to - // initialize it. - let mut mhdr = mem::MaybeUninit::::zeroed(); - let p = mhdr.as_mut_ptr(); - (*p).msg_name = (*address).as_mut_ptr() as *mut c_void; - (*p).msg_namelen = S::size(); - (*p).msg_iov = iov.as_ref().as_ptr() as *mut iovec; - (*p).msg_iovlen = iov.as_ref().len() as _; - (*p).msg_control = msg_control as *mut c_void; - (*p).msg_controllen = msg_controllen as _; - (*p).msg_flags = 0; - mhdr.assume_init() - }; - - (msg_controllen, mhdr) + // Musl's msghdr has private fields, so this is the only way to + // initialize it. + let mut mhdr = mem::MaybeUninit::::zeroed(); + let p = mhdr.as_mut_ptr(); + (*p).msg_name = (*address).as_mut_ptr() as *mut c_void; + (*p).msg_namelen = S::size(); + (*p).msg_iov = iov_buffer as *mut iovec; + (*p).msg_iovlen = iov_buffer_len as _; + (*p).msg_control = cmsg_buffer as *mut c_void; + (*p).msg_controllen = cmsg_capacity as _; + (*p).msg_flags = 0; + mhdr.assume_init() } fn pack_mhdr_to_send<'a, I, C, S>( @@ -1814,24 +2014,27 @@ fn pack_mhdr_to_send<'a, I, C, S>( /// [recvmsg(2)](https://pubs.opengroup.org/onlinepubs/9699919799/functions/recvmsg.html) pub fn recvmsg<'a, 'outer, 'inner, S>(fd: RawFd, iov: &'outer mut [IoSliceMut<'inner>], mut cmsg_buffer: Option<&'a mut Vec>, - flags: MsgFlags) -> Result> - where S: SockaddrLike + 'a + flags: MsgFlags) -> Result> + where S: SockaddrLike + 'a, + 'inner: 'outer { let mut address = mem::MaybeUninit::uninit(); - let (msg_controllen, mut mhdr) = unsafe { - pack_mhdr_to_receive::<_, S>(iov, &mut cmsg_buffer, address.as_mut_ptr()) + let (msg_control, msg_controllen) = cmsg_buffer.as_mut() + .map(|v| (v.as_mut_ptr(), v.capacity())) + .unwrap_or((ptr::null_mut(), 0)); + let mut mhdr = unsafe { + pack_mhdr_to_receive(iov.as_ref().as_ptr(), iov.len(), msg_control, msg_controllen, address.as_mut_ptr()) }; let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) }; let r = Errno::result(ret)?; - Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address.assume_init(), &mut cmsg_buffer) }) + Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address.assume_init()) }) } } - /// Create an endpoint for communication /// /// The `protocol` specifies a particular protocol to be used with the @@ -1842,7 +2045,12 @@ pub fn recvmsg<'a, 'outer, 'inner, S>(fd: RawFd, iov: &'outer mut [IoSliceMut<'i /// specified in this manner. /// /// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/socket.html) -pub fn socket>>(domain: AddressFamily, ty: SockType, flags: SockFlag, protocol: T) -> Result { +pub fn socket>>( + domain: AddressFamily, + ty: SockType, + flags: SockFlag, + protocol: T, +) -> Result { let protocol = match protocol.into() { None => 0, Some(p) => p as c_int, @@ -1862,8 +2070,12 @@ pub fn socket>>(domain: AddressFamily, ty: SockType /// Create a pair of connected sockets /// /// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/socketpair.html) -pub fn socketpair>>(domain: AddressFamily, ty: SockType, protocol: T, - flags: SockFlag) -> Result<(RawFd, RawFd)> { +pub fn socketpair>>( + domain: AddressFamily, + ty: SockType, + protocol: T, + flags: SockFlag, +) -> Result<(RawFd, RawFd)> { let protocol = match protocol.into() { None => 0, Some(p) => p as c_int, @@ -1877,7 +2089,9 @@ pub fn socketpair>>(domain: AddressFamily, ty: Sock let mut fds = [-1, -1]; - let res = unsafe { libc::socketpair(domain as c_int, ty, protocol, fds.as_mut_ptr()) }; + let res = unsafe { + libc::socketpair(domain as c_int, ty, protocol, fds.as_mut_ptr()) + }; Errno::result(res)?; Ok((fds[0], fds[1])) @@ -1896,9 +2110,7 @@ pub fn listen(sockfd: RawFd, backlog: usize) -> Result<()> { /// /// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html) pub fn bind(fd: RawFd, addr: &dyn SockaddrLike) -> Result<()> { - let res = unsafe { - libc::bind(fd, addr.as_ptr(), addr.len()) - }; + let res = unsafe { libc::bind(fd, addr.as_ptr(), addr.len()) }; Errno::result(res).map(drop) } @@ -1915,24 +2127,28 @@ pub fn accept(sockfd: RawFd) -> Result { /// Accept a connection on a socket /// /// [Further reading](https://man7.org/linux/man-pages/man2/accept.2.html) -#[cfg(any(all( - target_os = "android", - any( - target_arch = "aarch64", - target_arch = "x86", - target_arch = "x86_64" - ) - ), - target_os = "dragonfly", - target_os = "emscripten", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "illumos", - target_os = "linux", - target_os = "netbsd", - target_os = "openbsd"))] +#[cfg(any( + all( + target_os = "android", + any( + target_arch = "aarch64", + target_arch = "x86", + target_arch = "x86_64" + ) + ), + target_os = "dragonfly", + target_os = "emscripten", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd" +))] pub fn accept4(sockfd: RawFd, flags: SockFlag) -> Result { - let res = unsafe { libc::accept4(sockfd, ptr::null_mut(), ptr::null_mut(), flags.bits()) }; + let res = unsafe { + libc::accept4(sockfd, ptr::null_mut(), ptr::null_mut(), flags.bits()) + }; Errno::result(res) } @@ -1941,9 +2157,7 @@ pub fn accept4(sockfd: RawFd, flags: SockFlag) -> Result { /// /// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/connect.html) pub fn connect(fd: RawFd, addr: &dyn SockaddrLike) -> Result<()> { - let res = unsafe { - libc::connect(fd, addr.as_ptr(), addr.len()) - }; + let res = unsafe { libc::connect(fd, addr.as_ptr(), addr.len()) }; Errno::result(res).map(drop) } @@ -1958,7 +2172,8 @@ pub fn recv(sockfd: RawFd, buf: &mut [u8], flags: MsgFlags) -> Result { sockfd, buf.as_ptr() as *mut c_void, buf.len() as size_t, - flags.bits()); + flags.bits(), + ); Errno::result(ret).map(|r| r as usize) } @@ -1969,9 +2184,10 @@ pub fn recv(sockfd: RawFd, buf: &mut [u8], flags: MsgFlags) -> Result { /// address of the sender. /// /// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/recvfrom.html) -pub fn recvfrom(sockfd: RawFd, buf: &mut [u8]) - -> Result<(usize, Option)> -{ +pub fn recvfrom( + sockfd: RawFd, + buf: &mut [u8], +) -> Result<(usize, Option)> { unsafe { let mut addr = mem::MaybeUninit::::uninit(); let mut len = mem::size_of_val(&addr) as socklen_t; @@ -1982,11 +2198,15 @@ pub fn recvfrom(sockfd: RawFd, buf: &mut [u8]) buf.len() as size_t, 0, addr.as_mut_ptr() as *mut libc::sockaddr, - &mut len as *mut socklen_t))? as usize; - - Ok((ret, T::from_raw( - addr.assume_init().as_ptr() as *const libc::sockaddr, - Some(len)) + &mut len as *mut socklen_t, + ))? as usize; + + Ok(( + ret, + T::from_raw( + addr.assume_init().as_ptr() as *const libc::sockaddr, + Some(len), + ), )) } } @@ -1994,7 +2214,12 @@ pub fn recvfrom(sockfd: RawFd, buf: &mut [u8]) /// Send a message to a socket /// /// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/sendto.html) -pub fn sendto(fd: RawFd, buf: &[u8], addr: &dyn SockaddrLike, flags: MsgFlags) -> Result { +pub fn sendto( + fd: RawFd, + buf: &[u8], + addr: &dyn SockaddrLike, + flags: MsgFlags, +) -> Result { let ret = unsafe { libc::sendto( fd, @@ -2002,7 +2227,7 @@ pub fn sendto(fd: RawFd, buf: &[u8], addr: &dyn SockaddrLike, flags: MsgFlags) - buf.len() as size_t, flags.bits(), addr.as_ptr(), - addr.len() + addr.len(), ) }; @@ -2014,7 +2239,12 @@ pub fn sendto(fd: RawFd, buf: &[u8], addr: &dyn SockaddrLike, flags: MsgFlags) - /// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/send.html) pub fn send(fd: RawFd, buf: &[u8], flags: MsgFlags) -> Result { let ret = unsafe { - libc::send(fd, buf.as_ptr() as *const c_void, buf.len() as size_t, flags.bits()) + libc::send( + fd, + buf.as_ptr() as *const c_void, + buf.len() as size_t, + flags.bits(), + ) }; Errno::result(ret).map(|r| r as usize) @@ -2027,7 +2257,7 @@ pub fn send(fd: RawFd, buf: &[u8], flags: MsgFlags) -> Result { */ /// Represents a socket option that can be retrieved. -pub trait GetSockOpt : Copy { +pub trait GetSockOpt: Copy { type Val; /// Look up the value of this socket option on the given socket. @@ -2035,7 +2265,7 @@ pub trait GetSockOpt : Copy { } /// Represents a socket option that can be set. -pub trait SetSockOpt : Clone { +pub trait SetSockOpt: Clone { type Val; /// Set the value of this socket option on the given socket. @@ -2066,7 +2296,11 @@ pub fn getsockopt(fd: RawFd, opt: O) -> Result { /// let res = setsockopt(fd, KeepAlive, &true); /// assert!(res.is_ok()); /// ``` -pub fn setsockopt(fd: RawFd, opt: O, val: &O::Val) -> Result<()> { +pub fn setsockopt( + fd: RawFd, + opt: O, + val: &O::Val, +) -> Result<()> { opt.set(fd, val) } @@ -2081,13 +2315,12 @@ pub fn getpeername(fd: RawFd) -> Result { let ret = libc::getpeername( fd, addr.as_mut_ptr() as *mut libc::sockaddr, - &mut len + &mut len, ); Errno::result(ret)?; - T::from_raw(addr.assume_init().as_ptr(), Some(len)) - .ok_or(Errno::EINVAL) + T::from_raw(addr.assume_init().as_ptr(), Some(len)).ok_or(Errno::EINVAL) } } @@ -2102,13 +2335,12 @@ pub fn getsockname(fd: RawFd) -> Result { let ret = libc::getsockname( fd, addr.as_mut_ptr() as *mut libc::sockaddr, - &mut len + &mut len, ); Errno::result(ret)?; - T::from_raw(addr.assume_init().as_ptr(), Some(len)) - .ok_or(Errno::EINVAL) + T::from_raw(addr.assume_init().as_ptr(), Some(len)).ok_or(Errno::EINVAL) } } @@ -2127,8 +2359,8 @@ pub fn getsockname(fd: RawFd) -> Result { #[allow(deprecated)] pub fn sockaddr_storage_to_addr( addr: &sockaddr_storage, - len: usize) -> Result { - + len: usize, +) -> Result { assert!(len <= mem::size_of::()); if len < mem::size_of_val(&addr.ss_family) { return Err(Errno::ENOTCONN); @@ -2146,18 +2378,14 @@ pub fn sockaddr_storage_to_addr( #[cfg(feature = "net")] libc::AF_INET6 => { assert!(len >= mem::size_of::()); - let sin6 = unsafe { - *(addr as *const _ as *const sockaddr_in6) - }; + let sin6 = unsafe { *(addr as *const _ as *const sockaddr_in6) }; Ok(SockAddr::Inet(InetAddr::V6(sin6))) } - libc::AF_UNIX => { - unsafe { - let sun = *(addr as *const _ as *const sockaddr_un); - let sun_len = len.try_into().unwrap(); - Ok(SockAddr::Unix(UnixAddr::from_raw_parts(sun, sun_len))) - } - } + libc::AF_UNIX => unsafe { + let sun = *(addr as *const _ as *const sockaddr_un); + let sun_len = len.try_into().unwrap(); + Ok(SockAddr::Unix(UnixAddr::from_raw_parts(sun, sun_len))) + }, #[cfg(any(target_os = "android", target_os = "linux"))] #[cfg(feature = "net")] libc::AF_PACKET => { @@ -2166,40 +2394,31 @@ pub fn sockaddr_storage_to_addr( // Apparently the Linux kernel can return smaller sizes when // the value in the last element of sockaddr_ll (`sll_addr`) is // smaller than the declared size of that field - let sll = unsafe { - *(addr as *const _ as *const sockaddr_ll) - }; + let sll = unsafe { *(addr as *const _ as *const sockaddr_ll) }; Ok(SockAddr::Link(LinkAddr(sll))) } #[cfg(any(target_os = "android", target_os = "linux"))] libc::AF_NETLINK => { use libc::sockaddr_nl; - let snl = unsafe { - *(addr as *const _ as *const sockaddr_nl) - }; + let snl = unsafe { *(addr as *const _ as *const sockaddr_nl) }; Ok(SockAddr::Netlink(NetlinkAddr(snl))) } #[cfg(any(target_os = "android", target_os = "linux"))] libc::AF_ALG => { use libc::sockaddr_alg; - let salg = unsafe { - *(addr as *const _ as *const sockaddr_alg) - }; + let salg = unsafe { *(addr as *const _ as *const sockaddr_alg) }; Ok(SockAddr::Alg(AlgAddr(salg))) } #[cfg(any(target_os = "android", target_os = "linux"))] libc::AF_VSOCK => { use libc::sockaddr_vm; - let svm = unsafe { - *(addr as *const _ as *const sockaddr_vm) - }; + let svm = unsafe { *(addr as *const _ as *const sockaddr_vm) }; Ok(SockAddr::Vsock(VsockAddr(svm))) } af => panic!("unexpected address family {}", af), } } - #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum Shutdown { /// Further receptions will be disallowed. @@ -2218,9 +2437,9 @@ pub fn shutdown(df: RawFd, how: Shutdown) -> Result<()> { use libc::shutdown; let how = match how { - Shutdown::Read => libc::SHUT_RD, + Shutdown::Read => libc::SHUT_RD, Shutdown::Write => libc::SHUT_WR, - Shutdown::Both => libc::SHUT_RDWR, + Shutdown::Both => libc::SHUT_RDWR, }; Errno::result(shutdown(df, how)).map(drop) diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index b4ca279d67..7ab60ecc28 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -501,31 +501,31 @@ mod recvfrom { rsock, ssock, move |s, m, flags| { - let iov = [IoSlice::new(m)]; - let mut msgs = vec![SendMmsgData { - iov: &iov, - cmsgs: &[], - addr: Some(sock_addr), - _lt: Default::default(), - }]; - let batch_size = 15; + let mut iovs = Vec::with_capacity(1 + batch_size); + let mut addrs = Vec::with_capacity(1 + batch_size); + let mut data = MultiHeaders::preallocate(1 + batch_size, None); + let iov = IoSlice::new(m); + // first chunk: + iovs.push([iov]); + addrs.push(Some(sock_addr)); for _ in 0..batch_size { - msgs.push(SendMmsgData { - iov: &iov, - cmsgs: &[], - addr: Some(sock_addr2), - _lt: Default::default(), - }); + iovs.push([iov]); + addrs.push(Some(sock_addr2)); } - sendmmsg(s, msgs.iter(), flags).map(move |sent_bytes| { - assert!(!sent_bytes.is_empty()); - for sent in &sent_bytes { - assert_eq!(*sent, m.len()); - } - sent_bytes.len() - }) + + let res = sendmmsg(s, &mut data, &iovs, addrs, [], flags)?; + let mut sent_messages = 0; + let mut sent_bytes = 0; + for item in res { + sent_messages += 1; + sent_bytes += item.bytes; + } + // + assert_eq!(sent_messages, iovs.len()); + assert_eq!(sent_bytes, sent_messages * m.len()); + Ok(sent_messages) }, |_, _| {}, ); @@ -577,21 +577,19 @@ mod recvfrom { // Buffers to receive exactly `NUM_MESSAGES_SENT` messages let mut receive_buffers = [[0u8; 32]; NUM_MESSAGES_SENT]; - let iovs: Vec<_> = receive_buffers - .iter_mut() - .map(|buf| [IoSliceMut::new(&mut buf[..])]) - .collect(); + msgs.extend( + receive_buffers + .iter_mut() + .map(|buf| [IoSliceMut::new(&mut buf[..])]), + ); - for iov in &iovs { - msgs.push_back(RecvMmsgData { - iov, - cmsg_buffer: None, - }) - } + let mut data = + MultiHeaders::::preallocate(msgs.len(), None); let res: Vec> = - recvmmsg(rsock, &mut msgs, MsgFlags::empty(), None) - .expect("recvmmsg"); + recvmmsg(rsock, &mut data, msgs.iter(), MsgFlags::empty(), None) + .expect("recvmmsg") + .collect(); assert_eq!(res.len(), DATA.len()); for RecvMsg { address, bytes, .. } in res.into_iter() { @@ -655,21 +653,26 @@ mod recvfrom { // will return when there are fewer than requested messages in the // kernel buffers when using `MSG_DONTWAIT`. let mut receive_buffers = [[0u8; 32]; NUM_MESSAGES_SENT + 2]; - let iovs: Vec<_> = receive_buffers - .iter_mut() - .map(|buf| [IoSliceMut::new(&mut buf[..])]) - .collect(); + msgs.extend( + receive_buffers + .iter_mut() + .map(|buf| [IoSliceMut::new(&mut buf[..])]), + ); - for iov in &iovs { - msgs.push_back(RecvMmsgData { - iov, - cmsg_buffer: None, - }) - } + let mut data = MultiHeaders::::preallocate( + NUM_MESSAGES_SENT + 2, + None, + ); - let res: Vec> = - recvmmsg(rsock, &mut msgs, MsgFlags::MSG_DONTWAIT, None) - .expect("recvmmsg"); + let res: Vec> = recvmmsg( + rsock, + &mut data, + msgs.iter(), + MsgFlags::MSG_DONTWAIT, + None, + ) + .expect("recvmmsg") + .collect(); assert_eq!(res.len(), NUM_MESSAGES_SENT); for RecvMsg { address, bytes, .. } in res.into_iter() { @@ -2205,14 +2208,13 @@ fn test_recvmmsg_timestampns() { assert_eq!(message.len(), l); // Receive the message let mut buffer = vec![0u8; message.len()]; - let mut cmsgspace = nix::cmsg_space!(TimeSpec); - let iov = [IoSliceMut::new(&mut buffer)]; - let mut data = vec![RecvMmsgData { - iov, - cmsg_buffer: Some(&mut cmsgspace), - }]; + let cmsgspace = nix::cmsg_space!(TimeSpec); + let iov = vec![[IoSliceMut::new(&mut buffer)]]; + let mut data = MultiHeaders::preallocate(1, Some(cmsgspace)); let r: Vec> = - recvmmsg(in_socket, &mut data, flags, None).unwrap(); + recvmmsg(in_socket, &mut data, iov.iter(), flags, None) + .unwrap() + .collect(); let rtime = match r[0].cmsgs().next() { Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime, Some(_) => panic!("Unexpected control message"),