Skip to content

Commit

Permalink
also support sendmmsg
Browse files Browse the repository at this point in the history
renames:
    RecvMMsg -> MultHdrs
    RecvMMsgItems -> MultiResults

Adding a lifetime reference to RecvMsg
The name is not 100% correct now, it can be useful
for both sending and receiving messages: to collect hardware
sending timestamps you need to use control messages as well
  • Loading branch information
pacak committed Jun 24, 2022
1 parent f185d6d commit 1ae005d
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 96 deletions.
155 changes: 83 additions & 72 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,15 +574,20 @@ macro_rules! cmsg_space {
}

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct RecvMsg<'a, S> {
/// Contains outcome of sending or receiving a message
///
/// Use [`cmsgs`][RecvMsg::cmsgs] to access all the control messages present, and
/// [`iovs`][RecvMsg::iovs`] to access underlying io slices.
pub struct RecvMsg<'a, 's, S> {
pub bytes: usize,
cmsghdr: Option<&'a cmsghdr>,
pub address: Option<S>,
pub flags: MsgFlags,
iobufs: std::marker::PhantomData<& 's()>,
mhdr: msghdr,
}

impl<'a, S> RecvMsg<'a, S> {
impl<'a, S> RecvMsg<'a, '_, S> {
/// Iterate over the valid control messages pointed to by this
/// msghdr.
pub fn cmsgs(&self) -> CmsgIterator {
Expand Down Expand Up @@ -1411,24 +1416,6 @@ pub fn sendmsg<S>(fd: RawFd, iov: &[IoSlice<'_>], cmsgs: &[ControlMessage],
Errno::result(ret).map(|r| r as usize)
}

#[cfg(any(
target_os = "linux",
target_os = "android",
target_os = "freebsd",
target_os = "netbsd",
))]
#[derive(Debug)]
pub struct SendMmsgData<'a, I, C, S>
where
I: AsRef<[IoSlice<'a>]>,
C: AsRef<[ControlMessage<'a>]>,
S: SockaddrLike + 'a
{
pub iov: I,
pub cmsgs: C,
pub addr: Option<S>,
pub _lt: std::marker::PhantomData<&'a I>,
}

/// An extension of `sendmsg` that allows the caller to transmit multiple
/// messages on a socket using a single system call. This has performance
Expand All @@ -1453,51 +1440,66 @@ pub struct SendMmsgData<'a, I, C, S>
target_os = "freebsd",
target_os = "netbsd",
))]
pub fn sendmmsg<'a, I, C, S>(
pub fn sendmmsg<'a, XS, AS, C, I, S>(
fd: RawFd,
data: impl std::iter::IntoIterator<Item=&'a SendMmsgData<'a, I, C, S>>,
data: &'a mut MultHdrs<S>,
slices: XS,
// one address per group of slices
addrs: AS,
// shared across all the messages
cmsgs: C,
flags: MsgFlags
) -> Result<Vec<usize>>
) -> crate::Result<MultiResults<'a, S>>
where
XS: IntoIterator<Item = I>,
AS: AsRef<[Option<S>]>,
I: AsRef<[IoSlice<'a>]> + 'a,
C: AsRef<[ControlMessage<'a>]> + 'a,
S: SockaddrLike + 'a
{
let iter = data.into_iter();

let size_hint = iter.size_hint();
let reserve_items = size_hint.1.unwrap_or(size_hint.0);
let mut count = 0;

let mut output = Vec::<libc::mmsghdr>::with_capacity(reserve_items);

let mut cmsgs_buffers = Vec::<Vec<u8>>::with_capacity(reserve_items);

for d in iter {
let capacity: usize = d.cmsgs.as_ref().iter().map(|c| c.space()).sum();
let mut cmsgs_buffer = vec![0u8; capacity];
for (i, ((slice, addr), mmsghdr)) in slices.into_iter().zip(addrs.as_ref()).zip(data.items.iter_mut() ).enumerate() {
let mut p = &mut mmsghdr.msg_hdr;
p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec;
p.msg_iovlen = slice.as_ref().len() as _;

output.push(libc::mmsghdr {
msg_hdr: pack_mhdr_to_send(
&mut cmsgs_buffer,
&d.iov,
&d.cmsgs,
d.addr.as_ref()
),
msg_len: 0,
});
cmsgs_buffers.push(cmsgs_buffer);
};
(*p).msg_namelen = addr.as_ref().map_or(0, S::len);
(*p).msg_name = addr.as_ref().map_or(ptr::null(), S::as_ptr) as _;

// Encode each cmsg. This must happen after initializing the header because
// CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields.
// CMSG_FIRSTHDR is always safe
let mut pmhdr: *mut cmsghdr = unsafe { CMSG_FIRSTHDR(p) };
for cmsg in cmsgs.as_ref() {
assert_ne!(pmhdr, ptr::null_mut());
// Safe because we know that pmhdr is valid, and we initialized it with
// sufficient space
unsafe { cmsg.encode_into(pmhdr) };
// Safe because mhdr is valid
pmhdr = unsafe { CMSG_NXTHDR(p, pmhdr) };
}

let ret = unsafe { libc::sendmmsg(fd, output.as_mut_ptr(), output.len() as _, flags.bits() as _) };
count = i+1;
}

let sent_messages = Errno::result(ret)? as usize;
let mut sent_bytes = Vec::with_capacity(sent_messages);
let sent = Errno::result(unsafe {
libc::sendmmsg(
fd,
data.items.as_mut_ptr(),
count as _,
flags.bits() as _
)
})? as usize;

for item in &output {
sent_bytes.push(item.msg_len as usize);
}
Ok(MultiResults {
rmm: data,
current_index: 0,
received: sent
})

Ok(sent_bytes)
}


Expand All @@ -1508,8 +1510,8 @@ pub fn sendmmsg<'a, I, C, S>(
target_os = "netbsd",
))]
#[derive(Debug)]
/// Preallocated structures needed for [`recvmmsg`] function
pub struct RecvMMsg<S> {
/// Preallocated structures needed for [`recvmmsg`] and [`sendmmsg`] functions
pub struct MultHdrs<S> {
// preallocated boxed slice of mmsghdr
items: Box<[libc::mmsghdr]>,
addresses: Box<[mem::MaybeUninit<S>]>,
Expand All @@ -1526,8 +1528,8 @@ pub struct RecvMMsg<S> {
target_os = "freebsd",
target_os = "netbsd",
))]
impl<S> RecvMMsg<S> {
/// Preallocate structure used by [`recvmmsg`], takes number of headers to preallocate
impl<S> MultHdrs<S> {
/// Preallocate structure used by [`recvmmsg`] and [`sendmmsg`] takes number of headers to preallocate
///
/// `cmsg_buffer` should be created with [`cmsg_space!`] if needed
pub fn preallocate(num_slices: usize, cmsg_buffer: Option<Vec<u8>>) -> Self
Expand Down Expand Up @@ -1598,21 +1600,21 @@ impl<S> RecvMMsg<S> {
))]
pub fn recvmmsg<'a, XS, S, I>(
fd: RawFd,
data: &'a mut RecvMMsg<S>,
data: &'a mut MultHdrs<S>,
slices: XS,
flags: MsgFlags,
mut timeout: Option<crate::sys::time::TimeSpec>,
) -> crate::Result<RecvMMsgItems<'a, S>>
) -> crate::Result<MultiResults<'a, S>>
where
XS: ExactSizeIterator<Item = I>,
XS: IntoIterator<Item = I>,
I: AsRef<[IoSliceMut<'a>]>,
{
let count = std::cmp::min(slices.len(), data.items.len());

for (slice, mmsghdr) in slices.zip(data.items.iter_mut()) {
let mut count = 0;
for (i, (slice, mmsghdr)) in slices.into_iter().zip(data.items.iter_mut()).enumerate() {
let mut p = &mut mmsghdr.msg_hdr;
p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec;
p.msg_iovlen = slice.as_ref().len() as _;
count = i + 1;
}

let timeout_ptr = timeout
Expand All @@ -1629,7 +1631,7 @@ where
)
})? as usize;

Ok(RecvMMsgItems {
Ok(MultiResults {
rmm: data,
current_index: 0,
received,
Expand All @@ -1643,9 +1645,12 @@ where
target_os = "netbsd",
))]
#[derive(Debug)]
pub struct RecvMMsgItems<'a, S> {
/// Iterator over results of [`recvmmsg`]/[`sendmmsg`]
///
///
pub struct MultiResults<'a, S> {
// preallocated structures
rmm: &'a RecvMMsg<S>,
rmm: &'a MultHdrs<S>,
current_index: usize,
received: usize,
}
Expand All @@ -1656,11 +1661,11 @@ pub struct RecvMMsgItems<'a, S> {
target_os = "freebsd",
target_os = "netbsd",
))]
impl<'a, S> Iterator for RecvMMsgItems<'a, S>
impl<'a, S> Iterator for MultiResults<'a, S>
where
S: Copy + SockaddrLike,
{
type Item = RecvMsg<'a, S>;
type Item = RecvMsg<'a, 'a, S>;

fn next(&mut self) -> Option<Self::Item> {
if self.current_index >= self.received {
Expand All @@ -1684,13 +1689,17 @@ where
}
}

impl<'a, S> RecvMsg<'a, S> {
impl<'a, S> RecvMsg<'_, 'a, S> {
/// Iterate over the filled io slices pointed by this msghdr
pub fn iovs(&self) -> IoSliceIterator {
pub fn iovs(&self) -> IoSliceIterator<'a> {
IoSliceIterator {
index: 0,
remaining: self.bytes,
slices: unsafe {
// safe for as long as mgdr is properly initialized and references are valid.
// for multi messages API we initialize it with an empty
// slice and replace with a concrete buffer
// for single message API we hold a lifetime reference to ioslices
std::slice::from_raw_parts(self.mhdr.msg_iov as *const _, self.mhdr.msg_iovlen as _)
},
}
Expand Down Expand Up @@ -1782,7 +1791,7 @@ mod test {
let cmsg = cmsg_space!(crate::sys::socket::Timestamps);
sendmsg(ssock, &iov1, &[], flags, Some(&sock_addr)).unwrap();

let mut data = super::RecvMMsg::<()>::preallocate(recv_iovs.len(), Some(cmsg));
let mut data = super::MultHdrs::<()>::preallocate(recv_iovs.len(), Some(cmsg));

let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10));

Expand Down Expand Up @@ -1817,12 +1826,12 @@ mod test {
Ok(())
}
}
unsafe fn read_mhdr<'a, S>(
unsafe fn read_mhdr<'a, 'i, S>(
mhdr: msghdr,
r: isize,
msg_controllen: usize,
address: S,
) -> RecvMsg<'a, S>
) -> RecvMsg<'a, 'i, S>
where S: SockaddrLike
{
let cmsghdr = {
Expand All @@ -1841,6 +1850,7 @@ unsafe fn read_mhdr<'a, S>(
address: Some(address),
flags: MsgFlags::from_bits_truncate(mhdr.msg_flags),
mhdr,
iobufs: std::marker::PhantomData,
}
}

Expand Down Expand Up @@ -1948,8 +1958,9 @@ fn pack_mhdr_to_send<'a, I, C, S>(
/// [recvmsg(2)](https://pubs.opengroup.org/onlinepubs/9699919799/functions/recvmsg.html)
pub fn recvmsg<'a, 'outer, 'inner, S>(fd: RawFd, iov: &'outer mut [IoSliceMut<'inner>],
mut cmsg_buffer: Option<&'a mut Vec<u8>>,
flags: MsgFlags) -> Result<RecvMsg<'a, S>>
where S: SockaddrLike + 'a
flags: MsgFlags) -> Result<RecvMsg<'a, 'inner, S>>
where S: SockaddrLike + 'a,
'inner: 'outer
{
let mut address = mem::MaybeUninit::uninit();

Expand Down
48 changes: 24 additions & 24 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,31 +501,31 @@ mod recvfrom {
rsock,
ssock,
move |s, m, flags| {
let iov = [IoSlice::new(m)];
let mut msgs = vec![SendMmsgData {
iov: &iov,
cmsgs: &[],
addr: Some(sock_addr),
_lt: Default::default(),
}];

let batch_size = 15;
let mut iovs = Vec::with_capacity(1 + batch_size);
let mut addrs = Vec::with_capacity(1 + batch_size);
let mut data = MultHdrs::preallocate(1 + batch_size, None);
let iov = IoSlice::new(m);
// first chunk:
iovs.push([iov]);
addrs.push(Some(sock_addr));

for _ in 0..batch_size {
msgs.push(SendMmsgData {
iov: &iov,
cmsgs: &[],
addr: Some(sock_addr2),
_lt: Default::default(),
});
iovs.push([iov]);
addrs.push(Some(sock_addr2));
}
sendmmsg(s, msgs.iter(), flags).map(move |sent_bytes| {
assert!(!sent_bytes.is_empty());
for sent in &sent_bytes {
assert_eq!(*sent, m.len());
}
sent_bytes.len()
})

let res = sendmmsg(s, &mut data, &iovs, addrs, &[], flags)?;
let mut sent_messages = 0;
let mut sent_bytes = 0;
for item in res {
sent_messages += 1;
sent_bytes += item.bytes;
}
//
assert_eq!(sent_messages, iovs.len());
assert_eq!(sent_bytes, sent_messages * m.len());
Ok(sent_messages)
},
|_, _| {},
);
Expand Down Expand Up @@ -582,7 +582,7 @@ mod recvfrom {
.iter_mut()
.map(|buf| [IoSliceMut::new(&mut buf[..])]),
);
let mut data = RecvMMsg::<SockaddrIn>::preallocate(msgs.len(), None);
let mut data = MultHdrs::<SockaddrIn>::preallocate(msgs.len(), None);

let res: Vec<RecvMsg<SockaddrIn>> =
recvmmsg(rsock, &mut data, msgs.iter(), MsgFlags::empty(), None)
Expand Down Expand Up @@ -658,7 +658,7 @@ mod recvfrom {
);

let mut data =
RecvMMsg::<SockaddrIn>::preallocate(NUM_MESSAGES_SENT + 2, None);
MultHdrs::<SockaddrIn>::preallocate(NUM_MESSAGES_SENT + 2, None);

let res: Vec<RecvMsg<SockaddrIn>> = recvmmsg(
rsock,
Expand Down Expand Up @@ -1943,7 +1943,7 @@ fn test_recvmmsg_timestampns() {
let mut buffer = vec![0u8; message.len()];
let cmsgspace = nix::cmsg_space!(TimeSpec);
let iov = vec![[IoSliceMut::new(&mut buffer)]];
let mut data = RecvMMsg::preallocate(1, Some(cmsgspace));
let mut data = MultHdrs::preallocate(1, Some(cmsgspace));
let r: Vec<RecvMsg<()>> =
recvmmsg(in_socket, &mut data, iov.iter(), flags, None)
.unwrap()
Expand Down

0 comments on commit 1ae005d

Please sign in to comment.