From e6e7affc1769723ac41500b7ba01ee74ad16b80d Mon Sep 17 00:00:00 2001 From: Noah Kennedy Date: Fri, 30 Aug 2024 17:52:29 +0100 Subject: [PATCH] rt: add LocalRuntime This change adds LocalRuntime, a new unstable runtime type which cannot be transferred across thread boundaries and supports spawn_local when called from the thread which owns the runtime. The initial set of docs for this are iffy. Documentation is absent right now at the module level, with the docs for the LocalRuntime struct itself being somewhat duplicative of those for the `Runtime` type. This can probably be addressed later as stabilization nears. This API has a few interesting implementation details: - because it was considered beneficial to reuse the same Handle as the normal runtime, it is possible to call spawn_local from a runtime context while on a different thread from the one which drives the runtime and owns it. This forces us to check the thread ID before attempting a local spawn. - An empty LocalOptions struct is passed into the build_local method in order to build the runtime. This will eventually have stuff in it like hooks. Relates to #6739. --- tokio/src/runtime/builder.rs | 72 +++- tokio/src/runtime/handle.rs | 28 +- tokio/src/runtime/local_runtime/mod.rs | 7 + tokio/src/runtime/local_runtime/options.rs | 9 + tokio/src/runtime/local_runtime/runtime.rs | 375 ++++++++++++++++++ tokio/src/runtime/mod.rs | 3 + .../runtime/scheduler/current_thread/mod.rs | 33 ++ tokio/src/runtime/scheduler/mod.rs | 42 ++ tokio/src/runtime/task/list.rs | 20 + tokio/src/task/local.rs | 44 +- 10 files changed, 618 insertions(+), 15 deletions(-) create mode 100644 tokio/src/runtime/local_runtime/mod.rs create mode 100644 tokio/src/runtime/local_runtime/options.rs create mode 100644 tokio/src/runtime/local_runtime/runtime.rs diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index b5bf35d69b4..e6c3bea6a86 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -3,11 +3,16 @@ use crate::runtime::handle::Handle; #[cfg(tokio_unstable)] use crate::runtime::TaskMeta; -use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback}; +use crate::runtime::{ + blocking, driver, Callback, HistogramBuilder, LocalOptions, LocalRuntime, Runtime, TaskCallback, +}; use crate::util::rand::{RngSeed, RngSeedGenerator}; +use crate::runtime::blocking::BlockingPool; +use crate::runtime::scheduler::CurrentThread; use std::fmt; use std::io; +use std::thread::ThreadId; use std::time::Duration; /// Builds Tokio Runtime with custom configuration values. @@ -800,6 +805,29 @@ impl Builder { } } + /// Creates the configured `LocalRuntime`. + /// + /// The returned `LocalRuntime` instance is ready to spawn tasks. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Builder; + /// + /// let rt = Builder::new_current_thread().build_local(&mut Default::default()).unwrap(); + /// + /// rt.block_on(async { + /// println!("Hello from the Tokio runtime"); + /// }); + /// ``` + #[allow(unused_variables)] + pub fn build_local(&mut self, options: &mut LocalOptions) -> io::Result { + match &self.kind { + Kind::CurrentThread => self.build_current_thread_local_runtime(), + _ => panic!("Only current_thread is supported when building a local runtime"), + } + } + fn get_cfg(&self, workers: usize) -> driver::Cfg { driver::Cfg { enable_pause_time: match self.kind { @@ -1191,8 +1219,39 @@ impl Builder { } fn build_current_thread_runtime(&mut self) -> io::Result { - use crate::runtime::scheduler::{self, CurrentThread}; - use crate::runtime::{runtime::Scheduler, Config}; + use crate::runtime::runtime::Scheduler; + + let (scheduler, handle, blocking_pool) = + self.build_current_thread_runtime_components(None)?; + + Ok(Runtime::from_parts( + Scheduler::CurrentThread(scheduler), + handle, + blocking_pool, + )) + } + + fn build_current_thread_local_runtime(&mut self) -> io::Result { + use crate::runtime::local_runtime::LocalRuntimeScheduler; + + let tid = std::thread::current().id(); + + let (scheduler, handle, blocking_pool) = + self.build_current_thread_runtime_components(Some(tid))?; + + Ok(LocalRuntime::from_parts( + LocalRuntimeScheduler::CurrentThread(scheduler), + handle, + blocking_pool, + )) + } + + fn build_current_thread_runtime_components( + &mut self, + local_tid: Option, + ) -> io::Result<(CurrentThread, Handle, BlockingPool)> { + use crate::runtime::scheduler; + use crate::runtime::Config; let (driver, driver_handle) = driver::Driver::new(self.get_cfg(1))?; @@ -1227,17 +1286,14 @@ impl Builder { seed_generator: seed_generator_1, metrics_poll_count_histogram: self.metrics_poll_count_histogram_builder(), }, + local_tid, ); let handle = Handle { inner: scheduler::Handle::CurrentThread(handle), }; - Ok(Runtime::from_parts( - Scheduler::CurrentThread(scheduler), - handle, - blocking_pool, - )) + Ok((scheduler, handle, blocking_pool)) } fn metrics_poll_count_histogram_builder(&self) -> Option { diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 7e3cd1504e5..5fa24b6fd16 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -248,8 +248,8 @@ impl Handle { /// # Panics /// /// This function panics if the provided future panics, if called within an - /// asynchronous execution context, or if a timer future is executed on a - /// runtime that has been shut down. + /// asynchronous execution context, or if a timer future is executed on a runtime that has been + /// shut down. /// /// # Examples /// @@ -345,6 +345,30 @@ impl Handle { self.inner.spawn(future, id) } + #[track_caller] + pub(crate) unsafe fn spawn_local_named( + &self, + future: F, + _name: Option<&str>, + ) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + let id = crate::runtime::task::Id::next(); + #[cfg(all( + tokio_unstable, + tokio_taskdump, + feature = "rt", + target_os = "linux", + any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64") + ))] + let future = super::task::trace::Trace::root(future); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task(future, "task", _name, id.as_u64()); + self.inner.spawn_local(future, id) + } + /// Returns the flavor of the current `Runtime`. /// /// # Examples diff --git a/tokio/src/runtime/local_runtime/mod.rs b/tokio/src/runtime/local_runtime/mod.rs new file mode 100644 index 00000000000..1ea7693f292 --- /dev/null +++ b/tokio/src/runtime/local_runtime/mod.rs @@ -0,0 +1,7 @@ +mod runtime; + +mod options; + +pub use options::LocalOptions; +pub use runtime::LocalRuntime; +pub(super) use runtime::LocalRuntimeScheduler; diff --git a/tokio/src/runtime/local_runtime/options.rs b/tokio/src/runtime/local_runtime/options.rs new file mode 100644 index 00000000000..1ff0d59b2cd --- /dev/null +++ b/tokio/src/runtime/local_runtime/options.rs @@ -0,0 +1,9 @@ +/// LocalRuntime-only config options +/// +/// Currently, there are no such options, but in the future, things like `!Send + !Sync` hooks may +/// be added. +#[derive(Default, Debug)] +#[non_exhaustive] +pub struct LocalOptions { + // todo add local hooks at a later point +} diff --git a/tokio/src/runtime/local_runtime/runtime.rs b/tokio/src/runtime/local_runtime/runtime.rs new file mode 100644 index 00000000000..277d225b95c --- /dev/null +++ b/tokio/src/runtime/local_runtime/runtime.rs @@ -0,0 +1,375 @@ +#![allow(irrefutable_let_patterns)] + +use crate::runtime::blocking::BlockingPool; +use crate::runtime::scheduler::CurrentThread; +use crate::runtime::{context, Builder, EnterGuard, Handle, BOX_FUTURE_THRESHOLD}; +use crate::task::JoinHandle; + +use std::future::Future; +use std::marker::PhantomData; +use std::time::Duration; + +/// A local Tokio runtime. +/// +/// This runtime is identical to a current_thread [runtime], save for not being `!Send + !Sync`, +/// and supporting spawn_local. +/// +/// For more general information on how to use runtimes, see the [module] docs. +/// +/// [runtime]: crate::runtime::Runtime +/// [module]: crate::runtime +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] +pub struct LocalRuntime { + /// Task scheduler + scheduler: LocalRuntimeScheduler, + + /// Handle to runtime, also contains driver handles + handle: Handle, + + /// Blocking pool handle, used to signal shutdown + blocking_pool: BlockingPool, + + /// Marker used to make this !Send and !Sync. + _phantom: PhantomData<*mut u8>, +} + +/// The runtime scheduler is always a current_thread scheduler right now. +#[derive(Debug)] +pub(crate) enum LocalRuntimeScheduler { + /// Execute all tasks on the current-thread. + CurrentThread(CurrentThread), +} + +impl LocalRuntime { + pub(crate) fn from_parts( + scheduler: LocalRuntimeScheduler, + handle: Handle, + blocking_pool: BlockingPool, + ) -> LocalRuntime { + LocalRuntime { + scheduler, + handle, + blocking_pool, + _phantom: Default::default(), + } + } + + /// Creates a new local runtime instance with default configuration values. + /// + /// This results in the scheduler, I/O driver, and time driver being + /// initialized. + /// + /// When a more complex configuration is necessary, the [runtime builder] may be used. + /// + /// See [module level][mod] documentation for more details. + /// + /// # Examples + /// + /// Creating a new `LocalRuntime` with default configuration values. + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// let rt = LocalRuntime::new() + /// .unwrap(); + /// + /// // Use the runtime... + /// ``` + /// + /// [mod]: crate::runtime + /// [runtime builder]: crate::runtime::Builder + pub fn new() -> std::io::Result { + Builder::new_current_thread() + .enable_all() + .build_local(&mut Default::default()) + } + + /// Returns a handle to the runtime's spawner. + /// + /// The returned handle can be used to spawn tasks that run on this runtime, and can + /// be cloned to allow moving the `Handle` to other threads. + /// + /// Local tasks cannot be spawned on this handle. + /// + /// Calling [`Handle::block_on`] on a handle to a `LocalRuntime` is error-prone. + /// Refer to the documentation of [`Handle::block_on`] for more. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// let rt = LocalRuntime::new() + /// .unwrap(); + /// + /// let handle = rt.handle(); + /// + /// // Use the handle... + /// ``` + pub fn handle(&self) -> &Handle { + &self.handle + } + + /// Spawns a future onto the LocalRuntime. + /// + /// See the documentation for the equivalent method on [Runtime] for more information + /// + /// [Runtime]: crate::runtime::Runtime::spawn + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = LocalRuntime::new().unwrap(); + /// + /// // Spawn a future onto the runtime + /// rt.spawn(async { + /// println!("now running on a worker thread"); + /// }); + /// # } + /// ``` + #[track_caller] + pub fn spawn(&self, future: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + if cfg!(debug_assertions) && std::mem::size_of::() > BOX_FUTURE_THRESHOLD { + self.handle.spawn_named(Box::pin(future), None) + } else { + self.handle.spawn_named(future, None) + } + } + + /// Spawns a task which isn't `!Send + Sync` on the runtime. + #[track_caller] + pub fn spawn_local(&self, future: F) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + // safety: spawn_local can only be called from LocalRuntime, which this is + unsafe { + if cfg!(debug_assertions) && std::mem::size_of::() > BOX_FUTURE_THRESHOLD { + self.handle.spawn_local_named(Box::pin(future), None) + } else { + self.handle.spawn_local_named(future, None) + } + } + } + + /// Runs the provided function on an executor dedicated to blocking operations. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = LocalRuntime::new().unwrap(); + /// + /// // Spawn a blocking function onto the runtime + /// rt.spawn_blocking(|| { + /// println!("now running on a worker thread"); + /// }); + /// # } + /// ``` + #[track_caller] + pub fn spawn_blocking(&self, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.handle.spawn_blocking(func) + } + + /// Runs a future to completion on the Tokio runtime. This is the + /// runtime's entry point. + /// + /// See the documentation for the equivalent method on [Runtime] for more information. + /// + /// [Runtime]: crate::runtime::Runtime::block_on + /// + /// # Examples + /// + /// ```no_run + /// use tokio::runtime::LocalRuntime; + /// + /// // Create the runtime + /// let rt = LocalRuntime::new().unwrap(); + /// + /// // Execute the future, blocking the current thread until completion + /// rt.block_on(async { + /// println!("hello"); + /// }); + /// ``` + #[track_caller] + pub fn block_on(&self, future: F) -> F::Output { + if cfg!(debug_assertions) && std::mem::size_of::() > BOX_FUTURE_THRESHOLD { + self.block_on_inner(Box::pin(future)) + } else { + self.block_on_inner(future) + } + } + + #[track_caller] + fn block_on_inner(&self, future: F) -> F::Output { + #[cfg(all( + tokio_unstable, + tokio_taskdump, + feature = "rt", + target_os = "linux", + any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64") + ))] + let future = super::task::trace::Trace::root(future); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task( + future, + "block_on", + None, + crate::runtime::task::Id::next().as_u64(), + ); + + let _enter = self.enter(); + + if let LocalRuntimeScheduler::CurrentThread(exec) = &self.scheduler { + exec.block_on(&self.handle.inner, future) + } else { + unreachable!("LocalRuntime only supports current_thread") + } + } + + /// Enters the runtime context. + /// + /// This allows you to construct types that must have an executor + /// available on creation such as [`Sleep`] or [`TcpStream`]. It will + /// also allow you to call methods such as [`tokio::spawn`]. + /// + /// [`Sleep`]: struct@crate::time::Sleep + /// [`TcpStream`]: struct@crate::net::TcpStream + /// [`tokio::spawn`]: fn@crate::spawn + /// + /// # Example + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// use tokio::task::JoinHandle; + /// + /// fn function_that_spawns(msg: String) -> JoinHandle<()> { + /// // Had we not used `rt.enter` below, this would panic. + /// tokio::spawn(async move { + /// println!("{}", msg); + /// }) + /// } + /// + /// fn main() { + /// let rt = LocalRuntime::new().unwrap(); + /// + /// let s = "Hello World!".to_string(); + /// + /// // By entering the context, we tie `tokio::spawn` to this executor. + /// let _guard = rt.enter(); + /// let handle = function_that_spawns(s); + /// + /// // Wait for the task before we end the test. + /// rt.block_on(handle).unwrap(); + /// } + /// ``` + pub fn enter(&self) -> EnterGuard<'_> { + self.handle.enter() + } + + /// Shuts down the runtime, waiting for at most `duration` for all spawned + /// work to stop. + /// + /// See the [struct level documentation](LocalRuntime#shutdown) for more details. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// use tokio::task; + /// + /// use std::thread; + /// use std::time::Duration; + /// + /// fn main() { + /// let runtime = LocalRuntime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// task::spawn_blocking(move || { + /// thread::sleep(Duration::from_secs(10_000)); + /// }); + /// }); + /// + /// runtime.shutdown_timeout(Duration::from_millis(100)); + /// } + /// ``` + pub fn shutdown_timeout(mut self, duration: Duration) { + // Wakeup and shutdown all the worker threads + self.handle.inner.shutdown(); + self.blocking_pool.shutdown(Some(duration)); + } + + /// Shuts down the runtime, without waiting for any spawned work to stop. + /// + /// This can be useful if you want to drop a runtime from within another runtime. + /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks + /// to complete, which would normally not be permitted within an asynchronous context. + /// By calling `shutdown_background()`, you can drop the runtime from such a context. + /// + /// Note however, that because we do not wait for any blocking tasks to complete, this + /// may result in a resource leak (in that any blocking tasks are still running until they + /// return. + /// + /// See the [struct level documentation](LocalRuntime#shutdown) for more details. + /// + /// This function is equivalent to calling `shutdown_timeout(Duration::from_nanos(0))`. + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// fn main() { + /// let runtime = LocalRuntime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// let inner_runtime = LocalRuntime::new().unwrap(); + /// // ... + /// inner_runtime.shutdown_background(); + /// }); + /// } + /// ``` + pub fn shutdown_background(self) { + self.shutdown_timeout(Duration::from_nanos(0)); + } + + /// Returns a view that lets you get information about how the runtime + /// is performing. + pub fn metrics(&self) -> crate::runtime::RuntimeMetrics { + self.handle.metrics() + } +} + +#[allow(clippy::single_match)] // there are comments in the error branch, so we don't want if-let +impl Drop for LocalRuntime { + fn drop(&mut self) { + if let LocalRuntimeScheduler::CurrentThread(current_thread) = &mut self.scheduler { + // This ensures that tasks spawned on the current-thread + // runtime are dropped inside the runtime's context. + let _guard = context::try_set_current(&self.handle.inner); + current_thread.shutdown(&self.handle.inner); + } else { + unreachable!("LocalRuntime only supports current-thread") + } + } +} + +impl std::panic::UnwindSafe for LocalRuntime {} + +impl std::panic::RefUnwindSafe for LocalRuntime {} diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 3f2467f6dbc..c8efbe2f1cd 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -372,6 +372,9 @@ cfg_rt! { pub use self::builder::UnhandledPanic; pub use crate::util::rand::RngSeed; + + mod local_runtime; + pub use local_runtime::{LocalRuntime, LocalOptions}; } cfg_taskdump! { diff --git a/tokio/src/runtime/scheduler/current_thread/mod.rs b/tokio/src/runtime/scheduler/current_thread/mod.rs index b8cb5b46ca5..06667c88fb5 100644 --- a/tokio/src/runtime/scheduler/current_thread/mod.rs +++ b/tokio/src/runtime/scheduler/current_thread/mod.rs @@ -18,6 +18,7 @@ use std::future::{poll_fn, Future}; use std::sync::atomic::Ordering::{AcqRel, Release}; use std::task::Poll::{Pending, Ready}; use std::task::Waker; +use std::thread::ThreadId; use std::time::Duration; use std::{fmt, thread}; @@ -47,6 +48,9 @@ pub(crate) struct Handle { /// User-supplied hooks to invoke for things pub(crate) task_hooks: TaskHooks, + + /// If this is a LocalRuntime, flags the owning thread ID. + pub(crate) local_tid: Option, } /// Data required for executing the scheduler. The struct is passed around to @@ -127,6 +131,7 @@ impl CurrentThread { blocking_spawner: blocking::Spawner, seed_generator: RngSeedGenerator, config: Config, + local_tid: Option, ) -> (CurrentThread, Arc) { let worker_metrics = WorkerMetrics::from_config(&config); worker_metrics.set_thread_id(thread::current().id()); @@ -152,6 +157,7 @@ impl CurrentThread { driver: driver_handle, blocking_spawner, seed_generator, + local_tid, }); let core = AtomicCell::new(Some(Box::new(Core { @@ -459,6 +465,33 @@ impl Handle { handle } + /// Spawn a task which isn't safe to send across thread boundaries onto the runtime. + /// + /// # Safety + pub(crate) unsafe fn spawn_local( + me: &Arc, + future: F, + id: crate::runtime::task::Id, + ) -> JoinHandle + where + F: crate::future::Future + 'static, + F::Output: 'static, + { + let (handle, notified) = me.shared.owned.bind_local(future, me.clone(), id); + + me.task_hooks.spawn(&TaskMeta { + #[cfg(tokio_unstable)] + id, + _phantom: Default::default(), + }); + + if let Some(notified) = notified { + me.schedule(notified); + } + + handle + } + /// Capture a snapshot of this runtime's state. #[cfg(all( tokio_unstable, diff --git a/tokio/src/runtime/scheduler/mod.rs b/tokio/src/runtime/scheduler/mod.rs index ada8efbad63..749d85525e5 100644 --- a/tokio/src/runtime/scheduler/mod.rs +++ b/tokio/src/runtime/scheduler/mod.rs @@ -113,6 +113,31 @@ cfg_rt! { match_flavor!(self, Handle(h) => &h.blocking_spawner) } + pub(crate) fn is_local(&self) -> bool { + match self { + Handle::CurrentThread(h) => h.local_tid.is_some(), + + #[cfg(feature = "rt-multi-thread")] + Handle::MultiThread(_) => false, + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Handle::MultiThreadAlt(_) => false, + } + } + + /// Returns true if this is a local runtime and the runtime is owned by the current thread. + pub(crate) fn can_spawn_local_on_local_runtime(&self) -> bool { + match self { + Handle::CurrentThread(h) => h.local_tid.map(|x| std::thread::current().id() == x).unwrap_or(false), + + #[cfg(feature = "rt-multi-thread")] + Handle::MultiThread(_) => false, + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Handle::MultiThreadAlt(_) => false, + } + } + pub(crate) fn spawn(&self, future: F, id: Id) -> JoinHandle where F: Future + Send + 'static, @@ -129,6 +154,23 @@ cfg_rt! { } } + /// Spawn a local task + /// + /// # Safety + /// This should only be called in LocalRuntime if the runtime has been verified to be owned + /// by the current thread. + pub(crate) unsafe fn spawn_local(&self, future: F, id: Id) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + if let Handle::CurrentThread(h) = self { + current_thread::Handle::spawn_local(h, future, id) + } else { + panic!("Only current_thread and LocalSet have spawn_local internals implemented") + } + } + pub(crate) fn shutdown(&self) { match *self { Handle::CurrentThread(_) => {}, diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 988d422836d..273ab60fb8c 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -102,6 +102,26 @@ impl OwnedTasks { (join, notified) } + /// Bind a task that isn't safe to transfer across thread boundaries. + /// + /// # Safety + /// Only use this in LocalRuntime where the task cannot move + pub(crate) unsafe fn bind_local( + &self, + task: T, + scheduler: S, + id: super::Id, + ) -> (JoinHandle, Option>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let (task, notified, join) = super::new_task(task, scheduler, id); + let notified = unsafe { self.bind_inner(task, notified) }; + (join, notified) + } + /// The part of `bind` that's the same for every type of future. unsafe fn bind_inner(&self, task: Task, notified: Notified) -> Option> where diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 90d4d3612e8..da256d78f16 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -320,7 +320,7 @@ impl<'a> Drop for LocalDataEnterGuard<'a> { } cfg_rt! { - /// Spawns a `!Send` future on the current [`LocalSet`]. + /// Spawns a `!Send` future on the current [`LocalSet`] or [`LocalRuntime`]. /// /// The spawned future will run on the same thread that called `spawn_local`. /// @@ -360,6 +360,7 @@ cfg_rt! { /// ``` /// /// [`LocalSet`]: struct@crate::task::LocalSet + /// [`LocalRuntime`]: struct@crate::runtime::LocalRuntime /// [`tokio::spawn`]: fn@crate::task::spawn #[track_caller] pub fn spawn_local(future: F) -> JoinHandle @@ -380,10 +381,43 @@ cfg_rt! { where F: Future + 'static, F::Output: 'static { - match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { - None => panic!("`spawn_local` called from outside of a `task::LocalSet`"), - Some(cx) => cx.spawn(future, name) - } + use crate::runtime::{context, task}; + + let res = context::with_current(|handle| { + Some(if handle.is_local() { + if !handle.can_spawn_local_on_local_runtime() { + return None; + } + #[cfg(all( + tokio_unstable, + tokio_taskdump, + feature = "rt", + target_os = "linux", + any( + target_arch = "aarch64", + target_arch = "x86", + target_arch = "x86_64" + ) + ))] + let future = task::trace::Trace::root(future); + let id = task::Id::next(); + let task = crate::util::trace::task(future, "task", name, id.as_u64()); + + // safety: we have verified that this is a LocalRuntime owned by the current thread + unsafe { handle.spawn_local(task, id) } + } else { + match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { + None => panic!("`spawn_local` called from outside of a `task::LocalSet` or LocalRuntime"), + Some(cx) => cx.spawn(future, name) + } + }) + }); + + match res { + Ok(None) => panic!("Local tasks can only be spawned on a LocalRuntime from the thread the runtime was created on"), + Ok(Some(join_handle)) => join_handle, + Err(e) => panic!("{}", e), + } } }