diff --git a/futures-util/src/stream/stream/mod.rs b/futures-util/src/stream/stream/mod.rs index 84b0acd862..9d6ec56724 100644 --- a/futures-util/src/stream/stream/mod.rs +++ b/futures-util/src/stream/stream/mod.rs @@ -146,6 +146,10 @@ mod zip; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::zip::Zip; +mod unzip; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::unzip::{unzip, UnzipLeft, UnzipRight}; + #[cfg(feature = "alloc")] mod chunks; #[cfg(feature = "alloc")] @@ -1182,6 +1186,40 @@ pub trait StreamExt: Stream { assert_stream::<(Self::Item, St::Item), _>(Zip::new(self, other)) } + /// An adapter for unzipping a stream of tuples (T1, T2). + /// + /// Returns two streams, left stream and right stream. + /// You can drop one of them and the other will still work. Underlying stream + /// Will be dropped only when both of the child streams are dropped. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::stream::{self, StreamExt}; + /// + /// let stream = stream::iter(vec![(1, 2), (3, 4), (5, 6), (7, 8)]); + /// + /// let (left, right) = stream.unzip(); + /// let left = left.collect::>().await; + /// let right = right.collect::>().await; + /// assert_eq!(vec![1, 3, 5, 7], left); + /// assert_eq!(vec![2, 4, 6, 8], right); + /// # }); + /// ``` + /// + fn unzip(self) -> (UnzipLeft, UnzipRight) + where + Self: Stream, + Self: Sized, + { + let (left, right) = unzip(self); + ( + assert_stream::(left), + assert_stream::(right), + ) + } + /// Adapter for chaining two streams. /// /// The resulting stream emits elements from the first stream, and when diff --git a/futures-util/src/stream/stream/unzip.rs b/futures-util/src/stream/stream/unzip.rs new file mode 100644 index 0000000000..9f143c8cd1 --- /dev/null +++ b/futures-util/src/stream/stream/unzip.rs @@ -0,0 +1,177 @@ +use crate::task::AtomicWaker; +use alloc::sync::{Arc, Weak}; +use core::pin::Pin; +use futures_core::stream::{FusedStream, Stream}; +use futures_core::task::{Context, Poll}; +use pin_project::{pin_project, pinned_drop}; +use std::sync::mpsc; + +/// SAFETY: safe because only one of two unzipped streams is guaranteed +/// to be accessing underlying stream. This is guaranteed by mpsc. Right +/// stream will access underlying stream only if Sender (or left stream) +/// is dropped in which case try_recv returns disconnected error. +unsafe fn poll_unzipped( + stream: Pin<&mut Arc>, + cx: &mut Context<'_>, +) -> Poll> +where + S: Stream, +{ + stream + .map_unchecked_mut(|x| &mut *(Arc::as_ptr(x) as *mut S)) + .poll_next(cx) +} + +#[pin_project(PinnedDrop)] +#[derive(Debug)] +#[must_use = "streams do nothing unless polled"] +pub struct UnzipLeft +where + S: Stream, +{ + #[pin] + stream: Arc, + right_waker: Weak, + right_queue: mpsc::Sender>, +} + +impl UnzipLeft +where + S: Stream, +{ + fn send_to_right(&self, value: Option) { + if let Some(right_waker) = self.right_waker.upgrade() { + // if right_waker.upgrade() succeeds, then right is not + // dropped so send won't fail. + let _ = self.right_queue.send(value); + right_waker.wake(); + } + } +} + +impl FusedStream for UnzipLeft +where + S: Stream + FusedStream, +{ + fn is_terminated(&self) -> bool { + self.stream.as_ref().is_terminated() + } +} + +impl Stream for UnzipLeft +where + S: Stream, +{ + type Item = T1; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.as_mut().project(); + + // SAFETY: for safety details see comment for function: poll_unzipped + if let Some(value) = ready!(unsafe { poll_unzipped(this.stream, cx) }) { + self.send_to_right(Some(value.1)); + return Poll::Ready(Some(value.0)); + } + self.send_to_right(None); + Poll::Ready(None) + } +} + +#[pinned_drop] +impl PinnedDrop for UnzipLeft +where + S: Stream, +{ + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + // wake right stream if it isn't dropped + if let Some(right_waker) = this.right_waker.upgrade() { + // drop right_queue sender to cause rx.try_recv to return + // TryRecvError::Disconnected, so that right knows left is + // dropped and now it should take over polling the base stream. + drop(this.right_queue); + right_waker.wake(); + } + } +} + +#[pin_project] +#[derive(Debug)] +#[must_use = "streams do nothing unless polled"] +pub struct UnzipRight +where + S: Stream, +{ + #[pin] + stream: Arc, + waker: Arc, + queue: mpsc::Receiver>, + is_done: bool, +} + +impl FusedStream for UnzipRight +where + S: FusedStream, +{ + fn is_terminated(&self) -> bool { + self.is_done + } +} + +impl Stream for UnzipRight +where + S: Stream, +{ + type Item = T2; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + this.waker.register(&cx.waker().clone()); + + match this.queue.try_recv() { + Ok(value) => { + // can't know if more items are in the queue so wake the task + // again while there are items. Will cause extra wake though. + cx.waker().clone().wake(); + if value.is_none() { + *this.is_done = true; + } + Poll::Ready(value) + } + Err(mpsc::TryRecvError::Disconnected) => { + // if left is dropped, it is no longer polling the base stream + // so right should poll it instead. + // SAFETY: for safety details see comment for function: poll_unzipped + if let Some(value) = ready!(unsafe { poll_unzipped(this.stream, cx) }) { + return Poll::Ready(Some(value.1)); + } + *this.is_done = true; + Poll::Ready(None) + } + _ => Poll::Pending, + } + } +} + +pub fn unzip(stream: S) -> (UnzipLeft, UnzipRight) +where + S: Stream, +{ + let base_stream = Arc::new(stream); + let waker = Arc::new(AtomicWaker::new()); + let (tx, rx) = mpsc::channel::>(); + + ( + UnzipLeft { + stream: base_stream.clone(), + right_waker: Arc::downgrade(&waker), + right_queue: tx, + }, + UnzipRight { + stream: base_stream.clone(), + waker: waker, + queue: rx, + is_done: false, + }, + ) +} diff --git a/futures/tests/stream_unzip.rs b/futures/tests/stream_unzip.rs new file mode 100644 index 0000000000..2adf6e2d08 --- /dev/null +++ b/futures/tests/stream_unzip.rs @@ -0,0 +1,78 @@ +use futures::stream::{self, StreamExt}; +use futures::executor::{block_on_stream}; + +fn fail_on_thread_panic() { + std::panic::set_hook(Box::new(move |panic_info: &std::panic::PanicInfo| { + println!("{}", panic_info.to_string()); + std::process::exit(1); + })); +} + +fn sample_stream(start: usize, end: usize) -> futures_util::stream::Iter> { + let list_iter = (start..end) + .filter(|&x| x % 2 == 1) + .map(|x| (x, x + 1)); + + return stream::iter(list_iter.collect::>()); +} + +#[test] +fn left_dropped_before_first_poll() { + let (_, mut s2) = { + let (s1, s2) = sample_stream(1, 2).unzip(); + (block_on_stream(s1), block_on_stream(s2)) + }; + + assert_eq!(s2.next(), Some(2)); + assert_eq!(s2.next(), None); +} + +#[test] +fn left_dropped_after_polled() { + fail_on_thread_panic(); + + let (mut s1, mut s2) = { + let (s1, s2) = sample_stream(1, 4).unzip(); + (block_on_stream(s1), block_on_stream(s2)) + }; + + + let t1 = std::thread::spawn(move || { + assert_eq!(s1.next(), Some(1)); + drop(s1); + }); + + let t2 = std::thread::spawn(move || { + assert_eq!(s2.next(), Some(2)); + assert_eq!(s2.next(), Some(4)); + assert_eq!(s2.next(), None); + }); + + let _ = t1.join(); + let _ = t2.join(); +} + +#[test] +fn right_dropped_before_first_poll() { + let (mut s1, _) = { + let (s1, s2) = sample_stream(1, 2).unzip(); + (block_on_stream(s1), block_on_stream(s2)) + }; + + assert_eq!(s1.next(), Some(1)); + assert_eq!(s1.next(), None); +} + +#[test] +fn right_dropped_after_polled() { + let (mut s1, mut s2) = { + let (s1, s2) = sample_stream(1, 4).unzip(); + (block_on_stream(s1), block_on_stream(s2)) + }; + + assert_eq!(s1.next(), Some(1)); + assert_eq!(s2.next(), Some(2)); + drop(s2); + assert_eq!(s1.next(), Some(3)); + assert_eq!(s1.next(), None); +} \ No newline at end of file