Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop the join waker of a task eagerly when the task completes and the… #6964

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 70 additions & 16 deletions tokio/src/runtime/task/harness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,21 +283,33 @@ where
}

pub(super) fn drop_join_handle_slow(self) {
use super::state::TransitionToJoinHandleDrop;
// Try to unset `JOIN_INTEREST`. This must be done as a first step in
// case the task concurrently completed.
if self.state().unset_join_interested().is_err() {
// It is our responsibility to drop the output. This is critical as
// the task output may not be `Send` and as such must remain with
// the scheduler or `JoinHandle`. i.e. if the output remains in the
// task structure until the task is deallocated, it may be dropped
// by a Waker on any arbitrary thread.
//
// Panics are delivered to the user via the `JoinHandle`. Given that
// they are dropping the `JoinHandle`, we assume they are not
// interested in the panic and swallow it.
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
self.core().drop_future_or_output();
}));
//
// TODO Create new bit/flag in state -> Set WantToDropJoinWaker in transition when failing
let transition = self.state().transition_to_join_handle_drop();
match transition {
TransitionToJoinHandleDrop::Failed => {
// It is our responsibility to drop the output. This is critical as
// the task output may not be `Send` and as such must remain with
// the scheduler or `JoinHandle`. i.e. if the output remains in the
// task structure until the task is deallocated, it may be dropped
// by a Waker on any arbitrary thread.
//
// Panics are delivered to the user via the `JoinHandle`. Given that
// they are dropping the `JoinHandle`, we assume they are not
// interested in the panic and swallow it.
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
self.core().drop_future_or_output();
}));
}
TransitionToJoinHandleDrop::OkDropJoinWaker => unsafe {
// If there is a waker associated with this task when the `JoinHandle` is about to get
// dropped we want to also drop this waker if the task is already completed.
self.trailer().set_waker(None);
},
TransitionToJoinHandleDrop::OkDoNothing => (),
}

// Drop the `JoinHandle` reference, possibly deallocating the task
Expand All @@ -308,6 +320,7 @@ where

/// Completes the task. This method assumes that the state is RUNNING.
fn complete(self) {
use super::state::TransitionToTerminal;
// The future has completed and its output has been written to the task
// stage. We transition from running to complete.

Expand All @@ -332,8 +345,28 @@ where
// The task has completed execution and will no longer be scheduled.
let num_release = self.release();

if self.state().transition_to_terminal(num_release) {
self.dealloc();
match self.state().transition_to_terminal(num_release) {
TransitionToTerminal::OkDoNothing => (),
TransitionToTerminal::OkDealloc => {
self.dealloc();
}
TransitionToTerminal::FailedDropJoinWaker => {
// Safety: In this case we are the only one referencing the task and the active
// waker is the only one preventing the task from being deallocated so noone else
// will try to access the waker here.
unsafe {
self.trailer().set_waker(None);
}

match self.state().transition_to_terminal(num_release) {
TransitionToTerminal::OkDealloc => self.dealloc(),
// We do not expect this to happen since `TransitionToTerminal::DropJoinWaker`
// will only be returned when after dropping the JoinWaker the task can be
// safely. Because after this failed transition the COMPLETE bit is still set
// its fine to transition to terminal in two steps here
_ => (),
}
}
}
}

Expand Down Expand Up @@ -373,7 +406,7 @@ fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool {

debug_assert!(snapshot.is_join_interested());

if !snapshot.is_complete() {
if !snapshot.is_complete() && !snapshot.is_terminal() {
// If the task is not complete, try storing the provided waker in the
// task's waker field.

Expand Down Expand Up @@ -439,6 +472,27 @@ fn set_join_waker(
res
}

fn unset_join_waker(
header: &Header,
trailer: &Trailer,
snapshot: Snapshot,
) -> Result<Snapshot, Snapshot> {
assert!(snapshot.is_join_interested());
assert!(snapshot.is_join_waker_set());

// Make sure the `JoinWaker` bit is unset before accessing the `waker` directly.
let res = header.state.unset_waker();
if res.is_ok() {
// Safety: Only the `JoinHandle` may set the `waker` field. When
// `JOIN_INTEREST` is **not** set, nothing else will touch the field.
unsafe {
trailer.set_waker(None);
}
}

res
}

enum PollFuture {
Complete,
Notified,
Expand Down
83 changes: 72 additions & 11 deletions tokio/src/runtime/task/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ const JOIN_WAKER: usize = 0b10_000;
/// The task has been forcibly cancelled.
const CANCELLED: usize = 0b100_000;

const TERMINAL: usize = 0b1_000_000;

/// All bits.
const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED;
// const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED;
const STATE_MASK: usize =
LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED | TERMINAL;

/// Bits used by the ref count portion of the state.
const REF_COUNT_MASK: usize = !STATE_MASK;
Expand Down Expand Up @@ -89,6 +93,20 @@ pub(crate) enum TransitionToNotifiedByRef {
Submit,
}

#[must_use]
pub(crate) enum TransitionToJoinHandleDrop {
Failed,
OkDoNothing,
OkDropJoinWaker,
}

#[must_use]
pub(crate) enum TransitionToTerminal {
OkDoNothing,
OkDealloc,
FailedDropJoinWaker,
}

/// All transitions are performed via RMW operations. This establishes an
/// unambiguous modification order.
impl State {
Expand Down Expand Up @@ -174,30 +192,69 @@ impl State {
})
}

pub(super) fn transition_to_join_handle_drop(&self) -> TransitionToJoinHandleDrop {
self.fetch_update_action(|mut snapshot| {
if snapshot.is_join_interested() {
snapshot.unset_join_interested()
}

if snapshot.is_complete() && !snapshot.is_terminal() {
(TransitionToJoinHandleDrop::Failed, None)
} else if snapshot.is_join_waker_set() {
snapshot.unset_join_waker();
(TransitionToJoinHandleDrop::OkDropJoinWaker, Some(snapshot))
} else {
(TransitionToJoinHandleDrop::OkDoNothing, Some(snapshot))
}
})
}

/// Transitions the task from `Running` -> `Complete`.
pub(super) fn transition_to_complete(&self) -> Snapshot {
const DELTA: usize = RUNNING | COMPLETE;

let prev = Snapshot(self.val.fetch_xor(DELTA, AcqRel));
assert!(prev.is_running());
assert!(!prev.is_complete());
assert!(!prev.is_terminal());

Snapshot(prev.0 ^ DELTA)
}

/// Transitions from `Complete` -> `Terminal`, decrementing the reference
/// count the specified number of times.
///
/// Returns true if the task should be deallocated.
pub(super) fn transition_to_terminal(&self, count: usize) -> bool {
let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel));
assert!(
prev.ref_count() >= count,
"current: {}, sub: {}",
prev.ref_count(),
count
);
prev.ref_count() == count
/// Returns `TransitionToTerminal::OkDoNothing` if transition was successful but the task can
/// not already be deallocated.
/// Returns `TransitionToTerminal::OkDealloc` if the task should be deallocated.
/// Returns `TransitionToTerminal::FailedDropJoinWaker` if the transition failed because of a
/// the join waker being the only last. In this case the reference count will not be decremented
/// but the `JOIN_WAKER` bit will be unset.
pub(super) fn transition_to_terminal(&self, count: usize) -> TransitionToTerminal {
self.fetch_update_action(|mut snapshot| {
assert!(!snapshot.is_running());
assert!(snapshot.is_complete());
assert!(!snapshot.is_terminal());
assert!(
snapshot.ref_count() >= count,
"current: {}, sub: {}",
snapshot.ref_count(),
count
);

if snapshot.ref_count() == count {
snapshot.0 -= count * REF_ONE;
snapshot.0 |= TERMINAL;
(TransitionToTerminal::OkDealloc, Some(snapshot))
} else if !snapshot.is_join_interested() && snapshot.is_join_waker_set() {
snapshot.unset_join_waker();
(TransitionToTerminal::FailedDropJoinWaker, Some(snapshot))
} else {
snapshot.0 -= count * REF_ONE;
snapshot.0 |= TERMINAL;
(TransitionToTerminal::OkDoNothing, Some(snapshot))
}
})
}

/// Transitions the state to `NOTIFIED`.
Expand Down Expand Up @@ -557,6 +614,10 @@ impl Snapshot {
self.0 & COMPLETE == COMPLETE
}

pub(super) fn is_terminal(self) -> bool {
self.0 & TERMINAL == TERMINAL
}

pub(super) fn is_join_interested(self) -> bool {
self.0 & JOIN_INTEREST == JOIN_INTEREST
}
Expand Down
18 changes: 18 additions & 0 deletions tokio/src/runtime/tests/loom_current_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ fn assert_no_unnecessary_polls() {
});
}

// #[test]
// fn new_test() {
// loom::model(|| {
// let rt = Builder::new_current_thread().build().unwrap();
//
// let jh = rt.spawn(async {});
//
// let bg = std::thread::spawn(move || {
// jh.poll();
// });
//
// rt.block_on(async {});
//
// rt.shutdown();
// bg.join();
// });
// }

struct BlockedFuture {
rx: Receiver<()>,
num_polls: Arc<AtomicUsize>,
Expand Down
36 changes: 36 additions & 0 deletions tokio/src/runtime/tests/loom_multi_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,39 @@ impl<T: Future> Future for Track<T> {
})
}
}

#[test]
fn timo_test() {
use crate::sync::mpsc::channel;

loom::model(|| {
let pool = mk_pool(2);

pool.block_on(async move {
let (tx, mut rx) = channel(1);

let (a_closer, mut wait_for_close_a) = channel::<()>(1);
let (b_closer, mut wait_for_close_b) = channel::<()>(1);

let a = spawn(async move {
let b = rx.recv().await.unwrap();

futures::future::select(
std::pin::pin!(b),
// std::pin::pin!(futures::future::ready(())),
std::pin::pin!(a_closer.send(())),
)
.await;
});

let b = spawn(async move {
let _ = a.await;
let _ = b_closer.send(()).await;
});

tx.send(b).await.unwrap();

futures::future::join(wait_for_close_a.recv(), wait_for_close_b.recv()).await;
});
});
}
2 changes: 1 addition & 1 deletion tokio/src/sync/tests/loom_atomic_waker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn test_panicky_waker() {
// which would otherwise log.
//
// We can't however leaved it uncommented, because it's global.
// panic::set_hook(Box::new(|_| ()));
panic::set_hook(Box::new(|_| ()));

const NUM_NOTIFY: usize = 2;

Expand Down
Loading