From 1a4e7862434942f8da6a62c6cabe8cb222f8d92b Mon Sep 17 00:00:00 2001
From: Felix Obenhuber <felix.obenhuber@esrlabs.com>
Date: Tue, 14 Feb 2023 17:40:38 +0100
Subject: [PATCH] Implement Stream for mpsc::Receiver. Add PollSender.

Implement `futures::future::Stream` for `mpsc::Receiver`. Add
`mpsc::PollSender` that wrapps a mpsc::Sender and implements `Sink`.
---
 Cargo.toml             |   2 +
 src/mpsc/async_impl.rs | 168 +++++++++++++++++++++++++++++++++++++++++
 src/mpsc/errors.rs     |   5 +-
 3 files changed, 173 insertions(+), 2 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index b615394..dd305bf 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -25,10 +25,12 @@ std = ["alloc", "parking_lot"]
 alloc = []
 default = ["std"]
 static = []
+stream = ["futures"]
 
 [dependencies]
 pin-project = "1"
 parking_lot = { version = "0.12", optional = true, default-features = false }
+futures = { version = "0.3", default_features = false, optional = true }
 
 [dev-dependencies]
 tokio = { version = "1.14.0", features = ["rt", "rt-multi-thread", "macros", "sync"] }
diff --git a/src/mpsc/async_impl.rs b/src/mpsc/async_impl.rs
index 9f3431b..d76b579 100644
--- a/src/mpsc/async_impl.rs
+++ b/src/mpsc/async_impl.rs
@@ -281,6 +281,161 @@ feature! {
         }
     }
 
+    feature! {
+        #![feature = "stream"]
+        use futures::sink::Sink;
+        use std::mem;
+
+        /// PollSender internal state.
+        enum PollSenderState<'a, T, R> {
+            Idle(Sender<T, R>),
+            Acquiring(Pin<Box<SendRefFuture<'a ,T, R>>>),
+            ReadyToSend(SendRef<'a, T>),
+            Closed,
+        }
+
+        /// A wrapper around mpsc::Sender that can be polled.
+        pub struct PollSender<'a, T, R> {
+            sender: Option<Sender<T, R>>,
+            state: PollSenderState<'a, T, R>,
+        }
+
+        impl<'a, T, R> PollSender<'a, T, R> where R: 'a + Recycle<T>, T: 'a {
+            /// Creates a new `PollSender`.
+            pub fn new(sender: Sender<T, R>) -> PollSender<'a, T, R> {
+                PollSender {
+                    sender: Some(sender.clone()),
+                    state: PollSenderState::Idle(sender),
+                }
+            }
+
+            /// Attempts to reserve a slot in the channel.
+            pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Closed>> {
+                loop {
+                    let (result, next_state) = match mem::replace(&mut self.state, PollSenderState::Closed) {
+                        PollSenderState::Idle(sender) => {
+                            let send_ref = Box::pin(SendRefFuture {
+                                core: &sender.inner.core,
+                                slots: sender.inner.slots.as_ref(),
+                                recycle: &sender.inner.recycle,
+                                state: State::Start,
+                                waiter: queue::Waiter::new(),
+                            });
+                            // Start trying to acquire a permit to reserve a slot for our send, and
+                            // immediately loop back around to poll it the first time.
+                            (None, PollSenderState::Acquiring(send_ref))
+                        }
+                        PollSenderState::Acquiring(mut f) => match f.as_mut().poll(cx) {
+                            // Channel has capacity.
+                            Poll::Ready(Ok(send_ref)) => {
+                                (Some(Poll::Ready(Ok(()))), PollSenderState::ReadyToSend(send_ref))
+                            }
+                            // Channel is closed.
+                            Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), PollSenderState::Closed),
+                            // Channel doesn't have capacity yet, so we need to wait.
+                            Poll::Pending => (Some(Poll::Pending), PollSenderState::Acquiring(f)),
+                        },
+                        // We're closed, either by choice or because the underlying sender was closed.
+                        s @ PollSenderState::Closed => (Some(Poll::Ready(Err(Closed(())))), s),
+                        // We're already ready to send an item.
+                        s @ PollSenderState::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s),
+                    };
+
+                    self.state = next_state;
+                    if let Some(result) = result {
+                        return result;
+                    }
+                }
+            }
+
+            /// Sends an item to the channel.
+            ///
+            /// Before calling `send_item`, `poll_reserve` must be called with a successful return
+            /// value of `Poll::Ready(Ok(()))`.
+            ///
+            /// # Errors
+            ///
+            /// If the channel is closed, an error will be returned.  This is a permanent state.
+            ///
+            /// # Panics
+            ///
+            /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method
+            /// will panic.
+            pub fn send_item(&mut self, value: T) -> Result<(), Closed<T>> {
+                let (result, next_state) = match mem::replace(&mut self.state, PollSenderState::Closed) {
+                    PollSenderState::Idle(_) | PollSenderState::Acquiring(_) => {
+                        panic!("`send_item` called without first calling `poll_reserve`")
+                    }
+                    // We have a permit to send our item, so go ahead, which gets us our sender back.
+                    PollSenderState::ReadyToSend(mut send_ref) => {
+                        *send_ref = value;
+                        match &self.sender {
+                            Some(sender) => (Ok(()), PollSenderState::<T, R>::Idle(sender.clone())),
+                            None => (Ok(()), PollSenderState::Closed), // Closed in between.
+                        }
+                    },
+                    // We're closed, either by choice or because the underlying sender was closed.
+                    PollSenderState::Closed => (Err(Closed(value)), PollSenderState::Closed),
+                };
+
+                // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`.
+                self.state = if self.sender.is_some() {
+                    next_state
+                } else {
+                    PollSenderState::Closed
+                };
+
+                result
+            }
+
+            /// Checks whether this sender is been closed.
+            ///
+            /// The underlying channel that this sender was wrapping may still be open.
+            pub fn is_closed(&'a self) -> bool {
+                matches!(self.state, PollSenderState::Closed) || self.sender.is_none()
+            }
+
+            /// Gets a reference to the `Sender` of the underlying channel.
+            ///
+            /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender
+            /// was wrapping may still be open.
+            pub fn get_ref(&self) -> Option<&Sender<T, R>> {
+                self.sender.as_ref()
+            }
+
+            /// Closes this sender.
+            ///
+            /// No more messages will be able to be sent from this sender, but the underlying channel will
+            /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel.
+            pub fn close(&mut self) {
+                // Mark ourselves officially closed by dropping our main sender.
+                self.sender = None;
+                self.state = PollSenderState::Closed;
+            }
+        }
+
+        impl<'a, T, R> Sink<T> for PollSender<'a, T, R> where T: 'a + Default, R: 'a + Recycle<T> {
+            type Error = Closed<T>;
+
+            fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+                Pin::into_inner(self).poll_reserve(cx).map_err(|_| Closed(T::default()))
+            }
+
+            fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
+                Pin::into_inner(self).send_item(item)
+            }
+
+            fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+                Poll::Ready(Ok(()))
+            }
+
+            fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+                Pin::into_inner(self).close();
+                Poll::Ready(Ok(()))
+            }
+        }
+    }
+
     // === impl Receiver ===
 
     impl<T, R> Receiver<T, R> {
@@ -589,6 +744,19 @@ feature! {
         }
     }
 
+    feature! {
+        #![feature = "stream"]
+        use futures::stream::Stream;
+
+        impl<T: Default + Clone> Stream for Receiver<T> {
+            type Item = T;
+
+            fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+                self.poll_recv(cx)
+            }
+        }
+    }
+
     impl<T, R: fmt::Debug> fmt::Debug for Inner<T, R> {
         fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
             f.debug_struct("Inner")
diff --git a/src/mpsc/errors.rs b/src/mpsc/errors.rs
index a41df34..9a48c2e 100644
--- a/src/mpsc/errors.rs
+++ b/src/mpsc/errors.rs
@@ -72,13 +72,14 @@ pub enum TryRecvError {
 }
 
 /// Error returned by [`Sender::send`] or [`Sender::send_ref`] (and
-/// [`StaticSender::send`]/[`StaticSender::send_ref`]), if the
-/// [`Receiver`] half of the channel has been dropped.
+/// [`StaticSender::send`]/[`StaticSender::send_ref`]/[`PollSender::poll_reserve`]),
+/// if the [`Receiver`] half of the channel has been dropped.
 ///
 /// [`Sender::send`]: super::Sender::send
 /// [`Sender::send_ref`]: super::Sender::send_ref
 /// [`StaticSender::send`]: super::StaticSender::send
 /// [`StaticSender::send_ref`]: super::StaticSender::send_ref
+/// [`PollSender::poll_reserve`]: super::PollSender::poll_reserve
 /// [`Receiver`]: super::Receiver
 #[derive(PartialEq, Eq)]
 pub struct Closed<T = ()>(pub(crate) T);