diff --git a/futures-macro/src/lib.rs b/futures-macro/src/lib.rs index d1cbc3ce94..fa93e48f1f 100644 --- a/futures-macro/src/lib.rs +++ b/futures-macro/src/lib.rs @@ -19,6 +19,7 @@ use proc_macro::TokenStream; mod executor; mod join; mod select; +mod stream_select; /// The `join!` macro. #[cfg_attr(fn_like_proc_macro, proc_macro)] @@ -54,3 +55,12 @@ pub fn select_biased_internal(input: TokenStream) -> TokenStream { pub fn test_internal(input: TokenStream, item: TokenStream) -> TokenStream { crate::executor::test(input, item) } + +/// The `stream_select!` macro. +#[cfg_attr(fn_like_proc_macro, proc_macro)] +#[cfg_attr(not(fn_like_proc_macro), proc_macro_hack::proc_macro_hack)] +pub fn stream_select_internal(input: TokenStream) -> TokenStream { + crate::stream_select::stream_select(input.into()) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/futures-macro/src/stream_select.rs b/futures-macro/src/stream_select.rs new file mode 100644 index 0000000000..9927b53073 --- /dev/null +++ b/futures-macro/src/stream_select.rs @@ -0,0 +1,113 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; +use syn::{parse::Parser, punctuated::Punctuated, Expr, Index, Token}; + +/// The `stream_select!` macro. +pub(crate) fn stream_select(input: TokenStream) -> Result { + let args = Punctuated::::parse_terminated.parse2(input)?; + if args.len() < 2 { + return Ok(quote! { + compile_error!("stream select macro needs at least two arguments.") + }); + } + let generic_idents = (0..args.len()).map(|i| format_ident!("_{}", i)).collect::>(); + let field_idents = (0..args.len()).map(|i| format_ident!("__{}", i)).collect::>(); + let field_idents_2 = (0..args.len()).map(|i| format_ident!("___{}", i)).collect::>(); + let field_indices = (0..args.len()).map(Index::from).collect::>(); + let args = args.iter().map(|e| e.to_token_stream()); + + Ok(quote! { + { + #[derive(Debug)] + struct StreamSelect<#(#generic_idents),*> (#(Option<#generic_idents>),*); + + enum StreamEnum<#(#generic_idents),*> { + #( + #generic_idents(#generic_idents) + ),*, + None, + } + + impl __futures_crate::stream::Stream for StreamEnum<#(#generic_idents),*> + where #(#generic_idents: __futures_crate::stream::Stream + ::std::marker::Unpin,)* + { + type Item = ITEM; + + fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll> { + match self.get_mut() { + #( + Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll_next(cx) + ),*, + Self::None => panic!("StreamEnum::None should never be polled!"), + } + } + } + + impl __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*> + where #(#generic_idents: __futures_crate::stream::Stream + ::std::marker::Unpin,)* + { + type Item = ITEM; + + fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll> { + let Self(#(ref mut #field_idents),*) = self.get_mut(); + #( + let mut #field_idents_2 = false; + )* + let mut any_pending = false; + { + let mut stream_array = [#(#field_idents.as_mut().map(|f| StreamEnum::#generic_idents(f)).unwrap_or(StreamEnum::None)),*]; + __futures_crate::async_await::shuffle(&mut stream_array); + + for mut s in stream_array { + if let StreamEnum::None = s { + continue; + } else { + match __futures_crate::stream::Stream::poll_next(::std::pin::Pin::new(&mut s), cx) { + r @ __futures_crate::task::Poll::Ready(Some(_)) => { + return r; + }, + __futures_crate::task::Poll::Pending => { + any_pending = true; + }, + __futures_crate::task::Poll::Ready(None) => { + match s { + #( + StreamEnum::#generic_idents(_) => { #field_idents_2 = true; } + ),*, + StreamEnum::None => panic!("StreamEnum::None should never be polled!"), + } + }, + } + } + } + } + #( + if #field_idents_2 { + *#field_idents = None; + } + )* + if any_pending { + __futures_crate::task::Poll::Pending + } else { + __futures_crate::task::Poll::Ready(None) + } + } + + fn size_hint(&self) -> (usize, Option) { + let mut s = (0, Some(0)); + #( + if let Some(new_hint) = self.#field_indices.as_ref().map(|s| s.size_hint()) { + s.0 += new_hint.0; + // We can change this out for `.zip` when the MSRV is 1.46.0 or higher. + s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b)); + } + )* + s + } + } + + StreamSelect(#(Some(#args)),*) + + } + }) +} diff --git a/futures-util/src/async_await/mod.rs b/futures-util/src/async_await/mod.rs index 5f5d4aca3f..7276da227a 100644 --- a/futures-util/src/async_await/mod.rs +++ b/futures-util/src/async_await/mod.rs @@ -30,6 +30,13 @@ mod select_mod; #[cfg(feature = "async-await-macro")] pub use self::select_mod::*; +// Primary export is a macro +#[cfg(feature = "async-await-macro")] +mod stream_select_mod; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/64762 +#[cfg(feature = "async-await-macro")] +pub use self::stream_select_mod::*; + #[cfg(feature = "std")] #[cfg(feature = "async-await-macro")] mod random; diff --git a/futures-util/src/async_await/stream_select_mod.rs b/futures-util/src/async_await/stream_select_mod.rs new file mode 100644 index 0000000000..7743406dab --- /dev/null +++ b/futures-util/src/async_await/stream_select_mod.rs @@ -0,0 +1,45 @@ +//! The `stream_select` macro. + +#[cfg(feature = "std")] +#[allow(unreachable_pub)] +#[doc(hidden)] +#[cfg_attr(not(fn_like_proc_macro), proc_macro_hack::proc_macro_hack(support_nested))] +pub use futures_macro::stream_select_internal; + +/// Combines several streams, all producing the same `Item` type, into one stream. +/// This is similar to `select_all` but does not require the streams to all be the same type. +/// It also keeps the streams inline, and does not require `Box`s to be allocated. +/// Streams passed to this macro must be `Unpin`. +/// +/// If multiple streams are ready, one will be pseudo randomly selected at runtime. +/// +/// This macro is gated behind the `async-await` feature of this library, which is activated by default. +/// Note that `stream_select!` relies on `proc-macro-hack`, and may require to set the compiler's recursion +/// limit very high, e.g. `#![recursion_limit="1024"]`. +/// +/// # Examples +/// +/// ``` +/// # futures::executor::block_on(async { +/// use futures::{stream, StreamExt, stream_select}; +/// let endless_ints = |i| stream::iter(vec![i].into_iter().cycle()).fuse(); +/// +/// let mut endless_numbers = stream_select!(endless_ints(1i32), endless_ints(2), endless_ints(3)); +/// match endless_numbers.next().await { +/// Some(1) => println!("Got a 1"), +/// Some(2) => println!("Got a 2"), +/// Some(3) => println!("Got a 3"), +/// _ => unreachable!(), +/// } +/// # }); +/// ``` +#[cfg(feature = "std")] +#[macro_export] +macro_rules! stream_select { + ($($tokens:tt)*) => {{ + use $crate::__private as __futures_crate; + $crate::stream_select_internal! { + $( $tokens )* + } + }} +} diff --git a/futures/src/lib.rs b/futures/src/lib.rs index 287696f845..37476615b9 100644 --- a/futures/src/lib.rs +++ b/futures/src/lib.rs @@ -137,6 +137,11 @@ pub use futures_util::{join, pending, poll, select_biased, try_join}; // Async-a #[doc(inline)] pub use futures_util::{future, never, sink, stream, task}; +#[cfg(feature = "std")] +#[cfg(feature = "async-await")] +pub use futures_util::stream_select; + +#[cfg_attr(feature = "cfg-target-has-atomic", cfg(target_has_atomic = "ptr"))] #[cfg(feature = "alloc")] #[doc(inline)] pub use futures_channel as channel; diff --git a/futures/tests/async_await_macros.rs b/futures/tests/async_await_macros.rs index 19833d01ca..ce1f3a3379 100644 --- a/futures/tests/async_await_macros.rs +++ b/futures/tests/async_await_macros.rs @@ -4,7 +4,9 @@ use futures::future::{self, poll_fn, FutureExt}; use futures::sink::SinkExt; use futures::stream::StreamExt; use futures::task::{Context, Poll}; -use futures::{join, pending, pin_mut, poll, select, select_biased, try_join}; +use futures::{ + join, pending, pin_mut, poll, select, select_biased, stream, stream_select, try_join, +}; use std::mem; #[test] @@ -308,6 +310,42 @@ fn select_on_mutable_borrowing_future_with_same_borrow_in_block_and_default() { }); } +#[test] +#[allow(unused_assignments)] +fn stream_select() { + // stream_select! macro + block_on(async { + let endless_ints = |i| stream::iter(vec![i].into_iter().cycle()); + + let mut endless_ones = stream_select!(endless_ints(1i32), stream::pending()); + assert_eq!(endless_ones.next().await, Some(1)); + assert_eq!(endless_ones.next().await, Some(1)); + + let mut finite_list = + stream_select!(stream::iter(vec![1].into_iter()), stream::iter(vec![1].into_iter())); + assert_eq!(finite_list.next().await, Some(1)); + assert_eq!(finite_list.next().await, Some(1)); + assert_eq!(finite_list.next().await, None); + + let endless_mixed = stream_select!(endless_ints(1i32), endless_ints(2), endless_ints(3)); + // Take 1000, and assert a somewhat even distribution of values. + // The fairness is randomized, but over 1000 samples we should be pretty close to even. + // This test may be a bit flaky. Feel free to adjust the margins as you see fit. + let mut count = 0; + let results = endless_mixed + .take_while(move |_| { + count += 1; + let ret = count < 1000; + async move { ret } + }) + .collect::>() + .await; + assert!(results.iter().filter(|x| **x == 1).count() >= 299); + assert!(results.iter().filter(|x| **x == 2).count() >= 299); + assert!(results.iter().filter(|x| **x == 3).count() >= 299); + }); +} + #[test] fn join_size() { let fut = async {