From 52fb93dce9b79354fc1d48d80054c76965c5b08b Mon Sep 17 00:00:00 2001 From: Toby Lawrence Date: Wed, 9 Feb 2022 12:09:04 -0500 Subject: [PATCH] sync: refactored `PollSender` to fix a subtly broken `Sink` implementation (#4214) Signed-off-by: Toby Lawrence --- tokio-util/src/sync/mod.rs | 2 +- tokio-util/src/sync/mpsc.rs | 350 +++++++++++++++++++++--------------- tokio-util/tests/mpsc.rs | 200 ++++++++++++++++++--- 3 files changed, 379 insertions(+), 173 deletions(-) diff --git a/tokio-util/src/sync/mod.rs b/tokio-util/src/sync/mod.rs index 0b78a156cf3..9d2f40c32a2 100644 --- a/tokio-util/src/sync/mod.rs +++ b/tokio-util/src/sync/mod.rs @@ -6,7 +6,7 @@ pub use cancellation_token::{guard::DropGuard, CancellationToken, WaitForCancell mod intrusive_double_linked_list; mod mpsc; -pub use mpsc::PollSender; +pub use mpsc::{PollSendError, PollSender}; mod poll_semaphore; pub use poll_semaphore::PollSemaphore; diff --git a/tokio-util/src/sync/mpsc.rs b/tokio-util/src/sync/mpsc.rs index 7c79c98e5d0..ebaa4ae14e3 100644 --- a/tokio-util/src/sync/mpsc.rs +++ b/tokio-util/src/sync/mpsc.rs @@ -1,221 +1,283 @@ -use futures_core::ready; use futures_sink::Sink; use std::pin::Pin; -use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::sync::mpsc::{error::SendError, Sender}; +use std::{fmt, mem}; +use tokio::sync::mpsc::OwnedPermit; +use tokio::sync::mpsc::Sender; use super::ReusableBoxFuture; -// This implementation was chosen over something based on permits because to get a -// `tokio::sync::mpsc::Permit` out of the `inner` future, you must transmute the -// lifetime on the permit to `'static`. +/// Error returned by the `PollSender` when the channel is closed. +#[derive(Debug)] +pub struct PollSendError(Option); + +impl PollSendError { + /// Consumes the stored value, if any. + /// + /// If this error was encountered when calling `start_send`/`send_item`, this will be the item + /// that the caller attempted to send. Otherwise, it will be `None`. + pub fn into_inner(self) -> Option { + self.0 + } +} + +impl fmt::Display for PollSendError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +impl std::error::Error for PollSendError {} + +#[derive(Debug)] +enum State { + Idle(Sender), + Acquiring, + ReadyToSend(OwnedPermit), + Closed, +} /// A wrapper around [`mpsc::Sender`] that can be polled. /// /// [`mpsc::Sender`]: tokio::sync::mpsc::Sender #[derive(Debug)] pub struct PollSender { - /// is none if closed - sender: Option>>, - is_sending: bool, - inner: ReusableBoxFuture>>, + sender: Option>, + state: State, + acquire: ReusableBoxFuture, PollSendError>>, } -// By reusing the same async fn for both Some and None, we make sure every -// future passed to ReusableBoxFuture has the same underlying type, and hence -// the same size and alignment. -async fn make_future(data: Option<(Arc>, T)>) -> Result<(), SendError> { +// Creates a future for acquiring a permit from the underlying channel. This is used to ensure +// there's capacity for a send to complete. +// +// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to +// ReusableBoxFuture has the same underlying type, and hence the same size and alignment. +async fn make_acquire_future( + data: Option>, +) -> Result, PollSendError> { match data { - Some((sender, value)) => sender.send(value).await, - None => unreachable!( - "This future should not be pollable, as is_sending should be set to false." - ), + Some(sender) => sender + .reserve_owned() + .await + .map_err(|_| PollSendError(None)), + None => unreachable!("this future should not be pollable in this state"), } } impl PollSender { - /// Create a new `PollSender`. + /// Creates a new `PollSender`. pub fn new(sender: Sender) -> Self { Self { - sender: Some(Arc::new(sender)), - is_sending: false, - inner: ReusableBoxFuture::new(make_future(None)), + sender: Some(sender.clone()), + state: State::Idle(sender), + acquire: ReusableBoxFuture::new(make_acquire_future(None)), } } - /// Start sending a new item. - /// - /// This method panics if a send is currently in progress. To ensure that no - /// send is in progress, call `poll_send_done` first until it returns - /// `Poll::Ready`. - /// - /// If this method returns an error, that indicates that the channel is - /// closed. Note that this method is not guaranteed to return an error if - /// the channel is closed, but in that case the error would be reported by - /// the first call to `poll_send_done`. - pub fn start_send(&mut self, value: T) -> Result<(), SendError> { - if self.is_sending { - panic!("start_send called while not ready."); - } - match self.sender.clone() { - Some(sender) => { - self.inner.set(make_future(Some((sender, value)))); - self.is_sending = true; - Ok(()) - } - None => Err(SendError(value)), - } + fn take_state(&mut self) -> State { + mem::replace(&mut self.state, State::Closed) } - /// If a send is in progress, poll for its completion. If no send is in progress, - /// this method returns `Poll::Ready(Ok(()))`. + /// Attempts to prepare the sender to receive a value. /// - /// This method can return the following values: + /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to + /// `send_item`. /// - /// - `Poll::Ready(Ok(()))` if the in-progress send has been completed, or there is - /// no send in progress (even if the channel is closed). - /// - `Poll::Ready(Err(err))` if the in-progress send failed because the channel has - /// been closed. - /// - `Poll::Pending` if a send is in progress, but it could not complete now. + /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value, + /// by reserving a slot in the channel for the item to be sent. If this method returns + /// `Poll::Pending`, the current task is registered to be notified (via + /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again. /// - /// When this method returns `Poll::Pending`, the current task is scheduled - /// to receive a wakeup when the message is sent, or when the entire channel - /// is closed (but not if just this sender is closed by - /// `close_this_sender`). Note that on multiple calls to `poll_send_done`, - /// only the `Waker` from the `Context` passed to the most recent call is - /// scheduled to receive a wakeup. + /// # Errors /// - /// If this method returns `Poll::Ready`, then `start_send` is guaranteed to - /// not panic. - pub fn poll_send_done(&mut self, cx: &mut Context<'_>) -> Poll>> { - if !self.is_sending { - return Poll::Ready(Ok(())); - } + /// If the channel is closed, an error will be returned. This is a permanent state. + pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll>> { + loop { + let (result, next_state) = match self.take_state() { + State::Idle(sender) => { + // Start trying to acquire a permit to reserve a slot for our send, and + // immediately loop back around to poll it the first time. + self.acquire.set(make_acquire_future(Some(sender))); + (None, State::Acquiring) + } + State::Acquiring => match self.acquire.poll(cx) { + // Channel has capacity. + Poll::Ready(Ok(permit)) => { + (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit)) + } + // Channel is closed. + Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed), + // Channel doesn't have capacity yet, so we need to wait. + Poll::Pending => (Some(Poll::Pending), State::Acquiring), + }, + // We're closed, either by choice or because the underlying sender was closed. + s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s), + // We're already ready to send an item. + s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s), + }; - let result = self.inner.poll(cx); - if result.is_ready() { - self.is_sending = false; - } - if let Poll::Ready(Err(_)) = &result { - self.sender = None; + self.state = next_state; + if let Some(result) = result { + return result; + } } - result } - /// Check whether the channel is ready to send more messages now. + /// Sends an item to the channel. /// - /// If this method returns `true`, then `start_send` is guaranteed to not - /// panic. + /// Before calling `send_item`, `poll_reserve` must be called with a successful return + /// value of `Poll::Ready(Ok(()))`. /// - /// If the channel is closed, this method returns `true`. - pub fn is_ready(&self) -> bool { - !self.is_sending - } + /// # Errors + /// + /// If the channel is closed, an error will be returned. This is a permanent state. + /// + /// # Panics + /// + /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method + /// will panic. + pub fn send_item(&mut self, value: T) -> Result<(), PollSendError> { + let (result, next_state) = match self.take_state() { + State::Idle(_) | State::Acquiring => { + panic!("`send_item` called without first calling `poll_reserve`") + } + // We have a permit to send our item, so go ahead, which gets us our sender back. + State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))), + // We're closed, either by choice or because the underlying sender was closed. + State::Closed => (Err(PollSendError(Some(value))), State::Closed), + }; - /// Check whether the channel has been closed. - pub fn is_closed(&self) -> bool { - match &self.sender { - Some(sender) => sender.is_closed(), - None => true, - } + // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`. + self.state = if self.sender.is_some() { + next_state + } else { + State::Closed + }; + result } - /// Clone the underlying `Sender`. + /// Checks whether this sender is been closed. /// - /// If this method returns `None`, then the channel is closed. (But it is - /// not guaranteed to return `None` if the channel is closed.) - pub fn clone_inner(&self) -> Option> { - self.sender.as_ref().map(|sender| (&**sender).clone()) + /// The underlying channel that this sender was wrapping may still be open. + pub fn is_closed(&self) -> bool { + matches!(self.state, State::Closed) || self.sender.is_none() } - /// Access the underlying `Sender`. + /// Gets a reference to the `Sender` of the underlying channel. /// - /// If this method returns `None`, then the channel is closed. (But it is - /// not guaranteed to return `None` if the channel is closed.) - pub fn inner_ref(&self) -> Option<&Sender> { - self.sender.as_deref() + /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender + /// was wrapping may still be open. + pub fn get_ref(&self) -> Option<&Sender> { + self.sender.as_ref() } - // This operation is supported because it is required by the Sink trait. - /// Close this sender. No more messages can be sent from this sender. + /// Closes this sender. /// - /// Note that this only closes the channel from the view-point of this - /// sender. The channel remains open until all senders have gone away, or - /// until the [`Receiver`] closes the channel. + /// No more messages will be able to be sent from this sender, but the underlying channel will + /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel. /// - /// If there is a send in progress when this method is called, that send is - /// unaffected by this operation, and `poll_send_done` can still be called - /// to complete that send. + /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made + /// to `send_item` in order to consume the reserved slot. After that, no further sends will be + /// possible. If you do not intend to send another item, you can release the reserved slot back + /// to the underlying sender by calling [`abort_send`]. /// + /// [`abort_send`]: crate::sync::PollSender::abort_send /// [`Receiver`]: tokio::sync::mpsc::Receiver - pub fn close_this_sender(&mut self) { + pub fn close(&mut self) { + // Mark ourselves officially closed by dropping our main sender. self.sender = None; + + // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly + // transition to the closed state. Otherwise, leave the existing permit in place for the + // caller if they want to complete the send. + match self.state { + State::Idle(_) => self.state = State::Closed, + State::Acquiring => { + self.acquire.set(make_acquire_future(None)); + self.state = State::Closed; + } + _ => {} + } } - /// Abort the current in-progress send, if any. + /// Aborts the current in-progress send, if any. /// - /// Returns `true` if a send was aborted. + /// Returns `true` if a send was aborted. If the sender was closed prior to calling + /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be + /// ready to attempt another send. pub fn abort_send(&mut self) -> bool { - if self.is_sending { - self.inner.set(make_future(None)); - self.is_sending = false; - true - } else { - false - } + // We may have been closed in the meantime, after a call to `poll_reserve` already + // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the + // closed state when we actually abort a send, rather than resetting ourselves back to idle. + + let (result, next_state) = match self.take_state() { + // We're currently trying to reserve a slot to send into. + State::Acquiring => { + // Replacing the future drops the in-flight one. + self.acquire.set(make_acquire_future(None)); + + // If we haven't closed yet, we have to clone our stored sender since we have no way + // to get it back from the acquire future we just dropped. + let state = match self.sender.clone() { + Some(sender) => State::Idle(sender), + None => State::Closed, + }; + (true, state) + } + // We got the permit. If we haven't closed yet, get the sender back. + State::ReadyToSend(permit) => { + let state = if self.sender.is_some() { + State::Idle(permit.release()) + } else { + State::Closed + }; + (true, state) + } + s => (false, s), + }; + + self.state = next_state; + result } } impl Clone for PollSender { - /// Clones this `PollSender`. The resulting clone will not have any - /// in-progress send operations, even if the current `PollSender` does. + /// Clones this `PollSender`. + /// + /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`. fn clone(&self) -> PollSender { + let (sender, state) = match self.sender.clone() { + Some(sender) => (Some(sender.clone()), State::Idle(sender)), + None => (None, State::Closed), + }; + Self { - sender: self.sender.clone(), - is_sending: false, - inner: ReusableBoxFuture::new(async { unreachable!() }), + sender, + state, + // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not + // compatible with the transitive bounds required by `Sender`. + acquire: ReusableBoxFuture::new(async { unreachable!() }), } } } impl Sink for PollSender { - type Error = SendError; + type Error = PollSendError; - /// This is equivalent to calling [`poll_send_done`]. - /// - /// [`poll_send_done`]: PollSender::poll_send_done fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::into_inner(self).poll_send_done(cx) + Pin::into_inner(self).poll_reserve(cx) } - /// This is equivalent to calling [`poll_send_done`]. - /// - /// [`poll_send_done`]: PollSender::poll_send_done - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::into_inner(self).poll_send_done(cx) + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } - /// This is equivalent to calling [`start_send`]. - /// - /// [`start_send`]: PollSender::start_send fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { - Pin::into_inner(self).start_send(item) + Pin::into_inner(self).send_item(item) } - /// This method will first flush the `PollSender`, and then close it by - /// calling [`close_this_sender`]. - /// - /// If a send fails while flushing because the [`Receiver`] has gone away, - /// then this function returns an error. The channel is still successfully - /// closed in this situation. - /// - /// [`close_this_sender`]: PollSender::close_this_sender - /// [`Receiver`]: tokio::sync::mpsc::Receiver - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().poll_flush(cx))?; - - Pin::into_inner(self).close_this_sender(); + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Pin::into_inner(self).close(); Poll::Ready(Ok(())) } } diff --git a/tokio-util/tests/mpsc.rs b/tokio-util/tests/mpsc.rs index cb5df0f0b94..a3c164d3eca 100644 --- a/tokio-util/tests/mpsc.rs +++ b/tokio-util/tests/mpsc.rs @@ -5,53 +5,62 @@ use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok use tokio_util::sync::PollSender; #[tokio::test] -async fn test_simple() { +async fn simple() { let (send, mut recv) = channel(3); let mut send = PollSender::new(send); for i in 1..=3i32 { - send.start_send(i).unwrap(); - assert_ready_ok!(spawn(poll_fn(|cx| send.poll_send_done(cx))).poll()); + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(i).unwrap(); } - send.start_send(4).unwrap(); - let mut fourth_send = spawn(poll_fn(|cx| send.poll_send_done(cx))); - assert_pending!(fourth_send.poll()); + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + assert_eq!(recv.recv().await.unwrap(), 1); - assert!(fourth_send.is_woken()); - assert_ready_ok!(fourth_send.poll()); + assert!(reserve.is_woken()); + assert_ready_ok!(reserve.poll()); drop(recv); - // Here, start_send is not guaranteed to fail, but if it doesn't the first - // call to poll_send_done should. - if send.start_send(5).is_ok() { - assert_ready_err!(spawn(poll_fn(|cx| send.poll_send_done(cx))).poll()); - } + send.send_item(42).unwrap(); } #[tokio::test] -async fn test_abort() { +async fn repeated_poll_reserve() { + let (send, mut recv) = channel::(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert_eq!(recv.recv().await.unwrap(), 1); +} + +#[tokio::test] +async fn abort_send() { let (send, mut recv) = channel(3); let mut send = PollSender::new(send); - let send2 = send.clone_inner().unwrap(); + let send2 = send.get_ref().cloned().unwrap(); for i in 1..=3i32 { - send.start_send(i).unwrap(); - assert_ready_ok!(spawn(poll_fn(|cx| send.poll_send_done(cx))).poll()); + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(i).unwrap(); } - send.start_send(4).unwrap(); - { - let mut fourth_send = spawn(poll_fn(|cx| send.poll_send_done(cx))); - assert_pending!(fourth_send.poll()); - assert_eq!(recv.recv().await.unwrap(), 1); - assert!(fourth_send.is_woken()); - } + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + assert_eq!(recv.recv().await.unwrap(), 1); + assert!(reserve.is_woken()); + assert_ready_ok!(reserve.poll()); let mut send2_send = spawn(send2.send(5)); assert_pending!(send2_send.poll()); - send.abort_send(); + assert!(send.abort_send()); assert!(send2_send.is_woken()); assert_ready_ok!(send2_send.poll()); @@ -68,7 +77,7 @@ async fn close_sender_last() { let mut recv_task = spawn(recv.recv()); assert_pending!(recv_task.poll()); - send.close_this_sender(); + send.close(); assert!(recv_task.is_woken()); assert!(assert_ready!(recv_task.poll()).is_none()); @@ -77,13 +86,13 @@ async fn close_sender_last() { #[tokio::test] async fn close_sender_not_last() { let (send, mut recv) = channel::(3); - let send2 = send.clone(); let mut send = PollSender::new(send); + let send2 = send.get_ref().cloned().unwrap(); let mut recv_task = spawn(recv.recv()); assert_pending!(recv_task.poll()); - send.close_this_sender(); + send.close(); assert!(!recv_task.is_woken()); assert_pending!(recv_task.poll()); @@ -93,3 +102,138 @@ async fn close_sender_not_last() { assert!(recv_task.is_woken()); assert!(assert_ready!(recv_task.poll()).is_none()); } + +#[tokio::test] +async fn close_sender_before_reserve() { + let (send, mut recv) = channel::(3); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + send.close(); + + assert!(recv_task.is_woken()); + assert!(assert_ready!(recv_task.poll()).is_none()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[tokio::test] +async fn close_sender_after_pending_reserve() { + let (send, mut recv) = channel::(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert!(recv_task.is_woken()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + drop(reserve); + + send.close(); + + assert!(send.is_closed()); + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[tokio::test] +async fn close_sender_after_successful_reserve() { + let (send, mut recv) = channel::(3); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + drop(reserve); + + send.close(); + assert!(send.is_closed()); + assert!(!recv_task.is_woken()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); +} + +#[tokio::test] +async fn abort_send_after_pending_reserve() { + let (send, mut recv) = channel::(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert_eq!(send.get_ref().unwrap().capacity(), 0); + assert!(!send.abort_send()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + + assert!(send.abort_send()); + assert_eq!(send.get_ref().unwrap().capacity(), 0); +} + +#[tokio::test] +async fn abort_send_after_successful_reserve() { + let (send, mut recv) = channel::(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + assert_eq!(send.get_ref().unwrap().capacity(), 1); + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + assert_eq!(send.get_ref().unwrap().capacity(), 0); + + assert!(send.abort_send()); + assert_eq!(send.get_ref().unwrap().capacity(), 1); +} + +#[tokio::test] +async fn closed_when_receiver_drops() { + let (send, _) = channel::(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[should_panic] +#[test] +fn start_send_panics_when_idle() { + let (send, _) = channel::(3); + let mut send = PollSender::new(send); + + send.send_item(1).unwrap(); +} + +#[should_panic] +#[test] +fn start_send_panics_when_acquiring() { + let (send, _) = channel::(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + send.send_item(2).unwrap(); +}