From 0b7169432b5f51efe5c167be418c2c50220e46a5 Mon Sep 17 00:00:00 2001
From: Sean McArthur <sean.monstar@gmail.com>
Date: Mon, 16 Mar 2015 15:56:38 -0700
Subject: [PATCH] feat(server): add Expect 100-continue support

Adds a new method to `Handler`, with a default implementation of always
responding with a `100 Continue` when sent an expectation.

Closes #369
---
 src/header/common/expect.rs |  37 +++++++++
 src/header/common/mod.rs    |   2 +
 src/server/mod.rs           | 148 +++++++++++++++++++++++++++---------
 3 files changed, 151 insertions(+), 36 deletions(-)
 create mode 100644 src/header/common/expect.rs

diff --git a/src/header/common/expect.rs b/src/header/common/expect.rs
new file mode 100644
index 0000000000..3725310c48
--- /dev/null
+++ b/src/header/common/expect.rs
@@ -0,0 +1,37 @@
+use std::fmt;
+
+use header::{Header, HeaderFormat};
+
+/// The `Expect` header.
+///
+/// > The "Expect" header field in a request indicates a certain set of
+/// > behaviors (expectations) that need to be supported by the server in
+/// > order to properly handle this request.  The only such expectation
+/// > defined by this specification is 100-continue.
+/// >
+/// >    Expect  = "100-continue"
+#[derive(Copy, Clone, PartialEq, Debug)]
+pub enum Expect {
+    /// The value `100-continue`.
+    Continue
+}
+
+impl Header for Expect {
+    fn header_name() -> &'static str {
+        "Expect"
+    }
+
+    fn parse_header(raw: &[Vec<u8>]) -> Option<Expect> {
+        if &[b"100-continue"] == raw {
+            Some(Expect::Continue)
+        } else {
+            None
+        }
+    }
+}
+
+impl HeaderFormat for Expect {
+    fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        f.write_str("100-continue")
+    }
+}
diff --git a/src/header/common/mod.rs b/src/header/common/mod.rs
index 3f644c3946..e0a423e3a0 100644
--- a/src/header/common/mod.rs
+++ b/src/header/common/mod.rs
@@ -20,6 +20,7 @@ pub use self::content_type::ContentType;
 pub use self::cookie::Cookie;
 pub use self::date::Date;
 pub use self::etag::Etag;
+pub use self::expect::Expect;
 pub use self::expires::Expires;
 pub use self::host::Host;
 pub use self::if_match::IfMatch;
@@ -160,6 +161,7 @@ mod content_length;
 mod content_type;
 mod date;
 mod etag;
+mod expect;
 mod expires;
 mod host;
 mod if_match;
diff --git a/src/server/mod.rs b/src/server/mod.rs
index 3bc4343292..3407b9489c 100644
--- a/src/server/mod.rs
+++ b/src/server/mod.rs
@@ -1,5 +1,5 @@
 //! HTTP Server
-use std::io::{BufReader, BufWriter};
+use std::io::{BufReader, BufWriter, Write};
 use std::marker::PhantomData;
 use std::net::{IpAddr, SocketAddr};
 use std::path::Path;
@@ -14,9 +14,12 @@ pub use net::{Fresh, Streaming};
 
 use HttpError::HttpIoError;
 use {HttpResult};
-use header::Connection;
+use header::{Headers, Connection, Expect};
 use header::ConnectionOption::{Close, KeepAlive};
+use method::Method;
 use net::{NetworkListener, NetworkStream, HttpListener};
+use status::StatusCode;
+use uri::RequestUri;
 use version::HttpVersion::{Http10, Http11};
 
 use self::listener::ListenerPool;
@@ -99,7 +102,7 @@ S: NetworkStream + Clone + Send> Server<'a, H, L> {
 
         debug!("threads = {:?}", threads);
         let pool = ListenerPool::new(listener.clone());
-        let work = move |stream| keep_alive_loop(stream, &handler);
+        let work = move |mut stream| handle_connection(&mut stream, &handler);
 
         let guard = thread::scoped(move || pool.accept(work, threads));
 
@@ -111,7 +114,7 @@ S: NetworkStream + Clone + Send> Server<'a, H, L> {
 }
 
 
-fn keep_alive_loop<'h, S, H>(mut stream: S, handler: &'h H)
+fn handle_connection<'h, S, H>(mut stream: &mut S, handler: &'h H)
 where S: NetworkStream + Clone, H: Handler {
     debug!("Incoming stream");
     let addr = match stream.peer_addr() {
@@ -128,39 +131,45 @@ where S: NetworkStream + Clone, H: Handler {
 
     let mut keep_alive = true;
     while keep_alive {
-        keep_alive = handle_connection(addr, &mut rdr, &mut wrt, handler);
-        debug!("keep_alive = {:?}", keep_alive);
-    }
-}
+        let req = match Request::new(&mut rdr, addr) {
+            Ok(req) => req,
+            Err(e@HttpIoError(_)) => {
+                debug!("ioerror in keepalive loop = {:?}", e);
+                break;
+            }
+            Err(e) => {
+                //TODO: send a 400 response
+                error!("request error = {:?}", e);
+                break;
+            }
+        };
 
-fn handle_connection<'a, 'aa, 'h, S, H>(
-    addr: SocketAddr,
-    rdr: &'a mut BufReader<&'aa mut NetworkStream>,
-    wrt: &mut BufWriter<S>,
-    handler: &'h H
-) -> bool where 'aa: 'a, S: NetworkStream, H: Handler {
-    let mut res = Response::new(wrt);
-    let req = match Request::<'a, 'aa>::new(rdr, addr) {
-        Ok(req) => req,
-        Err(e@HttpIoError(_)) => {
-            debug!("ioerror in keepalive loop = {:?}", e);
-            return false;
-        }
-        Err(e) => {
-            //TODO: send a 400 response
-            error!("request error = {:?}", e);
-            return false;
+        if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
+            let status = handler.check_continue((&req.method, &req.uri, &req.headers));
+            match write!(&mut wrt, "{} {}\r\n\r\n", Http11, status) {
+                Ok(..) => (),
+                Err(e) => {
+                    error!("error writing 100-continue: {:?}", e);
+                    break;
+                }
+            }
+
+            if status != StatusCode::Continue {
+                debug!("non-100 status ({}) for Expect 100 request", status);
+                break;
+            }
         }
-    };
 
-    let keep_alive = match (req.version, req.headers.get::<Connection>()) {
-        (Http10, Some(conn)) if !conn.contains(&KeepAlive) => false,
-        (Http11, Some(conn)) if conn.contains(&Close)  => false,
-        _ => true
-    };
-    res.version = req.version;
-    handler.handle(req, res);
-    keep_alive
+        keep_alive = match (req.version, req.headers.get::<Connection>()) {
+            (Http10, Some(conn)) if !conn.contains(&KeepAlive) => false,
+            (Http11, Some(conn)) if conn.contains(&Close)  => false,
+            _ => true
+        };
+        let mut res = Response::new(&mut wrt);
+        res.version = req.version;
+        handler.handle(req, res);
+        debug!("keep_alive = {:?}", keep_alive);
+    }
 }
 
 /// A listening server, which can later be closed.
@@ -184,11 +193,78 @@ pub trait Handler: Sync + Send {
     /// Receives a `Request`/`Response` pair, and should perform some action on them.
     ///
     /// This could reading from the request, and writing to the response.
-    fn handle<'a, 'aa, 'b, 's>(&'s self, Request<'aa, 'a>, Response<'b, Fresh>);
+    fn handle<'a, 'k>(&'a self, Request<'a, 'k>, Response<'a, Fresh>);
+
+    /// Called when a Request includes a `Expect: 100-continue` header.
+    ///
+    /// By default, this will always immediately response with a `StatusCode::Continue`,
+    /// but can be overridden with custom behavior.
+    fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode {
+        StatusCode::Continue
+    }
 }
 
 impl<F> Handler for F where F: Fn(Request, Response<Fresh>), F: Sync + Send {
-    fn handle<'a, 'aa, 'b, 's>(&'s self, req: Request<'a, 'aa>, res: Response<'b, Fresh>) {
+    fn handle<'a, 'k>(&'a self, req: Request<'a, 'k>, res: Response<'a, Fresh>) {
         self(req, res)
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use header::Headers;
+    use method::Method;
+    use mock::MockStream;
+    use status::StatusCode;
+    use uri::RequestUri;
+
+    use super::{Request, Response, Fresh, Handler, handle_connection};
+
+    #[test]
+    fn test_check_continue_default() {
+        let mut mock = MockStream::with_input(b"\
+            POST /upload HTTP/1.1\r\n\
+            Host: example.domain\r\n\
+            Expect: 100-continue\r\n\
+            Content-Length: 10\r\n\
+            \r\n\
+            1234567890\
+        ");
+
+        fn handle(_: Request, res: Response<Fresh>) {
+            res.start().unwrap().end().unwrap();
+        }
+
+        handle_connection(&mut mock, &handle);
+        let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
+        assert_eq!(&mock.write[..cont.len()], cont);
+        let res = b"HTTP/1.1 200 OK\r\n";
+        assert_eq!(&mock.write[cont.len()..cont.len() + res.len()], res);
+    }
+
+    #[test]
+    fn test_check_continue_reject() {
+        struct Reject;
+        impl Handler for Reject {
+            fn handle<'a, 'k>(&'a self, _: Request<'a, 'k>, res: Response<'a, Fresh>) {
+                res.start().unwrap().end().unwrap();
+            }
+
+            fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode {
+                StatusCode::ExpectationFailed
+            }
+        }
+
+        let mut mock = MockStream::with_input(b"\
+            POST /upload HTTP/1.1\r\n\
+            Host: example.domain\r\n\
+            Expect: 100-continue\r\n\
+            Content-Length: 10\r\n\
+            \r\n\
+            1234567890\
+        ");
+
+        handle_connection(&mut mock, &Reject);
+        assert_eq!(mock.write, b"HTTP/1.1 417 Expectation Failed\r\n\r\n");
+    }
+}