diff --git a/Cargo.toml b/Cargo.toml index 8ef85ebc..8e9493d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,18 +1,16 @@ [package] -name = "socket2" -version = "0.5.7" +name = "socket2-plus" +version = "0.1.1" authors = [ "Alex Crichton ", - "Thomas de Zeeuw " + "Thomas de Zeeuw ", + "keepsimple1 " ] license = "MIT OR Apache-2.0" readme = "README.md" -repository = "https://github.com/rust-lang/socket2" -homepage = "https://github.com/rust-lang/socket2" -documentation = "https://docs.rs/socket2" +repository = "https://github.com/keepsimple1/socket2-plus" description = """ -Utilities for handling networking sockets with a maximal amount of configuration -possible intended. +A superset of socket2 that focuses on safe APIs """ keywords = ["io", "socket", "network"] categories = ["api-bindings", "network-programming"] diff --git a/README.md b/README.md index 8bb09495..26fa4349 100644 --- a/README.md +++ b/README.md @@ -1,67 +1,14 @@ -# Socket2 +# Socket2-plus -Socket2 is a crate that provides utilities for creating and using sockets. +This is a superset of [`socket2`](https://crates.io/crates/socket2) that aims to provide some additional APIs currently missing from `socket2`. This library can be used by a dropped-in replacement for socket2. -The goal of this crate is to create and use a socket using advanced -configuration options (those that are not available in the types in the standard -library) without using any unsafe code. +The following APIs are added in the first version: -This crate provides as direct as possible access to the system's functionality -for sockets, this means little effort to provide cross-platform utilities. It is -up to the user to know how to use sockets when using this crate. *If you don't -know how to create a socket using libc/system calls then this crate is not for -you*. Most, if not all, functions directly relate to the equivalent system call -with no error handling applied, so no handling errors such as `EINTR`. As a -result using this crate can be a little wordy, but it should give you maximal -flexibility over configuration of sockets. +- `recv_from_initialized` to support `recv_from` with a regular initialized buffer. +- `recvmsg_initialized` to support `recvmsg` with `MsgHdrInit` that has initialized buffers. +- Also support Windows for `recvmsg_initialized`. -See the [API documentation] for more. - -[API documentation]: https://docs.rs/socket2 - -# Branches - -Currently Socket2 supports two versions: v0.5 and v0.4. Version 0.5 is being -developed in the master branch. Version 0.4 is developed in the [v0.4.x branch] -branch. - -[v0.4.x branch]: https://github.com/rust-lang/socket2/tree/v0.4.x - -# OS support - -Socket2 attempts to support the same OS/architectures as Rust does, see -https://doc.rust-lang.org/nightly/rustc/platform-support.html. However this is -not always possible, below is current list of support OSs. - -*If your favorite OS is not on the list consider contributing it! See [issue -#78].* - -[issue #78]: https://github.com/rust-lang/socket2/issues/78 - -### Tier 1 - -These OSs are tested with each commit in the CI and must always pass the tests. -All functions/types/etc., excluding ones behind the `all` feature, must work on -these OSs. - -* Linux -* macOS -* Windows - -### Tier 2 - -These OSs are currently build in the CI, but not tested. Not all -functions/types/etc. may work on these OSs, even ones **not** behind the `all` -feature flag. - -* Android -* FreeBSD -* Fuchsia -* iOS -* illumos -* NetBSD -* Redox -* Solaris +This first version is forked from `socket2` v0.5.7. We plan to rebase to the latest `socket2` stable release regularly. # Minimum Supported Rust Version (MSRV) diff --git a/src/lib.rs b/src/lib.rs index 89dce26e..8a2870f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,17 +58,17 @@ // Disallow warnings in examples. #![doc(test(attr(deny(warnings))))] -use std::fmt; #[cfg(not(target_os = "redox"))] -use std::io::IoSlice; +use std::io::{IoSlice, IoSliceMut}; #[cfg(not(target_os = "redox"))] use std::marker::PhantomData; #[cfg(not(target_os = "redox"))] use std::mem; use std::mem::MaybeUninit; -use std::net::SocketAddr; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::ops::{Deref, DerefMut}; use std::time::Duration; +use std::{fmt, ptr}; /// Macro to implement `fmt::Debug` for a type, printing the constant names /// rather than a number. @@ -736,3 +736,242 @@ impl<'name, 'bufs, 'control> fmt::Debug for MsgHdrMut<'name, 'bufs, 'control> { "MsgHdrMut".fmt(fmt) } } + +/// Configuration of a `recvmsg(2)` system call with initialized buffers. +/// +/// This wraps `msghdr` on Unix and `WSAMSG` on Windows and supports +/// fully initialized buffers. +#[cfg(not(target_os = "redox"))] +pub struct MsgHdrInit { + inner: sys::msghdr, +} + +#[cfg(not(target_os = "redox"))] +impl MsgHdrInit { + /// Create a new `MsgHdrInit` with all empty/zero fields. + #[allow(clippy::new_without_default)] + pub fn new() -> MsgHdrInit { + // SAFETY: all zero is valid for `msghdr` and `WSAMSG`. + MsgHdrInit { + inner: unsafe { mem::zeroed() }, + } + } + + /// Set the mutable address buffer to store the source address. + /// + /// Corresponds to setting `msg_name` and `msg_namelen` on Unix and `name` + /// and `namelen` on Windows. + #[allow(clippy::needless_pass_by_ref_mut)] + pub fn with_addr(mut self, addr: &mut SockAddr) -> Self { + sys::set_msghdr_name(&mut self.inner, addr); + self + } + + /// Set the mutable array of buffers for receiving the message. + /// + /// Corresponds to setting `msg_iov` and `msg_iovlen` on Unix and `lpBuffers` + /// and `dwBufferCount` on Windows. + /// + /// For example: using only a single buffer of 1k bytes: + /// ```ignore + /// let mut buffer = vec![0; 1024]; + /// let mut buf_list = [IoSliceMut::new(&mut buffer)]; + /// ``` + pub fn with_buffers(mut self, buf_list: &mut [IoSliceMut<'_>]) -> Self { + sys::set_msghdr_iov( + &mut self.inner, + buf_list.as_mut_ptr().cast(), + buf_list.len(), + ); + self + } + + /// Set the mutable control buffer of the message. + /// + /// Corresponds to setting `msg_control` and `msg_controllen` on Unix and + /// `Control` on Windows. + pub fn with_control(mut self, buf: &mut [u8]) -> Self { + sys::set_msghdr_control(&mut self.inner, buf.as_mut_ptr().cast(), buf.len()); + self + } + + /// Returns the list of control message headers in the message. + /// + /// This decodes the control messages inside the ancillary data buffer. + pub fn cmsg_hdr_vec(&self) -> Vec> { + let mut cmsg_vec = Vec::new(); + + let mut cmsg = self.inner.cmsg_first_hdr(); + if !cmsg.is_null() { + let cmsg_hdr = unsafe { CMsgHdr { inner: &*cmsg } }; + cmsg_vec.push(cmsg_hdr); + + cmsg = self.inner.cmsg_next_hdr(unsafe { &*cmsg }); + while !cmsg.is_null() { + let cmsg_hdr = unsafe { CMsgHdr { inner: &*cmsg } }; + cmsg_vec.push(cmsg_hdr); + } + } + + cmsg_vec + } +} + +#[cfg(not(target_os = "redox"))] +impl fmt::Debug for MsgHdrInit { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + "MsgHdrInit".fmt(fmt) + } +} + +/// Common operations supported on `msghdr` +pub(crate) trait MsgHdrOps { + fn cmsg_first_hdr(&self) -> *mut sys::cmsghdr; + + fn cmsg_next_hdr(&self, cmsg: &sys::cmsghdr) -> *mut sys::cmsghdr; +} + +/// Reference of a control message header in the control buffer in `MsgHdrInit` +#[cfg(not(target_os = "redox"))] +pub struct CMsgHdr<'a> { + inner: &'a sys::cmsghdr, +} + +impl CMsgHdr<'_> { + /// Get the cmsg level + pub fn get_level(&self) -> CMsgLevel { + self.inner.cmsg_level + } + + /// Get the cmsg type + pub fn get_type(&self) -> CMsgType { + self.inner.cmsg_type + } + + /// Decode this header as IN_PKTINFO + pub fn as_pktinfo_v4(&self) -> Option { + if self.inner.cmsg_level != sys::IPPROTO_IP { + return None; + } + + if self.inner.cmsg_type != sys::IP_PKTINFO { + return None; + } + + let data_ptr = self.inner.cmsg_data(); + let pktinfo = unsafe { ptr::read_unaligned(data_ptr as *const sys::InPktInfo) }; + + #[cfg(not(windows))] + let addr_dst = Ipv4Addr::from(u32::from_be(pktinfo.ipi_addr.s_addr)); + + #[cfg(windows)] + let addr_dst = Ipv4Addr::from(u32::from_be(unsafe { pktinfo.ipi_addr.S_un.S_addr })); + + Some(PktInfoV4 { + if_index: pktinfo.ipi_ifindex as _, + addr_dst, + }) + } + + /// Decode this header as IN6_PKTINFO + pub fn as_recvpktinfo_v6(&self) -> Option { + if self.inner.cmsg_level != sys::IPPROTO_IPV6 { + return None; + } + + if self.inner.cmsg_type != sys::IPV6_PKTINFO { + return None; + } + + let data_ptr = self.inner.cmsg_data(); + let pktinfo = unsafe { ptr::read_unaligned(data_ptr as *const sys::In6PktInfo) }; + + #[cfg(windows)] + let addr_dst = Ipv6Addr::from(unsafe { pktinfo.ipi6_addr.u.Byte }); + + #[cfg(not(windows))] + let addr_dst = Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr); + + Some(PktInfoV6 { + if_index: pktinfo.ipi6_ifindex as _, + addr_dst, + }) + } +} + +pub(crate) trait CMsgHdrOps { + /// Returns a pointer to the data portion of a cmsghdr. + fn cmsg_data(&self) -> *mut u8; +} + +/// Given a payload of `data_len`, returns the number of bytes a control message occupies. +/// i.e. it includes the header, the data and the alignments. +pub fn cmsg_space(data_len: usize) -> usize { + sys::_cmsg_space(data_len) +} + +#[cfg(not(target_os = "redox"))] +impl<'a> fmt::Debug for CMsgHdr<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "(len: {} level: {} type: {})", + self.inner.cmsg_len, self.inner.cmsg_level, self.inner.cmsg_type + ) + } +} + +const IN_PKTINFO_SIZE: usize = mem::size_of::(); +const IN6_PKTINFO_SIZE: usize = mem::size_of::(); + +/// Represents IN_PKTINFO structure. +#[derive(Debug)] +pub struct PktInfoV4 { + /// Interface index + pub if_index: u64, + + /// Header destination address + pub addr_dst: Ipv4Addr, +} + +impl PktInfoV4 { + /// The size in bytes for IPv4 pktinfo + pub const fn size() -> usize { + IN_PKTINFO_SIZE + } +} + +/// Represents IN6_PKTINFO structure. +#[derive(Debug)] +pub struct PktInfoV6 { + /// Interface index + pub if_index: u64, + + /// Header destination address + pub addr_dst: Ipv6Addr, +} + +impl PktInfoV6 { + /// The size in bytes for IPv6 pktinfo + pub const fn size() -> usize { + IN6_PKTINFO_SIZE + } +} + +/// Represents available protocols +pub type CMsgLevel = i32; + +/// constant for cmsg_level of IPPROTO_IP +pub const CMSG_LEVEL_IPPROTO_IP: CMsgLevel = sys::IPPROTO_IP; + +/// constant for cmsg_level of IPPROTO_IPV6 +pub const CMSG_LEVEL_IPPROTO_IPV6: CMsgLevel = sys::IPPROTO_IPV6; + +/// Represents available types of control messages. +pub type CMsgType = i32; + +/// constant for cmsghdr type +pub const CMSG_TYPE_IP_PKTINFO: CMsgType = sys::IP_PKTINFO; + +/// constant for cmsghdr type in IPv6 +pub const CMSG_TYPE_IPV6_PKTINFO: CMsgType = sys::IPV6_PKTINFO; diff --git a/src/socket.rs b/src/socket.rs index 21bcfcb2..4e18c073 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -23,7 +23,7 @@ use std::time::Duration; use crate::sys::{self, c_int, getsockopt, setsockopt, Bool}; #[cfg(all(unix, not(target_os = "redox")))] use crate::MsgHdrMut; -use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type}; +use crate::{Domain, MsgHdrInit, Protocol, SockAddr, TcpKeepalive, Type}; #[cfg(not(target_os = "redox"))] use crate::{MaybeUninitSlice, MsgHdr, RecvFlags}; @@ -74,6 +74,9 @@ use crate::{MaybeUninitSlice, MsgHdr, RecvFlags}; /// ``` pub struct Socket { inner: Inner, + + #[cfg(windows)] + wsarecvmsg: Option, } /// Store a `TcpStream` internally to take advantage of its niche optimizations on Unix platforms. @@ -87,25 +90,36 @@ impl Socket { /// function, often passed as mapping function, it's makes it very /// inconvenient to mark it as `unsafe`. pub(crate) fn from_raw(raw: sys::Socket) -> Socket { + let inner = unsafe { + // SAFETY: the caller must ensure that `raw` is a valid file + // descriptor, but when it isn't it could return I/O errors, or + // potentially close a fd it doesn't own. All of that isn't + // memory unsafe, so it's not desired but never memory unsafe or + // causes UB. + // + // However there is one exception. We use `TcpStream` to + // represent the `Socket` internally (see `Inner` type), + // `TcpStream` has a layout optimisation that doesn't allow for + // negative file descriptors (as those are always invalid). + // Violating this assumption (fd never negative) causes UB, + // something we don't want. So check for that we have this + // `assert!`. + #[cfg(unix)] + assert!(raw >= 0, "tried to create a `Socket` with an invalid fd"); + sys::socket_from_raw(raw) + }; + + #[cfg(windows)] + let wsarecvmsg = match sys::locate_wsarecvmsg(raw) { + Ok(fp) => Some(fp), + Err(_) => None, + }; + Socket { - inner: unsafe { - // SAFETY: the caller must ensure that `raw` is a valid file - // descriptor, but when it isn't it could return I/O errors, or - // potentially close a fd it doesn't own. All of that isn't - // memory unsafe, so it's not desired but never memory unsafe or - // causes UB. - // - // However there is one exception. We use `TcpStream` to - // represent the `Socket` internally (see `Inner` type), - // `TcpStream` has a layout optimisation that doesn't allow for - // negative file descriptors (as those are always invalid). - // Violating this assumption (fd never negative) causes UB, - // something we don't want. So check for that we have this - // `assert!`. - #[cfg(unix)] - assert!(raw >= 0, "tried to create a `Socket` with an invalid fd"); - sys::socket_from_raw(raw) - }, + inner, + + #[cfg(windows)] + wsarecvmsg, } } @@ -543,6 +557,16 @@ impl Socket { sys::recv_from(self.as_raw(), buf, flags) } + /// Receives data from the socket with `buf` that is fully initialized. + /// On success, returns the number of bytes read and the address from where the data came. + pub fn recv_from_initialized(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> { + // Safety: the `recv_from` implementation promises not to write uninitialised + // bytes to the buffer, so this casting is safe. + let buf_uninit = unsafe { &mut *(buf as *mut [u8] as *mut [MaybeUninit]) }; + + sys::recv_from(self.as_raw(), buf_uninit, 0) + } + /// Receives data from the socket. Returns the amount of bytes read, the /// [`RecvFlags`] and the remote address from the data is coming. Unlike /// [`recv_from`] this allows passing multiple buffers. @@ -642,6 +666,46 @@ impl Socket { sys::recvmsg(self.as_raw(), msg, flags) } + /// Receive a message from a socket using a message structure that is fully initialized. + #[cfg(all(unix, not(target_os = "redox")))] + #[cfg_attr(docsrs, doc(cfg(all(unix, not(target_os = "redox")))))] + pub fn recvmsg_initialized( + &self, + msg: &mut MsgHdrInit, + flags: sys::c_int, + ) -> io::Result { + sys::recvmsg_init(self.as_raw(), msg, flags) + } + + /// Recvmsg with initialized buffers + #[cfg(windows)] + pub fn recvmsg_initialized( + &self, + msg: &mut MsgHdrInit, + _flags: sys::c_int, + ) -> io::Result { + let wsarecvmsg = self.wsarecvmsg.ok_or(io::Error::new( + io::ErrorKind::NotFound, + "missing WSARECVMSG function", + ))?; + let mut read_bytes = 0; + let error_code = unsafe { + (wsarecvmsg)( + self.as_raw() as _, + &mut msg.inner, + &mut read_bytes, + std::ptr::null_mut(), + None, + ) + }; + + if error_code != 0 { + return Err(io::Error::last_os_error()); + } + + Ok(read_bytes as usize) + } + /// Sends data on the socket to a connected peer. /// /// This is typically used on TCP sockets or datagram sockets which have @@ -1643,6 +1707,25 @@ impl Socket { .map(|recv_tos| recv_tos > 0) } } + + /// Set IPv4 PKTINFO for this socket. + /// This should be called before the socket binds. + pub fn set_pktinfo_v4(&self) -> io::Result<()> { + let enable: i32 = 1; + unsafe { setsockopt(self.as_raw(), sys::IPPROTO_IP, sys::IP_PKTINFO, enable) } + } + + /// Set IPv6 PKTINFO for this socket. + /// This should be called before the socket binds. + pub fn set_recv_pktinfo_v6(&self) -> io::Result<()> { + #[cfg(not(windows))] + let optname = sys::IPV6_RECVPKTINFO; + + #[cfg(windows)] + let optname = sys::IPV6_PKTINFO; + + unsafe { setsockopt(self.as_raw(), sys::IPPROTO_IPV6, optname, 1) } + } } /// Socket options for IPv6 sockets, get/set using `IPPROTO_IPV6`. diff --git a/src/sys/unix.rs b/src/sys/unix.rs index 51ef4a5d..e766e0b6 100644 --- a/src/sys/unix.rs +++ b/src/sys/unix.rs @@ -245,6 +245,11 @@ pub(crate) use libc::{ ))] pub(crate) use libc::{TCP_KEEPCNT, TCP_KEEPINTVL}; +#[cfg(any(target_os = "macos", target_os = "linux"))] +pub(crate) use libc::{ + in6_pktinfo as In6PktInfo, in_pktinfo as InPktInfo, IPV6_PKTINFO, IPV6_RECVPKTINFO, IP_PKTINFO, +}; + // See this type in the Windows file. pub(crate) type Bool = c_int; @@ -711,6 +716,9 @@ pub(crate) fn unix_sockaddr(path: &Path) -> io::Result { #[cfg(not(target_os = "redox"))] pub(crate) use libc::msghdr; +#[cfg(not(target_os = "redox"))] +pub(crate) use libc::cmsghdr; + #[cfg(not(target_os = "redox"))] pub(crate) fn set_msghdr_name(msg: &mut msghdr, name: &SockAddr) { msg.msg_name = name.as_ptr() as *mut _; @@ -1109,6 +1117,36 @@ pub(crate) fn recvmsg( syscall!(recvmsg(fd, &mut msg.inner, flags)).map(|n| n as usize) } +use crate::{CMsgHdrOps, MsgHdrInit, MsgHdrOps}; +use libc::{CMSG_DATA, CMSG_FIRSTHDR, CMSG_NXTHDR, CMSG_SPACE}; + +#[cfg(not(target_os = "redox"))] +pub(crate) fn recvmsg_init(fd: Socket, msg: &mut MsgHdrInit, flags: c_int) -> io::Result { + syscall!(recvmsg(fd, &mut msg.inner, flags)).map(|n| n as usize) +} + +impl MsgHdrOps for msghdr { + fn cmsg_first_hdr(&self) -> *mut cmsghdr { + unsafe { CMSG_FIRSTHDR(self) } + } + + fn cmsg_next_hdr(&self, cmsg: &cmsghdr) -> *mut cmsghdr { + unsafe { CMSG_NXTHDR(self, cmsg) } + } +} + +impl CMsgHdrOps for cmsghdr { + fn cmsg_data(&self) -> *mut u8 { + unsafe { CMSG_DATA(self) } + } +} + +/// Given a payload of `data_len`, returns the number of bytes a control message occupies. +/// i.e. it includes the header, the data and the alignments. +pub(crate) fn _cmsg_space(data_len: usize) -> usize { + unsafe { CMSG_SPACE(data_len as _) as usize } +} + pub(crate) fn send(fd: Socket, buf: &[u8], flags: c_int) -> io::Result { syscall!(send( fd, diff --git a/src/sys/windows.rs b/src/sys/windows.rs index 11f2b7b0..80d468c9 100644 --- a/src/sys/windows.rs +++ b/src/sys/windows.rs @@ -24,11 +24,13 @@ use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_I use windows_sys::Win32::Networking::WinSock::SO_PROTOCOL_INFOW; use windows_sys::Win32::Networking::WinSock::{ self, tcp_keepalive, FIONBIO, IN6_ADDR, IN6_ADDR_0, INVALID_SOCKET, IN_ADDR, IN_ADDR_0, - POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, SD_BOTH, SD_RECEIVE, SD_SEND, SIO_KEEPALIVE_VALS, - SOCKET_ERROR, WSABUF, WSAEMSGSIZE, WSAESHUTDOWN, WSAPOLLFD, WSAPROTOCOL_INFOW, + LPFN_WSARECVMSG, LPWSAOVERLAPPED_COMPLETION_ROUTINE, POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, + SD_BOTH, SD_RECEIVE, SD_SEND, SIO_GET_EXTENSION_FUNCTION_POINTER, SIO_KEEPALIVE_VALS, + SOCKET_ERROR, WSAEMSGSIZE, WSAESHUTDOWN, WSAID_WSARECVMSG, WSAPOLLFD, WSAPROTOCOL_INFOW, WSA_FLAG_NO_HANDLE_INHERIT, WSA_FLAG_OVERLAPPED, }; use windows_sys::Win32::System::Threading::INFINITE; +use windows_sys::Win32::System::IO::OVERLAPPED; use crate::{MsgHdr, RecvFlags, SockAddr, TcpKeepalive, Type}; @@ -55,7 +57,8 @@ pub(crate) const SOCK_SEQPACKET: c_int = windows_sys::Win32::Networking::WinSock::SOCK_SEQPACKET as c_int; // Used in `Protocol`. pub(crate) use windows_sys::Win32::Networking::WinSock::{ - IPPROTO_ICMP, IPPROTO_ICMPV6, IPPROTO_TCP, IPPROTO_UDP, + CMSGHDR as cmsghdr, IN6_PKTINFO as In6PktInfo, IN_PKTINFO as InPktInfo, IPPROTO_ICMP, + IPPROTO_ICMPV6, IPPROTO_TCP, IPPROTO_UDP, IPV6_PKTINFO, IP_PKTINFO, WSABUF, }; // Used in `SockAddr`. pub(crate) use windows_sys::Win32::Networking::WinSock::{ @@ -192,6 +195,52 @@ impl<'a> MaybeUninitSlice<'a> { // Used in `MsgHdr`. pub(crate) use windows_sys::Win32::Networking::WinSock::WSAMSG as msghdr; +use crate::CMsgHdrOps; + +impl CMsgHdrOps for cmsghdr { + fn cmsg_data(&self) -> *mut u8 { + (self as *const _ as usize + cmsgdata_align(mem::size_of::())) as *mut u8 + } +} + +pub(crate) fn _cmsg_space(length: usize) -> usize { + cmsgdata_align(mem::size_of::() + cmsghdr_align(length)) +} + +// Helpers functions for `WinSock::WSAMSG` and `WinSock::CMSGHDR` are based on C macros from +// https://github.com/microsoft/win32metadata/blob/main/generation/WinSDK/RecompiledIdlHeaders/shared/ws2def.h#L741 +fn cmsghdr_align(length: usize) -> usize { + (length + mem::align_of::() - 1) & !(mem::align_of::() - 1) +} + +fn cmsgdata_align(length: usize) -> usize { + (length + mem::align_of::() - 1) & !(mem::align_of::() - 1) +} + +use crate::MsgHdrOps; + +impl MsgHdrOps for msghdr { + fn cmsg_first_hdr(&self) -> *mut cmsghdr { + if self.Control.len as usize >= mem::size_of::() { + self.Control.buf as *mut cmsghdr + } else { + ptr::null_mut::() + } + } + + fn cmsg_next_hdr(&self, cmsg: &cmsghdr) -> *mut cmsghdr { + let next = (cmsg as *const _ as usize + cmsghdr_align(cmsg.cmsg_len)) as *mut cmsghdr; + + // check if the end of the next cmsg overshoots the buf. + let max = self.Control.buf as usize + self.Control.len as usize; + if unsafe { next.offset(1) } as usize > max { + ptr::null_mut() + } else { + next + } + } +} + pub(crate) fn set_msghdr_name(msg: &mut msghdr, name: &SockAddr) { msg.name = name.as_ptr() as *mut _; msg.namelen = name.len(); @@ -686,6 +735,57 @@ pub(crate) fn sendmsg(socket: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> .map(|_| nsent as usize) } +pub(crate) type WSARecvMsgExtension = unsafe extern "system" fn( + s: Socket, + lpMsg: *mut msghdr, + lpdwNumberOfBytesRecvd: *mut u32, + lpOverlapped: *mut OVERLAPPED, + lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, +) -> i32; + +/// Find the WSARECVMSG function pointer +// +// This implementation is copied from: +// https://github.com/pixsper/socket-pktinfo/blob/3845f44eef707eaa3d34f9d4bc4ebcb6dc9c5959/src/win.rs#L44 +pub(crate) fn locate_wsarecvmsg(socket: Socket) -> io::Result { + let mut fn_pointer: usize = 0; + let mut byte_len: u32 = 0; + + let r = unsafe { + WinSock::WSAIoctl( + socket as _, + SIO_GET_EXTENSION_FUNCTION_POINTER, + &WSAID_WSARECVMSG as *const _ as *mut _, + mem::size_of_val(&WSAID_WSARECVMSG) as u32, + &mut fn_pointer as *const _ as *mut _, + mem::size_of_val(&fn_pointer) as u32, + &mut byte_len, + ptr::null_mut(), + None, + ) + }; + + if r != 0 { + return Err(io::Error::last_os_error()); + } + + if mem::size_of::() != byte_len as _ { + return Err(io::Error::new( + io::ErrorKind::Other, + "Locating fn pointer to WSARecvMsg returned different expected bytes", + )); + } + let cast_to_fn: LPFN_WSARECVMSG = unsafe { mem::transmute(fn_pointer) }; + + match cast_to_fn { + None => Err(io::Error::new( + io::ErrorKind::Other, + "WSARecvMsg extension not found", + )), + Some(extension) => Ok(extension), + } +} + /// Wrapper around `getsockopt` to deal with platform specific timeouts. pub(crate) fn timeout_opt(fd: Socket, lvl: c_int, name: i32) -> io::Result> { unsafe { getsockopt(fd, lvl, name).map(from_ms) } diff --git a/tests/socket.rs b/tests/socket.rs index 89b79f5f..22051e09 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -15,6 +15,7 @@ use std::fs::File; use std::io; #[cfg(not(any(target_os = "redox", target_os = "vita")))] use std::io::IoSlice; +use std::io::IoSliceMut; use std::io::Read; use std::io::Write; #[cfg(not(target_os = "vita"))] @@ -48,6 +49,12 @@ use std::thread; use std::time::Duration; use std::{env, fs}; +use socket2::cmsg_space; +use socket2::CMSG_LEVEL_IPPROTO_IP; +use socket2::CMSG_LEVEL_IPPROTO_IPV6; +use socket2::CMSG_TYPE_IPV6_PKTINFO; +use socket2::CMSG_TYPE_IP_PKTINFO; +use socket2::{MsgHdrInit, PktInfoV4, PktInfoV6}; #[cfg(windows)] use windows_sys::Win32::Foundation::{GetHandleInformation, HANDLE_FLAG_INHERIT}; @@ -746,6 +753,131 @@ fn send_from_recv_to_vectored() { assert_eq!(unsafe { assume_init(&swear) }, b"swear"); } +#[test] +fn send_to_recv_from_init() { + let (socket_a, socket_b) = udp_pair_unconnected(); + let addr_a = socket_a.local_addr().unwrap(); + let addr_b = socket_b.local_addr().unwrap(); + + let data = b"buf_init"; + let sent = socket_a.send_to(data, &addr_b).unwrap(); + assert_eq!(sent, data.len()); + + let mut buffer = vec![0; data.len()]; + let received = socket_b.recv_from_initialized(&mut buffer).unwrap(); + assert_eq!(received.0, data.len()); + assert_eq!(received.1, addr_a); + assert_eq!(&buffer, data); +} + +#[test] +fn sent_to_recvmsg_init_v6() { + let (socket_a, socket_b) = udp_pair_unconnected(); + let addr_a = socket_a.local_addr().unwrap(); + let addr_b = socket_b.local_addr().unwrap(); + + let data = b"sent_to_recvmsg_init"; + let sent = socket_a.send_to(data, &addr_b).unwrap(); + assert_eq!(sent, data.len()); + + let ipv4addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0); + let ipv6addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 8080, 0, 0); + let mut sockaddr = if addr_b.is_ipv4() { + SockAddr::from(ipv4addr) + } else { + SockAddr::from(ipv6addr) + }; + + let mut buffer = vec![0; data.len()]; + let mut bufs = [IoSliceMut::new(&mut buffer)]; + let mut msg_control = vec![0; cmsg_space(PktInfoV6::size())]; + let mut msg = MsgHdrInit::new() + .with_addr(&mut sockaddr) + .with_buffers(&mut bufs) + .with_control(&mut msg_control); + + socket_b.set_recv_pktinfo_v6().unwrap(); + let received = socket_b.recvmsg_initialized(&mut msg, 0).unwrap(); + + assert_eq!(received, data.len()); + assert_eq!(sockaddr, addr_a); + assert_eq!(buffer, data); + + let cmsg_vec = msg.cmsg_hdr_vec(); + assert!(!cmsg_vec.is_empty()); + println!("cmsg vec: {:?}", cmsg_vec); + + let mut pktinfo_found = false; + for cmsg_hdr in cmsg_vec { + if cmsg_hdr.get_level() == CMSG_LEVEL_IPPROTO_IPV6 + && cmsg_hdr.get_type() == CMSG_TYPE_IPV6_PKTINFO + { + if let Some(ipv6_pktinfo) = cmsg_hdr.as_recvpktinfo_v6() { + pktinfo_found = true; + println!("control message: v6 pktinfo: {:?}", ipv6_pktinfo); + } + } + } + assert!(pktinfo_found); +} + +#[test] +fn sent_to_recvmsg_init_v4() { + let (socket_a, socket_b) = udp_pair_unconnected_v4(); + let addr_a = socket_a.local_addr().unwrap(); + let addr_b = socket_b.local_addr().unwrap(); + + // Send a message. + let data = b"sent_to_recvmsg_init_v4"; + let sent = socket_a.send_to(data, &addr_b).unwrap(); + + assert_eq!(sent, data.len()); + + let ipv4addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0); + let ipv6addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 8080, 0, 0); + let mut sockaddr = if addr_b.is_ipv4() { + SockAddr::from(ipv4addr) + } else { + SockAddr::from(ipv6addr) + }; + + let mut buffer = vec![0; data.len()]; + let mut bufs = [IoSliceMut::new(&mut buffer)]; + let mut msg_control = vec![0; cmsg_space(PktInfoV4::size())]; + let mut msg = MsgHdrInit::new() + .with_addr(&mut sockaddr) + .with_buffers(&mut bufs) + .with_control(&mut msg_control); + + // Receive a mesage. + let received = socket_b.recvmsg_initialized(&mut msg, 0).unwrap(); + + // Verify the data received. + assert_eq!(received, data.len()); + assert_eq!(buffer, data); + + // Verify the source address. + assert_eq!(sockaddr, addr_a); + + // Verify the control message and the address that received the packet. + let cmsg_vec = msg.cmsg_hdr_vec(); + assert!(!cmsg_vec.is_empty()); + println!("cmsg vec: {:?}", cmsg_vec); + + let mut pktinfo_found = false; + for cmsg_hdr in cmsg_vec { + if cmsg_hdr.get_level() == CMSG_LEVEL_IPPROTO_IP + && cmsg_hdr.get_type() == CMSG_TYPE_IP_PKTINFO + { + if let Some(ip_pktinfo) = cmsg_hdr.as_pktinfo_v4() { + println!("control message: pktinfo: {:?}", ip_pktinfo); + pktinfo_found = true; + } + } + } + assert!(pktinfo_found); +} + #[test] #[cfg(not(any(target_os = "redox", target_os = "vita")))] fn sendmsg() { @@ -818,6 +950,40 @@ fn udp_pair_unconnected() -> (Socket, Socket) { let socket_a = Socket::new(Domain::IPV6, Type::DGRAM, None).unwrap(); let socket_b = Socket::new(Domain::IPV6, Type::DGRAM, None).unwrap(); + // Set the socket option before bind. + socket_b.set_recv_pktinfo_v6().unwrap(); + + socket_a.bind(&unspecified_addr.into()).unwrap(); + socket_b.bind(&unspecified_addr.into()).unwrap(); + + // Set low timeouts to prevent the tests from blocking. + socket_a + .set_read_timeout(Some(std::time::Duration::from_millis(10))) + .unwrap(); + socket_b + .set_read_timeout(Some(std::time::Duration::from_millis(10))) + .unwrap(); + socket_a + .set_write_timeout(Some(std::time::Duration::from_millis(10))) + .unwrap(); + socket_b + .set_write_timeout(Some(std::time::Duration::from_millis(10))) + .unwrap(); + + (socket_a, socket_b) +} + +/// Create a pair of non-connected UDP sockets suitable for unit tests. +#[cfg(not(any(target_os = "redox", target_os = "vita")))] +fn udp_pair_unconnected_v4() -> (Socket, Socket) { + // Use ephemeral ports assigned by the OS. + let unspecified_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0); + let socket_a = Socket::new(Domain::IPV4, Type::DGRAM, None).unwrap(); + let socket_b = Socket::new(Domain::IPV4, Type::DGRAM, None).unwrap(); + + // Set the socket option before bind. + socket_b.set_pktinfo_v4().unwrap(); + socket_a.bind(&unspecified_addr.into()).unwrap(); socket_b.bind(&unspecified_addr.into()).unwrap();