diff --git a/tokio-util/src/codec/framed.rs b/tokio-util/src/codec/framed.rs index 3f7e4207d31..d89b8b6dc34 100644 --- a/tokio-util/src/codec/framed.rs +++ b/tokio-util/src/codec/framed.rs @@ -204,6 +204,26 @@ impl Framed { &mut self.inner.codec } + /// Maps the codec `U` to `C`, preserving the read and write buffers + /// wrapped by `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn map_codec(self, map: F) -> Framed + where + F: FnOnce(U) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let parts = self.into_parts(); + Framed::from_parts(FramedParts { + io: parts.io, + codec: map(parts.codec), + read_buf: parts.read_buf, + write_buf: parts.write_buf, + _priv: (), + }) + } + /// Returns a mutable reference to the underlying codec wrapped by /// `Framed`. /// diff --git a/tokio-util/src/codec/framed_read.rs b/tokio-util/src/codec/framed_read.rs index d6e34dbafc4..184c567b498 100644 --- a/tokio-util/src/codec/framed_read.rs +++ b/tokio-util/src/codec/framed_read.rs @@ -108,6 +108,27 @@ impl FramedRead { &mut self.inner.codec } + /// Maps the decoder `D` to `C`, preserving the read buffer + /// wrapped by `Framed`. + pub fn map_decoder(self, map: F) -> FramedRead + where + F: FnOnce(D) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let FramedImpl { + inner, + state, + codec, + } = self.inner; + FramedRead { + inner: FramedImpl { + inner, + state, + codec: map(codec), + }, + } + } + /// Returns a mutable reference to the underlying decoder. pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D { self.project().inner.project().codec diff --git a/tokio-util/src/codec/framed_write.rs b/tokio-util/src/codec/framed_write.rs index b827d9736ac..aa4cec98201 100644 --- a/tokio-util/src/codec/framed_write.rs +++ b/tokio-util/src/codec/framed_write.rs @@ -88,6 +88,27 @@ impl FramedWrite { &mut self.inner.codec } + /// Maps the encoder `E` to `C`, preserving the write buffer + /// wrapped by `Framed`. + pub fn map_encoder(self, map: F) -> FramedWrite + where + F: FnOnce(E) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let FramedImpl { + inner, + state, + codec, + } = self.inner; + FramedWrite { + inner: FramedImpl { + inner, + state, + codec: map(codec), + }, + } + } + /// Returns a mutable reference to the underlying encoder. pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E { self.project().inner.project().codec diff --git a/tokio-util/tests/framed.rs b/tokio-util/tests/framed.rs index 54d78d9d2c0..ec8cdf00d09 100644 --- a/tokio-util/tests/framed.rs +++ b/tokio-util/tests/framed.rs @@ -12,7 +12,10 @@ use std::task::{Context, Poll}; const INITIAL_CAPACITY: usize = 8 * 1024; /// Encode and decode u32 values. -struct U32Codec; +#[derive(Default)] +struct U32Codec { + read_bytes: usize, +} impl Decoder for U32Codec { type Item = u32; @@ -24,6 +27,7 @@ impl Decoder for U32Codec { } let n = buf.split_to(4).get_u32(); + self.read_bytes += 4; Ok(Some(n)) } } @@ -39,6 +43,38 @@ impl Encoder for U32Codec { } } +/// Encode and decode u64 values. +#[derive(Default)] +struct U64Codec { + read_bytes: usize, +} + +impl Decoder for U64Codec { + type Item = u64; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { + if buf.len() < 8 { + return Ok(None); + } + + let n = buf.split_to(8).get_u64(); + self.read_bytes += 8; + Ok(Some(n)) + } +} + +impl Encoder for U64Codec { + type Error = io::Error; + + fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(8); + dst.put_u64(item); + Ok(()) + } +} + /// This value should never be used struct DontReadIntoThis; @@ -63,18 +99,39 @@ impl tokio::io::AsyncRead for DontReadIntoThis { #[tokio::test] async fn can_read_from_existing_buf() { - let mut parts = FramedParts::new(DontReadIntoThis, U32Codec); + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); let mut framed = Framed::from_parts(parts); let num = assert_ok!(framed.next().await.unwrap()); assert_eq!(num, 42); + assert_eq!(framed.codec().read_bytes, 4); +} + +#[tokio::test] +async fn can_read_from_existing_buf_after_codec_changed() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 84][..]); + + let mut framed = Framed::from_parts(parts); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 42); + assert_eq!(framed.codec().read_bytes, 4); + + let mut framed = framed.map_codec(|codec| U64Codec { + read_bytes: codec.read_bytes, + }); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 84); + assert_eq!(framed.codec().read_bytes, 12); } #[test] fn external_buf_grows_to_init() { - let mut parts = FramedParts::new(DontReadIntoThis, U32Codec); + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); let framed = Framed::from_parts(parts); @@ -85,7 +142,7 @@ fn external_buf_grows_to_init() { #[test] fn external_buf_does_not_shrink() { - let mut parts = FramedParts::new(DontReadIntoThis, U32Codec); + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); parts.read_buf = BytesMut::from(&vec![0; INITIAL_CAPACITY * 2][..]); let framed = Framed::from_parts(parts); diff --git a/tokio-util/tests/framed_read.rs b/tokio-util/tests/framed_read.rs index 930a631d5d0..2a9e27e22f5 100644 --- a/tokio-util/tests/framed_read.rs +++ b/tokio-util/tests/framed_read.rs @@ -50,6 +50,22 @@ impl Decoder for U32Decoder { } } +struct U64Decoder; + +impl Decoder for U64Decoder { + type Item = u64; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { + if buf.len() < 8 { + return Ok(None); + } + + let n = buf.split_to(8).get_u64(); + Ok(Some(n)) + } +} + #[test] fn read_multi_frame_in_packet() { let mut task = task::spawn(()); @@ -84,6 +100,24 @@ fn read_multi_frame_across_packets() { }); } +#[test] +fn read_multi_frame_in_packet_after_codec_changed() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0x04); + + let mut framed = framed.map_decoder(|_| U64Decoder); + assert_read!(pin!(framed).poll_next(cx), 0x08); + + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + #[test] fn read_not_ready() { let mut task = task::spawn(()); diff --git a/tokio-util/tests/framed_write.rs b/tokio-util/tests/framed_write.rs index 9ac6c1d11d4..259d9b0c9f3 100644 --- a/tokio-util/tests/framed_write.rs +++ b/tokio-util/tests/framed_write.rs @@ -39,6 +39,19 @@ impl Encoder for U32Encoder { } } +struct U64Encoder; + +impl Encoder for U64Encoder { + type Error = io::Error; + + fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(8); + dst.put_u64(item); + Ok(()) + } +} + #[test] fn write_multi_frame_in_packet() { let mut task = task::spawn(()); @@ -65,6 +78,32 @@ fn write_multi_frame_in_packet() { }); } +#[test] +fn write_multi_frame_after_codec_changed() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), + }; + let mut framed = FramedWrite::new(mock, U32Encoder); + + task.enter(|cx, _| { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0x04).is_ok()); + + let mut framed = framed.map_encoder(|_| U64Encoder); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0x08).is_ok()); + + // Nothing written yet + assert_eq!(1, framed.get_ref().calls.len()); + + // Flush the writes + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + assert_eq!(0, framed.get_ref().calls.len()); + }); +} + #[test] fn write_hits_backpressure() { const ITER: usize = 2 * 1024;