Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make copy continue filling the buffer when writer stalls #5066

Merged
merged 7 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions benches/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ edition = "2018"
[dependencies]
tokio = { version = "1.5.0", path = "../tokio", features = ["full"] }
bencher = "0.1.5"
rand = "0.8"
rand_chacha = "0.3"

[dev-dependencies]
tokio-util = { version = "0.7.0", path = "../tokio-util", features = ["full"] }
Expand Down Expand Up @@ -50,3 +52,8 @@ harness = false
name = "fs"
path = "fs.rs"
harness = false

[[bench]]
name = "copy"
path = "copy.rs"
harness = false
238 changes: 238 additions & 0 deletions benches/copy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
use bencher::{benchmark_group, benchmark_main, Bencher};

use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;

use tokio::io::{copy, repeat, AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::time::{interval, Interval, MissedTickBehavior};

use std::task::Poll;
use std::time::Duration;

const KILO: usize = 1024;

// Tunable parameters if you want to change this benchmark. If reader and writer
// are matched in kilobytes per second, then this only exposes buffering to the
// benchmark.
const RNG_SEED: u64 = 0;
// How much data to copy in a single benchmark run
const SOURCE_SIZE: u64 = 256 * KILO as u64;
// Read side provides CHUNK_SIZE every READ_SERVICE_PERIOD. If it's not called
// frequently, it'll burst to catch up (representing OS buffers draining)
const CHUNK_SIZE: usize = 2 * KILO;
const READ_SERVICE_PERIOD: Duration = Duration::from_millis(1);
// Write side buffers up to WRITE_BUFFER, and flushes to disk every
// WRITE_SERVICE_PERIOD.
const WRITE_BUFFER: usize = 40 * KILO;
const WRITE_SERVICE_PERIOD: Duration = Duration::from_millis(20);
// How likely you are to have to wait for previously written data to be flushed
// because another writer claimed the buffer space
const PROBABILITY_FLUSH_WAIT: f64 = 0.1;

/// A slow writer that aims to simulate HDD behaviour under heavy load.
///
/// There is a limited buffer, which is fully drained on the next write after
/// a time limit is reached. Flush waits for the time limit to be reached
/// and then drains the buffer.
///
/// At random, the HDD will stall writers while it flushes out all buffers. If
/// this happens to you, you will be unable to write until the next time the
/// buffer is drained.
struct SlowHddWriter {
service_intervals: Interval,
blocking_rng: ChaCha20Rng,
buffer_size: usize,
buffer_used: usize,
}

impl SlowHddWriter {
fn new(service_interval: Duration, buffer_size: usize) -> Self {
let blocking_rng = ChaCha20Rng::seed_from_u64(RNG_SEED);
let mut service_intervals = interval(service_interval);
service_intervals.set_missed_tick_behavior(MissedTickBehavior::Delay);
Self {
service_intervals,
blocking_rng,
buffer_size,
buffer_used: 0,
}
}

fn service_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
// If we hit a service interval, the buffer can be cleared
let res = self.service_intervals.poll_tick(cx).map(|_| Ok(()));
if let Poll::Ready(_) = res {
self.buffer_used = 0;
}
res
}

fn write_bytes(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
writeable: usize,
) -> std::task::Poll<Result<usize, std::io::Error>> {
let service_res = self.as_mut().service_write(cx);

if service_res.is_pending() && self.blocking_rng.gen_bool(PROBABILITY_FLUSH_WAIT) {
return Poll::Pending;
}
let available = self.buffer_size - self.buffer_used;

if available == 0 {
assert!(service_res.is_pending());
Poll::Pending
} else {
let written = available.min(writeable);
self.buffer_used += written;
Poll::Ready(Ok(written))
}
}
}

impl Unpin for SlowHddWriter {}

impl AsyncWrite for SlowHddWriter {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
self.write_bytes(cx, buf.len())
}

fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
self.service_write(cx)
}

fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
self.service_write(cx)
}

fn poll_write_vectored(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let writeable = bufs.into_iter().fold(0, |acc, buf| acc + buf.len());
self.write_bytes(cx, writeable)
}

fn is_write_vectored(&self) -> bool {
true
}
}

/// A reader that limits the maximum chunk it'll give you back
///
/// Simulates something reading from a slow link - you get one chunk per call,
/// and you are offered chunks on a schedule
struct ChunkReader {
data: Vec<u8>,
service_intervals: Interval,
}

impl ChunkReader {
fn new(chunk_size: usize, service_interval: Duration) -> Self {
let mut service_intervals = interval(service_interval);
service_intervals.set_missed_tick_behavior(MissedTickBehavior::Burst);
let data: Vec<u8> = std::iter::repeat(0).take(chunk_size).collect();
Self {
data,
service_intervals,
}
}
}

impl AsyncRead for ChunkReader {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.service_intervals.poll_tick(cx).is_pending() {
return Poll::Pending;
}
buf.put_slice(&self.data[..buf.remaining().min(self.data.len())]);
Poll::Ready(Ok(()))
}
}

fn rt() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_time()
.build()
.unwrap()
}

fn copy_mem_to_mem(b: &mut Bencher) {
let rt = rt();

b.iter(|| {
let task = || async {
let mut source = repeat(0).take(SOURCE_SIZE);
let mut dest = Vec::new();
copy(&mut source, &mut dest).await.unwrap();
};

rt.block_on(task());
})
}

fn copy_mem_to_slow_hdd(b: &mut Bencher) {
let rt = rt();

b.iter(|| {
let task = || async {
let mut source = repeat(0).take(SOURCE_SIZE);
let mut dest = SlowHddWriter::new(WRITE_SERVICE_PERIOD, WRITE_BUFFER);
copy(&mut source, &mut dest).await.unwrap();
};

rt.block_on(task());
})
}

fn copy_chunk_to_mem(b: &mut Bencher) {
let rt = rt();
b.iter(|| {
let task = || async {
let mut source = ChunkReader::new(CHUNK_SIZE, READ_SERVICE_PERIOD).take(SOURCE_SIZE);
let mut dest = Vec::new();
copy(&mut source, &mut dest).await.unwrap();
};

rt.block_on(task());
})
}

fn copy_chunk_to_slow_hdd(b: &mut Bencher) {
let rt = rt();
b.iter(|| {
let task = || async {
let mut source = ChunkReader::new(CHUNK_SIZE, READ_SERVICE_PERIOD).take(SOURCE_SIZE);
let mut dest = SlowHddWriter::new(WRITE_SERVICE_PERIOD, WRITE_BUFFER);
copy(&mut source, &mut dest).await.unwrap();
};

rt.block_on(task());
})
}

benchmark_group!(
copy_bench,
copy_mem_to_mem,
copy_mem_to_slow_hdd,
copy_chunk_to_mem,
copy_chunk_to_slow_hdd,
);
benchmark_main!(copy_bench);
62 changes: 49 additions & 13 deletions tokio/src/io/util/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,51 @@ impl CopyBuffer {
}
}

fn poll_fill_buf<R>(
&mut self,
cx: &mut Context<'_>,
reader: Pin<&mut R>,
) -> Poll<io::Result<()>>
where
R: AsyncRead + ?Sized,
{
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
buf.set_filled(me.cap);

let res = reader.poll_read(cx, &mut buf);
if let Poll::Ready(Ok(_)) = res {
let filled_len = buf.filled().len();
me.read_done = me.cap == filled_len;
me.cap = filled_len;
}
res
}

fn poll_write_buf<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<usize>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
let me = &mut *self;
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
Poll::Pending => {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap < me.buf.len() {
ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
}
Poll::Pending
}
res => res,
}
}

pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
Expand All @@ -41,10 +86,10 @@ impl CopyBuffer {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
self.pos = 0;
self.cap = 0;

match reader.as_mut().poll_read(cx, &mut buf) {
match self.poll_fill_buf(cx, reader.as_mut()) {
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
Expand All @@ -58,20 +103,11 @@ impl CopyBuffer {
return Poll::Pending;
}
}

let n = buf.filled().len();
if n == 0 {
self.read_done = true;
} else {
self.pos = 0;
self.cap = n;
}
}

// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let me = &mut *self;
let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?;
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
Expand Down