Skip to content

Commit

Permalink
fix(iroh-blobs): properly handle Drop in local pool during shutdown (#…
Browse files Browse the repository at this point in the history
…2517)

## Description

The tokio_util LocalPoolHandle does not properly handle Drop during
shutdown. Its threads are just spawned as detached. So any drop impl
that runs in a local pool thread will be stopped as soon as the process
terminates. This can have some bad consequences if that drop operation
performs IO, like closing files and committing database transactions.

Here is where the threads get spawned. The `std::thread::JoinHandle`s
are just dropped.

https://docs.rs/tokio-util/latest/src/tokio_util/task/spawn_pinned.rs.html#381

Here is some discussion of the observed effects:
https://discord.com/channels/949724860232392765/1260571544414064670

LocalPoolHandle also, of course, is using an unbounded channel:

https://docs.rs/tokio-util/latest/src/tokio_util/task/spawn_pinned.rs.html#372

## Breaking Changes

Public interfaces using tokio_util::task::LocalPoolHandle will now use
our own LocalPool/LocalPoolHandle.

## Notes & open questions

Should we use an unbounded channel like tokio::spawn or
LocalPoolHandle::spawn_pinned? Seems like a big footgun. But if not, we
need to somehow handle when the queue is full.

<!-- Any notes, remarks or open questions you have to make about the PR.
-->

## Change checklist

- [x] Self-review.
- [x] Documentation updates following the [style
guide](https://rust-lang.github.io/rfcs/1574-more-api-documentation-conventions.html#appendix-a-full-conventions-text),
if relevant.
- [x] Tests if relevant.
- [x] All breaking changes documented.
  • Loading branch information
rklaehn authored Jul 22, 2024
1 parent 4c11c58 commit b4506b2
Show file tree
Hide file tree
Showing 16 changed files with 752 additions and 64 deletions.
3 changes: 1 addition & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion iroh-blobs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ iroh-metrics = { version = "0.20.0", path = "../iroh-metrics", optional = true }
iroh-net = { version = "0.20.0", path = "../iroh-net" }
num_cpus = "1.15.0"
parking_lot = { version = "0.12.1", optional = true }
pin-project = "1.1.5"
postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] }
rand = "0.8"
range-collections = "0.4.0"
Expand All @@ -45,7 +46,7 @@ smallvec = { version = "1.10.0", features = ["serde", "const_new"] }
tempfile = { version = "3.10.0", optional = true }
thiserror = "1"
tokio = { version = "1", features = ["fs"] }
tokio-util = { version = "0.7", features = ["io-util", "io", "rt"] }
tokio-util = { version = "0.7", features = ["io-util", "io"] }
tracing = "0.1"
tracing-futures = "0.2.5"

Expand Down
5 changes: 2 additions & 3 deletions iroh-blobs/examples/provide-bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
//! cargo run --example provide-bytes collection
//! To provide a collection (multiple blobs)
use anyhow::Result;
use tokio_util::task::LocalPoolHandle;
use tracing::warn;
use tracing_subscriber::{prelude::*, EnvFilter};

use iroh_blobs::{format::collection::Collection, Hash};
use iroh_blobs::{format::collection::Collection, util::local_pool::LocalPool, Hash};

mod connect;
use connect::{make_and_write_certs, make_server_endpoint, CERT_PATH};
Expand Down Expand Up @@ -82,7 +81,7 @@ async fn main() -> Result<()> {
println!("\nfetch the content using a stream by running the following example:\n\ncargo run --example fetch-stream {hash} \"{addr}\" {format}\n");

// create a new local pool handle with 1 worker thread
let lp = LocalPoolHandle::new(1);
let lp = LocalPool::single();

let accept_task = tokio::spawn(async move {
while let Some(incoming) = endpoint.accept().await {
Expand Down
6 changes: 3 additions & 3 deletions iroh-blobs/src/downloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ use tokio::{
sync::{mpsc, oneshot},
task::JoinSet,
};
use tokio_util::{sync::CancellationToken, task::LocalPoolHandle, time::delay_queue};
use tokio_util::{sync::CancellationToken, time::delay_queue};
use tracing::{debug, error_span, trace, warn, Instrument};

use crate::{
get::{db::DownloadProgress, Stats},
store::Store,
util::progress::ProgressSender,
util::{local_pool::LocalPoolHandle, progress::ProgressSender},
};

mod get;
Expand Down Expand Up @@ -338,7 +338,7 @@ impl Downloader {

service.run().instrument(error_span!("downloader", %me))
};
rt.spawn_pinned(create_future);
rt.spawn_detached(create_future);
Self {
next_id: Arc::new(AtomicU64::new(0)),
msg_tx,
Expand Down
55 changes: 36 additions & 19 deletions iroh-blobs/src/downloader/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use iroh_net::key::SecretKey;

use crate::{
get::{db::BlobId, progress::TransferState},
util::progress::{FlumeProgressSender, IdGenerator},
util::{
local_pool::LocalPool,
progress::{FlumeProgressSender, IdGenerator},
},
};

use super::*;
Expand All @@ -23,7 +26,7 @@ impl Downloader {
dialer: dialer::TestingDialer,
getter: getter::TestingGetter,
concurrency_limits: ConcurrencyLimits,
) -> Self {
) -> (Self, LocalPool) {
Self::spawn_for_test_with_retry_config(
dialer,
getter,
Expand All @@ -37,21 +40,25 @@ impl Downloader {
getter: getter::TestingGetter,
concurrency_limits: ConcurrencyLimits,
retry_config: RetryConfig,
) -> Self {
) -> (Self, LocalPool) {
let (msg_tx, msg_rx) = mpsc::channel(super::SERVICE_CHANNEL_CAPACITY);

LocalPoolHandle::new(1).spawn_pinned(move || async move {
let lp = LocalPool::default();
lp.spawn_detached(move || async move {
// we want to see the logs of the service
let _guard = iroh_test::logging::setup();

let service = Service::new(getter, dialer, concurrency_limits, retry_config, msg_rx);
service.run().await
});

Downloader {
next_id: Arc::new(AtomicU64::new(0)),
msg_tx,
}
(
Downloader {
next_id: Arc::new(AtomicU64::new(0)),
msg_tx,
},
lp,
)
}
}

Expand All @@ -63,7 +70,8 @@ async fn smoke_test() {
let getter = getter::TestingGetter::default();
let concurrency_limits = ConcurrencyLimits::default();

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

// send a request and make sure the peer is requested the corresponding download
let peer = SecretKey::generate().public();
Expand All @@ -88,7 +96,8 @@ async fn deduplication() {
getter.set_request_duration(Duration::from_secs(1));
let concurrency_limits = ConcurrencyLimits::default();

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

let peer = SecretKey::generate().public();
let kind: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into();
Expand Down Expand Up @@ -119,7 +128,8 @@ async fn cancellation() {
getter.set_request_duration(Duration::from_millis(500));
let concurrency_limits = ConcurrencyLimits::default();

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

let peer = SecretKey::generate().public();
let kind_1: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into();
Expand Down Expand Up @@ -158,7 +168,8 @@ async fn max_concurrent_requests_total() {
..Default::default()
};

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

// send the downloads
let peer = SecretKey::generate().public();
Expand Down Expand Up @@ -201,7 +212,8 @@ async fn max_concurrent_requests_per_peer() {
..Default::default()
};

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

// send the downloads
let peer = SecretKey::generate().public();
Expand Down Expand Up @@ -257,7 +269,8 @@ async fn concurrent_progress() {
}
.boxed()
}));
let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());

let peer = SecretKey::generate().public();
let hash = Hash::new([0u8; 32]);
Expand Down Expand Up @@ -341,7 +354,8 @@ async fn long_queue() {
..Default::default()
};

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
// send the downloads
let nodes = [
SecretKey::generate().public(),
Expand Down Expand Up @@ -370,7 +384,8 @@ async fn fail_while_running() {
let _guard = iroh_test::logging::setup();
let dialer = dialer::TestingDialer::default();
let getter = getter::TestingGetter::default();
let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());
let blob_fail = HashAndFormat::raw(Hash::new([1u8; 32]));
let blob_success = HashAndFormat::raw(Hash::new([2u8; 32]));

Expand Down Expand Up @@ -407,7 +422,8 @@ async fn retry_nodes_simple() {
let _guard = iroh_test::logging::setup();
let dialer = dialer::TestingDialer::default();
let getter = getter::TestingGetter::default();
let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());
let node = SecretKey::generate().public();
let dial_attempts = Arc::new(AtomicUsize::new(0));
let dial_attempts2 = dial_attempts.clone();
Expand All @@ -432,7 +448,7 @@ async fn retry_nodes_fail() {
max_retries_per_node: 3,
};

let downloader = Downloader::spawn_for_test_with_retry_config(
let (downloader, _lp) = Downloader::spawn_for_test_with_retry_config(
dialer.clone(),
getter.clone(),
Default::default(),
Expand Down Expand Up @@ -472,7 +488,8 @@ async fn retry_nodes_jump_queue() {
..Default::default()
};

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

let good_node = SecretKey::generate().public();
let bad_node = SecretKey::generate().public();
Expand Down
4 changes: 2 additions & 2 deletions iroh-blobs/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ use iroh_io::stats::{
use iroh_io::{AsyncSliceReader, AsyncStreamWriter, TokioStreamWriter};
use iroh_net::endpoint::{self, RecvStream, SendStream};
use serde::{Deserialize, Serialize};
use tokio_util::task::LocalPoolHandle;
use tracing::{debug, debug_span, info, trace, warn};
use tracing_futures::Instrument;

use crate::hashseq::parse_hash_seq;
use crate::protocol::{GetRequest, RangeSpec, Request};
use crate::store::*;
use crate::util::local_pool::LocalPoolHandle;
use crate::util::Tag;
use crate::{BlobFormat, Hash};

Expand Down Expand Up @@ -302,7 +302,7 @@ pub async fn handle_connection<D: Map, E: EventSender>(
};
events.send(Event::ClientConnected { connection_id }).await;
let db = db.clone();
rt.spawn_pinned(|| {
rt.spawn_detached(|| {
async move {
if let Err(err) = handle_stream(db, reader, writer).await {
warn!("error: {err:#?}",);
Expand Down
7 changes: 4 additions & 3 deletions iroh-blobs/src/store/bao_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,8 @@ mod tests {
decode_response_into_batch, local, make_wire_data, random_test_data, trickle, validate,
};
use tokio::task::JoinSet;
use tokio_util::task::LocalPoolHandle;

use crate::util::local_pool::LocalPool;

use super::*;

Expand Down Expand Up @@ -957,7 +958,7 @@ mod tests {
)),
hash.into(),
);
let local = LocalPoolHandle::new(4);
let local = LocalPool::default();
let mut tasks = Vec::new();
for i in 0..4 {
let file = handle.writer();
Expand All @@ -968,7 +969,7 @@ mod tests {
.map(io::Result::Ok)
.boxed();
let trickle = TokioStreamReader::new(tokio_util::io::StreamReader::new(trickle));
let task = local.spawn_pinned(move || async move {
let task = local.spawn(move || async move {
decode_response_into_batch(hash, IROH_BLOCK_SIZE, chunk_ranges, trickle, file).await
});
tasks.push(task);
Expand Down
11 changes: 7 additions & 4 deletions iroh-blobs/src/store/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use iroh_base::rpc::RpcError;
use iroh_io::AsyncSliceReader;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncRead;
use tokio_util::task::LocalPoolHandle;

use crate::{
hashseq::parse_hash_seq,
protocol::RangeSpec,
util::{
local_pool::{self, LocalPool},
progress::{BoxedProgressSender, IdGenerator, ProgressSender},
Tag,
},
Expand Down Expand Up @@ -423,7 +423,10 @@ async fn validate_impl(
use futures_buffered::BufferedStreamExt;

let validate_parallelism: usize = num_cpus::get();
let lp = LocalPoolHandle::new(validate_parallelism);
let lp = LocalPool::new(local_pool::Config {
threads: validate_parallelism,
..Default::default()
});
let complete = store.blobs().await?.collect::<io::Result<Vec<_>>>()?;
let partial = store
.partial_blobs()
Expand All @@ -437,7 +440,7 @@ async fn validate_impl(
.map(|hash| {
let store = store.clone();
let tx = tx.clone();
lp.spawn_pinned(move || async move {
lp.spawn(move || async move {
let entry = store
.get(&hash)
.await?
Expand Down Expand Up @@ -486,7 +489,7 @@ async fn validate_impl(
.map(|hash| {
let store = store.clone();
let tx = tx.clone();
lp.spawn_pinned(move || async move {
lp.spawn(move || async move {
let entry = store
.get(&hash)
.await?
Expand Down
1 change: 1 addition & 0 deletions iroh-blobs/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod progress;
pub use mem_or_file::MemOrFile;
mod sparse_mem_file;
pub use sparse_mem_file::SparseMemFile;
pub mod local_pool;

/// A tag
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, From, Into)]
Expand Down
Loading

0 comments on commit b4506b2

Please sign in to comment.