Skip to content

Commit

Permalink
Merge pull request #80 from kornelski/master
Browse files Browse the repository at this point in the history
Allow generating response programmatically
  • Loading branch information
lipanski authored Jul 20, 2019
2 parents c5497bf + f9c1a94 commit 2ba5b6c
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 42 deletions.
46 changes: 29 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,7 @@ mod diff;
type Request = request::Request;
type Response = response::Response;

use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::convert::{From, Into};
use std::ops::Drop;
use std::fmt;
Expand All @@ -492,6 +491,8 @@ use regex::Regex;
use std::sync::{Mutex, LockResult, MutexGuard};
use std::cell::RefCell;
use percent_encoding::percent_decode;
use std::sync::Arc;
use std::io;

lazy_static! {
// A global lock that ensure all Mockito tests are run on a single thread.
Expand Down Expand Up @@ -624,8 +625,8 @@ impl Matcher {
value == other
},
Matcher::UrlEncoded(ref expected_field, ref expected_value) => {
other.split("&").map( |pair| {
let mut parts = pair.splitn(2, "=");
other.split('&').map( |pair| {
let mut parts = pair.splitn(2, '=');
let field = percent_decode(parts.next().unwrap().as_bytes()).decode_utf8_lossy();
let value = percent_decode(parts.next().unwrap_or("").as_bytes()).decode_utf8_lossy();

Expand Down Expand Up @@ -667,8 +668,8 @@ impl Mock {
match path.into() {
// We also allow setting the query as part of the path argument
// but we split it under the hood into `Matcher::Exact` elements.
Matcher::Exact(ref raw_path) if raw_path.contains("?") => {
let mut parts = raw_path.splitn(2, "?");
Matcher::Exact(ref raw_path) if raw_path.contains('?') => {
let mut parts = raw_path.splitn(2, '?');
(parts.next().unwrap().into(), parts.next().unwrap_or("").into())
},
other => {
Expand All @@ -679,8 +680,8 @@ impl Mock {
Self {
id: thread_rng().sample_iter(&Alphanumeric).take(24).collect(),
method: method.to_owned().to_uppercase(),
path: path,
query: query,
path,
query,
headers: Vec::new(),
body: Matcher::Any,
response: Response::default(),
Expand Down Expand Up @@ -827,8 +828,25 @@ impl Mock {
/// ```
///
pub fn with_body<StrOrBytes: AsRef<[u8]>>(mut self, body: StrOrBytes) -> Self {
self.response.body = body.as_ref().to_owned();
self.response.body = response::Body::Bytes(body.as_ref().to_owned());
self
}

/// Sets the body of the mock response dynamically. The response will use chunked transfer encoding.
///
/// The function must be thread-safe. If it's a closure, it can't be borrowing its context.
/// Use `move` closures and `Arc` to share any data.
///
/// ## Example
///
/// ```
/// use mockito::mock;
///
/// let _m = mock("GET", "/").with_body_from_fn(|w| w.write_all(b"hello world"));
/// ```
///
pub fn with_body_from_fn(mut self, cb: impl Fn(&mut dyn io::Write) -> io::Result<()> + Send + Sync + 'static) -> Self {
self.response.body = response::Body::Fn(Arc::new(cb));
self
}

Expand All @@ -844,14 +862,8 @@ impl Mock {
/// let _m = mock("GET", "/").with_body_from_file("tests/files/simple.http");
/// ```
///
pub fn with_body_from_file(mut self, path: &str) -> Self {
let mut file = File::open(path).unwrap();
let mut body = Vec::new();

file.read_to_end(&mut body).unwrap();

self.response.body = body;

pub fn with_body_from_file(mut self, path: impl AsRef<Path>) -> Self {
self.response.body = response::Body::Bytes(std::fs::read(path).unwrap());
self
}

Expand Down
2 changes: 1 addition & 1 deletion src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl<'a> From<&'a TcpStream> for Request {
}

if let Some(a) = req.path {
let mut parts = a.splitn(2, "?");
let mut parts = a.splitn(2, '?');
request.path += parts.next().unwrap();
request.query += parts.next().unwrap_or("");
}
Expand Down
64 changes: 61 additions & 3 deletions src/response.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,78 @@
use std::sync::Arc;
use std::convert::From;
use std::fmt;
use std::io;

#[derive(Clone, Debug, PartialEq)]
pub struct Response {
pub(crate) struct Response {
pub status: Status,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
pub body: Body,
}

#[derive(Clone)]
pub(crate) enum Body {
Bytes(Vec<u8>),
Fn(Arc<Fn(&mut dyn io::Write) -> io::Result<()> + Send + Sync + 'static>),
}

impl fmt::Debug for Body {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Body::Bytes(ref b) => b.fmt(f),
Body::Fn(_) => f.write_str("<callback>"),
}
}
}

impl PartialEq for Body {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Body::Bytes(ref a), Body::Bytes(ref b)) => a == b,
(Body::Fn(ref a), Body::Fn(ref b)) => Arc::ptr_eq(a, b),
_ => false,
}
}
}

impl Default for Response {
fn default() -> Self {
Self {
status: Status::Ok,
headers: Vec::new(),
body: Vec::new(),
body: Body::Bytes(Vec::new()),
}
}
}

pub(crate) struct Chunked<W: io::Write> {
writer: W
}

impl<W: io::Write> Chunked<W> {
pub fn new(writer: W) -> Self {
Self {writer}
}

pub fn finish(mut self) -> io::Result<W> {
self.writer.write_all(b"0\r\n\r\n")?;
Ok(self.writer)
}
}

impl<W: io::Write> io::Write for Chunked<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
self.writer.write_all(format!("{:x}\r\n", buf.len()).as_bytes())?;
self.writer.write_all(buf)?;
self.writer.write_all(b"\r\n")?;
Ok(buf.len())
}

fn flush(&mut self) -> io::Result<()> {
self.writer.flush()
}
}

Expand Down
52 changes: 35 additions & 17 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::thread;
use std::io;
use std::io::Write;
use std::fmt::Display;
use std::net::{TcpListener, TcpStream, SocketAddr};
use std::sync::Mutex;
use std::sync::mpsc;
use {SERVER_ADDRESS_INTERNAL, Request, Mock};
use response::{Chunked, Body};

impl Mock {
fn method_matches(&self, request: &Request) -> bool {
Expand Down Expand Up @@ -157,16 +159,19 @@ fn respond(
headers: Option<&Vec<(String, String)>>,
body: Option<&str>
) {
respond_bytes(stream, version, status, headers, body.map(|s| s.as_bytes()))
let body = body.map(|s| Body::Bytes(s.as_bytes().to_owned()));
if let Err(e) = respond_bytes(stream, version, status, headers, body.as_ref()) {
eprintln!("warning: Mock response write error: {}", e);
}
}

fn respond_bytes(
mut stream: TcpStream,
version: (u8, u8),
status: impl Display,
headers: Option<&Vec<(String, String)>>,
body: Option<&[u8]>
) {
body: Option<&Body>
) -> io::Result<()> {
let mut response = Vec::from(format!("HTTP/{}.{} {}\r\n", version.0, version.1, status));
let mut has_content_length_header = false;

Expand All @@ -181,29 +186,42 @@ fn respond_bytes(
has_content_length_header = headers.iter().any(|(key, _)| key == "content-length");
}

if let Some(body) = body {
if !has_content_length_header {
response.extend(format!("content-length: {}\r\n\r\n", body.len()).as_bytes());
}

response.extend(body);
} else {
response.extend(b"\r\n");
}

let _ = stream.write(&response);
let _ = stream.flush();
match body {
Some(Body::Bytes(bytes)) => if !has_content_length_header {
response.extend(format!("content-length: {}\r\n", bytes.len()).as_bytes());
},
Some(Body::Fn(_)) => {
response.extend(b"transfer-encoding: chunked\r\n");
},
None => {},
};
response.extend(b"\r\n");
stream.write_all(&response)?;
match body {
Some(Body::Bytes(bytes)) => {
stream.write_all(bytes)?;
},
Some(Body::Fn(cb)) => {
let mut chunked = Chunked::new(&mut stream);
cb(&mut chunked)?;
chunked.finish()?;
},
None => {},
};
stream.flush()
}

fn respond_with_mock(stream: TcpStream, version: (u8, u8), mock: &Mock, skip_body: bool) {
let body =
if skip_body {
None
} else {
Some(&*mock.response.body)
Some(&mock.response.body)
};

respond_bytes(stream, version, &mock.response.status, Some(&mock.response.headers), body);
if let Err(e) = respond_bytes(stream, version, &mock.response.status, Some(&mock.response.headers), body) {
eprintln!("warning: Mock response write error: {}", e);
}
}

fn respond_with_mock_not_found(stream: TcpStream, version: (u8, u8)) {
Expand Down
55 changes: 51 additions & 4 deletions tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,48 @@ fn parse_stream(stream: TcpStream, skip_body: bool) -> (String, Vec<String>, Str
reader.read_line(&mut status_line).unwrap();

let mut headers = vec![];
let mut content_length: u64 = 0;
let mut content_length: Option<u64> = None;
let mut is_chunked = false;
loop {
let mut header_line = String::new();
reader.read_line(&mut header_line).unwrap();

if header_line == "\r\n" { break }

if header_line.starts_with("transfer-encoding:") && header_line.contains("chunked") {
is_chunked = true;
}

if header_line.starts_with("content-length:") {
let mut parts = header_line.split(':');
content_length = u64::from_str(parts.nth(1).unwrap().trim()).unwrap();
content_length = Some(u64::from_str(parts.nth(1).unwrap().trim()).unwrap());
}

headers.push(header_line.trim_right().to_string());
headers.push(header_line.trim_end().to_string());
}

let mut body = String::new();
if !skip_body {
reader.take(content_length).read_to_string(&mut body).unwrap();
if let Some(content_length) = content_length {
reader.take(content_length).read_to_string(&mut body).unwrap();
} else if is_chunked {
let mut chunk_size_buf = String::new();
loop {
chunk_size_buf.clear();
reader.read_line(&mut chunk_size_buf).unwrap();

let chunk_size = u64::from_str_radix(chunk_size_buf.trim_matches(|c| c == '\r' || c == '\n'), 16)
.expect("chunk size");
if chunk_size == 0 {
break;
}

(&mut reader).take(chunk_size).read_to_string(&mut body).unwrap();

chunk_size_buf.clear();
reader.read_line(&mut chunk_size_buf).unwrap();
}
}
}

(status_line, headers, body)
Expand Down Expand Up @@ -348,6 +372,29 @@ fn test_mock_with_custom_status() {
assert_eq!("HTTP/1.1 333 Custom\r\n", status_line);
}

#[test]
fn test_mock_with_body() {
let _m = mock("GET", "/")
.with_body("hello")
.create();

let (_, _, body) = request("GET /", "");
assert_eq!("hello", body);
}

#[test]
fn test_mock_with_fn_body() {
let _m = mock("GET", "/")
.with_body_from_fn(|w| {
w.write_all(b"hel")?;
w.write_all(b"lo")
})
.create();

let (_, _, body) = request("GET /", "");
assert_eq!("hello", body);
}

#[test]
fn test_mock_with_header() {
let _m = mock("GET", "/")
Expand Down

0 comments on commit 2ba5b6c

Please sign in to comment.