Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http patch: only buffer the top of the response #400

Merged
merged 2 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ trillium-client = { path = "../client" }
trillium-smol = { path = "../smol" }
trillium-testing = { path = "../testing" }
trillium-http = { path = ".", features = ["http-compat"] }
pretty_assertions = "1.4.0"
fastrand = "2.0.1"

[dev-dependencies.tokio]
version = "1.29.1"
Expand Down
299 changes: 269 additions & 30 deletions http/src/bufwriter.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use futures_lite::{AsyncRead, AsyncWrite};
use std::{
fmt,
io::{Error, ErrorKind, Result},
io::{Error, ErrorKind, IoSlice, Result},
pin::Pin,
task::{ready, Context, Poll},
};
Expand All @@ -11,31 +11,31 @@ use trillium_macros::AsyncRead;
pub(crate) struct BufWriter<W> {
#[async_read]
inner: W,
buf: Vec<u8>,
written: usize,
buffer: Vec<u8>,
written_to_inner: usize,
}

impl<W: AsyncWrite + Unpin> BufWriter<W> {
pub(crate) fn new_with_buffer(buf: Vec<u8>, inner: W) -> Self {
pub(crate) fn new_with_buffer(buffer: Vec<u8>, inner: W) -> Self {
Self {
inner,
buf,
written: 0,
buffer,
written_to_inner: 0,
}
}

fn poll_flush_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
fn poll_flush_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<usize>> {
let Self {
inner,
buf,
written,
buffer,
written_to_inner,
} = &mut *self;

let len = buf.len();
let mut ret = Ok(());
let len = buffer.len();
let mut ret = Ok(0);

while *written < len {
let buf = &buf[*written..];
while *written_to_inner < len {
let buf = &buffer[*written_to_inner..];
match ready!(Pin::new(&mut *inner).poll_write(cx, buf)) {
Ok(0) => {
ret = Err(Error::new(
Expand All @@ -44,7 +44,7 @@ impl<W: AsyncWrite + Unpin> BufWriter<W> {
));
break;
}
Ok(n) => *written += n,
Ok(n) => *written_to_inner += n,
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
Err(e) => {
ret = Err(e);
Expand All @@ -53,11 +53,6 @@ impl<W: AsyncWrite + Unpin> BufWriter<W> {
}
}

if *written > 0 {
buf.drain(..*written);
}
*written = 0;

Poll::Ready(ret)
}
}
Expand All @@ -66,8 +61,8 @@ impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BufWriter")
.field("writer", &self.inner)
.field("buf", &self.buf)
.field("written", &self.written)
.field("buf", &self.buffer)
.field("written", &self.written_to_inner)
.finish()
}
}
Expand All @@ -76,16 +71,37 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for BufWriter<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
additional: &[u8],
) -> Poll<Result<usize>> {
if self.buf.len() + buf.len() > self.buf.capacity() {
ready!(self.as_mut().poll_flush_buf(cx))?;
}
if buf.len() >= self.buf.capacity() {
Pin::new(&mut self.inner).poll_write(cx, buf)
} else {
self.buf.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
let Self {
inner,
buffer,
written_to_inner,
} = &mut *self;
loop {
let len = buffer.len();
let pending_buffer = &buffer[len.min(*written_to_inner)..];
let pending_bytes = pending_buffer.len();
let new_bytes = additional.len();
let new_len_would_be = len + new_bytes;
if *written_to_inner == 0 && new_len_would_be <= buffer.capacity() {
buffer.extend_from_slice(additional);
return Poll::Ready(Ok(additional.len()));
} else if !pending_buffer.is_empty() {
let written = ready!(Pin::new(&mut *inner).poll_write_vectored(
cx,
&[IoSlice::new(pending_buffer), IoSlice::new(additional)]
))?;
*written_to_inner += written;
let written_from_additional = written.saturating_sub(pending_bytes);
if written_from_additional != 0 {
return Poll::Ready(Ok(written_from_additional));
}
} else {
let written = ready!(Pin::new(&mut *inner).poll_write(cx, additional))?;
*written_to_inner += written;
return Poll::Ready(Ok(written));
}
}
}

Expand All @@ -99,3 +115,226 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for BufWriter<W> {
Pin::new(&mut self.inner).poll_close(cx)
}
}

#[cfg(test)]
mod tests {
use futures_lite::AsyncWriteExt;
use pretty_assertions::assert_eq;

use super::*;
#[derive(Default)]
struct TestWrite {
writes: Vec<Vec<u8>>,
max_write: Option<usize>,
}
impl AsyncWrite for TestWrite {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
let written = self.max_write.map_or(buf.len(), |mw| mw.min(buf.len()));
self.writes.push(buf[..written].to_vec());
Poll::Ready(Ok(written))
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize>> {
self.poll_write(cx, &bufs.iter().map(|s| &**s).collect::<Vec<_>>().concat())
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
}

impl TestWrite {
fn new(max_write: Option<usize>) -> Self {
Self {
max_write,
..Self::default()
}
}

fn data(&self) -> Vec<u8> {
self.writes.concat()
}
}

fn rand_bytes<const LEN: usize>() -> [u8; LEN] {
std::array::from_fn(|_| fastrand::u8(..))
}

#[test]
fn entire_content_shorter_than_capacity() {
futures_lite::future::block_on(async {
let data = rand_bytes::<90>();
let mut tw = TestWrite::new(None);
let mut bw = BufWriter::new_with_buffer(Vec::with_capacity(100), &mut tw);
bw.write_all(&data).await.unwrap();
assert_eq!(bw.inner.writes.len(), 0);
bw.flush().await.unwrap();
assert_eq!(&bw.inner.writes, &[&data]);
});
}

#[test]
fn longer_than_capacity_but_still_a_single_write() {
futures_lite::future::block_on(async {
let data = rand_bytes::<200>();
let mut tw = TestWrite::new(None);
let mut bw = BufWriter::new_with_buffer(Vec::with_capacity(100), &mut tw);
bw.write_all(&data).await.unwrap();
assert_eq!(&bw.inner.writes, &[&data]);
bw.flush().await.unwrap();
assert_eq!(&bw.inner.writes, &[&data]);
});
}

#[test]
fn multiple_writes() {
futures_lite::future::block_on(async {
let data = rand_bytes::<250>();
let mut tw = TestWrite::new(None);
let mut bw = BufWriter::new_with_buffer(Vec::with_capacity(100), &mut tw);
bw.write_all(&data[..200]).await.unwrap();
bw.write_all(&data[200..]).await.unwrap();
assert_eq!(&bw.inner.writes, &[&data[..200], &data[200..]]);
bw.flush().await.unwrap();
assert_eq!(&bw.inner.writes, &[&data[..200], &data[200..]]);
});
}

#[test]
fn overflow_is_vectored() {
futures_lite::future::block_on(async {
let data = rand_bytes::<101>();
let mut tw = TestWrite::new(None);
let mut bw = BufWriter::new_with_buffer(Vec::with_capacity(100), &mut tw);
bw.write_all(&data[..50]).await.unwrap();
bw.write_all(&data[50..]).await.unwrap();
assert_eq!(&bw.inner.writes, &[&data]);
bw.flush().await.unwrap();
assert_eq!(&bw.inner.writes, &[&data]);
});
}

#[test]
fn max_write() {
futures_lite::future::block_on(async {
let data = rand_bytes::<200>();
let mut tw = TestWrite::new(Some(50));
let mut bw = BufWriter::new_with_buffer(Vec::with_capacity(100), &mut tw);
bw.write_all(&data[..10]).await.unwrap();
bw.write_all(&data[10..20]).await.unwrap();
bw.write_all(&data[20..45]).await.unwrap();
bw.write_all(&data[45..125]).await.unwrap();
bw.write_all(&data[125..]).await.unwrap();
for write in &bw.inner.writes {
println!(
"{}",
write
.iter()
.map(u8::to_string)
.collect::<Vec<_>>()
.join(",")
);
}
assert_eq!(
&bw.inner.writes,
&[
&data[0..50],
&data[50..100],
&data[100..125],
&data[125..175],
&data[175..]
]
);
bw.flush().await.unwrap();
assert_eq!(&bw.inner.data(), &data);
});
}

#[test]
fn write_boundary_is_exactly_buffer_len() {
futures_lite::future::block_on(async {
let data = rand_bytes::<200>();
let mut tw = TestWrite::new(Some(50));
let mut bw = BufWriter::new_with_buffer(Vec::with_capacity(100), &mut tw);
bw.write_all(&data[..10]).await.unwrap();
bw.write_all(&data[10..20]).await.unwrap();
bw.write_all(&data[20..50]).await.unwrap();
bw.write_all(&data[50..125]).await.unwrap();
bw.write_all(&data[125..]).await.unwrap();
assert_eq!(
&bw.inner.writes,
&[
&data[0..50],
&data[50..100],
&data[100..125],
&data[125..175],
&data[175..]
]
);
bw.flush().await.unwrap();
assert_eq!(&bw.inner.data(), &data);
});
}

#[test]
fn buffer_is_exactly_full() {
futures_lite::future::block_on(async {
let data = rand_bytes::<200>();
let mut tw = TestWrite::new(None);
let mut bw = BufWriter::new_with_buffer(Vec::with_capacity(100), &mut tw);
bw.write_all(&data[..100]).await.unwrap();
bw.write_all(&data[100..]).await.unwrap();
assert_eq!(&bw.inner.writes, &[&data]);
bw.flush().await.unwrap();
assert_eq!(&bw.inner.data(), &data);
});
}

fn test_x<const SIZE: usize>(capacity: usize, max_write: Option<usize>, split: usize) {
futures_lite::future::block_on(async {
for _ in 0..100 {
let data = rand_bytes::<SIZE>();
let mut tw = TestWrite::new(max_write);
let mut bw = BufWriter::new_with_buffer(Vec::with_capacity(capacity), &mut tw);
bw.write_all(&data[..split]).await.unwrap();
bw.write_all(&data[split..]).await.unwrap();
bw.flush().await.unwrap();
assert_eq!(
&bw.inner.data(),
&data,
"test_x({},{:?},{split})",
bw.buffer.capacity(),
bw.inner.max_write
);
}
});
}

#[test]
fn known_bad() {
test_x::<200>(188, Some(47), 123);
}

#[test]
fn random() {
for _ in 0..100 {
test_x::<200>(
fastrand::usize(1..200),
Some(fastrand::usize(1..200)),
fastrand::usize(1..200),
);
}
}
}
Loading