From f24b9824e67f833bc78a5a08527cb48a8d053c66 Mon Sep 17 00:00:00 2001 From: Hayden Stainsby Date: Tue, 18 Jul 2023 09:43:25 +0200 Subject: [PATCH] rt: use optional non-zero value for task `owner_id` (#5876) We switch to using a `NonZeroU64` for the `id` field for `OwnedTasks` and `LocalOwnedTasks` lists. This allows the task header to contain an `Option` instead of a `u64` with a special meaning for 0. The size in memory will be the same thanks to Rust's niche optimization, but this solution is clearer in its intent. Co-authored-by: Alice Ryhl --- tokio/src/runtime/task/core.rs | 15 +++++++------ tokio/src/runtime/task/list.rs | 39 +++++++++++++++------------------- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index dbaa330937e..d62ea965659 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -17,6 +17,7 @@ use crate::runtime::task::state::State; use crate::runtime::task::{Id, Schedule}; use crate::util::linked_list; +use std::num::NonZeroU64; use std::pin::Pin; use std::ptr::NonNull; use std::task::{Context, Poll, Waker}; @@ -164,7 +165,7 @@ pub(crate) struct Header { /// This integer contains the id of the OwnedTasks or LocalOwnedTasks that /// this task is stored in. If the task is not in any list, should be the - /// id of the list that it was previously in, or zero if it has never been + /// id of the list that it was previously in, or `None` if it has never been /// in any list. /// /// Once a task has been bound to a list, it can never be bound to another @@ -173,7 +174,7 @@ pub(crate) struct Header { /// The id is not unset when removed from a list because we want to be able /// to read the id without synchronization, even if it is concurrently being /// removed from the list. - pub(super) owner_id: UnsafeCell, + pub(super) owner_id: UnsafeCell>, /// The tracing ID for this instrumented task. #[cfg(all(tokio_unstable, feature = "tracing"))] @@ -221,7 +222,7 @@ impl Cell { state, queue_next: UnsafeCell::new(None), vtable, - owner_id: UnsafeCell::new(0), + owner_id: UnsafeCell::new(None), #[cfg(all(tokio_unstable, feature = "tracing"))] tracing_id, } @@ -394,13 +395,13 @@ impl Header { } // safety: The caller must guarantee exclusive access to this field, and - // must ensure that the id is either 0 or the id of the OwnedTasks + // must ensure that the id is either `None` or the id of the OwnedTasks // containing this task. - pub(super) unsafe fn set_owner_id(&self, owner: u64) { - self.owner_id.with_mut(|ptr| *ptr = owner); + pub(super) unsafe fn set_owner_id(&self, owner: NonZeroU64) { + self.owner_id.with_mut(|ptr| *ptr = Some(owner)); } - pub(super) fn get_owner_id(&self) -> u64 { + pub(super) fn get_owner_id(&self) -> Option { // safety: If there are concurrent writes, then that write has violated // the safety requirements on `set_owner_id`. unsafe { self.owner_id.with(|ptr| *ptr) } diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 930a0099d3d..1c32a1ef361 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -13,10 +13,11 @@ use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; use crate::util::linked_list::{CountedLinkedList, Link, LinkedList}; use std::marker::PhantomData; +use std::num::NonZeroU64; // The id from the module below is used to verify whether a given task is stored // in this OwnedTasks, or some other task. The counter starts at one so we can -// use zero for tasks not owned by any list. +// use `None` for tasks not owned by any list. // // The safety checks in this file can technically be violated if the counter is // overflown, but the checks are not supposed to ever fail unless there is a @@ -28,10 +29,10 @@ cfg_has_atomic_u64! { static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); - fn get_next_id() -> u64 { + fn get_next_id() -> NonZeroU64 { loop { let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); - if id != 0 { + if let Some(id) = NonZeroU64::new(id) { return id; } } @@ -43,11 +44,11 @@ cfg_not_has_atomic_u64! { static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); - fn get_next_id() -> u64 { + fn get_next_id() -> NonZeroU64 { loop { let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); - if id != 0 { - return u64::from(id); + if let Some(id) = NonZeroU64::new(u64::from(id)) { + return id; } } } @@ -55,7 +56,7 @@ cfg_not_has_atomic_u64! { pub(crate) struct OwnedTasks { inner: Mutex>, - id: u64, + id: NonZeroU64, } struct CountedOwnedTasksInner { list: CountedLinkedList, as Link>::Target>, @@ -63,7 +64,7 @@ struct CountedOwnedTasksInner { } pub(crate) struct LocalOwnedTasks { inner: UnsafeCell>, - id: u64, + id: NonZeroU64, _not_send_or_sync: PhantomData<*const ()>, } struct OwnedTasksInner { @@ -127,7 +128,7 @@ impl OwnedTasks { /// a LocalNotified, giving the thread permission to poll this task. #[inline] pub(crate) fn assert_owner(&self, task: Notified) -> LocalNotified { - assert_eq!(task.header().get_owner_id(), self.id); + assert_eq!(task.header().get_owner_id(), Some(self.id)); // safety: All tasks bound to this OwnedTasks are Send, so it is safe // to poll it on this thread no matter what thread we are on. @@ -170,11 +171,9 @@ impl OwnedTasks { } pub(crate) fn remove(&self, task: &Task) -> Option> { - let task_id = task.header().get_owner_id(); - if task_id == 0 { - // The task is unowned. - return None; - } + // If the task's owner ID is `None` then it is not part of any list and + // doesn't need removing. + let task_id = task.header().get_owner_id()?; assert_eq!(task_id, self.id); @@ -257,11 +256,9 @@ impl LocalOwnedTasks { } pub(crate) fn remove(&self, task: &Task) -> Option> { - let task_id = task.header().get_owner_id(); - if task_id == 0 { - // The task is unowned. - return None; - } + // If the task's owner ID is `None` then it is not part of any list and + // doesn't need removing. + let task_id = task.header().get_owner_id()?; assert_eq!(task_id, self.id); @@ -275,7 +272,7 @@ impl LocalOwnedTasks { /// it to a LocalNotified, giving the thread permission to poll this task. #[inline] pub(crate) fn assert_owner(&self, task: Notified) -> LocalNotified { - assert_eq!(task.header().get_owner_id(), self.id); + assert_eq!(task.header().get_owner_id(), Some(self.id)); // safety: The task was bound to this LocalOwnedTasks, and the // LocalOwnedTasks is not Send or Sync, so we are on the right thread @@ -315,11 +312,9 @@ mod tests { #[test] fn test_id_not_broken() { let mut last_id = get_next_id(); - assert_ne!(last_id, 0); for _ in 0..1000 { let next_id = get_next_id(); - assert_ne!(next_id, 0); assert!(last_id < next_id); last_id = next_id; }