diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs index 3fdb280d5f..46d8d3850c 100644 --- a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs @@ -43,7 +43,7 @@ pin_project! { #[pin] rx: rendezvous::Receiver, #[pin] - generator: F, + generator: Option, } } @@ -58,7 +58,7 @@ impl FnStream { let (tx, rx) = rendezvous::channel::(); Self { rx, - generator: generator(tx), + generator: Some(generator(tx)), } } } @@ -74,7 +74,11 @@ where match me.rx.poll_recv(cx) { Poll::Ready(item) => Poll::Ready(item), Poll::Pending => { - let _ = me.generator.poll(cx); + if let Some(generator) = me.generator.as_mut().as_pin_mut() { + if generator.poll(cx).is_ready() { + me.generator.set(None); + } + } Poll::Pending } } @@ -140,7 +144,10 @@ where #[cfg(test)] mod test { use crate::future::fn_stream::{FnStream, TryFlatMap}; + use futures_util::task::noop_waker_ref; + use std::future::Future; use std::sync::{Arc, Mutex}; + use std::task::Context; use std::time::Duration; use tokio_stream::StreamExt; @@ -165,6 +172,30 @@ mod test { assert_eq!(out, vec!["1", "2", "3"]); } + // smithy-rs#1902: there was a bug where we could continue to poll the generator after it + // had returned Poll::Ready. This test case leaks the tx half so that the channel stays open + // but the send side generator completes. By calling `poll` multiple times on the resulting future, + // we can trigger the bug and validate the fix. + #[tokio::test] + async fn fn_stream_doesnt_poll_after_done() { + let mut stream = FnStream::new(|tx| { + Box::pin(async move { + assert!(tx.send("blah").await.is_ok()); + Box::leak(Box::new(tx)); + }) + }); + assert_eq!(stream.next().await, Some("blah")); + let mut fut = Box::pin(stream.next()); + assert!(fut + .as_mut() + .poll(&mut Context::from_waker(noop_waker_ref())) + .is_pending()); + assert!(fut + .as_mut() + .poll(&mut Context::from_waker(noop_waker_ref())) + .is_pending()); + } + /// Tests that the generator will not advance until demand exists #[tokio::test] async fn waits_for_reader() {