Skip to content

Commit

Permalink
runtime: minimize the amount of duplicated code (#3416)
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Westerlind authored Jan 29, 2021
1 parent 1f9765f commit 6f98872
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 104 deletions.
4 changes: 2 additions & 2 deletions tokio/src/runtime/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ impl<T> Local<T> {
// tasks and we are the only producer.
self.inner.buffer[i_idx].with_mut(|ptr| unsafe {
let ptr = (*ptr).as_ptr();
(*ptr).header().queue_next.with_mut(|ptr| *ptr = Some(next));
(*ptr).header().set_next(Some(next))
});
}

Expand Down Expand Up @@ -610,7 +610,7 @@ fn get_next(header: NonNull<task::Header>) -> Option<NonNull<task::Header>> {

fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) {
unsafe {
header.as_ref().queue_next.with_mut(|ptr| *ptr = val);
header.as_ref().set_next(val);
}
}

Expand Down
24 changes: 16 additions & 8 deletions tokio/src/runtime/task/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,10 @@ impl<T: Future> CoreStage<T> {
///
/// The caller must ensure it is safe to mutate the `stage` field.
pub(super) fn drop_future_or_output(&self) {
self.stage.with_mut(|ptr| {
// Safety: The caller ensures mutal exclusion to the field.
unsafe { *ptr = Stage::Consumed };
});
// Safety: the caller ensures mutual exclusion to the field.
unsafe {
self.set_stage(Stage::Consumed);
}
}

/// Store the task output
Expand All @@ -261,10 +261,10 @@ impl<T: Future> CoreStage<T> {
///
/// The caller must ensure it is safe to mutate the `stage` field.
pub(super) fn store_output(&self, output: super::Result<T::Output>) {
self.stage.with_mut(|ptr| {
// Safety: the caller ensures mutual exclusion to the field.
unsafe { *ptr = Stage::Finished(output) };
});
// Safety: the caller ensures mutual exclusion to the field.
unsafe {
self.set_stage(Stage::Finished(output));
}
}

/// Take the task output
Expand All @@ -283,6 +283,10 @@ impl<T: Future> CoreStage<T> {
}
})
}

unsafe fn set_stage(&self, stage: Stage<T>) {
self.stage.with_mut(|ptr| *ptr = stage)
}
}

cfg_rt_multi_thread! {
Expand All @@ -293,6 +297,10 @@ cfg_rt_multi_thread! {
let task = unsafe { RawTask::from_raw(self.into()) };
task.shutdown();
}

pub(crate) unsafe fn set_next(&self, next: Option<NonNull<Header>>) {
self.queue_next.with_mut(|ptr| *ptr = next);
}
}
}

Expand Down
56 changes: 28 additions & 28 deletions tokio/src/runtime/task/harness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,44 +403,44 @@ fn poll_future<T: Future>(
snapshot: Snapshot,
cx: Context<'_>,
) -> PollFuture<T::Output> {
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
struct Guard<'a, T: Future> {
core: &'a CoreStage<T>,
}
if snapshot.is_cancelled() {
PollFuture::Complete(Err(JoinError::cancelled()), snapshot.is_join_interested())
} else {
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
struct Guard<'a, T: Future> {
core: &'a CoreStage<T>,
}

impl<T: Future> Drop for Guard<'_, T> {
fn drop(&mut self) {
self.core.drop_future_or_output();
impl<T: Future> Drop for Guard<'_, T> {
fn drop(&mut self) {
self.core.drop_future_or_output();
}
}
}

let guard = Guard { core };
let guard = Guard { core };

// If the task is cancelled, avoid polling it, instead signalling it
// is complete.
if snapshot.is_cancelled() {
Poll::Ready(Err(JoinError::cancelled()))
} else {
let res = guard.core.poll(cx);

// prevent the guard from dropping the future
mem::forget(guard);

res.map(Ok)
}
}));
match res {
Ok(Poll::Pending) => match header.state.transition_to_idle() {
Ok(snapshot) => {
if snapshot.is_notified() {
PollFuture::Notified
} else {
PollFuture::None
res
}));
match res {
Ok(Poll::Pending) => match header.state.transition_to_idle() {
Ok(snapshot) => {
if snapshot.is_notified() {
PollFuture::Notified
} else {
PollFuture::None
}
}
Err(_) => PollFuture::Complete(Err(cancel_task(core)), true),
},
Ok(Poll::Ready(ok)) => PollFuture::Complete(Ok(ok), snapshot.is_join_interested()),
Err(err) => {
PollFuture::Complete(Err(JoinError::panic(err)), snapshot.is_join_interested())
}
Err(_) => PollFuture::Complete(Err(cancel_task(core)), true),
},
Ok(Poll::Ready(ok)) => PollFuture::Complete(ok, snapshot.is_join_interested()),
Err(err) => PollFuture::Complete(Err(JoinError::panic(err)), snapshot.is_join_interested()),
}
}
}
120 changes: 54 additions & 66 deletions tokio/src/sync/oneshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,42 @@ struct Inner<T> {
value: UnsafeCell<Option<T>>,

/// The task to notify when the receiver drops without consuming the value.
tx_task: UnsafeCell<MaybeUninit<Waker>>,
tx_task: Task,

/// The task to notify when the value is sent.
rx_task: UnsafeCell<MaybeUninit<Waker>>,
rx_task: Task,
}

struct Task(UnsafeCell<MaybeUninit<Waker>>);

impl Task {
unsafe fn will_wake(&self, cx: &mut Context<'_>) -> bool {
self.with_task(|w| w.will_wake(cx.waker()))
}

unsafe fn with_task<F, R>(&self, f: F) -> R
where
F: FnOnce(&Waker) -> R,
{
self.0.with(|ptr| {
let waker: *const Waker = (&*ptr).as_ptr();
f(&*waker)
})
}

unsafe fn drop_task(&self) {
self.0.with_mut(|ptr| {
let ptr: *mut Waker = (&mut *ptr).as_mut_ptr();
ptr.drop_in_place();
});
}

unsafe fn set_task(&self, cx: &mut Context<'_>) {
self.0.with_mut(|ptr| {
let ptr: *mut Waker = (&mut *ptr).as_mut_ptr();
ptr.write(cx.waker().clone());
});
}
}

#[derive(Clone, Copy)]
Expand Down Expand Up @@ -127,8 +159,8 @@ pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let inner = Arc::new(Inner {
state: AtomicUsize::new(State::new().as_usize()),
value: UnsafeCell::new(None),
tx_task: UnsafeCell::new(MaybeUninit::uninit()),
rx_task: UnsafeCell::new(MaybeUninit::uninit()),
tx_task: Task(UnsafeCell::new(MaybeUninit::uninit())),
rx_task: Task(UnsafeCell::new(MaybeUninit::uninit())),
});

let tx = Sender {
Expand Down Expand Up @@ -188,9 +220,9 @@ impl<T> Sender<T> {
});

if !inner.complete() {
return Err(inner
.value
.with_mut(|ptr| unsafe { (*ptr).take() }.unwrap()));
unsafe {
return Err(inner.consume_value().unwrap());
}
}

Ok(())
Expand Down Expand Up @@ -357,7 +389,7 @@ impl<T> Sender<T> {
}

if state.is_tx_task_set() {
let will_notify = unsafe { inner.with_tx_task(|w| w.will_wake(cx.waker())) };
let will_notify = unsafe { inner.tx_task.will_wake(cx) };

if !will_notify {
state = State::unset_tx_task(&inner.state);
Expand All @@ -368,15 +400,15 @@ impl<T> Sender<T> {
coop.made_progress();
return Ready(());
} else {
unsafe { inner.drop_tx_task() };
unsafe { inner.tx_task.drop_task() };
}
}
}

if !state.is_tx_task_set() {
// Attempt to set the task
unsafe {
inner.set_tx_task(cx);
inner.tx_task.set_task(cx);
}

// Update the state
Expand Down Expand Up @@ -584,7 +616,7 @@ impl<T> Inner<T> {
if prev.is_rx_task_set() {
// TODO: Consume waker?
unsafe {
self.with_rx_task(Waker::wake_by_ref);
self.rx_task.with_task(Waker::wake_by_ref);
}
}

Expand All @@ -609,7 +641,7 @@ impl<T> Inner<T> {
Ready(Err(RecvError(())))
} else {
if state.is_rx_task_set() {
let will_notify = unsafe { self.with_rx_task(|w| w.will_wake(cx.waker())) };
let will_notify = unsafe { self.rx_task.will_wake(cx) };

// Check if the task is still the same
if !will_notify {
Expand All @@ -625,15 +657,15 @@ impl<T> Inner<T> {
None => Ready(Err(RecvError(()))),
};
} else {
unsafe { self.drop_rx_task() };
unsafe { self.rx_task.drop_task() };
}
}
}

if !state.is_rx_task_set() {
// Attempt to set the task
unsafe {
self.set_rx_task(cx);
self.rx_task.set_task(cx);
}

// Update the state
Expand All @@ -660,7 +692,7 @@ impl<T> Inner<T> {

if prev.is_tx_task_set() && !prev.is_complete() {
unsafe {
self.with_tx_task(Waker::wake_by_ref);
self.tx_task.with_task(Waker::wake_by_ref);
}
}
}
Expand All @@ -669,72 +701,28 @@ impl<T> Inner<T> {
unsafe fn consume_value(&self) -> Option<T> {
self.value.with_mut(|ptr| (*ptr).take())
}

unsafe fn with_rx_task<F, R>(&self, f: F) -> R
where
F: FnOnce(&Waker) -> R,
{
self.rx_task.with(|ptr| {
let waker: *const Waker = (&*ptr).as_ptr();
f(&*waker)
})
}

unsafe fn with_tx_task<F, R>(&self, f: F) -> R
where
F: FnOnce(&Waker) -> R,
{
self.tx_task.with(|ptr| {
let waker: *const Waker = (&*ptr).as_ptr();
f(&*waker)
})
}

unsafe fn drop_rx_task(&self) {
self.rx_task.with_mut(|ptr| {
let ptr: *mut Waker = (&mut *ptr).as_mut_ptr();
ptr.drop_in_place();
});
}

unsafe fn drop_tx_task(&self) {
self.tx_task.with_mut(|ptr| {
let ptr: *mut Waker = (&mut *ptr).as_mut_ptr();
ptr.drop_in_place();
});
}

unsafe fn set_rx_task(&self, cx: &mut Context<'_>) {
self.rx_task.with_mut(|ptr| {
let ptr: *mut Waker = (&mut *ptr).as_mut_ptr();
ptr.write(cx.waker().clone());
});
}

unsafe fn set_tx_task(&self, cx: &mut Context<'_>) {
self.tx_task.with_mut(|ptr| {
let ptr: *mut Waker = (&mut *ptr).as_mut_ptr();
ptr.write(cx.waker().clone());
});
}
}

unsafe impl<T: Send> Send for Inner<T> {}
unsafe impl<T: Send> Sync for Inner<T> {}

fn mut_load(this: &mut AtomicUsize) -> usize {
this.with_mut(|v| *v)
}

impl<T> Drop for Inner<T> {
fn drop(&mut self) {
let state = State(self.state.with_mut(|v| *v));
let state = State(mut_load(&mut self.state));

if state.is_rx_task_set() {
unsafe {
self.drop_rx_task();
self.rx_task.drop_task();
}
}

if state.is_tx_task_set() {
unsafe {
self.drop_tx_task();
self.tx_task.drop_task();
}
}
}
Expand Down

0 comments on commit 6f98872

Please sign in to comment.