Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coop: Undo budget decrement on Pending #2549

Merged
merged 4 commits into from
May 21, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 70 additions & 9 deletions tokio/src/coop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,46 @@ cfg_blocking_impl! {
cfg_coop! {
use std::task::{Context, Poll};

#[must_use]
pub(crate) struct RestoreOnPending(Cell<Option<Budget>>);

impl RestoreOnPending {
pub(crate) fn made_progress(&self) {
jonhoo marked this conversation as resolved.
Show resolved Hide resolved
self.0.set(None);
}
}

impl Drop for RestoreOnPending {
fn drop(&mut self) {
if let Some(budget) = self.0.get() {
CURRENT.with(|cell| {
cell.set(budget);
});
}
}
}

/// Returns `Poll::Pending` if the current task has exceeded its budget and should yield.
///
/// When you call this method, the current budget is decremented. However, to ensure that
/// progress is made every time a task is polled, the budget is automatically restored to its
/// former value if the returned `RestoreOnPending` is dropped. It is the caller's
/// responsibility to call `RestoreOnPending::made_progress` if it made progress, to ensure
/// that the budget empties appropriately.
///
/// Note that `RestoreOnPending` restores the budget **as it was before `poll_proceed`**.
/// Therefore, if the budget is _further_ adjusted between when `poll_proceed` returns and
/// `RestRestoreOnPending` is dropped, those adjustments are erased unless the caller indicates
/// that progress was made.
#[inline]
pub(crate) fn poll_proceed(cx: &mut Context<'_>) -> Poll<()> {
pub(crate) fn poll_proceed(cx: &mut Context<'_>) -> Poll<RestoreOnPending> {
CURRENT.with(|cell| {
let mut budget = cell.get();

if budget.decrement() {
let restore = RestoreOnPending(Cell::new(cell.get().if_dynamic()));
jonhoo marked this conversation as resolved.
Show resolved Hide resolved
cell.set(budget);
Poll::Ready(())
Poll::Ready(restore)
} else {
cx.waker().wake_by_ref();
Poll::Pending
Expand All @@ -181,7 +212,15 @@ cfg_coop! {
} else {
true
}
}
}

fn if_dynamic(self) -> Option<Self> {
if self.0.is_some() {
Some(self)
} else {
None
}
}
}
}

Expand All @@ -200,21 +239,41 @@ mod test {

assert!(get().0.is_none());

assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));

assert!(get().0.is_none());
drop(coop);
assert!(get().0.is_none());

budget(|| {
assert_eq!(get().0, Budget::initial().0);
assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));

let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
drop(coop);
// we didn't make progress
assert_eq!(get().0, Budget::initial().0);

let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
coop.made_progress();
drop(coop);
// we _did_ make progress
assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);

let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2);
coop.made_progress();
drop(coop);
assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2);

budget(|| {
assert_eq!(get().0, Budget::initial().0);

assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
coop.made_progress();
drop(coop);
assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
});

Expand All @@ -227,11 +286,13 @@ mod test {
let n = get().0.unwrap();

for _ in 0..n {
assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
coop.made_progress();
}

let mut task = task::spawn(poll_fn(|cx| {
ready!(poll_proceed(cx));
let coop = ready!(poll_proceed(cx));
coop.made_progress();
Poll::Ready(())
}));

Expand Down
24 changes: 18 additions & 6 deletions tokio/src/io/registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,17 @@ impl Registration {
/// This function will panic if called from outside of a task context.
pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> {
// Keep track of task budget
ready!(crate::coop::poll_proceed(cx));
let coop = ready!(crate::coop::poll_proceed(cx));

let v = self.poll_ready(Direction::Read, Some(cx))?;
let v = self.poll_ready(Direction::Read, Some(cx)).map_err(|e| {
coop.made_progress();
e
})?;
match v {
Some(v) => Poll::Ready(Ok(v)),
Some(v) => {
coop.made_progress();
Poll::Ready(Ok(v))
}
None => Poll::Pending,
}
}
Expand Down Expand Up @@ -231,11 +237,17 @@ impl Registration {
/// This function will panic if called from outside of a task context.
pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> {
// Keep track of task budget
ready!(crate::coop::poll_proceed(cx));
let coop = ready!(crate::coop::poll_proceed(cx));

let v = self.poll_ready(Direction::Write, Some(cx))?;
let v = self.poll_ready(Direction::Write, Some(cx)).map_err(|e| {
coop.made_progress();
e
})?;
match v {
Some(v) => Poll::Ready(Ok(v)),
Some(v) => {
coop.made_progress();
Poll::Ready(Ok(v))
}
None => Poll::Pending,
}
}
Expand Down
6 changes: 5 additions & 1 deletion tokio/src/process/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ where

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Keep track of task budget
ready!(crate::coop::poll_proceed(cx));
let coop = ready!(crate::coop::poll_proceed(cx));

let ret = Pin::new(&mut self.inner).poll(cx);

Expand All @@ -717,6 +717,10 @@ where
self.kill_on_drop = false;
}

if ret.is_ready() {
coop.made_progress();
}

ret
}
}
Expand Down
6 changes: 5 additions & 1 deletion tokio/src/runtime/task/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl<T> Future for JoinHandle<T> {
let mut ret = Poll::Pending;

// Keep track of task budget
ready!(crate::coop::poll_proceed(cx));
let coop = ready!(crate::coop::poll_proceed(cx));

// Raw should always be set. If it is not, this is due to polling after
// completion
Expand All @@ -126,6 +126,10 @@ impl<T> Future for JoinHandle<T> {
raw.try_read_output(&mut ret as *mut _ as *mut (), cx.waker());
}

if ret.is_ready() {
coop.made_progress();
}

ret
}
}
Expand Down
3 changes: 2 additions & 1 deletion tokio/src/stream/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ where
type Item = I::Item;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<I::Item>> {
ready!(crate::coop::poll_proceed(cx));
let coop = ready!(crate::coop::poll_proceed(cx));
coop.made_progress();
Poll::Ready(self.iter.next())
}

Expand Down
3 changes: 2 additions & 1 deletion tokio/src/sync/batch_semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl Future for Acquire<'_> {

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// First, ensure the current task has enough budget to proceed.
ready!(crate::coop::poll_proceed(cx));
let coop = ready!(crate::coop::poll_proceed(cx));

let (node, semaphore, needed, queued) = self.project();

Expand All @@ -399,6 +399,7 @@ impl Future for Acquire<'_> {
Pending
}
Ready(r) => {
coop.made_progress();
r?;
*queued = false;
Ready(Ok(()))
Expand Down
11 changes: 9 additions & 2 deletions tokio/src/sync/mpsc/chan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ where
use super::block::Read::*;

// Keep track of task budget
ready!(crate::coop::poll_proceed(cx));
let coop = ready!(crate::coop::poll_proceed(cx));

self.inner.rx_fields.with_mut(|rx_fields_ptr| {
let rx_fields = unsafe { &mut *rx_fields_ptr };
Expand All @@ -287,6 +287,7 @@ where
match rx_fields.list.pop(&self.inner.tx) {
Some(Value(value)) => {
self.inner.semaphore.add_permit();
coop.made_progress();
return Ready(Some(value));
}
Some(Closed) => {
Expand All @@ -297,6 +298,7 @@ where
// which ensures that if dropping the tx handle is
// visible, then all messages sent are also visible.
assert!(self.inner.semaphore.is_idle());
coop.made_progress();
return Ready(None);
}
None => {} // fall through
Expand All @@ -314,6 +316,7 @@ where
try_recv!();

if rx_fields.rx_closed && self.inner.semaphore.is_idle() {
coop.made_progress();
Ready(None)
} else {
Pending
Expand Down Expand Up @@ -439,11 +442,15 @@ impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) {
permit: &mut Permit,
) -> Poll<Result<(), ClosedError>> {
// Keep track of task budget
ready!(crate::coop::poll_proceed(cx));
let coop = ready!(crate::coop::poll_proceed(cx));

permit
.poll_acquire(cx, 1, &self.0)
.map_err(|_| ClosedError::new())
.map(move |r| {
coop.made_progress();
r
})
}

fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> {
Expand Down
11 changes: 9 additions & 2 deletions tokio/src/sync/oneshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,14 @@ impl<T> Sender<T> {
#[doc(hidden)] // TODO: remove
pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
// Keep track of task budget
ready!(crate::coop::poll_proceed(cx));
let coop = ready!(crate::coop::poll_proceed(cx));

let inner = self.inner.as_ref().unwrap();

let mut state = State::load(&inner.state, Acquire);

if state.is_closed() {
coop.made_progress();
return Poll::Ready(());
}

Expand All @@ -216,6 +217,7 @@ impl<T> Sender<T> {
if state.is_closed() {
// Set the flag again so that the waker is released in drop
State::set_tx_task(&inner.state);
coop.made_progress();
return Ready(());
} else {
unsafe { inner.drop_tx_task() };
Expand All @@ -233,6 +235,7 @@ impl<T> Sender<T> {
state = State::set_tx_task(&inner.state);

if state.is_closed() {
coop.made_progress();
return Ready(());
}
}
Expand Down Expand Up @@ -548,17 +551,19 @@ impl<T> Inner<T> {

fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
// Keep track of task budget
ready!(crate::coop::poll_proceed(cx));
let coop = ready!(crate::coop::poll_proceed(cx));

// Load the state
let mut state = State::load(&self.state, Acquire);

if state.is_complete() {
coop.made_progress();
match unsafe { self.consume_value() } {
Some(value) => Ready(Ok(value)),
None => Ready(Err(RecvError(()))),
}
} else if state.is_closed() {
coop.made_progress();
Ready(Err(RecvError(())))
} else {
if state.is_rx_task_set() {
Expand All @@ -572,6 +577,7 @@ impl<T> Inner<T> {
// Set the flag again so that the waker is released in drop
State::set_rx_task(&self.state);

coop.made_progress();
return match unsafe { self.consume_value() } {
Some(value) => Ready(Ok(value)),
None => Ready(Err(RecvError(()))),
Expand All @@ -592,6 +598,7 @@ impl<T> Inner<T> {
state = State::set_rx_task(&self.state);

if state.is_complete() {
coop.made_progress();
match unsafe { self.consume_value() } {
Some(value) => Ready(Ok(value)),
None => Ready(Err(RecvError(()))),
Expand Down
7 changes: 5 additions & 2 deletions tokio/src/time/driver/registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ impl Registration {

pub(crate) fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
// Keep track of task budget
ready!(crate::coop::poll_proceed(cx));
let coop = ready!(crate::coop::poll_proceed(cx));

self.entry.poll_elapsed(cx)
self.entry.poll_elapsed(cx).map(move |r| {
coop.made_progress();
r
})
}
}

Expand Down