From a5a7526d28518785c53215717e04f4e7d1418aff Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Sat, 16 Nov 2024 13:47:48 +0100 Subject: [PATCH 01/38] wip: rebinding socket functionality to deal with broken pipes --- iroh-net-report/src/reportgen/hairpin.rs | 2 +- iroh-net/src/magicsock/udp_conn.rs | 16 +- net-tools/netwatch/src/udp.rs | 352 +++++++++++++++++++---- net-tools/portmapper/src/nat_pmp.rs | 6 +- net-tools/portmapper/src/pcp.rs | 6 +- 5 files changed, 312 insertions(+), 70 deletions(-) diff --git a/iroh-net-report/src/reportgen/hairpin.rs b/iroh-net-report/src/reportgen/hairpin.rs index dc730a7c9a..17fd49e4f5 100644 --- a/iroh-net-report/src/reportgen/hairpin.rs +++ b/iroh-net-report/src/reportgen/hairpin.rs @@ -121,7 +121,7 @@ impl Actor { .context("net_report actor gone")?; msg_response_rx.await.context("net_report actor died")?; - if let Err(err) = socket.send_to(&stun::request(txn), dst).await { + if let Err(err) = socket.send_to(&stun::request(txn), dst.into()).await { warn!(%dst, "failed to send hairpin check"); return Err(err.into()); } diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 2c23d44f5b..3ab6412d1a 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -29,7 +29,10 @@ impl UdpConn { pub(super) fn bind(addr: SocketAddr) -> anyhow::Result { let sock = bind(addr)?; - let state = quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&sock))?; + let state = sock.with_socket(move |socket| { + quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket)) + })?; + Ok(Self { io: Arc::new(sock), inner: Arc::new(state), @@ -61,8 +64,10 @@ impl AsyncUdpSocket for UdpConn { fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> { self.io.try_io(Interest::WRITABLE, || { - let sock_ref = UdpSockRef::from(&self.io); - self.inner.send(sock_ref, transmit) + self.io.with_socket(|io| { + let sock_ref = UdpSockRef::from(io); + self.inner.send(sock_ref, transmit) + }) }) } @@ -73,9 +78,10 @@ impl AsyncUdpSocket for UdpConn { meta: &mut [quinn_udp::RecvMeta], ) -> Poll> { loop { - ready!(self.io.poll_recv_ready(cx))?; + ready!(self.io.with_socket(|io| io.poll_recv_ready(cx)))?; if let Ok(res) = self.io.try_io(Interest::READABLE, || { - self.inner.recv(Arc::as_ref(&self.io).into(), bufs, meta) + self.io + .with_socket(|io| self.inner.recv(io.into(), bufs, meta)) }) { for meta in meta.iter().take(res) { trace!( diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 3aba36277f..b90944019c 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -1,4 +1,10 @@ -use std::net::SocketAddr; +use std::{ + future::Future, + net::SocketAddr, + pin::Pin, + sync::{atomic::AtomicBool, Arc, RwLock}, + task::Poll, +}; use anyhow::{ensure, Context, Result}; use tracing::warn; @@ -8,7 +14,14 @@ use super::IpFamily; /// Wrapper around a tokio UDP socket that handles the fact that /// on drop `libc::close` can block for UDP sockets. #[derive(Debug)] -pub struct UdpSocket(Option); +pub struct UdpSocket { + // TODO: can we drop the Arc and use lifetimes in the futures? + socket: Arc>>, + /// The addr we are binding to. + addr: SocketAddr, + /// Set to true, when an error occured, that means we need to rebind the socket. + is_broken: AtomicBool, +} /// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it /// is the max supported by a default configuration of macOS. Some platforms will silently clamp the value. @@ -51,93 +64,316 @@ impl UdpSocket { Self::bind_raw(addr) } + /// Is the socket broken and needs a rebind? + pub fn needs_rebind(&self) -> bool { + self.is_broken.load(std::sync::atomic::Ordering::SeqCst) + } + + /// Marks this socket as needing a rebind + pub fn mark_broken(&self) { + self.is_broken + .store(true, std::sync::atomic::Ordering::SeqCst); + } + + /// Rebind the underlying socket. + pub async fn rebind(&self) -> Result<()> { + // Remove old socket + { + let mut guard = self.socket.write().unwrap(); + let std_sock = guard.take().expect("not yet dropped").into_std(); + tokio::runtime::Handle::current() + .spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }) + .await?; + } + + // Prepare new socket + let new_socket = inner_bind(self.addr)?; + + // Insert new socket + self.socket.write().unwrap().replace(new_socket); + + // Clear errors + self.is_broken + .store(false, std::sync::atomic::Ordering::SeqCst); + + Ok(()) + } + fn bind_raw(addr: impl Into) -> Result { let addr = addr.into(); - let network = IpFamily::from(addr.ip()); - let socket = socket2::Socket::new( - network.into(), - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - ) - .context("socket create")?; - - if let Err(err) = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( - "failed to set recv_buffer_size to {}: {:?}", - SOCKET_BUFFER_SIZE, err - ); + let socket = inner_bind(addr)?; + + Ok(UdpSocket { + socket: Arc::new(RwLock::new(Some(socket))), + addr, + is_broken: AtomicBool::new(false), + }) + } + + /// Use the socket + pub fn with_socket(&self, f: F) -> T + where + F: FnOnce(&tokio::net::UdpSocket) -> T, + { + let guard = self.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); + f(socket) + } + + pub fn try_io( + &self, + interest: tokio::io::Interest, + f: impl FnOnce() -> std::io::Result, + ) -> std::io::Result { + let guard = self.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); + socket.try_io(interest, f) + } + + pub fn writable(&self) -> WritableFut { + WritableFut { + socket: self.socket.clone(), } - if let Err(err) = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( - "failed to set send_buffer_size to {}: {:?}", - SOCKET_BUFFER_SIZE, err - ); + } + + /// TODO + pub fn recv<'a>(&self, buffer: &'a mut [u8]) -> RecvFut<'a> { + RecvFut { + socket: self.socket.clone(), + buffer, + } + } + + /// TODO + pub fn recv_from<'a>(&self, buffer: &'a mut [u8]) -> RecvFromFut<'a> { + RecvFromFut { + socket: self.socket.clone(), + buffer, + } + } + + /// TODO + pub fn send<'a>(&self, buffer: &'a [u8]) -> SendFut<'a> { + SendFut { + socket: self.socket.clone(), + buffer, } - if network == IpFamily::V6 { - // Avoid dualstack - socket.set_only_v6(true).context("only IPv6")?; + } + + /// TODO + pub fn send_to<'a>(&self, buffer: &'a [u8], to: SocketAddr) -> SendToFut<'a> { + SendToFut { + socket: self.socket.clone(), + buffer, + to, } + } + + /// TODO + pub fn connect(&self, addr: SocketAddr) -> std::io::Result<()> { + let mut guard = self.socket.write().unwrap(); + // dance around to make non async connect work + let socket_tokio = guard.take().expect("missing socket"); + let socket_std = socket_tokio.into_std()?; + socket_std.connect(addr)?; + let socket_tokio = tokio::net::UdpSocket::from_std(socket_std)?; + guard.replace(socket_tokio); + Ok(()) + } + + pub fn local_addr(&self) -> std::io::Result { + let guard = self.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); + socket.local_addr() + } +} + +/// Receive future +#[derive(Debug)] +pub struct RecvFut<'a> { + socket: Arc>>, + buffer: &'a mut [u8], +} + +impl<'a> Future for RecvFut<'a> { + type Output = std::io::Result; - // Binding must happen before calling quinn, otherwise `local_addr` - // is not yet available on all OSes. - socket.bind(&addr.into()).context("binding")?; + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let Self { socket, buffer } = &mut *self; + let guard = socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); - // Ensure nonblocking - socket.set_nonblocking(true).context("nonblocking: true")?; + match socket.poll_recv_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => { + let res = socket.try_recv(buffer); + Poll::Ready(res) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + } + } +} + +/// Receive future +#[derive(Debug)] +pub struct RecvFromFut<'a> { + socket: Arc>>, + buffer: &'a mut [u8], +} - let socket: std::net::UdpSocket = socket.into(); +impl<'a> Future for RecvFromFut<'a> { + type Output = std::io::Result<(usize, SocketAddr)>; - // Convert into tokio UdpSocket - let socket = tokio::net::UdpSocket::from_std(socket).context("conversion to tokio")?; + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let Self { socket, buffer } = &mut *self; + let guard = socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); - if addr.port() != 0 { - let local_addr = socket.local_addr().context("local addr")?; - ensure!( - local_addr.port() == addr.port(), - "wrong port bound: {:?}: wanted: {} got {}", - network, - addr.port(), - local_addr.port(), - ); + match socket.poll_recv_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => { + let res = socket.try_recv_from(buffer); + Poll::Ready(res) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), } - Ok(UdpSocket(Some(socket))) } } -#[cfg(unix)] -impl std::os::fd::AsFd for UdpSocket { - fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> { - self.0.as_ref().expect("not dropped").as_fd() +/// Writable future +#[derive(Debug)] +pub struct WritableFut { + socket: Arc>>, +} + +impl Future for WritableFut { + type Output = std::io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let guard = self.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); + + socket.poll_send_ready(cx) } } -#[cfg(windows)] -impl std::os::windows::io::AsSocket for UdpSocket { - fn as_socket(&self) -> std::os::windows::io::BorrowedSocket<'_> { - self.0.as_ref().expect("not dropped").as_socket() +/// Send future +#[derive(Debug)] +pub struct SendFut<'a> { + socket: Arc>>, + buffer: &'a [u8], +} + +impl<'a> Future for SendFut<'a> { + type Output = std::io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let guard = self.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); + + match socket.poll_send_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => { + let res = socket.try_send(self.buffer); + Poll::Ready(res) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + } } } -impl From for UdpSocket { - fn from(socket: tokio::net::UdpSocket) -> Self { - Self(Some(socket)) +/// Send future +#[derive(Debug)] +pub struct SendToFut<'a> { + socket: Arc>>, + buffer: &'a [u8], + to: SocketAddr, +} + +impl<'a> Future for SendToFut<'a> { + type Output = std::io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let guard = self.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); + + match socket.poll_send_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => { + let res = socket.try_send_to(self.buffer, self.to); + Poll::Ready(res) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + } } } -impl std::ops::Deref for UdpSocket { - type Target = tokio::net::UdpSocket; +fn inner_bind(addr: SocketAddr) -> Result { + let network = IpFamily::from(addr.ip()); + let socket = socket2::Socket::new( + network.into(), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + ) + .context("socket create")?; + + if let Err(err) = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE) { + warn!( + "failed to set recv_buffer_size to {}: {:?}", + SOCKET_BUFFER_SIZE, err + ); + } + if let Err(err) = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE) { + warn!( + "failed to set send_buffer_size to {}: {:?}", + SOCKET_BUFFER_SIZE, err + ); + } + if network == IpFamily::V6 { + // Avoid dualstack + socket.set_only_v6(true).context("only IPv6")?; + } + + // Binding must happen before calling quinn, otherwise `local_addr` + // is not yet available on all OSes. + socket.bind(&addr.into()).context("binding")?; + + // Ensure nonblocking + socket.set_nonblocking(true).context("nonblocking: true")?; - fn deref(&self) -> &Self::Target { - self.0.as_ref().expect("only removed on drop") + let socket: std::net::UdpSocket = socket.into(); + + // Convert into tokio UdpSocket + let socket = tokio::net::UdpSocket::from_std(socket).context("conversion to tokio")?; + + if addr.port() != 0 { + let local_addr = socket.local_addr().context("local addr")?; + ensure!( + local_addr.port() == addr.port(), + "wrong port bound: {:?}: wanted: {} got {}", + network, + addr.port(), + local_addr.port(), + ); } + + Ok(socket) } impl Drop for UdpSocket { fn drop(&mut self) { - let std_sock = self.0.take().expect("not yet dropped").into_std(); - // Only spawn_blocking if we are inside a tokio runtime, otherwise we just drop. if let Ok(handle) = tokio::runtime::Handle::try_current() { + let std_sock = self + .socket + .write() + .unwrap() + .take() + .expect("not yet dropped") + .into_std(); handle.spawn_blocking(move || { // Calls libc::close, which can block drop(std_sock); diff --git a/net-tools/portmapper/src/nat_pmp.rs b/net-tools/portmapper/src/nat_pmp.rs index a44c4aeb7e..b859729923 100644 --- a/net-tools/portmapper/src/nat_pmp.rs +++ b/net-tools/portmapper/src/nat_pmp.rs @@ -51,7 +51,7 @@ impl Mapping { ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = Request::Mapping { proto: MapProtocol::Udp, @@ -124,7 +124,7 @@ impl Mapping { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = Request::Mapping { proto: MapProtocol::Udp, @@ -167,7 +167,7 @@ async fn probe_available_fallible( ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = Request::ExternalAddress; socket.send(&req.encode()).await?; diff --git a/net-tools/portmapper/src/pcp.rs b/net-tools/portmapper/src/pcp.rs index 0f2fe789f5..2019bc3ca5 100644 --- a/net-tools/portmapper/src/pcp.rs +++ b/net-tools/portmapper/src/pcp.rs @@ -54,7 +54,7 @@ impl Mapping { ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let mut nonce = [0u8; 12]; rand::thread_rng().fill_bytes(&mut nonce); @@ -144,7 +144,7 @@ impl Mapping { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let local_port = local_port.into(); let req = protocol::Request::mapping(nonce, local_port, local_ip, None, None, 0); @@ -188,7 +188,7 @@ async fn probe_available_fallible( ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = protocol::Request::announce(local_ip.to_ipv6_mapped()); socket.send(&req.encode()).await?; From e435086900f326eadf3e6956ed55285ae0dd91d0 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Mon, 18 Nov 2024 11:17:39 +0100 Subject: [PATCH 02/38] wip test --- net-tools/netwatch/src/udp.rs | 198 +++++++++++++++++++++++++++++----- 1 file changed, 172 insertions(+), 26 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index b90944019c..91b7ce4069 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -1,5 +1,6 @@ use std::{ future::Future, + io::ErrorKind, net::SocketAddr, pin::Pin, sync::{atomic::AtomicBool, Arc, RwLock}, @@ -24,7 +25,7 @@ pub struct UdpSocket { } /// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it -/// is the max supported by a default configuration of macOS. Some platforms will silently clamp the value. +/// is the ma supported by a default configuration of macOS. Some platforms will silently clamp the value. const SOCKET_BUFFER_SIZE: usize = 7 << 20; impl UdpSocket { /// Bind only Ipv4 on any interface. @@ -206,13 +207,29 @@ impl<'a> Future for RecvFut<'a> { let guard = socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); - match socket.poll_recv_ready(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(())) => { - let res = socket.try_recv(buffer); - Poll::Ready(res) + loop { + println!("looping"); + match socket.poll_recv_ready(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => { + let res = socket.try_recv(buffer); + dbg!(&res); + if let Err(err) = res { + if err.kind() == ErrorKind::WouldBlock { + continue; + } + return Poll::Ready(Err(err)); + } + return Poll::Ready(res); + } + Poll::Ready(Err(err)) => { + dbg!(&err); + if err.kind() == ErrorKind::WouldBlock { + continue; + } + return Poll::Ready(Err(err)); + } } - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), } } } @@ -227,18 +244,29 @@ pub struct RecvFromFut<'a> { impl<'a> Future for RecvFromFut<'a> { type Output = std::io::Result<(usize, SocketAddr)>; - fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + fn poll(mut self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { let Self { socket, buffer } = &mut *self; let guard = socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); - match socket.poll_recv_ready(cx) { + match socket.poll_recv_ready(c) { Poll::Pending => Poll::Pending, Poll::Ready(Ok(())) => { let res = socket.try_recv_from(buffer); + if let Err(err) = res { + if err.kind() == ErrorKind::WouldBlock { + return Poll::Pending; + } + return Poll::Ready(Err(err)); + } Poll::Ready(res) } - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Ready(Err(err)) => { + if err.kind() == ErrorKind::WouldBlock { + return Poll::Pending; + } + Poll::Ready(Err(err)) + } } } } @@ -252,11 +280,11 @@ pub struct WritableFut { impl Future for WritableFut { type Output = std::io::Result<()>; - fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { let guard = self.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); - socket.poll_send_ready(cx) + socket.poll_send_ready(c) } } @@ -270,17 +298,25 @@ pub struct SendFut<'a> { impl<'a> Future for SendFut<'a> { type Output = std::io::Result; - fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { let guard = self.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); - match socket.poll_send_ready(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(())) => { - let res = socket.try_send(self.buffer); - Poll::Ready(res) + loop { + match socket.poll_send_ready(c) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => { + let res = socket.try_send(self.buffer); + if let Err(err) = res { + if err.kind() == ErrorKind::WouldBlock { + continue; + } + return Poll::Ready(Err(err)); + } + return Poll::Ready(res); + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), } - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), } } } @@ -296,17 +332,32 @@ pub struct SendToFut<'a> { impl<'a> Future for SendToFut<'a> { type Output = std::io::Result; - fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { let guard = self.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); - match socket.poll_send_ready(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(())) => { - let res = socket.try_send_to(self.buffer, self.to); - Poll::Ready(res) + println!("sending to: {:?}", self.to); + loop { + match dbg!(socket.poll_send_ready(c)) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => { + let res = socket.try_send_to(self.buffer, self.to); + dbg!(&res); + if let Err(err) = res { + if err.kind() == ErrorKind::WouldBlock { + continue; + } + return Poll::Ready(Err(err)); + } + return Poll::Ready(res); + } + Poll::Ready(Err(err)) => { + if err.kind() == ErrorKind::WouldBlock { + continue; + } + return Poll::Ready(Err(err)); + } } - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), } } } @@ -381,3 +432,98 @@ impl Drop for UdpSocket { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_reconnect() -> anyhow::Result<()> { + let (s_a, mut r_a) = tokio::sync::mpsc::channel(16); + let (s_b, mut r_b) = tokio::sync::mpsc::channel(16); + let handle_a = tokio::task::spawn(async move { + let socket = UdpSocket::bind_local(IpFamily::V4, 0)?; + let addr = socket.local_addr()?; + s_b.send(addr).await?; + println!("socket bound to {:?}", addr); + + let mut buffer = [0u8; 16]; + loop { + tokio::select! { + biased; + + Some(_) = r_a.recv() => { + println!("disconnecting"); + break; + } + read = socket.recv_from(&mut buffer) => { + match read { + Ok((count, addr)) => { + println!("got {:?}", &buffer[..count]); + println!("sending {:?} to {:?}", &buffer[..count], addr); + socket.send_to(&buffer[..count], addr).await?; + } + Err(err) => { + eprintln!("error reading: {:?}", err); + } + } + } + } + } + + r_a.recv().await.unwrap(); + // restart after the second message + println!("reconnecting"); + loop { + match socket.recv(&mut buffer).await { + Ok(count) => { + println!("got {:?}", &buffer[..count]); + } + Err(err) => { + eprintln!("error reading: {:?}", err); + } + } + } + + anyhow::Ok(()) + }); + + let socket = UdpSocket::bind_local(IpFamily::V4, 0)?; + println!("socket2 bound to {:?}", socket.local_addr()?); + let addr = r_b.recv().await.unwrap(); + + socket.connect(addr)?; + let mut buffer = [0u8; 16]; + for i in 0u8..100 { + println!("round one - {}", i); + socket.send(&[i][..]).await.context("send")?; + let count = socket.recv(&mut buffer).await.context("recv")?; + assert_eq!(count, 1); + assert_eq!(buffer[0], i); + } + + // interrupt + s_a.send(()).await?; + + // keep sending, should fail + for i in 0u8..10 { + let res = socket.send(&[i][..]).await; + println!("send: {:?}", res); + assert!(res.is_err()); + } + // restart + s_a.send(()).await?; + // keep sending, should succeed + for i in 0u8..10 { + socket.send(&[i][..]).await?; + let count = socket.recv(&mut buffer).await?; + assert_eq!(count, 1); + assert_eq!(buffer[0], i); + } + + handle_a.abort(); + handle_a.await.ok(); + + Ok(()) + } +} From c343e0e75eef9bfff097cf8cd542d2a5bca80a85 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Mon, 18 Nov 2024 13:30:38 +0100 Subject: [PATCH 03/38] simplify iopoller --- iroh-net/src/magicsock/udp_conn.rs | 53 ++++-------------------------- 1 file changed, 6 insertions(+), 47 deletions(-) diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 3ab6412d1a..5470c44f56 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -1,6 +1,5 @@ use std::{ fmt::Debug, - future::Future, io, net::SocketAddr, pin::Pin, @@ -52,13 +51,8 @@ impl UdpConn { impl AsyncUdpSocket for UdpConn { fn create_io_poller(self: Arc) -> Pin> { - let sock = self.io.clone(); Box::pin(IoPoller { - next_waiter: move || { - let sock = sock.clone(); - async move { sock.writable().await } - }, - waiter: None, + io: self.io.clone(), }) } @@ -153,49 +147,14 @@ fn bind(mut addr: SocketAddr) -> anyhow::Result { } /// Poller for when the socket is writable. -/// -/// The tricky part is that we only have `tokio::net::UdpSocket::writable()` to create the -/// waiter we need, which does not return a named future type. In order to be able to store -/// this waiter in a struct without boxing we need to specify the future itself as a type -/// parameter, which we can only do if we introduce a second type parameter which returns -/// the future. So we end up with a function which we do not need, but it makes the types -/// work. -#[derive(derive_more::Debug)] -#[pin_project::pin_project] -struct IoPoller -where - F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future> + Send + Sync + 'static, -{ - /// Function which can create a new waiter if there is none. - #[debug("next_waiter")] - next_waiter: F, - /// The waiter which tells us when the socket is writable. - #[debug("waiter")] - #[pin] - waiter: Option, +#[derive(Debug)] +struct IoPoller { + io: Arc, } -impl quinn::UdpPoller for IoPoller -where - F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future> + Send + Sync + 'static, -{ +impl quinn::UdpPoller for IoPoller { fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let mut this = self.project(); - if this.waiter.is_none() { - this.waiter.set(Some((this.next_waiter)())); - } - let result = this - .waiter - .as_mut() - .as_pin_mut() - .expect("just set") - .poll(cx); - if result.is_ready() { - this.waiter.set(None); - } - result + self.io.with_socket(|io| io.poll_send_ready(cx)) } } From df8c966cba596e5f84a5f73cc8c955314087cf30 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Mon, 18 Nov 2024 13:56:49 +0100 Subject: [PATCH 04/38] fixup basic test and impl --- net-tools/netwatch/src/udp.rs | 176 ++++++++++++++-------------------- 1 file changed, 74 insertions(+), 102 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 91b7ce4069..52b0fd30c8 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -66,7 +66,7 @@ impl UdpSocket { } /// Is the socket broken and needs a rebind? - pub fn needs_rebind(&self) -> bool { + pub fn is_broken(&self) -> bool { self.is_broken.load(std::sync::atomic::Ordering::SeqCst) } @@ -77,17 +77,13 @@ impl UdpSocket { } /// Rebind the underlying socket. - pub async fn rebind(&self) -> Result<()> { + pub fn rebind(&self) -> Result<()> { // Remove old socket { let mut guard = self.socket.write().unwrap(); - let std_sock = guard.take().expect("not yet dropped").into_std(); - tokio::runtime::Handle::current() - .spawn_blocking(move || { - // Calls libc::close, which can block - drop(std_sock); - }) - .await?; + let socket = guard.take().expect("not yet dropped"); + + drop(socket); } // Prepare new socket @@ -104,8 +100,12 @@ impl UdpSocket { } fn bind_raw(addr: impl Into) -> Result { - let addr = addr.into(); + let mut addr = addr.into(); let socket = inner_bind(addr)?; + if addr.port() == 0 { + // update to use selected port + addr.set_port(socket.local_addr()?.port()); + } Ok(UdpSocket { socket: Arc::new(RwLock::new(Some(socket))), @@ -185,11 +185,32 @@ impl UdpSocket { Ok(()) } + /// Returns the local address of this socket. pub fn local_addr(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); socket.local_addr() } + + /// Closes the socket, and waits for the underlying `libc::close` call to be finished. + pub async fn close(self) { + let std_sock = self + .socket + .write() + .unwrap() + .take() + .expect("not yet dropped") + .into_std(); + let res = tokio::runtime::Handle::current() + .spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }) + .await; + if let Err(err) = res { + warn!("failed to close socket: {:?}", err); + } + } } /// Receive future @@ -222,13 +243,7 @@ impl<'a> Future for RecvFut<'a> { } return Poll::Ready(res); } - Poll::Ready(Err(err)) => { - dbg!(&err); - if err.kind() == ErrorKind::WouldBlock { - continue; - } - return Poll::Ready(Err(err)); - } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), } } } @@ -244,28 +259,26 @@ pub struct RecvFromFut<'a> { impl<'a> Future for RecvFromFut<'a> { type Output = std::io::Result<(usize, SocketAddr)>; - fn poll(mut self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let Self { socket, buffer } = &mut *self; let guard = socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); - match socket.poll_recv_ready(c) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(())) => { - let res = socket.try_recv_from(buffer); - if let Err(err) = res { - if err.kind() == ErrorKind::WouldBlock { - return Poll::Pending; + loop { + match dbg!(socket.poll_recv_ready(cx)) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => { + let res = socket.try_recv_from(buffer); + dbg!(&res); + if let Err(err) = res { + if err.kind() == ErrorKind::WouldBlock { + continue; + } + return Poll::Ready(Err(err)); } - return Poll::Ready(Err(err)); - } - Poll::Ready(res) - } - Poll::Ready(Err(err)) => { - if err.kind() == ErrorKind::WouldBlock { - return Poll::Pending; + return Poll::Ready(res); } - Poll::Ready(Err(err)) + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), } } } @@ -351,12 +364,7 @@ impl<'a> Future for SendToFut<'a> { } return Poll::Ready(res); } - Poll::Ready(Err(err)) => { - if err.kind() == ErrorKind::WouldBlock { - continue; - } - return Poll::Ready(Err(err)); - } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), } } } @@ -418,17 +426,14 @@ impl Drop for UdpSocket { fn drop(&mut self) { // Only spawn_blocking if we are inside a tokio runtime, otherwise we just drop. if let Ok(handle) = tokio::runtime::Handle::try_current() { - let std_sock = self - .socket - .write() - .unwrap() - .take() - .expect("not yet dropped") - .into_std(); - handle.spawn_blocking(move || { - // Calls libc::close, which can block - drop(std_sock); - }); + if let Some(socket) = self.socket.write().unwrap().take() { + // this will be empty if `close` was called before + let std_sock = socket.into_std(); + handle.spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }); + } } } } @@ -439,7 +444,6 @@ mod tests { #[tokio::test] async fn test_reconnect() -> anyhow::Result<()> { - let (s_a, mut r_a) = tokio::sync::mpsc::channel(16); let (s_b, mut r_b) = tokio::sync::mpsc::channel(16); let handle_a = tokio::task::spawn(async move { let socket = UdpSocket::bind_local(IpFamily::V4, 0)?; @@ -448,80 +452,48 @@ mod tests { println!("socket bound to {:?}", addr); let mut buffer = [0u8; 16]; - loop { - tokio::select! { - biased; - - Some(_) = r_a.recv() => { - println!("disconnecting"); - break; - } - read = socket.recv_from(&mut buffer) => { - match read { - Ok((count, addr)) => { - println!("got {:?}", &buffer[..count]); - println!("sending {:?} to {:?}", &buffer[..count], addr); - socket.send_to(&buffer[..count], addr).await?; - } - Err(err) => { - eprintln!("error reading: {:?}", err); - } - } - } - } - } - - r_a.recv().await.unwrap(); - // restart after the second message - println!("reconnecting"); - loop { - match socket.recv(&mut buffer).await { - Ok(count) => { + for i in 0..100 { + println!("-- tick {i}"); + let read = socket.recv_from(&mut buffer).await; + match read { + Ok((count, addr)) => { println!("got {:?}", &buffer[..count]); + println!("sending {:?} to {:?}", &buffer[..count], addr); + socket.send_to(&buffer[..count], addr).await?; } Err(err) => { eprintln!("error reading: {:?}", err); } } } - + socket.close().await; anyhow::Ok(()) }); let socket = UdpSocket::bind_local(IpFamily::V4, 0)?; + let first_addr = socket.local_addr()?; println!("socket2 bound to {:?}", socket.local_addr()?); let addr = r_b.recv().await.unwrap(); - socket.connect(addr)?; let mut buffer = [0u8; 16]; for i in 0u8..100 { println!("round one - {}", i); - socket.send(&[i][..]).await.context("send")?; - let count = socket.recv(&mut buffer).await.context("recv")?; + socket.send_to(&[i][..], addr).await.context("send")?; + let (count, from) = socket.recv_from(&mut buffer).await.context("recv")?; + assert_eq!(addr, from); assert_eq!(count, 1); assert_eq!(buffer[0], i); - } - // interrupt - s_a.send(()).await?; + // check for errors + assert!(!socket.is_broken()); - // keep sending, should fail - for i in 0u8..10 { - let res = socket.send(&[i][..]).await; - println!("send: {:?}", res); - assert!(res.is_err()); - } - // restart - s_a.send(()).await?; - // keep sending, should succeed - for i in 0u8..10 { - socket.send(&[i][..]).await?; - let count = socket.recv(&mut buffer).await?; - assert_eq!(count, 1); - assert_eq!(buffer[0], i); + // rebind + socket.rebind()?; + + // check that the socket has the same address as before + assert_eq!(socket.local_addr()?, first_addr); } - handle_a.abort(); handle_a.await.ok(); Ok(()) From 02ff112cc9ca461aed1b378cb9a53d2651b76691 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Mon, 18 Nov 2024 13:57:53 +0100 Subject: [PATCH 05/38] cleanup --- net-tools/netwatch/src/udp.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 52b0fd30c8..ec5b37c224 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -229,12 +229,10 @@ impl<'a> Future for RecvFut<'a> { let socket = guard.as_ref().expect("missing socket"); loop { - println!("looping"); match socket.poll_recv_ready(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(())) => { let res = socket.try_recv(buffer); - dbg!(&res); if let Err(err) = res { if err.kind() == ErrorKind::WouldBlock { continue; @@ -265,11 +263,10 @@ impl<'a> Future for RecvFromFut<'a> { let socket = guard.as_ref().expect("missing socket"); loop { - match dbg!(socket.poll_recv_ready(cx)) { + match socket.poll_recv_ready(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(())) => { let res = socket.try_recv_from(buffer); - dbg!(&res); if let Err(err) = res { if err.kind() == ErrorKind::WouldBlock { continue; @@ -345,17 +342,15 @@ pub struct SendToFut<'a> { impl<'a> Future for SendToFut<'a> { type Output = std::io::Result; - fn poll(self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let guard = self.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); - println!("sending to: {:?}", self.to); loop { - match dbg!(socket.poll_send_ready(c)) { + match socket.poll_send_ready(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(())) => { let res = socket.try_send_to(self.buffer, self.to); - dbg!(&res); if let Err(err) = res { if err.kind() == ErrorKind::WouldBlock { continue; From 175d69cfb3412fb5d30402392d449ca68234d35c Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Mon, 18 Nov 2024 14:10:19 +0100 Subject: [PATCH 06/38] reduce cloning and add `rebind` method to UdpConn --- iroh-net/src/magicsock.rs | 37 +++++++++++------------------- iroh-net/src/magicsock/udp_conn.rs | 32 +++++++++++++++++--------- 2 files changed, 35 insertions(+), 34 deletions(-) diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index f4870f8377..b4ed4e1976 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -36,7 +36,7 @@ use futures_util::stream::BoxStream; use iroh_base::key::NodeId; use iroh_metrics::{inc, inc_by}; use iroh_relay::protos::stun; -use netwatch::{interfaces, ip::LocalAddresses, netmon}; +use netwatch::{interfaces, ip::LocalAddresses, netmon, UdpSocket}; use quinn::AsyncUdpSocket; use rand::{seq::SliceRandom, Rng, SeedableRng}; use smallvec::{smallvec, SmallVec}; @@ -441,11 +441,8 @@ impl MagicSock { // Right now however we have one single poller behaving the same for each // connection. It checks all paths and returns Poll::Ready as soon as any path is // ready. - let ipv4_poller = Arc::new(self.pconn4.clone()).create_io_poller(); - let ipv6_poller = self - .pconn6 - .as_ref() - .map(|sock| Arc::new(sock.clone()).create_io_poller()); + let ipv4_poller = self.pconn4.create_io_poller(); + let ipv6_poller = self.pconn6.as_ref().map(|sock| sock.create_io_poller()); let relay_sender = self.relay_actor_sender.clone(); Box::pin(IoPoller { ipv4_poller, @@ -1091,7 +1088,6 @@ impl MagicSock { Err(err) if err.kind() == io::ErrorKind::WouldBlock => { // This is the socket .try_send_disco_message_udp used. let sock = self.conn_for_addr(dst)?; - let sock = Arc::new(sock.clone()); let mut poller = sock.create_io_poller(); match poller.as_mut().poll_writable(cx)? { Poll::Ready(()) => continue, @@ -1408,6 +1404,9 @@ impl Handle { let net_reporter = net_report::Client::new(Some(port_mapper.clone()), dns_resolver.clone())?; + let pconn4_sock = pconn4.as_socket(); + let pconn6_sock = pconn6.as_ref().map(|p| p.as_socket()); + let (actor_sender, actor_receiver) = mpsc::channel(256); let (relay_actor_sender, relay_actor_receiver) = mpsc::channel(256); let (udp_disco_sender, mut udp_disco_receiver) = mpsc::channel(256); @@ -1431,9 +1430,9 @@ impl Handle { ipv6_reported: Arc::new(AtomicBool::new(false)), relay_map, my_relay: Default::default(), - pconn4: pconn4.clone(), - pconn6: pconn6.clone(), net_reporter: net_reporter.addr(), + pconn4, + pconn6, disco_secrets: DiscoSecrets::default(), node_map, relay_actor_sender: relay_actor_sender.clone(), @@ -1481,8 +1480,8 @@ impl Handle { periodic_re_stun_timer: new_re_stun_timer(false), net_info_last: None, port_mapper, - pconn4, - pconn6, + pconn4: pconn4_sock, + pconn6: pconn6_sock, no_v4_send: false, net_reporter, network_monitor, @@ -1720,8 +1719,8 @@ struct Actor { net_info_last: Option, // The underlying UDP sockets used to send/rcv packets. - pconn4: UdpConn, - pconn6: Option, + pconn4: Arc, + pconn6: Option>, /// The NAT-PMP/PCP/UPnP prober/client, for requesting port mappings from NAT devices. port_mapper: portmapper::Client, @@ -1893,14 +1892,6 @@ impl Actor { self.port_mapper.deactivate(); self.relay_actor_cancel_token.cancel(); - // Ignore errors from pconnN - // They will frequently have been closed already by a call to connBind.Close. - debug!("stopping connections"); - if let Some(ref conn) = self.pconn6 { - conn.close().await.ok(); - } - self.pconn4.close().await.ok(); - debug!("shutdown complete"); return true; } @@ -2206,8 +2197,8 @@ impl Actor { } let relay_map = self.msock.relay_map.clone(); - let pconn4 = Some(self.pconn4.as_socket()); - let pconn6 = self.pconn6.as_ref().map(|p| p.as_socket()); + let pconn4 = Some(self.pconn4.clone()); + let pconn6 = self.pconn6.clone(); debug!("requesting net_report report"); match self diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 5470c44f56..0431943791 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -15,10 +15,10 @@ use tokio::io::Interest; use tracing::{debug, trace}; /// A UDP socket implementing Quinn's [`AsyncUdpSocket`]. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct UdpConn { io: Arc, - inner: Arc, + inner: quinn_udp::UdpSocketState, } impl UdpConn { @@ -28,32 +28,42 @@ impl UdpConn { pub(super) fn bind(addr: SocketAddr) -> anyhow::Result { let sock = bind(addr)?; - let state = sock.with_socket(move |socket| { + let state = sock.with_socket(|socket| { quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket)) })?; Ok(Self { io: Arc::new(sock), - inner: Arc::new(state), + inner: state, }) } + pub(super) fn rebind(&mut self) -> anyhow::Result<()> { + // Rebind underlying socket + self.io.rebind()?; + + // update socket state + let new_state = self.io.with_socket(|socket| { + quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket)) + })?; + self.inner = new_state; + Ok(()) + } + pub fn port(&self) -> u16 { self.local_addr().map(|p| p.port()).unwrap_or_default() } - #[allow(clippy::unused_async)] - pub async fn close(&self) -> Result<(), io::Error> { - // Nothing to do atm - Ok(()) + pub(super) fn create_io_poller(&self) -> Pin> { + Box::pin(IoPoller { + io: self.io.clone(), + }) } } impl AsyncUdpSocket for UdpConn { fn create_io_poller(self: Arc) -> Pin> { - Box::pin(IoPoller { - io: self.io.clone(), - }) + (&*self).create_io_poller() } fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> { From 2d3451709ac3284687edb63b4101a63ed05126b1 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Mon, 18 Nov 2024 16:38:31 +0100 Subject: [PATCH 07/38] locking is hard --- net-tools/netwatch/src/udp.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index ec5b37c224..3c7cd5465e 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -79,10 +79,9 @@ impl UdpSocket { /// Rebind the underlying socket. pub fn rebind(&self) -> Result<()> { // Remove old socket + let mut guard = self.socket.write().unwrap(); { - let mut guard = self.socket.write().unwrap(); let socket = guard.take().expect("not yet dropped"); - drop(socket); } @@ -90,7 +89,7 @@ impl UdpSocket { let new_socket = inner_bind(self.addr)?; // Insert new socket - self.socket.write().unwrap().replace(new_socket); + guard.replace(new_socket); // Clear errors self.is_broken @@ -102,10 +101,8 @@ impl UdpSocket { fn bind_raw(addr: impl Into) -> Result { let mut addr = addr.into(); let socket = inner_bind(addr)?; - if addr.port() == 0 { - // update to use selected port - addr.set_port(socket.local_addr()?.port()); - } + // update to use selected port + addr.set_port(socket.local_addr()?.port()); Ok(UdpSocket { socket: Arc::new(RwLock::new(Some(socket))), From 082e94849ef1bdd11855b52921cc2c5370e221f2 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 19 Nov 2024 13:47:36 +0100 Subject: [PATCH 08/38] start handling read errors --- net-tools/netwatch/src/udp.rs | 117 ++++++++++++++++++++++------------ 1 file changed, 78 insertions(+), 39 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 3c7cd5465e..6515333c6e 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -3,7 +3,7 @@ use std::{ io::ErrorKind, net::SocketAddr, pin::Pin, - sync::{atomic::AtomicBool, Arc, RwLock}, + sync::{atomic::AtomicBool, RwLock}, task::Poll, }; @@ -17,7 +17,7 @@ use super::IpFamily; #[derive(Debug)] pub struct UdpSocket { // TODO: can we drop the Arc and use lifetimes in the futures? - socket: Arc>>, + socket: RwLock>, /// The addr we are binding to. addr: SocketAddr, /// Set to true, when an error occured, that means we need to rebind the socket. @@ -105,7 +105,7 @@ impl UdpSocket { addr.set_port(socket.local_addr()?.port()); Ok(UdpSocket { - socket: Arc::new(RwLock::new(Some(socket))), + socket: RwLock::new(Some(socket)), addr, is_broken: AtomicBool::new(false), }) @@ -131,40 +131,38 @@ impl UdpSocket { socket.try_io(interest, f) } - pub fn writable(&self) -> WritableFut { - WritableFut { - socket: self.socket.clone(), - } + pub fn writable(&self) -> WritableFut<'_> { + WritableFut { socket: self } } /// TODO - pub fn recv<'a>(&self, buffer: &'a mut [u8]) -> RecvFut<'a> { + pub fn recv<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFut<'a, 'b> { RecvFut { - socket: self.socket.clone(), + socket: self, buffer, } } /// TODO - pub fn recv_from<'a>(&self, buffer: &'a mut [u8]) -> RecvFromFut<'a> { + pub fn recv_from<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFromFut<'a, 'b> { RecvFromFut { - socket: self.socket.clone(), + socket: self, buffer, } } /// TODO - pub fn send<'a>(&self, buffer: &'a [u8]) -> SendFut<'a> { + pub fn send<'a, 'b>(&'b self, buffer: &'a [u8]) -> SendFut<'a, 'b> { SendFut { - socket: self.socket.clone(), + socket: self, buffer, } } /// TODO - pub fn send_to<'a>(&self, buffer: &'a [u8], to: SocketAddr) -> SendToFut<'a> { + pub fn send_to<'a, 'b>(&'b self, buffer: &'a [u8], to: SocketAddr) -> SendToFut<'a, 'b> { SendToFut { - socket: self.socket.clone(), + socket: self, buffer, to, } @@ -208,21 +206,46 @@ impl UdpSocket { warn!("failed to close socket: {:?}", err); } } + + /// Handle potential read errors, updating internal state. + /// + /// Returns `Some(error)` if the error is fatal otherwise `None. + fn handle_read_error(&self, error: std::io::Error) -> Option { + let kind = error.kind(); + match kind { + std::io::ErrorKind::BrokenPipe => { + // This indicates the underlying socket is broken, and we should attempt to rebind it + self.mark_broken(); + match self.rebind() { + Ok(()) => None, + Err(err) => { + // Return original error, if we failed to rebind + warn!( + "failed to rebind socket, after error: {:?}: {:?}", + error, err + ); + Some(error) + } + } + } + _ => Some(error), + } + } } /// Receive future #[derive(Debug)] -pub struct RecvFut<'a> { - socket: Arc>>, +pub struct RecvFut<'a, 'b> { + socket: &'b UdpSocket, buffer: &'a mut [u8], } -impl<'a> Future for RecvFut<'a> { +impl Future for RecvFut<'_, '_> { type Output = std::io::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let Self { socket, buffer } = &mut *self; - let guard = socket.read().unwrap(); + let guard = socket.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); loop { @@ -246,17 +269,17 @@ impl<'a> Future for RecvFut<'a> { /// Receive future #[derive(Debug)] -pub struct RecvFromFut<'a> { - socket: Arc>>, +pub struct RecvFromFut<'a, 'b> { + socket: &'b UdpSocket, buffer: &'a mut [u8], } -impl<'a> Future for RecvFromFut<'a> { +impl Future for RecvFromFut<'_, '_> { type Output = std::io::Result<(usize, SocketAddr)>; fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let Self { socket, buffer } = &mut *self; - let guard = socket.read().unwrap(); + let guard = socket.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); loop { @@ -280,15 +303,15 @@ impl<'a> Future for RecvFromFut<'a> { /// Writable future #[derive(Debug)] -pub struct WritableFut { - socket: Arc>>, +pub struct WritableFut<'a> { + socket: &'a UdpSocket, } -impl Future for WritableFut { +impl Future for WritableFut<'_> { type Output = std::io::Result<()>; fn poll(self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { - let guard = self.socket.read().unwrap(); + let guard = self.socket.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); socket.poll_send_ready(c) @@ -297,16 +320,16 @@ impl Future for WritableFut { /// Send future #[derive(Debug)] -pub struct SendFut<'a> { - socket: Arc>>, +pub struct SendFut<'a, 'b> { + socket: &'b UdpSocket, buffer: &'a [u8], } -impl<'a> Future for SendFut<'a> { +impl Future for SendFut<'_, '_> { type Output = std::io::Result; fn poll(self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { - let guard = self.socket.read().unwrap(); + let guard = self.socket.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); loop { @@ -318,11 +341,19 @@ impl<'a> Future for SendFut<'a> { if err.kind() == ErrorKind::WouldBlock { continue; } - return Poll::Ready(Err(err)); + if let Some(err) = self.socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; } return Poll::Ready(res); } - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Ready(Err(err)) => { + if let Some(err) = self.socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } } } } @@ -330,17 +361,17 @@ impl<'a> Future for SendFut<'a> { /// Send future #[derive(Debug)] -pub struct SendToFut<'a> { - socket: Arc>>, +pub struct SendToFut<'a, 'b> { + socket: &'b UdpSocket, buffer: &'a [u8], to: SocketAddr, } -impl<'a> Future for SendToFut<'a> { +impl Future for SendToFut<'_, '_> { type Output = std::io::Result; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - let guard = self.socket.read().unwrap(); + let guard = self.socket.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); loop { @@ -352,11 +383,19 @@ impl<'a> Future for SendToFut<'a> { if err.kind() == ErrorKind::WouldBlock { continue; } - return Poll::Ready(Err(err)); + if let Some(err) = self.socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; } return Poll::Ready(res); } - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Ready(Err(err)) => { + if let Some(err) = self.socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } } } } From e90d1826fd206c698bdbfaacbc0d0caf9c715d74 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 19 Nov 2024 13:50:28 +0100 Subject: [PATCH 09/38] fix: hold locks for less time --- net-tools/netwatch/src/udp.rs | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 6515333c6e..fbb178160c 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -245,10 +245,11 @@ impl Future for RecvFut<'_, '_> { fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let Self { socket, buffer } = &mut *self; - let guard = socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); loop { + let guard = socket.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); + match socket.poll_recv_ready(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(())) => { @@ -279,10 +280,11 @@ impl Future for RecvFromFut<'_, '_> { fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let Self { socket, buffer } = &mut *self; - let guard = socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); loop { + let guard = socket.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); + match socket.poll_recv_ready(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(())) => { @@ -329,10 +331,10 @@ impl Future for SendFut<'_, '_> { type Output = std::io::Result; fn poll(self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { - let guard = self.socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); - loop { + let guard = self.socket.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); + match socket.poll_send_ready(c) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(())) => { @@ -341,6 +343,7 @@ impl Future for SendFut<'_, '_> { if err.kind() == ErrorKind::WouldBlock { continue; } + drop(guard); // make sure we are not holding a lock before handling the error if let Some(err) = self.socket.handle_read_error(err) { return Poll::Ready(Err(err)); } @@ -349,6 +352,7 @@ impl Future for SendFut<'_, '_> { return Poll::Ready(res); } Poll::Ready(Err(err)) => { + drop(guard); // make sure we are not holding a lock before handling the error if let Some(err) = self.socket.handle_read_error(err) { return Poll::Ready(Err(err)); } @@ -371,10 +375,10 @@ impl Future for SendToFut<'_, '_> { type Output = std::io::Result; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - let guard = self.socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); - loop { + let guard = self.socket.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); + match socket.poll_send_ready(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(())) => { @@ -383,6 +387,7 @@ impl Future for SendToFut<'_, '_> { if err.kind() == ErrorKind::WouldBlock { continue; } + drop(guard); // make sure we are not holding a lock before handling the error if let Some(err) = self.socket.handle_read_error(err) { return Poll::Ready(Err(err)); } @@ -391,6 +396,7 @@ impl Future for SendToFut<'_, '_> { return Poll::Ready(res); } Poll::Ready(Err(err)) => { + drop(guard); // make sure we are not holding a lock before handling the error if let Some(err) = self.socket.handle_read_error(err) { return Poll::Ready(Err(err)); } From c81428406f6f11e8cbb4ee550c1687adf51fff63 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 19 Nov 2024 14:04:00 +0100 Subject: [PATCH 10/38] move rebind call --- net-tools/netwatch/src/udp.rs | 69 ++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index fbb178160c..445a948b56 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -211,22 +211,11 @@ impl UdpSocket { /// /// Returns `Some(error)` if the error is fatal otherwise `None. fn handle_read_error(&self, error: std::io::Error) -> Option { - let kind = error.kind(); - match kind { + match error.kind() { std::io::ErrorKind::BrokenPipe => { // This indicates the underlying socket is broken, and we should attempt to rebind it self.mark_broken(); - match self.rebind() { - Ok(()) => None, - Err(err) => { - // Return original error, if we failed to rebind - warn!( - "failed to rebind socket, after error: {:?}: {:?}", - error, err - ); - Some(error) - } - } + None } _ => Some(error), } @@ -343,7 +332,6 @@ impl Future for SendFut<'_, '_> { if err.kind() == ErrorKind::WouldBlock { continue; } - drop(guard); // make sure we are not holding a lock before handling the error if let Some(err) = self.socket.handle_read_error(err) { return Poll::Ready(Err(err)); } @@ -352,7 +340,6 @@ impl Future for SendFut<'_, '_> { return Poll::Ready(res); } Poll::Ready(Err(err)) => { - drop(guard); // make sure we are not holding a lock before handling the error if let Some(err) = self.socket.handle_read_error(err) { return Poll::Ready(Err(err)); } @@ -376,6 +363,22 @@ impl Future for SendToFut<'_, '_> { fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { loop { + // check if the socket needs a rebind + if self.socket.is_broken() { + match self.socket.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = + std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); + return Poll::Ready(Err(err)); + } + } + } + let guard = self.socket.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); @@ -387,7 +390,7 @@ impl Future for SendToFut<'_, '_> { if err.kind() == ErrorKind::WouldBlock { continue; } - drop(guard); // make sure we are not holding a lock before handling the error + if let Some(err) = self.socket.handle_read_error(err) { return Poll::Ready(Err(err)); } @@ -396,7 +399,6 @@ impl Future for SendToFut<'_, '_> { return Poll::Ready(res); } Poll::Ready(Err(err)) => { - drop(guard); // make sure we are not holding a lock before handling the error if let Some(err) = self.socket.handle_read_error(err) { return Poll::Ready(Err(err)); } @@ -535,4 +537,37 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_udp_mark_broken() -> anyhow::Result<()> { + let socket_a = UdpSocket::bind_local(IpFamily::V4, 0)?; + let addr_a = socket_a.local_addr()?; + println!("socket bound to {:?}", addr_a); + + let socket_b = UdpSocket::bind_local(IpFamily::V4, 0)?; + let addr_b = socket_b.local_addr()?; + println!("socket bound to {:?}", addr_b); + + let handle = tokio::task::spawn(async move { + let mut buffer = [0u8; 16]; + for _ in 0..2 { + match socket_b.recv_from(&mut buffer).await { + Ok((count, addr)) => { + println!("got {:?} from {:?}", &buffer[..count], addr); + } + Err(err) => { + eprintln!("error recv: {:?}", err); + } + } + } + }); + socket_a.send_to(&[0][..], addr_b).await?; + socket_a.mark_broken(); + assert!(socket_a.is_broken()); + socket_a.send_to(&[0][..], addr_b).await?; + assert!(!socket_a.is_broken()); + + handle.await?; + Ok(()) + } } From 9441b27902937a389f46050d8da4ee75576c3589 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 19 Nov 2024 14:23:02 +0100 Subject: [PATCH 11/38] handle read and write erros --- net-tools/netwatch/src/udp.rs | 136 +++++++++++++++++++++++++++++----- 1 file changed, 119 insertions(+), 17 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 445a948b56..a26f115aa2 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -211,6 +211,20 @@ impl UdpSocket { /// /// Returns `Some(error)` if the error is fatal otherwise `None. fn handle_read_error(&self, error: std::io::Error) -> Option { + match error.kind() { + std::io::ErrorKind::NotConnected => { + // This indicates the underlying socket is broken, and we should attempt to rebind it + self.mark_broken(); + None + } + _ => Some(error), + } + } + + /// Handle potential write errors, updating internal state. + /// + /// Returns `Some(error)` if the error is fatal otherwise `None. + fn handle_write_error(&self, error: std::io::Error) -> Option { match error.kind() { std::io::ErrorKind::BrokenPipe => { // This indicates the underlying socket is broken, and we should attempt to rebind it @@ -236,22 +250,46 @@ impl Future for RecvFut<'_, '_> { let Self { socket, buffer } = &mut *self; loop { + // check if the socket needs a rebind + if socket.is_broken() { + match socket.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = + std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); + return Poll::Ready(Err(err)); + } + } + } + let guard = socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let inner_socket = guard.as_ref().expect("missing socket"); - match socket.poll_recv_ready(cx) { + match inner_socket.poll_recv_ready(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(())) => { - let res = socket.try_recv(buffer); + let res = inner_socket.try_recv(buffer); if let Err(err) = res { if err.kind() == ErrorKind::WouldBlock { continue; } - return Poll::Ready(Err(err)); + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; } return Poll::Ready(res); } - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Ready(Err(err)) => { + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } } } } @@ -271,22 +309,45 @@ impl Future for RecvFromFut<'_, '_> { let Self { socket, buffer } = &mut *self; loop { + // check if the socket needs a rebind + if socket.is_broken() { + match socket.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = + std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); + return Poll::Ready(Err(err)); + } + } + } let guard = socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let inner_socket = guard.as_ref().expect("missing socket"); - match socket.poll_recv_ready(cx) { + match inner_socket.poll_recv_ready(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(())) => { - let res = socket.try_recv_from(buffer); + let res = inner_socket.try_recv_from(buffer); if let Err(err) = res { if err.kind() == ErrorKind::WouldBlock { continue; } - return Poll::Ready(Err(err)); + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; } return Poll::Ready(res); } - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Ready(Err(err)) => { + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } } } } @@ -302,10 +363,36 @@ impl Future for WritableFut<'_> { type Output = std::io::Result<()>; fn poll(self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { - let guard = self.socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + loop { + // check if the socket needs a rebind + if self.socket.is_broken() { + match self.socket.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = + std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); + return Poll::Ready(Err(err)); + } + } + } + let guard = self.socket.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); - socket.poll_send_ready(c) + match socket.poll_send_ready(c) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => return Poll::Ready(Ok(())), + Poll::Ready(Err(err)) => { + if let Some(err) = self.socket.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } } } @@ -321,6 +408,21 @@ impl Future for SendFut<'_, '_> { fn poll(self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { loop { + // check if the socket needs a rebind + if self.socket.is_broken() { + match self.socket.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = + std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); + return Poll::Ready(Err(err)); + } + } + } let guard = self.socket.socket.read().unwrap(); let socket = guard.as_ref().expect("missing socket"); @@ -332,7 +434,7 @@ impl Future for SendFut<'_, '_> { if err.kind() == ErrorKind::WouldBlock { continue; } - if let Some(err) = self.socket.handle_read_error(err) { + if let Some(err) = self.socket.handle_write_error(err) { return Poll::Ready(Err(err)); } continue; @@ -340,7 +442,7 @@ impl Future for SendFut<'_, '_> { return Poll::Ready(res); } Poll::Ready(Err(err)) => { - if let Some(err) = self.socket.handle_read_error(err) { + if let Some(err) = self.socket.handle_write_error(err) { return Poll::Ready(Err(err)); } continue; @@ -391,7 +493,7 @@ impl Future for SendToFut<'_, '_> { continue; } - if let Some(err) = self.socket.handle_read_error(err) { + if let Some(err) = self.socket.handle_write_error(err) { return Poll::Ready(Err(err)); } continue; @@ -399,7 +501,7 @@ impl Future for SendToFut<'_, '_> { return Poll::Ready(res); } Poll::Ready(Err(err)) => { - if let Some(err) = self.socket.handle_read_error(err) { + if let Some(err) = self.socket.handle_write_error(err) { return Poll::Ready(Err(err)); } continue; From c09f6dd433835e27a3f415f7d370b85707f9d2b9 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 19 Nov 2024 14:45:50 +0100 Subject: [PATCH 12/38] handle errors in udp_conn as well --- iroh-net/src/magicsock/udp_conn.rs | 121 +++++++++++++++++++++-------- net-tools/netwatch/src/udp.rs | 71 +++++++++-------- 2 files changed, 128 insertions(+), 64 deletions(-) diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 0431943791..6fecbb7a5a 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -3,8 +3,8 @@ use std::{ io, net::SocketAddr, pin::Pin, - sync::Arc, - task::{ready, Context, Poll}, + sync::{Arc, RwLock}, + task::{Context, Poll}, }; use anyhow::{bail, Context as _}; @@ -12,13 +12,13 @@ use netwatch::UdpSocket; use quinn::AsyncUdpSocket; use quinn_udp::{Transmit, UdpSockRef}; use tokio::io::Interest; -use tracing::{debug, trace}; +use tracing::{debug, trace, warn}; /// A UDP socket implementing Quinn's [`AsyncUdpSocket`]. #[derive(Debug)] pub struct UdpConn { io: Arc, - inner: quinn_udp::UdpSocketState, + inner: RwLock, } impl UdpConn { @@ -34,11 +34,11 @@ impl UdpConn { Ok(Self { io: Arc::new(sock), - inner: state, + inner: RwLock::new(state), }) } - pub(super) fn rebind(&mut self) -> anyhow::Result<()> { + pub(super) fn rebind(&self) -> anyhow::Result<()> { // Rebind underlying socket self.io.rebind()?; @@ -46,7 +46,7 @@ impl UdpConn { let new_state = self.io.with_socket(|socket| { quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket)) })?; - self.inner = new_state; + *self.inner.write().unwrap() = new_state; Ok(()) } @@ -67,12 +67,38 @@ impl AsyncUdpSocket for UdpConn { } fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> { - self.io.try_io(Interest::WRITABLE, || { - self.io.with_socket(|io| { - let sock_ref = UdpSockRef::from(io); - self.inner.send(sock_ref, transmit) - }) - }) + loop { + if self.io.is_broken() { + match self.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = + std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); + return Err(err); + } + } + } + + let res = self.io.try_io(Interest::WRITABLE, || { + self.io.with_socket(|io| { + let sock_ref = UdpSockRef::from(io); + self.inner.read().unwrap().send(sock_ref, transmit) + }) + }); + match res { + Ok(()) => return Ok(()), + Err(err) => match self.io.handle_write_error(err) { + Some(err) => return Err(err), + None => { + continue; + } + }, + } + } } fn poll_recv( @@ -82,22 +108,55 @@ impl AsyncUdpSocket for UdpConn { meta: &mut [quinn_udp::RecvMeta], ) -> Poll> { loop { - ready!(self.io.with_socket(|io| io.poll_recv_ready(cx)))?; - if let Ok(res) = self.io.try_io(Interest::READABLE, || { - self.io - .with_socket(|io| self.inner.recv(io.into(), bufs, meta)) - }) { - for meta in meta.iter().take(res) { - trace!( - src = %meta.addr, - len = meta.len, - count = meta.len / meta.stride, - dst = %meta.dst_ip.map(|x| x.to_string()).unwrap_or_default(), - "UDP recv" - ); + if self.io.is_broken() { + match self.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = + std::io::Error::new(std::io::ErrorKind::NotConnected, err.to_string()); + return Poll::Ready(Err(err)); + } } + } + + match self.io.with_socket(|io| io.poll_recv_ready(cx)) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(err)) => match self.io.handle_read_error(err) { + Some(err) => return Poll::Ready(Err(err)), + None => { + continue; + } + }, + } - return Poll::Ready(Ok(res)); + let res = self.io.try_io(Interest::READABLE, || { + self.io + .with_socket(|io| self.inner.read().unwrap().recv(io.into(), bufs, meta)) + }); + match res { + Ok(count) => { + for meta in meta.iter().take(count) { + trace!( + src = %meta.addr, + len = meta.len, + count = meta.len / meta.stride, + dst = %meta.dst_ip.map(|x| x.to_string()).unwrap_or_default(), + "UDP recv" + ); + } + return Poll::Ready(Ok(count)); + } + Err(err) => match self.io.handle_read_error(err) { + Some(err) => return Poll::Ready(Err(err)), + None => { + continue; + } + }, } } } @@ -107,15 +166,15 @@ impl AsyncUdpSocket for UdpConn { } fn may_fragment(&self) -> bool { - self.inner.may_fragment() + self.inner.read().unwrap().may_fragment() } fn max_transmit_segments(&self) -> usize { - self.inner.max_gso_segments() + self.inner.read().unwrap().max_gso_segments() } fn max_receive_segments(&self) -> usize { - self.inner.gro_segments() + self.inner.read().unwrap().gro_segments() } } @@ -164,7 +223,7 @@ struct IoPoller { impl quinn::UdpPoller for IoPoller { fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.io.with_socket(|io| io.poll_send_ready(cx)) + self.io.poll_writable(cx) } } diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index a26f115aa2..58c44c3b81 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -210,7 +210,7 @@ impl UdpSocket { /// Handle potential read errors, updating internal state. /// /// Returns `Some(error)` if the error is fatal otherwise `None. - fn handle_read_error(&self, error: std::io::Error) -> Option { + pub fn handle_read_error(&self, error: std::io::Error) -> Option { match error.kind() { std::io::ErrorKind::NotConnected => { // This indicates the underlying socket is broken, and we should attempt to rebind it @@ -224,7 +224,7 @@ impl UdpSocket { /// Handle potential write errors, updating internal state. /// /// Returns `Some(error)` if the error is fatal otherwise `None. - fn handle_write_error(&self, error: std::io::Error) -> Option { + pub fn handle_write_error(&self, error: std::io::Error) -> Option { match error.kind() { std::io::ErrorKind::BrokenPipe => { // This indicates the underlying socket is broken, and we should attempt to rebind it @@ -234,6 +234,40 @@ impl UdpSocket { _ => Some(error), } } + + /// Poll for writable + pub fn poll_writable(&self, cx: &mut std::task::Context<'_>) -> Poll> { + loop { + // check if the socket needs a rebind + if self.is_broken() { + match self.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = + std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); + return Poll::Ready(Err(err)); + } + } + } + let guard = self.socket.read().unwrap(); + let socket = guard.as_ref().expect("missing socket"); + + match socket.poll_send_ready(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => return Poll::Ready(Ok(())), + Poll::Ready(Err(err)) => { + if let Some(err) = self.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } + } } /// Receive future @@ -362,37 +396,8 @@ pub struct WritableFut<'a> { impl Future for WritableFut<'_> { type Output = std::io::Result<()>; - fn poll(self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { - loop { - // check if the socket needs a rebind - if self.socket.is_broken() { - match self.socket.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } - } - let guard = self.socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); - - match socket.poll_send_ready(c) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(())) => return Poll::Ready(Ok(())), - Poll::Ready(Err(err)) => { - if let Some(err) = self.socket.handle_write_error(err) { - return Poll::Ready(Err(err)); - } - continue; - } - } - } + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + self.socket.poll_writable(cx) } } From d13879bc9c5141fd1418eab61c3fa658056582fa Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 19 Nov 2024 14:57:59 +0100 Subject: [PATCH 13/38] fix: handle spurious wakeups in udp_conn --- iroh-net/src/magicsock/udp_conn.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 6fecbb7a5a..d3b59e1303 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -151,12 +151,18 @@ impl AsyncUdpSocket for UdpConn { } return Poll::Ready(Ok(count)); } - Err(err) => match self.io.handle_read_error(err) { - Some(err) => return Poll::Ready(Err(err)), - None => { + Err(err) => { + // ignore spurious wakeups + if err.kind() == std::io::ErrorKind::WouldBlock { continue; } - }, + match self.io.handle_read_error(err) { + Some(err) => return Poll::Ready(Err(err)), + None => { + continue; + } + } + } } } } From d698980099a708972e6f184d098cca5cd2e42342 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 19 Nov 2024 14:58:52 +0100 Subject: [PATCH 14/38] and another one --- iroh-net/src/magicsock/udp_conn.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index d3b59e1303..0feee3905f 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -91,12 +91,17 @@ impl AsyncUdpSocket for UdpConn { }); match res { Ok(()) => return Ok(()), - Err(err) => match self.io.handle_write_error(err) { - Some(err) => return Err(err), - None => { + Err(err) => { + if err.kind() == std::io::ErrorKind::WouldBlock { continue; } - }, + match self.io.handle_write_error(err) { + Some(err) => return Err(err), + None => { + continue; + } + } + } } } } From eff40decdda3f7dc1b7c599e3536b389646a43af Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 19 Nov 2024 16:54:58 +0100 Subject: [PATCH 15/38] fixup clippy --- iroh-net/src/magicsock/udp_conn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 0feee3905f..687be40f8b 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -63,7 +63,7 @@ impl UdpConn { impl AsyncUdpSocket for UdpConn { fn create_io_poller(self: Arc) -> Pin> { - (&*self).create_io_poller() + (*self).create_io_poller() } fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> { From dce54e227db14bc561efc864ad9d192365c95a91 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 19 Nov 2024 17:45:26 +0100 Subject: [PATCH 16/38] spelling fix --- net-tools/netwatch/src/udp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 58c44c3b81..9f532281eb 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -20,7 +20,7 @@ pub struct UdpSocket { socket: RwLock>, /// The addr we are binding to. addr: SocketAddr, - /// Set to true, when an error occured, that means we need to rebind the socket. + /// Set to true, when an error occurred, that means we need to rebind the socket. is_broken: AtomicBool, } From 6847d2de8d1b0be9badf6e4e820e3cac31dacd37 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Wed, 20 Nov 2024 18:45:32 +0100 Subject: [PATCH 17/38] rebind sockets on major network changes --- iroh-net/src/magicsock.rs | 8 ++++++++ net-tools/netwatch/src/udp.rs | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index b4ed4e1976..13f7426940 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -1860,6 +1860,14 @@ impl Actor { debug!("link change detected: major? {}", is_major); if is_major { + if let Err(err) = self.pconn4.rebind() { + warn!("failed to rebind socket v4: {:?}", err); + } + if let Some(ref pconn6) = self.pconn6 { + if let Err(err) = pconn6.rebind() { + warn!("failed to rebind socket v4: {:?}", err); + } + } self.msock.dns_resolver.clear_cache(); self.msock.re_stun("link-change-major"); self.close_stale_relay_connections().await; diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 9f532281eb..06a5c127c5 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -8,7 +8,7 @@ use std::{ }; use anyhow::{ensure, Context, Result}; -use tracing::warn; +use tracing::{debug, warn}; use super::IpFamily; @@ -78,6 +78,7 @@ impl UdpSocket { /// Rebind the underlying socket. pub fn rebind(&self) -> Result<()> { + debug!("rebinding {}", self.addr); // Remove old socket let mut guard = self.socket.write().unwrap(); { From 84741ba28c09b65fbe777f7c3da5fedbddf7d561 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Thu, 21 Nov 2024 12:09:44 +0100 Subject: [PATCH 18/38] actually shut down sockets --- iroh-net/src/magicsock.rs | 5 ++ iroh-net/src/magicsock/udp_conn.rs | 26 ++++--- net-tools/netwatch/src/udp.rs | 109 +++++++++++++++++++++-------- 3 files changed, 102 insertions(+), 38 deletions(-) diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index 13f7426940..4a79493a2f 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -1514,6 +1514,11 @@ impl Handle { } self.msock.closing.store(true, Ordering::Relaxed); self.msock.actor_sender.send(ActorMessage::Shutdown).await?; + self.msock.pconn4.close().await; + if let Some(ref conn) = self.msock.pconn6 { + conn.close().await; + } + self.msock.closed.store(true, Ordering::SeqCst); self.msock.direct_addrs.addrs.shutdown(); diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 687be40f8b..f90794130f 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -30,7 +30,7 @@ impl UdpConn { let sock = bind(addr)?; let state = sock.with_socket(|socket| { quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket)) - })?; + })??; Ok(Self { io: Arc::new(sock), @@ -45,7 +45,7 @@ impl UdpConn { // update socket state let new_state = self.io.with_socket(|socket| { quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket)) - })?; + })??; *self.inner.write().unwrap() = new_state; Ok(()) } @@ -59,6 +59,11 @@ impl UdpConn { io: self.io.clone(), }) } + + /// Closes the socket for good + pub async fn close(&self) { + self.io.close().await; + } } impl AsyncUdpSocket for UdpConn { @@ -90,8 +95,9 @@ impl AsyncUdpSocket for UdpConn { }) }); match res { - Ok(()) => return Ok(()), - Err(err) => { + Ok(Ok(())) => return Ok(()), + Err(err) => return Err(err), // closed error + Ok(Err(err)) => { if err.kind() == std::io::ErrorKind::WouldBlock { continue; } @@ -129,14 +135,15 @@ impl AsyncUdpSocket for UdpConn { } match self.io.with_socket(|io| io.poll_recv_ready(cx)) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(())) => {} - Poll::Ready(Err(err)) => match self.io.handle_read_error(err) { + Ok(Poll::Pending) => return Poll::Pending, + Ok(Poll::Ready(Ok(()))) => {} + Ok(Poll::Ready(Err(err))) => match self.io.handle_read_error(err) { Some(err) => return Poll::Ready(Err(err)), None => { continue; } }, + Err(err) => return Poll::Ready(Err(err)), } let res = self.io.try_io(Interest::READABLE, || { @@ -144,7 +151,7 @@ impl AsyncUdpSocket for UdpConn { .with_socket(|io| self.inner.read().unwrap().recv(io.into(), bufs, meta)) }); match res { - Ok(count) => { + Ok(Ok(count)) => { for meta in meta.iter().take(count) { trace!( src = %meta.addr, @@ -156,7 +163,7 @@ impl AsyncUdpSocket for UdpConn { } return Poll::Ready(Ok(count)); } - Err(err) => { + Ok(Err(err)) => { // ignore spurious wakeups if err.kind() == std::io::ErrorKind::WouldBlock { continue; @@ -168,6 +175,7 @@ impl AsyncUdpSocket for UdpConn { } } } + Err(err) => return Poll::Ready(Err(err)), } } } diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 06a5c127c5..12b78be1d5 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -7,7 +7,7 @@ use std::{ task::Poll, }; -use anyhow::{ensure, Context, Result}; +use anyhow::{bail, ensure, Context, Result}; use tracing::{debug, warn}; use super::IpFamily; @@ -82,7 +82,9 @@ impl UdpSocket { // Remove old socket let mut guard = self.socket.write().unwrap(); { - let socket = guard.take().expect("not yet dropped"); + let Some(socket) = guard.take() else { + bail!("cannot rebind closed socket"); + }; drop(socket); } @@ -113,13 +115,18 @@ impl UdpSocket { } /// Use the socket - pub fn with_socket(&self, f: F) -> T + pub fn with_socket(&self, f: F) -> std::io::Result where F: FnOnce(&tokio::net::UdpSocket) -> T, { let guard = self.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); - f(socket) + let Some(socket) = guard.as_ref() else { + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; + Ok(f(socket)) } pub fn try_io( @@ -128,7 +135,12 @@ impl UdpSocket { f: impl FnOnce() -> std::io::Result, ) -> std::io::Result { let guard = self.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let Some(socket) = guard.as_ref() else { + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; socket.try_io(interest, f) } @@ -173,7 +185,13 @@ impl UdpSocket { pub fn connect(&self, addr: SocketAddr) -> std::io::Result<()> { let mut guard = self.socket.write().unwrap(); // dance around to make non async connect work - let socket_tokio = guard.take().expect("missing socket"); + let Some(socket_tokio) = guard.take() else { + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; + let socket_std = socket_tokio.into_std()?; socket_std.connect(addr)?; let socket_tokio = tokio::net::UdpSocket::from_std(socket_std)?; @@ -184,30 +202,38 @@ impl UdpSocket { /// Returns the local address of this socket. pub fn local_addr(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let Some(socket) = guard.as_ref() else { + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; + socket.local_addr() } /// Closes the socket, and waits for the underlying `libc::close` call to be finished. - pub async fn close(self) { - let std_sock = self - .socket - .write() - .unwrap() - .take() - .expect("not yet dropped") - .into_std(); - let res = tokio::runtime::Handle::current() - .spawn_blocking(move || { - // Calls libc::close, which can block - drop(std_sock); - }) - .await; - if let Err(err) = res { - warn!("failed to close socket: {:?}", err); + pub async fn close(&self) { + let socket = self.socket.write().unwrap().take(); + if let Some(sock) = socket { + let std_sock = sock.into_std(); + let res = tokio::runtime::Handle::current() + .spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }) + .await; + if let Err(err) = res { + warn!("failed to close socket: {:?}", err); + } } } + /// Check if this socket is closed. + pub fn is_closed(&self) -> bool { + self.socket.read().unwrap().is_none() + } + /// Handle potential read errors, updating internal state. /// /// Returns `Some(error)` if the error is fatal otherwise `None. @@ -255,7 +281,12 @@ impl UdpSocket { } } let guard = self.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let Some(socket) = guard.as_ref() else { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + ))); + }; match socket.poll_send_ready(cx) { Poll::Pending => return Poll::Pending, @@ -302,7 +333,12 @@ impl Future for RecvFut<'_, '_> { } let guard = socket.socket.read().unwrap(); - let inner_socket = guard.as_ref().expect("missing socket"); + let Some(inner_socket) = guard.as_ref() else { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + ))); + }; match inner_socket.poll_recv_ready(cx) { Poll::Pending => return Poll::Pending, @@ -360,7 +396,12 @@ impl Future for RecvFromFut<'_, '_> { } } let guard = socket.socket.read().unwrap(); - let inner_socket = guard.as_ref().expect("missing socket"); + let Some(inner_socket) = guard.as_ref() else { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + ))); + }; match inner_socket.poll_recv_ready(cx) { Poll::Pending => return Poll::Pending, @@ -430,7 +471,12 @@ impl Future for SendFut<'_, '_> { } } let guard = self.socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let Some(socket) = guard.as_ref() else { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + ))); + }; match socket.poll_send_ready(c) { Poll::Pending => return Poll::Pending, @@ -488,7 +534,12 @@ impl Future for SendToFut<'_, '_> { } let guard = self.socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let Some(socket) = guard.as_ref() else { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + ))); + }; match socket.poll_send_ready(cx) { Poll::Pending => return Poll::Pending, From 914d5efccf6924910cc0b1db808e1756e3b8f3ca Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Fri, 22 Nov 2024 11:11:15 +0100 Subject: [PATCH 19/38] why can't we have nice things? --- iroh-net/src/magicsock.rs | 5 ----- iroh-net/src/magicsock/udp_conn.rs | 5 ----- 2 files changed, 10 deletions(-) diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index 4a79493a2f..13f7426940 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -1514,11 +1514,6 @@ impl Handle { } self.msock.closing.store(true, Ordering::Relaxed); self.msock.actor_sender.send(ActorMessage::Shutdown).await?; - self.msock.pconn4.close().await; - if let Some(ref conn) = self.msock.pconn6 { - conn.close().await; - } - self.msock.closed.store(true, Ordering::SeqCst); self.msock.direct_addrs.addrs.shutdown(); diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index f90794130f..efd374c6ef 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -59,11 +59,6 @@ impl UdpConn { io: self.io.clone(), }) } - - /// Closes the socket for good - pub async fn close(&self) { - self.io.close().await; - } } impl AsyncUdpSocket for UdpConn { From 1b763d9ffd01e3773aab5e7bcc5fcfc004870164 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Fri, 22 Nov 2024 12:34:52 +0100 Subject: [PATCH 20/38] fix error handling --- iroh-net/src/magicsock.rs | 9 +++------ iroh-net/src/magicsock/udp_conn.rs | 26 ++++++++++++++++++-------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index 13f7426940..76d2f403ab 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -1860,13 +1860,10 @@ impl Actor { debug!("link change detected: major? {}", is_major); if is_major { - if let Err(err) = self.pconn4.rebind() { - warn!("failed to rebind socket v4: {:?}", err); - } + // Only mark them as broken to trigger a rebind when used again + self.pconn4.mark_broken(); if let Some(ref pconn6) = self.pconn6 { - if let Err(err) = pconn6.rebind() { - warn!("failed to rebind socket v4: {:?}", err); - } + pconn6.mark_broken(); } self.msock.dns_resolver.clear_cache(); self.msock.re_stun("link-change-major"); diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index efd374c6ef..946857bd85 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -89,10 +89,9 @@ impl AsyncUdpSocket for UdpConn { self.inner.read().unwrap().send(sock_ref, transmit) }) }); - match res { - Ok(Ok(())) => return Ok(()), - Err(err) => return Err(err), // closed error - Ok(Err(err)) => { + match flatten(res) { + Ok(()) => return Ok(()), + Err(err) => { if err.kind() == std::io::ErrorKind::WouldBlock { continue; } @@ -145,8 +144,9 @@ impl AsyncUdpSocket for UdpConn { self.io .with_socket(|io| self.inner.read().unwrap().recv(io.into(), bufs, meta)) }); - match res { - Ok(Ok(count)) => { + + match flatten(res) { + Ok(count) => { for meta in meta.iter().take(count) { trace!( src = %meta.addr, @@ -158,7 +158,7 @@ impl AsyncUdpSocket for UdpConn { } return Poll::Ready(Ok(count)); } - Ok(Err(err)) => { + Err(err) => { // ignore spurious wakeups if err.kind() == std::io::ErrorKind::WouldBlock { continue; @@ -170,7 +170,6 @@ impl AsyncUdpSocket for UdpConn { } } } - Err(err) => return Poll::Ready(Err(err)), } } } @@ -229,6 +228,17 @@ fn bind(mut addr: SocketAddr) -> anyhow::Result { bail!("failed to bind any ports on {:?} (tried {:?})", addr, ports); } +/// Flatten a result +fn flatten( + result: std::result::Result, E>, +) -> std::result::Result { + match result { + Ok(Ok(res)) => Ok(res), + Ok(Err(err)) => Err(err), + Err(err) => Err(err), + } +} + /// Poller for when the socket is writable. #[derive(Debug)] struct IoPoller { From bf71e0807f9f3cb90939b432f04fe0c39595a230 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Fri, 22 Nov 2024 13:34:52 +0100 Subject: [PATCH 21/38] refactor: move quinn_udp logic into udpsockert --- iroh-net/src/magicsock.rs | 9 +- iroh-net/src/magicsock/udp_conn.rs | 146 ++---------------- net-tools/netwatch/Cargo.toml | 32 +++- net-tools/netwatch/src/udp.rs | 233 +++++++++++++++++++++++------ 4 files changed, 235 insertions(+), 185 deletions(-) diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index 76d2f403ab..b79dff9459 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -1860,10 +1860,13 @@ impl Actor { debug!("link change detected: major? {}", is_major); if is_major { - // Only mark them as broken to trigger a rebind when used again - self.pconn4.mark_broken(); + if let Err(err) = self.pconn4.rebind() { + warn!("failed to rebind Udp IPv4 socket: {:?}", err); + }; if let Some(ref pconn6) = self.pconn6 { - pconn6.mark_broken(); + if let Err(err) = pconn6.rebind() { + warn!("failed to rebind Udp IPv6 socket: {:?}", err); + }; } self.msock.dns_resolver.clear_cache(); self.msock.re_stun("link-change-major"); diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 946857bd85..2498711bee 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -3,22 +3,20 @@ use std::{ io, net::SocketAddr, pin::Pin, - sync::{Arc, RwLock}, + sync::Arc, task::{Context, Poll}, }; use anyhow::{bail, Context as _}; use netwatch::UdpSocket; use quinn::AsyncUdpSocket; -use quinn_udp::{Transmit, UdpSockRef}; -use tokio::io::Interest; -use tracing::{debug, trace, warn}; +use quinn_udp::Transmit; +use tracing::debug; /// A UDP socket implementing Quinn's [`AsyncUdpSocket`]. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct UdpConn { io: Arc, - inner: RwLock, } impl UdpConn { @@ -28,26 +26,8 @@ impl UdpConn { pub(super) fn bind(addr: SocketAddr) -> anyhow::Result { let sock = bind(addr)?; - let state = sock.with_socket(|socket| { - quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket)) - })??; - - Ok(Self { - io: Arc::new(sock), - inner: RwLock::new(state), - }) - } - - pub(super) fn rebind(&self) -> anyhow::Result<()> { - // Rebind underlying socket - self.io.rebind()?; - // update socket state - let new_state = self.io.with_socket(|socket| { - quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket)) - })??; - *self.inner.write().unwrap() = new_state; - Ok(()) + Ok(Self { io: Arc::new(sock) }) } pub fn port(&self) -> u16 { @@ -67,43 +47,7 @@ impl AsyncUdpSocket for UdpConn { } fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> { - loop { - if self.io.is_broken() { - match self.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Err(err); - } - } - } - - let res = self.io.try_io(Interest::WRITABLE, || { - self.io.with_socket(|io| { - let sock_ref = UdpSockRef::from(io); - self.inner.read().unwrap().send(sock_ref, transmit) - }) - }); - match flatten(res) { - Ok(()) => return Ok(()), - Err(err) => { - if err.kind() == std::io::ErrorKind::WouldBlock { - continue; - } - match self.io.handle_write_error(err) { - Some(err) => return Err(err), - None => { - continue; - } - } - } - } - } + self.io.try_send_quinn(transmit) } fn poll_recv( @@ -112,66 +56,7 @@ impl AsyncUdpSocket for UdpConn { bufs: &mut [io::IoSliceMut<'_>], meta: &mut [quinn_udp::RecvMeta], ) -> Poll> { - loop { - if self.io.is_broken() { - match self.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::NotConnected, err.to_string()); - return Poll::Ready(Err(err)); - } - } - } - - match self.io.with_socket(|io| io.poll_recv_ready(cx)) { - Ok(Poll::Pending) => return Poll::Pending, - Ok(Poll::Ready(Ok(()))) => {} - Ok(Poll::Ready(Err(err))) => match self.io.handle_read_error(err) { - Some(err) => return Poll::Ready(Err(err)), - None => { - continue; - } - }, - Err(err) => return Poll::Ready(Err(err)), - } - - let res = self.io.try_io(Interest::READABLE, || { - self.io - .with_socket(|io| self.inner.read().unwrap().recv(io.into(), bufs, meta)) - }); - - match flatten(res) { - Ok(count) => { - for meta in meta.iter().take(count) { - trace!( - src = %meta.addr, - len = meta.len, - count = meta.len / meta.stride, - dst = %meta.dst_ip.map(|x| x.to_string()).unwrap_or_default(), - "UDP recv" - ); - } - return Poll::Ready(Ok(count)); - } - Err(err) => { - // ignore spurious wakeups - if err.kind() == std::io::ErrorKind::WouldBlock { - continue; - } - match self.io.handle_read_error(err) { - Some(err) => return Poll::Ready(Err(err)), - None => { - continue; - } - } - } - } - } + self.io.poll_recv_quinn(cx, bufs, meta) } fn local_addr(&self) -> io::Result { @@ -179,15 +64,15 @@ impl AsyncUdpSocket for UdpConn { } fn may_fragment(&self) -> bool { - self.inner.read().unwrap().may_fragment() + self.io.may_fragment().unwrap_or_default() } fn max_transmit_segments(&self) -> usize { - self.inner.read().unwrap().max_gso_segments() + self.io.max_transmit_segments().unwrap_or_default() } fn max_receive_segments(&self) -> usize { - self.inner.read().unwrap().gro_segments() + self.io.max_receive_segments().unwrap_or_default() } } @@ -228,17 +113,6 @@ fn bind(mut addr: SocketAddr) -> anyhow::Result { bail!("failed to bind any ports on {:?} (tried {:?})", addr, ports); } -/// Flatten a result -fn flatten( - result: std::result::Result, E>, -) -> std::result::Result { - match result { - Ok(Ok(res)) => Ok(res), - Ok(Err(err)) => Err(err), - Err(err) => Err(err), - } -} - /// Poller for when the socket is writable. #[derive(Debug)] struct IoPoller { diff --git a/net-tools/netwatch/Cargo.toml b/net-tools/netwatch/Cargo.toml index 38637d45b6..26a3013ff6 100644 --- a/net-tools/netwatch/Cargo.toml +++ b/net-tools/netwatch/Cargo.toml @@ -21,10 +21,22 @@ futures-util = "0.3.25" libc = "0.2.139" netdev = "0.30.0" once_cell = "1.18.0" +quinn-udp = { package = "iroh-quinn-udp", version = "0.5.5" } socket2 = "0.5.3" thiserror = "1" time = "0.3.20" -tokio = { version = "1", features = ["io-util", "macros", "sync", "rt", "net", "fs", "io-std", "signal", "process", "time"] } +tokio = { version = "1", features = [ + "io-util", + "macros", + "sync", + "rt", + "net", + "fs", + "io-std", + "signal", + "process", + "time", +] } tokio-util = { version = "0.7", features = ["rt"] } tracing = "0.1" @@ -36,12 +48,26 @@ rtnetlink = "0.13.0" [target.'cfg(target_os = "windows")'.dependencies] wmi = "0.13" -windows = { version = "0.51", features = ["Win32_NetworkManagement_IpHelper", "Win32_Foundation", "Win32_NetworkManagement_Ndis", "Win32_Networking_WinSock"] } +windows = { version = "0.51", features = [ + "Win32_NetworkManagement_IpHelper", + "Win32_Foundation", + "Win32_NetworkManagement_Ndis", + "Win32_Networking_WinSock", +] } serde = { version = "1", features = ["derive"] } derive_more = { version = "1.0.0", features = ["debug"] } [dev-dependencies] -tokio = { version = "1", features = ["io-util", "sync", "rt", "net", "fs", "macros", "time", "test-util"] } +tokio = { version = "1", features = [ + "io-util", + "sync", + "rt", + "net", + "fs", + "macros", + "time", + "test-util", +] } [package.metadata.docs.rs] all-features = true diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 12b78be1d5..695250bbed 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -4,20 +4,20 @@ use std::{ net::SocketAddr, pin::Pin, sync::{atomic::AtomicBool, RwLock}, - task::Poll, + task::{Context, Poll}, }; -use anyhow::{bail, ensure, Context, Result}; -use tracing::{debug, warn}; +use anyhow::{bail, ensure, Context as _, Result}; +use quinn_udp::Transmit; +use tokio::io::Interest; +use tracing::{debug, trace, warn}; use super::IpFamily; -/// Wrapper around a tokio UDP socket that handles the fact that -/// on drop `libc::close` can block for UDP sockets. +/// Wrapper around a tokio UDP socket. #[derive(Debug)] pub struct UdpSocket { - // TODO: can we drop the Arc and use lifetimes in the futures? - socket: RwLock>, + socket: RwLock>, /// The addr we are binding to. addr: SocketAddr, /// Set to true, when an error occurred, that means we need to rebind the socket. @@ -71,7 +71,7 @@ impl UdpSocket { } /// Marks this socket as needing a rebind - pub fn mark_broken(&self) { + fn mark_broken(&self) { self.is_broken .store(true, std::sync::atomic::Ordering::SeqCst); } @@ -105,7 +105,7 @@ impl UdpSocket { let mut addr = addr.into(); let socket = inner_bind(addr)?; // update to use selected port - addr.set_port(socket.local_addr()?.port()); + addr.set_port(socket.0.local_addr()?.port()); Ok(UdpSocket { socket: RwLock::new(Some(socket)), @@ -117,35 +117,17 @@ impl UdpSocket { /// Use the socket pub fn with_socket(&self, f: F) -> std::io::Result where - F: FnOnce(&tokio::net::UdpSocket) -> T, + F: FnOnce(&tokio::net::UdpSocket, &quinn_udp::UdpSocketState) -> T, { let guard = self.socket.read().unwrap(); - let Some(socket) = guard.as_ref() else { + let Some((socket, state)) = guard.as_ref() else { + warn!("socket closed"); return Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "socket closed", )); }; - Ok(f(socket)) - } - - pub fn try_io( - &self, - interest: tokio::io::Interest, - f: impl FnOnce() -> std::io::Result, - ) -> std::io::Result { - let guard = self.socket.read().unwrap(); - let Some(socket) = guard.as_ref() else { - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; - socket.try_io(interest, f) - } - - pub fn writable(&self) -> WritableFut<'_> { - WritableFut { socket: self } + Ok(f(socket, state)) } /// TODO @@ -183,9 +165,11 @@ impl UdpSocket { /// TODO pub fn connect(&self, addr: SocketAddr) -> std::io::Result<()> { + tracing::info!("connectnig to {}", addr); let mut guard = self.socket.write().unwrap(); // dance around to make non async connect work - let Some(socket_tokio) = guard.take() else { + let Some((socket_tokio, state)) = guard.take() else { + warn!("socket closed"); return Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "socket closed", @@ -195,14 +179,15 @@ impl UdpSocket { let socket_std = socket_tokio.into_std()?; socket_std.connect(addr)?; let socket_tokio = tokio::net::UdpSocket::from_std(socket_std)?; - guard.replace(socket_tokio); + guard.replace((socket_tokio, state)); Ok(()) } /// Returns the local address of this socket. pub fn local_addr(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); - let Some(socket) = guard.as_ref() else { + let Some((socket, _)) = guard.as_ref() else { + warn!("socket closed"); return Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "socket closed", @@ -215,7 +200,7 @@ impl UdpSocket { /// Closes the socket, and waits for the underlying `libc::close` call to be finished. pub async fn close(&self) { let socket = self.socket.write().unwrap().take(); - if let Some(sock) = socket { + if let Some((sock, _)) = socket { let std_sock = sock.into_std(); let res = tokio::runtime::Handle::current() .spawn_blocking(move || { @@ -281,7 +266,8 @@ impl UdpSocket { } } let guard = self.socket.read().unwrap(); - let Some(socket) = guard.as_ref() else { + let Some((socket, _state)) = guard.as_ref() else { + warn!("socket closed"); return Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "socket closed", @@ -300,6 +286,160 @@ impl UdpSocket { } } } + + /// Send a quinn based `Transmit`. + pub fn try_send_quinn(&self, transmit: &Transmit<'_>) -> std::io::Result<()> { + loop { + // check if the socket needs a rebind + if self.is_broken() { + match self.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = + std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); + return Err(err); + } + } + } + let guard = self.socket.read().unwrap(); + let Some((socket, state)) = guard.as_ref() else { + warn!("socket closed"); + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; + + let res = socket.try_io(Interest::WRITABLE, || state.send(socket.into(), transmit)); + + match res { + Ok(()) => return Ok(()), + Err(err) => match self.handle_write_error(err) { + Some(err) => return Err(err), + None => { + continue; + } + }, + } + } + } + + /// quinn based `poll_recv` + pub fn poll_recv_quinn( + &self, + cx: &mut Context, + bufs: &mut [std::io::IoSliceMut<'_>], + meta: &mut [quinn_udp::RecvMeta], + ) -> Poll> { + loop { + // check if the socket needs a rebind + if self.is_broken() { + match self.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = + std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); + return Poll::Ready(Err(err)); + } + } + } + let guard = self.socket.read().unwrap(); + let Some((socket, state)) = guard.as_ref() else { + warn!("socket closed"); + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + ))); + }; + + match socket.poll_recv_ready(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => { + // We are ready to read, continue + } + Poll::Ready(Err(err)) => match self.handle_read_error(err) { + Some(err) => return Poll::Ready(Err(err)), + None => { + continue; + } + }, + } + + match state.recv(socket.into(), bufs, meta) { + Ok(count) => { + for meta in meta.iter().take(count) { + trace!( + src = %meta.addr, + len = meta.len, + count = meta.len / meta.stride, + dst = %meta.dst_ip.map(|x| x.to_string()).unwrap_or_default(), + "UDP recv" + ); + } + return Poll::Ready(Ok(count)); + } + Err(err) => { + // ignore spurious wakeups + if err.kind() == std::io::ErrorKind::WouldBlock { + continue; + } + match self.handle_read_error(err) { + Some(err) => return Poll::Ready(Err(err)), + None => { + continue; + } + } + } + } + } + } + + /// TODO + pub fn may_fragment(&self) -> std::io::Result { + let guard = self.socket.read().unwrap(); + let Some((_, state)) = guard.as_ref() else { + warn!("socket closed"); + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; + Ok(state.may_fragment()) + } + + /// TODO + pub fn max_transmit_segments(&self) -> std::io::Result { + let guard = self.socket.read().unwrap(); + let Some((_, state)) = guard.as_ref() else { + warn!("socket closed"); + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; + Ok(state.max_gso_segments()) + } + + /// TODO + pub fn max_receive_segments(&self) -> std::io::Result { + let guard = self.socket.read().unwrap(); + let Some((_, state)) = guard.as_ref() else { + warn!("socket closed"); + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; + Ok(state.gro_segments()) + } } /// Receive future @@ -333,7 +473,8 @@ impl Future for RecvFut<'_, '_> { } let guard = socket.socket.read().unwrap(); - let Some(inner_socket) = guard.as_ref() else { + let Some((inner_socket, _)) = guard.as_ref() else { + warn!("socket closed"); return Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "socket closed", @@ -396,7 +537,8 @@ impl Future for RecvFromFut<'_, '_> { } } let guard = socket.socket.read().unwrap(); - let Some(inner_socket) = guard.as_ref() else { + let Some((inner_socket, _)) = guard.as_ref() else { + warn!("socket closed"); return Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "socket closed", @@ -471,7 +613,8 @@ impl Future for SendFut<'_, '_> { } } let guard = self.socket.socket.read().unwrap(); - let Some(socket) = guard.as_ref() else { + let Some((socket, _)) = guard.as_ref() else { + warn!("socket closed"); return Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "socket closed", @@ -534,7 +677,8 @@ impl Future for SendToFut<'_, '_> { } let guard = self.socket.socket.read().unwrap(); - let Some(socket) = guard.as_ref() else { + let Some((socket, _)) = guard.as_ref() else { + warn!("socket closed"); return Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "socket closed", @@ -568,7 +712,7 @@ impl Future for SendToFut<'_, '_> { } } -fn inner_bind(addr: SocketAddr) -> Result { +fn inner_bind(addr: SocketAddr) -> Result<(tokio::net::UdpSocket, quinn_udp::UdpSocketState)> { let network = IpFamily::from(addr.ip()); let socket = socket2::Socket::new( network.into(), @@ -605,6 +749,8 @@ fn inner_bind(addr: SocketAddr) -> Result { // Convert into tokio UdpSocket let socket = tokio::net::UdpSocket::from_std(socket).context("conversion to tokio")?; + let socket_ref = quinn_udp::UdpSockRef::from(&socket); + let socket_state = quinn_udp::UdpSocketState::new(socket_ref)?; if addr.port() != 0 { let local_addr = socket.local_addr().context("local addr")?; @@ -617,14 +763,15 @@ fn inner_bind(addr: SocketAddr) -> Result { ); } - Ok(socket) + Ok((socket, socket_state)) } impl Drop for UdpSocket { fn drop(&mut self) { + debug!("dropping UdpSocket"); // Only spawn_blocking if we are inside a tokio runtime, otherwise we just drop. if let Ok(handle) = tokio::runtime::Handle::try_current() { - if let Some(socket) = self.socket.write().unwrap().take() { + if let Some((socket, _)) = self.socket.write().unwrap().take() { // this will be empty if `close` was called before let std_sock = socket.into_std(); handle.spawn_blocking(move || { From dd7428debad803d77d934686d6e6705ad0b912c5 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Mon, 25 Nov 2024 17:26:25 +0100 Subject: [PATCH 22/38] fix: use try_io --- Cargo.lock | 1 + net-tools/netwatch/src/udp.rs | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 5378414353..2123caceda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3329,6 +3329,7 @@ dependencies = [ "futures-lite 2.5.0", "futures-sink", "futures-util", + "iroh-quinn-udp", "libc", "netdev", "netlink-packet-core", diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 695250bbed..4e0c161b5c 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -373,7 +373,8 @@ impl UdpSocket { }, } - match state.recv(socket.into(), bufs, meta) { + let res = socket.try_io(Interest::READABLE, || state.recv(socket.into(), bufs, meta)); + match res { Ok(count) => { for meta in meta.iter().take(count) { trace!( From 21500352572696294e8591fc6f1b4c8bc4482332 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Mon, 25 Nov 2024 18:02:25 +0100 Subject: [PATCH 23/38] waky waky --- net-tools/netwatch/src/udp.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 4e0c161b5c..82d9471382 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -3,8 +3,8 @@ use std::{ io::ErrorKind, net::SocketAddr, pin::Pin, - sync::{atomic::AtomicBool, RwLock}, - task::{Context, Poll}, + sync::{atomic::AtomicBool, Mutex, RwLock}, + task::{Context, Poll, Waker}, }; use anyhow::{bail, ensure, Context as _, Result}; @@ -18,6 +18,7 @@ use super::IpFamily; #[derive(Debug)] pub struct UdpSocket { socket: RwLock>, + recv_waker: Mutex>, /// The addr we are binding to. addr: SocketAddr, /// Set to true, when an error occurred, that means we need to rebind the socket. @@ -98,6 +99,11 @@ impl UdpSocket { self.is_broken .store(false, std::sync::atomic::Ordering::SeqCst); + // wakup + if let Some(waker) = self.recv_waker.lock().unwrap().take() { + waker.wake(); + } + Ok(()) } @@ -109,6 +115,7 @@ impl UdpSocket { Ok(UdpSocket { socket: RwLock::new(Some(socket)), + recv_waker: Mutex::new(None), addr, is_broken: AtomicBool::new(false), }) @@ -361,7 +368,10 @@ impl UdpSocket { }; match socket.poll_recv_ready(cx) { - Poll::Pending => return Poll::Pending, + Poll::Pending => { + self.recv_waker.lock().unwrap().replace(cx.waker().clone()); + return Poll::Pending; + } Poll::Ready(Ok(())) => { // We are ready to read, continue } From d24ab7229795c23ccc6b4eea217a501bc4c27bda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 25 Nov 2024 18:15:19 +0100 Subject: [PATCH 24/38] Simpler regression test case --- iroh-net/src/magicsock.rs | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index b79dff9459..d2d757106b 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -3098,6 +3098,45 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_regression_network_change_rebind_wakes_connection_driver( + ) -> testresult::TestResult { + iroh_test::logging::setup(); + let m1 = MagicStack::new(RelayMode::Disabled).await?; + let m2 = MagicStack::new(RelayMode::Disabled).await?; + + println!("Net change"); + m1.endpoint.magic_sock().force_network_change(true).await; + tokio::time::sleep(Duration::from_secs(1)).await; // wait for socket rebinding + + let _guard = mesh_stacks(vec![m1.clone(), m2.clone()]).await?; + + let _handle = AbortOnDropHandle::new(tokio::spawn({ + let endpoint = m2.endpoint.clone(); + async move { + while let Some(incoming) = endpoint.accept().await { + println!("Incoming first conn!"); + let conn = incoming.await?; + conn.closed().await; + } + + testresult::TestResult::Ok(()) + } + })); + + println!("first conn!"); + let conn = m1 + .endpoint + .connect(m2.endpoint.node_addr().await?, ALPN) + .await?; + println!("Closing first conn"); + conn.close(0u32.into(), b"bye lolz"); + conn.closed().await; + println!("Closed first conn"); + + Ok(()) + } + #[tokio::test(flavor = "multi_thread")] async fn test_two_devices_roundtrip_network_change() -> Result<()> { time::timeout( From 6a0212c2b38438be0090c2b2ceeeeba350a9ed42 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Mon, 25 Nov 2024 18:52:52 +0100 Subject: [PATCH 25/38] use atomic waker, and register properly --- Cargo.lock | 1 + iroh-net/src/magicsock.rs | 2 +- net-tools/netwatch/Cargo.toml | 1 + net-tools/netwatch/src/udp.rs | 49 +++++++++++++++++++++++------------ 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2123caceda..a511cede31 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3324,6 +3324,7 @@ name = "netwatch" version = "0.1.0" dependencies = [ "anyhow", + "atomic-waker", "bytes", "derive_more", "futures-lite 2.5.0", diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index d2d757106b..62fa829b14 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -3101,7 +3101,7 @@ mod tests { #[tokio::test] async fn test_regression_network_change_rebind_wakes_connection_driver( ) -> testresult::TestResult { - iroh_test::logging::setup(); + let _ = iroh_test::logging::setup(); let m1 = MagicStack::new(RelayMode::Disabled).await?; let m2 = MagicStack::new(RelayMode::Disabled).await?; diff --git a/net-tools/netwatch/Cargo.toml b/net-tools/netwatch/Cargo.toml index 26a3013ff6..2a0050666d 100644 --- a/net-tools/netwatch/Cargo.toml +++ b/net-tools/netwatch/Cargo.toml @@ -14,6 +14,7 @@ workspace = true [dependencies] anyhow = { version = "1" } +atomic-waker = "1.1.2" bytes = "1.7" futures-lite = "2.3" futures-sink = "0.3.25" diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 82d9471382..8a50fd89cc 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -3,11 +3,12 @@ use std::{ io::ErrorKind, net::SocketAddr, pin::Pin, - sync::{atomic::AtomicBool, Mutex, RwLock}, - task::{Context, Poll, Waker}, + sync::{atomic::AtomicBool, RwLock}, + task::{Context, Poll}, }; use anyhow::{bail, ensure, Context as _, Result}; +use atomic_waker::AtomicWaker; use quinn_udp::Transmit; use tokio::io::Interest; use tracing::{debug, trace, warn}; @@ -18,7 +19,8 @@ use super::IpFamily; #[derive(Debug)] pub struct UdpSocket { socket: RwLock>, - recv_waker: Mutex>, + recv_waker: AtomicWaker, + send_waker: AtomicWaker, /// The addr we are binding to. addr: SocketAddr, /// Set to true, when an error occurred, that means we need to rebind the socket. @@ -99,10 +101,9 @@ impl UdpSocket { self.is_broken .store(false, std::sync::atomic::Ordering::SeqCst); - // wakup - if let Some(waker) = self.recv_waker.lock().unwrap().take() { - waker.wake(); - } + // wakeup + self.recv_waker.wake(); + self.send_waker.wake(); Ok(()) } @@ -115,7 +116,8 @@ impl UdpSocket { Ok(UdpSocket { socket: RwLock::new(Some(socket)), - recv_waker: Mutex::new(None), + recv_waker: AtomicWaker::default(), + send_waker: AtomicWaker::default(), addr, is_broken: AtomicBool::new(false), }) @@ -282,7 +284,10 @@ impl UdpSocket { }; match socket.poll_send_ready(cx) { - Poll::Pending => return Poll::Pending, + Poll::Pending => { + self.send_waker.register(cx.waker()); + return Poll::Pending; + } Poll::Ready(Ok(())) => return Poll::Ready(Ok(())), Poll::Ready(Err(err)) => { if let Some(err) = self.handle_write_error(err) { @@ -369,7 +374,7 @@ impl UdpSocket { match socket.poll_recv_ready(cx) { Poll::Pending => { - self.recv_waker.lock().unwrap().replace(cx.waker().clone()); + self.recv_waker.register(cx.waker()); return Poll::Pending; } Poll::Ready(Ok(())) => { @@ -493,7 +498,10 @@ impl Future for RecvFut<'_, '_> { }; match inner_socket.poll_recv_ready(cx) { - Poll::Pending => return Poll::Pending, + Poll::Pending => { + self.socket.recv_waker.register(cx.waker()); + return Poll::Pending; + } Poll::Ready(Ok(())) => { let res = inner_socket.try_recv(buffer); if let Err(err) = res { @@ -557,7 +565,10 @@ impl Future for RecvFromFut<'_, '_> { }; match inner_socket.poll_recv_ready(cx) { - Poll::Pending => return Poll::Pending, + Poll::Pending => { + self.socket.recv_waker.register(cx.waker()); + return Poll::Pending; + } Poll::Ready(Ok(())) => { let res = inner_socket.try_recv_from(buffer); if let Err(err) = res { @@ -606,7 +617,7 @@ pub struct SendFut<'a, 'b> { impl Future for SendFut<'_, '_> { type Output = std::io::Result; - fn poll(self: Pin<&mut Self>, c: &mut std::task::Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { loop { // check if the socket needs a rebind if self.socket.is_broken() { @@ -632,8 +643,11 @@ impl Future for SendFut<'_, '_> { ))); }; - match socket.poll_send_ready(c) { - Poll::Pending => return Poll::Pending, + match socket.poll_send_ready(cx) { + Poll::Pending => { + self.socket.send_waker.register(cx.waker()); + return Poll::Pending; + } Poll::Ready(Ok(())) => { let res = socket.try_send(self.buffer); if let Err(err) = res { @@ -697,7 +711,10 @@ impl Future for SendToFut<'_, '_> { }; match socket.poll_send_ready(cx) { - Poll::Pending => return Poll::Pending, + Poll::Pending => { + self.socket.send_waker.register(cx.waker()); + return Poll::Pending; + } Poll::Ready(Ok(())) => { let res = socket.try_send_to(self.buffer, self.to); if let Err(err) = res { From 95f29c805adb8a17fbe0132429f4fa6df326616c Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Mon, 25 Nov 2024 19:01:51 +0100 Subject: [PATCH 26/38] fixup --- net-tools/netwatch/src/udp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 8a50fd89cc..f3a38a191d 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -28,7 +28,7 @@ pub struct UdpSocket { } /// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it -/// is the ma supported by a default configuration of macOS. Some platforms will silently clamp the value. +/// is the max supported by a default configuration of macOS. Some platforms will silently clamp the value. const SOCKET_BUFFER_SIZE: usize = 7 << 20; impl UdpSocket { /// Bind only Ipv4 on any interface. From 6194890c367f2143718c12ee4ea1d05194fa41cf Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Mon, 25 Nov 2024 20:18:31 +0100 Subject: [PATCH 27/38] some more cleanup --- iroh-net/src/magicsock.rs | 6 +++--- iroh-net/src/magicsock/udp_conn.rs | 4 ++++ net-tools/netwatch/src/udp.rs | 34 ++---------------------------- 3 files changed, 9 insertions(+), 35 deletions(-) diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index 62fa829b14..87964a3108 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -1088,9 +1088,9 @@ impl MagicSock { Err(err) if err.kind() == io::ErrorKind::WouldBlock => { // This is the socket .try_send_disco_message_udp used. let sock = self.conn_for_addr(dst)?; - let mut poller = sock.create_io_poller(); - match poller.as_mut().poll_writable(cx)? { - Poll::Ready(()) => continue, + match sock.as_socket_ref().poll_writable(cx) { + Poll::Ready(Ok(())) => continue, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } } diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 2498711bee..2f993d8950 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -24,6 +24,10 @@ impl UdpConn { self.io.clone() } + pub(super) fn as_socket_ref(&self) -> &UdpSocket { + &self.io + } + pub(super) fn bind(addr: SocketAddr) -> anyhow::Result { let sock = bind(addr)?; diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index f3a38a191d..af2109b49c 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -123,22 +123,6 @@ impl UdpSocket { }) } - /// Use the socket - pub fn with_socket(&self, f: F) -> std::io::Result - where - F: FnOnce(&tokio::net::UdpSocket, &quinn_udp::UdpSocketState) -> T, - { - let guard = self.socket.read().unwrap(); - let Some((socket, state)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; - Ok(f(socket, state)) - } - /// TODO pub fn recv<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFut<'a, 'b> { RecvFut { @@ -231,7 +215,7 @@ impl UdpSocket { /// Handle potential read errors, updating internal state. /// /// Returns `Some(error)` if the error is fatal otherwise `None. - pub fn handle_read_error(&self, error: std::io::Error) -> Option { + fn handle_read_error(&self, error: std::io::Error) -> Option { match error.kind() { std::io::ErrorKind::NotConnected => { // This indicates the underlying socket is broken, and we should attempt to rebind it @@ -245,7 +229,7 @@ impl UdpSocket { /// Handle potential write errors, updating internal state. /// /// Returns `Some(error)` if the error is fatal otherwise `None. - pub fn handle_write_error(&self, error: std::io::Error) -> Option { + fn handle_write_error(&self, error: std::io::Error) -> Option { match error.kind() { std::io::ErrorKind::BrokenPipe => { // This indicates the underlying socket is broken, and we should attempt to rebind it @@ -593,20 +577,6 @@ impl Future for RecvFromFut<'_, '_> { } } -/// Writable future -#[derive(Debug)] -pub struct WritableFut<'a> { - socket: &'a UdpSocket, -} - -impl Future for WritableFut<'_> { - type Output = std::io::Result<()>; - - fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - self.socket.poll_writable(cx) - } -} - /// Send future #[derive(Debug)] pub struct SendFut<'a, 'b> { From 724b92e308c67b5b578976fbabeb43d594252c5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Tue, 26 Nov 2024 11:38:09 +0100 Subject: [PATCH 28/38] Don't block when locking in poll --- net-tools/netwatch/src/udp.rs | 62 ++++++++++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 9 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index af2109b49c..8d2ecf9cd5 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -3,7 +3,7 @@ use std::{ io::ErrorKind, net::SocketAddr, pin::Pin, - sync::{atomic::AtomicBool, RwLock}, + sync::{atomic::AtomicBool, RwLock, RwLockReadGuard, TryLockError}, task::{Context, Poll}, }; @@ -101,9 +101,9 @@ impl UdpSocket { self.is_broken .store(false, std::sync::atomic::Ordering::SeqCst); + drop(guard); // wakeup - self.recv_waker.wake(); - self.send_waker.wake(); + self.wake_all(); Ok(()) } @@ -173,6 +173,10 @@ impl UdpSocket { socket_std.connect(addr)?; let socket_tokio = tokio::net::UdpSocket::from_std(socket_std)?; guard.replace((socket_tokio, state)); + + drop(guard); + self.wake_all(); + Ok(()) } @@ -193,6 +197,7 @@ impl UdpSocket { /// Closes the socket, and waits for the underlying `libc::close` call to be finished. pub async fn close(&self) { let socket = self.socket.write().unwrap().take(); + self.wake_all(); if let Some((sock, _)) = socket { let std_sock = sock.into_std(); let res = tokio::runtime::Handle::current() @@ -240,6 +245,41 @@ impl UdpSocket { } } + /// Try to get a read lock for the sockets, but don't block for trying to acquire it. + fn poll_read_socket( + &self, + waker: &AtomicWaker, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + let guard = match self.socket.try_read() { + Ok(guard) => guard, + Err(TryLockError::Poisoned(e)) => panic!("socket lock poisoned: {e}"), + Err(TryLockError::WouldBlock) => { + waker.register(cx.waker()); + + match self.socket.try_read() { + Ok(guard) => { + // we're actually fine, no need to cause a spurious wakeup + waker.take(); + guard + } + Err(TryLockError::Poisoned(e)) => panic!("socket lock poisoned: {e}"), + Err(TryLockError::WouldBlock) => { + // Ok fine, we registered our waker, the lock is really closed, + // we can return pending. + return Poll::Pending; + } + } + } + }; + Poll::Ready(guard) + } + + fn wake_all(&self) { + self.recv_waker.wake(); + self.send_waker.wake(); + } + /// Poll for writable pub fn poll_writable(&self, cx: &mut std::task::Context<'_>) -> Poll> { loop { @@ -258,7 +298,8 @@ impl UdpSocket { } } } - let guard = self.socket.read().unwrap(); + + let guard = futures_lite::ready!(self.poll_read_socket(&self.send_waker, cx)); let Some((socket, _state)) = guard.as_ref() else { warn!("socket closed"); return Poll::Ready(Err(std::io::Error::new( @@ -347,7 +388,7 @@ impl UdpSocket { } } } - let guard = self.socket.read().unwrap(); + let guard = futures_lite::ready!(self.poll_read_socket(&self.recv_waker, cx)); let Some((socket, state)) = guard.as_ref() else { warn!("socket closed"); return Poll::Ready(Err(std::io::Error::new( @@ -472,7 +513,7 @@ impl Future for RecvFut<'_, '_> { } } - let guard = socket.socket.read().unwrap(); + let guard = futures_lite::ready!(socket.poll_read_socket(&socket.recv_waker, cx)); let Some((inner_socket, _)) = guard.as_ref() else { warn!("socket closed"); return Poll::Ready(Err(std::io::Error::new( @@ -539,7 +580,7 @@ impl Future for RecvFromFut<'_, '_> { } } } - let guard = socket.socket.read().unwrap(); + let guard = futures_lite::ready!(socket.poll_read_socket(&socket.recv_waker, cx)); let Some((inner_socket, _)) = guard.as_ref() else { warn!("socket closed"); return Poll::Ready(Err(std::io::Error::new( @@ -604,7 +645,8 @@ impl Future for SendFut<'_, '_> { } } } - let guard = self.socket.socket.read().unwrap(); + let guard = + futures_lite::ready!(self.socket.poll_read_socket(&self.socket.send_waker, cx)); let Some((socket, _)) = guard.as_ref() else { warn!("socket closed"); return Poll::Ready(Err(std::io::Error::new( @@ -671,7 +713,8 @@ impl Future for SendToFut<'_, '_> { } } - let guard = self.socket.socket.read().unwrap(); + let guard = + futures_lite::ready!(self.socket.poll_read_socket(&self.socket.send_waker, cx)); let Some((socket, _)) = guard.as_ref() else { warn!("socket closed"); return Poll::Ready(Err(std::io::Error::new( @@ -769,6 +812,7 @@ impl Drop for UdpSocket { debug!("dropping UdpSocket"); // Only spawn_blocking if we are inside a tokio runtime, otherwise we just drop. if let Ok(handle) = tokio::runtime::Handle::try_current() { + // No wakeup after dropping write lock here, since we're getting dropped. if let Some((socket, _)) = self.socket.write().unwrap().take() { // this will be empty if `close` was called before let std_sock = socket.into_std(); From 003294e0260ca0c80e066fffe3ac309bae75e39a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Tue, 26 Nov 2024 13:38:16 +0100 Subject: [PATCH 29/38] Use `socket2::SockRef` for `connect` instead of holding a write lock --- net-tools/netwatch/src/udp.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 8d2ecf9cd5..8f59ee7891 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -159,9 +159,8 @@ impl UdpSocket { /// TODO pub fn connect(&self, addr: SocketAddr) -> std::io::Result<()> { tracing::info!("connectnig to {}", addr); - let mut guard = self.socket.write().unwrap(); - // dance around to make non async connect work - let Some((socket_tokio, state)) = guard.take() else { + let guard = self.socket.read().unwrap(); + let Some((socket_tokio, _state)) = guard.as_ref() else { warn!("socket closed"); return Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, @@ -169,13 +168,8 @@ impl UdpSocket { )); }; - let socket_std = socket_tokio.into_std()?; - socket_std.connect(addr)?; - let socket_tokio = tokio::net::UdpSocket::from_std(socket_std)?; - guard.replace((socket_tokio, state)); - - drop(guard); - self.wake_all(); + let sock_ref = socket2::SockRef::from(&socket_tokio); + sock_ref.connect(&socket2::SockAddr::from(addr))?; Ok(()) } From 605fd16bb37f9b2a8696bbcd14c45c17416f9181 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 26 Nov 2024 14:55:14 +0100 Subject: [PATCH 30/38] cleanup socket code --- iroh-net/src/magicsock/udp_conn.rs | 4 +- net-tools/netwatch/src/udp.rs | 509 +++++++++++++---------------- 2 files changed, 237 insertions(+), 276 deletions(-) diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 2f993d8950..7267cf529b 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -72,11 +72,11 @@ impl AsyncUdpSocket for UdpConn { } fn max_transmit_segments(&self) -> usize { - self.io.max_transmit_segments().unwrap_or_default() + self.io.max_gso_segments().unwrap_or_default() } fn max_receive_segments(&self) -> usize { - self.io.max_receive_segments().unwrap_or_default() + self.io.gro_segments().unwrap_or_default() } } diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 8f59ee7891..91af26f2e8 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -18,11 +18,9 @@ use super::IpFamily; /// Wrapper around a tokio UDP socket. #[derive(Debug)] pub struct UdpSocket { - socket: RwLock>, + socket: RwLock, recv_waker: AtomicWaker, send_waker: AtomicWaker, - /// The addr we are binding to. - addr: SocketAddr, /// Set to true, when an error occurred, that means we need to rebind the socket. is_broken: AtomicBool, } @@ -81,27 +79,17 @@ impl UdpSocket { /// Rebind the underlying socket. pub fn rebind(&self) -> Result<()> { - debug!("rebinding {}", self.addr); - // Remove old socket - let mut guard = self.socket.write().unwrap(); { - let Some(socket) = guard.take() else { - bail!("cannot rebind closed socket"); - }; - drop(socket); - } + let mut guard = self.socket.write().unwrap(); + guard.rebind()?; - // Prepare new socket - let new_socket = inner_bind(self.addr)?; + // Clear errors + self.is_broken + .store(false, std::sync::atomic::Ordering::SeqCst); - // Insert new socket - guard.replace(new_socket); - - // Clear errors - self.is_broken - .store(false, std::sync::atomic::Ordering::SeqCst); + drop(guard); + } - drop(guard); // wakeup self.wake_all(); @@ -109,21 +97,27 @@ impl UdpSocket { } fn bind_raw(addr: impl Into) -> Result { - let mut addr = addr.into(); - let socket = inner_bind(addr)?; - // update to use selected port - addr.set_port(socket.0.local_addr()?.port()); + let socket = SocketState::bind(addr.into())?; Ok(UdpSocket { - socket: RwLock::new(Some(socket)), + socket: RwLock::new(socket), recv_waker: AtomicWaker::default(), send_waker: AtomicWaker::default(), - addr, is_broken: AtomicBool::new(false), }) } - /// TODO + /// Receives a single datagram message on the socket from the remote address + /// to which it is connected. On success, returns the number of bytes read. + /// + /// The function must be called with valid byte array `buf` of sufficient + /// size to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// [`connect`]: method@Self::connect pub fn recv<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFut<'a, 'b> { RecvFut { socket: self, @@ -131,7 +125,13 @@ impl UdpSocket { } } - /// TODO + /// Receives a single datagram message on the socket. On success, returns + /// the number of bytes read and the origin. + /// + /// The function must be called with valid byte array `buf` of sufficient + /// size to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// pub fn recv_from<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFromFut<'a, 'b> { RecvFromFut { socket: self, @@ -139,7 +139,18 @@ impl UdpSocket { } } - /// TODO + /// Sends data on the socket to the remote address that the socket is + /// connected to. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// [`connect`]: method@Self::connect + /// + /// # Return + /// + /// On success, the number of bytes sent is returned, otherwise, the + /// encountered error is returned. pub fn send<'a, 'b>(&'b self, buffer: &'a [u8]) -> SendFut<'a, 'b> { SendFut { socket: self, @@ -147,7 +158,8 @@ impl UdpSocket { } } - /// TODO + /// Sends data on the socket to the given address. On success, returns the + /// number of bytes written. pub fn send_to<'a, 'b>(&'b self, buffer: &'a [u8], to: SocketAddr) -> SendToFut<'a, 'b> { SendToFut { socket: self, @@ -156,17 +168,13 @@ impl UdpSocket { } } - /// TODO + /// Connects the UDP socket setting the default destination for send() and + /// limiting packets that are read via `recv` from the address specified in + /// `addr`. pub fn connect(&self, addr: SocketAddr) -> std::io::Result<()> { tracing::info!("connectnig to {}", addr); let guard = self.socket.read().unwrap(); - let Some((socket_tokio, _state)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; + let (socket_tokio, _state) = guard.try_get_connected()?; let sock_ref = socket2::SockRef::from(&socket_tokio); sock_ref.connect(&socket2::SockAddr::from(addr))?; @@ -177,20 +185,14 @@ impl UdpSocket { /// Returns the local address of this socket. pub fn local_addr(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); - let Some((socket, _)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; + let (socket, _state) = guard.try_get_connected()?; socket.local_addr() } /// Closes the socket, and waits for the underlying `libc::close` call to be finished. pub async fn close(&self) { - let socket = self.socket.write().unwrap().take(); + let socket = self.socket.write().unwrap().close(); self.wake_all(); if let Some((sock, _)) = socket { let std_sock = sock.into_std(); @@ -208,7 +210,7 @@ impl UdpSocket { /// Check if this socket is closed. pub fn is_closed(&self) -> bool { - self.socket.read().unwrap().is_none() + self.socket.read().unwrap().is_closed() } /// Handle potential read errors, updating internal state. @@ -244,7 +246,7 @@ impl UdpSocket { &self, waker: &AtomicWaker, cx: &mut std::task::Context<'_>, - ) -> Poll>> { + ) -> Poll> { let guard = match self.socket.try_read() { Ok(guard) => guard, Err(TryLockError::Poisoned(e)) => panic!("socket lock poisoned: {e}"), @@ -274,33 +276,35 @@ impl UdpSocket { self.send_waker.wake(); } + /// Checks if the socket needs a rebind, and if so does it. + /// + /// Returns an error if the rebind is needed, but failed. + fn maybe_rebind(&self) -> std::io::Result<()> { + if self.is_broken() { + match self.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); + return Err(err); + } + } + } + Ok(()) + } + /// Poll for writable pub fn poll_writable(&self, cx: &mut std::task::Context<'_>) -> Poll> { loop { - // check if the socket needs a rebind - if self.is_broken() { - match self.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = self.maybe_rebind() { + return Poll::Ready(Err(err)); } let guard = futures_lite::ready!(self.poll_read_socket(&self.send_waker, cx)); - let Some((socket, _state)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (socket, _state) = guard.try_get_connected()?; match socket.poll_send_ready(cx) { Poll::Pending => { @@ -321,29 +325,18 @@ impl UdpSocket { /// Send a quinn based `Transmit`. pub fn try_send_quinn(&self, transmit: &Transmit<'_>) -> std::io::Result<()> { loop { - // check if the socket needs a rebind - if self.is_broken() { - match self.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Err(err); - } + self.maybe_rebind()?; + + let guard = match self.socket.try_read() { + Ok(guard) => guard, + Err(TryLockError::Poisoned(e)) => { + panic!("lock poisoned: {:?}", e); + } + Err(TryLockError::WouldBlock) => { + return Err(std::io::Error::new(std::io::ErrorKind::WouldBlock, "")); } - } - let guard = self.socket.read().unwrap(); - let Some((socket, state)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); }; + let (socket, state) = guard.try_get_connected()?; let res = socket.try_io(Interest::WRITABLE, || state.send(socket.into(), transmit)); @@ -367,29 +360,12 @@ impl UdpSocket { meta: &mut [quinn_udp::RecvMeta], ) -> Poll> { loop { - // check if the socket needs a rebind - if self.is_broken() { - match self.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = self.maybe_rebind() { + return Poll::Ready(Err(err)); } + let guard = futures_lite::ready!(self.poll_read_socket(&self.recv_waker, cx)); - let Some((socket, state)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (socket, state) = guard.try_get_connected()?; match socket.poll_recv_ready(cx) { Poll::Pending => { @@ -437,42 +413,33 @@ impl UdpSocket { } } - /// TODO + /// Whether transmitted datagrams might get fragmented by the IP layer + /// + /// Returns `false` on targets which employ e.g. the `IPV6_DONTFRAG` socket option. pub fn may_fragment(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); - let Some((_, state)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; + let (_, state) = guard.try_get_connected()?; Ok(state.may_fragment()) } - /// TODO - pub fn max_transmit_segments(&self) -> std::io::Result { + /// The maximum amount of segments which can be transmitted if a platform + /// supports Generic Send Offload (GSO). + /// + /// This is 1 if the platform doesn't support GSO. Subject to change if errors are detected + /// while using GSO. + pub fn max_gso_segments(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); - let Some((_, state)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; + let (_, state) = guard.try_get_connected()?; Ok(state.max_gso_segments()) } - /// TODO - pub fn max_receive_segments(&self) -> std::io::Result { + /// The number of segments to read when GRO is enabled. Used as a factor to + /// compute the receive buffer size. + /// + /// Returns 1 if the platform doesn't support GRO. + pub fn gro_segments(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); - let Some((_, state)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; + let (_, state) = guard.try_get_connected()?; Ok(state.gro_segments()) } } @@ -491,30 +458,12 @@ impl Future for RecvFut<'_, '_> { let Self { socket, buffer } = &mut *self; loop { - // check if the socket needs a rebind - if socket.is_broken() { - match socket.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = socket.maybe_rebind() { + return Poll::Ready(Err(err)); } let guard = futures_lite::ready!(socket.poll_read_socket(&socket.recv_waker, cx)); - let Some((inner_socket, _)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (inner_socket, _state) = guard.try_get_connected()?; match inner_socket.poll_recv_ready(cx) { Poll::Pending => { @@ -559,29 +508,12 @@ impl Future for RecvFromFut<'_, '_> { let Self { socket, buffer } = &mut *self; loop { - // check if the socket needs a rebind - if socket.is_broken() { - match socket.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = socket.maybe_rebind() { + return Poll::Ready(Err(err)); } + let guard = futures_lite::ready!(socket.poll_read_socket(&socket.recv_waker, cx)); - let Some((inner_socket, _)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (inner_socket, _state) = guard.try_get_connected()?; match inner_socket.poll_recv_ready(cx) { Poll::Pending => { @@ -624,30 +556,13 @@ impl Future for SendFut<'_, '_> { fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { loop { - // check if the socket needs a rebind - if self.socket.is_broken() { - match self.socket.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = self.socket.maybe_rebind() { + return Poll::Ready(Err(err)); } + let guard = futures_lite::ready!(self.socket.poll_read_socket(&self.socket.send_waker, cx)); - let Some((socket, _)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (socket, _state) = guard.try_get_connected()?; match socket.poll_send_ready(cx) { Poll::Pending => { @@ -691,31 +606,13 @@ impl Future for SendToFut<'_, '_> { fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { loop { - // check if the socket needs a rebind - if self.socket.is_broken() { - match self.socket.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = self.socket.maybe_rebind() { + return Poll::Ready(Err(err)); } let guard = futures_lite::ready!(self.socket.poll_read_socket(&self.socket.send_waker, cx)); - let Some((socket, _)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (socket, _state) = guard.try_get_connected()?; match socket.poll_send_ready(cx) { Poll::Pending => { @@ -747,74 +644,138 @@ impl Future for SendToFut<'_, '_> { } } -fn inner_bind(addr: SocketAddr) -> Result<(tokio::net::UdpSocket, quinn_udp::UdpSocketState)> { - let network = IpFamily::from(addr.ip()); - let socket = socket2::Socket::new( - network.into(), - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - ) - .context("socket create")?; +#[derive(Debug)] +enum SocketState { + Connected { + socket: tokio::net::UdpSocket, + state: quinn_udp::UdpSocketState, + /// The addr we are binding to. + addr: SocketAddr, + }, + Closed, +} - if let Err(err) = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( - "failed to set recv_buffer_size to {}: {:?}", - SOCKET_BUFFER_SIZE, err - ); - } - if let Err(err) = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( - "failed to set send_buffer_size to {}: {:?}", - SOCKET_BUFFER_SIZE, err - ); - } - if network == IpFamily::V6 { - // Avoid dualstack - socket.set_only_v6(true).context("only IPv6")?; +impl SocketState { + fn try_get_connected( + &self, + ) -> std::io::Result<(&tokio::net::UdpSocket, &quinn_udp::UdpSocketState)> { + match self { + Self::Connected { + socket, + state, + addr: _, + } => Ok((socket, state)), + Self::Closed => { + warn!("socket closed"); + Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )) + } + } } - // Binding must happen before calling quinn, otherwise `local_addr` - // is not yet available on all OSes. - socket.bind(&addr.into()).context("binding")?; + fn bind(addr: SocketAddr) -> Result { + let network = IpFamily::from(addr.ip()); + let socket = socket2::Socket::new( + network.into(), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + ) + .context("socket create")?; + + if let Err(err) = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE) { + warn!( + "failed to set recv_buffer_size to {}: {:?}", + SOCKET_BUFFER_SIZE, err + ); + } + if let Err(err) = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE) { + warn!( + "failed to set send_buffer_size to {}: {:?}", + SOCKET_BUFFER_SIZE, err + ); + } + if network == IpFamily::V6 { + // Avoid dualstack + socket.set_only_v6(true).context("only IPv6")?; + } + + // Binding must happen before calling quinn, otherwise `local_addr` + // is not yet available on all OSes. + socket.bind(&addr.into()).context("binding")?; - // Ensure nonblocking - socket.set_nonblocking(true).context("nonblocking: true")?; + // Ensure nonblocking + socket.set_nonblocking(true).context("nonblocking: true")?; - let socket: std::net::UdpSocket = socket.into(); + let socket: std::net::UdpSocket = socket.into(); - // Convert into tokio UdpSocket - let socket = tokio::net::UdpSocket::from_std(socket).context("conversion to tokio")?; - let socket_ref = quinn_udp::UdpSockRef::from(&socket); - let socket_state = quinn_udp::UdpSocketState::new(socket_ref)?; + // Convert into tokio UdpSocket + let socket = tokio::net::UdpSocket::from_std(socket).context("conversion to tokio")?; + let socket_ref = quinn_udp::UdpSockRef::from(&socket); + let socket_state = quinn_udp::UdpSocketState::new(socket_ref)?; - if addr.port() != 0 { let local_addr = socket.local_addr().context("local addr")?; - ensure!( - local_addr.port() == addr.port(), - "wrong port bound: {:?}: wanted: {} got {}", - network, - addr.port(), - local_addr.port(), - ); + if addr.port() != 0 { + ensure!( + local_addr.port() == addr.port(), + "wrong port bound: {:?}: wanted: {} got {}", + network, + addr.port(), + local_addr.port(), + ); + } + + Ok(Self::Connected { + socket, + state: socket_state, + addr: local_addr, + }) + } + + fn rebind(&mut self) -> Result<()> { + let addr = match self { + Self::Connected { addr, .. } => *addr, + Self::Closed => { + bail!("socket is closed and cannot be rebound"); + } + }; + debug!("rebinding {}", addr); + + *self = SocketState::Closed; + *self = Self::bind(addr)?; + + Ok(()) + } + + fn is_closed(&self) -> bool { + matches!(self, Self::Closed) } - Ok((socket, socket_state)) + fn close(&mut self) -> Option<(tokio::net::UdpSocket, quinn_udp::UdpSocketState)> { + match std::mem::replace(self, SocketState::Closed) { + Self::Connected { socket, state, .. } => Some((socket, state)), + Self::Closed => None, + } + } } impl Drop for UdpSocket { fn drop(&mut self) { - debug!("dropping UdpSocket"); - // Only spawn_blocking if we are inside a tokio runtime, otherwise we just drop. - if let Ok(handle) = tokio::runtime::Handle::try_current() { - // No wakeup after dropping write lock here, since we're getting dropped. - if let Some((socket, _)) = self.socket.write().unwrap().take() { - // this will be empty if `close` was called before - let std_sock = socket.into_std(); - handle.spawn_blocking(move || { - // Calls libc::close, which can block - drop(std_sock); - }); + trace!("dropping UdpSocket"); + match self.socket.write().unwrap().close() { + Some((socket, _)) => { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + // No wakeup after dropping write lock here, since we're getting dropped. + // this will be empty if `close` was called before + let std_sock = socket.into_std(); + handle.spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }); + } } + None => {} } } } From bd572c87e2912e67ebf921380069235587391529 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 26 Nov 2024 15:00:32 +0100 Subject: [PATCH 31/38] less warning --- net-tools/netwatch/src/udp.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 91af26f2e8..a9ffe5ffd8 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -685,13 +685,13 @@ impl SocketState { .context("socket create")?; if let Err(err) = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( + debug!( "failed to set recv_buffer_size to {}: {:?}", SOCKET_BUFFER_SIZE, err ); } if let Err(err) = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( + debug!( "failed to set send_buffer_size to {}: {:?}", SOCKET_BUFFER_SIZE, err ); From db5da67e3f89261bc291c69f00bf803518d21002 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 26 Nov 2024 17:17:53 +0100 Subject: [PATCH 32/38] store last socket state --- iroh-net/src/magicsock/udp_conn.rs | 6 +-- net-tools/netwatch/src/udp.rs | 84 +++++++++++++++++++++++------- 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 7267cf529b..8626c3fcec 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -68,15 +68,15 @@ impl AsyncUdpSocket for UdpConn { } fn may_fragment(&self) -> bool { - self.io.may_fragment().unwrap_or_default() + self.io.may_fragment() } fn max_transmit_segments(&self) -> usize { - self.io.max_gso_segments().unwrap_or_default() + self.io.max_gso_segments() } fn max_receive_segments(&self) -> usize { - self.io.gro_segments().unwrap_or_default() + self.io.gro_segments() } } diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index a9ffe5ffd8..95e51f473a 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -416,10 +416,9 @@ impl UdpSocket { /// Whether transmitted datagrams might get fragmented by the IP layer /// /// Returns `false` on targets which employ e.g. the `IPV6_DONTFRAG` socket option. - pub fn may_fragment(&self) -> std::io::Result { + pub fn may_fragment(&self) -> bool { let guard = self.socket.read().unwrap(); - let (_, state) = guard.try_get_connected()?; - Ok(state.may_fragment()) + guard.may_fragment() } /// The maximum amount of segments which can be transmitted if a platform @@ -427,20 +426,18 @@ impl UdpSocket { /// /// This is 1 if the platform doesn't support GSO. Subject to change if errors are detected /// while using GSO. - pub fn max_gso_segments(&self) -> std::io::Result { + pub fn max_gso_segments(&self) -> usize { let guard = self.socket.read().unwrap(); - let (_, state) = guard.try_get_connected()?; - Ok(state.max_gso_segments()) + guard.max_gso_segments() } /// The number of segments to read when GRO is enabled. Used as a factor to /// compute the receive buffer size. /// /// Returns 1 if the platform doesn't support GRO. - pub fn gro_segments(&self) -> std::io::Result { + pub fn gro_segments(&self) -> usize { let guard = self.socket.read().unwrap(); - let (_, state) = guard.try_get_connected()?; - Ok(state.gro_segments()) + guard.gro_segments() } } @@ -652,7 +649,11 @@ enum SocketState { /// The addr we are binding to. addr: SocketAddr, }, - Closed, + Closed { + last_max_gso_segments: usize, + last_gro_segments: usize, + last_may_fragment: bool, + }, } impl SocketState { @@ -665,7 +666,7 @@ impl SocketState { state, addr: _, } => Ok((socket, state)), - Self::Closed => { + Self::Closed { .. } => { warn!("socket closed"); Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, @@ -734,28 +735,73 @@ impl SocketState { } fn rebind(&mut self) -> Result<()> { - let addr = match self { - Self::Connected { addr, .. } => *addr, - Self::Closed => { + let (addr, closed_state) = match self { + Self::Connected { state, addr, .. } => { + let s = SocketState::Closed { + last_max_gso_segments: state.max_gso_segments(), + last_gro_segments: state.gro_segments(), + last_may_fragment: state.may_fragment(), + }; + (*addr, s) + } + Self::Closed { .. } => { bail!("socket is closed and cannot be rebound"); } }; debug!("rebinding {}", addr); - *self = SocketState::Closed; + *self = closed_state; *self = Self::bind(addr)?; Ok(()) } fn is_closed(&self) -> bool { - matches!(self, Self::Closed) + matches!(self, Self::Closed { .. }) } fn close(&mut self) -> Option<(tokio::net::UdpSocket, quinn_udp::UdpSocketState)> { - match std::mem::replace(self, SocketState::Closed) { - Self::Connected { socket, state, .. } => Some((socket, state)), - Self::Closed => None, + match self { + Self::Connected { state, .. } => { + let s = SocketState::Closed { + last_max_gso_segments: state.max_gso_segments(), + last_gro_segments: state.gro_segments(), + last_may_fragment: state.may_fragment(), + }; + let Self::Connected { socket, state, .. } = std::mem::replace(self, s) else { + unreachable!("just checked"); + }; + Some((socket, state)) + } + Self::Closed { .. } => None, + } + } + + fn may_fragment(&self) -> bool { + match self { + Self::Connected { state, .. } => state.may_fragment(), + Self::Closed { + last_may_fragment, .. + } => *last_may_fragment, + } + } + + fn max_gso_segments(&self) -> usize { + match self { + Self::Connected { state, .. } => state.max_gso_segments(), + Self::Closed { + last_max_gso_segments, + .. + } => *last_max_gso_segments, + } + } + + fn gro_segments(&self) -> usize { + match self { + Self::Connected { state, .. } => state.gro_segments(), + Self::Closed { + last_gro_segments, .. + } => *last_gro_segments, } } } From b6a42991059605e1eac50f376918bae714d3673f Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 26 Nov 2024 17:20:09 +0100 Subject: [PATCH 33/38] update atomic ops --- net-tools/netwatch/src/udp.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 95e51f473a..97c3e08420 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -68,13 +68,13 @@ impl UdpSocket { /// Is the socket broken and needs a rebind? pub fn is_broken(&self) -> bool { - self.is_broken.load(std::sync::atomic::Ordering::SeqCst) + self.is_broken.load(std::sync::atomic::Ordering::Acquire) } /// Marks this socket as needing a rebind fn mark_broken(&self) { self.is_broken - .store(true, std::sync::atomic::Ordering::SeqCst); + .store(true, std::sync::atomic::Ordering::Release); } /// Rebind the underlying socket. @@ -85,7 +85,7 @@ impl UdpSocket { // Clear errors self.is_broken - .store(false, std::sync::atomic::Ordering::SeqCst); + .store(false, std::sync::atomic::Ordering::Release); drop(guard); } From f56e95dee2f59fd840b91fd7b1f739e0882668d0 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 26 Nov 2024 17:21:55 +0100 Subject: [PATCH 34/38] happy clippy --- net-tools/netwatch/src/udp.rs | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 97c3e08420..ae788bf06b 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -809,19 +809,16 @@ impl SocketState { impl Drop for UdpSocket { fn drop(&mut self) { trace!("dropping UdpSocket"); - match self.socket.write().unwrap().close() { - Some((socket, _)) => { - if let Ok(handle) = tokio::runtime::Handle::try_current() { - // No wakeup after dropping write lock here, since we're getting dropped. - // this will be empty if `close` was called before - let std_sock = socket.into_std(); - handle.spawn_blocking(move || { - // Calls libc::close, which can block - drop(std_sock); - }); - } + if let Some((socket, _)) = self.socket.write().unwrap().close() { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + // No wakeup after dropping write lock here, since we're getting dropped. + // this will be empty if `close` was called before + let std_sock = socket.into_std(); + handle.spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }); } - None => {} } } } From 62cbe4d1959b88123b9c793b3795bb1fd1db2898 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 26 Nov 2024 17:40:32 +0100 Subject: [PATCH 35/38] normalize errors to std::io --- net-tools/netwatch/src/udp.rs | 128 ++++++++++++++++------------------ 1 file changed, 61 insertions(+), 67 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index ae788bf06b..a91f9593ca 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -1,13 +1,12 @@ use std::{ future::Future, - io::ErrorKind, + io, net::SocketAddr, pin::Pin, sync::{atomic::AtomicBool, RwLock, RwLockReadGuard, TryLockError}, task::{Context, Poll}, }; -use anyhow::{bail, ensure, Context as _, Result}; use atomic_waker::AtomicWaker; use quinn_udp::Transmit; use tokio::io::Interest; @@ -30,39 +29,39 @@ pub struct UdpSocket { const SOCKET_BUFFER_SIZE: usize = 7 << 20; impl UdpSocket { /// Bind only Ipv4 on any interface. - pub fn bind_v4(port: u16) -> Result { + pub fn bind_v4(port: u16) -> io::Result { Self::bind(IpFamily::V4, port) } /// Bind only Ipv6 on any interface. - pub fn bind_v6(port: u16) -> Result { + pub fn bind_v6(port: u16) -> io::Result { Self::bind(IpFamily::V6, port) } /// Bind only Ipv4 on localhost. - pub fn bind_local_v4(port: u16) -> Result { + pub fn bind_local_v4(port: u16) -> io::Result { Self::bind_local(IpFamily::V4, port) } /// Bind only Ipv6 on localhost. - pub fn bind_local_v6(port: u16) -> Result { + pub fn bind_local_v6(port: u16) -> io::Result { Self::bind_local(IpFamily::V6, port) } /// Bind to the given port only on localhost. - pub fn bind_local(network: IpFamily, port: u16) -> Result { + pub fn bind_local(network: IpFamily, port: u16) -> io::Result { let addr = SocketAddr::new(network.local_addr(), port); - Self::bind_raw(addr).with_context(|| format!("{addr:?}")) + Self::bind_raw(addr) } /// Bind to the given port and listen on all interfaces. - pub fn bind(network: IpFamily, port: u16) -> Result { + pub fn bind(network: IpFamily, port: u16) -> io::Result { let addr = SocketAddr::new(network.unspecified_addr(), port); - Self::bind_raw(addr).with_context(|| format!("{addr:?}")) + Self::bind_raw(addr) } /// Bind to any provided [`SocketAddr`]. - pub fn bind_full(addr: impl Into) -> Result { + pub fn bind_full(addr: impl Into) -> io::Result { Self::bind_raw(addr) } @@ -78,7 +77,7 @@ impl UdpSocket { } /// Rebind the underlying socket. - pub fn rebind(&self) -> Result<()> { + pub fn rebind(&self) -> io::Result<()> { { let mut guard = self.socket.write().unwrap(); guard.rebind()?; @@ -96,7 +95,7 @@ impl UdpSocket { Ok(()) } - fn bind_raw(addr: impl Into) -> Result { + fn bind_raw(addr: impl Into) -> io::Result { let socket = SocketState::bind(addr.into())?; Ok(UdpSocket { @@ -171,7 +170,7 @@ impl UdpSocket { /// Connects the UDP socket setting the default destination for send() and /// limiting packets that are read via `recv` from the address specified in /// `addr`. - pub fn connect(&self, addr: SocketAddr) -> std::io::Result<()> { + pub fn connect(&self, addr: SocketAddr) -> io::Result<()> { tracing::info!("connectnig to {}", addr); let guard = self.socket.read().unwrap(); let (socket_tokio, _state) = guard.try_get_connected()?; @@ -183,7 +182,7 @@ impl UdpSocket { } /// Returns the local address of this socket. - pub fn local_addr(&self) -> std::io::Result { + pub fn local_addr(&self) -> io::Result { let guard = self.socket.read().unwrap(); let (socket, _state) = guard.try_get_connected()?; @@ -216,9 +215,9 @@ impl UdpSocket { /// Handle potential read errors, updating internal state. /// /// Returns `Some(error)` if the error is fatal otherwise `None. - fn handle_read_error(&self, error: std::io::Error) -> Option { + fn handle_read_error(&self, error: io::Error) -> Option { match error.kind() { - std::io::ErrorKind::NotConnected => { + io::ErrorKind::NotConnected => { // This indicates the underlying socket is broken, and we should attempt to rebind it self.mark_broken(); None @@ -230,9 +229,9 @@ impl UdpSocket { /// Handle potential write errors, updating internal state. /// /// Returns `Some(error)` if the error is fatal otherwise `None. - fn handle_write_error(&self, error: std::io::Error) -> Option { + fn handle_write_error(&self, error: io::Error) -> Option { match error.kind() { - std::io::ErrorKind::BrokenPipe => { + io::ErrorKind::BrokenPipe => { // This indicates the underlying socket is broken, and we should attempt to rebind it self.mark_broken(); None @@ -279,25 +278,15 @@ impl UdpSocket { /// Checks if the socket needs a rebind, and if so does it. /// /// Returns an error if the rebind is needed, but failed. - fn maybe_rebind(&self) -> std::io::Result<()> { + fn maybe_rebind(&self) -> io::Result<()> { if self.is_broken() { - match self.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Err(err); - } - } + self.rebind()?; } Ok(()) } /// Poll for writable - pub fn poll_writable(&self, cx: &mut std::task::Context<'_>) -> Poll> { + pub fn poll_writable(&self, cx: &mut std::task::Context<'_>) -> Poll> { loop { if let Err(err) = self.maybe_rebind() { return Poll::Ready(Err(err)); @@ -323,7 +312,7 @@ impl UdpSocket { } /// Send a quinn based `Transmit`. - pub fn try_send_quinn(&self, transmit: &Transmit<'_>) -> std::io::Result<()> { + pub fn try_send_quinn(&self, transmit: &Transmit<'_>) -> io::Result<()> { loop { self.maybe_rebind()?; @@ -333,7 +322,7 @@ impl UdpSocket { panic!("lock poisoned: {:?}", e); } Err(TryLockError::WouldBlock) => { - return Err(std::io::Error::new(std::io::ErrorKind::WouldBlock, "")); + return Err(io::Error::new(io::ErrorKind::WouldBlock, "")); } }; let (socket, state) = guard.try_get_connected()?; @@ -356,9 +345,9 @@ impl UdpSocket { pub fn poll_recv_quinn( &self, cx: &mut Context, - bufs: &mut [std::io::IoSliceMut<'_>], + bufs: &mut [io::IoSliceMut<'_>], meta: &mut [quinn_udp::RecvMeta], - ) -> Poll> { + ) -> Poll> { loop { if let Err(err) = self.maybe_rebind() { return Poll::Ready(Err(err)); @@ -399,7 +388,7 @@ impl UdpSocket { } Err(err) => { // ignore spurious wakeups - if err.kind() == std::io::ErrorKind::WouldBlock { + if err.kind() == io::ErrorKind::WouldBlock { continue; } match self.handle_read_error(err) { @@ -449,7 +438,7 @@ pub struct RecvFut<'a, 'b> { } impl Future for RecvFut<'_, '_> { - type Output = std::io::Result; + type Output = io::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let Self { socket, buffer } = &mut *self; @@ -470,7 +459,7 @@ impl Future for RecvFut<'_, '_> { Poll::Ready(Ok(())) => { let res = inner_socket.try_recv(buffer); if let Err(err) = res { - if err.kind() == ErrorKind::WouldBlock { + if err.kind() == io::ErrorKind::WouldBlock { continue; } if let Some(err) = socket.handle_read_error(err) { @@ -499,7 +488,7 @@ pub struct RecvFromFut<'a, 'b> { } impl Future for RecvFromFut<'_, '_> { - type Output = std::io::Result<(usize, SocketAddr)>; + type Output = io::Result<(usize, SocketAddr)>; fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let Self { socket, buffer } = &mut *self; @@ -520,7 +509,7 @@ impl Future for RecvFromFut<'_, '_> { Poll::Ready(Ok(())) => { let res = inner_socket.try_recv_from(buffer); if let Err(err) = res { - if err.kind() == ErrorKind::WouldBlock { + if err.kind() == io::ErrorKind::WouldBlock { continue; } if let Some(err) = socket.handle_read_error(err) { @@ -549,7 +538,7 @@ pub struct SendFut<'a, 'b> { } impl Future for SendFut<'_, '_> { - type Output = std::io::Result; + type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { loop { @@ -569,7 +558,7 @@ impl Future for SendFut<'_, '_> { Poll::Ready(Ok(())) => { let res = socket.try_send(self.buffer); if let Err(err) = res { - if err.kind() == ErrorKind::WouldBlock { + if err.kind() == io::ErrorKind::WouldBlock { continue; } if let Some(err) = self.socket.handle_write_error(err) { @@ -599,7 +588,7 @@ pub struct SendToFut<'a, 'b> { } impl Future for SendToFut<'_, '_> { - type Output = std::io::Result; + type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { loop { @@ -619,7 +608,7 @@ impl Future for SendToFut<'_, '_> { Poll::Ready(Ok(())) => { let res = socket.try_send_to(self.buffer, self.to); if let Err(err) = res { - if err.kind() == ErrorKind::WouldBlock { + if err.kind() == io::ErrorKind::WouldBlock { continue; } @@ -659,7 +648,7 @@ enum SocketState { impl SocketState { fn try_get_connected( &self, - ) -> std::io::Result<(&tokio::net::UdpSocket, &quinn_udp::UdpSocketState)> { + ) -> io::Result<(&tokio::net::UdpSocket, &quinn_udp::UdpSocketState)> { match self { Self::Connected { socket, @@ -668,22 +657,18 @@ impl SocketState { } => Ok((socket, state)), Self::Closed { .. } => { warn!("socket closed"); - Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )) + Err(io::Error::new(io::ErrorKind::BrokenPipe, "socket closed")) } } } - fn bind(addr: SocketAddr) -> Result { + fn bind(addr: SocketAddr) -> io::Result { let network = IpFamily::from(addr.ip()); let socket = socket2::Socket::new( network.into(), socket2::Type::DGRAM, Some(socket2::Protocol::UDP), - ) - .context("socket create")?; + )?; if let Err(err) = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE) { debug!( @@ -699,32 +684,36 @@ impl SocketState { } if network == IpFamily::V6 { // Avoid dualstack - socket.set_only_v6(true).context("only IPv6")?; + socket.set_only_v6(true)?; } // Binding must happen before calling quinn, otherwise `local_addr` // is not yet available on all OSes. - socket.bind(&addr.into()).context("binding")?; + socket.bind(&addr.into())?; // Ensure nonblocking - socket.set_nonblocking(true).context("nonblocking: true")?; + socket.set_nonblocking(true)?; let socket: std::net::UdpSocket = socket.into(); // Convert into tokio UdpSocket - let socket = tokio::net::UdpSocket::from_std(socket).context("conversion to tokio")?; + let socket = tokio::net::UdpSocket::from_std(socket)?; let socket_ref = quinn_udp::UdpSockRef::from(&socket); let socket_state = quinn_udp::UdpSocketState::new(socket_ref)?; - let local_addr = socket.local_addr().context("local addr")?; + let local_addr = socket.local_addr()?; if addr.port() != 0 { - ensure!( - local_addr.port() == addr.port(), - "wrong port bound: {:?}: wanted: {} got {}", - network, - addr.port(), - local_addr.port(), - ); + if local_addr.port() != addr.port() { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "wrong port bound: {:?}: wanted: {} got {}", + network, + addr.port(), + local_addr.port(), + ), + )); + } } Ok(Self::Connected { @@ -734,7 +723,7 @@ impl SocketState { }) } - fn rebind(&mut self) -> Result<()> { + fn rebind(&mut self) -> io::Result<()> { let (addr, closed_state) = match self { Self::Connected { state, addr, .. } => { let s = SocketState::Closed { @@ -745,7 +734,10 @@ impl SocketState { (*addr, s) } Self::Closed { .. } => { - bail!("socket is closed and cannot be rebound"); + return Err(io::Error::new( + io::ErrorKind::Other, + "socket is closed and cannot be rebound", + )); } }; debug!("rebinding {}", addr); @@ -827,6 +819,8 @@ impl Drop for UdpSocket { mod tests { use super::*; + use anyhow::Context; + #[tokio::test] async fn test_reconnect() -> anyhow::Result<()> { let (s_b, mut r_b) = tokio::sync::mpsc::channel(16); From 3c06274e37f85813d2e02f0e034d71814a79af9c Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 26 Nov 2024 17:44:09 +0100 Subject: [PATCH 36/38] fmt.. --- net-tools/netwatch/src/udp.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index a91f9593ca..bd6616ed78 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -130,7 +130,6 @@ impl UdpSocket { /// The function must be called with valid byte array `buf` of sufficient /// size to hold the message bytes. If a message is too long to fit in the /// supplied buffer, excess bytes may be discarded. - /// pub fn recv_from<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFromFut<'a, 'b> { RecvFromFut { socket: self, From cc8d28603bf99bc0a4127fa577f42a71d795e9d6 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 26 Nov 2024 17:44:39 +0100 Subject: [PATCH 37/38] ... --- net-tools/netwatch/src/udp.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index bd6616ed78..e8bf8abb72 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -816,10 +816,10 @@ impl Drop for UdpSocket { #[cfg(test)] mod tests { - use super::*; - use anyhow::Context; + use super::*; + #[tokio::test] async fn test_reconnect() -> anyhow::Result<()> { let (s_b, mut r_b) = tokio::sync::mpsc::channel(16); From 23b4eb77703aaa5d48c3d473f406ff48ca595c51 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 26 Nov 2024 18:13:55 +0100 Subject: [PATCH 38/38] happy clippy --- net-tools/netwatch/src/udp.rs | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index e8bf8abb72..ab9f130402 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -701,18 +701,16 @@ impl SocketState { let socket_state = quinn_udp::UdpSocketState::new(socket_ref)?; let local_addr = socket.local_addr()?; - if addr.port() != 0 { - if local_addr.port() != addr.port() { - return Err(io::Error::new( - io::ErrorKind::Other, - format!( - "wrong port bound: {:?}: wanted: {} got {}", - network, - addr.port(), - local_addr.port(), - ), - )); - } + if addr.port() != 0 && local_addr.port() != addr.port() { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "wrong port bound: {:?}: wanted: {} got {}", + network, + addr.port(), + local_addr.port(), + ), + )); } Ok(Self::Connected {