Skip to content

Commit

Permalink
Merge pull request #401 from hyperium/packets
Browse files Browse the repository at this point in the history
fix(http): read more before triggering TooLargeError
  • Loading branch information
seanmonstar committed Mar 28, 2015
2 parents 48700aa + cb59f60 commit 6c31ea8
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 43 deletions.
95 changes: 95 additions & 0 deletions src/buffer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use std::cmp;
use std::iter;
use std::io::{self, Read, BufRead, Cursor};

pub struct BufReader<R> {
buf: Cursor<Vec<u8>>,
inner: R
}

const INIT_BUFFER_SIZE: usize = 4096;
const MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100;

impl<R: Read> BufReader<R> {
pub fn new(rdr: R) -> BufReader<R> {
BufReader::with_capacity(rdr, INIT_BUFFER_SIZE)
}

pub fn with_capacity(rdr: R, cap: usize) -> BufReader<R> {
BufReader {
buf: Cursor::new(Vec::with_capacity(cap)),
inner: rdr
}
}

pub fn get_ref(&self) -> &R { &self.inner }

pub fn get_mut(&mut self) -> &mut R { &mut self.inner }

pub fn get_buf(&self) -> &[u8] {
self.buf.get_ref()
}

pub fn into_inner(self) -> R { self.inner }

pub fn read_into_buf(&mut self) -> io::Result<usize> {
let v = self.buf.get_mut();
reserve(v);
let inner = &mut self.inner;
with_end_to_cap(v, |b| inner.read(b))
}
}

impl<R: Read> Read for BufReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.buf.get_ref().len() == self.buf.position() as usize &&
buf.len() >= self.buf.get_ref().capacity() {
return self.inner.read(buf);
}
try!(self.fill_buf());
self.buf.read(buf)
}
}

impl<R: Read> BufRead for BufReader<R> {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if self.buf.position() as usize == self.buf.get_ref().len() {
self.buf.set_position(0);
let v = self.buf.get_mut();
v.truncate(0);
let inner = &mut self.inner;
try!(with_end_to_cap(v, |b| inner.read(b)));
}
self.buf.fill_buf()
}

fn consume(&mut self, amt: usize) {
self.buf.consume(amt)
}
}

fn with_end_to_cap<F>(v: &mut Vec<u8>, f: F) -> io::Result<usize>
where F: FnOnce(&mut [u8]) -> io::Result<usize>
{
let len = v.len();
let new_area = v.capacity() - len;
v.extend(iter::repeat(0).take(new_area));
match f(&mut v[len..]) {
Ok(n) => {
v.truncate(len + n);
Ok(n)
}
Err(e) => {
v.truncate(len);
Err(e)
}
}
}

#[inline]
fn reserve(v: &mut Vec<u8>) {
let cap = v.capacity();
if v.len() == cap {
v.reserve(cmp::min(cap * 4, MAX_BUFFER_SIZE) - cap);
}
}
6 changes: 4 additions & 2 deletions src/client/response.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Client Responses
use std::io::{self, Read, BufReader};
use std::io::{self, Read};
use std::num::FromPrimitive;
use std::marker::PhantomData;

use buffer::BufReader;
use header;
use header::{ContentLength, TransferEncoding};
use header::Encoding::Chunked;
Expand Down Expand Up @@ -103,9 +104,10 @@ impl Read for Response {
mod tests {
use std::borrow::Cow::Borrowed;
use std::boxed::BoxAny;
use std::io::{self, Read, BufReader};
use std::io::{self, Read};
use std::marker::PhantomData;

use buffer::BufReader;
use header::Headers;
use header::TransferEncoding;
use header::Encoding;
Expand Down
118 changes: 81 additions & 37 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ use std::io::{self, Read, Write, BufRead};

use httparse;

use buffer::BufReader;
use header::Headers;
use method::Method;
use uri::RequestUri;
use version::HttpVersion::{self, Http10, Http11};
use HttpError:: HttpTooLargeError;
use HttpResult;
use {HttpError, HttpResult};

use self::HttpReader::{SizedReader, ChunkedReader, EofReader, EmptyReader};
use self::HttpWriter::{ThroughWriter, ChunkedWriter, SizedWriter, EmptyWriter};
Expand Down Expand Up @@ -307,56 +308,88 @@ impl<W: Write> Write for HttpWriter<W> {
}
}

const MAX_HEADERS: usize = 100;

/// Parses a request into an Incoming message head.
pub fn parse_request<T: BufRead>(buf: &mut T) -> HttpResult<Incoming<(Method, RequestUri)>> {
let (inc, len) = {
let slice = try!(buf.fill_buf());
let mut headers = [httparse::Header { name: "", value: b"" }; 64];
let mut req = httparse::Request::new(&mut headers);
match try!(req.parse(slice)) {
#[inline]
pub fn parse_request<R: Read>(buf: &mut BufReader<R>) -> HttpResult<Incoming<(Method, RequestUri)>> {
parse::<R, httparse::Request, (Method, RequestUri)>(buf)
}

/// Parses a response into an Incoming message head.
#[inline]
pub fn parse_response<R: Read>(buf: &mut BufReader<R>) -> HttpResult<Incoming<RawStatus>> {
parse::<R, httparse::Response, RawStatus>(buf)
}

fn parse<R: Read, T: TryParse<Subject=I>, I>(rdr: &mut BufReader<R>) -> HttpResult<Incoming<I>> {
loop {
match try!(try_parse::<R, T, I>(rdr)) {
httparse::Status::Complete((inc, len)) => {
rdr.consume(len);
return Ok(inc);
},
_partial => ()
}
match try!(rdr.read_into_buf()) {
0 => return Err(HttpTooLargeError),
_ => ()
}
}
}

fn try_parse<R: Read, T: TryParse<Subject=I>, I>(rdr: &mut BufReader<R>) -> TryParseResult<I> {
let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
<T as TryParse>::try_parse(&mut headers, rdr.get_buf())
}

#[doc(hidden)]
trait TryParse {
type Subject;
fn try_parse<'a>(headers: &'a mut [httparse::Header<'a>], buf: &'a [u8]) -> TryParseResult<Self::Subject>;
}

type TryParseResult<T> = Result<httparse::Status<(Incoming<T>, usize)>, HttpError>;

impl<'a> TryParse for httparse::Request<'a> {
type Subject = (Method, RequestUri);

fn try_parse<'b>(headers: &'b mut [httparse::Header<'b>], buf: &'b [u8]) -> TryParseResult<(Method, RequestUri)> {
let mut req = httparse::Request::new(headers);
Ok(match try!(req.parse(buf)) {
httparse::Status::Complete(len) => {
(Incoming {
httparse::Status::Complete((Incoming {
version: if req.version.unwrap() == 1 { Http11 } else { Http10 },
subject: (
try!(req.method.unwrap().parse()),
try!(req.path.unwrap().parse())
),
headers: try!(Headers::from_raw(req.headers))
}, len)
}, len))
},
_ => {
// request head is bigger than a BufRead's buffer? 400 that!
return Err(HttpTooLargeError)
}
}
};
buf.consume(len);
Ok(inc)
httparse::Status::Partial => httparse::Status::Partial
})
}
}

/// Parses a response into an Incoming message head.
pub fn parse_response<T: BufRead>(buf: &mut T) -> HttpResult<Incoming<RawStatus>> {
let (inc, len) = {
let mut headers = [httparse::Header { name: "", value: b"" }; 64];
let mut res = httparse::Response::new(&mut headers);
match try!(res.parse(try!(buf.fill_buf()))) {
impl<'a> TryParse for httparse::Response<'a> {
type Subject = RawStatus;

fn try_parse<'b>(headers: &'b mut [httparse::Header<'b>], buf: &'b [u8]) -> TryParseResult<RawStatus> {
let mut res = httparse::Response::new(headers);
Ok(match try!(res.parse(buf)) {
httparse::Status::Complete(len) => {
(Incoming {
httparse::Status::Complete((Incoming {
version: if res.version.unwrap() == 1 { Http11 } else { Http10 },
subject: RawStatus(
res.code.unwrap(), res.reason.unwrap().to_owned().into_cow()
),
headers: try!(Headers::from_raw(res.headers))
}, len)
}, len))
},
_ => {
// response head is bigger than a BufRead's buffer?
return Err(HttpTooLargeError)
}
}
};
buf.consume(len);
Ok(inc)
httparse::Status::Partial => httparse::Status::Partial
})
}
}

/// An Incoming Message head. Includes request/status line, and headers.
Expand Down Expand Up @@ -456,19 +489,30 @@ mod tests {
read_err("1;no CRLF");
}

#[test]
fn test_parse_incoming() {
use buffer::BufReader;
use mock::MockStream;

use super::parse_request;
let mut raw = MockStream::with_input(b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n");
let mut buf = BufReader::new(&mut raw);
parse_request(&mut buf).unwrap();
}

use test::Bencher;

#[bench]
fn bench_parse_incoming(b: &mut Bencher) {
use std::io::BufReader;
use buffer::BufReader;
use mock::MockStream;

use super::parse_request;
let mut raw = MockStream::with_input(b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n");
let mut buf = BufReader::new(&mut raw);
b.iter(|| {
let mut raw = MockStream::with_input(b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n");
let mut buf = BufReader::new(&mut raw);

parse_request(&mut buf).unwrap();
buf.get_mut().read.set_position(0);
});
}
}
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ macro_rules! inspect(
#[cfg(test)]
#[macro_use]
mod mock;

#[doc(hidden)]
pub mod buffer;
pub mod client;
pub mod error;
pub mod method;
Expand Down
4 changes: 3 additions & 1 deletion src/server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! HTTP Server
use std::io::{BufReader, BufWriter, Write};
use std::io::{BufWriter, Write};
use std::marker::PhantomData;
use std::net::{SocketAddr, ToSocketAddrs};
use std::path::Path;
Expand All @@ -14,6 +14,7 @@ pub use net::{Fresh, Streaming};

use HttpError::HttpIoError;
use {HttpResult};
use buffer::BufReader;
use header::{Headers, Connection, Expect};
use header::ConnectionOption::{Close, KeepAlive};
use method::Method;
Expand Down Expand Up @@ -227,6 +228,7 @@ mod tests {
Host: example.domain\r\n\
Expect: 100-continue\r\n\
Content-Length: 10\r\n\
Connection: close\r\n\
\r\n\
1234567890\
");
Expand Down
6 changes: 4 additions & 2 deletions src/server/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
//!
//! These are requests that a `hyper::Server` receives, and include its method,
//! target URI, headers, and message body.
use std::io::{self, Read, BufReader};
use std::io::{self, Read};
use std::net::SocketAddr;

use {HttpResult};
use buffer::BufReader;
use net::NetworkStream;
use version::{HttpVersion};
use method::Method::{self, Get, Head};
Expand Down Expand Up @@ -81,12 +82,13 @@ impl<'a, 'b> Read for Request<'a, 'b> {

#[cfg(test)]
mod tests {
use buffer::BufReader;
use header::{Host, TransferEncoding, Encoding};
use net::NetworkStream;
use mock::MockStream;
use super::Request;

use std::io::{self, Read, BufReader};
use std::io::{self, Read};
use std::net::SocketAddr;

fn sock(s: &str) -> SocketAddr {
Expand Down

0 comments on commit 6c31ea8

Please sign in to comment.