From ba687ed72e942030e796d4021159673072d70fe2 Mon Sep 17 00:00:00 2001 From: Aatif Syed Date: Thu, 31 Oct 2024 11:53:42 +0000 Subject: [PATCH 1/7] test: #246 --- tests/issues.rs | 93 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 tests/issues.rs diff --git a/tests/issues.rs b/tests/issues.rs new file mode 100644 index 0000000..d4acae5 --- /dev/null +++ b/tests/issues.rs @@ -0,0 +1,93 @@ +#![cfg(all(feature = "tokio", feature = "zstd"))] +#![allow(clippy::unusual_byte_groupings)] + +use std::{ + io, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use async_compression::tokio::write::ZstdEncoder; +use tokio::io::{AsyncWrite, AsyncWriteExt as _}; + +/// +#[tokio::test] +async fn issue_246() { + let mut zstd_encoder = Transparent::new(ZstdEncoder::new(DelayedShutdown::default())); + zstd_encoder.shutdown().await.unwrap(); +} + +pin_project_lite::pin_project! { + /// A simple wrapper struct that follows the [`AsyncWrite`] protocol. + struct Transparent { + #[pin] inner: T + } +} + +impl Transparent { + fn new(inner: T) -> Self { + Self { inner } + } +} + +impl AsyncWrite for Transparent { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + /// To quote the [`AsyncWrite`] docs: + /// > Invocation of a shutdown implies an invocation of flush. + /// > Once this method returns Ready it implies that a flush successfully happened before the shutdown happened. + /// > That is, callers don't need to call flush before calling shutdown. + /// > They can rely that by calling shutdown any pending buffered data will be written out. + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + ready!(this.inner.as_mut().poll_flush(cx))?; + this.inner.poll_shutdown(cx) + } +} + +pin_project_lite::pin_project! { + /// Yields [`Poll::Pending`] the first time [`AsyncWrite::poll_shutdown`] is called. + #[derive(Default)] + struct DelayedShutdown { + contents: Vec, + num_times_shutdown_called: u8, + } +} + +impl AsyncWrite for DelayedShutdown { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let _ = cx; + self.project().contents.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let _ = cx; + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project().num_times_shutdown_called { + it @ 0 => { + *it += 1; + cx.waker().wake_by_ref(); + Poll::Pending + } + _ => Poll::Ready(Ok(())), + } + } +} From 1f52b61e9214df230a8f3ab99f70fd321a515260 Mon Sep 17 00:00:00 2001 From: Aatif Syed Date: Thu, 31 Oct 2024 12:33:30 +0000 Subject: [PATCH 2/7] test: add tracing --- tests/issues.rs | 82 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 72 insertions(+), 10 deletions(-) diff --git a/tests/issues.rs b/tests/issues.rs index d4acae5..bc5270a 100644 --- a/tests/issues.rs +++ b/tests/issues.rs @@ -4,7 +4,7 @@ use std::{ io, pin::Pin, - task::{ready, Context, Poll}, + task::{Context, Poll}, }; use async_compression::tokio::write::ZstdEncoder; @@ -13,7 +13,8 @@ use tokio::io::{AsyncWrite, AsyncWriteExt as _}; /// #[tokio::test] async fn issue_246() { - let mut zstd_encoder = Transparent::new(ZstdEncoder::new(DelayedShutdown::default())); + let mut zstd_encoder = + Transparent::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default()))); zstd_encoder.shutdown().await.unwrap(); } @@ -36,11 +37,17 @@ impl AsyncWrite for Transparent { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.project().inner.poll_write(cx, buf) + eprintln!("Transparent::poll_write = ..."); + let ret = self.project().inner.poll_write(cx, buf); + eprintln!("Transparent::poll_write = {:?}", ret); + ret } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_flush(cx) + eprintln!("Transparent::poll_flush = ..."); + let ret = self.project().inner.poll_flush(cx); + eprintln!("Transparent::poll_flush = {:?}", ret); + ret } /// To quote the [`AsyncWrite`] docs: @@ -49,9 +56,15 @@ impl AsyncWrite for Transparent { /// > That is, callers don't need to call flush before calling shutdown. /// > They can rely that by calling shutdown any pending buffered data will be written out. fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + eprintln!("Transparent::poll_shutdown = ..."); let mut this = self.project(); - ready!(this.inner.as_mut().poll_flush(cx))?; - this.inner.poll_shutdown(cx) + let ret = match this.inner.as_mut().poll_flush(cx) { + Poll::Ready(Ok(())) => this.inner.poll_shutdown(cx), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + }; + eprintln!("Transparent::poll_shutdown = {:?}", ret); + ret } } @@ -70,24 +83,73 @@ impl AsyncWrite for DelayedShutdown { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { + eprintln!("DelayedShutdown::poll_write = ..."); let _ = cx; self.project().contents.extend_from_slice(buf); - Poll::Ready(Ok(buf.len())) + let ret = Poll::Ready(Ok(buf.len())); + eprintln!("DelayedShutdown::poll_write = {:?}", ret); + ret } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + eprintln!("DelayedShutdown::poll_flush = ..."); let _ = cx; - Poll::Ready(Ok(())) + let ret = Poll::Ready(Ok(())); + eprintln!("DelayedShutdown::poll_flush = {:?}", ret); + ret } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project().num_times_shutdown_called { + eprintln!("DelayedShutdown::poll_shutdown = ..."); + let ret = match self.project().num_times_shutdown_called { it @ 0 => { *it += 1; cx.waker().wake_by_ref(); Poll::Pending } _ => Poll::Ready(Ok(())), - } + }; + eprintln!("DelayedShutdown::poll_shutdown = {:?}", ret); + ret + } +} + +pin_project_lite::pin_project! { + /// A wrapper which traces all calls + struct Trace { + #[pin] inner: T + } +} + +impl Trace { + fn new(inner: T) -> Self { + Self { inner } + } +} + +impl AsyncWrite for Trace { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + eprintln!("Trace::poll_write = ..."); + let ret = self.project().inner.poll_write(cx, buf); + eprintln!("Trace::poll_write = {:?}", ret); + ret + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + eprintln!("Trace::poll_flush = ..."); + let ret = self.project().inner.poll_flush(cx); + eprintln!("Trace::poll_flush = {:?}", ret); + ret + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + eprintln!("Trace::poll_shutdown = ..."); + let ret = self.project().inner.poll_shutdown(cx); + eprintln!("Trace::poll_shutdown = {:?}", ret); + ret } } From 5536c576af08ec1a7d47ae06ffb9bf825ed751b9 Mon Sep 17 00:00:00 2001 From: Aatif Syed Date: Thu, 31 Oct 2024 12:44:50 +0000 Subject: [PATCH 3/7] test: more tracing --- Cargo.toml | 2 ++ tests/issues.rs | 74 ++++++++++++++++++++----------------------------- 2 files changed, 32 insertions(+), 44 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fbf3f11..e6422d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,8 @@ proptest-derive = "0.5" rand = "0.8.5" tokio = { version = "1.24.2", default-features = false, features = ["io-util", "macros", "rt-multi-thread", "io-std"] } tokio-util = { version = "0.7", default-features = false, features = ["io"] } +tracing = "0.1.40" +tracing-subscriber = "0.3.18" [[test]] name = "brotli" diff --git a/tests/issues.rs b/tests/issues.rs index bc5270a..e040887 100644 --- a/tests/issues.rs +++ b/tests/issues.rs @@ -1,18 +1,26 @@ #![cfg(all(feature = "tokio", feature = "zstd"))] -#![allow(clippy::unusual_byte_groupings)] use std::{ io, pin::Pin, - task::{Context, Poll}, + task::{ready, Context, Poll}, }; use async_compression::tokio::write::ZstdEncoder; use tokio::io::{AsyncWrite, AsyncWriteExt as _}; +use tracing_subscriber::fmt::format::FmtSpan; /// #[tokio::test] async fn issue_246() { + tracing_subscriber::fmt() + .without_time() + .with_ansi(false) + .with_level(false) + .with_test_writer() + .with_target(false) + .with_span_events(FmtSpan::NEW) + .init(); let mut zstd_encoder = Transparent::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default()))); zstd_encoder.shutdown().await.unwrap(); @@ -32,22 +40,18 @@ impl Transparent { } impl AsyncWrite for Transparent { + #[tracing::instrument(name = "Transparent::poll_write", skip_all, ret)] fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - eprintln!("Transparent::poll_write = ..."); - let ret = self.project().inner.poll_write(cx, buf); - eprintln!("Transparent::poll_write = {:?}", ret); - ret + self.project().inner.poll_write(cx, buf) } + #[tracing::instrument(name = "Transparent::poll_flush", skip_all, ret)] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - eprintln!("Transparent::poll_flush = ..."); - let ret = self.project().inner.poll_flush(cx); - eprintln!("Transparent::poll_flush = {:?}", ret); - ret + self.project().inner.poll_flush(cx) } /// To quote the [`AsyncWrite`] docs: @@ -55,16 +59,11 @@ impl AsyncWrite for Transparent { /// > Once this method returns Ready it implies that a flush successfully happened before the shutdown happened. /// > That is, callers don't need to call flush before calling shutdown. /// > They can rely that by calling shutdown any pending buffered data will be written out. + #[tracing::instrument(name = "Transparent::poll_shutdown", skip_all, ret)] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - eprintln!("Transparent::poll_shutdown = ..."); let mut this = self.project(); - let ret = match this.inner.as_mut().poll_flush(cx) { - Poll::Ready(Ok(())) => this.inner.poll_shutdown(cx), - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - }; - eprintln!("Transparent::poll_shutdown = {:?}", ret); - ret + ready!(this.inner.as_mut().poll_flush(cx))?; + this.inner.poll_shutdown(cx) } } @@ -78,39 +77,33 @@ pin_project_lite::pin_project! { } impl AsyncWrite for DelayedShutdown { + #[tracing::instrument(name = "DelayedShutdown::poll_write", skip_all, ret)] fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - eprintln!("DelayedShutdown::poll_write = ..."); let _ = cx; self.project().contents.extend_from_slice(buf); - let ret = Poll::Ready(Ok(buf.len())); - eprintln!("DelayedShutdown::poll_write = {:?}", ret); - ret + Poll::Ready(Ok(buf.len())) } + #[tracing::instrument(name = "DelayedShutdown::poll_flush", skip_all, ret)] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - eprintln!("DelayedShutdown::poll_flush = ..."); let _ = cx; - let ret = Poll::Ready(Ok(())); - eprintln!("DelayedShutdown::poll_flush = {:?}", ret); - ret + Poll::Ready(Ok(())) } + #[tracing::instrument(name = "DelayedShutdown::poll_shutdown", skip_all, ret)] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - eprintln!("DelayedShutdown::poll_shutdown = ..."); - let ret = match self.project().num_times_shutdown_called { + match self.project().num_times_shutdown_called { it @ 0 => { *it += 1; cx.waker().wake_by_ref(); Poll::Pending } _ => Poll::Ready(Ok(())), - }; - eprintln!("DelayedShutdown::poll_shutdown = {:?}", ret); - ret + } } } @@ -128,28 +121,21 @@ impl Trace { } impl AsyncWrite for Trace { + #[tracing::instrument(name = "Trace::poll_write", skip_all, ret)] fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - eprintln!("Trace::poll_write = ..."); - let ret = self.project().inner.poll_write(cx, buf); - eprintln!("Trace::poll_write = {:?}", ret); - ret + self.project().inner.poll_write(cx, buf) } - + #[tracing::instrument(name = "Trace::poll_flush", skip_all, ret)] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - eprintln!("Trace::poll_flush = ..."); - let ret = self.project().inner.poll_flush(cx); - eprintln!("Trace::poll_flush = {:?}", ret); - ret + self.project().inner.poll_flush(cx) } + #[tracing::instrument(name = "Trace::poll_shutdown", skip_all, ret)] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - eprintln!("Trace::poll_shutdown = ..."); - let ret = self.project().inner.poll_shutdown(cx); - eprintln!("Trace::poll_shutdown = {:?}", ret); - ret + self.project().inner.poll_shutdown(cx) } } From 1c6c40711adb6c82c0dc6581da2f7de2bf768796 Mon Sep 17 00:00:00 2001 From: Aatif Syed Date: Thu, 31 Oct 2024 16:30:14 +0000 Subject: [PATCH 4/7] test: doc --- tests/issues.rs | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/issues.rs b/tests/issues.rs index e040887..13408d3 100644 --- a/tests/issues.rs +++ b/tests/issues.rs @@ -10,9 +10,21 @@ use async_compression::tokio::write::ZstdEncoder; use tokio::io::{AsyncWrite, AsyncWriteExt as _}; use tracing_subscriber::fmt::format::FmtSpan; +/// This issue covers our state machine being invalid when using adapters +/// like [`tokio_util::codec`]. +/// +/// After the first [`poll_shutdown`] call, +/// we must expect any number of [`poll_flush`] and [`poll_shutdown`] calls, +/// until [`poll_shutdown`] returns [`Poll::Ready`], +/// according to the documentation on [`AsyncWrite`]. +/// /// -#[tokio::test] -async fn issue_246() { +/// +/// [`tokio_util::codec`](https://docs.rs/tokio-util/latest/tokio_util/codec) +/// [`poll_shutdown`](AsyncWrite::poll_shutdown) +/// [`poll_flush`](AsyncWrite::poll_flush) +#[test] +fn issue_246() { tracing_subscriber::fmt() .without_time() .with_ansi(false) @@ -23,7 +35,7 @@ async fn issue_246() { .init(); let mut zstd_encoder = Transparent::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default()))); - zstd_encoder.shutdown().await.unwrap(); + futures::executor::block_on(zstd_encoder.shutdown()).unwrap(); } pin_project_lite::pin_project! { From e2590606511e43e73b20c7d33cfba33728261679 Mon Sep 17 00:00:00 2001 From: Aatif Syed Date: Thu, 31 Oct 2024 16:37:57 +0000 Subject: [PATCH 5/7] chore: #[should_panic] --- tests/issues.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/issues.rs b/tests/issues.rs index 13408d3..a913a96 100644 --- a/tests/issues.rs +++ b/tests/issues.rs @@ -23,6 +23,7 @@ use tracing_subscriber::fmt::format::FmtSpan; /// [`tokio_util::codec`](https://docs.rs/tokio-util/latest/tokio_util/codec) /// [`poll_shutdown`](AsyncWrite::poll_shutdown) /// [`poll_flush`](AsyncWrite::poll_flush) +#[should_panic = "Flush after shutdown"] // TODO: this should be removed when the bug is fixed #[test] fn issue_246() { tracing_subscriber::fmt() From 8af0baa33053b701389ed526f017315e38df0d7c Mon Sep 17 00:00:00 2001 From: Aatif Syed Date: Fri, 1 Nov 2024 20:25:43 +0000 Subject: [PATCH 6/7] fix: Encoder state machine (#308) --- src/tokio/write/generic/encoder.rs | 151 ++++++++--------------------- tests/issues.rs | 17 ++-- 2 files changed, 50 insertions(+), 118 deletions(-) diff --git a/src/tokio/write/generic/encoder.rs b/src/tokio/write/generic/encoder.rs index f5a83aa..421f064 100644 --- a/src/tokio/write/generic/encoder.rs +++ b/src/tokio/write/generic/encoder.rs @@ -13,20 +13,13 @@ use futures_core::ready; use pin_project_lite::pin_project; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; -#[derive(Debug)] -enum State { - Encoding, - Finishing, - Done, -} - pin_project! { #[derive(Debug)] pub struct Encoder { #[pin] writer: BufWriter, encoder: E, - state: State, + finished: bool } } @@ -35,7 +28,7 @@ impl Encoder { Self { writer: BufWriter::new(writer), encoder, - state: State::Encoding, + finished: false, } } } @@ -62,97 +55,6 @@ impl Encoder { } } -impl Encoder { - fn do_poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - input: &mut PartialBuffer<&[u8]>, - ) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - *this.state = match this.state { - State::Encoding => { - this.encoder.encode(input, &mut output)?; - State::Encoding - } - - State::Finishing | State::Done => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "Write after shutdown", - ))) - } - }; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if input.unwritten().is_empty() { - return Poll::Ready(Ok(())); - } - } - } - - fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - let done = match this.state { - State::Encoding => this.encoder.flush(&mut output)?, - - State::Finishing | State::Done => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "Flush after shutdown", - ))) - } - }; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if done { - return Poll::Ready(Ok(())); - } - } - } - - fn do_poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - *this.state = match this.state { - State::Encoding | State::Finishing => { - if this.encoder.finish(&mut output)? { - State::Done - } else { - State::Finishing - } - } - - State::Done => State::Done, - }; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if let State::Done = this.state { - return Poll::Ready(Ok(())); - } - } - } -} - impl AsyncWrite for Encoder { fn poll_write( self: Pin<&mut Self>, @@ -163,24 +65,55 @@ impl AsyncWrite for Encoder { return Poll::Ready(Ok(0)); } - let mut input = PartialBuffer::new(buf); + let mut this = self.project(); + + let mut encodeme = PartialBuffer::new(buf); - match self.do_poll_write(cx, &mut input)? { - Poll::Pending if input.written().is_empty() => Poll::Pending, - _ => Poll::Ready(Ok(input.written().len())), + loop { + let mut space = + PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?); + this.encoder.encode(&mut encodeme, &mut space)?; + let bytes_encoded = space.written().len(); + this.writer.as_mut().produce(bytes_encoded); + if encodeme.unwritten().is_empty() { + break; + } } + + Poll::Ready(Ok(encodeme.written().len())) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().do_poll_flush(cx))?; - ready!(self.project().writer.as_mut().poll_flush(cx))?; + let mut this = self.project(); + loop { + let mut space = + PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?); + let flushed = this.encoder.flush(&mut space)?; + let bytes_encoded = space.written().len(); + this.writer.as_mut().produce(bytes_encoded); + if flushed { + break; + } + } Poll::Ready(Ok(())) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().do_poll_shutdown(cx))?; - ready!(self.project().writer.as_mut().poll_shutdown(cx))?; - Poll::Ready(Ok(())) + let mut this = self.project(); + if !*this.finished { + loop { + let mut space = + PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?); + let finished = this.encoder.finish(&mut space)?; + let bytes_encoded = space.written().len(); + this.writer.as_mut().produce(bytes_encoded); + if finished { + *this.finished = true; + break; + } + } + } + this.writer.poll_shutdown(cx) } } diff --git a/tests/issues.rs b/tests/issues.rs index a913a96..3d52cd4 100644 --- a/tests/issues.rs +++ b/tests/issues.rs @@ -23,7 +23,6 @@ use tracing_subscriber::fmt::format::FmtSpan; /// [`tokio_util::codec`](https://docs.rs/tokio-util/latest/tokio_util/codec) /// [`poll_shutdown`](AsyncWrite::poll_shutdown) /// [`poll_flush`](AsyncWrite::poll_flush) -#[should_panic = "Flush after shutdown"] // TODO: this should be removed when the bug is fixed #[test] fn issue_246() { tracing_subscriber::fmt() @@ -34,26 +33,26 @@ fn issue_246() { .with_target(false) .with_span_events(FmtSpan::NEW) .init(); - let mut zstd_encoder = - Transparent::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default()))); + let mut zstd_encoder = Wrapper::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default()))); futures::executor::block_on(zstd_encoder.shutdown()).unwrap(); } pin_project_lite::pin_project! { /// A simple wrapper struct that follows the [`AsyncWrite`] protocol. - struct Transparent { + /// This is a stand-in for combinators like `tokio_util::codec`s + struct Wrapper { #[pin] inner: T } } -impl Transparent { +impl Wrapper { fn new(inner: T) -> Self { Self { inner } } } -impl AsyncWrite for Transparent { - #[tracing::instrument(name = "Transparent::poll_write", skip_all, ret)] +impl AsyncWrite for Wrapper { + #[tracing::instrument(name = "Wrapper::poll_write", skip_all, ret)] fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -62,7 +61,7 @@ impl AsyncWrite for Transparent { self.project().inner.poll_write(cx, buf) } - #[tracing::instrument(name = "Transparent::poll_flush", skip_all, ret)] + #[tracing::instrument(name = "Wrapper::poll_flush", skip_all, ret)] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_flush(cx) } @@ -72,7 +71,7 @@ impl AsyncWrite for Transparent { /// > Once this method returns Ready it implies that a flush successfully happened before the shutdown happened. /// > That is, callers don't need to call flush before calling shutdown. /// > They can rely that by calling shutdown any pending buffered data will be written out. - #[tracing::instrument(name = "Transparent::poll_shutdown", skip_all, ret)] + #[tracing::instrument(name = "Wrapper::poll_shutdown", skip_all, ret)] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); ready!(this.inner.as_mut().poll_flush(cx))?; From 94825ab8b84ba39281c534f1993006f19150cca1 Mon Sep 17 00:00:00 2001 From: Aatif Syed Date: Fri, 1 Nov 2024 20:32:48 +0000 Subject: [PATCH 7/7] refactor: loop -> while in Encoder --- src/tokio/write/generic/encoder.rs | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/tokio/write/generic/encoder.rs b/src/tokio/write/generic/encoder.rs index 421f064..c26dc24 100644 --- a/src/tokio/write/generic/encoder.rs +++ b/src/tokio/write/generic/encoder.rs @@ -100,18 +100,12 @@ impl AsyncWrite for Encoder { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); - if !*this.finished { - loop { - let mut space = - PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?); - let finished = this.encoder.finish(&mut space)?; - let bytes_encoded = space.written().len(); - this.writer.as_mut().produce(bytes_encoded); - if finished { - *this.finished = true; - break; - } - } + while !*this.finished { + let mut space = + PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?); + *this.finished = this.encoder.finish(&mut space)?; + let bytes_encoded = space.written().len(); + this.writer.as_mut().produce(bytes_encoded); } this.writer.poll_shutdown(cx) }