Skip to content

Commit

Permalink
rt: use optional non-zero value for task owner_id (#5876)
Browse files Browse the repository at this point in the history
We switch to using a `NonZeroU64` for the `id` field for `OwnedTasks`
and `LocalOwnedTasks` lists. This allows the task header to contain an
`Option<NonZeroU64>` instead of a `u64` with a special meaning for 0.

The size in memory will be the same thanks to Rust's niche optimization,
but this solution is clearer in its intent.

Co-authored-by: Alice Ryhl <[email protected]>
  • Loading branch information
hds and Darksonn authored Jul 18, 2023
1 parent 267a231 commit f24b982
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 29 deletions.
15 changes: 8 additions & 7 deletions tokio/src/runtime/task/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::runtime::task::state::State;
use crate::runtime::task::{Id, Schedule};
use crate::util::linked_list;

use std::num::NonZeroU64;
use std::pin::Pin;
use std::ptr::NonNull;
use std::task::{Context, Poll, Waker};
Expand Down Expand Up @@ -164,7 +165,7 @@ pub(crate) struct Header {

/// This integer contains the id of the OwnedTasks or LocalOwnedTasks that
/// this task is stored in. If the task is not in any list, should be the
/// id of the list that it was previously in, or zero if it has never been
/// id of the list that it was previously in, or `None` if it has never been
/// in any list.
///
/// Once a task has been bound to a list, it can never be bound to another
Expand All @@ -173,7 +174,7 @@ pub(crate) struct Header {
/// The id is not unset when removed from a list because we want to be able
/// to read the id without synchronization, even if it is concurrently being
/// removed from the list.
pub(super) owner_id: UnsafeCell<u64>,
pub(super) owner_id: UnsafeCell<Option<NonZeroU64>>,

/// The tracing ID for this instrumented task.
#[cfg(all(tokio_unstable, feature = "tracing"))]
Expand Down Expand Up @@ -221,7 +222,7 @@ impl<T: Future, S: Schedule> Cell<T, S> {
state,
queue_next: UnsafeCell::new(None),
vtable,
owner_id: UnsafeCell::new(0),
owner_id: UnsafeCell::new(None),
#[cfg(all(tokio_unstable, feature = "tracing"))]
tracing_id,
}
Expand Down Expand Up @@ -394,13 +395,13 @@ impl Header {
}

// safety: The caller must guarantee exclusive access to this field, and
// must ensure that the id is either 0 or the id of the OwnedTasks
// must ensure that the id is either `None` or the id of the OwnedTasks
// containing this task.
pub(super) unsafe fn set_owner_id(&self, owner: u64) {
self.owner_id.with_mut(|ptr| *ptr = owner);
pub(super) unsafe fn set_owner_id(&self, owner: NonZeroU64) {
self.owner_id.with_mut(|ptr| *ptr = Some(owner));
}

pub(super) fn get_owner_id(&self) -> u64 {
pub(super) fn get_owner_id(&self) -> Option<NonZeroU64> {
// safety: If there are concurrent writes, then that write has violated
// the safety requirements on `set_owner_id`.
unsafe { self.owner_id.with(|ptr| *ptr) }
Expand Down
39 changes: 17 additions & 22 deletions tokio/src/runtime/task/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task};
use crate::util::linked_list::{CountedLinkedList, Link, LinkedList};

use std::marker::PhantomData;
use std::num::NonZeroU64;

// The id from the module below is used to verify whether a given task is stored
// in this OwnedTasks, or some other task. The counter starts at one so we can
// use zero for tasks not owned by any list.
// use `None` for tasks not owned by any list.
//
// The safety checks in this file can technically be violated if the counter is
// overflown, but the checks are not supposed to ever fail unless there is a
Expand All @@ -28,10 +29,10 @@ cfg_has_atomic_u64! {

static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);

fn get_next_id() -> u64 {
fn get_next_id() -> NonZeroU64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if id != 0 {
if let Some(id) = NonZeroU64::new(id) {
return id;
}
}
Expand All @@ -43,27 +44,27 @@ cfg_not_has_atomic_u64! {

static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);

fn get_next_id() -> u64 {
fn get_next_id() -> NonZeroU64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if id != 0 {
return u64::from(id);
if let Some(id) = NonZeroU64::new(u64::from(id)) {
return id;
}
}
}
}

pub(crate) struct OwnedTasks<S: 'static> {
inner: Mutex<CountedOwnedTasksInner<S>>,
id: u64,
id: NonZeroU64,
}
struct CountedOwnedTasksInner<S: 'static> {
list: CountedLinkedList<Task<S>, <Task<S> as Link>::Target>,
closed: bool,
}
pub(crate) struct LocalOwnedTasks<S: 'static> {
inner: UnsafeCell<OwnedTasksInner<S>>,
id: u64,
id: NonZeroU64,
_not_send_or_sync: PhantomData<*const ()>,
}
struct OwnedTasksInner<S: 'static> {
Expand Down Expand Up @@ -127,7 +128,7 @@ impl<S: 'static> OwnedTasks<S> {
/// a LocalNotified, giving the thread permission to poll this task.
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
assert_eq!(task.header().get_owner_id(), self.id);
assert_eq!(task.header().get_owner_id(), Some(self.id));

// safety: All tasks bound to this OwnedTasks are Send, so it is safe
// to poll it on this thread no matter what thread we are on.
Expand Down Expand Up @@ -170,11 +171,9 @@ impl<S: 'static> OwnedTasks<S> {
}

pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id();
if task_id == 0 {
// The task is unowned.
return None;
}
// If the task's owner ID is `None` then it is not part of any list and
// doesn't need removing.
let task_id = task.header().get_owner_id()?;

assert_eq!(task_id, self.id);

Expand Down Expand Up @@ -257,11 +256,9 @@ impl<S: 'static> LocalOwnedTasks<S> {
}

pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id();
if task_id == 0 {
// The task is unowned.
return None;
}
// If the task's owner ID is `None` then it is not part of any list and
// doesn't need removing.
let task_id = task.header().get_owner_id()?;

assert_eq!(task_id, self.id);

Expand All @@ -275,7 +272,7 @@ impl<S: 'static> LocalOwnedTasks<S> {
/// it to a LocalNotified, giving the thread permission to poll this task.
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
assert_eq!(task.header().get_owner_id(), self.id);
assert_eq!(task.header().get_owner_id(), Some(self.id));

// safety: The task was bound to this LocalOwnedTasks, and the
// LocalOwnedTasks is not Send or Sync, so we are on the right thread
Expand Down Expand Up @@ -315,11 +312,9 @@ mod tests {
#[test]
fn test_id_not_broken() {
let mut last_id = get_next_id();
assert_ne!(last_id, 0);

for _ in 0..1000 {
let next_id = get_next_id();
assert_ne!(next_id, 0);
assert!(last_id < next_id);
last_id = next_id;
}
Expand Down

0 comments on commit f24b982

Please sign in to comment.