Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inspect io wrappers #5033

Merged
merged 10 commits into from
Sep 28, 2022
122 changes: 122 additions & 0 deletions tokio-util/src/io/inspect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use futures_core::ready;
use pin_project_lite::pin_project;
use std::io::{IoSlice, Result};
use std::pin::Pin;
use std::task::{Context, Poll};

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

pin_project! {
/// An adapter that lets you inspect the data that's being read.
///
/// This is useful for things like hashing data as it's read in.
pub struct InspectReader<R, F> {
#[pin]
reader: R,
f: F,
}
}

impl<R, F> InspectReader<R, F> {
/// Create a new InspectReader, wrapping `reader` and calling `f` for the
/// new data supplied by each read call.
pub fn new(reader: R, f: F) -> InspectReader<R, F>
where
R: AsyncRead,
F: FnMut(&[u8]),
{
InspectReader { reader, f }
}

/// Consumes the `InspectReader`, returning the wrapped reader
pub fn into_inner(self) -> R {
self.reader
}
}

impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
let me = self.project();
let filled_length = buf.filled().len();
ready!(me.reader.poll_read(cx, buf))?;
(me.f)(&buf.filled()[filled_length..]);
Poll::Ready(Ok(()))
}
}

pin_project! {
/// An adapter that lets you inspect the data that's being written.
///
/// This is useful for things like hashing data as it's written out.
pub struct InspectWriter<W, F> {
#[pin]
writer: W,
f: F,
}
}

impl<W, F> InspectWriter<W, F> {
/// Create a new InspectWriter, wrapping `write` and calling `f` for the
/// data successfully written by each write call.
pub fn new(writer: W, f: F) -> InspectWriter<W, F>
where
W: AsyncWrite,
F: FnMut(&[u8]),
{
InspectWriter { writer, f }
}

/// Consumes the `InspectWriter`, returning the wrapped writer
pub fn into_inner(self) -> W {
self.writer
}
}

impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
let me = self.project();
let res = me.writer.poll_write(cx, buf);
if let Poll::Ready(Ok(count)) = res {
(me.f)(&buf[..count]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if count is zero?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In current code, f gets passed an empty slice, and is expected to handle this - which is fine for hashing and teeing uses,

I would prefer to document this, but can put a guard in place to stop you getting an empty slice if you'd prefer that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that we should avoid empty slices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do so for both writes and reads - the user should know about EOF on read, or ENOSPC on write from the "main" write or read process, and thus be able to handle that without needing the inspection callback to be called.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing an empty slice on EOF read seems ok, but otherwise I would not expect the closure to ever get passed an empty slice.

res
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let me = self.project();
me.writer.poll_flush(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let me = self.project();
me.writer.poll_shutdown(cx)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize>> {
let me = self.project();
let res = me.writer.poll_write_vectored(cx, bufs);
if let Poll::Ready(Ok(mut count)) = res {
for buf in bufs {
let size = count.min(buf.len());
(me.f)(&buf[..size]);
count -= size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if buf is empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then f gets called with an empty slice, since count.min(buf.len()) should be zero, and &buf[..0] is an empty slice.

I only see two cases in which this can happen, and in both cases I think f should be able to handle it (reasoning as above):

  1. The wrapped AsyncWrite genuinely returned a zero byte write. In this case, it is claiming to successfully write zero bytes, which is allowed in the contract (implying that the underlying object can't accept more data ever again).
  2. The caller supplied one or more empty buffers to poll_write_vectored. In this case, you're being told that the supplied empty buffer was indeed written, which is accurate (if not particularly helpful).

I will add a documentation note to warn users that f must cope with empty buffers to both wrappers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoiding empty slices instead.

if count == 0 {
break;
}
}
}
res
}

fn is_write_vectored(&self) -> bool {
self.writer.is_write_vectored()
}
}
3 changes: 3 additions & 0 deletions tokio-util/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html
//! [`AsyncRead`]: tokio::io::AsyncRead

mod inspect;
mod read_buf;
mod reader_stream;
mod stream_reader;

cfg_io_util! {
mod sync_bridge;
pub use self::sync_bridge::SyncIoBridge;
}

pub use self::inspect::{InspectReader, InspectWriter};
pub use self::read_buf::read_buf;
pub use self::reader_stream::ReaderStream;
pub use self::stream_reader::StreamReader;
Expand Down
194 changes: 194 additions & 0 deletions tokio-util/tests/io_inspect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
use futures::future::poll_fn;
use std::{
io::IoSlice,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_util::io::{InspectReader, InspectWriter};

/// An AsyncRead implementation that works byte-by-byte, to catch out callers
/// who don't allow for `buf` being part-filled before the call
struct SmallReader {
contents: Vec<u8>,
}

impl Unpin for SmallReader {}

impl AsyncRead for SmallReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if let Some(byte) = self.contents.pop() {
buf.put_slice(&[byte])
}
Poll::Ready(Ok(()))
}
}

#[tokio::test]
async fn read_tee() {
let contents = b"This could be really long, you know".to_vec();
let reader = SmallReader {
contents: contents.clone(),
};
let mut altout: Vec<u8> = Vec::new();
let mut teeout = Vec::new();
{
let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes));
tee.read_to_end(&mut teeout).await.unwrap();
}
assert_eq!(teeout, altout);
assert_eq!(altout.len(), contents.len());
}

/// An AsyncWrite implementation that works byte-by-byte for poll_write, and
/// that reads the whole of the first buffer plus one byte from the second in
/// poll_write_vectored.
///
/// This is designed to catch bugs in handling partially written buffers
#[derive(Debug)]
struct SmallWriter {
contents: Vec<u8>,
}

impl Unpin for SmallWriter {}

impl AsyncWrite for SmallWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
// Just write one byte at a time
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
self.contents.push(buf[0]);
Poll::Ready(Ok(1))
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Poll::Ready(Ok(()))
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
// Write all of the first buffer, then one byte from the second buffer
// This should trip up anything that doesn't correctly handle multiple
// buffers.
if bufs.is_empty() {
return Poll::Ready(Ok(0));
}
let mut written_len = bufs[0].len();
self.contents.extend_from_slice(&bufs[0]);

if bufs.len() > 1 {
let buf = bufs[1];
if !buf.is_empty() {
written_len += 1;
self.contents.push(buf[0]);
}
}
Poll::Ready(Ok(written_len))
}

fn is_write_vectored(&self) -> bool {
true
}
}

#[tokio::test]
async fn write_tee() {
let mut altout: Vec<u8> = Vec::new();
let mut writeout = SmallWriter {
contents: Vec::new(),
};
{
let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes));
tee.write_all(b"A testing string, very testing")
.await
.unwrap();
}
assert_eq!(altout, writeout.contents);
}

// This is inefficient, but works well enough for test use.
// If you want something similar for real code, you'll want to avoid all the
// fun of manipulating `bufs` - ideally, by the time you read this,
// IoSlice::advance_slices will be stable, and you can use that.
async fn write_all_vectored<W: AsyncWrite + Unpin>(
mut writer: W,
mut bufs: Vec<Vec<u8>>,
) -> Result<usize, std::io::Error> {
let mut res = 0;
while !bufs.is_empty() {
let mut written = poll_fn(|cx| {
let bufs: Vec<IoSlice> = bufs.iter().map(|v| IoSlice::new(v)).collect();
Pin::new(&mut writer).poll_write_vectored(cx, &bufs)
})
.await?;
res += written;
while written > 0 {
let buf_len = bufs[0].len();
if buf_len <= written {
bufs.remove(0);
written -= buf_len;
} else {
let buf = &mut bufs[0];
let drain_len = written.min(buf.len());
buf.drain(..drain_len);
written -= drain_len;
}
}
}
Ok(res)
}

#[tokio::test]
async fn write_tee_vectored() {
let mut altout: Vec<u8> = Vec::new();
let mut writeout = SmallWriter {
contents: Vec::new(),
};
let original = b"A very long string split up";
let bufs: Vec<Vec<u8>> = original
.split(|b| b.is_ascii_whitespace())
.map(Vec::from)
.collect();
assert!(bufs.len() > 1);
let expected: Vec<u8> = {
let mut out = Vec::new();
for item in &bufs {
out.extend_from_slice(item)
}
out
};
{
let mut bufcount = 0;
let tee = InspectWriter::new(&mut writeout, |bytes| {
bufcount += 1;
altout.extend(bytes)
});

assert!(tee.is_write_vectored());

write_all_vectored(tee, bufs.clone()).await.unwrap();

assert!(bufcount >= bufs.len());
}
assert_eq!(altout, writeout.contents);
assert_eq!(writeout.contents, expected);
}