Skip to content

Commit

Permalink
refactor(iroh-net)!: remove async channel (#2620)
Browse files Browse the repository at this point in the history
## Description

Removes async-channel from iroh-net and replaces it with the tokio mpsc
channel.

This is a mergeable version of
#2614

## Breaking Changes

LocalSwarmDiscovery is no longer UnwindSafe

## Notes & open questions

Open question: I am using blocking_send from inside the windows
RouteMonitor. This depends on that the callback thread is not a tokio
thread. I have no idea if this is a reasonable assumption to make.
Otherwise we would have to bring back the runtime check condition or
just capture a runtime handle and spawn a task.

Note: what even is this? The only message there is is
NetworkMessage::Change, and there is no timestamp or anything. So could
I just use try_send? If the queue is full, there is already a Change in
it, so it is as good as the new one...

## Change checklist

- [x] Self-review.
- [ ] Documentation updates following the [style
guide](https://rust-lang.github.io/rfcs/1574-more-api-documentation-conventions.html#appendix-a-full-conventions-text),
if relevant.
- [ ] Tests if relevant.
- [ ] All breaking changes documented.

---------

Co-authored-by: dignifiedquire <[email protected]>
  • Loading branch information
rklaehn and dignifiedquire authored Aug 14, 2024
1 parent a5072c3 commit 74a527b
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions iroh-net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ workspace = true

[dependencies]
anyhow = { version = "1" }
async-channel = "2.3.1"
base64 = "0.22.1"
backoff = "0.4.0"
bytes = "1"
Expand Down Expand Up @@ -58,7 +57,6 @@ ring = "0.17"
rustls = { version = "0.21.11", default-features = false, features = ["dangerous_configuration"] }
serde = { version = "1", features = ["derive", "rc"] }
smallvec = "1.11.1"
swarm-discovery = { version = "0.2.1", optional = true }
socket2 = "0.5.3"
stun-rs = "0.1.5"
surge-ping = "0.8.0"
Expand Down Expand Up @@ -92,6 +90,11 @@ tokio-rustls-acme = { version = "0.3", optional = true }
iroh-metrics = { version = "0.22.0", path = "../iroh-metrics", default-features = false }
strum = { version = "0.26.2", features = ["derive"] }

# local_swarm_discovery
swarm-discovery = { version = "0.2.1", optional = true }
tokio-stream = { version = "0.1.15", optional = true }


[target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies]
netlink-packet-core = "0.7.0"
netlink-packet-route = "0.17.0"
Expand Down Expand Up @@ -140,7 +143,7 @@ iroh-relay = [
]
metrics = ["iroh-metrics/metrics"]
test-utils = ["iroh-relay"]
local_swarm_discovery = ["dep:swarm-discovery"]
local_swarm_discovery = ["dep:swarm-discovery", "dep:tokio-stream"]

[[bin]]
name = "iroh-relay"
Expand Down
41 changes: 24 additions & 17 deletions iroh-net/src/discovery/local_swarm_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@ use std::{

use anyhow::Result;
use derive_more::FromStr;
use futures_lite::{stream::Boxed as BoxStream, StreamExt};
use futures_lite::stream::Boxed as BoxStream;
use tracing::{debug, error, trace, warn};

use async_channel::Sender;
use iroh_base::key::PublicKey;
use swarm_discovery::{Discoverer, DropGuard, IpClass, Peer};
use tokio::task::JoinSet;
use tokio::{sync::mpsc, task::JoinSet};

use crate::{
discovery::{Discovery, DiscoveryItem},
Expand All @@ -39,13 +38,13 @@ const DISCOVERY_DURATION: Duration = Duration::from_secs(10);
pub struct LocalSwarmDiscovery {
#[allow(dead_code)]
handle: AbortingJoinHandle<()>,
sender: Sender<Message>,
sender: mpsc::Sender<Message>,
}

#[derive(Debug)]
enum Message {
Discovery(String, Peer),
SendAddrs(NodeId, Sender<Result<DiscoveryItem>>),
SendAddrs(NodeId, mpsc::Sender<Result<DiscoveryItem>>),
ChangeLocalAddrs(AddrInfo),
Timeout(NodeId, usize),
}
Expand All @@ -62,7 +61,7 @@ impl LocalSwarmDiscovery {
/// This relies on [`tokio::runtime::Handle::current`] and will panic if called outside of the context of a tokio runtime.
pub fn new(node_id: NodeId) -> Result<Self> {
debug!("Creating new LocalSwarmDiscovery service");
let (send, recv) = async_channel::bounded(64);
let (send, mut recv) = mpsc::channel(64);
let task_sender = send.clone();
let rt = tokio::runtime::Handle::current();
let discovery = LocalSwarmDiscovery::spawn_discoverer(
Expand All @@ -75,19 +74,21 @@ impl LocalSwarmDiscovery {
let handle = tokio::spawn(async move {
let mut node_addrs: HashMap<PublicKey, Peer> = HashMap::default();
let mut last_id = 0;
let mut senders: HashMap<PublicKey, HashMap<usize, Sender<Result<DiscoveryItem>>>> =
HashMap::default();
let mut senders: HashMap<
PublicKey,
HashMap<usize, mpsc::Sender<Result<DiscoveryItem>>>,
> = HashMap::default();
let mut timeouts = JoinSet::new();
loop {
trace!(?node_addrs, "LocalSwarmDiscovery Service loop tick");
let msg = match recv.recv().await {
Err(err) => {
error!("LocalSwarmDiscovery service error: {err:?}");
None => {
error!("LocalSwarmDiscovery channel closed");
error!("closing LocalSwarmDiscovery");
timeouts.abort_all();
return;
}
Ok(msg) => msg,
Some(msg) => msg,
};
match msg {
Message::Discovery(discovered_node_id, peer_info) => {
Expand Down Expand Up @@ -189,20 +190,24 @@ impl LocalSwarmDiscovery {

fn spawn_discoverer(
node_id: PublicKey,
sender: Sender<Message>,
sender: mpsc::Sender<Message>,
socketaddrs: BTreeSet<SocketAddr>,
rt: &tokio::runtime::Handle,
) -> Result<DropGuard> {
let spawn_rt = rt.clone();
let callback = move |node_id: &str, peer: &Peer| {
trace!(
node_id,
?peer,
"Received peer information from LocalSwarmDiscovery"
);

sender
.send_blocking(Message::Discovery(node_id.to_string(), peer.clone()))
.ok();
let sender = sender.clone();
let node_id = node_id.to_string();
let peer = peer.clone();
spawn_rt.spawn(async move {
sender.send(Message::Discovery(node_id, peer)).await.ok();
});
};
let addrs = LocalSwarmDiscovery::socketaddrs_to_addrs(socketaddrs);
let mut discoverer =
Expand Down Expand Up @@ -247,15 +252,16 @@ impl From<&Peer> for DiscoveryItem {

impl Discovery for LocalSwarmDiscovery {
fn resolve(&self, _ep: Endpoint, node_id: NodeId) -> Option<BoxStream<Result<DiscoveryItem>>> {
let (send, recv) = async_channel::bounded(20);
let (send, recv) = mpsc::channel(20);
let discovery_sender = self.sender.clone();
tokio::spawn(async move {
discovery_sender
.send(Message::SendAddrs(node_id, send))
.await
.ok();
});
Some(recv.boxed())
let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
Some(Box::pin(stream))
}

fn publish(&self, info: &AddrInfo) {
Expand All @@ -277,6 +283,7 @@ mod tests {
/// tests)
mod run_in_isolation {
use super::super::*;
use futures_lite::StreamExt;
use testresult::TestResult;

#[tokio::test]
Expand Down
15 changes: 8 additions & 7 deletions iroh-net/src/magicsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ pub(crate) struct MagicSock {
proxy_url: Option<Url>,

/// Used for receiving relay messages.
relay_recv_receiver: async_channel::Receiver<RelayRecvResult>,
relay_recv_receiver: parking_lot::Mutex<mpsc::Receiver<RelayRecvResult>>,
/// Stores wakers, to be called when relay_recv_ch receives new data.
network_recv_wakers: parking_lot::Mutex<Option<Waker>>,
network_send_wakers: parking_lot::Mutex<Option<Waker>>,
Expand Down Expand Up @@ -788,12 +788,13 @@ impl MagicSock {
if self.is_closed() {
break;
}
match self.relay_recv_receiver.try_recv() {
Err(async_channel::TryRecvError::Empty) => {
let mut relay_recv_receiver = self.relay_recv_receiver.lock();
match relay_recv_receiver.try_recv() {
Err(mpsc::error::TryRecvError::Empty) => {
self.network_recv_wakers.lock().replace(cx.waker().clone());
break;
}
Err(async_channel::TryRecvError::Closed) => {
Err(mpsc::error::TryRecvError::Disconnected) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::NotConnected,
"connection closed",
Expand Down Expand Up @@ -1378,7 +1379,7 @@ impl Handle {
insecure_skip_relay_cert_verify,
} = opts;

let (relay_recv_sender, relay_recv_receiver) = async_channel::bounded(128);
let (relay_recv_sender, relay_recv_receiver) = mpsc::channel(128);

let (pconn4, pconn6) = bind(port)?;
let port = pconn4.port();
Expand Down Expand Up @@ -1412,7 +1413,7 @@ impl Handle {
local_addrs: std::sync::RwLock::new((ipv4_addr, ipv6_addr)),
closing: AtomicBool::new(false),
closed: AtomicBool::new(false),
relay_recv_receiver,
relay_recv_receiver: parking_lot::Mutex::new(relay_recv_receiver),
network_recv_wakers: parking_lot::Mutex::new(None),
network_send_wakers: parking_lot::Mutex::new(None),
actor_sender: actor_sender.clone(),
Expand Down Expand Up @@ -1704,7 +1705,7 @@ struct Actor {
relay_actor_sender: mpsc::Sender<RelayActorMessage>,
relay_actor_cancel_token: CancellationToken,
/// Channel to send received relay messages on, for processing.
relay_recv_sender: async_channel::Sender<RelayRecvResult>,
relay_recv_sender: mpsc::Sender<RelayRecvResult>,
/// When set, is an AfterFunc timer that will call MagicSock::do_periodic_stun.
periodic_re_stun_timer: time::Interval,
/// The `NetInfo` provided in the last call to `net_info_func`. It's used to deduplicate calls to netInfoFunc.
Expand Down
5 changes: 3 additions & 2 deletions iroh-net/src/magicsock/udp_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ mod tests {

use super::*;
use anyhow::Result;
use tokio::sync::mpsc;

const ALPN: &[u8] = b"n0/test/1";

Expand Down Expand Up @@ -192,7 +193,7 @@ mod tests {
let (m2, _m2_key) = wrap_socket(m2)?;

let m1_addr = SocketAddr::new(network.local_addr(), m1.local_addr()?.port());
let (m1_send, m1_recv) = async_channel::bounded(8);
let (m1_send, mut m1_recv) = mpsc::channel(8);

let m1_task = tokio::task::spawn(async move {
if let Some(conn) = m1.accept().await {
Expand Down Expand Up @@ -220,7 +221,7 @@ mod tests {
drop(send_bi);

// make sure the right values arrived
let val = m1_recv.recv().await?;
let val = m1_recv.recv().await.unwrap();
assert_eq!(val, b"hello");

m1_task.await??;
Expand Down
6 changes: 3 additions & 3 deletions iroh-net/src/net/netmon/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub(super) struct Actor {
/// OS specific monitor.
#[allow(dead_code)]
route_monitor: RouteMonitor,
mon_receiver: async_channel::Receiver<NetworkMessage>,
mon_receiver: mpsc::Receiver<NetworkMessage>,
actor_receiver: mpsc::Receiver<ActorMessage>,
actor_sender: mpsc::Sender<ActorMessage>,
/// Callback registry.
Expand All @@ -84,7 +84,7 @@ impl Actor {
let wall_time = Instant::now();

// Use flume channels, as tokio::mpsc is not safe to use across ffi boundaries.
let (mon_sender, mon_receiver) = async_channel::bounded(MON_CHAN_CAPACITY);
let (mon_sender, mon_receiver) = mpsc::channel(MON_CHAN_CAPACITY);
let route_monitor = RouteMonitor::new(mon_sender)?;
let (actor_sender, actor_receiver) = mpsc::channel(ACTOR_CHAN_CAPACITY);

Expand Down Expand Up @@ -129,7 +129,7 @@ impl Actor {
debounce_interval.reset_immediately();
}
}
Ok(_event) = self.mon_receiver.recv() => {
Some(_event) = self.mon_receiver.recv() => {
trace!("network activity detected");
last_event.replace(false);
debounce_interval.reset_immediately();
Expand Down
3 changes: 2 additions & 1 deletion iroh-net/src/net/netmon/android.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use anyhow::Result;
use tokio::sync::mpsc;

use super::actor::NetworkMessage;

#[derive(Debug)]
pub(super) struct RouteMonitor {}

impl RouteMonitor {
pub(super) fn new(_sender: async_channel::Sender<NetworkMessage>) -> Result<Self> {
pub(super) fn new(_sender: mpsc::Sender<NetworkMessage>) -> Result<Self> {
// Very sad monitor. Android doesn't allow us to do this

Ok(RouteMonitor {})
Expand Down
4 changes: 2 additions & 2 deletions iroh-net/src/net/netmon/bsd.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use tokio::{io::AsyncReadExt, task::JoinHandle};
use tokio::{io::AsyncReadExt, sync::mpsc, task::JoinHandle};
use tracing::{trace, warn};

#[cfg(any(target_os = "freebsd", target_os = "netbsd", target_os = "openbsd"))]
Expand All @@ -23,7 +23,7 @@ impl Drop for RouteMonitor {
}

impl RouteMonitor {
pub(super) fn new(sender: async_channel::Sender<NetworkMessage>) -> Result<Self> {
pub(super) fn new(sender: mpsc::Sender<NetworkMessage>) -> Result<Self> {
let socket = socket2::Socket::new(libc::AF_ROUTE.into(), socket2::Type::RAW, None)?;
socket.set_nonblocking(true)?;
let socket_std: std::os::unix::net::UnixStream = socket.into();
Expand Down
4 changes: 2 additions & 2 deletions iroh-net/src/net/netmon/linux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use netlink_packet_core::NetlinkPayload;
use netlink_packet_route::{address, constants::*, route, RtnlMessage};
use netlink_sys::{AsyncSocket, SocketAddr};
use rtnetlink::new_connection;
use tokio::task::JoinHandle;
use tokio::{sync::mpsc, task::JoinHandle};
use tracing::{info, trace, warn};

use crate::net::ip::is_link_local;
Expand Down Expand Up @@ -49,7 +49,7 @@ macro_rules! get_nla {
}

impl RouteMonitor {
pub(super) fn new(sender: async_channel::Sender<NetworkMessage>) -> Result<Self> {
pub(super) fn new(sender: mpsc::Sender<NetworkMessage>) -> Result<Self> {
let (mut conn, mut _handle, mut messages) = new_connection()?;

// Specify flags to listen on.
Expand Down
7 changes: 4 additions & 3 deletions iroh-net/src/net/netmon/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Arc};

use anyhow::Result;
use libc::c_void;
use tokio::sync::mpsc;
use tracing::{trace, warn};
use windows::Win32::{
Foundation::{BOOLEAN, HANDLE as Handle},
Expand All @@ -19,21 +20,21 @@ pub(super) struct RouteMonitor {
}

impl RouteMonitor {
pub(super) fn new(sender: async_channel::Sender<NetworkMessage>) -> Result<Self> {
pub(super) fn new(sender: mpsc::Sender<NetworkMessage>) -> Result<Self> {
// Register two callbacks with the windows api
let mut cb_handler = CallbackHandler::default();

// 1. Unicast Address Changes
let s = sender.clone();
cb_handler.register_unicast_address_change_callback(Box::new(move || {
if let Err(err) = s.send_blocking(NetworkMessage::Change) {
if let Err(err) = s.blocking_send(NetworkMessage::Change) {
warn!("unable to send: unicast change notification: {:?}", err);
}
}))?;

// 2. Route Changes
cb_handler.register_route_change_callback(Box::new(move || {
if let Err(err) = sender.send_blocking(NetworkMessage::Change) {
if let Err(err) = sender.blocking_send(NetworkMessage::Change) {
warn!("unable to send: route change notification: {:?}", err);
}
}))?;
Expand Down

0 comments on commit 74a527b

Please sign in to comment.