diff --git a/Cargo.lock b/Cargo.lock index 5378414353..a511cede31 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3324,11 +3324,13 @@ name = "netwatch" version = "0.1.0" dependencies = [ "anyhow", + "atomic-waker", "bytes", "derive_more", "futures-lite 2.5.0", "futures-sink", "futures-util", + "iroh-quinn-udp", "libc", "netdev", "netlink-packet-core", diff --git a/iroh-net-report/src/reportgen/hairpin.rs b/iroh-net-report/src/reportgen/hairpin.rs index dc730a7c9a..17fd49e4f5 100644 --- a/iroh-net-report/src/reportgen/hairpin.rs +++ b/iroh-net-report/src/reportgen/hairpin.rs @@ -121,7 +121,7 @@ impl Actor { .context("net_report actor gone")?; msg_response_rx.await.context("net_report actor died")?; - if let Err(err) = socket.send_to(&stun::request(txn), dst).await { + if let Err(err) = socket.send_to(&stun::request(txn), dst.into()).await { warn!(%dst, "failed to send hairpin check"); return Err(err.into()); } diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index f4870f8377..87964a3108 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -36,7 +36,7 @@ use futures_util::stream::BoxStream; use iroh_base::key::NodeId; use iroh_metrics::{inc, inc_by}; use iroh_relay::protos::stun; -use netwatch::{interfaces, ip::LocalAddresses, netmon}; +use netwatch::{interfaces, ip::LocalAddresses, netmon, UdpSocket}; use quinn::AsyncUdpSocket; use rand::{seq::SliceRandom, Rng, SeedableRng}; use smallvec::{smallvec, SmallVec}; @@ -441,11 +441,8 @@ impl MagicSock { // Right now however we have one single poller behaving the same for each // connection. It checks all paths and returns Poll::Ready as soon as any path is // ready. - let ipv4_poller = Arc::new(self.pconn4.clone()).create_io_poller(); - let ipv6_poller = self - .pconn6 - .as_ref() - .map(|sock| Arc::new(sock.clone()).create_io_poller()); + let ipv4_poller = self.pconn4.create_io_poller(); + let ipv6_poller = self.pconn6.as_ref().map(|sock| sock.create_io_poller()); let relay_sender = self.relay_actor_sender.clone(); Box::pin(IoPoller { ipv4_poller, @@ -1091,10 +1088,9 @@ impl MagicSock { Err(err) if err.kind() == io::ErrorKind::WouldBlock => { // This is the socket .try_send_disco_message_udp used. let sock = self.conn_for_addr(dst)?; - let sock = Arc::new(sock.clone()); - let mut poller = sock.create_io_poller(); - match poller.as_mut().poll_writable(cx)? { - Poll::Ready(()) => continue, + match sock.as_socket_ref().poll_writable(cx) { + Poll::Ready(Ok(())) => continue, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } } @@ -1408,6 +1404,9 @@ impl Handle { let net_reporter = net_report::Client::new(Some(port_mapper.clone()), dns_resolver.clone())?; + let pconn4_sock = pconn4.as_socket(); + let pconn6_sock = pconn6.as_ref().map(|p| p.as_socket()); + let (actor_sender, actor_receiver) = mpsc::channel(256); let (relay_actor_sender, relay_actor_receiver) = mpsc::channel(256); let (udp_disco_sender, mut udp_disco_receiver) = mpsc::channel(256); @@ -1431,9 +1430,9 @@ impl Handle { ipv6_reported: Arc::new(AtomicBool::new(false)), relay_map, my_relay: Default::default(), - pconn4: pconn4.clone(), - pconn6: pconn6.clone(), net_reporter: net_reporter.addr(), + pconn4, + pconn6, disco_secrets: DiscoSecrets::default(), node_map, relay_actor_sender: relay_actor_sender.clone(), @@ -1481,8 +1480,8 @@ impl Handle { periodic_re_stun_timer: new_re_stun_timer(false), net_info_last: None, port_mapper, - pconn4, - pconn6, + pconn4: pconn4_sock, + pconn6: pconn6_sock, no_v4_send: false, net_reporter, network_monitor, @@ -1720,8 +1719,8 @@ struct Actor { net_info_last: Option, // The underlying UDP sockets used to send/rcv packets. - pconn4: UdpConn, - pconn6: Option, + pconn4: Arc, + pconn6: Option>, /// The NAT-PMP/PCP/UPnP prober/client, for requesting port mappings from NAT devices. port_mapper: portmapper::Client, @@ -1861,6 +1860,14 @@ impl Actor { debug!("link change detected: major? {}", is_major); if is_major { + if let Err(err) = self.pconn4.rebind() { + warn!("failed to rebind Udp IPv4 socket: {:?}", err); + }; + if let Some(ref pconn6) = self.pconn6 { + if let Err(err) = pconn6.rebind() { + warn!("failed to rebind Udp IPv6 socket: {:?}", err); + }; + } self.msock.dns_resolver.clear_cache(); self.msock.re_stun("link-change-major"); self.close_stale_relay_connections().await; @@ -1893,14 +1900,6 @@ impl Actor { self.port_mapper.deactivate(); self.relay_actor_cancel_token.cancel(); - // Ignore errors from pconnN - // They will frequently have been closed already by a call to connBind.Close. - debug!("stopping connections"); - if let Some(ref conn) = self.pconn6 { - conn.close().await.ok(); - } - self.pconn4.close().await.ok(); - debug!("shutdown complete"); return true; } @@ -2206,8 +2205,8 @@ impl Actor { } let relay_map = self.msock.relay_map.clone(); - let pconn4 = Some(self.pconn4.as_socket()); - let pconn6 = self.pconn6.as_ref().map(|p| p.as_socket()); + let pconn4 = Some(self.pconn4.clone()); + let pconn6 = self.pconn6.clone(); debug!("requesting net_report report"); match self @@ -3099,6 +3098,45 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_regression_network_change_rebind_wakes_connection_driver( + ) -> testresult::TestResult { + let _ = iroh_test::logging::setup(); + let m1 = MagicStack::new(RelayMode::Disabled).await?; + let m2 = MagicStack::new(RelayMode::Disabled).await?; + + println!("Net change"); + m1.endpoint.magic_sock().force_network_change(true).await; + tokio::time::sleep(Duration::from_secs(1)).await; // wait for socket rebinding + + let _guard = mesh_stacks(vec![m1.clone(), m2.clone()]).await?; + + let _handle = AbortOnDropHandle::new(tokio::spawn({ + let endpoint = m2.endpoint.clone(); + async move { + while let Some(incoming) = endpoint.accept().await { + println!("Incoming first conn!"); + let conn = incoming.await?; + conn.closed().await; + } + + testresult::TestResult::Ok(()) + } + })); + + println!("first conn!"); + let conn = m1 + .endpoint + .connect(m2.endpoint.node_addr().await?, ALPN) + .await?; + println!("Closing first conn"); + conn.close(0u32.into(), b"bye lolz"); + conn.closed().await; + println!("Closed first conn"); + + Ok(()) + } + #[tokio::test(flavor = "multi_thread")] async fn test_two_devices_roundtrip_network_change() -> Result<()> { time::timeout( diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 2c23d44f5b..8626c3fcec 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -1,25 +1,22 @@ use std::{ fmt::Debug, - future::Future, io, net::SocketAddr, pin::Pin, sync::Arc, - task::{ready, Context, Poll}, + task::{Context, Poll}, }; use anyhow::{bail, Context as _}; use netwatch::UdpSocket; use quinn::AsyncUdpSocket; -use quinn_udp::{Transmit, UdpSockRef}; -use tokio::io::Interest; -use tracing::{debug, trace}; +use quinn_udp::Transmit; +use tracing::debug; /// A UDP socket implementing Quinn's [`AsyncUdpSocket`]. -#[derive(Clone, Debug)] +#[derive(Debug, Clone)] pub struct UdpConn { io: Arc, - inner: Arc, } impl UdpConn { @@ -27,43 +24,34 @@ impl UdpConn { self.io.clone() } + pub(super) fn as_socket_ref(&self) -> &UdpSocket { + &self.io + } + pub(super) fn bind(addr: SocketAddr) -> anyhow::Result { let sock = bind(addr)?; - let state = quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&sock))?; - Ok(Self { - io: Arc::new(sock), - inner: Arc::new(state), - }) + + Ok(Self { io: Arc::new(sock) }) } pub fn port(&self) -> u16 { self.local_addr().map(|p| p.port()).unwrap_or_default() } - #[allow(clippy::unused_async)] - pub async fn close(&self) -> Result<(), io::Error> { - // Nothing to do atm - Ok(()) + pub(super) fn create_io_poller(&self) -> Pin> { + Box::pin(IoPoller { + io: self.io.clone(), + }) } } impl AsyncUdpSocket for UdpConn { fn create_io_poller(self: Arc) -> Pin> { - let sock = self.io.clone(); - Box::pin(IoPoller { - next_waiter: move || { - let sock = sock.clone(); - async move { sock.writable().await } - }, - waiter: None, - }) + (*self).create_io_poller() } fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> { - self.io.try_io(Interest::WRITABLE, || { - let sock_ref = UdpSockRef::from(&self.io); - self.inner.send(sock_ref, transmit) - }) + self.io.try_send_quinn(transmit) } fn poll_recv( @@ -72,24 +60,7 @@ impl AsyncUdpSocket for UdpConn { bufs: &mut [io::IoSliceMut<'_>], meta: &mut [quinn_udp::RecvMeta], ) -> Poll> { - loop { - ready!(self.io.poll_recv_ready(cx))?; - if let Ok(res) = self.io.try_io(Interest::READABLE, || { - self.inner.recv(Arc::as_ref(&self.io).into(), bufs, meta) - }) { - for meta in meta.iter().take(res) { - trace!( - src = %meta.addr, - len = meta.len, - count = meta.len / meta.stride, - dst = %meta.dst_ip.map(|x| x.to_string()).unwrap_or_default(), - "UDP recv" - ); - } - - return Poll::Ready(Ok(res)); - } - } + self.io.poll_recv_quinn(cx, bufs, meta) } fn local_addr(&self) -> io::Result { @@ -97,15 +68,15 @@ impl AsyncUdpSocket for UdpConn { } fn may_fragment(&self) -> bool { - self.inner.may_fragment() + self.io.may_fragment() } fn max_transmit_segments(&self) -> usize { - self.inner.max_gso_segments() + self.io.max_gso_segments() } fn max_receive_segments(&self) -> usize { - self.inner.gro_segments() + self.io.gro_segments() } } @@ -147,49 +118,14 @@ fn bind(mut addr: SocketAddr) -> anyhow::Result { } /// Poller for when the socket is writable. -/// -/// The tricky part is that we only have `tokio::net::UdpSocket::writable()` to create the -/// waiter we need, which does not return a named future type. In order to be able to store -/// this waiter in a struct without boxing we need to specify the future itself as a type -/// parameter, which we can only do if we introduce a second type parameter which returns -/// the future. So we end up with a function which we do not need, but it makes the types -/// work. -#[derive(derive_more::Debug)] -#[pin_project::pin_project] -struct IoPoller -where - F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future> + Send + Sync + 'static, -{ - /// Function which can create a new waiter if there is none. - #[debug("next_waiter")] - next_waiter: F, - /// The waiter which tells us when the socket is writable. - #[debug("waiter")] - #[pin] - waiter: Option, +#[derive(Debug)] +struct IoPoller { + io: Arc, } -impl quinn::UdpPoller for IoPoller -where - F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future> + Send + Sync + 'static, -{ +impl quinn::UdpPoller for IoPoller { fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let mut this = self.project(); - if this.waiter.is_none() { - this.waiter.set(Some((this.next_waiter)())); - } - let result = this - .waiter - .as_mut() - .as_pin_mut() - .expect("just set") - .poll(cx); - if result.is_ready() { - this.waiter.set(None); - } - result + self.io.poll_writable(cx) } } diff --git a/net-tools/netwatch/Cargo.toml b/net-tools/netwatch/Cargo.toml index 38637d45b6..2a0050666d 100644 --- a/net-tools/netwatch/Cargo.toml +++ b/net-tools/netwatch/Cargo.toml @@ -14,6 +14,7 @@ workspace = true [dependencies] anyhow = { version = "1" } +atomic-waker = "1.1.2" bytes = "1.7" futures-lite = "2.3" futures-sink = "0.3.25" @@ -21,10 +22,22 @@ futures-util = "0.3.25" libc = "0.2.139" netdev = "0.30.0" once_cell = "1.18.0" +quinn-udp = { package = "iroh-quinn-udp", version = "0.5.5" } socket2 = "0.5.3" thiserror = "1" time = "0.3.20" -tokio = { version = "1", features = ["io-util", "macros", "sync", "rt", "net", "fs", "io-std", "signal", "process", "time"] } +tokio = { version = "1", features = [ + "io-util", + "macros", + "sync", + "rt", + "net", + "fs", + "io-std", + "signal", + "process", + "time", +] } tokio-util = { version = "0.7", features = ["rt"] } tracing = "0.1" @@ -36,12 +49,26 @@ rtnetlink = "0.13.0" [target.'cfg(target_os = "windows")'.dependencies] wmi = "0.13" -windows = { version = "0.51", features = ["Win32_NetworkManagement_IpHelper", "Win32_Foundation", "Win32_NetworkManagement_Ndis", "Win32_Networking_WinSock"] } +windows = { version = "0.51", features = [ + "Win32_NetworkManagement_IpHelper", + "Win32_Foundation", + "Win32_NetworkManagement_Ndis", + "Win32_Networking_WinSock", +] } serde = { version = "1", features = ["derive"] } derive_more = { version = "1.0.0", features = ["debug"] } [dev-dependencies] -tokio = { version = "1", features = ["io-util", "sync", "rt", "net", "fs", "macros", "time", "test-util"] } +tokio = { version = "1", features = [ + "io-util", + "sync", + "rt", + "net", + "fs", + "macros", + "time", + "test-util", +] } [package.metadata.docs.rs] all-features = true diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 3aba36277f..ab9f130402 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -1,147 +1,910 @@ -use std::net::SocketAddr; +use std::{ + future::Future, + io, + net::SocketAddr, + pin::Pin, + sync::{atomic::AtomicBool, RwLock, RwLockReadGuard, TryLockError}, + task::{Context, Poll}, +}; -use anyhow::{ensure, Context, Result}; -use tracing::warn; +use atomic_waker::AtomicWaker; +use quinn_udp::Transmit; +use tokio::io::Interest; +use tracing::{debug, trace, warn}; use super::IpFamily; -/// Wrapper around a tokio UDP socket that handles the fact that -/// on drop `libc::close` can block for UDP sockets. +/// Wrapper around a tokio UDP socket. #[derive(Debug)] -pub struct UdpSocket(Option); +pub struct UdpSocket { + socket: RwLock, + recv_waker: AtomicWaker, + send_waker: AtomicWaker, + /// Set to true, when an error occurred, that means we need to rebind the socket. + is_broken: AtomicBool, +} /// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it /// is the max supported by a default configuration of macOS. Some platforms will silently clamp the value. const SOCKET_BUFFER_SIZE: usize = 7 << 20; impl UdpSocket { /// Bind only Ipv4 on any interface. - pub fn bind_v4(port: u16) -> Result { + pub fn bind_v4(port: u16) -> io::Result { Self::bind(IpFamily::V4, port) } /// Bind only Ipv6 on any interface. - pub fn bind_v6(port: u16) -> Result { + pub fn bind_v6(port: u16) -> io::Result { Self::bind(IpFamily::V6, port) } /// Bind only Ipv4 on localhost. - pub fn bind_local_v4(port: u16) -> Result { + pub fn bind_local_v4(port: u16) -> io::Result { Self::bind_local(IpFamily::V4, port) } /// Bind only Ipv6 on localhost. - pub fn bind_local_v6(port: u16) -> Result { + pub fn bind_local_v6(port: u16) -> io::Result { Self::bind_local(IpFamily::V6, port) } /// Bind to the given port only on localhost. - pub fn bind_local(network: IpFamily, port: u16) -> Result { + pub fn bind_local(network: IpFamily, port: u16) -> io::Result { let addr = SocketAddr::new(network.local_addr(), port); - Self::bind_raw(addr).with_context(|| format!("{addr:?}")) + Self::bind_raw(addr) } /// Bind to the given port and listen on all interfaces. - pub fn bind(network: IpFamily, port: u16) -> Result { + pub fn bind(network: IpFamily, port: u16) -> io::Result { let addr = SocketAddr::new(network.unspecified_addr(), port); - Self::bind_raw(addr).with_context(|| format!("{addr:?}")) + Self::bind_raw(addr) } /// Bind to any provided [`SocketAddr`]. - pub fn bind_full(addr: impl Into) -> Result { + pub fn bind_full(addr: impl Into) -> io::Result { Self::bind_raw(addr) } - fn bind_raw(addr: impl Into) -> Result { - let addr = addr.into(); + /// Is the socket broken and needs a rebind? + pub fn is_broken(&self) -> bool { + self.is_broken.load(std::sync::atomic::Ordering::Acquire) + } + + /// Marks this socket as needing a rebind + fn mark_broken(&self) { + self.is_broken + .store(true, std::sync::atomic::Ordering::Release); + } + + /// Rebind the underlying socket. + pub fn rebind(&self) -> io::Result<()> { + { + let mut guard = self.socket.write().unwrap(); + guard.rebind()?; + + // Clear errors + self.is_broken + .store(false, std::sync::atomic::Ordering::Release); + + drop(guard); + } + + // wakeup + self.wake_all(); + + Ok(()) + } + + fn bind_raw(addr: impl Into) -> io::Result { + let socket = SocketState::bind(addr.into())?; + + Ok(UdpSocket { + socket: RwLock::new(socket), + recv_waker: AtomicWaker::default(), + send_waker: AtomicWaker::default(), + is_broken: AtomicBool::new(false), + }) + } + + /// Receives a single datagram message on the socket from the remote address + /// to which it is connected. On success, returns the number of bytes read. + /// + /// The function must be called with valid byte array `buf` of sufficient + /// size to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// [`connect`]: method@Self::connect + pub fn recv<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFut<'a, 'b> { + RecvFut { + socket: self, + buffer, + } + } + + /// Receives a single datagram message on the socket. On success, returns + /// the number of bytes read and the origin. + /// + /// The function must be called with valid byte array `buf` of sufficient + /// size to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + pub fn recv_from<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFromFut<'a, 'b> { + RecvFromFut { + socket: self, + buffer, + } + } + + /// Sends data on the socket to the remote address that the socket is + /// connected to. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// [`connect`]: method@Self::connect + /// + /// # Return + /// + /// On success, the number of bytes sent is returned, otherwise, the + /// encountered error is returned. + pub fn send<'a, 'b>(&'b self, buffer: &'a [u8]) -> SendFut<'a, 'b> { + SendFut { + socket: self, + buffer, + } + } + + /// Sends data on the socket to the given address. On success, returns the + /// number of bytes written. + pub fn send_to<'a, 'b>(&'b self, buffer: &'a [u8], to: SocketAddr) -> SendToFut<'a, 'b> { + SendToFut { + socket: self, + buffer, + to, + } + } + + /// Connects the UDP socket setting the default destination for send() and + /// limiting packets that are read via `recv` from the address specified in + /// `addr`. + pub fn connect(&self, addr: SocketAddr) -> io::Result<()> { + tracing::info!("connectnig to {}", addr); + let guard = self.socket.read().unwrap(); + let (socket_tokio, _state) = guard.try_get_connected()?; + + let sock_ref = socket2::SockRef::from(&socket_tokio); + sock_ref.connect(&socket2::SockAddr::from(addr))?; + + Ok(()) + } + + /// Returns the local address of this socket. + pub fn local_addr(&self) -> io::Result { + let guard = self.socket.read().unwrap(); + let (socket, _state) = guard.try_get_connected()?; + + socket.local_addr() + } + + /// Closes the socket, and waits for the underlying `libc::close` call to be finished. + pub async fn close(&self) { + let socket = self.socket.write().unwrap().close(); + self.wake_all(); + if let Some((sock, _)) = socket { + let std_sock = sock.into_std(); + let res = tokio::runtime::Handle::current() + .spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }) + .await; + if let Err(err) = res { + warn!("failed to close socket: {:?}", err); + } + } + } + + /// Check if this socket is closed. + pub fn is_closed(&self) -> bool { + self.socket.read().unwrap().is_closed() + } + + /// Handle potential read errors, updating internal state. + /// + /// Returns `Some(error)` if the error is fatal otherwise `None. + fn handle_read_error(&self, error: io::Error) -> Option { + match error.kind() { + io::ErrorKind::NotConnected => { + // This indicates the underlying socket is broken, and we should attempt to rebind it + self.mark_broken(); + None + } + _ => Some(error), + } + } + + /// Handle potential write errors, updating internal state. + /// + /// Returns `Some(error)` if the error is fatal otherwise `None. + fn handle_write_error(&self, error: io::Error) -> Option { + match error.kind() { + io::ErrorKind::BrokenPipe => { + // This indicates the underlying socket is broken, and we should attempt to rebind it + self.mark_broken(); + None + } + _ => Some(error), + } + } + + /// Try to get a read lock for the sockets, but don't block for trying to acquire it. + fn poll_read_socket( + &self, + waker: &AtomicWaker, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let guard = match self.socket.try_read() { + Ok(guard) => guard, + Err(TryLockError::Poisoned(e)) => panic!("socket lock poisoned: {e}"), + Err(TryLockError::WouldBlock) => { + waker.register(cx.waker()); + + match self.socket.try_read() { + Ok(guard) => { + // we're actually fine, no need to cause a spurious wakeup + waker.take(); + guard + } + Err(TryLockError::Poisoned(e)) => panic!("socket lock poisoned: {e}"), + Err(TryLockError::WouldBlock) => { + // Ok fine, we registered our waker, the lock is really closed, + // we can return pending. + return Poll::Pending; + } + } + } + }; + Poll::Ready(guard) + } + + fn wake_all(&self) { + self.recv_waker.wake(); + self.send_waker.wake(); + } + + /// Checks if the socket needs a rebind, and if so does it. + /// + /// Returns an error if the rebind is needed, but failed. + fn maybe_rebind(&self) -> io::Result<()> { + if self.is_broken() { + self.rebind()?; + } + Ok(()) + } + + /// Poll for writable + pub fn poll_writable(&self, cx: &mut std::task::Context<'_>) -> Poll> { + loop { + if let Err(err) = self.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = futures_lite::ready!(self.poll_read_socket(&self.send_waker, cx)); + let (socket, _state) = guard.try_get_connected()?; + + match socket.poll_send_ready(cx) { + Poll::Pending => { + self.send_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => return Poll::Ready(Ok(())), + Poll::Ready(Err(err)) => { + if let Some(err) = self.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } + } + + /// Send a quinn based `Transmit`. + pub fn try_send_quinn(&self, transmit: &Transmit<'_>) -> io::Result<()> { + loop { + self.maybe_rebind()?; + + let guard = match self.socket.try_read() { + Ok(guard) => guard, + Err(TryLockError::Poisoned(e)) => { + panic!("lock poisoned: {:?}", e); + } + Err(TryLockError::WouldBlock) => { + return Err(io::Error::new(io::ErrorKind::WouldBlock, "")); + } + }; + let (socket, state) = guard.try_get_connected()?; + + let res = socket.try_io(Interest::WRITABLE, || state.send(socket.into(), transmit)); + + match res { + Ok(()) => return Ok(()), + Err(err) => match self.handle_write_error(err) { + Some(err) => return Err(err), + None => { + continue; + } + }, + } + } + } + + /// quinn based `poll_recv` + pub fn poll_recv_quinn( + &self, + cx: &mut Context, + bufs: &mut [io::IoSliceMut<'_>], + meta: &mut [quinn_udp::RecvMeta], + ) -> Poll> { + loop { + if let Err(err) = self.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = futures_lite::ready!(self.poll_read_socket(&self.recv_waker, cx)); + let (socket, state) = guard.try_get_connected()?; + + match socket.poll_recv_ready(cx) { + Poll::Pending => { + self.recv_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => { + // We are ready to read, continue + } + Poll::Ready(Err(err)) => match self.handle_read_error(err) { + Some(err) => return Poll::Ready(Err(err)), + None => { + continue; + } + }, + } + + let res = socket.try_io(Interest::READABLE, || state.recv(socket.into(), bufs, meta)); + match res { + Ok(count) => { + for meta in meta.iter().take(count) { + trace!( + src = %meta.addr, + len = meta.len, + count = meta.len / meta.stride, + dst = %meta.dst_ip.map(|x| x.to_string()).unwrap_or_default(), + "UDP recv" + ); + } + return Poll::Ready(Ok(count)); + } + Err(err) => { + // ignore spurious wakeups + if err.kind() == io::ErrorKind::WouldBlock { + continue; + } + match self.handle_read_error(err) { + Some(err) => return Poll::Ready(Err(err)), + None => { + continue; + } + } + } + } + } + } + + /// Whether transmitted datagrams might get fragmented by the IP layer + /// + /// Returns `false` on targets which employ e.g. the `IPV6_DONTFRAG` socket option. + pub fn may_fragment(&self) -> bool { + let guard = self.socket.read().unwrap(); + guard.may_fragment() + } + + /// The maximum amount of segments which can be transmitted if a platform + /// supports Generic Send Offload (GSO). + /// + /// This is 1 if the platform doesn't support GSO. Subject to change if errors are detected + /// while using GSO. + pub fn max_gso_segments(&self) -> usize { + let guard = self.socket.read().unwrap(); + guard.max_gso_segments() + } + + /// The number of segments to read when GRO is enabled. Used as a factor to + /// compute the receive buffer size. + /// + /// Returns 1 if the platform doesn't support GRO. + pub fn gro_segments(&self) -> usize { + let guard = self.socket.read().unwrap(); + guard.gro_segments() + } +} + +/// Receive future +#[derive(Debug)] +pub struct RecvFut<'a, 'b> { + socket: &'b UdpSocket, + buffer: &'a mut [u8], +} + +impl Future for RecvFut<'_, '_> { + type Output = io::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let Self { socket, buffer } = &mut *self; + + loop { + if let Err(err) = socket.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = futures_lite::ready!(socket.poll_read_socket(&socket.recv_waker, cx)); + let (inner_socket, _state) = guard.try_get_connected()?; + + match inner_socket.poll_recv_ready(cx) { + Poll::Pending => { + self.socket.recv_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => { + let res = inner_socket.try_recv(buffer); + if let Err(err) = res { + if err.kind() == io::ErrorKind::WouldBlock { + continue; + } + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + return Poll::Ready(res); + } + Poll::Ready(Err(err)) => { + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } + } +} + +/// Receive future +#[derive(Debug)] +pub struct RecvFromFut<'a, 'b> { + socket: &'b UdpSocket, + buffer: &'a mut [u8], +} + +impl Future for RecvFromFut<'_, '_> { + type Output = io::Result<(usize, SocketAddr)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let Self { socket, buffer } = &mut *self; + + loop { + if let Err(err) = socket.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = futures_lite::ready!(socket.poll_read_socket(&socket.recv_waker, cx)); + let (inner_socket, _state) = guard.try_get_connected()?; + + match inner_socket.poll_recv_ready(cx) { + Poll::Pending => { + self.socket.recv_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => { + let res = inner_socket.try_recv_from(buffer); + if let Err(err) = res { + if err.kind() == io::ErrorKind::WouldBlock { + continue; + } + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + return Poll::Ready(res); + } + Poll::Ready(Err(err)) => { + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } + } +} + +/// Send future +#[derive(Debug)] +pub struct SendFut<'a, 'b> { + socket: &'b UdpSocket, + buffer: &'a [u8], +} + +impl Future for SendFut<'_, '_> { + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + loop { + if let Err(err) = self.socket.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = + futures_lite::ready!(self.socket.poll_read_socket(&self.socket.send_waker, cx)); + let (socket, _state) = guard.try_get_connected()?; + + match socket.poll_send_ready(cx) { + Poll::Pending => { + self.socket.send_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => { + let res = socket.try_send(self.buffer); + if let Err(err) = res { + if err.kind() == io::ErrorKind::WouldBlock { + continue; + } + if let Some(err) = self.socket.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + return Poll::Ready(res); + } + Poll::Ready(Err(err)) => { + if let Some(err) = self.socket.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } + } +} + +/// Send future +#[derive(Debug)] +pub struct SendToFut<'a, 'b> { + socket: &'b UdpSocket, + buffer: &'a [u8], + to: SocketAddr, +} + +impl Future for SendToFut<'_, '_> { + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + loop { + if let Err(err) = self.socket.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = + futures_lite::ready!(self.socket.poll_read_socket(&self.socket.send_waker, cx)); + let (socket, _state) = guard.try_get_connected()?; + + match socket.poll_send_ready(cx) { + Poll::Pending => { + self.socket.send_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => { + let res = socket.try_send_to(self.buffer, self.to); + if let Err(err) = res { + if err.kind() == io::ErrorKind::WouldBlock { + continue; + } + + if let Some(err) = self.socket.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + return Poll::Ready(res); + } + Poll::Ready(Err(err)) => { + if let Some(err) = self.socket.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } + } +} + +#[derive(Debug)] +enum SocketState { + Connected { + socket: tokio::net::UdpSocket, + state: quinn_udp::UdpSocketState, + /// The addr we are binding to. + addr: SocketAddr, + }, + Closed { + last_max_gso_segments: usize, + last_gro_segments: usize, + last_may_fragment: bool, + }, +} + +impl SocketState { + fn try_get_connected( + &self, + ) -> io::Result<(&tokio::net::UdpSocket, &quinn_udp::UdpSocketState)> { + match self { + Self::Connected { + socket, + state, + addr: _, + } => Ok((socket, state)), + Self::Closed { .. } => { + warn!("socket closed"); + Err(io::Error::new(io::ErrorKind::BrokenPipe, "socket closed")) + } + } + } + + fn bind(addr: SocketAddr) -> io::Result { let network = IpFamily::from(addr.ip()); let socket = socket2::Socket::new( network.into(), socket2::Type::DGRAM, Some(socket2::Protocol::UDP), - ) - .context("socket create")?; + )?; if let Err(err) = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( + debug!( "failed to set recv_buffer_size to {}: {:?}", SOCKET_BUFFER_SIZE, err ); } if let Err(err) = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( + debug!( "failed to set send_buffer_size to {}: {:?}", SOCKET_BUFFER_SIZE, err ); } if network == IpFamily::V6 { // Avoid dualstack - socket.set_only_v6(true).context("only IPv6")?; + socket.set_only_v6(true)?; } // Binding must happen before calling quinn, otherwise `local_addr` // is not yet available on all OSes. - socket.bind(&addr.into()).context("binding")?; + socket.bind(&addr.into())?; // Ensure nonblocking - socket.set_nonblocking(true).context("nonblocking: true")?; + socket.set_nonblocking(true)?; let socket: std::net::UdpSocket = socket.into(); // Convert into tokio UdpSocket - let socket = tokio::net::UdpSocket::from_std(socket).context("conversion to tokio")?; - - if addr.port() != 0 { - let local_addr = socket.local_addr().context("local addr")?; - ensure!( - local_addr.port() == addr.port(), - "wrong port bound: {:?}: wanted: {} got {}", - network, - addr.port(), - local_addr.port(), - ); + let socket = tokio::net::UdpSocket::from_std(socket)?; + let socket_ref = quinn_udp::UdpSockRef::from(&socket); + let socket_state = quinn_udp::UdpSocketState::new(socket_ref)?; + + let local_addr = socket.local_addr()?; + if addr.port() != 0 && local_addr.port() != addr.port() { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "wrong port bound: {:?}: wanted: {} got {}", + network, + addr.port(), + local_addr.port(), + ), + )); } - Ok(UdpSocket(Some(socket))) + + Ok(Self::Connected { + socket, + state: socket_state, + addr: local_addr, + }) } -} -#[cfg(unix)] -impl std::os::fd::AsFd for UdpSocket { - fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> { - self.0.as_ref().expect("not dropped").as_fd() + fn rebind(&mut self) -> io::Result<()> { + let (addr, closed_state) = match self { + Self::Connected { state, addr, .. } => { + let s = SocketState::Closed { + last_max_gso_segments: state.max_gso_segments(), + last_gro_segments: state.gro_segments(), + last_may_fragment: state.may_fragment(), + }; + (*addr, s) + } + Self::Closed { .. } => { + return Err(io::Error::new( + io::ErrorKind::Other, + "socket is closed and cannot be rebound", + )); + } + }; + debug!("rebinding {}", addr); + + *self = closed_state; + *self = Self::bind(addr)?; + + Ok(()) } -} -#[cfg(windows)] -impl std::os::windows::io::AsSocket for UdpSocket { - fn as_socket(&self) -> std::os::windows::io::BorrowedSocket<'_> { - self.0.as_ref().expect("not dropped").as_socket() + fn is_closed(&self) -> bool { + matches!(self, Self::Closed { .. }) + } + + fn close(&mut self) -> Option<(tokio::net::UdpSocket, quinn_udp::UdpSocketState)> { + match self { + Self::Connected { state, .. } => { + let s = SocketState::Closed { + last_max_gso_segments: state.max_gso_segments(), + last_gro_segments: state.gro_segments(), + last_may_fragment: state.may_fragment(), + }; + let Self::Connected { socket, state, .. } = std::mem::replace(self, s) else { + unreachable!("just checked"); + }; + Some((socket, state)) + } + Self::Closed { .. } => None, + } } -} -impl From for UdpSocket { - fn from(socket: tokio::net::UdpSocket) -> Self { - Self(Some(socket)) + fn may_fragment(&self) -> bool { + match self { + Self::Connected { state, .. } => state.may_fragment(), + Self::Closed { + last_may_fragment, .. + } => *last_may_fragment, + } } -} -impl std::ops::Deref for UdpSocket { - type Target = tokio::net::UdpSocket; + fn max_gso_segments(&self) -> usize { + match self { + Self::Connected { state, .. } => state.max_gso_segments(), + Self::Closed { + last_max_gso_segments, + .. + } => *last_max_gso_segments, + } + } - fn deref(&self) -> &Self::Target { - self.0.as_ref().expect("only removed on drop") + fn gro_segments(&self) -> usize { + match self { + Self::Connected { state, .. } => state.gro_segments(), + Self::Closed { + last_gro_segments, .. + } => *last_gro_segments, + } } } impl Drop for UdpSocket { fn drop(&mut self) { - let std_sock = self.0.take().expect("not yet dropped").into_std(); + trace!("dropping UdpSocket"); + if let Some((socket, _)) = self.socket.write().unwrap().close() { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + // No wakeup after dropping write lock here, since we're getting dropped. + // this will be empty if `close` was called before + let std_sock = socket.into_std(); + handle.spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }); + } + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Context; + + use super::*; + + #[tokio::test] + async fn test_reconnect() -> anyhow::Result<()> { + let (s_b, mut r_b) = tokio::sync::mpsc::channel(16); + let handle_a = tokio::task::spawn(async move { + let socket = UdpSocket::bind_local(IpFamily::V4, 0)?; + let addr = socket.local_addr()?; + s_b.send(addr).await?; + println!("socket bound to {:?}", addr); + + let mut buffer = [0u8; 16]; + for i in 0..100 { + println!("-- tick {i}"); + let read = socket.recv_from(&mut buffer).await; + match read { + Ok((count, addr)) => { + println!("got {:?}", &buffer[..count]); + println!("sending {:?} to {:?}", &buffer[..count], addr); + socket.send_to(&buffer[..count], addr).await?; + } + Err(err) => { + eprintln!("error reading: {:?}", err); + } + } + } + socket.close().await; + anyhow::Ok(()) + }); + + let socket = UdpSocket::bind_local(IpFamily::V4, 0)?; + let first_addr = socket.local_addr()?; + println!("socket2 bound to {:?}", socket.local_addr()?); + let addr = r_b.recv().await.unwrap(); - // Only spawn_blocking if we are inside a tokio runtime, otherwise we just drop. - if let Ok(handle) = tokio::runtime::Handle::try_current() { - handle.spawn_blocking(move || { - // Calls libc::close, which can block - drop(std_sock); - }); + let mut buffer = [0u8; 16]; + for i in 0u8..100 { + println!("round one - {}", i); + socket.send_to(&[i][..], addr).await.context("send")?; + let (count, from) = socket.recv_from(&mut buffer).await.context("recv")?; + assert_eq!(addr, from); + assert_eq!(count, 1); + assert_eq!(buffer[0], i); + + // check for errors + assert!(!socket.is_broken()); + + // rebind + socket.rebind()?; + + // check that the socket has the same address as before + assert_eq!(socket.local_addr()?, first_addr); } + + handle_a.await.ok(); + + Ok(()) + } + + #[tokio::test] + async fn test_udp_mark_broken() -> anyhow::Result<()> { + let socket_a = UdpSocket::bind_local(IpFamily::V4, 0)?; + let addr_a = socket_a.local_addr()?; + println!("socket bound to {:?}", addr_a); + + let socket_b = UdpSocket::bind_local(IpFamily::V4, 0)?; + let addr_b = socket_b.local_addr()?; + println!("socket bound to {:?}", addr_b); + + let handle = tokio::task::spawn(async move { + let mut buffer = [0u8; 16]; + for _ in 0..2 { + match socket_b.recv_from(&mut buffer).await { + Ok((count, addr)) => { + println!("got {:?} from {:?}", &buffer[..count], addr); + } + Err(err) => { + eprintln!("error recv: {:?}", err); + } + } + } + }); + socket_a.send_to(&[0][..], addr_b).await?; + socket_a.mark_broken(); + assert!(socket_a.is_broken()); + socket_a.send_to(&[0][..], addr_b).await?; + assert!(!socket_a.is_broken()); + + handle.await?; + Ok(()) } } diff --git a/net-tools/portmapper/src/nat_pmp.rs b/net-tools/portmapper/src/nat_pmp.rs index a44c4aeb7e..b859729923 100644 --- a/net-tools/portmapper/src/nat_pmp.rs +++ b/net-tools/portmapper/src/nat_pmp.rs @@ -51,7 +51,7 @@ impl Mapping { ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = Request::Mapping { proto: MapProtocol::Udp, @@ -124,7 +124,7 @@ impl Mapping { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = Request::Mapping { proto: MapProtocol::Udp, @@ -167,7 +167,7 @@ async fn probe_available_fallible( ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = Request::ExternalAddress; socket.send(&req.encode()).await?; diff --git a/net-tools/portmapper/src/pcp.rs b/net-tools/portmapper/src/pcp.rs index 0f2fe789f5..2019bc3ca5 100644 --- a/net-tools/portmapper/src/pcp.rs +++ b/net-tools/portmapper/src/pcp.rs @@ -54,7 +54,7 @@ impl Mapping { ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let mut nonce = [0u8; 12]; rand::thread_rng().fill_bytes(&mut nonce); @@ -144,7 +144,7 @@ impl Mapping { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let local_port = local_port.into(); let req = protocol::Request::mapping(nonce, local_port, local_ip, None, None, 0); @@ -188,7 +188,7 @@ async fn probe_available_fallible( ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = protocol::Request::announce(local_ip.to_ipv6_mapped()); socket.send(&req.encode()).await?;