From adcce224d56648c0e7791e8d594faf90a9b956e4 Mon Sep 17 00:00:00 2001 From: Roland Kuhn Date: Wed, 20 Apr 2022 18:23:52 +0200 Subject: [PATCH] src/linux: Use select over select! and return BrokenPipe error (#15) --- src/linux.rs | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/src/linux.rs b/src/linux.rs index 89964bb..c0afed1 100644 --- a/src/linux.rs +++ b/src/linux.rs @@ -1,7 +1,7 @@ use crate::{IfEvent, IpNet, Ipv4Net, Ipv6Net}; use fnv::FnvHashSet; use futures::channel::mpsc::UnboundedReceiver; -use futures::future::FutureExt; +use futures::future::Either; use futures::stream::{Stream, TryStreamExt}; use rtnetlink::constants::{RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR}; use rtnetlink::packet::address::nlas::Nla; @@ -12,6 +12,7 @@ use std::collections::VecDeque; use std::future::Future; use std::io::{Error, ErrorKind, Result}; use std::net::{Ipv4Addr, Ipv6Addr}; +use std::ops::DerefMut; use std::pin::Pin; use std::task::{Context, Poll}; @@ -41,19 +42,28 @@ impl IfWatcher { let mut queue = VecDeque::default(); loop { - futures::select! { - msg = stream.try_next().fuse() => match msg { - Ok(Some(msg)) => { - for net in iter_nets(msg) { - if addrs.insert(net) { - queue.push_back(IfEvent::Up(net)); + let fut = futures::future::select(conn, stream.try_next()); + match fut.await { + Either::Left(_) => { + return Err(std::io::Error::new( + ErrorKind::BrokenPipe, + "rtnetlink socket closed", + )) + } + Either::Right((x, c)) => { + conn = c; + match x { + Ok(Some(msg)) => { + for net in iter_nets(msg) { + if addrs.insert(net) { + queue.push_back(IfEvent::Up(net)); + } } } - }, - Ok(None) => break, - Err(err) => return Err(Error::new(ErrorKind::Other, err)), - }, - _r = (&mut conn).fuse() => {} + Ok(None) => break, + Err(err) => return Err(Error::new(ErrorKind::Other, err)), + } + } } } Ok(Self { @@ -89,7 +99,13 @@ impl Future for IfWatcher { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - while Pin::new(&mut self.conn).poll(cx).is_ready() {} + log::trace!("polling IfWatcher {:p}", self.deref_mut()); + if Pin::new(&mut self.conn).poll(cx).is_ready() { + return Poll::Ready(Err(std::io::Error::new( + ErrorKind::BrokenPipe, + "rtnetlink socket closed", + ))); + } while let Poll::Ready(Some((message, _))) = Pin::new(&mut self.messages).poll_next(cx) { match message.payload { NetlinkPayload::Error(err) => return Poll::Ready(Err(err.to_io())),