From 97662c611475937bfead418102818159b3218518 Mon Sep 17 00:00:00 2001 From: jtnunley Date: Fri, 10 Mar 2023 13:15:35 -0800 Subject: [PATCH 1/4] Push tasks directly to local queue, take 2 --- src/lib.rs | 152 +++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 135 insertions(+), 17 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d8e59ca..151609a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ #![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)] +use std::cell::RefCell; use std::fmt; use std::future::Future; use std::marker::PhantomData; @@ -229,29 +230,51 @@ impl<'a> Executor<'a> { let runner = Runner::new(self.state()); let mut rng = fastrand::Rng::new(); - // A future that runs tasks forever. - let run_forever = async { - loop { - for _ in 0..200 { - let runnable = runner.runnable(&mut rng).await; - runnable.run(); - } - future::yield_now().await; - } - }; + // Set the local queue while we're running. + LocalQueue::set(&runner.local, { + let runner = &runner; + async move { + // A future that runs tasks forever. + let run_forever = async { + loop { + for _ in 0..200 { + let runnable = runner.runnable(&mut rng).await; + runnable.run(); + } + future::yield_now().await; + } + }; - // Run `future` and `run_forever` concurrently until `future` completes. - future.or(run_forever).await + // Run `future` and `run_forever` concurrently until `future` completes. + future.or(run_forever).await + } + }) + .await } /// Returns a function that schedules a runnable task when it gets woken up. fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static { let state = self.state().clone(); - // TODO(stjepang): If possible, push into the current local queue and notify the ticker. + // If possible, push into the current local queue and notify the ticker. move |runnable| { - state.queue.push(runnable).unwrap(); - state.notify(); + let mut runnable = Some(runnable); + + // Try to push into the local queue. + LocalQueue::with(|local_queue| { + if let Err(e) = local_queue.queue.push(runnable.take().unwrap()) { + runnable = Some(e.into_inner()); + return; + } + + local_queue.waker.wake_by_ref(); + }); + + // If the local queue push failed, just push to the global queue. + if let Some(runnable) = runnable { + state.queue.push(runnable).unwrap(); + state.notify(); + } } } @@ -819,6 +842,92 @@ impl Drop for Runner<'_> { } } +/// The state of the currently running local queue. +struct LocalQueue { + /// The concurrent queue. + queue: Arc>, + + /// The waker for the runnable. + waker: Waker, +} + +impl LocalQueue { + /// Run a function with the current local queue. + fn with(f: impl FnOnce(&LocalQueue) -> R) -> Option { + std::thread_local! { + /// The current local queue. + static LOCAL_QUEUE: RefCell> = RefCell::new(None); + } + + impl LocalQueue { + /// Run a function with a set local queue. + async fn set(queue: &Arc>, fut: F) -> F::Output + where + F: Future, + { + // Store the local queue and the current waker. + let mut old = with_waker(|waker| { + LOCAL_QUEUE.with(move |slot| { + let mut slot = slot.borrow_mut(); + slot.replace(LocalQueue { + queue: queue.clone(), + waker: waker.clone(), + }) + }) + }) + .await; + + // Restore the old local queue on drop. + let _guard = CallOnDrop(move || { + let old = old.take(); + LOCAL_QUEUE.with(move |slot| { + let mut slot = slot.borrow_mut(); + *slot = old; + }); + }); + + // Pin the future. + futures_lite::pin!(fut); + + // Run it such that the waker is updated every time it's polled. + future::poll_fn(move |cx| { + LOCAL_QUEUE + .try_with({ + let waker = cx.waker(); + move |slot| { + let mut slot = slot.borrow_mut(); + let qaw = slot.as_mut().expect("missing local queue"); + + // If we've been replaced, just ignore the slot. + if !Arc::ptr_eq(&qaw.queue, queue) { + return; + } + + // Update the waker, if it has changed. + if !qaw.waker.will_wake(waker) { + qaw.waker = waker.clone(); + } + } + }) + .ok(); + + // Poll the future. + fut.as_mut().poll(cx) + }) + .await + } + } + + LOCAL_QUEUE + .try_with(|local_queue| { + let local_queue = local_queue.borrow(); + local_queue.as_ref().map(f) + }) + .ok() + .flatten() + } +} + /// Steals some items from one queue into another. fn steal(src: &ConcurrentQueue, dest: &ConcurrentQueue) { // Half of `src`'s length rounded up. @@ -911,10 +1020,19 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_ } /// Runs a closure when dropped. -struct CallOnDrop(F); +struct CallOnDrop(F); -impl Drop for CallOnDrop { +impl Drop for CallOnDrop { fn drop(&mut self) { (self.0)(); } } + +/// Run a closure with the current waker. +fn with_waker R, R>(f: F) -> impl Future { + let mut f = Some(f); + future::poll_fn(move |cx| { + let f = f.take().unwrap(); + Poll::Ready(f(cx.waker())) + }) +} From 5d35c373fd4a97a56acd6fc3cd17d770b3c7d906 Mon Sep 17 00:00:00 2001 From: jtnunley Date: Fri, 10 Mar 2023 16:55:30 -0800 Subject: [PATCH 2/4] Conciseness --- src/lib.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 151609a..3330b6a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -868,8 +868,7 @@ impl LocalQueue { // Store the local queue and the current waker. let mut old = with_waker(|waker| { LOCAL_QUEUE.with(move |slot| { - let mut slot = slot.borrow_mut(); - slot.replace(LocalQueue { + slot.borrow_mut().replace(LocalQueue { queue: queue.clone(), waker: waker.clone(), }) @@ -881,8 +880,7 @@ impl LocalQueue { let _guard = CallOnDrop(move || { let old = old.take(); LOCAL_QUEUE.with(move |slot| { - let mut slot = slot.borrow_mut(); - *slot = old; + *slot.borrow_mut() = old; }); }); @@ -919,10 +917,7 @@ impl LocalQueue { } LOCAL_QUEUE - .try_with(|local_queue| { - let local_queue = local_queue.borrow(); - local_queue.as_ref().map(f) - }) + .try_with(|local_queue| local_queue.borrow().as_ref().map(f)) .ok() .flatten() } From e93561c7a69cfa529ad650f1e5b5db633a17a65a Mon Sep 17 00:00:00 2001 From: jtnunley Date: Tue, 14 Mar 2023 09:29:47 -0700 Subject: [PATCH 3/4] Don't use with() in a destructor --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 3330b6a..1620c1e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -879,7 +879,7 @@ impl LocalQueue { // Restore the old local queue on drop. let _guard = CallOnDrop(move || { let old = old.take(); - LOCAL_QUEUE.with(move |slot| { + let _ = LOCAL_QUEUE.try_with(move |slot| { *slot.borrow_mut() = old; }); }); From e436d2b7f63e0c7b72a62213424854060f936c42 Mon Sep 17 00:00:00 2001 From: jtnunley Date: Wed, 22 Mar 2023 23:52:35 -0700 Subject: [PATCH 4/4] Nip some potentially unsound behavior in the bud --- src/lib.rs | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 1620c1e..904803f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -231,7 +231,7 @@ impl<'a> Executor<'a> { let mut rng = fastrand::Rng::new(); // Set the local queue while we're running. - LocalQueue::set(&runner.local, { + LocalQueue::set(self.state(), &runner.local, { let runner = &runner; async move { // A future that runs tasks forever. @@ -262,6 +262,11 @@ impl<'a> Executor<'a> { // Try to push into the local queue. LocalQueue::with(|local_queue| { + // Make sure that we don't accidentally push to an executor that isn't ours. + if !std::ptr::eq(local_queue.state, &*state) { + return; + } + if let Err(e) = local_queue.queue.push(runnable.take().unwrap()) { runnable = Some(e.into_inner()); return; @@ -844,6 +849,11 @@ impl Drop for Runner<'_> { /// The state of the currently running local queue. struct LocalQueue { + /// The pointer to the state of the executor. + /// + /// Used to make sure we don't push runnables to the wrong executor. + state: *const State, + /// The concurrent queue. queue: Arc>, @@ -861,7 +871,11 @@ impl LocalQueue { impl LocalQueue { /// Run a function with a set local queue. - async fn set(queue: &Arc>, fut: F) -> F::Output + async fn set( + state: &State, + queue: &Arc>, + fut: F, + ) -> F::Output where F: Future, { @@ -869,6 +883,7 @@ impl LocalQueue { let mut old = with_waker(|waker| { LOCAL_QUEUE.with(move |slot| { slot.borrow_mut().replace(LocalQueue { + state: state as *const State, queue: queue.clone(), waker: waker.clone(), })