From 0b7169432b5f51efe5c167be418c2c50220e46a5 Mon Sep 17 00:00:00 2001 From: Sean McArthur <sean.monstar@gmail.com> Date: Mon, 16 Mar 2015 15:56:38 -0700 Subject: [PATCH] feat(server): add Expect 100-continue support Adds a new method to `Handler`, with a default implementation of always responding with a `100 Continue` when sent an expectation. Closes #369 --- src/header/common/expect.rs | 37 +++++++++ src/header/common/mod.rs | 2 + src/server/mod.rs | 148 +++++++++++++++++++++++++++--------- 3 files changed, 151 insertions(+), 36 deletions(-) create mode 100644 src/header/common/expect.rs diff --git a/src/header/common/expect.rs b/src/header/common/expect.rs new file mode 100644 index 0000000000..3725310c48 --- /dev/null +++ b/src/header/common/expect.rs @@ -0,0 +1,37 @@ +use std::fmt; + +use header::{Header, HeaderFormat}; + +/// The `Expect` header. +/// +/// > The "Expect" header field in a request indicates a certain set of +/// > behaviors (expectations) that need to be supported by the server in +/// > order to properly handle this request. The only such expectation +/// > defined by this specification is 100-continue. +/// > +/// > Expect = "100-continue" +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum Expect { + /// The value `100-continue`. + Continue +} + +impl Header for Expect { + fn header_name() -> &'static str { + "Expect" + } + + fn parse_header(raw: &[Vec<u8>]) -> Option<Expect> { + if &[b"100-continue"] == raw { + Some(Expect::Continue) + } else { + None + } + } +} + +impl HeaderFormat for Expect { + fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("100-continue") + } +} diff --git a/src/header/common/mod.rs b/src/header/common/mod.rs index 3f644c3946..e0a423e3a0 100644 --- a/src/header/common/mod.rs +++ b/src/header/common/mod.rs @@ -20,6 +20,7 @@ pub use self::content_type::ContentType; pub use self::cookie::Cookie; pub use self::date::Date; pub use self::etag::Etag; +pub use self::expect::Expect; pub use self::expires::Expires; pub use self::host::Host; pub use self::if_match::IfMatch; @@ -160,6 +161,7 @@ mod content_length; mod content_type; mod date; mod etag; +mod expect; mod expires; mod host; mod if_match; diff --git a/src/server/mod.rs b/src/server/mod.rs index 3bc4343292..3407b9489c 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,5 +1,5 @@ //! HTTP Server -use std::io::{BufReader, BufWriter}; +use std::io::{BufReader, BufWriter, Write}; use std::marker::PhantomData; use std::net::{IpAddr, SocketAddr}; use std::path::Path; @@ -14,9 +14,12 @@ pub use net::{Fresh, Streaming}; use HttpError::HttpIoError; use {HttpResult}; -use header::Connection; +use header::{Headers, Connection, Expect}; use header::ConnectionOption::{Close, KeepAlive}; +use method::Method; use net::{NetworkListener, NetworkStream, HttpListener}; +use status::StatusCode; +use uri::RequestUri; use version::HttpVersion::{Http10, Http11}; use self::listener::ListenerPool; @@ -99,7 +102,7 @@ S: NetworkStream + Clone + Send> Server<'a, H, L> { debug!("threads = {:?}", threads); let pool = ListenerPool::new(listener.clone()); - let work = move |stream| keep_alive_loop(stream, &handler); + let work = move |mut stream| handle_connection(&mut stream, &handler); let guard = thread::scoped(move || pool.accept(work, threads)); @@ -111,7 +114,7 @@ S: NetworkStream + Clone + Send> Server<'a, H, L> { } -fn keep_alive_loop<'h, S, H>(mut stream: S, handler: &'h H) +fn handle_connection<'h, S, H>(mut stream: &mut S, handler: &'h H) where S: NetworkStream + Clone, H: Handler { debug!("Incoming stream"); let addr = match stream.peer_addr() { @@ -128,39 +131,45 @@ where S: NetworkStream + Clone, H: Handler { let mut keep_alive = true; while keep_alive { - keep_alive = handle_connection(addr, &mut rdr, &mut wrt, handler); - debug!("keep_alive = {:?}", keep_alive); - } -} + let req = match Request::new(&mut rdr, addr) { + Ok(req) => req, + Err(e@HttpIoError(_)) => { + debug!("ioerror in keepalive loop = {:?}", e); + break; + } + Err(e) => { + //TODO: send a 400 response + error!("request error = {:?}", e); + break; + } + }; -fn handle_connection<'a, 'aa, 'h, S, H>( - addr: SocketAddr, - rdr: &'a mut BufReader<&'aa mut NetworkStream>, - wrt: &mut BufWriter<S>, - handler: &'h H -) -> bool where 'aa: 'a, S: NetworkStream, H: Handler { - let mut res = Response::new(wrt); - let req = match Request::<'a, 'aa>::new(rdr, addr) { - Ok(req) => req, - Err(e@HttpIoError(_)) => { - debug!("ioerror in keepalive loop = {:?}", e); - return false; - } - Err(e) => { - //TODO: send a 400 response - error!("request error = {:?}", e); - return false; + if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) { + let status = handler.check_continue((&req.method, &req.uri, &req.headers)); + match write!(&mut wrt, "{} {}\r\n\r\n", Http11, status) { + Ok(..) => (), + Err(e) => { + error!("error writing 100-continue: {:?}", e); + break; + } + } + + if status != StatusCode::Continue { + debug!("non-100 status ({}) for Expect 100 request", status); + break; + } } - }; - let keep_alive = match (req.version, req.headers.get::<Connection>()) { - (Http10, Some(conn)) if !conn.contains(&KeepAlive) => false, - (Http11, Some(conn)) if conn.contains(&Close) => false, - _ => true - }; - res.version = req.version; - handler.handle(req, res); - keep_alive + keep_alive = match (req.version, req.headers.get::<Connection>()) { + (Http10, Some(conn)) if !conn.contains(&KeepAlive) => false, + (Http11, Some(conn)) if conn.contains(&Close) => false, + _ => true + }; + let mut res = Response::new(&mut wrt); + res.version = req.version; + handler.handle(req, res); + debug!("keep_alive = {:?}", keep_alive); + } } /// A listening server, which can later be closed. @@ -184,11 +193,78 @@ pub trait Handler: Sync + Send { /// Receives a `Request`/`Response` pair, and should perform some action on them. /// /// This could reading from the request, and writing to the response. - fn handle<'a, 'aa, 'b, 's>(&'s self, Request<'aa, 'a>, Response<'b, Fresh>); + fn handle<'a, 'k>(&'a self, Request<'a, 'k>, Response<'a, Fresh>); + + /// Called when a Request includes a `Expect: 100-continue` header. + /// + /// By default, this will always immediately response with a `StatusCode::Continue`, + /// but can be overridden with custom behavior. + fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode { + StatusCode::Continue + } } impl<F> Handler for F where F: Fn(Request, Response<Fresh>), F: Sync + Send { - fn handle<'a, 'aa, 'b, 's>(&'s self, req: Request<'a, 'aa>, res: Response<'b, Fresh>) { + fn handle<'a, 'k>(&'a self, req: Request<'a, 'k>, res: Response<'a, Fresh>) { self(req, res) } } + +#[cfg(test)] +mod tests { + use header::Headers; + use method::Method; + use mock::MockStream; + use status::StatusCode; + use uri::RequestUri; + + use super::{Request, Response, Fresh, Handler, handle_connection}; + + #[test] + fn test_check_continue_default() { + let mut mock = MockStream::with_input(b"\ + POST /upload HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Expect: 100-continue\r\n\ + Content-Length: 10\r\n\ + \r\n\ + 1234567890\ + "); + + fn handle(_: Request, res: Response<Fresh>) { + res.start().unwrap().end().unwrap(); + } + + handle_connection(&mut mock, &handle); + let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; + assert_eq!(&mock.write[..cont.len()], cont); + let res = b"HTTP/1.1 200 OK\r\n"; + assert_eq!(&mock.write[cont.len()..cont.len() + res.len()], res); + } + + #[test] + fn test_check_continue_reject() { + struct Reject; + impl Handler for Reject { + fn handle<'a, 'k>(&'a self, _: Request<'a, 'k>, res: Response<'a, Fresh>) { + res.start().unwrap().end().unwrap(); + } + + fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode { + StatusCode::ExpectationFailed + } + } + + let mut mock = MockStream::with_input(b"\ + POST /upload HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Expect: 100-continue\r\n\ + Content-Length: 10\r\n\ + \r\n\ + 1234567890\ + "); + + handle_connection(&mut mock, &Reject); + assert_eq!(mock.write, b"HTTP/1.1 417 Expectation Failed\r\n\r\n"); + } +}