Skip to content

Commit

Permalink
stream: add StreamNotifyClose (#4851)
Browse files Browse the repository at this point in the history
  • Loading branch information
aviramha authored Apr 16, 2023
1 parent 6037fae commit 9507f8b
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tokio-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,8 @@ pub use pending::{pending, Pending};
mod stream_map;
pub use stream_map::StreamMap;

mod stream_close;
pub use stream_close::StreamNotifyClose;

#[doc(no_inline)]
pub use futures_core::Stream;
93 changes: 93 additions & 0 deletions tokio-stream/src/stream_close.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use crate::Stream;
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
/// A `Stream` that wraps the values in an `Option`.
///
/// Whenever the wrapped stream yields an item, this stream yields that item
/// wrapped in `Some`. When the inner stream ends, then this stream first
/// yields a `None` item, and then this stream will also end.
///
/// # Example
///
/// Using `StreamNotifyClose` to handle closed streams with `StreamMap`.
///
/// ```
/// use tokio_stream::{StreamExt, StreamMap, StreamNotifyClose};
///
/// #[tokio::main]
/// async fn main() {
/// let mut map = StreamMap::new();
/// let stream = StreamNotifyClose::new(tokio_stream::iter(vec![0, 1]));
/// let stream2 = StreamNotifyClose::new(tokio_stream::iter(vec![0, 1]));
/// map.insert(0, stream);
/// map.insert(1, stream2);
/// while let Some((key, val)) = map.next().await {
/// match val {
/// Some(val) => println!("got {val:?} from stream {key:?}"),
/// None => println!("stream {key:?} closed"),
/// }
/// }
/// }
/// ```
#[must_use = "streams do nothing unless polled"]
pub struct StreamNotifyClose<S> {
#[pin]
inner: Option<S>,
}
}

impl<S> StreamNotifyClose<S> {
/// Create a new `StreamNotifyClose`.
pub fn new(stream: S) -> Self {
Self {
inner: Some(stream),
}
}

/// Get back the inner `Stream`.
///
/// Returns `None` if the stream has reached its end.
pub fn into_inner(self) -> Option<S> {
self.inner
}
}

impl<S> Stream for StreamNotifyClose<S>
where
S: Stream,
{
type Item = Option<S::Item>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// We can't invoke poll_next after it ended, so we unset the inner stream as a marker.
match self
.as_mut()
.project()
.inner
.as_pin_mut()
.map(|stream| S::poll_next(stream, cx))
{
Some(Poll::Ready(Some(item))) => Poll::Ready(Some(Some(item))),
Some(Poll::Ready(None)) => {
self.project().inner.set(None);
Poll::Ready(Some(None))
}
Some(Poll::Pending) => Poll::Pending,
None => Poll::Ready(None),
}
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
if let Some(inner) = &self.inner {
// We always return +1 because when there's stream there's atleast one more item.
let (l, u) = inner.size_hint();
(l.saturating_add(1), u.and_then(|u| u.checked_add(1)))
} else {
(0, Some(0))
}
}
}
30 changes: 30 additions & 0 deletions tokio-stream/src/stream_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@ use std::task::{Context, Poll};
/// to be merged, it may be advisable to use tasks sending values on a shared
/// [`mpsc`] channel.
///
/// # Notes
///
/// `StreamMap` removes finished streams automatically, without alerting the user.
/// In some scenarios, the caller would want to know on closed streams.
/// To do this, use [`StreamNotifyClose`] as a wrapper to your stream.
/// It will return None when the stream is closed.
///
/// [`StreamExt::merge`]: crate::StreamExt::merge
/// [`mpsc`]: https://docs.rs/tokio/1.0/tokio/sync/mpsc/index.html
/// [`pin!`]: https://docs.rs/tokio/1.0/tokio/macro.pin.html
/// [`Box::pin`]: std::boxed::Box::pin
/// [`StreamNotifyClose`]: crate::StreamNotifyClose
///
/// # Examples
///
Expand Down Expand Up @@ -170,6 +178,28 @@ use std::task::{Context, Poll};
/// }
/// }
/// ```
///
/// Using `StreamNotifyClose` to handle closed streams with `StreamMap`.
///
/// ```
/// use tokio_stream::{StreamExt, StreamMap, StreamNotifyClose};
///
/// #[tokio::main]
/// async fn main() {
/// let mut map = StreamMap::new();
/// let stream = StreamNotifyClose::new(tokio_stream::iter(vec![0, 1]));
/// let stream2 = StreamNotifyClose::new(tokio_stream::iter(vec![0, 1]));
/// map.insert(0, stream);
/// map.insert(1, stream2);
/// while let Some((key, val)) = map.next().await {
/// match val {
/// Some(val) => println!("got {val:?} from stream {key:?}"),
/// None => println!("stream {key:?} closed"),
/// }
/// }
/// }
/// ```
#[derive(Debug)]
pub struct StreamMap<K, V> {
/// Streams stored in the map
Expand Down
11 changes: 11 additions & 0 deletions tokio-stream/tests/stream_close.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use tokio_stream::{StreamExt, StreamNotifyClose};

#[tokio::test]
async fn basic_usage() {
let mut stream = StreamNotifyClose::new(tokio_stream::iter(vec![0, 1]));

assert_eq!(stream.next().await, Some(Some(0)));
assert_eq!(stream.next().await, Some(Some(1)));
assert_eq!(stream.next().await, Some(None));
assert_eq!(stream.next().await, None);
}

0 comments on commit 9507f8b

Please sign in to comment.