Skip to content

Commit

Permalink
feat: use socket2 for SockAddr
Browse files Browse the repository at this point in the history
  • Loading branch information
fujiapple852 committed Mar 15, 2023
1 parent 68a8695 commit 32a06cd
Showing 1 changed file with 36 additions and 127 deletions.
163 changes: 36 additions & 127 deletions src/tracing/net/platform/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@ use crate::tracing::net::channel::MAX_PACKET_SIZE;
use crate::tracing::net::platform::windows::adapter::Adapters;
use crate::tracing::net::socket::TracerSocket;
use socket2::{Domain, Protocol, SockAddr, Type};
use std::ffi::c_void;
use std::io::{Error, ErrorKind, Result};
use std::mem::{size_of, zeroed};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4};
use std::os::windows::prelude::AsRawSocket;
use std::ptr::{addr_of, addr_of_mut, null_mut};
use std::time::Duration;
use windows_sys::Win32::Foundation::{WAIT_FAILED, WAIT_TIMEOUT};
use windows_sys::Win32::Networking::WinSock::{
AF_INET, AF_INET6, FD_CONNECT, FD_WRITE, ICMP_ERROR_INFO, IN6_ADDR, IN6_ADDR_0, IN_ADDR,
IN_ADDR_0, IPPROTO_RAW, IPPROTO_TCP, SIO_ROUTING_INTERFACE_QUERY, SOCKADDR_IN, SOCKADDR_IN6,
SOCKADDR_IN6_0, SOCKADDR_STORAGE, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, SO_PORT_SCALABILITY,
AF_INET, AF_INET6, FD_CONNECT, FD_WRITE, ICMP_ERROR_INFO, IPPROTO_RAW, IPPROTO_TCP,
SIO_ROUTING_INTERFACE_QUERY, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, SO_PORT_SCALABILITY,
SO_REUSE_UNICASTPORT, TCP_FAIL_CONNECT_ON_ICMP_ERROR, TCP_ICMP_ERROR_INFO, WSABUF, WSADATA,
WSAEADDRNOTAVAIL, WSAECONNREFUSED, WSAEHOSTUNREACH, WSAEINPROGRESS, WSA_IO_INCOMPLETE,
WSA_IO_PENDING,
Expand Down Expand Up @@ -102,7 +100,7 @@ pub struct Socket {
inner: socket2::Socket,
ol: Box<OVERLAPPED>,
buf: Vec<u8>,
from: Box<SOCKADDR_STORAGE>,
from: Box<SockAddr>,
}

#[allow(clippy::cast_possible_wrap)]
Expand All @@ -117,7 +115,7 @@ impl Socket {

fn new(domain: Domain, ty: Type, protocol: Option<Protocol>) -> Result<Self> {
let inner = socket2::Socket::new(domain, ty, protocol)?;
let from = Box::new(Self::new_sockaddr_storage());
let from = Box::new(Self::new_sockaddr());
let ol = Box::new(Self::new_overlapped());
let buf = vec![0u8; MAX_PACKET_SIZE];
Ok(Self {
Expand Down Expand Up @@ -196,7 +194,7 @@ impl Socket {
fn is_err(res: i32) -> bool {
res == SOCKET_ERROR && Error::last_os_error().raw_os_error() != Some(WSA_IO_PENDING)
}
let mut fromlen = std::mem::size_of::<SOCKADDR_STORAGE>() as i32;
let mut fromlen = self.from.len();
let wbuf = WSABUF {
len: MAX_PACKET_SIZE as u32,
buf: self.buf.as_mut_ptr(),
Expand All @@ -208,7 +206,7 @@ impl Socket {
1,
null_mut(),
&mut 0,
addr_of_mut!(*self.from).cast(),
self.from.as_mut_ptr(),
addr_of_mut!(fromlen),
addr_of_mut!(*self.ol),
None,
Expand Down Expand Up @@ -241,10 +239,8 @@ impl Socket {
unsafe { zeroed::<WSADATA>() }
}

#[allow(unsafe_code)]
fn new_sockaddr_storage() -> SOCKADDR_STORAGE {
// Safety: an all-zero value is valid for SOCKADDR_STORAGE.
unsafe { zeroed::<SOCKADDR_STORAGE>() }
fn new_sockaddr() -> SockAddr {
SockAddr::from(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))
}

#[allow(unsafe_code)]
Expand Down Expand Up @@ -272,7 +268,6 @@ impl Drop for Socket {
#[allow(clippy::cast_possible_wrap)]
impl TracerSocket for Socket {
fn new_icmp_send_socket_ipv4() -> Result<Self> {
// let sock = Self::new(AF_INET, SOCK_RAW, IPPROTO_RAW)?;
let sock = Self::new(Domain::IPV4, Type::RAW, Some(Protocol::from(IPPROTO_RAW)))?;
sock.set_non_blocking(true)?;
sock.set_header_included(true)?;
Expand Down Expand Up @@ -425,9 +420,8 @@ impl TracerSocket for Socket {
}

fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, Option<SocketAddr>)> {
let addr = sockaddrptr_to_ipaddr(addr_of_mut!(*self.from))?;
let len = self.read(buf)?;
Ok((len, Some(SocketAddr::new(addr, 0))))
Ok((len, self.from.as_socket()))
}

// TODO
Expand Down Expand Up @@ -483,132 +477,39 @@ impl TracerSocket for Socket {
}
}

/// NOTE under Windows, we cannot use a bind connect/getsockname as "If the socket
/// is using a connectionless protocol, the address may not be available until I/O
/// occurs on the socket." We use `SIO_ROUTING_INTERFACE_QUERY` instead.
/// Determine the src `IpAddr` used for routing to a given target `IpAddr`.
///
/// under Windows, we cannot use a bind connect/getsockname as "If the socket is using a connectionless protocol, the
/// address may not be available until I/O occurs on the socket.". Therefore we use `SIO_ROUTING_INTERFACE_QUERY`
/// instead.
///
/// Note that the `WSAIoctl` call potentially returns multiple results (see
/// <https://www.winsocketdotnetworkprogramming.com/winsock2programming/winsock2advancedsocketoptionioctl7h.html>),
/// and we currently choose the first one arbitrarily.
#[allow(clippy::cast_sign_loss)]
fn routing_interface_query(target: IpAddr) -> TraceResult<IpAddr> {
let src: *mut c_void = [0; 1024].as_mut_ptr().cast();
let mut src = Socket::new_sockaddr();
let dest = SockAddr::from(SocketAddr::new(target, 0));
let mut bytes = 0;
let socket = match target {
IpAddr::V4(_) => Socket::new_udp_dgram_socket_ipv4(),
IpAddr::V6(_) => Socket::new_udp_dgram_socket_ipv6(),
}?;
let (dest, destlen) = socketaddr_to_sockaddr(SocketAddr::new(target, 0));
syscall!(
WSAIoctl(
socket.inner.as_raw_socket() as _,
SIO_ROUTING_INTERFACE_QUERY,
addr_of!(dest).cast(),
destlen as u32,
src,
1024,
dest.as_ptr().cast(),
dest.len() as u32,
src.as_mut_ptr().cast(),
src.len() as u32,
addr_of_mut!(bytes),
null_mut(),
None,
),
|res| res == SOCKET_ERROR
)?;
// Note that the WSAIoctl call potentially returns multiple results (see
// <https://www.winsocketdotnetworkprogramming.com/winsock2programming/winsock2advancedsocketoptionioctl7h.html>),
// TBD We choose the first one arbitrarily.
let sockaddr = src.cast::<SOCKADDR_STORAGE>();
sockaddrptr_to_ipaddr(sockaddr).map_err(TracerError::IoError)
}

#[allow(unsafe_code)]
fn sockaddrptr_to_ipaddr(sockaddr: *mut SOCKADDR_STORAGE) -> Result<IpAddr> {
// Safety: TODO
match sockaddr_to_socketaddr(unsafe { sockaddr.as_ref().unwrap() }) {
Err(e) => Err(e),
Ok(socketaddr) => match socketaddr {
SocketAddr::V4(socketaddrv4) => Ok(IpAddr::V4(*socketaddrv4.ip())),
SocketAddr::V6(socketaddrv6) => Ok(IpAddr::V6(*socketaddrv6.ip())),
},
}
}

#[allow(unsafe_code)]
fn sockaddr_to_socketaddr(sockaddr: &SOCKADDR_STORAGE) -> Result<SocketAddr> {
let ptr = sockaddr as *const SOCKADDR_STORAGE;
let af = sockaddr.ss_family;
if af == AF_INET {
let sockaddr_in_ptr = ptr.cast::<SOCKADDR_IN>();
// Safety: TODO
let sockaddr_in = unsafe { *sockaddr_in_ptr };
let ipv4addr = u32::from_be(unsafe { sockaddr_in.sin_addr.S_un.S_addr });
let port = sockaddr_in.sin_port;
Ok(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::from(ipv4addr),
port,
)))
} else if af == AF_INET6 {
#[allow(clippy::cast_ptr_alignment)]
let sockaddr_in6_ptr = ptr.cast::<SOCKADDR_IN6>();
// Safety: TODO
let sockaddr_in6 = unsafe { *sockaddr_in6_ptr };
// TODO: check endianness
// Safety: TODO
let ipv6addr = unsafe { sockaddr_in6.sin6_addr.u.Byte };
let port = sockaddr_in6.sin6_port;
// Safety: TODO
let scope_id = unsafe { sockaddr_in6.Anonymous.sin6_scope_id };
Ok(SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::from(ipv6addr),
port,
sockaddr_in6.sin6_flowinfo,
scope_id,
)))
} else {
Err(Error::new(
ErrorKind::Unsupported,
format!("Unsupported address family: {af:?}"),
))
}
}

#[allow(unsafe_code)]
#[allow(clippy::cast_possible_wrap)]
#[must_use]
fn socketaddr_to_sockaddr(socketaddr: SocketAddr) -> (SOCKADDR_STORAGE, i32) {
#[repr(C)]
union SockAddr {
storage: SOCKADDR_STORAGE,
in4: SOCKADDR_IN,
in6: SOCKADDR_IN6,
}

let sockaddr = match socketaddr {
SocketAddr::V4(socketaddrv4) => SockAddr {
in4: SOCKADDR_IN {
sin_family: AF_INET,
sin_port: socketaddrv4.port().to_be(),
sin_addr: IN_ADDR {
S_un: IN_ADDR_0 {
S_addr: u32::from(*socketaddrv4.ip()).to_be(),
},
},
sin_zero: [0; 8],
},
},
SocketAddr::V6(socketaddrv6) => SockAddr {
in6: SOCKADDR_IN6 {
sin6_family: AF_INET6,
sin6_port: socketaddrv6.port().to_be(),
sin6_flowinfo: socketaddrv6.flowinfo(),
sin6_addr: IN6_ADDR {
u: IN6_ADDR_0 {
Byte: socketaddrv6.ip().octets(),
},
},
Anonymous: SOCKADDR_IN6_0 {
sin6_scope_id: socketaddrv6.scope_id(),
},
},
},
};

(unsafe { sockaddr.storage }, size_of::<SockAddr>() as i32)
Ok(src.as_socket().unwrap().ip())
}

fn lookup_interface_addr(adapters: &Adapters, name: &str) -> TraceResult<IpAddr> {
Expand All @@ -626,7 +527,7 @@ fn lookup_interface_addr(adapters: &Adapters, name: &str) -> TraceResult<IpAddr>

mod adapter {
use crate::tracing::error::{TraceResult, TracerError};
use crate::tracing::net::platform::windows::sockaddrptr_to_ipaddr;
use socket2::SockAddr;
use std::io::Error;
use std::marker::PhantomData;
use std::net::IpAddr;
Expand Down Expand Up @@ -745,7 +646,15 @@ mod adapter {
let first_unicast = (*self.next).FirstUnicastAddress;
let socket_address = (*first_unicast).Address;
let sockaddr = socket_address.lpSockaddr;
sockaddrptr_to_ipaddr(sockaddr.cast()).ok()?

// Safety: TODO
let (_, addr) = SockAddr::try_init(|s, _length| {
// TODO or memcpy?
*s = *sockaddr.cast();
Ok(())
})
.unwrap();
addr.as_socket().unwrap().ip()
};
self.next = (*self.next).Next;
Some(AdapterAddress {
Expand Down

0 comments on commit 32a06cd

Please sign in to comment.