Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rt: use optional non-zero value for task owner_id #5876

Merged
merged 5 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions tokio/src/runtime/task/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::runtime::task::state::State;
use crate::runtime::task::{Id, Schedule};
use crate::util::linked_list;

use std::num::NonZeroU64;
use std::pin::Pin;
use std::ptr::NonNull;
use std::task::{Context, Poll, Waker};
Expand Down Expand Up @@ -164,7 +165,7 @@ pub(crate) struct Header {

/// This integer contains the id of the OwnedTasks or LocalOwnedTasks that
/// this task is stored in. If the task is not in any list, should be the
/// id of the list that it was previously in, or zero if it has never been
/// id of the list that it was previously in, or `None` if it has never been
/// in any list.
///
/// Once a task has been bound to a list, it can never be bound to another
Expand All @@ -173,7 +174,7 @@ pub(crate) struct Header {
/// The id is not unset when removed from a list because we want to be able
/// to read the id without synchronization, even if it is concurrently being
/// removed from the list.
pub(super) owner_id: UnsafeCell<u64>,
pub(super) owner_id: UnsafeCell<Option<NonZeroU64>>,

/// The tracing ID for this instrumented task.
#[cfg(all(tokio_unstable, feature = "tracing"))]
Expand Down Expand Up @@ -221,7 +222,7 @@ impl<T: Future, S: Schedule> Cell<T, S> {
state,
queue_next: UnsafeCell::new(None),
vtable,
owner_id: UnsafeCell::new(0),
owner_id: UnsafeCell::new(None),
#[cfg(all(tokio_unstable, feature = "tracing"))]
tracing_id,
}
Expand Down Expand Up @@ -394,13 +395,13 @@ impl Header {
}

// safety: The caller must guarantee exclusive access to this field, and
// must ensure that the id is either 0 or the id of the OwnedTasks
// must ensure that the id is either `None` or the id of the OwnedTasks
// containing this task.
pub(super) unsafe fn set_owner_id(&self, owner: u64) {
self.owner_id.with_mut(|ptr| *ptr = owner);
pub(super) unsafe fn set_owner_id(&self, owner: NonZeroU64) {
self.owner_id.with_mut(|ptr| *ptr = Some(owner));
}

pub(super) fn get_owner_id(&self) -> u64 {
pub(super) fn get_owner_id(&self) -> Option<NonZeroU64> {
// safety: If there are concurrent writes, then that write has violated
// the safety requirements on `set_owner_id`.
unsafe { self.owner_id.with(|ptr| *ptr) }
Expand Down
39 changes: 17 additions & 22 deletions tokio/src/runtime/task/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task};
use crate::util::linked_list::{CountedLinkedList, Link, LinkedList};

use std::marker::PhantomData;
use std::num::NonZeroU64;

// The id from the module below is used to verify whether a given task is stored
// in this OwnedTasks, or some other task. The counter starts at one so we can
// use zero for tasks not owned by any list.
// use `None` for tasks not owned by any list.
//
// The safety checks in this file can technically be violated if the counter is
// overflown, but the checks are not supposed to ever fail unless there is a
Expand All @@ -28,10 +29,10 @@ cfg_has_atomic_u64! {

static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);

fn get_next_id() -> u64 {
fn get_next_id() -> NonZeroU64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if id != 0 {
if let Some(id) = NonZeroU64::new(id) {
return id;
}
}
Expand All @@ -43,27 +44,27 @@ cfg_not_has_atomic_u64! {

static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);

fn get_next_id() -> u64 {
fn get_next_id() -> NonZeroU64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if id != 0 {
return u64::from(id);
if let Some(id) = NonZeroU64::new(u64::from(id)) {
return id;
}
}
}
}

pub(crate) struct OwnedTasks<S: 'static> {
inner: Mutex<CountedOwnedTasksInner<S>>,
id: u64,
id: NonZeroU64,
}
struct CountedOwnedTasksInner<S: 'static> {
list: CountedLinkedList<Task<S>, <Task<S> as Link>::Target>,
closed: bool,
}
pub(crate) struct LocalOwnedTasks<S: 'static> {
inner: UnsafeCell<OwnedTasksInner<S>>,
id: u64,
id: NonZeroU64,
_not_send_or_sync: PhantomData<*const ()>,
}
struct OwnedTasksInner<S: 'static> {
Expand Down Expand Up @@ -127,7 +128,7 @@ impl<S: 'static> OwnedTasks<S> {
/// a LocalNotified, giving the thread permission to poll this task.
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
assert_eq!(task.header().get_owner_id(), self.id);
assert_eq!(task.header().get_owner_id(), Some(self.id));

// safety: All tasks bound to this OwnedTasks are Send, so it is safe
// to poll it on this thread no matter what thread we are on.
Expand Down Expand Up @@ -170,11 +171,9 @@ impl<S: 'static> OwnedTasks<S> {
}

pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id();
if task_id == 0 {
// The task is unowned.
return None;
}
// If the task's owner ID is `None` then it is not part of any list and
// doesn't need removing.
let task_id = task.header().get_owner_id()?;

assert_eq!(task_id, self.id);

Expand Down Expand Up @@ -257,11 +256,9 @@ impl<S: 'static> LocalOwnedTasks<S> {
}

pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id();
if task_id == 0 {
// The task is unowned.
return None;
}
// If the task's owner ID is `None` then it is not part of any list and
// doesn't need removing.
let task_id = task.header().get_owner_id()?;

assert_eq!(task_id, self.id);

Expand All @@ -275,7 +272,7 @@ impl<S: 'static> LocalOwnedTasks<S> {
/// it to a LocalNotified, giving the thread permission to poll this task.
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
assert_eq!(task.header().get_owner_id(), self.id);
assert_eq!(task.header().get_owner_id(), Some(self.id));

// safety: The task was bound to this LocalOwnedTasks, and the
// LocalOwnedTasks is not Send or Sync, so we are on the right thread
Expand Down Expand Up @@ -315,11 +312,9 @@ mod tests {
#[test]
fn test_id_not_broken() {
let mut last_id = get_next_id();
assert_ne!(last_id, 0);

for _ in 0..1000 {
let next_id = get_next_id();
assert_ne!(next_id, 0);
assert!(last_id < next_id);
last_id = next_id;
}
Expand Down