Skip to content

Commit

Permalink
Use chunked transfer encoding for callback fn
Browse files Browse the repository at this point in the history
  • Loading branch information
kornelski committed Jul 19, 2019
1 parent 5285ef8 commit aca14b5
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ impl Mock {
self
}

/// Sets the body of the mock response dynamically. The response is buffered, so the `Content-Length` is handled automatically.
/// Sets the body of the mock response dynamically. The response will use chunked transfer encoding.
///
/// The function must be thread-safe. If it's a closure, it can't be borrowing its context.
/// Use `move` closures and `Arc` to share any data.
Expand Down
33 changes: 32 additions & 1 deletion src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl fmt::Debug for Body {
}

impl PartialEq for Body {
fn eq(&self, other: &Body) -> bool {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Body::Bytes(ref a), Body::Bytes(ref b)) => a == b,
(Body::Fn(ref a), Body::Fn(ref b)) => Arc::ptr_eq(a, b),
Expand All @@ -45,6 +45,37 @@ impl Default for Response {
}
}

pub(crate) struct Chunked<W: io::Write> {
writer: W
}

impl<W: io::Write> Chunked<W> {
pub fn new(writer: W) -> Self {
Self {writer}
}

pub fn finish(mut self) -> io::Result<W> {
self.writer.write_all(b"0\r\n\r\n")?;
Ok(self.writer)
}
}

impl<W: io::Write> io::Write for Chunked<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
self.writer.write_all(format!("{:x}\r\n", buf.len()).as_bytes())?;
self.writer.write_all(buf)?;
self.writer.write_all(b"\r\n")?;
Ok(buf.len())
}

fn flush(&mut self) -> io::Result<()> {
self.writer.flush()
}
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature="cargo-clippy", allow(clippy::pub_enum_variant_names))]
pub enum Status {
Expand Down
37 changes: 19 additions & 18 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::net::{TcpListener, TcpStream, SocketAddr};
use std::sync::Mutex;
use std::sync::mpsc;
use {SERVER_ADDRESS_INTERNAL, Request, Mock};
use response::Body;
use response::{Chunked, Body};

impl Mock {
fn method_matches(&self, request: &Request) -> bool {
Expand Down Expand Up @@ -183,27 +183,28 @@ fn respond_bytes(
has_content_length_header = headers.iter().any(|(key, _)| key == "content-length");
}

let mut buffer;
let body = match body {
Some(Body::Bytes(bytes)) => Some(&bytes[..]),
Some(Body::Fn(cb)) => {
// we don't implement transfer-encoding: chunked, so need to buffer
buffer = Vec::new();
let _ = cb(&mut buffer);
Some(&buffer[..])
match body {
Some(Body::Bytes(bytes)) => if !has_content_length_header {
response.extend(format!("content-length: {}\r\n", bytes.len()).as_bytes());
},
Some(Body::Fn(_)) => {
response.extend(b"transfer-encoding: chunked\r\n");
},
None => None,
None => {},
};
if let Some(bytes) = body {
if !has_content_length_header {
response.extend(format!("content-length: {}\r\n", bytes.len()).as_bytes());
}
}
response.extend(b"\r\n");
let _ = stream.write(&response);
if let Some(bytes) = body {
let _ = stream.write(bytes);
}
match body {
Some(Body::Bytes(bytes)) => {
let _ = stream.write_all(bytes);
},
Some(Body::Fn(cb)) => {
let mut chunked = Chunked::new(&mut stream);
let _ = cb(&mut chunked);
let _ = chunked.finish();
},
None => {},
};
let _ = stream.flush();
}

Expand Down
53 changes: 50 additions & 3 deletions tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,48 @@ fn parse_stream(stream: TcpStream, skip_body: bool) -> (String, Vec<String>, Str
reader.read_line(&mut status_line).unwrap();

let mut headers = vec![];
let mut content_length: u64 = 0;
let mut content_length: Option<u64> = None;
let mut is_chunked = false;
loop {
let mut header_line = String::new();
reader.read_line(&mut header_line).unwrap();

if header_line == "\r\n" { break }

if header_line.starts_with("transfer-encoding:") && header_line.contains("chunked") {
is_chunked = true;
}

if header_line.starts_with("content-length:") {
let mut parts = header_line.split(':');
content_length = u64::from_str(parts.nth(1).unwrap().trim()).unwrap();
content_length = Some(u64::from_str(parts.nth(1).unwrap().trim()).unwrap());
}

headers.push(header_line.trim_end().to_string());
}

let mut body = String::new();
if !skip_body {
reader.take(content_length).read_to_string(&mut body).unwrap();
if let Some(content_length) = content_length {
reader.take(content_length).read_to_string(&mut body).unwrap();
} else if is_chunked {
let mut chunk_size_buf = String::new();
loop {
chunk_size_buf.clear();
reader.read_line(&mut chunk_size_buf).unwrap();

let chunk_size = u64::from_str_radix(chunk_size_buf.trim_matches(|c| c == '\r' || c == '\n'), 16)
.expect("chunk size");
if chunk_size == 0 {
break;
}

(&mut reader).take(chunk_size).read_to_string(&mut body).unwrap();

chunk_size_buf.clear();
reader.read_line(&mut chunk_size_buf).unwrap();
}
}
}

(status_line, headers, body)
Expand Down Expand Up @@ -348,6 +372,29 @@ fn test_mock_with_custom_status() {
assert_eq!("HTTP/1.1 333 Custom\r\n", status_line);
}

#[test]
fn test_mock_with_body() {
let _m = mock("GET", "/")
.with_body("hello")
.create();

let (_, _, body) = request("GET /", "");
assert_eq!("hello", body);
}

#[test]
fn test_mock_with_fn_body() {
let _m = mock("GET", "/")
.with_body_from_fn(|w| {
w.write_all(b"hel")?;
w.write_all(b"lo")
})
.create();

let (_, _, body) = request("GET /", "");
assert_eq!("hello", body);
}

#[test]
fn test_mock_with_header() {
let _m = mock("GET", "/")
Expand Down

0 comments on commit aca14b5

Please sign in to comment.