Skip to content

Commit

Permalink
process: support ChildStdin::poll_write_vectored on unix (#5216)
Browse files Browse the repository at this point in the history
  • Loading branch information
koivunej authored Nov 21, 2022
1 parent 6a2cd9a commit f15d14e
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 1 deletion.
1 change: 1 addition & 0 deletions tests-integration/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@ tokio = { path = "../tokio" }
tokio-test = { path = "../tokio-test", optional = true }
doc-comment = "0.3.1"
futures = { version = "0.3.0", features = ["async-await"] }
bytes = "1.0.0"
51 changes: 51 additions & 0 deletions tests-integration/tests/process_stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,54 @@ async fn pipe_from_one_command_to_another() {
assert!(second_status.expect("second status").success());
assert!(third_status.expect("third status").success());
}

#[tokio::test]
async fn vectored_writes() {
use bytes::{Buf, Bytes};
use std::{io::IoSlice, pin::Pin};
use tokio::io::AsyncWrite;

let mut cat = cat().spawn().unwrap();
let mut stdin = cat.stdin.take().unwrap();
let are_writes_vectored = stdin.is_write_vectored();
let mut stdout = cat.stdout.take().unwrap();

let write = async {
let mut input = Bytes::from_static(b"hello\n").chain(Bytes::from_static(b"world!\n"));
let mut writes_completed = 0;

futures::future::poll_fn(|cx| loop {
let mut slices = [IoSlice::new(&[]); 2];
let vectored = input.chunks_vectored(&mut slices);
if vectored == 0 {
return std::task::Poll::Ready(std::io::Result::Ok(()));
}
let n = futures::ready!(Pin::new(&mut stdin).poll_write_vectored(cx, &slices))?;
writes_completed += 1;
input.advance(n);
})
.await?;

drop(stdin);

std::io::Result::Ok(writes_completed)
};

let read = async {
let mut buffer = Vec::with_capacity(6 + 7);
stdout.read_to_end(&mut buffer).await?;
std::io::Result::Ok(buffer)
};

let (write, read, status) = future::join3(write, read, cat.wait()).await;

assert!(status.unwrap().success());

let writes_completed = write.unwrap();
// on unix our small payload should always fit in whatever default sized pipe with a single
// syscall. if multiple are used, then the forwarding does not work, or we are on a platform
// for which the `std` does not support vectored writes.
assert_eq!(writes_completed == 1, are_writes_vectored);

assert_eq!(&read.unwrap(), b"hello\nworld!\n");
}
2 changes: 1 addition & 1 deletion tokio/src/io/poll_evented.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ feature! {
}
}

#[cfg(feature = "net")]
#[cfg(any(feature = "net", feature = "process"))]
pub(crate) fn poll_write_vectored<'a>(
&'a self,
cx: &mut Context<'_>,
Expand Down
12 changes: 12 additions & 0 deletions tokio/src/process/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,18 @@ impl AsyncWrite for ChildStdin {
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

impl AsyncRead for ChildStdout {
Expand Down
16 changes: 16 additions & 0 deletions tokio/src/process/unix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ impl<'a> io::Write for &'a Pipe {
fn flush(&mut self) -> io::Result<()> {
(&self.fd).flush()
}

fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
(&self.fd).write_vectored(bufs)
}
}

impl AsRawFd for Pipe {
Expand Down Expand Up @@ -258,6 +262,18 @@ impl AsyncWrite for ChildStdio {
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
self.inner.poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
true
}
}

impl AsyncRead for ChildStdio {
Expand Down

0 comments on commit f15d14e

Please sign in to comment.