Skip to content

Commit

Permalink
Make IfWatch pollable using manual futures.
Browse files Browse the repository at this point in the history
  • Loading branch information
dvc94ch committed Feb 23, 2021
1 parent caba1d0 commit 1a14359
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 53 deletions.
3 changes: 2 additions & 1 deletion examples/if_watch.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use if_watch::IfWatcher;
use std::pin::Pin;

fn main() {
env_logger::init();
futures_lite::future::block_on(async {
let mut set = IfWatcher::new().await.unwrap();
loop {
println!("Got event {:?}", set.next().await);
println!("Got event {:?}", Pin::new(&mut set).await);
}
});
}
21 changes: 15 additions & 6 deletions src/fallback.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use crate::IfEvent;
use async_io::Timer;
use futures_lite::StreamExt;
use if_addrs::IfAddr;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use std::io::Result;
use std::{
collections::{HashSet, VecDeque},
future::Future,
pin::Pin,
task::{Context, Poll},
time::{Duration, Instant},
};

Expand Down Expand Up @@ -51,15 +53,22 @@ impl IfWatcher {
pub fn iter(&self) -> impl Iterator<Item = &IpNet> {
self.addrs.iter()
}
}

impl Future for IfWatcher {
type Output = Result<IfEvent>;

/// Returns a future for the next event.
pub async fn next(&mut self) -> Result<IfEvent> {
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
loop {
if let Some(event) = self.queue.pop_front() {
return Ok(event);
return Poll::Ready(Ok(event));
}
if Pin::new(&mut self.ticker).poll_next(cx).is_pending() {
return Poll::Pending;
}
if let Err(err) = self.resync() {
return Poll::Ready(Err(err));
}
self.ticker.next().await;
self.resync()?;
}
}
}
Expand Down
28 changes: 12 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
#![deny(warnings)]

pub use ipnet::IpNet;
use std::future::Future;
use std::io::Result;
use std::pin::Pin;
use std::task::{Context, Poll};

#[cfg(not(any(unix, windows)))]
compile_error!("Only Unix and Windows are supported");
Expand Down Expand Up @@ -45,33 +48,26 @@ impl IfWatcher {
pub fn iter(&self) -> impl Iterator<Item = &IpNet> {
self.0.iter()
}
}

impl Future for IfWatcher {
type Output = Result<IfEvent>;

/// Returns a future for the next event.
pub async fn next(&mut self) -> Result<IfEvent> {
self.0.next().await
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx)
}
}

#[cfg(test)]
mod tests {
use super::*;
use futures_lite::future::poll_fn;
use std::{future::Future, pin::Pin, task::Poll};

#[test]
fn test_ip_watch() {
futures_lite::future::block_on(async {
let mut set = IfWatcher::new().await.unwrap();
poll_fn(|cx| loop {
let next = set.next();
futures_lite::pin!(next);
if let Poll::Ready(Ok(ev)) = Pin::new(&mut next).poll(cx) {
println!("Got event {:?}", ev);
continue;
}
return Poll::Ready(());
})
.await;
let event = Pin::new(&mut set).await.unwrap();
println!("Got event {:?}", event);
});
}

Expand All @@ -81,7 +77,7 @@ mod tests {
fn is_send<T: Send>(_: T) {}
is_send(IfWatcher::new());
is_send(IfWatcher::new().await.unwrap());
is_send(IfWatcher::new().await.unwrap().next());
is_send(Pin::new(&mut IfWatcher::new().await.unwrap()));
});
}
}
32 changes: 23 additions & 9 deletions src/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ macro_rules! errno {

use crate::IfEvent;
use async_io::Async;
use futures_lite::Future;
use ipnet::IpNet;
use std::collections::{HashSet, VecDeque};
use std::io::Result;
use std::os::unix::prelude::*;
use std::pin::Pin;
use std::task::{Context, Poll};

mod aligned_buffer;

Expand All @@ -35,7 +38,7 @@ mod linux;
type Watcher = linux::NetlinkSocket;

#[derive(Debug)]
struct Fd(RawFd);
pub struct Fd(RawFd);

impl Fd {
pub fn new(fd: RawFd) -> Result<Async<Self>> {
Expand Down Expand Up @@ -83,25 +86,36 @@ impl IfWatcher {
pub fn iter(&self) -> impl Iterator<Item = &IpNet> {
self.addrs.iter()
}
}

impl Future for IfWatcher {
type Output = Result<IfEvent>;

/// Returns a future for the next event.
pub async fn next(&mut self) -> Result<IfEvent> {
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let me = Pin::into_inner(self);
loop {
while let Some(event) = self.queue.pop_front() {
while let Some(event) = me.queue.pop_front() {
match event {
IfEvent::Up(inet) => {
if self.addrs.insert(inet) {
return Ok(event);
if me.addrs.insert(inet) {
return Poll::Ready(Ok(event));
}
}
IfEvent::Down(inet) => {
if self.addrs.remove(&inet) {
return Ok(event);
if me.addrs.remove(&inet) {
return Poll::Ready(Ok(event));
}
}
}
}
self.watcher.recv_event(&mut self.queue).await?;
if me.watcher.fd().poll_readable(cx).is_pending() {
return Poll::Pending;
}
let fut = me.watcher.recv_event(&mut me.queue);
futures_lite::pin!(fut);
if let Poll::Ready(Err(err)) = fut.poll(cx) {
return Poll::Ready(Err(err));
}
}
}
}
8 changes: 6 additions & 2 deletions src/unix/linux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,16 @@ impl NetlinkSocket {
}
}

pub fn fd(&self) -> &Async<Fd> {
&self.fd
}

pub async fn send_getaddr(&mut self) -> Result<()> {
#[repr(C)]
struct Nlmsg {
hdr: libc::nlmsghdr,
msg: rtnetlink::ifaddrmsg,
};
}
if self.seqnum == u32::max_value() {
self.seqnum = 1;
} else {
Expand Down Expand Up @@ -184,7 +188,7 @@ fn read_ifaddrmsg<'a>(queue: &mut VecDeque<IfEvent>, ty: i32, msg: &mut U32Align
}
let ip = iter
.filter_map(|e| match e {
rtnetlink::RtaMessage::IPAddr(e) => Some(e),
rtnetlink::RtaMessage::IpAddr(e) => Some(e),
rtnetlink::RtaMessage::Other => None,
})
.next()
Expand Down
4 changes: 2 additions & 2 deletions src/unix/linux/rtnetlink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub(crate) fn read_msg<'a, M: FromBuffer>(
}

pub(crate) enum RtaMessage {
IPAddr(std::net::IpAddr),
IpAddr(std::net::IpAddr),
Other,
}

Expand All @@ -79,7 +79,7 @@ impl Iterator for RtaIterator<'_> {
let (attr, buf): (rtattr, _) = self.0.read()?;
Some(match attr.rta_type {
libc::RTA_DST => match buf.try_into().ok() {
Some(e) => RtaMessage::IPAddr(e),
Some(e) => RtaMessage::IpAddr(e),
None => RtaMessage::Other,
},
other => {
Expand Down
35 changes: 18 additions & 17 deletions src/windows.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::IfEvent;
use futures::task::AtomicWaker;
use futures_lite::future::poll_fn;
use if_addrs::IfAddr;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use std::{
collections::{HashSet, VecDeque},
future::Future,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::Poll,
task::{Context, Poll},
};
use winapi::shared::{
netioapi::{
Expand Down Expand Up @@ -71,23 +72,23 @@ impl IfWatcher {
pub fn iter(&self) -> impl Iterator<Item = &IpNet> {
self.addrs.iter()
}
}

/// Returns a future for the next event.
pub async fn next(&mut self) -> std::io::Result<IfEvent> {
poll_fn(|cx| {
self.waker.register(cx.waker());
if self.resync.swap(false, Ordering::Relaxed) {
if let Err(error) = self.resync() {
return Poll::Ready(Err(error));
}
}
if let Some(event) = self.queue.pop_front() {
Poll::Ready(Ok(event))
} else {
Poll::Pending
impl Future for IfWatcher {
type Output = Result<IfEvent>;

fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.waker.register(cx.waker());
if self.resync.swap(false, Ordering::Relaxed) {
if let Err(error) = self.resync() {
return Poll::Ready(Err(error));
}
})
.await
}
if let Some(event) = self.queue.pop_front() {
Poll::Ready(Ok(event))
} else {
Poll::Pending
}
}
}

Expand Down

0 comments on commit 1a14359

Please sign in to comment.