Skip to content

Commit

Permalink
Drop the join waker of a task eagerly when the task completes and the…
Browse files Browse the repository at this point in the history
…re is no

join interest
  • Loading branch information
tglane committed Nov 9, 2024
1 parent 53ea44b commit cba13d4
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 28 deletions.
86 changes: 70 additions & 16 deletions tokio/src/runtime/task/harness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,21 +283,33 @@ where
}

pub(super) fn drop_join_handle_slow(self) {
use super::state::TransitionToJoinHandleDrop;
// Try to unset `JOIN_INTEREST`. This must be done as a first step in
// case the task concurrently completed.
if self.state().unset_join_interested().is_err() {
// It is our responsibility to drop the output. This is critical as
// the task output may not be `Send` and as such must remain with
// the scheduler or `JoinHandle`. i.e. if the output remains in the
// task structure until the task is deallocated, it may be dropped
// by a Waker on any arbitrary thread.
//
// Panics are delivered to the user via the `JoinHandle`. Given that
// they are dropping the `JoinHandle`, we assume they are not
// interested in the panic and swallow it.
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
self.core().drop_future_or_output();
}));
//
// TODO Create new bit/flag in state -> Set WantToDropJoinWaker in transition when failing
let transition = self.state().transition_to_join_handle_drop();
match transition {
TransitionToJoinHandleDrop::Failed => {
// It is our responsibility to drop the output. This is critical as
// the task output may not be `Send` and as such must remain with
// the scheduler or `JoinHandle`. i.e. if the output remains in the
// task structure until the task is deallocated, it may be dropped
// by a Waker on any arbitrary thread.
//
// Panics are delivered to the user via the `JoinHandle`. Given that
// they are dropping the `JoinHandle`, we assume they are not
// interested in the panic and swallow it.
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
self.core().drop_future_or_output();
}));
}
TransitionToJoinHandleDrop::OkDropJoinWaker => unsafe {
// If there is a waker associated with this task when the `JoinHandle` is about to get
// dropped we want to also drop this waker if the task is already completed.
self.trailer().set_waker(None);
},
TransitionToJoinHandleDrop::OkDoNothing => (),
}

// Drop the `JoinHandle` reference, possibly deallocating the task
Expand All @@ -308,6 +320,7 @@ where

/// Completes the task. This method assumes that the state is RUNNING.
fn complete(self) {
use super::state::TransitionToTerminal;
// The future has completed and its output has been written to the task
// stage. We transition from running to complete.

Expand All @@ -332,8 +345,28 @@ where
// The task has completed execution and will no longer be scheduled.
let num_release = self.release();

if self.state().transition_to_terminal(num_release) {
self.dealloc();
match self.state().transition_to_terminal(num_release) {
TransitionToTerminal::OkDoNothing => (),
TransitionToTerminal::OkDealloc => {
self.dealloc();
}
TransitionToTerminal::FailedDropJoinWaker => {
// Safety: In this case we are the only one referencing the task and the active
// waker is the only one preventing the task from being deallocated so noone else
// will try to access the waker here.
unsafe {
self.trailer().set_waker(None);
}

match self.state().transition_to_terminal(num_release) {
TransitionToTerminal::OkDealloc => self.dealloc(),
// We do not expect this to happen since `TransitionToTerminal::DropJoinWaker`
// will only be returned when after dropping the JoinWaker the task can be
// safely. Because after this failed transition the COMPLETE bit is still set
// its fine to transition to terminal in two steps here
_ => (),
}
}
}
}

Expand Down Expand Up @@ -373,7 +406,7 @@ fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool {

debug_assert!(snapshot.is_join_interested());

if !snapshot.is_complete() {
if !snapshot.is_complete() && !snapshot.is_terminal() {
// If the task is not complete, try storing the provided waker in the
// task's waker field.

Expand Down Expand Up @@ -439,6 +472,27 @@ fn set_join_waker(
res
}

fn unset_join_waker(
header: &Header,
trailer: &Trailer,
snapshot: Snapshot,
) -> Result<Snapshot, Snapshot> {
assert!(snapshot.is_join_interested());
assert!(snapshot.is_join_waker_set());

// Make sure the `JoinWaker` bit is unset before accessing the `waker` directly.
let res = header.state.unset_waker();
if res.is_ok() {
// Safety: Only the `JoinHandle` may set the `waker` field. When
// `JOIN_INTEREST` is **not** set, nothing else will touch the field.
unsafe {
trailer.set_waker(None);
}
}

res
}

enum PollFuture {
Complete,
Notified,
Expand Down
83 changes: 72 additions & 11 deletions tokio/src/runtime/task/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ const JOIN_WAKER: usize = 0b10_000;
/// The task has been forcibly cancelled.
const CANCELLED: usize = 0b100_000;

const TERMINAL: usize = 0b1_000_000;

/// All bits.
const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED;
// const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED;
const STATE_MASK: usize =
LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED | TERMINAL;

/// Bits used by the ref count portion of the state.
const REF_COUNT_MASK: usize = !STATE_MASK;
Expand Down Expand Up @@ -89,6 +93,20 @@ pub(crate) enum TransitionToNotifiedByRef {
Submit,
}

#[must_use]
pub(crate) enum TransitionToJoinHandleDrop {
Failed,
OkDoNothing,
OkDropJoinWaker,
}

#[must_use]
pub(crate) enum TransitionToTerminal {
OkDoNothing,
OkDealloc,
FailedDropJoinWaker,
}

/// All transitions are performed via RMW operations. This establishes an
/// unambiguous modification order.
impl State {
Expand Down Expand Up @@ -174,30 +192,69 @@ impl State {
})
}

pub(super) fn transition_to_join_handle_drop(&self) -> TransitionToJoinHandleDrop {
self.fetch_update_action(|mut snapshot| {
if snapshot.is_join_interested() {
snapshot.unset_join_interested()
}

if snapshot.is_complete() && !snapshot.is_terminal() {
(TransitionToJoinHandleDrop::Failed, None)
} else if snapshot.is_join_waker_set() {
snapshot.unset_join_waker();
(TransitionToJoinHandleDrop::OkDropJoinWaker, Some(snapshot))
} else {
(TransitionToJoinHandleDrop::OkDoNothing, Some(snapshot))
}
})
}

/// Transitions the task from `Running` -> `Complete`.
pub(super) fn transition_to_complete(&self) -> Snapshot {
const DELTA: usize = RUNNING | COMPLETE;

let prev = Snapshot(self.val.fetch_xor(DELTA, AcqRel));
assert!(prev.is_running());
assert!(!prev.is_complete());
assert!(!prev.is_terminal());

Snapshot(prev.0 ^ DELTA)
}

/// Transitions from `Complete` -> `Terminal`, decrementing the reference
/// count the specified number of times.
///
/// Returns true if the task should be deallocated.
pub(super) fn transition_to_terminal(&self, count: usize) -> bool {
let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel));
assert!(
prev.ref_count() >= count,
"current: {}, sub: {}",
prev.ref_count(),
count
);
prev.ref_count() == count
/// Returns `TransitionToTerminal::OkDoNothing` if transition was successful but the task can
/// not already be deallocated.
/// Returns `TransitionToTerminal::OkDealloc` if the task should be deallocated.
/// Returns `TransitionToTerminal::FailedDropJoinWaker` if the transition failed because of a
/// the join waker being the only last. In this case the reference count will not be decremented
/// but the `JOIN_WAKER` bit will be unset.
pub(super) fn transition_to_terminal(&self, count: usize) -> TransitionToTerminal {
self.fetch_update_action(|mut snapshot| {
assert!(!snapshot.is_running());
assert!(snapshot.is_complete());
assert!(!snapshot.is_terminal());
assert!(
snapshot.ref_count() >= count,
"current: {}, sub: {}",
snapshot.ref_count(),
count
);

if snapshot.ref_count() == count {
snapshot.0 -= count * REF_ONE;
snapshot.0 |= TERMINAL;
(TransitionToTerminal::OkDealloc, Some(snapshot))
} else if !snapshot.is_join_interested() && snapshot.is_join_waker_set() {
snapshot.unset_join_waker();
(TransitionToTerminal::FailedDropJoinWaker, Some(snapshot))
} else {
snapshot.0 -= count * REF_ONE;
snapshot.0 |= TERMINAL;
(TransitionToTerminal::OkDoNothing, Some(snapshot))
}
})
}

/// Transitions the state to `NOTIFIED`.
Expand Down Expand Up @@ -557,6 +614,10 @@ impl Snapshot {
self.0 & COMPLETE == COMPLETE
}

pub(super) fn is_terminal(self) -> bool {
self.0 & TERMINAL == TERMINAL
}

pub(super) fn is_join_interested(self) -> bool {
self.0 & JOIN_INTEREST == JOIN_INTEREST
}
Expand Down
18 changes: 18 additions & 0 deletions tokio/src/runtime/tests/loom_current_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ fn assert_no_unnecessary_polls() {
});
}

// #[test]
// fn new_test() {
// loom::model(|| {
// let rt = Builder::new_current_thread().build().unwrap();
//
// let jh = rt.spawn(async {});
//
// let bg = std::thread::spawn(move || {
// jh.poll();
// });
//
// rt.block_on(async {});
//
// rt.shutdown();
// bg.join();
// });
// }

struct BlockedFuture {
rx: Receiver<()>,
num_polls: Arc<AtomicUsize>,
Expand Down
36 changes: 36 additions & 0 deletions tokio/src/runtime/tests/loom_multi_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,39 @@ impl<T: Future> Future for Track<T> {
})
}
}

#[test]
fn timo_test() {
use crate::sync::mpsc::channel;

loom::model(|| {
let pool = mk_pool(2);

pool.block_on(async move {
let (tx, mut rx) = channel(1);

let (a_closer, mut wait_for_close_a) = channel::<()>(1);
let (b_closer, mut wait_for_close_b) = channel::<()>(1);

let a = spawn(async move {
let b = rx.recv().await.unwrap();

futures::future::select(
std::pin::pin!(b),
// std::pin::pin!(futures::future::ready(())),
std::pin::pin!(a_closer.send(())),
)
.await;
});

let b = spawn(async move {
let _ = a.await;
let _ = b_closer.send(()).await;
});

tx.send(b).await.unwrap();

futures::future::join(wait_for_close_a.recv(), wait_for_close_b.recv()).await;
});
});
}
2 changes: 1 addition & 1 deletion tokio/src/sync/tests/loom_atomic_waker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn test_panicky_waker() {
// which would otherwise log.
//
// We can't however leaved it uncommented, because it's global.
// panic::set_hook(Box::new(|_| ()));
panic::set_hook(Box::new(|_| ()));

const NUM_NOTIFY: usize = 2;

Expand Down

0 comments on commit cba13d4

Please sign in to comment.