From 9e80118309d660d338e23cad3c143a3adb09e48d Mon Sep 17 00:00:00 2001 From: Michael Baikov Date: Mon, 13 Jun 2022 11:48:24 +0800 Subject: [PATCH] reimplement recvmsg New implementation performs no allocations after all the necessary structures are created, removes potentially unsound code that was used by the old version (see below) and adds a bit more documentation about bugs in how timeout is actually handled ``` let timeout = if let Some(mut t) = timeout { t.as_mut() as *mut libc::timespec } else { ptr::null_mut() }; ``` --- src/sys/socket/mod.rs | 367 ++++++++++++++++++++++++++++++---------- test/sys/test_socket.rs | 40 ++--- 2 files changed, 292 insertions(+), 115 deletions(-) diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index a45dd9ed67..bc60e30bde 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -1508,122 +1508,315 @@ pub fn sendmmsg<'a, I, C, S>( target_os = "netbsd", ))] #[derive(Debug)] -pub struct RecvMmsgData<'a, I> +/// Preallocated structures needed for [`recvmmsg`] function +pub struct RecvMMsg { + // preallocated boxed slice of mmsghdr + items: Box<[libc::mmsghdr]>, + addresses: Box<[mem::MaybeUninit]>, + // while we are not using it directly - this is used to store control messages + // and we retain pointers to them inside items array + #[allow(dead_code)] + cmsg_buffers: Option>, + msg_controllen: usize, +} + +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "netbsd", +))] +impl RecvMMsg { + /// Preallocate structure used by [`recvmmsg`], takes number of headers to preallocate + /// + /// `cmsg_buffer` should be created with [`cmsg_space!`] if needed + pub fn preallocate(num_slices: usize, cmsg_buffer: Option>) -> Self where - I: AsRef<[IoSliceMut<'a>]> + 'a, -{ - pub iov: I, - pub cmsg_buffer: Option<&'a mut Vec>, + S: Copy + SockaddrLike, + { + // we will be storing pointers to addresses inside mhdr - convert it into boxed + // slice so it can'be changed later by pushing anything into self.addresses + let mut addresses = vec![std::mem::MaybeUninit::uninit(); num_slices].into_boxed_slice(); + + let msg_controllen = cmsg_buffer.as_ref().map_or(0, |v| v.capacity()); + + // we'll need a cmsg_buffer for each slice, we preallocate a vector and split + // it into "slices" parts + let cmsg_buffers = + cmsg_buffer.map(|v| vec![0u8; v.capacity() * num_slices].into_boxed_slice()); + + let items = addresses + .iter_mut() + .enumerate() + .map(|(ix, address)| { + let (ptr, cap) = match &cmsg_buffers { + Some(v) => ((&v[ix * msg_controllen] as *const u8), msg_controllen), + None => (std::ptr::null(), 0), + }; + let msg_hdr = unsafe { pack_mhdr_to_receive(std::ptr::null(), 0, ptr, cap, address.as_mut_ptr()) }; + libc::mmsghdr { + msg_hdr, + msg_len: 0, + } + }) + .collect::>(); + + Self { + items: items.into_boxed_slice(), + addresses, + cmsg_buffers, + msg_controllen, + } + } } -/// An extension of `recvmsg` that allows the caller to receive multiple -/// messages from a socket using a single system call. This has -/// performance benefits for some applications. +/// An extension of recvmsg that allows the caller to receive multiple messages from a socket using a single system call. /// -/// `iov` and `cmsg_buffer` should be constructed similarly to `recvmsg` +/// This has performance benefits for some applications. /// -/// Multiple allocations are performed +/// This method performs no allications. /// -/// # Arguments -/// -/// * `fd`: Socket file descriptor -/// * `data`: Struct that implements `IntoIterator` with `RecvMmsgData` items -/// * `flags`: Optional flags passed directly to the operating system. -/// -/// # RecvMmsgData -/// -/// * `iov`: Scatter-gather list of buffers to receive the message -/// * `cmsg_buffer`: Space to receive ancillary data. Should be created by -/// [`cmsg_space!`](../../macro.cmsg_space.html) +/// Returns an iterator producing [`RecvMsg`], one per received messages. Each `RecvMsg` can produce +/// iterators over [`IoSlice`] with [`iovs`][RecvMsg::iovs`] and +/// `ControlMessageOwned` with [`cmsgs`][RecvMsg::cmsgs]. /// -/// # Returns -/// A `Vec` with multiple `RecvMsg`, one per received message +/// # Bugs (in underlying implementation, at least in Linux) +/// The timeout argument does not work as intended. The timeout is checked only after the receipt +/// of each datagram, so that if up to `vlen`-1 datagrams are received before the timeout expires, +/// but then no further datagrams are received, the call will block forever. /// -/// # References -/// - [`recvmsg`](fn.recvmsg.html) -/// - [`RecvMsg`](struct.RecvMsg.html) +/// If an error occurs after at least one message has been received, the call succeeds, and returns +/// the number of messages received. The error code is expected to be returned on a subsequent +/// call to recvmmsg(). In the current implementation, however, the error code can be +/// overwritten in the meantime by an unrelated network event on a socket, for example an +/// incoming ICMP packet. #[cfg(any( target_os = "linux", target_os = "android", target_os = "freebsd", target_os = "netbsd", ))] -#[allow(clippy::needless_collect)] // Complicated false positive -pub fn recvmmsg<'a, I, S>( +pub fn recvmmsg<'a, XS, S, I>( fd: RawFd, - data: impl std::iter::IntoIterator, - IntoIter=impl ExactSizeIterator + Iterator>>, + data: &'a mut RecvMMsg, + slices: XS, flags: MsgFlags, - timeout: Option -) -> Result>> - where - I: AsRef<[IoSliceMut<'a>]> + 'a, - S: Copy + SockaddrLike + 'a + mut timeout: Option, +) -> crate::Result> +where + XS: ExactSizeIterator, + I: AsRef<[IoSliceMut<'a>]>, { - let iter = data.into_iter(); + let count = std::cmp::min(slices.len(), data.items.len()); + + for (slice, mmsghdr) in slices.zip(data.items.iter_mut()) { + let mut p = &mut mmsghdr.msg_hdr; + p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec; + p.msg_iovlen = slice.as_ref().len() as _; + } + + let timeout_ptr = timeout + .as_mut() + .map_or_else(std::ptr::null_mut, |t| t as *mut _ as *mut libc::timespec); + + let received = Errno::result(unsafe { + libc::recvmmsg( + fd, + data.items.as_mut_ptr(), + count as _, + flags.bits() as _, + timeout_ptr, + ) + })? as usize; + + Ok(RecvMMsgItems { + rmm: data, + current_index: 0, + received, + }) +} + +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "netbsd", +))] +#[derive(Debug)] +pub struct RecvMMsgItems<'a, S> { + // preallocated structures + rmm: &'a RecvMMsg, + current_index: usize, + received: usize, +} + +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "netbsd", +))] +impl<'a, S> Iterator for RecvMMsgItems<'a, S> +where + S: Copy + SockaddrLike, +{ + type Item = RecvMsg<'a, S>; - let num_messages = iter.len(); - - let mut output: Vec = Vec::with_capacity(num_messages); - - // Addresses should be pre-allocated. pack_mhdr_to_receive will store them - // as raw pointers, so we may not move them. Turn the vec into a boxed - // slice so we won't inadvertently reallocate the vec. - let mut addresses = vec![mem::MaybeUninit::uninit(); num_messages] - .into_boxed_slice(); - - let results: Vec<_> = iter.enumerate().map(|(i, d)| { - let (msg_control, msg_controllen) = d.cmsg_buffer.as_mut() - .map(|v| (v.as_mut_ptr(), v.capacity())) - .unwrap_or((ptr::null_mut(), 0)); - let mhdr = unsafe { - pack_mhdr_to_receive( - d.iov.as_ref().as_ptr(), - d.iov.as_ref().len(), - msg_control, - msg_controllen, - addresses[i].as_mut_ptr(), + fn next(&mut self) -> Option { + if self.current_index >= self.received { + return None; + } + let mmsghdr = self.rmm.items[self.current_index]; + + // as long as we are not reading past the index writen by recvmmsg - address + // will be initialized + let address = unsafe { self.rmm.addresses[self.current_index].assume_init() }; + + self.current_index += 1; + Some(unsafe { + read_mhdr( + mmsghdr.msg_hdr, + mmsghdr.msg_len as isize, + self.rmm.msg_controllen, + address, ) + }) + } +} + +impl<'a, S> RecvMsg<'a, S> { + /// Iterate over the filled io slices pointed by this msghdr + pub fn iovs(&self) -> IoSliceIterator { + IoSliceIterator { + index: 0, + remaining: self.bytes, + slices: unsafe { + std::slice::from_raw_parts(self.mhdr.msg_iov as *const _, self.mhdr.msg_iovlen as _) + }, + } + } +} + +#[derive(Debug)] +pub struct IoSliceIterator<'a> { + index: usize, + remaining: usize, + slices: &'a [IoSlice<'a>], +} + +impl<'a> Iterator for IoSliceIterator<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + if self.index >= self.slices.len() { + return None; + } + let slice = &self.slices[self.index][..self.remaining.min(self.slices[self.index].len())]; + self.remaining -= slice.len(); + self.index += 1; + if slice.is_empty() { + return None; + } + + Some(slice) + } +} + +// test contains both recvmmsg and timestaping which is linux only +// there are existing tests for recvmmsg only in tests/ +#[cfg(target_os = "linux")] +#[cfg(test)] +mod test { + use crate::sys::socket::{AddressFamily, ControlMessageOwned}; + use crate::*; + use std::str::FromStr; + + #[cfg_attr(qemu, ignore)] + #[test] + fn test_recvmm2() -> crate::Result<()> { + use crate::sys::socket::{ + sendmsg, setsockopt, socket, sockopt::Timestamping, MsgFlags, SockFlag, SockType, + SockaddrIn, TimestampingFlag, }; + use std::io::{IoSlice, IoSliceMut}; + + let sock_addr = SockaddrIn::from_str("127.0.0.1:6790").unwrap(); + + let ssock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None, + )?; + + let rsock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::SOCK_NONBLOCK, + None, + )?; + + crate::sys::socket::bind(rsock, &sock_addr)?; + + setsockopt(rsock, Timestamping, &TimestampingFlag::all())?; - output.push( - libc::mmsghdr { - msg_hdr: mhdr, - msg_len: 0, + let sbuf = (0..400).map(|i| i as u8).collect::>(); + + let mut recv_buf = vec![0; 1024]; + + let mut recv_iovs = Vec::new(); + let mut pkt_iovs = Vec::new(); + + for (ix, chunk) in recv_buf.chunks_mut(256).enumerate() { + pkt_iovs.push(IoSliceMut::new(chunk)); + if ix % 2 == 1 { + recv_iovs.push(pkt_iovs); + pkt_iovs = Vec::new(); } - ); + } + drop(pkt_iovs); - msg_controllen as usize - }).collect(); + let flags = MsgFlags::empty(); + let iov1 = [IoSlice::new(&sbuf)]; - let timeout = if let Some(mut t) = timeout { - t.as_mut() as *mut libc::timespec - } else { - ptr::null_mut() - }; + let cmsg = cmsg_space!(crate::sys::socket::Timestamps); + sendmsg(ssock, &iov1, &[], flags, Some(&sock_addr)).unwrap(); - let ret = unsafe { libc::recvmmsg(fd, output.as_mut_ptr(), output.len() as _, flags.bits() as _, timeout) }; + let mut data = super::RecvMMsg::<()>::preallocate(recv_iovs.len(), Some(cmsg)); - let _ = Errno::result(ret)?; + let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10)); - Ok(output - .into_iter() - .take(ret as usize) - .zip(addresses.iter().map(|addr| unsafe{addr.assume_init()})) - .zip(results.into_iter()) - .map(|((mmsghdr, address), msg_controllen)| { - unsafe { - read_mhdr( - mmsghdr.msg_hdr, - mmsghdr.msg_len as isize, - msg_controllen, - address, - ) + let recv = super::recvmmsg(rsock, &mut data, recv_iovs.iter(), flags, Some(t))?; + + for rmsg in recv { + let mut saw_time = false; + let mut recvd = 0; + for cmsg in rmsg.cmsgs() { + if let ControlMessageOwned::ScmTimestampsns(timestamps) = cmsg { + let ts = timestamps.system; + + let sys_time = + crate::time::clock_gettime(crate::time::ClockId::CLOCK_REALTIME)?; + let diff = if ts > sys_time { + ts - sys_time + } else { + sys_time - ts + }; + assert!(std::time::Duration::from(diff).as_secs() < 60); + saw_time = true; + } } - }) - .collect()) -} + assert!(saw_time); + + for iov in rmsg.iovs() { + recvd += iov.len(); + } + assert_eq!(recvd, 400); + } + Ok(()) + } +} unsafe fn read_mhdr<'a, S>( mhdr: msghdr, r: isize, diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index c742960ae8..4ff58d8017 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -521,18 +521,12 @@ mod recvfrom { // Buffers to receive exactly `NUM_MESSAGES_SENT` messages let mut receive_buffers = [[0u8; 32]; NUM_MESSAGES_SENT]; - let iovs: Vec<_> = receive_buffers.iter_mut().map(|buf| { + msgs.extend(receive_buffers.iter_mut().map(|buf| { [IoSliceMut::new(&mut buf[..])] - }).collect(); + })); + let mut data = RecvMMsg::::preallocate(msgs.len(), None); - for iov in &iovs { - msgs.push_back(RecvMmsgData { - iov, - cmsg_buffer: None, - }) - }; - - let res: Vec> = recvmmsg(rsock, &mut msgs, MsgFlags::empty(), None).expect("recvmmsg"); + let res: Vec> = recvmmsg(rsock, &mut data, msgs.iter(), MsgFlags::empty(), None).expect("recvmmsg").collect(); assert_eq!(res.len(), DATA.len()); for RecvMsg { address, bytes, .. } in res.into_iter() { @@ -592,18 +586,13 @@ mod recvfrom { // will return when there are fewer than requested messages in the // kernel buffers when using `MSG_DONTWAIT`. let mut receive_buffers = [[0u8; 32]; NUM_MESSAGES_SENT + 2]; - let iovs: Vec<_> = receive_buffers.iter_mut().map(|buf| { + msgs.extend(receive_buffers.iter_mut().map(|buf| { [IoSliceMut::new(&mut buf[..])] - }).collect(); + })); - for iov in &iovs { - msgs.push_back(RecvMmsgData { - iov, - cmsg_buffer: None, - }) - }; + let mut data = RecvMMsg::::preallocate(NUM_MESSAGES_SENT + 2, None); - let res: Vec> = recvmmsg(rsock, &mut msgs, MsgFlags::MSG_DONTWAIT, None).expect("recvmmsg"); + let res: Vec> = recvmmsg(rsock, &mut data, msgs.iter(), MsgFlags::MSG_DONTWAIT, None).expect("recvmmsg").collect(); assert_eq!(res.len(), NUM_MESSAGES_SENT); for RecvMsg { address, bytes, .. } in res.into_iter() { @@ -1710,15 +1699,10 @@ fn test_recvmmsg_timestampns() { assert_eq!(message.len(), l); // Receive the message let mut buffer = vec![0u8; message.len()]; - let mut cmsgspace = nix::cmsg_space!(TimeSpec); - let iov = [IoSliceMut::new(&mut buffer)]; - let mut data = vec![ - RecvMmsgData { - iov, - cmsg_buffer: Some(&mut cmsgspace), - }, - ]; - let r: Vec> = recvmmsg(in_socket, &mut data, flags, None).unwrap(); + let cmsgspace = nix::cmsg_space!(TimeSpec); + let iov = vec![[IoSliceMut::new(&mut buffer)]]; + let mut data = RecvMMsg::preallocate(1, Some(cmsgspace)); + let r: Vec> = recvmmsg(in_socket, &mut data, iov.iter(), flags, None).unwrap().collect(); let rtime = match r[0].cmsgs().next() { Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime, Some(_) => panic!("Unexpected control message"),