diff --git a/tokio-util/src/task/mod.rs b/tokio-util/src/task/mod.rs index a5f94a898e2..e37015a4e3c 100644 --- a/tokio-util/src/task/mod.rs +++ b/tokio-util/src/task/mod.rs @@ -10,3 +10,6 @@ pub use spawn_pinned::LocalPoolHandle; #[cfg(tokio_unstable)] #[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))] pub use join_map::{JoinMap, JoinMapKeys}; + +pub mod task_tracker; +pub use task_tracker::TaskTracker; diff --git a/tokio-util/src/task/task_tracker.rs b/tokio-util/src/task/task_tracker.rs new file mode 100644 index 00000000000..d8f3bb4859a --- /dev/null +++ b/tokio-util/src/task/task_tracker.rs @@ -0,0 +1,719 @@ +//! Types related to the [`TaskTracker`] collection. +//! +//! See the documentation of [`TaskTracker`] for more information. + +use pin_project_lite::pin_project; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::{futures::Notified, Notify}; + +#[cfg(feature = "rt")] +use tokio::{ + runtime::Handle, + task::{JoinHandle, LocalSet}, +}; + +/// A task tracker used for waiting until tasks exit. +/// +/// This is usually used together with [`CancellationToken`] to implement [graceful shutdown]. The +/// `CancellationToken` is used to signal to tasks that they should shut down, and the +/// `TaskTracker` is used to wait for them to finish shutting down. +/// +/// The `TaskTracker` will also keep track of a `closed` boolean. This is used to handle the case +/// where the `TaskTracker` is empty, but we don't want to shut down yet. This means that the +/// [`wait`] method will wait until *both* of the following happen at the same time: +/// +/// * The `TaskTracker` must be closed using the [`close`] method. +/// * The `TaskTracker` must be empty, that is, all tasks that it is tracking must have exited. +/// +/// When a call to [`wait`] returns, it is guaranteed that all tracked tasks have exited and that +/// the destructor of the future has finished running. However, there might be a short amount of +/// time where [`JoinHandle::is_finished`] returns false. +/// +/// # Comparison to `JoinSet` +/// +/// The main Tokio crate has a similar collection known as [`JoinSet`]. The `JoinSet` type has a +/// lot more features than `TaskTracker`, so `TaskTracker` should only be used when one of its +/// unique features is required: +/// +/// 1. When tasks exit, a `TaskTracker` will allow the task to immediately free its memory. +/// 2. By not closing the `TaskTracker`, [`wait`] will be prevented from from returning even if +/// the `TaskTracker` is empty. +/// 3. A `TaskTracker` does not require mutable access to insert tasks. +/// 4. A `TaskTracker` can be cloned to share it with many tasks. +/// +/// The first point is the most important one. A [`JoinSet`] keeps track of the return value of +/// every inserted task. This means that if the caller keeps inserting tasks and never calls +/// [`join_next`], then their return values will keep building up and consuming memory, _even if_ +/// most of the tasks have already exited. This can cause the process to run out of memory. With a +/// `TaskTracker`, this does not happen. Once tasks exit, they are immediately removed from the +/// `TaskTracker`. +/// +/// # Examples +/// +/// For more examples, please see the topic page on [graceful shutdown]. +/// +/// ## Spawn tasks and wait for them to exit +/// +/// This is a simple example. For this case, [`JoinSet`] should probably be used instead. +/// +/// ``` +/// use tokio_util::task::TaskTracker; +/// +/// #[tokio::main] +/// async fn main() { +/// let tracker = TaskTracker::new(); +/// +/// for i in 0..10 { +/// tracker.spawn(async move { +/// println!("Task {} is running!", i); +/// }); +/// } +/// // Once we spawned everything, we close the tracker. +/// tracker.close(); +/// +/// // Wait for everything to finish. +/// tracker.wait().await; +/// +/// println!("This is printed after all of the tasks."); +/// } +/// ``` +/// +/// ## Wait for tasks to exit +/// +/// This example shows the intended use-case of `TaskTracker`. It is used together with +/// [`CancellationToken`] to implement graceful shutdown. +/// ``` +/// use tokio_util::sync::CancellationToken; +/// use tokio_util::task::TaskTracker; +/// use tokio::time::{self, Duration}; +/// +/// async fn background_task(num: u64) { +/// for i in 0..10 { +/// time::sleep(Duration::from_millis(100*num)).await; +/// println!("Background task {} in iteration {}.", num, i); +/// } +/// } +/// +/// #[tokio::main] +/// # async fn _hidden() {} +/// # #[tokio::main(flavor = "current_thread", start_paused = true)] +/// async fn main() { +/// let tracker = TaskTracker::new(); +/// let token = CancellationToken::new(); +/// +/// for i in 0..10 { +/// let token = token.clone(); +/// tracker.spawn(async move { +/// // Use a `tokio::select!` to kill the background task if the token is +/// // cancelled. +/// tokio::select! { +/// () = background_task(i) => { +/// println!("Task {} exiting normally.", i); +/// }, +/// () = token.cancelled() => { +/// // Do some cleanup before we really exit. +/// time::sleep(Duration::from_millis(50)).await; +/// println!("Task {} finished cleanup.", i); +/// }, +/// } +/// }); +/// } +/// +/// // Spawn a background task that will send the shutdown signal. +/// { +/// let tracker = tracker.clone(); +/// tokio::spawn(async move { +/// // Normally you would use something like ctrl-c instead of +/// // sleeping. +/// time::sleep(Duration::from_secs(2)).await; +/// tracker.close(); +/// token.cancel(); +/// }); +/// } +/// +/// // Wait for all tasks to exit. +/// tracker.wait().await; +/// +/// println!("All tasks have exited now."); +/// } +/// ``` +/// +/// [`CancellationToken`]: crate::sync::CancellationToken +/// [`JoinHandle::is_finished`]: tokio::task::JoinHandle::is_finished +/// [`JoinSet`]: tokio::task::JoinSet +/// [`close`]: Self::close +/// [`join_next`]: tokio::task::JoinSet::join_next +/// [`wait`]: Self::wait +/// [graceful shutdown]: https://tokio.rs/tokio/topics/shutdown +pub struct TaskTracker { + inner: Arc, +} + +/// Represents a task tracked by a [`TaskTracker`]. +#[must_use] +#[derive(Debug)] +pub struct TaskTrackerToken { + task_tracker: TaskTracker, +} + +struct TaskTrackerInner { + /// Keeps track of the state. + /// + /// The lowest bit is whether the task tracker is closed. + /// + /// The rest of the bits count the number of tracked tasks. + state: AtomicUsize, + /// Used to notify when the last task exits. + on_last_exit: Notify, +} + +pin_project! { + /// A future that is tracked as a task by a [`TaskTracker`]. + /// + /// The associated [`TaskTracker`] cannot complete until this future is dropped. + /// + /// This future is returned by [`TaskTracker::track_future`]. + #[must_use = "futures do nothing unless polled"] + pub struct TrackedFuture { + #[pin] + future: F, + token: TaskTrackerToken, + } +} + +pin_project! { + /// A future that completes when the [`TaskTracker`] is empty and closed. + /// + /// This future is returned by [`TaskTracker::wait`]. + #[must_use = "futures do nothing unless polled"] + pub struct TaskTrackerWaitFuture<'a> { + #[pin] + future: Notified<'a>, + inner: Option<&'a TaskTrackerInner>, + } +} + +impl TaskTrackerInner { + #[inline] + fn new() -> Self { + Self { + state: AtomicUsize::new(0), + on_last_exit: Notify::new(), + } + } + + #[inline] + fn is_closed_and_empty(&self) -> bool { + // If empty and closed bit set, then we are done. + // + // The acquire load will synchronize with the release store of any previous call to + // `set_closed` and `drop_task`. + self.state.load(Ordering::Acquire) == 1 + } + + #[inline] + fn set_closed(&self) -> bool { + // The AcqRel ordering makes the closed bit behave like a `Mutex` for synchronization + // purposes. We do this because it makes the return value of `TaskTracker::{close,reopen}` + // more meaningful for the user. Without these orderings, this assert could fail: + // ``` + // // thread 1 + // some_other_atomic.store(true, Relaxed); + // tracker.close(); + // + // // thread 2 + // if tracker.reopen() { + // assert!(some_other_atomic.load(Relaxed)); + // } + // ``` + // However, with the AcqRel ordering, we establish a happens-before relationship from the + // call to `close` and the later call to `reopen` that returned true. + let state = self.state.fetch_or(1, Ordering::AcqRel); + + // If there are no tasks, and if it was not already closed: + if state == 0 { + self.notify_now(); + } + + (state & 1) == 0 + } + + #[inline] + fn set_open(&self) -> bool { + // See `set_closed` regarding the AcqRel ordering. + let state = self.state.fetch_and(!1, Ordering::AcqRel); + (state & 1) == 1 + } + + #[inline] + fn add_task(&self) { + self.state.fetch_add(2, Ordering::Relaxed); + } + + #[inline] + fn drop_task(&self) { + let state = self.state.fetch_sub(2, Ordering::Release); + + // If this was the last task and we are closed: + if state == 3 { + self.notify_now(); + } + } + + #[cold] + fn notify_now(&self) { + // Insert an acquire fence. This matters for `drop_task` but doesn't matter for + // `set_closed` since it already uses AcqRel. + // + // This synchronizes with the release store of any other call to `drop_task`, and with the + // release store in the call to `set_closed`. That ensures that everything that happened + // before those other calls to `drop_task` or `set_closed` will be visible after this load, + // and those things will also be visible to anything woken by the call to `notify_waiters`. + self.state.load(Ordering::Acquire); + + self.on_last_exit.notify_waiters(); + } +} + +impl TaskTracker { + /// Creates a new `TaskTracker`. + /// + /// The `TaskTracker` will start out as open. + #[must_use] + pub fn new() -> Self { + Self { + inner: Arc::new(TaskTrackerInner::new()), + } + } + + /// Waits until this `TaskTracker` is both closed and empty. + /// + /// If the `TaskTracker` is already closed and empty when this method is called, then it + /// returns immediately. + /// + /// The `wait` future is resistant against [ABA problems][aba]. That is, if the `TaskTracker` + /// becomes both closed and empty for a short amount of time, then it is guarantee that all + /// `wait` futures that were created before the short time interval will trigger, even if they + /// are not polled during that short time interval. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + /// + /// However, the resistance against [ABA problems][aba] is lost when using `wait` as the + /// condition in a `tokio::select!` loop. + /// + /// [aba]: https://en.wikipedia.org/wiki/ABA_problem + #[inline] + pub fn wait(&self) -> TaskTrackerWaitFuture<'_> { + TaskTrackerWaitFuture { + future: self.inner.on_last_exit.notified(), + inner: if self.inner.is_closed_and_empty() { + None + } else { + Some(&self.inner) + }, + } + } + + /// Close this `TaskTracker`. + /// + /// This allows [`wait`] futures to complete. It does not prevent you from spawning new tasks. + /// + /// Returns `true` if this closed the `TaskTracker`, or `false` if it was already closed. + /// + /// [`wait`]: Self::wait + #[inline] + pub fn close(&self) -> bool { + self.inner.set_closed() + } + + /// Reopen this `TaskTracker`. + /// + /// This prevents [`wait`] futures from completing even if the `TaskTracker` is empty. + /// + /// Returns `true` if this reopened the `TaskTracker`, or `false` if it was already open. + /// + /// [`wait`]: Self::wait + #[inline] + pub fn reopen(&self) -> bool { + self.inner.set_open() + } + + /// Returns `true` if this `TaskTracker` is [closed](Self::close). + #[inline] + #[must_use] + pub fn is_closed(&self) -> bool { + (self.inner.state.load(Ordering::Acquire) & 1) != 0 + } + + /// Returns the number of tasks tracked by this `TaskTracker`. + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.inner.state.load(Ordering::Acquire) >> 1 + } + + /// Returns `true` if there are no tasks in this `TaskTracker`. + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.inner.state.load(Ordering::Acquire) <= 1 + } + + /// Spawn the provided future on the current Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `tokio::spawn(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn(&self, task: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + tokio::task::spawn(self.track_future(task)) + } + + /// Spawn the provided future on the provided Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `handle.spawn(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_on(&self, task: F, handle: &Handle) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + handle.spawn(self.track_future(task)) + } + + /// Spawn the provided future on the current [`LocalSet`], and track it in this `TaskTracker`. + /// + /// This is equivalent to `tokio::task::spawn_local(tracker.track_future(task))`. + /// + /// [`LocalSet`]: tokio::task::LocalSet + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_local(&self, task: F) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + tokio::task::spawn_local(self.track_future(task)) + } + + /// Spawn the provided future on the provided [`LocalSet`], and track it in this `TaskTracker`. + /// + /// This is equivalent to `local_set.spawn_local(tracker.track_future(task))`. + /// + /// [`LocalSet`]: tokio::task::LocalSet + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_local_on(&self, task: F, local_set: &LocalSet) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + local_set.spawn_local(self.track_future(task)) + } + + /// Spawn the provided blocking task on the current Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `tokio::task::spawn_blocking(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg(not(target_family = "wasm"))] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_blocking(&self, task: F) -> JoinHandle + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, + { + let token = self.token(); + tokio::task::spawn_blocking(move || { + let res = task(); + drop(token); + res + }) + } + + /// Spawn the provided blocking task on the provided Tokio runtime, and track it in this `TaskTracker`. + /// + /// This is equivalent to `handle.spawn_blocking(tracker.track_future(task))`. + #[inline] + #[track_caller] + #[cfg(feature = "rt")] + #[cfg(not(target_family = "wasm"))] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn spawn_blocking_on(&self, task: F, handle: &Handle) -> JoinHandle + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, + { + let token = self.token(); + handle.spawn_blocking(move || { + let res = task(); + drop(token); + res + }) + } + + /// Track the provided future. + /// + /// The returned [`TrackedFuture`] will count as a task tracked by this collection, and will + /// prevent calls to [`wait`] from returning until the task is dropped. + /// + /// The task is removed from the collection when it is dropped, not when [`poll`] returns + /// [`Poll::Ready`]. + /// + /// # Examples + /// + /// Track a future spawned with [`tokio::spawn`]. + /// + /// ``` + /// # async fn my_async_fn() {} + /// use tokio_util::task::TaskTracker; + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let tracker = TaskTracker::new(); + /// + /// tokio::spawn(tracker.track_future(my_async_fn())); + /// # } + /// ``` + /// + /// Track a future spawned on a [`JoinSet`]. + /// ``` + /// # async fn my_async_fn() {} + /// use tokio::task::JoinSet; + /// use tokio_util::task::TaskTracker; + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let tracker = TaskTracker::new(); + /// let mut join_set = JoinSet::new(); + /// + /// join_set.spawn(tracker.track_future(my_async_fn())); + /// # } + /// ``` + /// + /// [`JoinSet`]: tokio::task::JoinSet + /// [`Poll::Pending`]: std::task::Poll::Pending + /// [`poll`]: std::future::Future::poll + /// [`wait`]: Self::wait + #[inline] + pub fn track_future(&self, future: F) -> TrackedFuture { + TrackedFuture { + future, + token: self.token(), + } + } + + /// Creates a [`TaskTrackerToken`] representing a task tracked by this `TaskTracker`. + /// + /// This token is a lower-level utility than the spawn methods. Each token is considered to + /// correspond to a task. As long as the token exists, the `TaskTracker` cannot complete. + /// Furthermore, the count returned by the [`len`] method will include the tokens in the count. + /// + /// Dropping the token indicates to the `TaskTracker` that the task has exited. + /// + /// [`len`]: TaskTracker::len + #[inline] + pub fn token(&self) -> TaskTrackerToken { + self.inner.add_task(); + TaskTrackerToken { + task_tracker: self.clone(), + } + } + + /// Returns `true` if both task trackers correspond to the same set of tasks. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::TaskTracker; + /// + /// let tracker_1 = TaskTracker::new(); + /// let tracker_2 = TaskTracker::new(); + /// let tracker_1_clone = tracker_1.clone(); + /// + /// assert!(TaskTracker::ptr_eq(&tracker_1, &tracker_1_clone)); + /// assert!(!TaskTracker::ptr_eq(&tracker_1, &tracker_2)); + /// ``` + #[inline] + #[must_use] + pub fn ptr_eq(left: &TaskTracker, right: &TaskTracker) -> bool { + Arc::ptr_eq(&left.inner, &right.inner) + } +} + +impl Default for TaskTracker { + /// Creates a new `TaskTracker`. + /// + /// The `TaskTracker` will start out as open. + #[inline] + fn default() -> TaskTracker { + TaskTracker::new() + } +} + +impl Clone for TaskTracker { + /// Returns a new `TaskTracker` that tracks the same set of tasks. + /// + /// Since the new `TaskTracker` shares the same set of tasks, changes to one set are visible in + /// all other clones. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::TaskTracker; + /// + /// #[tokio::main] + /// # async fn _hidden() {} + /// # #[tokio::main(flavor = "current_thread")] + /// async fn main() { + /// let tracker = TaskTracker::new(); + /// let cloned = tracker.clone(); + /// + /// // Spawns on `tracker` are visible in `cloned`. + /// tracker.spawn(std::future::pending::<()>()); + /// assert_eq!(cloned.len(), 1); + /// + /// // Spawns on `cloned` are visible in `tracker`. + /// cloned.spawn(std::future::pending::<()>()); + /// assert_eq!(tracker.len(), 2); + /// + /// // Calling `close` is visible to `cloned`. + /// tracker.close(); + /// assert!(cloned.is_closed()); + /// + /// // Calling `reopen` is visible to `tracker`. + /// cloned.reopen(); + /// assert!(!tracker.is_closed()); + /// } + /// ``` + #[inline] + fn clone(&self) -> TaskTracker { + Self { + inner: self.inner.clone(), + } + } +} + +fn debug_inner(inner: &TaskTrackerInner, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = inner.state.load(Ordering::Acquire); + let is_closed = (state & 1) != 0; + let len = state >> 1; + + f.debug_struct("TaskTracker") + .field("len", &len) + .field("is_closed", &is_closed) + .field("inner", &(inner as *const TaskTrackerInner)) + .finish() +} + +impl fmt::Debug for TaskTracker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_inner(&self.inner, f) + } +} + +impl TaskTrackerToken { + /// Returns the [`TaskTracker`] that this token is associated with. + #[inline] + #[must_use] + pub fn task_tracker(&self) -> &TaskTracker { + &self.task_tracker + } +} + +impl Clone for TaskTrackerToken { + /// Returns a new `TaskTrackerToken` associated with the same [`TaskTracker`]. + /// + /// This is equivalent to `token.task_tracker().token()`. + #[inline] + fn clone(&self) -> TaskTrackerToken { + self.task_tracker.token() + } +} + +impl Drop for TaskTrackerToken { + /// Dropping the token indicates to the [`TaskTracker`] that the task has exited. + #[inline] + fn drop(&mut self) { + self.task_tracker.inner.drop_task(); + } +} + +impl Future for TrackedFuture { + type Output = F::Output; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().future.poll(cx) + } +} + +impl fmt::Debug for TrackedFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TrackedFuture") + .field("future", &self.future) + .field("task_tracker", self.token.task_tracker()) + .finish() + } +} + +impl<'a> Future for TaskTrackerWaitFuture<'a> { + type Output = (); + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let me = self.project(); + + let inner = match me.inner.as_ref() { + None => return Poll::Ready(()), + Some(inner) => inner, + }; + + let ready = inner.is_closed_and_empty() || me.future.poll(cx).is_ready(); + if ready { + *me.inner = None; + Poll::Ready(()) + } else { + Poll::Pending + } + } +} + +impl<'a> fmt::Debug for TaskTrackerWaitFuture<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Helper<'a>(&'a TaskTrackerInner); + + impl fmt::Debug for Helper<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_inner(self.0, f) + } + } + + f.debug_struct("TaskTrackerWaitFuture") + .field("future", &self.future) + .field("task_tracker", &self.inner.map(Helper)) + .finish() + } +} diff --git a/tokio-util/tests/task_tracker.rs b/tokio-util/tests/task_tracker.rs new file mode 100644 index 00000000000..f0eb2442ca6 --- /dev/null +++ b/tokio-util/tests/task_tracker.rs @@ -0,0 +1,178 @@ +#![warn(rust_2018_idioms)] + +use tokio_test::{assert_pending, assert_ready, task}; +use tokio_util::task::TaskTracker; + +#[test] +fn open_close() { + let tracker = TaskTracker::new(); + assert!(!tracker.is_closed()); + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); + + tracker.close(); + assert!(tracker.is_closed()); + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); + + tracker.reopen(); + assert!(!tracker.is_closed()); + tracker.reopen(); + assert!(!tracker.is_closed()); + + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); + + tracker.close(); + assert!(tracker.is_closed()); + tracker.close(); + assert!(tracker.is_closed()); + + assert!(tracker.is_empty()); + assert_eq!(tracker.len(), 0); +} + +#[test] +fn token_len() { + let tracker = TaskTracker::new(); + + let mut tokens = Vec::new(); + for i in 0..10 { + assert_eq!(tracker.len(), i); + tokens.push(tracker.token()); + } + + assert!(!tracker.is_empty()); + assert_eq!(tracker.len(), 10); + + for (i, token) in tokens.into_iter().enumerate() { + drop(token); + assert_eq!(tracker.len(), 9 - i); + } +} + +#[test] +fn notify_immediately() { + let tracker = TaskTracker::new(); + tracker.close(); + + let mut wait = task::spawn(tracker.wait()); + assert_ready!(wait.poll()); +} + +#[test] +fn notify_immediately_on_reopen() { + let tracker = TaskTracker::new(); + tracker.close(); + + let mut wait = task::spawn(tracker.wait()); + tracker.reopen(); + assert_ready!(wait.poll()); +} + +#[test] +fn notify_on_close() { + let tracker = TaskTracker::new(); + + let mut wait = task::spawn(tracker.wait()); + + assert_pending!(wait.poll()); + tracker.close(); + assert_ready!(wait.poll()); +} + +#[test] +fn notify_on_close_reopen() { + let tracker = TaskTracker::new(); + + let mut wait = task::spawn(tracker.wait()); + + assert_pending!(wait.poll()); + tracker.close(); + tracker.reopen(); + assert_ready!(wait.poll()); +} + +#[test] +fn notify_on_last_task() { + let tracker = TaskTracker::new(); + tracker.close(); + let token = tracker.token(); + + let mut wait = task::spawn(tracker.wait()); + assert_pending!(wait.poll()); + drop(token); + assert_ready!(wait.poll()); +} + +#[test] +fn notify_on_last_task_respawn() { + let tracker = TaskTracker::new(); + tracker.close(); + let token = tracker.token(); + + let mut wait = task::spawn(tracker.wait()); + assert_pending!(wait.poll()); + drop(token); + let token2 = tracker.token(); + assert_ready!(wait.poll()); + drop(token2); +} + +#[test] +fn no_notify_on_respawn_if_open() { + let tracker = TaskTracker::new(); + let token = tracker.token(); + + let mut wait = task::spawn(tracker.wait()); + assert_pending!(wait.poll()); + drop(token); + let token2 = tracker.token(); + assert_pending!(wait.poll()); + drop(token2); +} + +#[test] +fn close_during_exit() { + const ITERS: usize = 5; + + for close_spot in 0..=ITERS { + let tracker = TaskTracker::new(); + let tokens: Vec<_> = (0..ITERS).map(|_| tracker.token()).collect(); + + let mut wait = task::spawn(tracker.wait()); + + for (i, token) in tokens.into_iter().enumerate() { + assert_pending!(wait.poll()); + if i == close_spot { + tracker.close(); + assert_pending!(wait.poll()); + } + drop(token); + } + + if close_spot == ITERS { + assert_pending!(wait.poll()); + tracker.close(); + } + + assert_ready!(wait.poll()); + } +} + +#[test] +fn notify_many() { + let tracker = TaskTracker::new(); + + let mut waits: Vec<_> = (0..10).map(|_| task::spawn(tracker.wait())).collect(); + + for wait in &mut waits { + assert_pending!(wait.poll()); + } + + tracker.close(); + + for wait in &mut waits { + assert_ready!(wait.poll()); + } +}