Skip to content

Commit

Permalink
io: add a copy_bidirectional utility (#3572)
Browse files Browse the repository at this point in the history
  • Loading branch information
conblem authored Apr 12, 2021
1 parent 08f1b67 commit adad8fc
Show file tree
Hide file tree
Showing 5 changed files with 329 additions and 51 deletions.
2 changes: 1 addition & 1 deletion tokio/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ cfg_io_util! {
pub(crate) mod seek;
pub(crate) mod util;
pub use util::{
copy, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
copy, copy_bidirectional, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
BufReader, BufStream, BufWriter, DuplexStream, Empty, Lines, Repeat, Sink, Split, Take,
};
}
Expand Down
128 changes: 78 additions & 50 deletions tokio/src/io/util/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,85 @@ use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};

#[derive(Debug)]
pub(super) struct CopyBuffer {
read_done: bool,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}

impl CopyBuffer {
pub(super) fn new() -> Self {
Self {
read_done: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; 2048].into_boxed_slice(),
}
}

pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<u64>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
ready!(reader.as_mut().poll_read(cx, &mut buf))?;
let n = buf.filled().len();
if n == 0 {
self.read_done = true;
} else {
self.pos = 0;
self.cap = n;
}
}

// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let me = &mut *self;
let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
} else {
self.pos += i;
self.amt += i as u64;
}
}

// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
}
}

/// A future that asynchronously copies the entire contents of a reader into a
/// writer.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct Copy<'a, R: ?Sized, W: ?Sized> {
reader: &'a mut R,
read_done: bool,
writer: &'a mut W,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
buf: CopyBuffer,
}

cfg_io_util! {
Expand All @@ -35,8 +102,8 @@ cfg_io_util! {
///
/// # Errors
///
/// The returned future will finish with an error will return an error
/// immediately if any call to `poll_read` or `poll_write` returns an error.
/// The returned future will return an error immediately if any call to
/// `poll_read` or `poll_write` returns an error.
///
/// # Examples
///
Expand All @@ -60,12 +127,8 @@ cfg_io_util! {
{
Copy {
reader,
read_done: false,
writer,
amt: 0,
pos: 0,
cap: 0,
buf: vec![0; 2048].into_boxed_slice(),
buf: CopyBuffer::new()
}.await
}
}
Expand All @@ -78,44 +141,9 @@ where
type Output = io::Result<u64>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?;
let n = buf.filled().len();
if n == 0 {
self.read_done = true;
} else {
self.pos = 0;
self.cap = n;
}
}
let me = &mut *self;

// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let me = &mut *self;
let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, &me.buf[me.pos..me.cap]))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
} else {
self.pos += i;
self.amt += i as u64;
}
}

// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
let me = &mut *self;
ready!(Pin::new(&mut *me.writer).poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
me.buf
.poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer))
}
}
119 changes: 119 additions & 0 deletions tokio/src/io/util/copy_bidirectional.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use super::copy::CopyBuffer;

use crate::io::{AsyncRead, AsyncWrite};

use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};

enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}

struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
a: &'a mut A,
b: &'a mut B,
a_to_b: TransferState,
b_to_a: TransferState,
}

fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
r: &mut A,
w: &mut B,
) -> Poll<io::Result<u64>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut r = Pin::new(r);
let mut w = Pin::new(w);

loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx))?;

*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
}
}
}

impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<(u64, u64)>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Unpack self into mut refs to each field to avoid borrow check issues.
let CopyBidirectional {
a,
b,
a_to_b,
b_to_a,
} = &mut *self;

let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?;
let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?;

// It is not a problem if ready! returns early because transfer_one_direction for the
// other direction will keep returning TransferState::Done(count) in future calls to poll
let a_to_b = ready!(a_to_b);
let b_to_a = ready!(b_to_a);

Poll::Ready(Ok((a_to_b, b_to_a)))
}
}

/// Copies data in both directions between `a` and `b`.
///
/// This function returns a future that will read from both streams,
/// writing any data read to the opposing stream.
/// This happens in both directions concurrently.
///
/// If an EOF is observed on one stream, [`shutdown()`] will be invoked on
/// the other, and reading from that stream will stop. Copying of data in
/// the other direction will continue.
///
/// The future will complete successfully once both directions of communication has been shut down.
/// A direction is shut down when the reader reports EOF,
/// at which point [`shutdown()`] is called on the corresponding writer. When finished,
/// it will return a tuple of the number of bytes copied from a to b
/// and the number of bytes copied from b to a, in that order.
///
/// [`shutdown()`]: crate::io::AsyncWriteExt::shutdown
///
/// # Errors
///
/// The future will immediately return an error if any IO operation on `a`
/// or `b` returns an error. Some data read from either stream may be lost (not
/// written to the other stream) in this case.
///
/// # Return value
///
/// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`.
pub async fn copy_bidirectional<A, B>(a: &mut A, b: &mut B) -> Result<(u64, u64), std::io::Error>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
CopyBidirectional {
a,
b,
a_to_b: TransferState::Running(CopyBuffer::new()),
b_to_a: TransferState::Running(CopyBuffer::new()),
}
.await
}
3 changes: 3 additions & 0 deletions tokio/src/io/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ cfg_io_util! {
mod copy;
pub use copy::copy;

mod copy_bidirectional;
pub use copy_bidirectional::copy_bidirectional;

mod copy_buf;
pub use copy_buf::copy_buf;

Expand Down
Loading

0 comments on commit adad8fc

Please sign in to comment.