diff --git a/tokio/src/coop.rs b/tokio/src/coop.rs index 2aae093afa3..a45b86d7623 100644 --- a/tokio/src/coop.rs +++ b/tokio/src/coop.rs @@ -151,15 +151,49 @@ cfg_blocking_impl! { cfg_coop! { use std::task::{Context, Poll}; + #[must_use] + pub(crate) struct RestoreOnPending(Cell); + + impl RestoreOnPending { + pub(crate) fn made_progress(&self) { + self.0.set(Budget::unconstrained()); + } + } + + impl Drop for RestoreOnPending { + fn drop(&mut self) { + // Don't reset if budget was unconstrained or if we made progress. + // They are both represented as the remembered budget being unconstrained. + let budget = self.0.get(); + if !budget.is_unconstrained() { + CURRENT.with(|cell| { + cell.set(budget); + }); + } + } + } + /// Returns `Poll::Pending` if the current task has exceeded its budget and should yield. + /// + /// When you call this method, the current budget is decremented. However, to ensure that + /// progress is made every time a task is polled, the budget is automatically restored to its + /// former value if the returned `RestoreOnPending` is dropped. It is the caller's + /// responsibility to call `RestoreOnPending::made_progress` if it made progress, to ensure + /// that the budget empties appropriately. + /// + /// Note that `RestoreOnPending` restores the budget **as it was before `poll_proceed`**. + /// Therefore, if the budget is _further_ adjusted between when `poll_proceed` returns and + /// `RestRestoreOnPending` is dropped, those adjustments are erased unless the caller indicates + /// that progress was made. #[inline] - pub(crate) fn poll_proceed(cx: &mut Context<'_>) -> Poll<()> { + pub(crate) fn poll_proceed(cx: &mut Context<'_>) -> Poll { CURRENT.with(|cell| { let mut budget = cell.get(); if budget.decrement() { + let restore = RestoreOnPending(Cell::new(cell.get())); cell.set(budget); - Poll::Ready(()) + Poll::Ready(restore) } else { cx.waker().wake_by_ref(); Poll::Pending @@ -181,7 +215,11 @@ cfg_coop! { } else { true } - } + } + + fn is_unconstrained(self) -> bool { + self.0.is_none() + } } } @@ -200,21 +238,41 @@ mod test { assert!(get().0.is_none()); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert!(get().0.is_none()); + drop(coop); assert!(get().0.is_none()); budget(|| { assert_eq!(get().0, Budget::initial().0); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + drop(coop); + // we didn't make progress + assert_eq!(get().0, Budget::initial().0); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + coop.made_progress(); + drop(coop); + // we _did_ make progress assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2); + coop.made_progress(); + drop(coop); assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2); budget(|| { assert_eq!(get().0, Budget::initial().0); - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + coop.made_progress(); + drop(coop); assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); }); @@ -227,11 +285,13 @@ mod test { let n = get().0.unwrap(); for _ in 0..n { - assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + coop.made_progress(); } let mut task = task::spawn(poll_fn(|cx| { - ready!(poll_proceed(cx)); + let coop = ready!(poll_proceed(cx)); + coop.made_progress(); Poll::Ready(()) })); diff --git a/tokio/src/io/registration.rs b/tokio/src/io/registration.rs index 6e7d84b4f90..152f19bdba5 100644 --- a/tokio/src/io/registration.rs +++ b/tokio/src/io/registration.rs @@ -177,11 +177,17 @@ impl Registration { /// This function will panic if called from outside of a task context. pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); - let v = self.poll_ready(Direction::Read, Some(cx))?; + let v = self.poll_ready(Direction::Read, Some(cx)).map_err(|e| { + coop.made_progress(); + e + })?; match v { - Some(v) => Poll::Ready(Ok(v)), + Some(v) => { + coop.made_progress(); + Poll::Ready(Ok(v)) + } None => Poll::Pending, } } @@ -231,11 +237,17 @@ impl Registration { /// This function will panic if called from outside of a task context. pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); - let v = self.poll_ready(Direction::Write, Some(cx))?; + let v = self.poll_ready(Direction::Write, Some(cx)).map_err(|e| { + coop.made_progress(); + e + })?; match v { - Some(v) => Poll::Ready(Ok(v)), + Some(v) => { + coop.made_progress(); + Poll::Ready(Ok(v)) + } None => Poll::Pending, } } diff --git a/tokio/src/process/mod.rs b/tokio/src/process/mod.rs index ab3dae1820d..647f4368387 100644 --- a/tokio/src/process/mod.rs +++ b/tokio/src/process/mod.rs @@ -708,7 +708,7 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); let ret = Pin::new(&mut self.inner).poll(cx); @@ -717,6 +717,10 @@ where self.kill_on_drop = false; } + if ret.is_ready() { + coop.made_progress(); + } + ret } } diff --git a/tokio/src/runtime/task/join.rs b/tokio/src/runtime/task/join.rs index fdcc346e5c1..3c4aabb2e84 100644 --- a/tokio/src/runtime/task/join.rs +++ b/tokio/src/runtime/task/join.rs @@ -102,7 +102,7 @@ impl Future for JoinHandle { let mut ret = Poll::Pending; // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); // Raw should always be set. If it is not, this is due to polling after // completion @@ -126,6 +126,10 @@ impl Future for JoinHandle { raw.try_read_output(&mut ret as *mut _ as *mut (), cx.waker()); } + if ret.is_ready() { + coop.made_progress(); + } + ret } } diff --git a/tokio/src/stream/iter.rs b/tokio/src/stream/iter.rs index d84929d7ec5..bc0388a1442 100644 --- a/tokio/src/stream/iter.rs +++ b/tokio/src/stream/iter.rs @@ -45,7 +45,8 @@ where type Item = I::Item; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); + coop.made_progress(); Poll::Ready(self.iter.next()) } diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index 29f659a06f7..0a3724b4662 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -389,7 +389,7 @@ impl Future for Acquire<'_> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // First, ensure the current task has enough budget to proceed. - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); let (node, semaphore, needed, queued) = self.project(); @@ -399,6 +399,7 @@ impl Future for Acquire<'_> { Pending } Ready(r) => { + coop.made_progress(); r?; *queued = false; Ready(Ok(())) diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 34663957883..148ee3ad766 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -277,7 +277,7 @@ where use super::block::Read::*; // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); self.inner.rx_fields.with_mut(|rx_fields_ptr| { let rx_fields = unsafe { &mut *rx_fields_ptr }; @@ -287,6 +287,7 @@ where match rx_fields.list.pop(&self.inner.tx) { Some(Value(value)) => { self.inner.semaphore.add_permit(); + coop.made_progress(); return Ready(Some(value)); } Some(Closed) => { @@ -297,6 +298,7 @@ where // which ensures that if dropping the tx handle is // visible, then all messages sent are also visible. assert!(self.inner.semaphore.is_idle()); + coop.made_progress(); return Ready(None); } None => {} // fall through @@ -314,6 +316,7 @@ where try_recv!(); if rx_fields.rx_closed && self.inner.semaphore.is_idle() { + coop.made_progress(); Ready(None) } else { Pending @@ -439,11 +442,15 @@ impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) { permit: &mut Permit, ) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); permit .poll_acquire(cx, 1, &self.0) .map_err(|_| ClosedError::new()) + .map(move |r| { + coop.made_progress(); + r + }) } fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> { diff --git a/tokio/src/sync/oneshot.rs b/tokio/src/sync/oneshot.rs index 62ad484eec3..4b033ac3adf 100644 --- a/tokio/src/sync/oneshot.rs +++ b/tokio/src/sync/oneshot.rs @@ -197,13 +197,14 @@ impl Sender { #[doc(hidden)] // TODO: remove pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); let inner = self.inner.as_ref().unwrap(); let mut state = State::load(&inner.state, Acquire); if state.is_closed() { + coop.made_progress(); return Poll::Ready(()); } @@ -216,6 +217,7 @@ impl Sender { if state.is_closed() { // Set the flag again so that the waker is released in drop State::set_tx_task(&inner.state); + coop.made_progress(); return Ready(()); } else { unsafe { inner.drop_tx_task() }; @@ -233,6 +235,7 @@ impl Sender { state = State::set_tx_task(&inner.state); if state.is_closed() { + coop.made_progress(); return Ready(()); } } @@ -548,17 +551,19 @@ impl Inner { fn poll_recv(&self, cx: &mut Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); // Load the state let mut state = State::load(&self.state, Acquire); if state.is_complete() { + coop.made_progress(); match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), } } else if state.is_closed() { + coop.made_progress(); Ready(Err(RecvError(()))) } else { if state.is_rx_task_set() { @@ -572,6 +577,7 @@ impl Inner { // Set the flag again so that the waker is released in drop State::set_rx_task(&self.state); + coop.made_progress(); return match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), @@ -592,6 +598,7 @@ impl Inner { state = State::set_rx_task(&self.state); if state.is_complete() { + coop.made_progress(); match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), diff --git a/tokio/src/time/driver/registration.rs b/tokio/src/time/driver/registration.rs index b77357e7353..3a0b34501b0 100644 --- a/tokio/src/time/driver/registration.rs +++ b/tokio/src/time/driver/registration.rs @@ -40,9 +40,12 @@ impl Registration { pub(crate) fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll> { // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); + let coop = ready!(crate::coop::poll_proceed(cx)); - self.entry.poll_elapsed(cx) + self.entry.poll_elapsed(cx).map(move |r| { + coop.made_progress(); + r + }) } }