From 32a06cd5074d1974c7f48ef2f962757ac07a5f53 Mon Sep 17 00:00:00 2001 From: FujiApple Date: Tue, 14 Mar 2023 20:39:59 +0800 Subject: [PATCH] feat: use socket2 for SockAddr --- src/tracing/net/platform/windows.rs | 163 ++++++---------------------- 1 file changed, 36 insertions(+), 127 deletions(-) diff --git a/src/tracing/net/platform/windows.rs b/src/tracing/net/platform/windows.rs index 98ddb3eb..695d3f34 100644 --- a/src/tracing/net/platform/windows.rs +++ b/src/tracing/net/platform/windows.rs @@ -4,18 +4,16 @@ use crate::tracing::net::channel::MAX_PACKET_SIZE; use crate::tracing::net::platform::windows::adapter::Adapters; use crate::tracing::net::socket::TracerSocket; use socket2::{Domain, Protocol, SockAddr, Type}; -use std::ffi::c_void; use std::io::{Error, ErrorKind, Result}; use std::mem::{size_of, zeroed}; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}; use std::os::windows::prelude::AsRawSocket; use std::ptr::{addr_of, addr_of_mut, null_mut}; use std::time::Duration; use windows_sys::Win32::Foundation::{WAIT_FAILED, WAIT_TIMEOUT}; use windows_sys::Win32::Networking::WinSock::{ - AF_INET, AF_INET6, FD_CONNECT, FD_WRITE, ICMP_ERROR_INFO, IN6_ADDR, IN6_ADDR_0, IN_ADDR, - IN_ADDR_0, IPPROTO_RAW, IPPROTO_TCP, SIO_ROUTING_INTERFACE_QUERY, SOCKADDR_IN, SOCKADDR_IN6, - SOCKADDR_IN6_0, SOCKADDR_STORAGE, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, SO_PORT_SCALABILITY, + AF_INET, AF_INET6, FD_CONNECT, FD_WRITE, ICMP_ERROR_INFO, IPPROTO_RAW, IPPROTO_TCP, + SIO_ROUTING_INTERFACE_QUERY, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, SO_PORT_SCALABILITY, SO_REUSE_UNICASTPORT, TCP_FAIL_CONNECT_ON_ICMP_ERROR, TCP_ICMP_ERROR_INFO, WSABUF, WSADATA, WSAEADDRNOTAVAIL, WSAECONNREFUSED, WSAEHOSTUNREACH, WSAEINPROGRESS, WSA_IO_INCOMPLETE, WSA_IO_PENDING, @@ -102,7 +100,7 @@ pub struct Socket { inner: socket2::Socket, ol: Box, buf: Vec, - from: Box, + from: Box, } #[allow(clippy::cast_possible_wrap)] @@ -117,7 +115,7 @@ impl Socket { fn new(domain: Domain, ty: Type, protocol: Option) -> Result { let inner = socket2::Socket::new(domain, ty, protocol)?; - let from = Box::new(Self::new_sockaddr_storage()); + let from = Box::new(Self::new_sockaddr()); let ol = Box::new(Self::new_overlapped()); let buf = vec![0u8; MAX_PACKET_SIZE]; Ok(Self { @@ -196,7 +194,7 @@ impl Socket { fn is_err(res: i32) -> bool { res == SOCKET_ERROR && Error::last_os_error().raw_os_error() != Some(WSA_IO_PENDING) } - let mut fromlen = std::mem::size_of::() as i32; + let mut fromlen = self.from.len(); let wbuf = WSABUF { len: MAX_PACKET_SIZE as u32, buf: self.buf.as_mut_ptr(), @@ -208,7 +206,7 @@ impl Socket { 1, null_mut(), &mut 0, - addr_of_mut!(*self.from).cast(), + self.from.as_mut_ptr(), addr_of_mut!(fromlen), addr_of_mut!(*self.ol), None, @@ -241,10 +239,8 @@ impl Socket { unsafe { zeroed::() } } - #[allow(unsafe_code)] - fn new_sockaddr_storage() -> SOCKADDR_STORAGE { - // Safety: an all-zero value is valid for SOCKADDR_STORAGE. - unsafe { zeroed::() } + fn new_sockaddr() -> SockAddr { + SockAddr::from(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))) } #[allow(unsafe_code)] @@ -272,7 +268,6 @@ impl Drop for Socket { #[allow(clippy::cast_possible_wrap)] impl TracerSocket for Socket { fn new_icmp_send_socket_ipv4() -> Result { - // let sock = Self::new(AF_INET, SOCK_RAW, IPPROTO_RAW)?; let sock = Self::new(Domain::IPV4, Type::RAW, Some(Protocol::from(IPPROTO_RAW)))?; sock.set_non_blocking(true)?; sock.set_header_included(true)?; @@ -425,9 +420,8 @@ impl TracerSocket for Socket { } fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, Option)> { - let addr = sockaddrptr_to_ipaddr(addr_of_mut!(*self.from))?; let len = self.read(buf)?; - Ok((len, Some(SocketAddr::new(addr, 0)))) + Ok((len, self.from.as_socket())) } // TODO @@ -483,132 +477,39 @@ impl TracerSocket for Socket { } } -/// NOTE under Windows, we cannot use a bind connect/getsockname as "If the socket -/// is using a connectionless protocol, the address may not be available until I/O -/// occurs on the socket." We use `SIO_ROUTING_INTERFACE_QUERY` instead. +/// Determine the src `IpAddr` used for routing to a given target `IpAddr`. +/// +/// under Windows, we cannot use a bind connect/getsockname as "If the socket is using a connectionless protocol, the +/// address may not be available until I/O occurs on the socket.". Therefore we use `SIO_ROUTING_INTERFACE_QUERY` +/// instead. +/// +/// Note that the `WSAIoctl` call potentially returns multiple results (see +/// ), +/// and we currently choose the first one arbitrarily. #[allow(clippy::cast_sign_loss)] fn routing_interface_query(target: IpAddr) -> TraceResult { - let src: *mut c_void = [0; 1024].as_mut_ptr().cast(); + let mut src = Socket::new_sockaddr(); + let dest = SockAddr::from(SocketAddr::new(target, 0)); let mut bytes = 0; let socket = match target { IpAddr::V4(_) => Socket::new_udp_dgram_socket_ipv4(), IpAddr::V6(_) => Socket::new_udp_dgram_socket_ipv6(), }?; - let (dest, destlen) = socketaddr_to_sockaddr(SocketAddr::new(target, 0)); syscall!( WSAIoctl( socket.inner.as_raw_socket() as _, SIO_ROUTING_INTERFACE_QUERY, - addr_of!(dest).cast(), - destlen as u32, - src, - 1024, + dest.as_ptr().cast(), + dest.len() as u32, + src.as_mut_ptr().cast(), + src.len() as u32, addr_of_mut!(bytes), null_mut(), None, ), |res| res == SOCKET_ERROR )?; - // Note that the WSAIoctl call potentially returns multiple results (see - // ), - // TBD We choose the first one arbitrarily. - let sockaddr = src.cast::(); - sockaddrptr_to_ipaddr(sockaddr).map_err(TracerError::IoError) -} - -#[allow(unsafe_code)] -fn sockaddrptr_to_ipaddr(sockaddr: *mut SOCKADDR_STORAGE) -> Result { - // Safety: TODO - match sockaddr_to_socketaddr(unsafe { sockaddr.as_ref().unwrap() }) { - Err(e) => Err(e), - Ok(socketaddr) => match socketaddr { - SocketAddr::V4(socketaddrv4) => Ok(IpAddr::V4(*socketaddrv4.ip())), - SocketAddr::V6(socketaddrv6) => Ok(IpAddr::V6(*socketaddrv6.ip())), - }, - } -} - -#[allow(unsafe_code)] -fn sockaddr_to_socketaddr(sockaddr: &SOCKADDR_STORAGE) -> Result { - let ptr = sockaddr as *const SOCKADDR_STORAGE; - let af = sockaddr.ss_family; - if af == AF_INET { - let sockaddr_in_ptr = ptr.cast::(); - // Safety: TODO - let sockaddr_in = unsafe { *sockaddr_in_ptr }; - let ipv4addr = u32::from_be(unsafe { sockaddr_in.sin_addr.S_un.S_addr }); - let port = sockaddr_in.sin_port; - Ok(SocketAddr::V4(SocketAddrV4::new( - Ipv4Addr::from(ipv4addr), - port, - ))) - } else if af == AF_INET6 { - #[allow(clippy::cast_ptr_alignment)] - let sockaddr_in6_ptr = ptr.cast::(); - // Safety: TODO - let sockaddr_in6 = unsafe { *sockaddr_in6_ptr }; - // TODO: check endianness - // Safety: TODO - let ipv6addr = unsafe { sockaddr_in6.sin6_addr.u.Byte }; - let port = sockaddr_in6.sin6_port; - // Safety: TODO - let scope_id = unsafe { sockaddr_in6.Anonymous.sin6_scope_id }; - Ok(SocketAddr::V6(SocketAddrV6::new( - Ipv6Addr::from(ipv6addr), - port, - sockaddr_in6.sin6_flowinfo, - scope_id, - ))) - } else { - Err(Error::new( - ErrorKind::Unsupported, - format!("Unsupported address family: {af:?}"), - )) - } -} - -#[allow(unsafe_code)] -#[allow(clippy::cast_possible_wrap)] -#[must_use] -fn socketaddr_to_sockaddr(socketaddr: SocketAddr) -> (SOCKADDR_STORAGE, i32) { - #[repr(C)] - union SockAddr { - storage: SOCKADDR_STORAGE, - in4: SOCKADDR_IN, - in6: SOCKADDR_IN6, - } - - let sockaddr = match socketaddr { - SocketAddr::V4(socketaddrv4) => SockAddr { - in4: SOCKADDR_IN { - sin_family: AF_INET, - sin_port: socketaddrv4.port().to_be(), - sin_addr: IN_ADDR { - S_un: IN_ADDR_0 { - S_addr: u32::from(*socketaddrv4.ip()).to_be(), - }, - }, - sin_zero: [0; 8], - }, - }, - SocketAddr::V6(socketaddrv6) => SockAddr { - in6: SOCKADDR_IN6 { - sin6_family: AF_INET6, - sin6_port: socketaddrv6.port().to_be(), - sin6_flowinfo: socketaddrv6.flowinfo(), - sin6_addr: IN6_ADDR { - u: IN6_ADDR_0 { - Byte: socketaddrv6.ip().octets(), - }, - }, - Anonymous: SOCKADDR_IN6_0 { - sin6_scope_id: socketaddrv6.scope_id(), - }, - }, - }, - }; - - (unsafe { sockaddr.storage }, size_of::() as i32) + Ok(src.as_socket().unwrap().ip()) } fn lookup_interface_addr(adapters: &Adapters, name: &str) -> TraceResult { @@ -626,7 +527,7 @@ fn lookup_interface_addr(adapters: &Adapters, name: &str) -> TraceResult mod adapter { use crate::tracing::error::{TraceResult, TracerError}; - use crate::tracing::net::platform::windows::sockaddrptr_to_ipaddr; + use socket2::SockAddr; use std::io::Error; use std::marker::PhantomData; use std::net::IpAddr; @@ -745,7 +646,15 @@ mod adapter { let first_unicast = (*self.next).FirstUnicastAddress; let socket_address = (*first_unicast).Address; let sockaddr = socket_address.lpSockaddr; - sockaddrptr_to_ipaddr(sockaddr.cast()).ok()? + + // Safety: TODO + let (_, addr) = SockAddr::try_init(|s, _length| { + // TODO or memcpy? + *s = *sockaddr.cast(); + Ok(()) + }) + .unwrap(); + addr.as_socket().unwrap().ip() }; self.next = (*self.next).Next; Some(AdapterAddress {