Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(iroh): Remove custom impl of SharedAbortingJoinHandle #2715

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 7 additions & 46 deletions iroh/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,15 @@
//! To shut down the node, call [`Node::shutdown`].
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::Debug;
use std::future::Future;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use anyhow::{anyhow, Result};
use futures_lite::future::Boxed as BoxFuture;
use futures_lite::StreamExt;
use futures_util::{future::Shared, FutureExt};
use futures_util::future::MapErr;
use futures_util::future::Shared;
use iroh_base::key::PublicKey;
use iroh_blobs::store::Store as BaoStore;
use iroh_blobs::util::local_pool::{LocalPool, LocalPoolHandle};
Expand All @@ -60,8 +57,9 @@ use iroh_net::key::SecretKey;
use iroh_net::{AddrInfo, Endpoint, NodeAddr};
use quic_rpc::transport::ServerEndpoint as _;
use quic_rpc::RpcServer;
use tokio::task::JoinSet;
use tokio::task::{JoinError, JoinSet};
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
use tracing::{debug, error, info, info_span, trace, warn, Instrument};

use crate::node::nodes_storage::store_node_addrs;
Expand Down Expand Up @@ -106,10 +104,12 @@ pub type IrohServerEndpoint = quic_rpc::transport::boxed::ServerEndpoint<
#[derive(Debug, Clone)]
pub struct Node<D> {
inner: Arc<NodeInner<D>>,
task: SharedAbortingJoinHandle<()>,
task: Shared<MapErr<AbortOnDropHandle<()>, JoinErrToStr>>,
matheus23 marked this conversation as resolved.
Show resolved Hide resolved
protocols: Arc<ProtocolMap>,
}

pub(crate) type JoinErrToStr = Box<dyn Fn(JoinError) -> String + Send + Sync + 'static>;

#[derive(derive_more::Debug)]
struct NodeInner<D> {
db: D,
Expand Down Expand Up @@ -624,45 +624,6 @@ fn node_address_for_storage(info: RemoteInfo) -> Option<NodeAddr> {
}
}

/// A join handle that owns the task it is running, and aborts it when dropped.
/// It is cloneable and will abort when the last instance is dropped.
///
/// Please do not copy/use this elsewhere, try and use
/// [`tokio_util::task::AbortOnDropHandle`] instead.
#[derive(Debug, Clone)]
struct SharedAbortingJoinHandle<T: Clone + Send> {
fut: Shared<BoxFuture<std::result::Result<T, String>>>,
abort: Arc<tokio::task::AbortHandle>,
}

impl<T: Clone + Send + 'static> From<tokio::task::JoinHandle<T>> for SharedAbortingJoinHandle<T> {
fn from(handle: tokio::task::JoinHandle<T>) -> Self {
let abort = handle.abort_handle();
let fut: BoxFuture<std::result::Result<T, String>> =
Box::pin(async move { handle.await.map_err(|e| e.to_string()) });
Self {
fut: fut.shared(),
abort: Arc::new(abort),
}
}
}

impl<T: Clone + Send> Future for SharedAbortingJoinHandle<T> {
type Output = std::result::Result<T, String>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.fut).poll(cx)
}
}

impl<T: Clone + Send> Drop for SharedAbortingJoinHandle<T> {
fn drop(&mut self) {
if Arc::strong_count(&self.abort) == 1 {
self.abort.abort();
}
}
}

#[cfg(test)]
mod tests {
use anyhow::{bail, Context};
Expand Down
12 changes: 9 additions & 3 deletions iroh/src/node/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{

use anyhow::{Context, Result};
use futures_lite::StreamExt;
use futures_util::{FutureExt as _, TryFutureExt as _};
use iroh_base::key::SecretKey;
use iroh_blobs::{
downloader::Downloader,
Expand All @@ -28,7 +29,8 @@ use iroh_net::{

use quic_rpc::transport::{boxed::BoxableServerEndpoint, quinn::QuinnServerEndpoint};
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use tokio::task::JoinError;
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
use tracing::{debug, error_span, trace, Instrument};

use crate::{
Expand All @@ -42,7 +44,9 @@ use crate::{
util::{fs::load_secret_key, path::IrohPaths},
};

use super::{docs::DocsEngine, rpc_status::RpcStatus, IrohServerEndpoint, Node, NodeInner};
use super::{
docs::DocsEngine, rpc_status::RpcStatus, IrohServerEndpoint, JoinErrToStr, Node, NodeInner,
};

/// Default bind address for the node.
/// 11204 is "iroh" in leetspeak <https://simple.wikipedia.org/wiki/Leet>
Expand Down Expand Up @@ -853,7 +857,9 @@ impl<D: iroh_blobs::store::Store> ProtocolBuilder<D> {
let node = Node {
inner,
protocols,
task: task.into(),
task: AbortOnDropHandle::new(task)
.map_err(Box::new(|e: JoinError| e.to_string()) as JoinErrToStr)
.shared(),
};

// Wait for a single direct address update, to make sure
Expand Down
Loading