From d35992d0198d733c251e133ecc35f2bca8540d96 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Tue, 3 May 2016 20:45:43 -0700 Subject: [PATCH] feat(lib): switch to non-blocking (asynchronous) IO BREAKING CHANGE: This breaks a lot of the Client and Server APIs. Check the documentation for how Handlers can be used for asynchronous events. --- .travis.yml | 2 + Cargo.toml | 9 +- README.md | 54 +- benches/client.rs | 111 -- examples/client.rs | 81 +- examples/client_http2.rs | 34 - examples/headers.rs | 16 - examples/hello.rs | 45 +- examples/server.rs | 186 ++- examples/sync.rs | 278 ++++ src/buffer.rs | 164 --- src/client/connect.rs | 219 ++++ src/client/dns.rs | 84 ++ src/client/mod.rs | 984 +++++++------- src/client/proxy.rs | 240 ---- src/client/request.rs | 172 +-- src/client/response.rs | 96 +- src/error.rs | 22 +- .../access_control_allow_credentials.rs | 6 +- .../common/access_control_allow_origin.rs | 4 +- src/header/common/authorization.rs | 4 +- src/header/common/cache_control.rs | 4 +- src/header/common/content_disposition.rs | 4 +- src/header/common/content_length.rs | 4 +- src/header/common/cookie.rs | 4 +- src/header/common/expect.rs | 4 +- src/header/common/host.rs | 8 +- src/header/common/if_range.rs | 12 +- src/header/common/mod.rs | 18 +- src/header/common/pragma.rs | 4 +- src/header/common/prefer.rs | 4 +- src/header/common/preference_applied.rs | 8 +- src/header/common/range.rs | 5 +- src/header/common/set_cookie.rs | 6 +- .../common/strict_transport_security.rs | 4 +- src/header/common/transfer_encoding.rs | 7 + src/header/internals/cell.rs | 7 +- src/header/internals/item.rs | 16 +- src/header/mod.rs | 70 +- src/header/parsing.rs | 6 + src/http/buffer.rs | 120 ++ src/http/channel.rs | 96 ++ src/http/conn.rs | 915 +++++++++++++ src/http/h1.rs | 1137 ----------------- src/http/h1/decode.rs | 293 +++++ src/http/h1/encode.rs | 371 ++++++ src/http/h1/mod.rs | 136 ++ src/http/h1/parse.rs | 246 ++++ src/http/{h2.rs => h2/mod.rs} | 0 src/http/message.rs | 133 -- src/http/mod.rs | 331 ++++- src/lib.rs | 168 +-- src/method.rs | 6 + src/mock.rs | 348 ++--- src/net.rs | 910 ++++++------- src/server/listener.rs | 79 -- src/server/message.rs | 58 + src/server/mod.rs | 673 ++++------ src/server/request.rs | 347 +---- src/server/response.rs | 420 +----- src/status.rs | 6 + src/uri.rs | 8 +- src/version.rs | 17 +- tests/client.rs | 205 +++ tests/server.rs | 379 ++++++ 65 files changed, 5492 insertions(+), 4916 deletions(-) delete mode 100755 benches/client.rs delete mode 100644 examples/client_http2.rs delete mode 100644 examples/headers.rs create mode 100644 examples/sync.rs delete mode 100644 src/buffer.rs create mode 100644 src/client/connect.rs create mode 100644 src/client/dns.rs delete mode 100644 src/client/proxy.rs create mode 100644 src/http/buffer.rs create mode 100644 src/http/channel.rs create mode 100644 src/http/conn.rs delete mode 100644 src/http/h1.rs create mode 100644 src/http/h1/decode.rs create mode 100644 src/http/h1/encode.rs create mode 100644 src/http/h1/mod.rs create mode 100644 src/http/h1/parse.rs rename src/http/{h2.rs => h2/mod.rs} (100%) delete mode 100644 src/http/message.rs delete mode 100644 src/server/listener.rs create mode 100644 src/server/message.rs create mode 100644 tests/client.rs create mode 100644 tests/server.rs diff --git a/.travis.yml b/.travis.yml index 5475915c40..2afedf4be0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,6 +3,8 @@ matrix: fast_finish: true include: - os: osx + rust: stable + env: FEATURES="--no-default-features --features security-framework" - rust: nightly env: FEATURES="--features nightly" - rust: beta diff --git a/Cargo.toml b/Cargo.toml index 709bced6a4..bf55b4fb46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,13 +16,15 @@ httparse = "1.0" language-tags = "0.2" log = "0.3" mime = "0.2" -num_cpus = "0.2" +rotor = "0.6" rustc-serialize = "0.3" +spmc = "0.2" time = "0.1" traitobject = "0.0.1" typeable = "0.1" unicase = "1.0" url = "1.0" +vecio = "0.1" [dependencies.cookie] version = "0.2" @@ -40,16 +42,13 @@ optional = true version = "0.1.4" optional = true -[dependencies.solicit] -version = "0.4" -default-features = false - [dependencies.serde] version = "0.7" optional = true [dev-dependencies] env_logger = "0.3" +num_cpus = "0.2" [features] default = ["ssl"] diff --git a/README.md b/README.md index eb16134189..17c5c1aea5 100644 --- a/README.md +++ b/README.md @@ -10,12 +10,12 @@ A Modern HTTP library for Rust. ### Documentation -- [Stable](http://hyperium.github.io/hyper) +- [Released](http://hyperium.github.io/hyper) - [Master](http://hyperium.github.io/hyper/master) ## Overview -Hyper is a fast, modern HTTP implementation written in and for Rust. It +hyper is a fast, modern HTTP implementation written in and for Rust. It is a low-level typesafe abstraction over raw HTTP, providing an elegant layer over "stringly-typed" HTTP. @@ -23,53 +23,3 @@ Hyper offers both an HTTP client and server which can be used to drive complex web applications written entirely in Rust. The documentation is located at [http://hyperium.github.io/hyper](http://hyperium.github.io/hyper). - -## Example - -### Hello World Server: - -```rust -extern crate hyper; - -use hyper::Server; -use hyper::server::Request; -use hyper::server::Response; - -fn hello(_: Request, res: Response) { - res.send(b"Hello World!").unwrap(); -} - -fn main() { - Server::http("127.0.0.1:3000").unwrap() - .handle(hello).unwrap(); -} -``` - -### Client: - -```rust -extern crate hyper; - -use std::io::Read; - -use hyper::Client; -use hyper::header::Connection; - -fn main() { - // Create a client. - let client = Client::new(); - - // Creating an outgoing request. - let mut res = client.get("http://rust-lang.org/") - // set a header - .header(Connection::close()) - // let 'er go! - .send().unwrap(); - - // Read the Response. - let mut body = String::new(); - res.read_to_string(&mut body).unwrap(); - - println!("Response: {}", body); -} -``` diff --git a/benches/client.rs b/benches/client.rs deleted file mode 100755 index 7fb1d440ba..0000000000 --- a/benches/client.rs +++ /dev/null @@ -1,111 +0,0 @@ -#![deny(warnings)] -#![feature(test)] -extern crate hyper; - -extern crate test; - -use std::fmt; -use std::io::{self, Read, Write, Cursor}; -use std::net::SocketAddr; -use std::time::Duration; - -use hyper::net; - -static README: &'static [u8] = include_bytes!("../README.md"); - -struct MockStream { - read: Cursor> -} - -impl MockStream { - fn new() -> MockStream { - let head = b"HTTP/1.1 200 OK\r\nServer: Mock\r\n\r\n"; - let mut res = head.to_vec(); - res.extend_from_slice(README); - MockStream { - read: Cursor::new(res) - } - } -} - -impl Clone for MockStream { - fn clone(&self) -> MockStream { - MockStream { - read: Cursor::new(self.read.get_ref().clone()) - } - } -} - -impl Read for MockStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.read.read(buf) - } -} - -impl Write for MockStream { - fn write(&mut self, msg: &[u8]) -> io::Result { - // we're mocking, what do we care. - Ok(msg.len()) - } - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -#[derive(Clone, Debug)] -struct Foo; - -impl hyper::header::Header for Foo { - fn header_name() -> &'static str { - "x-foo" - } - fn parse_header(_: &[Vec]) -> hyper::Result { - Err(hyper::Error::Header) - } -} - -impl hyper::header::HeaderFormat for Foo { - fn fmt_header(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.write_str("Bar") - } -} - -impl net::NetworkStream for MockStream { - fn peer_addr(&mut self) -> io::Result { - Ok("127.0.0.1:1337".parse().unwrap()) - } - fn set_read_timeout(&self, _: Option) -> io::Result<()> { - // can't time out - Ok(()) - } - fn set_write_timeout(&self, _: Option) -> io::Result<()> { - // can't time out - Ok(()) - } -} - -struct MockConnector; - -impl net::NetworkConnector for MockConnector { - type Stream = MockStream; - fn connect(&self, _: &str, _: u16, _: &str) -> hyper::Result { - Ok(MockStream::new()) - } -} - -#[bench] -fn bench_mock_hyper(b: &mut test::Bencher) { - let url = "http://127.0.0.1:1337/"; - b.iter(|| { - let mut req = hyper::client::Request::with_connector( - hyper::Get, hyper::Url::parse(url).unwrap(), &MockConnector - ).unwrap(); - req.headers_mut().set(Foo); - - let mut s = String::new(); - req - .start().unwrap() - .send().unwrap() - .read_to_string(&mut s).unwrap() - }); -} diff --git a/examples/client.rs b/examples/client.rs index 5467f52c7d..a66afb5af2 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -5,9 +5,61 @@ extern crate env_logger; use std::env; use std::io; +use std::sync::mpsc; +use std::time::Duration; -use hyper::Client; +use hyper::client::{Client, Request, Response, DefaultTransport as HttpStream}; use hyper::header::Connection; +use hyper::{Decoder, Encoder, Next}; + +#[derive(Debug)] +struct Dump(mpsc::Sender<()>); + +impl Drop for Dump { + fn drop(&mut self) { + let _ = self.0.send(()); + } +} + +fn read() -> Next { + Next::read().timeout(Duration::from_secs(10)) +} + +impl hyper::client::Handler for Dump { + fn on_request(&mut self, req: &mut Request) -> Next { + req.headers_mut().set(Connection::close()); + read() + } + + fn on_request_writable(&mut self, _encoder: &mut Encoder) -> Next { + read() + } + + fn on_response(&mut self, res: Response) -> Next { + println!("Response: {}", res.status()); + println!("Headers:\n{}", res.headers()); + read() + } + + fn on_response_readable(&mut self, decoder: &mut Decoder) -> Next { + match io::copy(decoder, &mut io::stdout()) { + Ok(0) => Next::end(), + Ok(_) => read(), + Err(e) => match e.kind() { + io::ErrorKind::WouldBlock => Next::read(), + _ => { + println!("ERROR: {}", e); + Next::end() + } + } + } + } + + fn on_error(&mut self, err: hyper::Error) -> Next { + println!("ERROR: {}", err); + Next::remove() + } +} fn main() { env_logger::init().unwrap(); @@ -20,26 +72,11 @@ fn main() { } }; - let client = match env::var("HTTP_PROXY") { - Ok(mut proxy) => { - // parse the proxy, message if it doesn't make sense - let mut port = 80; - if let Some(colon) = proxy.rfind(':') { - port = proxy[colon + 1..].parse().unwrap_or_else(|e| { - panic!("HTTP_PROXY is malformed: {:?}, port parse error: {}", proxy, e); - }); - proxy.truncate(colon); - } - Client::with_http_proxy(proxy, port) - }, - _ => Client::new() - }; - - let mut res = client.get(&*url) - .header(Connection::close()) - .send().unwrap(); + let (tx, rx) = mpsc::channel(); + let client = Client::new().expect("Failed to create a Client"); + client.request(url.parse().unwrap(), Dump(tx)).unwrap(); - println!("Response: {}", res.status); - println!("Headers:\n{}", res.headers); - io::copy(&mut res, &mut io::stdout()).unwrap(); + // wait till done + let _ = rx.recv(); + client.close(); } diff --git a/examples/client_http2.rs b/examples/client_http2.rs deleted file mode 100644 index c12f4567f9..0000000000 --- a/examples/client_http2.rs +++ /dev/null @@ -1,34 +0,0 @@ -#![deny(warnings)] -extern crate hyper; - -extern crate env_logger; - -use std::env; -use std::io; - -use hyper::Client; -use hyper::header::Connection; -use hyper::http::h2; - -fn main() { - env_logger::init().unwrap(); - - let url = match env::args().nth(1) { - Some(url) => url, - None => { - println!("Usage: client "); - return; - } - }; - - let client = Client::with_protocol(h2::new_protocol()); - - // `Connection: Close` is not a valid header for HTTP/2, but the client handles it gracefully. - let mut res = client.get(&*url) - .header(Connection::close()) - .send().unwrap(); - - println!("Response: {}", res.status); - println!("Headers:\n{}", res.headers); - io::copy(&mut res, &mut io::stdout()).unwrap(); -} diff --git a/examples/headers.rs b/examples/headers.rs deleted file mode 100644 index 24e5301f27..0000000000 --- a/examples/headers.rs +++ /dev/null @@ -1,16 +0,0 @@ -#![deny(warnings)] - -#[macro_use] -// TODO: only import header!, blocked by https://github.com/rust-lang/rust/issues/25003 -extern crate hyper; - -#[cfg(feature = "serde-serialization")] -extern crate serde; - -// A header in the form of `X-Foo: some random string` -header! { - (Foo, "X-Foo") => [String] -} - -fn main() { -} diff --git a/examples/hello.rs b/examples/hello.rs index fb2790bc6d..9b20d999e4 100644 --- a/examples/hello.rs +++ b/examples/hello.rs @@ -1,18 +1,53 @@ #![deny(warnings)] extern crate hyper; extern crate env_logger; +extern crate num_cpus; -use hyper::server::{Request, Response}; +use std::io::Write; + +use hyper::{Decoder, Encoder, Next}; +use hyper::net::{HttpStream, HttpListener}; +use hyper::server::{Server, Handler, Request, Response}; static PHRASE: &'static [u8] = b"Hello World!"; -fn hello(_: Request, res: Response) { - res.send(PHRASE).unwrap(); +struct Hello; + +impl Handler for Hello { + fn on_request(&mut self, _: Request) -> Next { + Next::write() + } + fn on_request_readable(&mut self, _: &mut Decoder) -> Next { + Next::write() + } + fn on_response(&mut self, response: &mut Response) -> Next { + use hyper::header::ContentLength; + response.headers_mut().set(ContentLength(PHRASE.len() as u64)); + Next::write() + } + fn on_response_writable(&mut self, encoder: &mut Encoder) -> Next { + let n = encoder.write(PHRASE).unwrap(); + debug_assert_eq!(n, PHRASE.len()); + Next::end() + } } fn main() { env_logger::init().unwrap(); - let _listening = hyper::Server::http("127.0.0.1:3000").unwrap() - .handle(hello); + + let listener = HttpListener::bind(&"127.0.0.1:3000".parse().unwrap()).unwrap(); + let mut handles = Vec::new(); + + for _ in 0..num_cpus::get() { + let listener = listener.try_clone().unwrap(); + handles.push(::std::thread::spawn(move || { + Server::new(listener) + .handle(|_| Hello).unwrap() + })); + } println!("Listening on http://127.0.0.1:3000"); + + for handle in handles { + handle.join().unwrap(); + } } diff --git a/examples/server.rs b/examples/server.rs index 06b68b37c0..3a2d8ccb83 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,47 +1,171 @@ #![deny(warnings)] extern crate hyper; extern crate env_logger; +#[macro_use] +extern crate log; -use std::io::copy; +use std::io::{self, Read, Write}; -use hyper::{Get, Post}; -use hyper::server::{Server, Request, Response}; -use hyper::uri::RequestUri::AbsolutePath; +use hyper::{Get, Post, StatusCode, RequestUri, Decoder, Encoder, Next}; +use hyper::header::ContentLength; +use hyper::net::HttpStream; +use hyper::server::{Server, Handler, Request, Response}; -macro_rules! try_return( - ($e:expr) => {{ - match $e { - Ok(v) => v, - Err(e) => { println!("Error: {}", e); return; } +struct Echo { + buf: Vec, + read_pos: usize, + write_pos: usize, + eof: bool, + route: Route, +} + +enum Route { + NotFound, + Index, + Echo(Body), +} + +#[derive(Clone, Copy)] +enum Body { + Len(u64), + Chunked +} + +static INDEX: &'static [u8] = b"Try POST /echo"; + +impl Echo { + fn new() -> Echo { + Echo { + buf: vec![0; 4096], + read_pos: 0, + write_pos: 0, + eof: false, + route: Route::NotFound, } - }} -); - -fn echo(mut req: Request, mut res: Response) { - match req.uri { - AbsolutePath(ref path) => match (&req.method, &path[..]) { - (&Get, "/") | (&Get, "/echo") => { - try_return!(res.send(b"Try POST /echo")); - return; + } +} + +impl Handler for Echo { + fn on_request(&mut self, req: Request) -> Next { + match *req.uri() { + RequestUri::AbsolutePath(ref path) => match (req.method(), &path[..]) { + (&Get, "/") | (&Get, "/echo") => { + info!("GET Index"); + self.route = Route::Index; + Next::write() + } + (&Post, "/echo") => { + info!("POST Echo"); + let mut is_more = true; + self.route = if let Some(len) = req.headers().get::() { + is_more = **len > 0; + Route::Echo(Body::Len(**len)) + } else { + Route::Echo(Body::Chunked) + }; + if is_more { + Next::read_and_write() + } else { + Next::write() + } + } + _ => Next::write(), }, - (&Post, "/echo") => (), // fall through, fighting mutable borrows - _ => { - *res.status_mut() = hyper::NotFound; - return; + _ => Next::write() + } + } + fn on_request_readable(&mut self, transport: &mut Decoder) -> Next { + match self.route { + Route::Echo(ref body) => { + if self.read_pos < self.buf.len() { + match transport.read(&mut self.buf[self.read_pos..]) { + Ok(0) => { + debug!("Read 0, eof"); + self.eof = true; + Next::write() + }, + Ok(n) => { + self.read_pos += n; + match *body { + Body::Len(max) if max <= self.read_pos as u64 => { + self.eof = true; + Next::write() + }, + _ => Next::read_and_write() + } + } + Err(e) => match e.kind() { + io::ErrorKind::WouldBlock => Next::read_and_write(), + _ => { + println!("read error {:?}", e); + Next::end() + } + } + } + } else { + Next::write() + } } - }, - _ => { - return; + _ => unreachable!() } - }; + } - let mut res = try_return!(res.start()); - try_return!(copy(&mut req, &mut res)); + fn on_response(&mut self, res: &mut Response) -> Next { + match self.route { + Route::NotFound => { + res.set_status(StatusCode::NotFound); + Next::end() + } + Route::Index => { + res.headers_mut().set(ContentLength(INDEX.len() as u64)); + Next::write() + } + Route::Echo(body) => { + if let Body::Len(len) = body { + res.headers_mut().set(ContentLength(len)); + } + Next::read_and_write() + } + } + } + + fn on_response_writable(&mut self, transport: &mut Encoder) -> Next { + match self.route { + Route::Index => { + transport.write(INDEX).unwrap(); + Next::end() + } + Route::Echo(..) => { + if self.write_pos < self.read_pos { + match transport.write(&self.buf[self.write_pos..self.read_pos]) { + Ok(0) => panic!("write ZERO"), + Ok(n) => { + self.write_pos += n; + Next::write() + } + Err(e) => match e.kind() { + io::ErrorKind::WouldBlock => Next::write(), + _ => { + println!("write error {:?}", e); + Next::end() + } + } + } + } else if !self.eof { + Next::read() + } else { + Next::end() + } + } + _ => unreachable!() + } + } } fn main() { env_logger::init().unwrap(); - let server = Server::http("127.0.0.1:1337").unwrap(); - let _guard = server.handle(echo); - println!("Listening on http://127.0.0.1:1337"); + let server = Server::http(&"127.0.0.1:1337".parse().unwrap()).unwrap(); + let (listening, server) = server.handle(|_| Echo::new()).unwrap(); + println!("Listening on http://{}", listening); + server.run(); } diff --git a/examples/sync.rs b/examples/sync.rs new file mode 100644 index 0000000000..6d4eacae64 --- /dev/null +++ b/examples/sync.rs @@ -0,0 +1,278 @@ +extern crate hyper; +extern crate env_logger; +extern crate time; + +use std::io::{self, Read, Write}; +use std::marker::PhantomData; +use std::thread; +use std::sync::{Arc, mpsc}; + +pub struct Server { + listening: hyper::server::Listening, +} + +pub struct Request<'a> { + #[allow(dead_code)] + inner: hyper::server::Request, + tx: &'a mpsc::Sender, + rx: &'a mpsc::Receiver>, + ctrl: &'a hyper::Control, +} + +impl<'a> Request<'a> { + fn new(inner: hyper::server::Request, tx: &'a mpsc::Sender, rx: &'a mpsc::Receiver>, ctrl: &'a hyper::Control) -> Request<'a> { + Request { + inner: inner, + tx: tx, + rx: rx, + ctrl: ctrl, + } + } +} + +impl<'a> io::Read for Request<'a> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.tx.send(Action::Read(buf.as_mut_ptr(), buf.len())).unwrap(); + self.ctrl.ready(hyper::Next::read()).unwrap(); + self.rx.recv().unwrap() + } +} + +pub enum Fresh {} +pub enum Streaming {} + +pub struct Response<'a, W = Fresh> { + status: hyper::StatusCode, + headers: hyper::Headers, + version: hyper::HttpVersion, + tx: &'a mpsc::Sender, + rx: &'a mpsc::Receiver>, + ctrl: &'a hyper::Control, + _marker: PhantomData, +} + +impl<'a> Response<'a, Fresh> { + fn new(tx: &'a mpsc::Sender, rx: &'a mpsc::Receiver>, ctrl: &'a hyper::Control) -> Response<'a, Fresh> { + Response { + status: hyper::Ok, + headers: hyper::Headers::new(), + version: hyper::HttpVersion::Http11, + tx: tx, + rx: rx, + ctrl: ctrl, + _marker: PhantomData, + } + } + + pub fn start(self) -> io::Result> { + self.tx.send(Action::Respond(self.version.clone(), self.status.clone(), self.headers.clone())).unwrap(); + self.ctrl.ready(hyper::Next::write()).unwrap(); + let res = self.rx.recv().unwrap(); + res.map(move |_| Response { + status: self.status, + headers: self.headers, + version: self.version, + tx: self.tx, + rx: self.rx, + ctrl: self.ctrl, + _marker: PhantomData, + }) + } + + pub fn send(mut self, msg: &[u8]) -> io::Result<()> { + self.headers.set(hyper::header::ContentLength(msg.len() as u64)); + self.start().and_then(|mut res| res.write_all(msg)).map(|_| ()) + } +} + +impl<'a> Write for Response<'a, Streaming> { + fn write(&mut self, msg: &[u8]) -> io::Result { + self.tx.send(Action::Write(msg.as_ptr(), msg.len())).unwrap(); + self.ctrl.ready(hyper::Next::write()).unwrap(); + let res = self.rx.recv().unwrap(); + res + } + + fn flush(&mut self) -> io::Result<()> { + panic!("Response.flush() not impemented") + } +} + +struct SynchronousHandler { + req_tx: mpsc::Sender, + tx: mpsc::Sender>, + rx: mpsc::Receiver, + reading: Option<(*mut u8, usize)>, + writing: Option<(*const u8, usize)>, + respond: Option<(hyper::HttpVersion, hyper::StatusCode, hyper::Headers)> +} + +unsafe impl Send for SynchronousHandler {} + +impl SynchronousHandler { + fn next(&mut self) -> hyper::Next { + match self.rx.try_recv() { + Ok(Action::Read(ptr, len)) => { + self.reading = Some((ptr, len)); + hyper::Next::read() + }, + Ok(Action::Respond(ver, status, headers)) => { + self.respond = Some((ver, status, headers)); + hyper::Next::write() + }, + Ok(Action::Write(ptr, len)) => { + self.writing = Some((ptr, len)); + hyper::Next::write() + } + Err(mpsc::TryRecvError::Empty) => { + // we're too fast, the other thread hasn't had a chance to respond + hyper::Next::wait() + } + Err(mpsc::TryRecvError::Disconnected) => { + // they dropped it + // TODO: should finish up sending response, whatever it was + hyper::Next::end() + } + } + } + + fn reading(&mut self) -> Option<(*mut u8, usize)> { + self.reading.take().or_else(|| { + match self.rx.try_recv() { + Ok(Action::Read(ptr, len)) => { + Some((ptr, len)) + }, + _ => None + } + }) + } + + fn writing(&mut self) -> Option<(*const u8, usize)> { + self.writing.take().or_else(|| { + match self.rx.try_recv() { + Ok(Action::Write(ptr, len)) => { + Some((ptr, len)) + }, + _ => None + } + }) + } + fn respond(&mut self) -> Option<(hyper::HttpVersion, hyper::StatusCode, hyper::Headers)> { + self.respond.take().or_else(|| { + match self.rx.try_recv() { + Ok(Action::Respond(ver, status, headers)) => { + Some((ver, status, headers)) + }, + _ => None + } + }) + } +} + +impl hyper::server::Handler for SynchronousHandler { + fn on_request(&mut self, req: hyper::server::Request) -> hyper::Next { + if let Err(_) = self.req_tx.send(req) { + return hyper::Next::end(); + } + + self.next() + } + + fn on_request_readable(&mut self, decoder: &mut hyper::Decoder) -> hyper::Next { + if let Some(raw) = self.reading() { + let slice = unsafe { ::std::slice::from_raw_parts_mut(raw.0, raw.1) }; + if self.tx.send(decoder.read(slice)).is_err() { + return hyper::Next::end(); + } + } + self.next() + } + + fn on_response(&mut self, req: &mut hyper::server::Response) -> hyper::Next { + use std::iter::Extend; + if let Some(head) = self.respond() { + req.set_status(head.1); + req.headers_mut().extend(head.2.iter()); + if self.tx.send(Ok(0)).is_err() { + return hyper::Next::end(); + } + } else { + // wtf happened? + panic!("no head to respond with"); + } + self.next() + } + + fn on_response_writable(&mut self, encoder: &mut hyper::Encoder) -> hyper::Next { + if let Some(raw) = self.writing() { + let slice = unsafe { ::std::slice::from_raw_parts(raw.0, raw.1) }; + if self.tx.send(encoder.write(slice)).is_err() { + return hyper::Next::end(); + } + } + self.next() + } +} + +enum Action { + Read(*mut u8, usize), + Write(*const u8, usize), + Respond(hyper::HttpVersion, hyper::StatusCode, hyper::Headers), +} + +unsafe impl Send for Action {} + +trait Handler: Send + Sync + 'static { + fn handle(&self, req: Request, res: Response); +} + +impl Handler for F where F: Fn(Request, Response) + Send + Sync + 'static { + fn handle(&self, req: Request, res: Response) { + (self)(req, res) + } +} + +impl Server { + fn handle(addr: &str, handler: H) -> Server { + let handler = Arc::new(handler); + let (listening, server) = hyper::Server::http(&addr.parse().unwrap()).unwrap() + .handle(move |ctrl| { + let (req_tx, req_rx) = mpsc::channel(); + let (blocking_tx, blocking_rx) = mpsc::channel(); + let (async_tx, async_rx) = mpsc::channel(); + let handler = handler.clone(); + thread::Builder::new().name("handler-thread".into()).spawn(move || { + let req = Request::new(req_rx.recv().unwrap(), &blocking_tx, &async_rx, &ctrl); + let res = Response::new(&blocking_tx, &async_rx, &ctrl); + handler.handle(req, res); + }).unwrap(); + + SynchronousHandler { + req_tx: req_tx, + tx: async_tx, + rx: blocking_rx, + reading: None, + writing: None, + respond: None, + } + }).unwrap(); + thread::spawn(move || { + server.run(); + }); + Server { + listening: listening + } + } +} + +fn main() { + env_logger::init().unwrap(); + let s = Server::handle("127.0.0.1:0", |mut req: Request, res: Response| { + let mut body = [0; 256]; + let n = req.read(&mut body).unwrap(); + println!("!!!: received: {:?}", ::std::str::from_utf8(&body[..n]).unwrap()); + + res.send(b"Hello World!").unwrap(); + }); + println!("listening on {}", s.listening.addr()); +} diff --git a/src/buffer.rs b/src/buffer.rs deleted file mode 100644 index 905493aa6c..0000000000 --- a/src/buffer.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::cmp; -use std::io::{self, Read, BufRead}; - -pub struct BufReader { - inner: R, - buf: Vec, - pos: usize, - cap: usize, -} - -const INIT_BUFFER_SIZE: usize = 4096; -const MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; - -impl BufReader { - #[inline] - pub fn new(rdr: R) -> BufReader { - BufReader::with_capacity(rdr, INIT_BUFFER_SIZE) - } - - #[inline] - pub fn with_capacity(rdr: R, cap: usize) -> BufReader { - BufReader { - inner: rdr, - buf: vec![0; cap], - pos: 0, - cap: 0, - } - } - - #[inline] - pub fn get_ref(&self) -> &R { &self.inner } - - #[inline] - pub fn get_mut(&mut self) -> &mut R { &mut self.inner } - - #[inline] - pub fn get_buf(&self) -> &[u8] { - if self.pos < self.cap { - trace!("get_buf [u8; {}][{}..{}]", self.buf.len(), self.pos, self.cap); - &self.buf[self.pos..self.cap] - } else { - trace!("get_buf []"); - &[] - } - } - - #[inline] - pub fn into_inner(self) -> R { self.inner } - - #[inline] - pub fn read_into_buf(&mut self) -> io::Result { - self.maybe_reserve(); - let v = &mut self.buf; - trace!("read_into_buf buf[{}..{}]", self.cap, v.len()); - if self.cap < v.capacity() { - let nread = try!(self.inner.read(&mut v[self.cap..])); - self.cap += nread; - Ok(nread) - } else { - trace!("read_into_buf at full capacity"); - Ok(0) - } - } - - #[inline] - fn maybe_reserve(&mut self) { - let cap = self.buf.capacity(); - if self.cap == cap && cap < MAX_BUFFER_SIZE { - self.buf.reserve(cmp::min(cap * 4, MAX_BUFFER_SIZE) - cap); - let new = self.buf.capacity() - self.buf.len(); - trace!("reserved {}", new); - unsafe { grow_zerofill(&mut self.buf, new) } - } - } -} - -#[inline] -unsafe fn grow_zerofill(buf: &mut Vec, additional: usize) { - use std::ptr; - let len = buf.len(); - buf.set_len(len + additional); - ptr::write_bytes(buf.as_mut_ptr().offset(len as isize), 0, additional); -} - -impl Read for BufReader { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - if self.cap == self.pos && buf.len() >= self.buf.len() { - return self.inner.read(buf); - } - let nread = { - let mut rem = try!(self.fill_buf()); - try!(rem.read(buf)) - }; - self.consume(nread); - Ok(nread) - } -} - -impl BufRead for BufReader { - fn fill_buf(&mut self) -> io::Result<&[u8]> { - if self.pos == self.cap { - self.cap = try!(self.inner.read(&mut self.buf)); - self.pos = 0; - } - Ok(&self.buf[self.pos..self.cap]) - } - - #[inline] - fn consume(&mut self, amt: usize) { - self.pos = cmp::min(self.pos + amt, self.cap); - if self.pos == self.cap { - self.pos = 0; - self.cap = 0; - } - } -} - -#[cfg(test)] -mod tests { - - use std::io::{self, Read, BufRead}; - use super::BufReader; - - struct SlowRead(u8); - - impl Read for SlowRead { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let state = self.0; - self.0 += 1; - (&match state % 3 { - 0 => b"foo", - 1 => b"bar", - _ => b"baz", - }[..]).read(buf) - } - } - - #[test] - fn test_consume_and_get_buf() { - let mut rdr = BufReader::new(SlowRead(0)); - rdr.read_into_buf().unwrap(); - rdr.consume(1); - assert_eq!(rdr.get_buf(), b"oo"); - rdr.read_into_buf().unwrap(); - rdr.read_into_buf().unwrap(); - assert_eq!(rdr.get_buf(), b"oobarbaz"); - rdr.consume(5); - assert_eq!(rdr.get_buf(), b"baz"); - rdr.consume(3); - assert_eq!(rdr.get_buf(), b""); - assert_eq!(rdr.pos, 0); - assert_eq!(rdr.cap, 0); - } - - #[test] - fn test_resize() { - let raw = b"hello world"; - let mut rdr = BufReader::with_capacity(&raw[..], 5); - rdr.read_into_buf().unwrap(); - assert_eq!(rdr.get_buf(), b"hello"); - rdr.read_into_buf().unwrap(); - assert_eq!(rdr.get_buf(), b"hello world"); - } -} diff --git a/src/client/connect.rs b/src/client/connect.rs new file mode 100644 index 0000000000..4c9a36013a --- /dev/null +++ b/src/client/connect.rs @@ -0,0 +1,219 @@ +use std::collections::hash_map::{HashMap, Entry}; +use std::hash::Hash; +use std::fmt; +use std::io; +use std::net::SocketAddr; + +use rotor::mio::tcp::TcpStream; +use url::Url; + +use net::{HttpStream, HttpsStream, Transport, SslClient}; +use super::dns::Dns; +use super::Registration; + +/// A connector creates a Transport to a remote address.. +pub trait Connect { + /// Type of Transport to create + type Output: Transport; + /// The key used to determine if an existing socket can be used. + type Key: Eq + Hash + Clone; + /// Returns the key based off the Url. + fn key(&self, &Url) -> Option; + /// Connect to a remote address. + fn connect(&mut self, &Url) -> io::Result; + /// Returns a connected socket and associated host. + fn connected(&mut self) -> Option<(Self::Key, io::Result)>; + #[doc(hidden)] + fn register(&mut self, Registration); +} + +type Scheme = String; +type Port = u16; + +/// A connector for the `http` scheme. +pub struct HttpConnector { + dns: Option, + threads: usize, + resolving: HashMap>, +} + +impl HttpConnector { + /// Set the number of resolver threads. + /// + /// Default is 4. + pub fn threads(mut self, threads: usize) -> HttpConnector { + debug_assert!(self.dns.is_none(), "setting threads after Dns is created does nothing"); + self.threads = threads; + self + } +} + +impl Default for HttpConnector { + fn default() -> HttpConnector { + HttpConnector { + dns: None, + threads: 4, + resolving: HashMap::new(), + } + } +} + +impl fmt::Debug for HttpConnector { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("HttpConnector") + .field("threads", &self.threads) + .field("resolving", &self.resolving) + .finish() + } +} + +impl Connect for HttpConnector { + type Output = HttpStream; + type Key = (&'static str, String, u16); + + fn key(&self, url: &Url) -> Option { + if url.scheme() == "http" { + Some(( + "http", + url.host_str().expect("http scheme must have host").to_owned(), + url.port().unwrap_or(80), + )) + } else { + None + } + } + + fn connect(&mut self, url: &Url) -> io::Result { + debug!("Http::connect({:?})", url); + if let Some(key) = self.key(url) { + let host = url.host_str().expect("http scheme must have a host"); + self.dns.as_ref().expect("dns workers lost").resolve(host); + self.resolving.entry(host.to_owned()).or_insert(Vec::new()).push(key.clone()); + Ok(key) + } else { + Err(io::Error::new(io::ErrorKind::InvalidInput, "scheme must be http")) + } + } + + fn connected(&mut self) -> Option<(Self::Key, io::Result)> { + let (host, addr) = match self.dns.as_ref().expect("dns workers lost").resolved() { + Ok(res) => res, + Err(_) => return None + }; + debug!("Http::resolved <- ({:?}, {:?})", host, addr); + match self.resolving.entry(host) { + Entry::Occupied(mut entry) => { + let resolved = entry.get_mut().remove(0); + if entry.get().is_empty() { + entry.remove(); + } + let port = resolved.2; + match addr { + Ok(addr) => { + Some((resolved, TcpStream::connect(&SocketAddr::new(addr, port)) + .map(HttpStream))) + }, + Err(e) => Some((resolved, Err(e))) + } + } + _ => { + trace!("^-- resolved but not in hashmap?"); + return None + } + } + } + + fn register(&mut self, reg: Registration) { + self.dns = Some(Dns::new(reg.notify, 4)); + } +} + +/// A connector that can protect HTTP streams using SSL. +#[derive(Debug, Default)] +pub struct HttpsConnector { + http: HttpConnector, + ssl: S +} + +impl HttpsConnector { + /// Create a new connector using the provided SSL implementation. + pub fn new(s: S) -> HttpsConnector { + HttpsConnector { + http: HttpConnector::default(), + ssl: s, + } + } +} + +impl Connect for HttpsConnector { + type Output = HttpsStream; + type Key = (&'static str, String, u16); + + fn key(&self, url: &Url) -> Option { + let scheme = match url.scheme() { + "http" => "http", + "https" => "https", + _ => return None + }; + Some(( + scheme, + url.host_str().expect("http scheme must have host").to_owned(), + url.port_or_known_default().expect("http scheme must have a port"), + )) + } + + fn connect(&mut self, url: &Url) -> io::Result { + debug!("Https::connect({:?})", url); + if let Some(key) = self.key(url) { + let host = url.host_str().expect("http scheme must have a host"); + self.http.dns.as_ref().expect("dns workers lost").resolve(host); + self.http.resolving.entry(host.to_owned()).or_insert(Vec::new()).push(key.clone()); + Ok(key) + } else { + Err(io::Error::new(io::ErrorKind::InvalidInput, "scheme must be http or https")) + } + } + + fn connected(&mut self) -> Option<(Self::Key, io::Result)> { + self.http.connected().map(|(key, res)| { + let res = res.and_then(|http| { + if key.0 == "https" { + self.ssl.wrap_client(http, &key.1) + .map(HttpsStream::Https) + .map_err(|e| match e { + ::Error::Io(e) => e, + e => io::Error::new(io::ErrorKind::Other, e) + }) + } else { + Ok(HttpsStream::Http(http)) + } + }); + (key, res) + }) + } + + fn register(&mut self, reg: Registration) { + self.http.register(reg); + } +} + +#[cfg(not(any(feature = "openssl", feature = "security-framework")))] +#[doc(hidden)] +pub type DefaultConnector = HttpConnector; + +#[cfg(all(feature = "openssl", not(feature = "security-framework")))] +#[doc(hidden)] +pub type DefaultConnector = HttpsConnector<::net::Openssl>; + +#[cfg(feature = "security-framework")] +#[doc(hidden)] +pub type DefaultConnector = HttpsConnector<::net::SecureTransportClient>; + +#[doc(hidden)] +pub type DefaultTransport = ::Output; + +fn _assert_defaults() { + fn _assert() where T: Connect, U: Transport {} + + _assert::(); +} diff --git a/src/client/dns.rs b/src/client/dns.rs new file mode 100644 index 0000000000..8a3579e612 --- /dev/null +++ b/src/client/dns.rs @@ -0,0 +1,84 @@ +use std::io; +use std::net::{IpAddr, ToSocketAddrs}; +use std::thread; + +use ::spmc; + +use http::channel; + +pub struct Dns { + tx: spmc::Sender, + rx: channel::Receiver, +} + +pub type Answer = (String, io::Result); + +impl Dns { + pub fn new(notify: (channel::Sender, channel::Receiver), threads: usize) -> Dns { + let (tx, rx) = spmc::channel(); + for _ in 0..threads { + work(rx.clone(), notify.0.clone()); + } + Dns { + tx: tx, + rx: notify.1, + } + } + + pub fn resolve>(&self, hostname: T) { + self.tx.send(hostname.into()).expect("Workers all died horribly"); + } + + pub fn resolved(&self) -> Result { + self.rx.try_recv() + } +} + +fn work(rx: spmc::Receiver, notify: channel::Sender) { + thread::spawn(move || { + let mut worker = Worker::new(rx, notify); + let rx = worker.rx.as_ref().expect("Worker lost rx"); + let notify = worker.notify.as_ref().expect("Worker lost notify"); + while let Ok(host) = rx.recv() { + debug!("resolve {:?}", host); + let res = match (&*host, 80).to_socket_addrs().map(|mut i| i.next()) { + Ok(Some(addr)) => (host, Ok(addr.ip())), + Ok(None) => (host, Err(io::Error::new(io::ErrorKind::Other, "no addresses found"))), + Err(e) => (host, Err(e)) + }; + + if let Err(_) = notify.send(res) { + break; + } + } + worker.shutdown = true; + }); +} + +struct Worker { + rx: Option>, + notify: Option>, + shutdown: bool, +} + +impl Worker { + fn new(rx: spmc::Receiver, notify: channel::Sender) -> Worker { + Worker { + rx: Some(rx), + notify: Some(notify), + shutdown: false, + } + } +} + +impl Drop for Worker { + fn drop(&mut self) { + if !self.shutdown { + trace!("Worker.drop panicked, restarting"); + work(self.rx.take().expect("Worker lost rx"), + self.notify.take().expect("Worker lost notify")); + } else { + trace!("Worker.drop shutdown, closing"); + } + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index c29fdc9f29..e29f5cd701 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,600 +1,607 @@ //! HTTP Client //! -//! # Usage -//! -//! The `Client` API is designed for most people to make HTTP requests. -//! It utilizes the lower level `Request` API. -//! -//! ## GET -//! -//! ```no_run -//! # use hyper::Client; -//! let client = Client::new(); -//! -//! let res = client.get("http://example.domain").send().unwrap(); -//! assert_eq!(res.status, hyper::Ok); -//! ``` -//! -//! The returned value is a `Response`, which provides easy access to -//! the `status`, the `headers`, and the response body via the `Read` -//! trait. -//! -//! ## POST -//! -//! ```no_run -//! # use hyper::Client; -//! let client = Client::new(); -//! -//! let res = client.post("http://example.domain") -//! .body("foo=bar") -//! .send() -//! .unwrap(); -//! assert_eq!(res.status, hyper::Ok); -//! ``` -//! -//! # Sync -//! -//! The `Client` implements `Sync`, so you can share it among multiple threads -//! and make multiple requests simultaneously. -//! -//! ```no_run -//! # use hyper::Client; -//! use std::sync::Arc; -//! use std::thread; -//! -//! // Note: an Arc is used here because `thread::spawn` creates threads that -//! // can outlive the main thread, so we must use reference counting to keep -//! // the Client alive long enough. Scoped threads could skip the Arc. -//! let client = Arc::new(Client::new()); -//! let clone1 = client.clone(); -//! let clone2 = client.clone(); -//! thread::spawn(move || { -//! clone1.get("http://example.domain").send().unwrap(); -//! }); -//! thread::spawn(move || { -//! clone2.post("http://example.domain/post").body("foo=bar").send().unwrap(); -//! }); -//! ``` -use std::borrow::Cow; -use std::default::Default; -use std::io::{self, copy, Read}; -use std::fmt; +//! The HTTP `Client` uses asynchronous IO, and utilizes the `Handler` trait +//! to convey when IO events are available for a given request. +use std::collections::HashMap; +use std::fmt; +use std::marker::PhantomData; +use std::sync::mpsc; +use std::thread; use std::time::Duration; -use url::Url; -use url::ParseError as UrlError; +use rotor::{self, Scope, EventSet, PollOpt}; -use header::{Headers, Header, HeaderFormat}; -use header::{ContentLength, Host, Location}; -use method::Method; -use net::{NetworkConnector, NetworkStream}; -use Error; +use header::Host; +use http::{self, Next, RequestHead}; +use net::Transport; +use uri::RequestUri; +use {Url}; -use self::proxy::tunnel; -pub use self::pool::Pool; +pub use self::connect::{Connect, DefaultConnector, HttpConnector, HttpsConnector, DefaultTransport}; pub use self::request::Request; pub use self::response::Response; -mod proxy; -pub mod pool; -pub mod request; -pub mod response; +mod connect; +mod dns; +//mod pool; +mod request; +mod response; -use http::Protocol; -use http::h1::Http11Protocol; - -/// A Client to use additional features with Requests. -/// -/// Clients can handle things such as: redirect policy, connection pooling. -pub struct Client { - protocol: Box, - redirect_policy: RedirectPolicy, - read_timeout: Option, - write_timeout: Option, - proxy: Option<(Cow<'static, str>, u16)> +/// A Client to make outgoing HTTP requests. +pub struct Client { + //handle: Option>, + tx: http::channel::Sender>, } -impl fmt::Debug for Client { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("Client") - .field("redirect_policy", &self.redirect_policy) - .field("read_timeout", &self.read_timeout) - .field("write_timeout", &self.write_timeout) - .field("proxy", &self.proxy) - .finish() +impl Clone for Client { + fn clone(&self) -> Client { + Client { + tx: self.tx.clone() + } } } -impl Client { +impl fmt::Debug for Client { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.pad("Client") + } +} - /// Create a new Client. - pub fn new() -> Client { - Client::with_pool_config(Default::default()) +impl Client { + /// Configure a Client. + /// + /// # Example + /// + /// ```dont_run + /// # use hyper::Client; + /// let client = Client::configure() + /// .keep_alive(true) + /// .max_sockets(10_000) + /// .build().unwrap(); + /// ``` + #[inline] + pub fn configure() -> Config { + Config::default() } - /// Create a new Client with a configured Pool Config. - pub fn with_pool_config(config: pool::Config) -> Client { - Client::with_connector(Pool::new(config)) + /*TODO + pub fn http() -> Config { + } - pub fn with_http_proxy(host: H, port: u16) -> Client - where H: Into> { - let host = host.into(); - let proxy = tunnel((host.clone(), port)); - let mut client = Client::with_connector(Pool::with_connector(Default::default(), proxy)); - client.proxy = Some((host, port)); - client + pub fn https() -> Config { + } + */ +} - /// Create a new client with a specific connector. - pub fn with_connector(connector: C) -> Client - where C: NetworkConnector + Send + Sync + 'static, S: NetworkStream + Send { - Client::with_protocol(Http11Protocol::with_connector(connector)) +impl::Output>> Client { + /// Create a new Client with the default config. + #[inline] + pub fn new() -> ::Result> { + Client::::configure().build() } +} - /// Create a new client with a specific `Protocol`. - pub fn with_protocol(protocol: P) -> Client { - Client { - protocol: Box::new(protocol), - redirect_policy: Default::default(), - read_timeout: None, - write_timeout: None, - proxy: None, +impl Client { + /// Create a new client with a specific connector. + fn configured(config: Config) -> ::Result> + where H: Handler, + T: Transport, + C: Connect + Send + 'static { + let mut rotor_config = rotor::Config::new(); + rotor_config.slab_capacity(config.max_sockets); + rotor_config.mio().notify_capacity(config.max_sockets); + let keep_alive = config.keep_alive; + let connect_timeout = config.connect_timeout; + let mut loop_ = try!(rotor::Loop::new(&rotor_config)); + let mut notifier = None; + let mut connector = config.connector; + { + let not = &mut notifier; + loop_.add_machine_with(move |scope| { + let (tx, rx) = http::channel::new(scope.notifier()); + let (dns_tx, dns_rx) = http::channel::share(&tx); + *not = Some(tx); + connector.register(Registration { + notify: (dns_tx, dns_rx), + }); + rotor::Response::ok(ClientFsm::Connector(connector, rx)) + }).unwrap(); } - } - /// Set the RedirectPolicy. - pub fn set_redirect_policy(&mut self, policy: RedirectPolicy) { - self.redirect_policy = policy; - } + let notifier = notifier.expect("loop.add_machine_with failed"); + let _handle = try!(thread::Builder::new().name("hyper-client".to_owned()).spawn(move || { + loop_.run(Context { + connect_timeout: connect_timeout, + keep_alive: keep_alive, + queue: HashMap::new(), + }).unwrap() + })); - /// Set the read timeout value for all requests. - pub fn set_read_timeout(&mut self, dur: Option) { - self.read_timeout = dur; + Ok(Client { + //handle: Some(handle), + tx: notifier, + }) } - /// Set the write timeout value for all requests. - pub fn set_write_timeout(&mut self, dur: Option) { - self.write_timeout = dur; + /// Build a new request using this Client. + /// + /// ## Error + /// + /// If the event loop thread has died, or the queue is full, a `ClientError` + /// will be returned. + pub fn request(&self, url: Url, handler: H) -> Result<(), ClientError> { + self.tx.send(Notify::Connect(url, handler)).map_err(|e| { + match e.0 { + Some(Notify::Connect(url, handler)) => ClientError(Some((url, handler))), + _ => ClientError(None) + } + }) } - /// Build a Get request. - pub fn get(&self, url: U) -> RequestBuilder { - self.request(Method::Get, url) + /// Close the Client loop. + pub fn close(self) { + // Most errors mean that the Receivers are already dead, which would + // imply the EventLoop panicked. + let _ = self.tx.send(Notify::Shutdown); } +} - /// Build a Head request. - pub fn head(&self, url: U) -> RequestBuilder { - self.request(Method::Head, url) - } +/// Configuration for a Client +#[derive(Debug, Clone)] +pub struct Config { + connect_timeout: Duration, + connector: C, + keep_alive: bool, + max_idle: usize, + max_sockets: usize, +} - /// Build a Patch request. - pub fn patch(&self, url: U) -> RequestBuilder { - self.request(Method::Patch, url) +impl Config where C: Connect + Send + 'static { + /// Set the `Connect` type to be used. + #[inline] + pub fn connector(self, val: CC) -> Config { + Config { + connect_timeout: self.connect_timeout, + connector: val, + keep_alive: self.keep_alive, + max_idle: self.max_idle, + max_sockets: self.max_sockets, + } } - /// Build a Post request. - pub fn post(&self, url: U) -> RequestBuilder { - self.request(Method::Post, url) + /// Enable or disable keep-alive mechanics. + /// + /// Default is enabled. + #[inline] + pub fn keep_alive(mut self, val: bool) -> Config { + self.keep_alive = val; + self } - /// Build a Put request. - pub fn put(&self, url: U) -> RequestBuilder { - self.request(Method::Put, url) + /// Set the max table size allocated for holding on to live sockets. + /// + /// Default is 1024. + #[inline] + pub fn max_sockets(mut self, val: usize) -> Config { + self.max_sockets = val; + self } - /// Build a Delete request. - pub fn delete(&self, url: U) -> RequestBuilder { - self.request(Method::Delete, url) + /// Set the timeout for connecting to a URL. + /// + /// Default is 10 seconds. + #[inline] + pub fn connect_timeout(mut self, val: Duration) -> Config { + self.connect_timeout = val; + self } + /// Construct the Client with this configuration. + #[inline] + pub fn build>(self) -> ::Result> { + Client::configured(self) + } +} - /// Build a new request using this Client. - pub fn request(&self, method: Method, url: U) -> RequestBuilder { - RequestBuilder { - client: self, - method: method, - url: url.into_url(), - body: None, - headers: None, +impl Default for Config { + fn default() -> Config { + Config { + connect_timeout: Duration::from_secs(10), + connector: DefaultConnector::default(), + keep_alive: true, + max_idle: 5, + max_sockets: 1024, } } } -impl Default for Client { - fn default() -> Client { Client::new() } +/// An error that can occur when trying to queue a request. +#[derive(Debug)] +pub struct ClientError(Option<(Url, H)>); + +impl ClientError { + /// If the event loop was down, the `Url` and `Handler` can be recovered + /// from this method. + pub fn recover(self) -> Option<(Url, H)> { + self.0 + } } -/// Options for an individual Request. -/// -/// One of these will be built for you if you use one of the convenience -/// methods, such as `get()`, `post()`, etc. -pub struct RequestBuilder<'a> { - client: &'a Client, - // We store a result here because it's good to keep RequestBuilder - // from being generic, but it is a nicer API to report the error - // from `send` (when other errors may be happening, so it already - // returns a `Result`). Why's it good to keep it non-generic? It - // stops downstream crates having to remonomorphise and recompile - // the code, which can take a while, since `send` is fairly large. - // (For an extreme example, a tiny crate containing - // `hyper::Client::new().get("x").send().unwrap();` took ~4s to - // compile with a generic RequestBuilder, but 2s with this scheme,) - url: Result, - headers: Option, - method: Method, - body: Option>, +impl ::std::error::Error for ClientError { + fn description(&self) -> &str { + "Cannot queue request" + } } -impl<'a> RequestBuilder<'a> { +impl fmt::Display for ClientError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("Cannot queue request") + } +} - /// Set a request body to be sent. - pub fn body>>(mut self, body: B) -> RequestBuilder<'a> { - self.body = Some(body.into()); - self +/* +impl Drop for Client { + fn drop(&mut self) { + self.handle.take().map(|handle| handle.join()); } +} +*/ - /// Add additional headers to the request. - pub fn headers(mut self, headers: Headers) -> RequestBuilder<'a> { - self.headers = Some(headers); - self +/// A trait to react to client events that happen for each message. +/// +/// Each event handler returns it's desired `Next` action. +pub trait Handler: Send + 'static { + /// This event occurs first, triggering when a `Request` head can be written.. + fn on_request(&mut self, request: &mut Request) -> http::Next; + /// This event occurs each time the `Request` is ready to be written to. + fn on_request_writable(&mut self, request: &mut http::Encoder) -> http::Next; + /// This event occurs after the first time this handler signals `Next::read()`, + /// and a Response has been parsed. + fn on_response(&mut self, response: Response) -> http::Next; + /// This event occurs each time the `Response` is ready to be read from. + fn on_response_readable(&mut self, response: &mut http::Decoder) -> http::Next; + + /// This event occurs whenever an `Error` occurs outside of the other events. + /// + /// This could IO errors while waiting for events, or a timeout, etc. + fn on_error(&mut self, err: ::Error) -> http::Next { + debug!("default Handler.on_error({:?})", err); + http::Next::remove() + } + + /// This event occurs when this Handler has requested to remove the Transport. + fn on_remove(self, _transport: T) where Self: Sized { + debug!("default Handler.on_remove"); + } + + /// Receive a `Control` to manage waiting for this request. + fn on_control(&mut self, _: http::Control) { + debug!("default Handler.on_control()"); } +} - /// Add an individual new header to the request. - pub fn header(mut self, header: H) -> RequestBuilder<'a> { - { - let mut headers = match self.headers { - Some(ref mut h) => h, - None => { - self.headers = Some(Headers::new()); - self.headers.as_mut().unwrap() - } - }; +struct Message, T: Transport> { + handler: H, + url: Option, + _marker: PhantomData, +} + +impl, T: Transport> http::MessageHandler for Message { + type Message = http::ClientMessage; - headers.set(header); + fn on_outgoing(&mut self, head: &mut RequestHead) -> Next { + let url = self.url.take().expect("Message.url is missing"); + if let Some(host) = url.host_str() { + head.headers.set(Host { + hostname: host.to_owned(), + port: url.port(), + }); } - self + head.subject.1 = RequestUri::AbsolutePath(url.path().to_owned()); + let mut req = self::request::new(head); + self.handler.on_request(&mut req) } - /// Execute this request and receive a Response back. - pub fn send(self) -> ::Result { - let RequestBuilder { client, method, url, headers, body } = self; - let mut url = try!(url); - trace!("send method={:?}, url={:?}, client={:?}", method, url, client); - - let can_have_body = match method { - Method::Get | Method::Head => false, - _ => true - }; + fn on_encode(&mut self, transport: &mut http::Encoder) -> Next { + self.handler.on_request_writable(transport) + } - let mut body = if can_have_body { - body - } else { - None - }; + fn on_incoming(&mut self, head: http::ResponseHead) -> Next { + trace!("on_incoming {:?}", head); + let resp = response::new(head); + self.handler.on_response(resp) + } - loop { - let mut req = { - let (host, port) = try!(get_host_and_port(&url)); - let mut message = try!(client.protocol.new_message(&host, port, url.scheme())); - if url.scheme() == "http" && client.proxy.is_some() { - message.set_proxied(true); - } + fn on_decode(&mut self, transport: &mut http::Decoder) -> Next { + self.handler.on_response_readable(transport) + } - let mut headers = match headers { - Some(ref headers) => headers.clone(), - None => Headers::new(), - }; - headers.set(Host { - hostname: host.to_owned(), - port: Some(port), - }); - Request::with_headers_and_message(method.clone(), url.clone(), headers, message) - }; - - try!(req.set_write_timeout(client.write_timeout)); - try!(req.set_read_timeout(client.read_timeout)); - - match (can_have_body, body.as_ref()) { - (true, Some(body)) => match body.size() { - Some(size) => req.headers_mut().set(ContentLength(size)), - None => (), // chunked, Request will add it automatically - }, - (true, None) => req.headers_mut().set(ContentLength(0)), - _ => () // neither - } - let mut streaming = try!(req.start()); - body.take().map(|mut rdr| copy(&mut rdr, &mut streaming)); - let res = try!(streaming.send()); - if !res.status.is_redirection() { - return Ok(res) - } - debug!("redirect code {:?} for {}", res.status, url); + fn on_error(&mut self, error: ::Error) -> Next { + self.handler.on_error(error) + } - let loc = { - // punching borrowck here - let loc = match res.headers.get::() { - Some(&Location(ref loc)) => { - Some(url.join(loc)) - } - None => { - debug!("no Location header"); - // could be 304 Not Modified? - None - } - }; - match loc { - Some(r) => r, - None => return Ok(res) - } - }; - url = match loc { - Ok(u) => u, - Err(e) => { - debug!("Location header had invalid URI: {:?}", e); - return Ok(res); - } - }; - match client.redirect_policy { - // separate branches because they can't be one - RedirectPolicy::FollowAll => (), //continue - RedirectPolicy::FollowIf(cond) if cond(&url) => (), //continue - _ => return Ok(res), - } - } + fn on_remove(self, transport: T) { + self.handler.on_remove(transport); } } -/// An enum of possible body types for a Request. -pub enum Body<'a> { - /// A Reader does not necessarily know it's size, so it is chunked. - ChunkedBody(&'a mut (Read + 'a)), - /// For Readers that can know their size, like a `File`. - SizedBody(&'a mut (Read + 'a), u64), - /// A String has a size, and uses Content-Length. - BufBody(&'a [u8] , usize), +struct Context { + connect_timeout: Duration, + keep_alive: bool, + // idle: HashMap>, + queue: HashMap>>, } -impl<'a> Body<'a> { - fn size(&self) -> Option { - match *self { - Body::SizedBody(_, len) => Some(len), - Body::BufBody(_, len) => Some(len as u64), - _ => None +impl Context { + fn pop_queue(&mut self, key: &K) -> Queued { + let mut should_remove = false; + let queued = { + let mut vec = self.queue.get_mut(key).expect("handler not in queue for key"); + let queued = vec.remove(0); + if vec.is_empty() { + should_remove = true; + } + queued + }; + if should_remove { + self.queue.remove(key); } + queued } } -impl<'a> Read for Body<'a> { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { - Body::ChunkedBody(ref mut r) => r.read(buf), - Body::SizedBody(ref mut r, _) => r.read(buf), - Body::BufBody(ref mut r, _) => Read::read(r, buf), +impl, T: Transport> http::MessageHandlerFactory for Context { + type Output = Message; + + fn create(&mut self, seed: http::Seed) -> Self::Output { + let key = seed.key(); + let queued = self.pop_queue(key); + let (url, mut handler) = (queued.url, queued.handler); + handler.on_control(seed.control()); + Message { + handler: handler, + url: Some(url), + _marker: PhantomData, } } } -impl<'a> Into> for &'a [u8] { - #[inline] - fn into(self) -> Body<'a> { - Body::BufBody(self, self.len()) - } +enum Notify { + Connect(Url, T), + Shutdown, } -impl<'a> Into> for &'a str { - #[inline] - fn into(self) -> Body<'a> { - self.as_bytes().into() - } +enum ClientFsm +where C: Connect, + C::Output: Transport, + H: Handler { + Connector(C, http::channel::Receiver>), + Socket(http::Conn>) } -impl<'a> Into> for &'a String { - #[inline] - fn into(self) -> Body<'a> { - self.as_bytes().into() +unsafe impl Send for ClientFsm +where + C: Connect + Send, + //C::Key, // Key doesn't need to be Send + C::Output: Transport, // Tranport doesn't need to be Send + H: Handler + Send +{} + +impl rotor::Machine for ClientFsm +where C: Connect, + C::Output: Transport, + H: Handler { + type Context = Context; + type Seed = (C::Key, C::Output); + + fn create(seed: Self::Seed, scope: &mut Scope) -> rotor::Response { + rotor_try!(scope.register(&seed.1, EventSet::writable(), PollOpt::level())); + rotor::Response::ok( + ClientFsm::Socket( + http::Conn::new(seed.0, seed.1, scope.notifier()) + .keep_alive(scope.keep_alive) + ) + ) + } + + fn ready(self, events: EventSet, scope: &mut Scope) -> rotor::Response { + match self { + ClientFsm::Connector(..) => { + unreachable!("Connector can never be ready") + }, + ClientFsm::Socket(conn) => { + match conn.ready(events, scope) { + Some((conn, None)) => rotor::Response::ok(ClientFsm::Socket(conn)), + Some((conn, Some(dur))) => { + rotor::Response::ok(ClientFsm::Socket(conn)) + .deadline(scope.now() + dur) + } + None => rotor::Response::done() + } + } + } } -} -impl<'a, R: Read> From<&'a mut R> for Body<'a> { - #[inline] - fn from(r: &'a mut R) -> Body<'a> { - Body::ChunkedBody(r) + fn spawned(self, scope: &mut Scope) -> rotor::Response { + match self { + ClientFsm::Connector(..) => self.connect(scope), + other => rotor::Response::ok(other) + } } -} - -/// A helper trait to convert common objects into a Url. -pub trait IntoUrl { - /// Consumes the object, trying to return a Url. - fn into_url(self) -> Result; -} -impl IntoUrl for Url { - fn into_url(self) -> Result { - Ok(self) + fn timeout(self, scope: &mut Scope) -> rotor::Response { + trace!("timeout now = {:?}", scope.now()); + match self { + ClientFsm::Connector(..) => { + let now = scope.now(); + let mut empty_keys = Vec::new(); + { + for (key, mut vec) in scope.queue.iter_mut() { + while !vec.is_empty() && vec[0].deadline <= now { + let mut queued = vec.remove(0); + let _ = queued.handler.on_error(::Error::Timeout); + } + if vec.is_empty() { + empty_keys.push(key.clone()); + } + } + } + for key in &empty_keys { + scope.queue.remove(key); + } + match self.deadline(scope) { + Some(deadline) => { + rotor::Response::ok(self).deadline(deadline) + }, + None => rotor::Response::ok(self) + } + } + ClientFsm::Socket(conn) => { + match conn.timeout(scope) { + Some((conn, None)) => rotor::Response::ok(ClientFsm::Socket(conn)), + Some((conn, Some(dur))) => { + rotor::Response::ok(ClientFsm::Socket(conn)) + .deadline(scope.now() + dur) + } + None => rotor::Response::done() + } + } + } } -} -impl<'a> IntoUrl for &'a str { - fn into_url(self) -> Result { - Url::parse(self) + fn wakeup(self, scope: &mut Scope) -> rotor::Response { + match self { + ClientFsm::Connector(..) => { + self.connect(scope) + }, + ClientFsm::Socket(conn) => match conn.wakeup(scope) { + Some((conn, None)) => rotor::Response::ok(ClientFsm::Socket(conn)), + Some((conn, Some(dur))) => { + rotor::Response::ok(ClientFsm::Socket(conn)) + .deadline(scope.now() + dur) + } + None => rotor::Response::done() + } + } } } -impl<'a> IntoUrl for &'a String { - fn into_url(self) -> Result { - Url::parse(self) +impl ClientFsm +where C: Connect, + C::Output: Transport, + H: Handler { + fn connect(self, scope: &mut rotor::Scope<::Context>) -> rotor::Response::Seed> { + match self { + ClientFsm::Connector(mut connector, rx) => { + if let Some((key, res)) = connector.connected() { + match res { + Ok(socket) => { + trace!("connected"); + return rotor::Response::spawn(ClientFsm::Connector(connector, rx), (key, socket)); + }, + Err(e) => { + trace!("connected error = {:?}", e); + let mut queued = scope.pop_queue(&key); + let _ = queued.handler.on_error(::Error::Io(e)); + } + } + } + loop { + match rx.try_recv() { + Ok(Notify::Connect(url, mut handler)) => { + // TODO: check pool for sockets to this domain + match connector.connect(&url) { + Ok(key) => { + let deadline = scope.now() + scope.connect_timeout; + scope.queue.entry(key).or_insert(Vec::new()).push(Queued { + deadline: deadline, + handler: handler, + url: url + }); + } + Err(e) => { + let _todo = handler.on_error(e.into()); + trace!("Connect error, next={:?}", _todo); + continue; + } + } + } + Ok(Notify::Shutdown) => { + scope.shutdown_loop(); + return rotor::Response::done() + }, + Err(mpsc::TryRecvError::Disconnected) => { + // if there is no way to send additional requests, + // what more can the loop do? i suppose we should + // shutdown. + scope.shutdown_loop(); + return rotor::Response::done() + } + Err(mpsc::TryRecvError::Empty) => { + // spurious wakeup or loop is done + let fsm = ClientFsm::Connector(connector, rx); + return match fsm.deadline(scope) { + Some(deadline) => { + rotor::Response::ok(fsm).deadline(deadline) + }, + None => rotor::Response::ok(fsm) + }; + } + } + } + }, + other => rotor::Response::ok(other) + } } -} -/// Behavior regarding how to handle redirects within a Client. -#[derive(Copy)] -pub enum RedirectPolicy { - /// Don't follow any redirects. - FollowNone, - /// Follow all redirects. - FollowAll, - /// Follow a redirect if the contained function returns true. - FollowIf(fn(&Url) -> bool), -} - -impl fmt::Debug for RedirectPolicy { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fn deadline(&self, scope: &mut rotor::Scope<::Context>) -> Option { match *self { - RedirectPolicy::FollowNone => fmt.write_str("FollowNone"), - RedirectPolicy::FollowAll => fmt.write_str("FollowAll"), - RedirectPolicy::FollowIf(_) => fmt.write_str("FollowIf"), + ClientFsm::Connector(..) => { + let mut earliest = None; + for vec in scope.queue.values() { + for queued in vec { + match earliest { + Some(ref mut earliest) => { + if queued.deadline < *earliest { + *earliest = queued.deadline; + } + } + None => earliest = Some(queued.deadline) + } + } + } + trace!("deadline = {:?}, now = {:?}", earliest, scope.now()); + earliest + } + _ => None } } } -// This is a hack because of upstream typesystem issues. -impl Clone for RedirectPolicy { - fn clone(&self) -> RedirectPolicy { - *self - } -} - -impl Default for RedirectPolicy { - fn default() -> RedirectPolicy { - RedirectPolicy::FollowAll - } +struct Queued { + deadline: rotor::Time, + handler: H, + url: Url, } - -fn get_host_and_port(url: &Url) -> ::Result<(&str, u16)> { - let host = match url.host_str() { - Some(host) => host, - None => return Err(Error::Uri(UrlError::EmptyHost)) - }; - trace!("host={:?}", host); - let port = match url.port_or_known_default() { - Some(port) => port, - None => return Err(Error::Uri(UrlError::InvalidPort)) - }; - trace!("port={:?}", port); - Ok((host, port)) +#[doc(hidden)] +#[allow(missing_debug_implementations)] +pub struct Registration { + notify: (http::channel::Sender, http::channel::Receiver), } #[cfg(test)] mod tests { + /* use std::io::Read; use header::Server; - use http::h1::Http11Message; - use mock::{MockStream, MockSsl}; - use super::{Client, RedirectPolicy}; - use super::proxy::Proxy; + use super::{Client}; use super::pool::Pool; use url::Url; - mock_connector!(MockRedirectPolicy { - "http://127.0.0.1" => "HTTP/1.1 301 Redirect\r\n\ - Location: http://127.0.0.2\r\n\ - Server: mock1\r\n\ - \r\n\ - " - "http://127.0.0.2" => "HTTP/1.1 302 Found\r\n\ - Location: https://127.0.0.3\r\n\ - Server: mock2\r\n\ - \r\n\ - " - "https://127.0.0.3" => "HTTP/1.1 200 OK\r\n\ - Server: mock3\r\n\ - \r\n\ - " - }); - - - #[test] - fn test_proxy() { - use super::pool::PooledStream; - type MessageStream = PooledStream>; - mock_connector!(ProxyConnector { - b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" - }); - let tunnel = Proxy { - connector: ProxyConnector, - proxy: ("example.proxy".into(), 8008), - ssl: MockSsl, - }; - let mut client = Client::with_connector(Pool::with_connector(Default::default(), tunnel)); - client.proxy = Some(("example.proxy".into(), 8008)); - let mut dump = vec![]; - client.get("http://127.0.0.1/foo/bar").send().unwrap().read_to_end(&mut dump).unwrap(); - - let box_message = client.protocol.new_message("127.0.0.1", 80, "http").unwrap(); - let message = box_message.downcast::().unwrap(); - let stream = message.into_inner().downcast::().unwrap().into_inner().into_normal().unwrap();; - - let s = ::std::str::from_utf8(&stream.write).unwrap(); - let request_line = "GET http://127.0.0.1/foo/bar HTTP/1.1\r\n"; - assert!(s.starts_with(request_line), "{:?} doesn't start with {:?}", s, request_line); - assert!(s.contains("Host: 127.0.0.1\r\n")); - } - - #[test] - fn test_proxy_tunnel() { - use super::pool::PooledStream; - type MessageStream = PooledStream>; - - mock_connector!(ProxyConnector { - b"HTTP/1.1 200 OK\r\n\r\n", - b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" - }); - let tunnel = Proxy { - connector: ProxyConnector, - proxy: ("example.proxy".into(), 8008), - ssl: MockSsl, - }; - let mut client = Client::with_connector(Pool::with_connector(Default::default(), tunnel)); - client.proxy = Some(("example.proxy".into(), 8008)); - let mut dump = vec![]; - client.get("https://127.0.0.1/foo/bar").send().unwrap().read_to_end(&mut dump).unwrap(); - - let box_message = client.protocol.new_message("127.0.0.1", 443, "https").unwrap(); - let message = box_message.downcast::().unwrap(); - let stream = message.into_inner().downcast::().unwrap().into_inner().into_tunneled().unwrap(); - - let s = ::std::str::from_utf8(&stream.write).unwrap(); - let connect_line = "CONNECT 127.0.0.1:443 HTTP/1.1\r\nHost: 127.0.0.1:443\r\n\r\n"; - assert_eq!(&s[..connect_line.len()], connect_line); - - let s = &s[connect_line.len()..]; - let request_line = "GET /foo/bar HTTP/1.1\r\n"; - assert_eq!(&s[..request_line.len()], request_line); - assert!(s.contains("Host: 127.0.0.1\r\n")); - } - - #[test] - fn test_redirect_followall() { - let mut client = Client::with_connector(MockRedirectPolicy); - client.set_redirect_policy(RedirectPolicy::FollowAll); - - let res = client.get("http://127.0.0.1").send().unwrap(); - assert_eq!(res.headers.get(), Some(&Server("mock3".to_owned()))); - } - - #[test] - fn test_redirect_dontfollow() { - let mut client = Client::with_connector(MockRedirectPolicy); - client.set_redirect_policy(RedirectPolicy::FollowNone); - let res = client.get("http://127.0.0.1").send().unwrap(); - assert_eq!(res.headers.get(), Some(&Server("mock1".to_owned()))); - } - - #[test] - fn test_redirect_followif() { - fn follow_if(url: &Url) -> bool { - !url.as_str().contains("127.0.0.3") - } - let mut client = Client::with_connector(MockRedirectPolicy); - client.set_redirect_policy(RedirectPolicy::FollowIf(follow_if)); - let res = client.get("http://127.0.0.1").send().unwrap(); - assert_eq!(res.headers.get(), Some(&Server("mock2".to_owned()))); - } - mock_connector!(Issue640Connector { b"HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\n", b"GET", @@ -621,4 +628,5 @@ mod tests { client.post("http://127.0.0.1").send().unwrap().read_to_string(&mut s).unwrap(); assert_eq!(s, "POST"); } + */ } diff --git a/src/client/proxy.rs b/src/client/proxy.rs deleted file mode 100644 index 923d12eaa3..0000000000 --- a/src/client/proxy.rs +++ /dev/null @@ -1,240 +0,0 @@ -use std::borrow::Cow; -use std::io; -use std::net::{SocketAddr, Shutdown}; -use std::time::Duration; - -use method::Method; -use net::{NetworkConnector, HttpConnector, NetworkStream, SslClient}; - -#[cfg(all(feature = "openssl", not(feature = "security-framework")))] -pub fn tunnel(proxy: (Cow<'static, str>, u16)) -> Proxy { - Proxy { - connector: HttpConnector, - proxy: proxy, - ssl: Default::default() - } -} - -#[cfg(feature = "security-framework")] -pub fn tunnel(proxy: (Cow<'static, str>, u16)) -> Proxy { - Proxy { - connector: HttpConnector, - proxy: proxy, - ssl: Default::default() - } -} - -#[cfg(not(any(feature = "openssl", feature = "security-framework")))] -pub fn tunnel(proxy: (Cow<'static, str>, u16)) -> Proxy { - Proxy { - connector: HttpConnector, - proxy: proxy, - ssl: self::no_ssl::Plaintext, - } - -} - -pub struct Proxy -where C: NetworkConnector + Send + Sync + 'static, - C::Stream: NetworkStream + Send + Clone, - S: SslClient { - pub connector: C, - pub proxy: (Cow<'static, str>, u16), - pub ssl: S, -} - - -impl NetworkConnector for Proxy -where C: NetworkConnector + Send + Sync + 'static, - C::Stream: NetworkStream + Send + Clone, - S: SslClient { - type Stream = Proxied; - - fn connect(&self, host: &str, port: u16, scheme: &str) -> ::Result { - use httparse; - use std::io::{Read, Write}; - use ::version::HttpVersion::Http11; - trace!("{:?} proxy for '{}://{}:{}'", self.proxy, scheme, host, port); - match scheme { - "http" => { - self.connector.connect(self.proxy.0.as_ref(), self.proxy.1, "http") - .map(Proxied::Normal) - }, - "https" => { - let mut stream = try!(self.connector.connect(self.proxy.0.as_ref(), self.proxy.1, "http")); - trace!("{:?} CONNECT {}:{}", self.proxy, host, port); - try!(write!(&mut stream, "{method} {host}:{port} {version}\r\nHost: {host}:{port}\r\n\r\n", - method=Method::Connect, host=host, port=port, version=Http11)); - try!(stream.flush()); - let mut buf = [0; 1024]; - let mut n = 0; - while n < buf.len() { - n += try!(stream.read(&mut buf[n..])); - let mut headers = [httparse::EMPTY_HEADER; 10]; - let mut res = httparse::Response::new(&mut headers); - if try!(res.parse(&buf[..n])).is_complete() { - let code = res.code.expect("complete parsing lost code"); - if code >= 200 && code < 300 { - trace!("CONNECT success = {:?}", code); - return self.ssl.wrap_client(stream, host) - .map(Proxied::Tunneled) - } else { - trace!("CONNECT response = {:?}", code); - return Err(::Error::Status); - } - } - } - Err(::Error::TooLarge) - }, - _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid scheme").into()) - } - } -} - -#[derive(Debug)] -pub enum Proxied { - Normal(T1), - Tunneled(T2) -} - -#[cfg(test)] -impl Proxied { - pub fn into_normal(self) -> Result { - match self { - Proxied::Normal(t1) => Ok(t1), - _ => Err(self) - } - } - - pub fn into_tunneled(self) -> Result { - match self { - Proxied::Tunneled(t2) => Ok(t2), - _ => Err(self) - } - } -} - -impl io::Read for Proxied { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { - Proxied::Normal(ref mut t) => io::Read::read(t, buf), - Proxied::Tunneled(ref mut t) => io::Read::read(t, buf), - } - } -} - -impl io::Write for Proxied { - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - match *self { - Proxied::Normal(ref mut t) => io::Write::write(t, buf), - Proxied::Tunneled(ref mut t) => io::Write::write(t, buf), - } - } - - #[inline] - fn flush(&mut self) -> io::Result<()> { - match *self { - Proxied::Normal(ref mut t) => io::Write::flush(t), - Proxied::Tunneled(ref mut t) => io::Write::flush(t), - } - } -} - -impl NetworkStream for Proxied { - #[inline] - fn peer_addr(&mut self) -> io::Result { - match *self { - Proxied::Normal(ref mut s) => s.peer_addr(), - Proxied::Tunneled(ref mut s) => s.peer_addr() - } - } - - #[inline] - fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - match *self { - Proxied::Normal(ref inner) => inner.set_read_timeout(dur), - Proxied::Tunneled(ref inner) => inner.set_read_timeout(dur) - } - } - - #[inline] - fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - match *self { - Proxied::Normal(ref inner) => inner.set_write_timeout(dur), - Proxied::Tunneled(ref inner) => inner.set_write_timeout(dur) - } - } - - #[inline] - fn close(&mut self, how: Shutdown) -> io::Result<()> { - match *self { - Proxied::Normal(ref mut s) => s.close(how), - Proxied::Tunneled(ref mut s) => s.close(how) - } - } -} - -#[cfg(not(any(feature = "openssl", feature = "security-framework")))] -mod no_ssl { - use std::io; - use std::net::{Shutdown, SocketAddr}; - use std::time::Duration; - - use net::{SslClient, NetworkStream}; - - pub struct Plaintext; - - #[derive(Clone)] - pub enum Void {} - - impl io::Read for Void { - #[inline] - fn read(&mut self, _buf: &mut [u8]) -> io::Result { - match *self {} - } - } - - impl io::Write for Void { - #[inline] - fn write(&mut self, _buf: &[u8]) -> io::Result { - match *self {} - } - - #[inline] - fn flush(&mut self) -> io::Result<()> { - match *self {} - } - } - - impl NetworkStream for Void { - #[inline] - fn peer_addr(&mut self) -> io::Result { - match *self {} - } - - #[inline] - fn set_read_timeout(&self, _dur: Option) -> io::Result<()> { - match *self {} - } - - #[inline] - fn set_write_timeout(&self, _dur: Option) -> io::Result<()> { - match *self {} - } - - #[inline] - fn close(&mut self, _how: Shutdown) -> io::Result<()> { - match *self {} - } - } - - impl SslClient for Plaintext { - type Stream = Void; - - fn wrap_client(&self, _stream: T, _host: &str) -> ::Result { - Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid scheme").into()) - } - } -} diff --git a/src/client/request.rs b/src/client/request.rs index 8743373c5b..f4e67ca215 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -1,177 +1,52 @@ //! Client Requests -use std::marker::PhantomData; -use std::io::{self, Write}; -use std::time::Duration; - -use url::Url; - -use method::Method; use header::Headers; -use header::Host; -use net::{NetworkStream, NetworkConnector, DefaultConnector, Fresh, Streaming}; -use version; -use client::{Response, get_host_and_port}; +use http::RequestHead; +use method::Method; +use uri::RequestUri; +use version::HttpVersion; -use http::{HttpMessage, RequestHead}; -use http::h1::Http11Message; /// A client request to a remote server. -/// The W type tracks the state of the request, Fresh vs Streaming. -pub struct Request { - /// The target URI for this request. - pub url: Url, - - /// The HTTP version of this request. - pub version: version::HttpVersion, - - message: Box, - headers: Headers, - method: Method, - - _marker: PhantomData, +#[derive(Debug)] +pub struct Request<'a> { + head: &'a mut RequestHead } -impl Request { - /// Read the Request headers. +impl<'a> Request<'a> { + /// Read the Request Url. #[inline] - pub fn headers(&self) -> &Headers { &self.headers } + pub fn uri(&self) -> &RequestUri { &self.head.subject.1 } - /// Read the Request method. + /// Readthe Request Version. #[inline] - pub fn method(&self) -> Method { self.method.clone() } + pub fn version(&self) -> &HttpVersion { &self.head.version } - /// Set the write timeout. + /// Read the Request headers. #[inline] - pub fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.message.set_write_timeout(dur) - } + pub fn headers(&self) -> &Headers { &self.head.headers } - /// Set the read timeout. + /// Read the Request method. #[inline] - pub fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.message.set_read_timeout(dur) - } -} - -impl Request { - /// Create a new `Request` that will use the given `HttpMessage` for its communication - /// with the server. This implies that the given `HttpMessage` instance has already been - /// properly initialized by the caller (e.g. a TCP connection's already established). - pub fn with_message(method: Method, url: Url, message: Box) - -> ::Result> { - let mut headers = Headers::new(); - { - let (host, port) = try!(get_host_and_port(&url)); - headers.set(Host { - hostname: host.to_owned(), - port: Some(port), - }); - } - - Ok(Request::with_headers_and_message(method, url, headers, message)) - } - - #[doc(hidden)] - pub fn with_headers_and_message(method: Method, url: Url, headers: Headers, message: Box) - -> Request { - Request { - method: method, - headers: headers, - url: url, - version: version::HttpVersion::Http11, - message: message, - _marker: PhantomData, - } - } + pub fn method(&self) -> &Method { &self.head.subject.0 } - /// Create a new client request. - pub fn new(method: Method, url: Url) -> ::Result> { - let conn = DefaultConnector::default(); - Request::with_connector(method, url, &conn) - } - - /// Create a new client request with a specific underlying NetworkStream. - pub fn with_connector(method: Method, url: Url, connector: &C) - -> ::Result> where - C: NetworkConnector, - S: Into> { - let stream = { - let (host, port) = try!(get_host_and_port(&url)); - try!(connector.connect(host, port, url.scheme())).into() - }; - - Request::with_message(method, url, Box::new(Http11Message::with_stream(stream))) - } - - /// Consume a Fresh Request, writing the headers and method, - /// returning a Streaming Request. - pub fn start(mut self) -> ::Result> { - let head = match self.message.set_outgoing(RequestHead { - headers: self.headers, - method: self.method, - url: self.url, - }) { - Ok(head) => head, - Err(e) => { - let _ = self.message.close_connection(); - return Err(From::from(e)); - } - }; - - Ok(Request { - method: head.method, - headers: head.headers, - url: head.url, - version: self.version, - message: self.message, - _marker: PhantomData, - }) - } + /// Set the Method of this request. + #[inline] + pub fn set_method(&mut self, method: Method) { self.head.subject.0 = method; } /// Get a mutable reference to the Request headers. #[inline] - pub fn headers_mut(&mut self) -> &mut Headers { &mut self.headers } -} - - - -impl Request { - /// Completes writing the request, and returns a response to read from. - /// - /// Consumes the Request. - pub fn send(self) -> ::Result { - Response::with_message(self.url, self.message) - } + pub fn headers_mut(&mut self) -> &mut Headers { &mut self.head.headers } } -impl Write for Request { - #[inline] - fn write(&mut self, msg: &[u8]) -> io::Result { - match self.message.write(msg) { - Ok(n) => Ok(n), - Err(e) => { - let _ = self.message.close_connection(); - Err(e) - } - } - } - - #[inline] - fn flush(&mut self) -> io::Result<()> { - match self.message.flush() { - Ok(r) => Ok(r), - Err(e) => { - let _ = self.message.close_connection(); - Err(e) - } - } - } +pub fn new(head: &mut RequestHead) -> Request { + Request { head: head } } #[cfg(test)] mod tests { + /* use std::io::Write; use std::str::from_utf8; use url::Url; @@ -311,4 +186,5 @@ mod tests { .get_ref().downcast_ref::().unwrap() .is_closed); } + */ } diff --git a/src/client/response.rs b/src/client/response.rs index 05bdf4cdc6..bb20aa6d9e 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -1,82 +1,57 @@ //! Client Responses -use std::io::{self, Read}; - -use url::Url; - use header; -use net::NetworkStream; -use http::{self, RawStatus, ResponseHead, HttpMessage}; -use http::h1::Http11Message; +//use net::NetworkStream; +use http::{self, RawStatus}; use status; use version; +pub fn new(incoming: http::ResponseHead) -> Response { + trace!("Response::new"); + let status = status::StatusCode::from_u16(incoming.subject.0); + debug!("version={:?}, status={:?}", incoming.version, status); + debug!("headers={:?}", incoming.headers); + + Response { + status: status, + version: incoming.version, + headers: incoming.headers, + status_raw: incoming.subject, + } + +} + /// A response for a client request to a remote server. #[derive(Debug)] pub struct Response { - /// The status from the server. - pub status: status::StatusCode, - /// The headers from the server. - pub headers: header::Headers, - /// The HTTP version of this response from the server. - pub version: version::HttpVersion, - /// The final URL of this response. - pub url: Url, + status: status::StatusCode, + headers: header::Headers, + version: version::HttpVersion, status_raw: RawStatus, - message: Box, } impl Response { + /// Get the headers from the server. + #[inline] + pub fn headers(&self) -> &header::Headers { &self.headers } - /// Creates a new response from a server. - pub fn new(url: Url, stream: Box) -> ::Result { - trace!("Response::new"); - Response::with_message(url, Box::new(Http11Message::with_stream(stream))) - } - - /// Creates a new response received from the server on the given `HttpMessage`. - pub fn with_message(url: Url, mut message: Box) -> ::Result { - trace!("Response::with_message"); - let ResponseHead { headers, raw_status, version } = match message.get_incoming() { - Ok(head) => head, - Err(e) => { - let _ = message.close_connection(); - return Err(From::from(e)); - } - }; - let status = status::StatusCode::from_u16(raw_status.0); - debug!("version={:?}, status={:?}", version, status); - debug!("headers={:?}", headers); - - Ok(Response { - status: status, - version: version, - headers: headers, - url: url, - status_raw: raw_status, - message: message, - }) - } + /// Get the status from the server. + #[inline] + pub fn status(&self) -> &status::StatusCode { &self.status } /// Get the raw status code and reason. #[inline] - pub fn status_raw(&self) -> &RawStatus { - &self.status_raw - } -} + pub fn status_raw(&self) -> &RawStatus { &self.status_raw } -impl Read for Response { + /// Get the final URL of this response. #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.message.read(buf) { - Err(e) => { - let _ = self.message.close_connection(); - Err(e) - } - r => r - } - } + //pub fn url(&self) -> &Url { &self.url } + + /// Get the HTTP version of this response from the server. + #[inline] + pub fn version(&self) -> &version::HttpVersion { &self.version } } +/* impl Drop for Response { fn drop(&mut self) { // if not drained, theres old bits in the Reader. we can't reuse this, @@ -94,9 +69,11 @@ impl Drop for Response { } } } +*/ #[cfg(test)] mod tests { + /* use std::io::{self, Read}; use url::Url; @@ -230,4 +207,5 @@ mod tests { assert!(Response::new(url, Box::new(stream)).is_err()); } + */ } diff --git a/src/error.rs b/src/error.rs index ab164c2b82..ced60036b0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -7,7 +7,6 @@ use std::string::FromUtf8Error; use httparse; use url; -use solicit::http::HttpError as Http2Error; #[cfg(feature = "openssl")] use openssl::ssl::error::SslError; @@ -18,10 +17,11 @@ use self::Error::{ Version, Header, Status, + Timeout, Io, Ssl, TooLarge, - Http2, + Incomplete, Utf8 }; @@ -42,14 +42,16 @@ pub enum Error { Header, /// A message head is too large to be reasonable. TooLarge, + /// A message reached EOF before being a complete message. + Incomplete, /// An invalid `Status`, such as `1337 ELITE`. Status, + /// A timeout occurred waiting for an IO event. + Timeout, /// An `io::Error` that occurred while trying to read or write to a network stream. Io(IoError), /// An error from a SSL library. Ssl(Box), - /// An HTTP/2-specific error, coming from the `solicit` library. - Http2(Http2Error), /// Parsing a field as string failed Utf8(Utf8Error), @@ -80,10 +82,11 @@ impl StdError for Error { Header => "Invalid Header provided", TooLarge => "Message head is too large", Status => "Invalid Status provided", + Incomplete => "Message is incomplete", + Timeout => "Timeout", Uri(ref e) => e.description(), Io(ref e) => e.description(), Ssl(ref e) => e.description(), - Http2(ref e) => e.description(), Utf8(ref e) => e.description(), Error::__Nonexhaustive(ref void) => match *void {} } @@ -94,7 +97,6 @@ impl StdError for Error { Io(ref error) => Some(error), Ssl(ref error) => Some(&**error), Uri(ref error) => Some(error), - Http2(ref error) => Some(error), _ => None, } } @@ -148,18 +150,11 @@ impl From for Error { } } -impl From for Error { - fn from(err: Http2Error) -> Error { - Error::Http2(err) - } -} - #[cfg(test)] mod tests { use std::error::Error as StdError; use std::io; use httparse; - use solicit::http::HttpError as Http2Error; use url; use super::Error; use super::Error::*; @@ -201,7 +196,6 @@ mod tests { from_and_cause!(io::Error::new(io::ErrorKind::Other, "other") => Io(..)); from_and_cause!(url::ParseError::EmptyHost => Uri(..)); - from_and_cause!(Http2Error::UnknownStreamId => Http2(..)); from!(httparse::Error::HeaderName => Header); from!(httparse::Error::HeaderName => Header); diff --git a/src/header/common/access_control_allow_credentials.rs b/src/header/common/access_control_allow_credentials.rs index 03ef893024..74f6496843 100644 --- a/src/header/common/access_control_allow_credentials.rs +++ b/src/header/common/access_control_allow_credentials.rs @@ -1,7 +1,7 @@ use std::fmt::{self, Display}; use std::str; use unicase::UniCase; -use header::{Header, HeaderFormat}; +use header::{Header}; /// `Access-Control-Allow-Credentials` header, part of /// [CORS](http://www.w3.org/TR/cors/#access-control-allow-headers-response-header) @@ -62,9 +62,7 @@ impl Header for AccessControlAllowCredentials { } Err(::Error::Header) } -} -impl HeaderFormat for AccessControlAllowCredentials { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("true") } @@ -86,4 +84,4 @@ mod test_access_control_allow_credentials { test_header!(not_bool, vec![b"false"], None); test_header!(only_single, vec![b"true", b"true"], None); test_header!(no_gibberish, vec!["\u{645}\u{631}\u{62d}\u{628}\u{627}".as_bytes()], None); -} \ No newline at end of file +} diff --git a/src/header/common/access_control_allow_origin.rs b/src/header/common/access_control_allow_origin.rs index 306966e867..418e9c8cda 100644 --- a/src/header/common/access_control_allow_origin.rs +++ b/src/header/common/access_control_allow_origin.rs @@ -1,6 +1,6 @@ use std::fmt::{self, Display}; -use header::{Header, HeaderFormat}; +use header::{Header}; /// The `Access-Control-Allow-Origin` response header, /// part of [CORS](http://www.w3.org/TR/cors/#access-control-allow-origin-response-header) @@ -70,9 +70,7 @@ impl Header for AccessControlAllowOrigin { _ => AccessControlAllowOrigin::Value(try!(String::from_utf8(value.clone()))) }) } -} -impl HeaderFormat for AccessControlAllowOrigin { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { AccessControlAllowOrigin::Any => f.write_str("*"), diff --git a/src/header/common/authorization.rs b/src/header/common/authorization.rs index 7d2d0c942b..d8997da8e2 100644 --- a/src/header/common/authorization.rs +++ b/src/header/common/authorization.rs @@ -3,7 +3,7 @@ use std::fmt::{self, Display}; use std::str::{FromStr, from_utf8}; use std::ops::{Deref, DerefMut}; use serialize::base64::{ToBase64, FromBase64, Standard, Config, Newline}; -use header::{Header, HeaderFormat}; +use header::{Header}; /// `Authorization` header, defined in [RFC7235](https://tools.ietf.org/html/rfc7235#section-4.2) /// @@ -97,9 +97,7 @@ impl Header for Authorization where ::Err: 'st } } } -} -impl HeaderFormat for Authorization where ::Err: 'static { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { if let Some(scheme) = ::scheme() { try!(write!(f, "{} ", scheme)) diff --git a/src/header/common/cache_control.rs b/src/header/common/cache_control.rs index adbcaf18b0..2b40685b1a 100644 --- a/src/header/common/cache_control.rs +++ b/src/header/common/cache_control.rs @@ -1,6 +1,6 @@ use std::fmt; use std::str::FromStr; -use header::{Header, HeaderFormat}; +use header::Header; use header::parsing::{from_comma_delimited, fmt_comma_delimited}; /// `Cache-Control` header, defined in [RFC7234](https://tools.ietf.org/html/rfc7234#section-5.2) @@ -62,9 +62,7 @@ impl Header for CacheControl { Err(::Error::Header) } } -} -impl HeaderFormat for CacheControl { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt_comma_delimited(f, &self[..]) } diff --git a/src/header/common/content_disposition.rs b/src/header/common/content_disposition.rs index e6d38743d2..cf9d6f0913 100644 --- a/src/header/common/content_disposition.rs +++ b/src/header/common/content_disposition.rs @@ -11,7 +11,7 @@ use std::fmt; use unicase::UniCase; use url::percent_encoding; -use header::{Header, HeaderFormat, parsing}; +use header::{Header, parsing}; use header::parsing::{parse_extended_value, HTTP_VALUE}; use header::shared::Charset; @@ -144,9 +144,7 @@ impl Header for ContentDisposition { Ok(cd) }) } -} -impl HeaderFormat for ContentDisposition { #[inline] fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(&self, f) diff --git a/src/header/common/content_length.rs b/src/header/common/content_length.rs index 8c9707aeaa..7ffa3943af 100644 --- a/src/header/common/content_length.rs +++ b/src/header/common/content_length.rs @@ -1,6 +1,6 @@ use std::fmt; -use header::{HeaderFormat, Header, parsing}; +use header::{Header, parsing}; /// `Content-Length` header, defined in /// [RFC7230](http://tools.ietf.org/html/rfc7230#section-3.3.2) @@ -55,9 +55,7 @@ impl Header for ContentLength { .unwrap_or(Err(::Error::Header)) .map(ContentLength) } -} -impl HeaderFormat for ContentLength { #[inline] fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(&self.0, f) diff --git a/src/header/common/cookie.rs b/src/header/common/cookie.rs index e6e8261ee9..dbb95c9fbc 100644 --- a/src/header/common/cookie.rs +++ b/src/header/common/cookie.rs @@ -1,4 +1,4 @@ -use header::{Header, HeaderFormat, CookiePair, CookieJar}; +use header::{Header, CookiePair, CookieJar}; use std::fmt::{self, Display}; use std::str::from_utf8; @@ -61,9 +61,7 @@ impl Header for Cookie { Err(::Error::Header) } } -} -impl HeaderFormat for Cookie { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { let cookies = &self.0; for (i, cookie) in cookies.iter().enumerate() { diff --git a/src/header/common/expect.rs b/src/header/common/expect.rs index f380e5271a..69f1d46056 100644 --- a/src/header/common/expect.rs +++ b/src/header/common/expect.rs @@ -3,7 +3,7 @@ use std::str; use unicase::UniCase; -use header::{Header, HeaderFormat}; +use header::{Header}; /// The `Expect` header. /// @@ -53,9 +53,7 @@ impl Header for Expect { Err(::Error::Header) } } -} -impl HeaderFormat for Expect { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("100-continue") } diff --git a/src/header/common/host.rs b/src/header/common/host.rs index 7c41253241..0e180bdb6c 100644 --- a/src/header/common/host.rs +++ b/src/header/common/host.rs @@ -1,4 +1,4 @@ -use header::{Header, HeaderFormat}; +use header::{Header}; use std::fmt; use header::parsing::from_one_raw_str; @@ -52,9 +52,7 @@ impl Header for Host { // https://github.com/servo/rust-url/issues/42 let idx = { let slice = &s[..]; - let mut chars = slice.chars(); - chars.next(); - if chars.next().unwrap() == '[' { + if slice.starts_with('[') { match slice.rfind(']') { Some(idx) => { if slice.len() > idx + 2 { @@ -86,9 +84,7 @@ impl Header for Host { }) }) } -} -impl HeaderFormat for Host { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.port { None | Some(80) | Some(443) => f.write_str(&self.hostname[..]), diff --git a/src/header/common/if_range.rs b/src/header/common/if_range.rs index 7399b3a40c..3828705693 100644 --- a/src/header/common/if_range.rs +++ b/src/header/common/if_range.rs @@ -1,5 +1,5 @@ use std::fmt::{self, Display}; -use header::{self, Header, HeaderFormat, EntityTag, HttpDate}; +use header::{self, Header, EntityTag, HttpDate}; /// `If-Range` header, defined in [RFC7233](http://tools.ietf.org/html/rfc7233#section-3.2) /// @@ -59,18 +59,16 @@ impl Header for IfRange { } fn parse_header(raw: &[Vec]) -> ::Result { let etag: ::Result = header::parsing::from_one_raw_str(raw); - if etag.is_ok() { - return Ok(IfRange::EntityTag(etag.unwrap())); + if let Ok(etag) = etag { + return Ok(IfRange::EntityTag(etag)); } let date: ::Result = header::parsing::from_one_raw_str(raw); - if date.is_ok() { - return Ok(IfRange::Date(date.unwrap())); + if let Ok(date) = date { + return Ok(IfRange::Date(date)); } Err(::Error::Header) } -} -impl HeaderFormat for IfRange { fn fmt_header(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { match *self { IfRange::EntityTag(ref x) => Display::fmt(x, f), diff --git a/src/header/common/mod.rs b/src/header/common/mod.rs index 62e712958e..6e9ed1d199 100644 --- a/src/header/common/mod.rs +++ b/src/header/common/mod.rs @@ -66,7 +66,7 @@ macro_rules! bench_header( use test::Bencher; use super::*; - use header::{Header, HeaderFormatter}; + use header::{Header}; #[bench] fn bench_parse(b: &mut Bencher) { @@ -79,7 +79,7 @@ macro_rules! bench_header( #[bench] fn bench_format(b: &mut Bencher) { let val: $ty = Header::parse_header(&$value[..]).unwrap(); - let fmt = HeaderFormatter(&val); + let fmt = ::header::HeaderFormatter(&val); b.iter(|| { format!("{}", fmt); }); @@ -222,15 +222,13 @@ macro_rules! header { fn parse_header(raw: &[Vec]) -> $crate::Result { $crate::header::parsing::from_comma_delimited(raw).map($id) } - } - impl $crate::header::HeaderFormat for $id { fn fmt_header(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { $crate::header::parsing::fmt_comma_delimited(f, &self.0[..]) } } impl ::std::fmt::Display for $id { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - use $crate::header::HeaderFormat; + use $crate::header::Header; self.fmt_header(f) } } @@ -250,15 +248,13 @@ macro_rules! header { fn parse_header(raw: &[Vec]) -> $crate::Result { $crate::header::parsing::from_comma_delimited(raw).map($id) } - } - impl $crate::header::HeaderFormat for $id { fn fmt_header(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { $crate::header::parsing::fmt_comma_delimited(f, &self.0[..]) } } impl ::std::fmt::Display for $id { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - use $crate::header::HeaderFormat; + use $crate::header::Header; self.fmt_header(f) } } @@ -277,8 +273,6 @@ macro_rules! header { fn parse_header(raw: &[Vec]) -> $crate::Result { $crate::header::parsing::from_one_raw_str(raw).map($id) } - } - impl $crate::header::HeaderFormat for $id { fn fmt_header(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { ::std::fmt::Display::fmt(&**self, f) } @@ -313,8 +307,6 @@ macro_rules! header { } $crate::header::parsing::from_comma_delimited(raw).map($id::Items) } - } - impl $crate::header::HeaderFormat for $id { fn fmt_header(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { match *self { $id::Any => f.write_str("*"), @@ -325,7 +317,7 @@ macro_rules! header { } impl ::std::fmt::Display for $id { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - use $crate::header::HeaderFormat; + use $crate::header::Header; self.fmt_header(f) } } diff --git a/src/header/common/pragma.rs b/src/header/common/pragma.rs index 39dee7e8c5..20f0922702 100644 --- a/src/header/common/pragma.rs +++ b/src/header/common/pragma.rs @@ -1,7 +1,7 @@ use std::fmt; use std::ascii::AsciiExt; -use header::{Header, HeaderFormat, parsing}; +use header::{Header, parsing}; /// The `Pragma` header defined by HTTP/1.0. /// @@ -52,9 +52,7 @@ impl Header for Pragma { } }) } -} -impl HeaderFormat for Pragma { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str(match *self { Pragma::NoCache => "no-cache", diff --git a/src/header/common/prefer.rs b/src/header/common/prefer.rs index 71490da1be..1209dba3e4 100644 --- a/src/header/common/prefer.rs +++ b/src/header/common/prefer.rs @@ -1,6 +1,6 @@ use std::fmt; use std::str::FromStr; -use header::{Header, HeaderFormat}; +use header::{Header}; use header::parsing::{from_comma_delimited, fmt_comma_delimited}; /// `Prefer` header, defined in [RFC7240](http://tools.ietf.org/html/rfc7240) @@ -64,9 +64,7 @@ impl Header for Prefer { Err(::Error::Header) } } -} -impl HeaderFormat for Prefer { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt_comma_delimited(f, &self[..]) } diff --git a/src/header/common/preference_applied.rs b/src/header/common/preference_applied.rs index fd1bc36fd0..bd208c1fd5 100644 --- a/src/header/common/preference_applied.rs +++ b/src/header/common/preference_applied.rs @@ -1,5 +1,5 @@ use std::fmt; -use header::{Header, HeaderFormat, Preference}; +use header::{Header, Preference}; use header::parsing::{from_comma_delimited, fmt_comma_delimited}; /// `Preference-Applied` header, defined in [RFC7240](http://tools.ietf.org/html/rfc7240) @@ -61,9 +61,7 @@ impl Header for PreferenceApplied { Err(::Error::Header) } } -} -impl HeaderFormat for PreferenceApplied { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { let preferences: Vec<_> = self.0.iter().map(|pref| match pref { // The spec ignores parameters in `Preferences-Applied` @@ -80,7 +78,7 @@ impl HeaderFormat for PreferenceApplied { #[cfg(test)] mod tests { - use header::{HeaderFormat, Preference}; + use header::{Header, Preference}; use super::*; #[test] @@ -90,7 +88,7 @@ mod tests { "foo".to_owned(), "bar".to_owned(), vec![("bar".to_owned(), "foo".to_owned()), ("buz".to_owned(), "".to_owned())] - )]) as &(HeaderFormat + Send + Sync)), + )]) as &(Header + Send + Sync)), "foo=bar".to_owned() ); } diff --git a/src/header/common/range.rs b/src/header/common/range.rs index 4eea1bcc46..69d68ef477 100644 --- a/src/header/common/range.rs +++ b/src/header/common/range.rs @@ -1,7 +1,7 @@ use std::fmt::{self, Display}; use std::str::FromStr; -use header::{Header, HeaderFormat}; +use header::Header; use header::parsing::{from_one_raw_str, from_comma_delimited}; /// `Range` header, defined in [RFC7233](https://tools.ietf.org/html/rfc7233#section-3.1) @@ -182,9 +182,6 @@ impl Header for Range { fn parse_header(raw: &[Vec]) -> ::Result { from_one_raw_str(raw) } -} - -impl HeaderFormat for Range { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { Display::fmt(self, f) diff --git a/src/header/common/set_cookie.rs b/src/header/common/set_cookie.rs index 88ac424c00..c04945089c 100644 --- a/src/header/common/set_cookie.rs +++ b/src/header/common/set_cookie.rs @@ -1,4 +1,4 @@ -use header::{Header, HeaderFormat, CookiePair, CookieJar}; +use header::{Header, CookiePair, CookieJar}; use std::fmt::{self, Display}; use std::str::from_utf8; @@ -104,10 +104,6 @@ impl Header for SetCookie { } } -} - -impl HeaderFormat for SetCookie { - fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { for (i, cookie) in self.0.iter().enumerate() { if i != 0 { diff --git a/src/header/common/strict_transport_security.rs b/src/header/common/strict_transport_security.rs index 5497830765..25aa06c226 100644 --- a/src/header/common/strict_transport_security.rs +++ b/src/header/common/strict_transport_security.rs @@ -3,7 +3,7 @@ use std::str::{self, FromStr}; use unicase::UniCase; -use header::{Header, HeaderFormat, parsing}; +use header::{Header, parsing}; /// `StrictTransportSecurity` header, defined in [RFC6797](https://tools.ietf.org/html/rfc6797) /// @@ -127,9 +127,7 @@ impl Header for StrictTransportSecurity { fn parse_header(raw: &[Vec]) -> ::Result { parsing::from_one_raw_str(raw) } -} -impl HeaderFormat for StrictTransportSecurity { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { if self.include_subdomains { write!(f, "max-age={}; includeSubdomains", self.max_age) diff --git a/src/header/common/transfer_encoding.rs b/src/header/common/transfer_encoding.rs index 13ae54a34c..a5a93b3eb3 100644 --- a/src/header/common/transfer_encoding.rs +++ b/src/header/common/transfer_encoding.rs @@ -49,5 +49,12 @@ header! { } } +impl TransferEncoding { + /// Constructor for the most common Transfer-Encoding, `chunked`. + pub fn chunked() -> TransferEncoding { + TransferEncoding(vec![Encoding::Chunked]) + } +} + bench_header!(normal, TransferEncoding, { vec![b"chunked, gzip".to_vec()] }); bench_header!(ext, TransferEncoding, { vec![b"ext".to_vec()] }); diff --git a/src/header/internals/cell.rs b/src/header/internals/cell.rs index 38365b7314..fd15b1e9af 100644 --- a/src/header/internals/cell.rs +++ b/src/header/internals/cell.rs @@ -1,7 +1,6 @@ use std::any::{Any, TypeId}; use std::cell::UnsafeCell; use std::collections::HashMap; -use std::fmt; use std::mem; use std::ops::Deref; @@ -53,7 +52,7 @@ enum PtrMap { Many(HashMap) } -impl PtrMapCell { +impl PtrMapCell { #[inline] pub fn new() -> PtrMapCell { PtrMapCell(UnsafeCell::new(PtrMap::Empty)) @@ -114,12 +113,12 @@ impl PtrMapCell { let map = &*self.0.get(); match *map { PtrMap::One(_, ref one) => one, - _ => panic!("not PtrMap::One value, {:?}", *map) + _ => panic!("not PtrMap::One value") } } } -impl Clone for PtrMapCell where Box: Clone { +impl Clone for PtrMapCell where Box: Clone { #[inline] fn clone(&self) -> PtrMapCell { let cell = PtrMapCell::new(); diff --git a/src/header/internals/item.rs b/src/header/internals/item.rs index 5b1eff92a6..d60e692203 100644 --- a/src/header/internals/item.rs +++ b/src/header/internals/item.rs @@ -4,13 +4,13 @@ use std::fmt; use std::str::from_utf8; use super::cell::{OptCell, PtrMapCell}; -use header::{Header, HeaderFormat}; +use header::{Header}; #[derive(Clone)] pub struct Item { raw: OptCell>>, - typed: PtrMapCell + typed: PtrMapCell
} impl Item { @@ -23,7 +23,7 @@ impl Item { } #[inline] - pub fn new_typed(ty: Box) -> Item { + pub fn new_typed(ty: Box
) -> Item { let map = PtrMapCell::new(); unsafe { map.insert((*ty).get_type(), ty); } Item { @@ -52,7 +52,7 @@ impl Item { &raw[..] } - pub fn typed(&self) -> Option<&H> { + pub fn typed(&self) -> Option<&H> { let tid = TypeId::of::(); match self.typed.get(tid) { Some(val) => Some(val), @@ -68,7 +68,7 @@ impl Item { }.map(|typed| unsafe { typed.downcast_ref_unchecked() }) } - pub fn typed_mut(&mut self) -> Option<&mut H> { + pub fn typed_mut(&mut self) -> Option<&mut H> { let tid = TypeId::of::(); if self.typed.get_mut(tid).is_none() { match parse::(self.raw.as_ref().expect("item.raw must exist")) { @@ -83,11 +83,11 @@ impl Item { } #[inline] -fn parse(raw: &Vec>) -> - ::Result> { +fn parse(raw: &Vec>) -> + ::Result> { Header::parse_header(&raw[..]).map(|h: H| { // FIXME: Use Type ascription - let h: Box = Box::new(h); + let h: Box
= Box::new(h); h }) } diff --git a/src/header/mod.rs b/src/header/mod.rs index 53499f80f1..baedfb8e20 100644 --- a/src/header/mod.rs +++ b/src/header/mod.rs @@ -31,18 +31,17 @@ //! } //! ``` //! -//! This works well for simple "string" headers. But the header system -//! actually involves 2 parts: parsing, and formatting. If you need to -//! customize either part, you can do so. +//! This works well for simple "string" headers. If you need more control, +//! you can implement the trait directly. //! -//! ## `Header` and `HeaderFormat` +//! ## Implementing the `Header` trait //! //! Consider a Do Not Track header. It can be true or false, but it represents //! that via the numerals `1` and `0`. //! //! ``` //! use std::fmt; -//! use hyper::header::{Header, HeaderFormat}; +//! use hyper::header::Header; //! //! #[derive(Debug, Clone, Copy)] //! struct Dnt(bool); @@ -66,9 +65,7 @@ //! } //! Err(hyper::Error::Header) //! } -//! } //! -//! impl HeaderFormat for Dnt { //! fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { //! if self.0 { //! f.write_str("1") @@ -113,11 +110,11 @@ type HeaderName = UniCase; /// /// This trait represents the construction and identification of headers, /// and contains trait-object unsafe methods. -pub trait Header: Clone + Any + Send + Sync { +pub trait Header: HeaderClone + Any + Typeable + Send + Sync { /// Returns the name of the header field this belongs to. /// /// This will become an associated constant once available. - fn header_name() -> &'static str; + fn header_name() -> &'static str where Self: Sized; /// Parse a header from a raw stream of bytes. /// /// It's possible that a request can include a header field more than once, @@ -125,35 +122,27 @@ pub trait Header: Clone + Any + Send + Sync { /// it's not necessarily the case that a Header is *allowed* to have more /// than one field value. If that's the case, you **should** return `None` /// if `raw.len() > 1`. - fn parse_header(raw: &[Vec]) -> ::Result; - -} - -/// A trait for any object that will represent a header field and value. -/// -/// This trait represents the formatting of a `Header` for output to a TcpStream. -pub trait HeaderFormat: fmt::Debug + HeaderClone + Any + Typeable + Send + Sync { + fn parse_header(raw: &[Vec]) -> ::Result where Self: Sized; /// Format a header to be output into a TcpStream. /// /// This method is not allowed to introduce an Err not produced /// by the passed-in Formatter. fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result; - } #[doc(hidden)] pub trait HeaderClone { - fn clone_box(&self) -> Box; + fn clone_box(&self) -> Box
; } -impl HeaderClone for T { +impl HeaderClone for T { #[inline] - fn clone_box(&self) -> Box { + fn clone_box(&self) -> Box
{ Box::new(self.clone()) } } -impl HeaderFormat + Send + Sync { +impl Header + Send + Sync { #[inline] unsafe fn downcast_ref_unchecked(&self) -> &T { mem::transmute(traitobject::data(self)) @@ -165,9 +154,9 @@ impl HeaderFormat + Send + Sync { } } -impl Clone for Box { +impl Clone for Box
{ #[inline] - fn clone(&self) -> Box { + fn clone(&self) -> Box
{ self.clone_box() } } @@ -183,6 +172,12 @@ pub struct Headers { data: HashMap } +impl Default for Headers { + fn default() -> Headers { + Headers::new() + } +} + impl Headers { /// Creates a new, empty headers map. @@ -212,8 +207,8 @@ impl Headers { /// Set a header field to the corresponding value. /// /// The field is determined by the type of the value being set. - pub fn set(&mut self, value: H) { - trace!("Headers.set( {:?}, {:?} )", header_name::(), value); + pub fn set(&mut self, value: H) { + trace!("Headers.set( {:?}, {:?} )", header_name::(), HeaderFormatter(&value)); self.data.insert(UniCase(CowStr(Cow::Borrowed(header_name::()))), Item::new_typed(Box::new(value))); } @@ -259,13 +254,13 @@ impl Headers { } /// Get a reference to the header field's value, if it exists. - pub fn get(&self) -> Option<&H> { + pub fn get(&self) -> Option<&H> { self.data.get(&UniCase(CowStr(Cow::Borrowed(header_name::())))) .and_then(Item::typed::) } /// Get a mutable reference to the header field's value, if it exists. - pub fn get_mut(&mut self) -> Option<&mut H> { + pub fn get_mut(&mut self) -> Option<&mut H> { self.data.get_mut(&UniCase(CowStr(Cow::Borrowed(header_name::())))) .and_then(Item::typed_mut::) } @@ -280,13 +275,13 @@ impl Headers { /// # let mut headers = Headers::new(); /// let has_type = headers.has::(); /// ``` - pub fn has(&self) -> bool { + pub fn has(&self) -> bool { self.data.contains_key(&UniCase(CowStr(Cow::Borrowed(header_name::())))) } /// Removes a header from the map, if one existed. /// Returns true if a header has been removed. - pub fn remove(&mut self) -> bool { + pub fn remove(&mut self) -> bool { trace!("Headers.remove( {:?} )", header_name::()); self.data.remove(&UniCase(CowStr(Cow::Borrowed(header_name::())))).is_some() } @@ -380,6 +375,7 @@ impl Deserialize for Headers { } /// An `Iterator` over the fields in a `Headers` map. +#[allow(missing_debug_implementations)] pub struct HeadersItems<'a> { inner: Iter<'a, HeaderName, Item> } @@ -410,7 +406,7 @@ impl<'a> HeaderView<'a> { /// Cast the value to a certain Header type. #[inline] - pub fn value(&self) -> Option<&'a H> { + pub fn value(&self) -> Option<&'a H> { self.1.typed::() } @@ -449,7 +445,7 @@ impl<'a> FromIterator> for Headers { } } -impl<'a> fmt::Display for &'a (HeaderFormat + Send + Sync) { +impl<'a> fmt::Display for &'a (Header + Send + Sync) { #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { (**self).fmt_header(f) @@ -461,16 +457,16 @@ impl<'a> fmt::Display for &'a (HeaderFormat + Send + Sync) { /// This can be used like so: `format!("{}", HeaderFormatter(&header))` to /// get the representation of a Header which will be written to an /// outgoing `TcpStream`. -pub struct HeaderFormatter<'a, H: HeaderFormat>(pub &'a H); +pub struct HeaderFormatter<'a, H: Header>(pub &'a H); -impl<'a, H: HeaderFormat> fmt::Display for HeaderFormatter<'a, H> { +impl<'a, H: Header> fmt::Display for HeaderFormatter<'a, H> { #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.0.fmt_header(f) } } -impl<'a, H: HeaderFormat> fmt::Debug for HeaderFormatter<'a, H> { +impl<'a, H: Header> fmt::Debug for HeaderFormatter<'a, H> { #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.0.fmt_header(f) @@ -519,7 +515,7 @@ mod tests { use mime::Mime; use mime::TopLevel::Text; use mime::SubLevel::Plain; - use super::{Headers, Header, HeaderFormat, ContentLength, ContentType, + use super::{Headers, Header, ContentLength, ContentType, Accept, Host, qitem}; use httparse; @@ -597,9 +593,7 @@ mod tests { None => Err(::Error::Header), } } - } - impl HeaderFormat for CrazyLength { fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { let CrazyLength(ref opt, ref value) = *self; write!(f, "{:?}, {:?}", opt, value) diff --git a/src/header/parsing.rs b/src/header/parsing.rs index 38c09db5e7..52425a78a4 100644 --- a/src/header/parsing.rs +++ b/src/header/parsing.rs @@ -137,6 +137,12 @@ define_encode_set! { } } +impl fmt::Debug for HTTP_VALUE { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.pad("HTTP_VALUE") + } +} + impl Display for ExtendedValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let encoded_value = diff --git a/src/http/buffer.rs b/src/http/buffer.rs new file mode 100644 index 0000000000..ae46491ac5 --- /dev/null +++ b/src/http/buffer.rs @@ -0,0 +1,120 @@ +use std::cmp; +use std::io::{self, Read}; +use std::ptr; + + +const INIT_BUFFER_SIZE: usize = 4096; +const MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; + +#[derive(Debug)] +pub struct Buffer { + vec: Vec, + read_pos: usize, + write_pos: usize, +} + +impl Buffer { + pub fn new() -> Buffer { + Buffer { + vec: Vec::new(), + read_pos: 0, + write_pos: 0, + } + } + + pub fn reset(&mut self) { + *self = Buffer::new() + } + + #[inline] + pub fn len(&self) -> usize { + self.read_pos - self.write_pos + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline] + pub fn bytes(&self) -> &[u8] { + &self.vec[self.write_pos..self.read_pos] + } + + #[inline] + pub fn consume(&mut self, pos: usize) { + debug_assert!(self.read_pos >= self.write_pos + pos); + self.write_pos += pos; + if self.write_pos == self.read_pos { + self.write_pos = 0; + self.read_pos = 0; + } + } + + pub fn read_from(&mut self, r: &mut R) -> io::Result { + self.maybe_reserve(); + let n = try!(r.read(&mut self.vec[self.read_pos..])); + self.read_pos += n; + Ok(n) + } + + #[inline] + fn maybe_reserve(&mut self) { + let cap = self.vec.len(); + if cap == 0 { + trace!("reserving initial {}", INIT_BUFFER_SIZE); + self.vec = vec![0; INIT_BUFFER_SIZE]; + } else if self.write_pos > 0 && self.read_pos == cap { + let count = self.read_pos - self.write_pos; + trace!("moving buffer bytes over by {}", count); + unsafe { + ptr::copy( + self.vec.as_ptr().offset(self.write_pos as isize), + self.vec.as_mut_ptr(), + count + ); + } + self.read_pos -= count; + self.write_pos = 0; + } else if self.read_pos == cap && cap < MAX_BUFFER_SIZE { + self.vec.reserve(cmp::min(cap * 4, MAX_BUFFER_SIZE) - cap); + let new = self.vec.capacity() - cap; + trace!("reserved {}", new); + unsafe { grow_zerofill(&mut self.vec, new) } + } + } + + pub fn wrap<'a, 'b: 'a, R: io::Read>(&'a mut self, reader: &'b mut R) -> BufReader<'a, R> { + BufReader { + buf: self, + reader: reader + } + } +} + +#[derive(Debug)] +pub struct BufReader<'a, R: io::Read + 'a> { + buf: &'a mut Buffer, + reader: &'a mut R +} + +impl<'a, R: io::Read> Read for BufReader<'a, R> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + trace!("BufReader.read self={}, buf={}", self.buf.len(), buf.len()); + let n = try!(self.buf.bytes().read(buf)); + self.buf.consume(n); + if n == 0 { + self.buf.reset(); + self.reader.read(&mut buf[n..]) + } else { + Ok(n) + } + } +} + +#[inline] +unsafe fn grow_zerofill(buf: &mut Vec, additional: usize) { + let len = buf.len(); + buf.set_len(len + additional); + ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()); +} diff --git a/src/http/channel.rs b/src/http/channel.rs new file mode 100644 index 0000000000..ec80148fc8 --- /dev/null +++ b/src/http/channel.rs @@ -0,0 +1,96 @@ +use std::fmt; +use std::sync::{Arc, mpsc}; +use std::sync::atomic::{AtomicBool, Ordering}; +use ::rotor; + +pub use std::sync::mpsc::TryRecvError; + +pub fn new(notify: rotor::Notifier) -> (Sender, Receiver) { + let b = Arc::new(AtomicBool::new(false)); + let (tx, rx) = mpsc::channel(); + (Sender { + awake: b.clone(), + notify: notify, + tx: tx, + }, + Receiver { + awake: b, + rx: rx, + }) +} + +pub fn share(other: &Sender) -> (Sender, Receiver) { + let (tx, rx) = mpsc::channel(); + let notify = other.notify.clone(); + let b = other.awake.clone(); + (Sender { + awake: b.clone(), + notify: notify, + tx: tx, + }, + Receiver { + awake: b, + rx: rx, + }) +} + +pub struct Sender { + awake: Arc, + notify: rotor::Notifier, + tx: mpsc::Sender, +} + +impl Sender { + pub fn send(&self, val: T) -> Result<(), SendError> { + try!(self.tx.send(val)); + if !self.awake.swap(true, Ordering::SeqCst) { + try!(self.notify.wakeup()); + } + Ok(()) + } +} + +impl Clone for Sender { + fn clone(&self) -> Sender { + Sender { + awake: self.awake.clone(), + notify: self.notify.clone(), + tx: self.tx.clone(), + } + } +} + +impl fmt::Debug for Sender { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Sender") + .field("notify", &self.notify) + .finish() + } +} + +#[derive(Debug)] +pub struct SendError(pub Option); + +impl From> for SendError { + fn from(e: mpsc::SendError) -> SendError { + SendError(Some(e.0)) + } +} + +impl From for SendError { + fn from(_e: rotor::WakeupError) -> SendError { + SendError(None) + } +} + +pub struct Receiver { + awake: Arc, + rx: mpsc::Receiver, +} + +impl Receiver { + pub fn try_recv(&self) -> Result { + self.awake.store(false, Ordering::Relaxed); + self.rx.try_recv() + } +} diff --git a/src/http/conn.rs b/src/http/conn.rs new file mode 100644 index 0000000000..b5eedb664f --- /dev/null +++ b/src/http/conn.rs @@ -0,0 +1,915 @@ +use std::borrow::Cow; +use std::fmt; +use std::hash::Hash; +use std::io; +use std::marker::PhantomData; +use std::mem; +use std::time::Duration; + +use rotor::{self, EventSet, PollOpt, Scope}; + +use http::{self, h1, Http1Message, Encoder, Decoder, Next, Next_, Reg, Control}; +use http::channel; +use http::internal::WriteBuf; +use http::buffer::Buffer; +use net::{Transport, Blocked}; +use version::HttpVersion; + +const MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; + +/// This handles a connection, which will have been established over a +/// Transport (like a socket), and will likely include multiple +/// `Message`s over HTTP. +/// +/// The connection will determine when a message begins and ends, creating +/// a new message `MessageHandler` for each one, as well as determine if this +/// connection can be kept alive after the message, or if it is complete. +pub struct Conn> { + buf: Buffer, + ctrl: (channel::Sender, channel::Receiver), + keep_alive_enabled: bool, + key: K, + state: State, + transport: T, +} + +impl> fmt::Debug for Conn { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Conn") + .field("keep_alive_enabled", &self.keep_alive_enabled) + .field("state", &self.state) + .field("buf", &self.buf) + .finish() + } +} + +impl> Conn { + pub fn new(key: K, transport: T, notify: rotor::Notifier) -> Conn { + Conn { + buf: Buffer::new(), + ctrl: channel::new(notify), + keep_alive_enabled: true, + key: key, + state: State::Init, + transport: transport, + } + } + + pub fn keep_alive(mut self, val: bool) -> Conn { + self.keep_alive_enabled = val; + self + } + + /// Desired Register interest based on state of current connection. + /// + /// This includes the user interest, such as when they return `Next::read()`. + fn interest(&self) -> Reg { + match self.state { + State::Closed => Reg::Remove, + State::Init => { + ::Message::initial_interest().interest() + } + State::Http1(Http1 { reading: Reading::Closed, writing: Writing::Closed, .. }) => { + Reg::Remove + } + State::Http1(Http1 { ref reading, ref writing, .. }) => { + let read = match *reading { + Reading::Parse | + Reading::Body(..) => Reg::Read, + Reading::Init | + Reading::Wait(..) | + Reading::KeepAlive | + Reading::Closed => Reg::Wait + }; + + let write = match *writing { + Writing::Head | + Writing::Chunk(..) | + Writing::Ready(..) => Reg::Write, + Writing::Init | + Writing::Wait(..) | + Writing::KeepAlive => Reg::Wait, + Writing::Closed => Reg::Wait, + }; + + match (read, write) { + (Reg::Read, Reg::Write) => Reg::ReadWrite, + (Reg::Read, Reg::Wait) => Reg::Read, + (Reg::Wait, Reg::Write) => Reg::Write, + (Reg::Wait, Reg::Wait) => Reg::Wait, + _ => unreachable!("bad read/write reg combo") + } + } + } + } + + /// Actual register action. + /// + /// Considers the user interest(), but also compares if the underlying + /// transport is blocked(), and adjusts accordingly. + fn register(&self) -> Reg { + let reg = self.interest(); + match (reg, self.transport.blocked()) { + (Reg::Remove, _) | + (Reg::Wait, _) | + (_, None) => reg, + + (_, Some(Blocked::Read)) => Reg::Read, + (_, Some(Blocked::Write)) => Reg::Write, + } + } + + fn parse(&mut self) -> ::Result>::Message as Http1Message>::Incoming>> { + let n = try!(self.buf.read_from(&mut self.transport)); + if n == 0 { + trace!("parse eof"); + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "parse eof").into()); + } + match try!(http::parse::<>::Message, _>(self.buf.bytes())) { + Some((head, len)) => { + trace!("parsed {} bytes out of {}", len, self.buf.len()); + self.buf.consume(len); + Ok(head) + }, + None => { + if self.buf.len() >= MAX_BUFFER_SIZE { + //TODO: Handler.on_too_large_error() + debug!("MAX_BUFFER_SIZE reached, closing"); + Err(::Error::TooLarge) + } else { + Err(io::Error::new(io::ErrorKind::WouldBlock, "incomplete parse").into()) + } + }, + } + } + + fn read>(&mut self, scope: &mut Scope, state: State) -> State { + match state { + State::Init => { + let head = match self.parse() { + Ok(head) => head, + Err(::Error::Io(e)) => match e.kind() { + io::ErrorKind::WouldBlock | + io::ErrorKind::Interrupted => return State::Init, + _ => { + debug!("io error trying to parse {:?}", e); + return State::Closed; + } + }, + Err(e) => { + //TODO: send proper error codes depending on error + trace!("parse eror: {:?}", e); + return State::Closed; + } + }; + match <>::Message as Http1Message>::decoder(&head) { + Ok(decoder) => { + trace!("decoder = {:?}", decoder); + let keep_alive = self.keep_alive_enabled && head.should_keep_alive(); + let mut handler = scope.create(Seed(&self.key, &self.ctrl.0)); + let next = handler.on_incoming(head); + trace!("handler.on_incoming() -> {:?}", next); + + match next.interest { + Next_::Read => self.read(scope, State::Http1(Http1 { + handler: handler, + reading: Reading::Body(decoder), + writing: Writing::Init, + keep_alive: keep_alive, + timeout: next.timeout, + _marker: PhantomData, + })), + Next_::Write => State::Http1(Http1 { + handler: handler, + reading: if decoder.is_eof() { + if keep_alive { + Reading::KeepAlive + } else { + Reading::Closed + } + } else { + Reading::Wait(decoder) + }, + writing: Writing::Head, + keep_alive: keep_alive, + timeout: next.timeout, + _marker: PhantomData, + }), + Next_::ReadWrite => self.read(scope, State::Http1(Http1 { + handler: handler, + reading: Reading::Body(decoder), + writing: Writing::Head, + keep_alive: keep_alive, + timeout: next.timeout, + _marker: PhantomData, + })), + Next_::Wait => State::Http1(Http1 { + handler: handler, + reading: Reading::Wait(decoder), + writing: Writing::Init, + keep_alive: keep_alive, + timeout: next.timeout, + _marker: PhantomData, + }), + Next_::End | + Next_::Remove => State::Closed, + } + }, + Err(e) => { + debug!("error creating decoder: {:?}", e); + //TODO: respond with 400 + State::Closed + } + } + }, + State::Http1(mut http1) => { + let next = match http1.reading { + Reading::Init => None, + Reading::Parse => match self.parse() { + Ok(head) => match <>::Message as Http1Message>::decoder(&head) { + Ok(decoder) => { + trace!("decoder = {:?}", decoder); + // if client request asked for keep alive, + // then it depends entirely on if the server agreed + if http1.keep_alive { + http1.keep_alive = head.should_keep_alive(); + } + let next = http1.handler.on_incoming(head); + http1.reading = Reading::Wait(decoder); + trace!("handler.on_incoming() -> {:?}", next); + Some(next) + }, + Err(e) => { + debug!("error creating decoder: {:?}", e); + //TODO: respond with 400 + return State::Closed; + } + }, + Err(::Error::Io(e)) => match e.kind() { + io::ErrorKind::WouldBlock | + io::ErrorKind::Interrupted => None, + _ => { + debug!("io error trying to parse {:?}", e); + return State::Closed; + } + }, + Err(e) => { + //TODO: send proper error codes depending on error + trace!("parse eror: {:?}", e); + return State::Closed; + } + }, + Reading::Body(ref mut decoder) => { + let wrapped = if !self.buf.is_empty() { + super::Trans::Buf(self.buf.wrap(&mut self.transport)) + } else { + super::Trans::Port(&mut self.transport) + }; + + Some(http1.handler.on_decode(&mut Decoder::h1(decoder, wrapped))) + }, + _ => { + trace!("Conn.on_readable State::Http1(reading = {:?})", http1.reading); + None + } + }; + let mut s = State::Http1(http1); + trace!("h1 read completed, next = {:?}", next); + if let Some(next) = next { + s.update(next); + } + trace!("h1 read completed, state = {:?}", s); + + let again = match s { + State::Http1(Http1 { reading: Reading::Body(ref encoder), .. }) => encoder.is_eof(), + _ => false + }; + + if again { + self.read(scope, s) + } else { + s + } + }, + State::Closed => { + error!("on_readable State::Closed"); + State::Closed + } + + } + } + + fn write>(&mut self, scope: &mut Scope, mut state: State) -> State { + let next = match state { + State::Init => { + // this could be a Client request, which writes first, so pay + // attention to the version written here, which will adjust + // our internal state to Http1 or Http2 + let mut handler = scope.create(Seed(&self.key, &self.ctrl.0)); + let mut head = http::MessageHead::default(); + let interest = handler.on_outgoing(&mut head); + if head.version == HttpVersion::Http11 { + let mut buf = Vec::new(); + let keep_alive = self.keep_alive_enabled && head.should_keep_alive(); + let mut encoder = H::Message::encode(head, &mut buf); + let writing = match interest.interest { + // user wants to write some data right away + // try to write the headers and the first chunk + // together, so they are in the same packet + Next_::Write | + Next_::ReadWrite => { + encoder.prefix(WriteBuf { + bytes: buf, + pos: 0 + }); + Writing::Ready(encoder) + }, + _ => Writing::Chunk(Chunk { + buf: Cow::Owned(buf), + pos: 0, + next: (encoder, interest.clone()) + }) + }; + state = State::Http1(Http1 { + reading: Reading::Init, + writing: writing, + handler: handler, + keep_alive: keep_alive, + timeout: interest.timeout, + _marker: PhantomData, + }) + } + Some(interest) + } + State::Http1(Http1 { ref mut handler, ref mut writing, ref mut keep_alive, .. }) => { + match *writing { + Writing::Init => { + trace!("Conn.on_writable Http1::Writing::Init"); + None + } + Writing::Head => { + let mut head = http::MessageHead::default(); + let interest = handler.on_outgoing(&mut head); + // if the request wants to close, server cannot stop it + if *keep_alive { + // if the request wants to stay alive, then it depends + // on the server to agree + *keep_alive = head.should_keep_alive(); + } + let mut buf = Vec::new(); + let mut encoder = <>::Message as Http1Message>::encode(head, &mut buf); + *writing = match interest.interest { + // user wants to write some data right away + // try to write the headers and the first chunk + // together, so they are in the same packet + Next_::Write | + Next_::ReadWrite => { + encoder.prefix(WriteBuf { + bytes: buf, + pos: 0 + }); + Writing::Ready(encoder) + }, + _ => Writing::Chunk(Chunk { + buf: Cow::Owned(buf), + pos: 0, + next: (encoder, interest.clone()) + }) + }; + Some(interest) + }, + Writing::Chunk(ref mut chunk) => { + trace!("Http1.Chunk on_writable"); + match self.transport.write(&chunk.buf.as_ref()[chunk.pos..]) { + Ok(n) => { + chunk.pos += n; + trace!("Http1.Chunk wrote={}, done={}", n, chunk.is_written()); + if chunk.is_written() { + Some(chunk.next.1.clone()) + } else { + None + } + }, + Err(e) => match e.kind() { + io::ErrorKind::WouldBlock | + io::ErrorKind::Interrupted => None, + _ => { + Some(handler.on_error(e.into())) + } + } + } + }, + Writing::Ready(ref mut encoder) => { + trace!("Http1.Ready on_writable"); + Some(handler.on_encode(&mut Encoder::h1(encoder, &mut self.transport))) + }, + Writing::Wait(..) => { + trace!("Conn.on_writable Http1::Writing::Wait"); + None + } + Writing::KeepAlive => { + trace!("Conn.on_writable Http1::Writing::KeepAlive"); + None + } + Writing::Closed => { + trace!("on_writable Http1::Writing::Closed"); + None + } + } + }, + State::Closed => { + trace!("on_writable State::Closed"); + None + } + }; + + if let Some(next) = next { + state.update(next); + } + state + } + + fn can_read_more(&self) -> bool { + match self.state { + State::Init => false, + _ => !self.buf.is_empty() + } + } + + pub fn ready(mut self, events: EventSet, scope: &mut Scope) -> Option<(Self, Option)> + where F: MessageHandlerFactory { + trace!("Conn::ready events='{:?}', blocked={:?}", events, self.transport.blocked()); + + if events.is_error() { + match self.transport.take_socket_error() { + Ok(_) => { + trace!("is_error, but not socket error"); + // spurious? + }, + Err(e) => self.on_error(e.into()) + } + } + + // if the user had an io interest, but the transport was blocked differently, + // the event needs to be translated to what the user was actually expecting. + // + // Example: + // - User asks for `Next::write(). + // - But transport is in the middle of renegotiating TLS, and is blocked on reading. + // - hyper should not wait on the `write` event, since epoll already + // knows it is writable. We would just loop a whole bunch, and slow down. + // - So instead, hyper waits on the event needed to unblock the transport, `read`. + // - Once epoll detects the transport is readable, it will alert hyper + // with a `readable` event. + // - hyper needs to translate that `readable` event back into a `write`, + // since that is actually what the Handler wants. + + let events = if let Some(blocked) = self.transport.blocked() { + let interest = self.interest(); + trace!("translating blocked={:?}, interest={:?}", blocked, interest); + match (blocked, interest) { + (Blocked::Read, Reg::Write) => EventSet::writable(), + (Blocked::Write, Reg::Read) => EventSet::readable(), + // otherwise, the transport was blocked on the same thing the user wanted + _ => events + } + } else { + events + }; + + if events.is_readable() { + self.on_readable(scope); + } + + if events.is_writable() { + self.on_writable(scope); + } + + let events = match self.register() { + Reg::Read => EventSet::readable(), + Reg::Write => EventSet::writable(), + Reg::ReadWrite => EventSet::readable() | EventSet::writable(), + Reg::Wait => EventSet::none(), + Reg::Remove => { + trace!("removing transport"); + let _ = scope.deregister(&self.transport); + self.on_remove(); + return None; + }, + }; + + if events.is_readable() && self.can_read_more() { + return self.ready(events, scope); + } + + trace!("scope.reregister({:?})", events); + match scope.reregister(&self.transport, events, PollOpt::level()) { + Ok(..) => { + let timeout = self.state.timeout(); + Some((self, timeout)) + }, + Err(e) => { + error!("error reregistering: {:?}", e); + None + } + } + } + + pub fn wakeup(mut self, scope: &mut Scope) -> Option<(Self, Option)> + where F: MessageHandlerFactory { + loop { + match self.ctrl.1.try_recv() { + Ok(next) => { + trace!("woke up with {:?}", next); + self.state.update(next); + }, + Err(_) => break + } + } + self.ready(EventSet::readable() | EventSet::writable(), scope) + } + + pub fn timeout(mut self, scope: &mut Scope) -> Option<(Self, Option)> + where F: MessageHandlerFactory { + //TODO: check if this was a spurious timeout? + self.on_error(::Error::Timeout); + self.ready(EventSet::none(), scope) + } + + fn on_error(&mut self, err: ::Error) { + debug!("on_error err = {:?}", err); + trace!("on_error state = {:?}", self.state); + let next = match self.state { + State::Init => Next::remove(), + State::Http1(ref mut http1) => http1.handler.on_error(err), + State::Closed => Next::remove(), + }; + self.state.update(next); + } + + fn on_remove(self) { + debug!("on_remove"); + match self.state { + State::Init | State::Closed => (), + State::Http1(http1) => http1.handler.on_remove(self.transport), + } + } + + fn on_readable(&mut self, scope: &mut Scope) + where F: MessageHandlerFactory { + trace!("on_readable -> {:?}", self.state); + let state = mem::replace(&mut self.state, State::Closed); + self.state = self.read(scope, state); + trace!("on_readable <- {:?}", self.state); + } + + fn on_writable(&mut self, scope: &mut Scope) + where F: MessageHandlerFactory { + trace!("on_writable -> {:?}", self.state); + let state = mem::replace(&mut self.state, State::Closed); + self.state = self.write(scope, state); + trace!("on_writable <- {:?}", self.state); + } +} + +enum State, T: Transport> { + Init, + /// Http1 will only ever use a connection to send and receive a single + /// message at a time. Once a H1 status has been determined, we will either + /// be reading or writing an H1 message, and optionally multiple if + /// keep-alive is true. + Http1(Http1), + /// Http2 allows multiplexing streams over a single connection. So even + /// when we've identified a certain message, we must always parse frame + /// head to determine if the incoming frame is part of a current message, + /// or a new one. This also means we could have multiple messages at once. + //Http2 {}, + Closed, +} + + +impl, T: Transport> State { + fn timeout(&self) -> Option { + match *self { + State::Init => None, + State::Http1(ref http1) => http1.timeout, + State::Closed => None, + } + } +} + +impl, T: Transport> fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + State::Init => f.write_str("Init"), + State::Http1(ref h1) => f.debug_tuple("Http1") + .field(h1) + .finish(), + State::Closed => f.write_str("Closed") + } + } +} + +impl, T: Transport> State { + fn update(&mut self, next: Next) { + let timeout = next.timeout; + let state = mem::replace(self, State::Closed); + let new_state = match (state, next.interest) { + (_, Next_::Remove) => State::Closed, + (State::Closed, _) => State::Closed, + (State::Init, _) => State::Init, + (State::Http1(http1), Next_::End) => { + let reading = match http1.reading { + Reading::Body(ref decoder) if decoder.is_eof() => { + if http1.keep_alive { + Reading::KeepAlive + } else { + Reading::Closed + } + }, + Reading::KeepAlive => http1.reading, + _ => Reading::Closed, + }; + let writing = match http1.writing { + Writing::Ready(ref encoder) if encoder.is_eof() => { + if http1.keep_alive { + Writing::KeepAlive + } else { + Writing::Closed + } + }, + Writing::Ready(encoder) => { + if encoder.is_eof() { + if http1.keep_alive { + Writing::KeepAlive + } else { + Writing::Closed + } + } else if let Some(buf) = encoder.end() { + Writing::Chunk(Chunk { + buf: buf.bytes, + pos: buf.pos, + next: (h1::Encoder::length(0), Next::end()) + }) + } else { + Writing::Closed + } + } + Writing::Chunk(mut chunk) => { + if chunk.is_written() { + let encoder = chunk.next.0; + //TODO: de-dupe this code and from Writing::Ready + if encoder.is_eof() { + if http1.keep_alive { + Writing::KeepAlive + } else { + Writing::Closed + } + } else if let Some(buf) = encoder.end() { + Writing::Chunk(Chunk { + buf: buf.bytes, + pos: buf.pos, + next: (h1::Encoder::length(0), Next::end()) + }) + } else { + Writing::Closed + } + } else { + chunk.next.1 = next; + Writing::Chunk(chunk) + } + }, + _ => Writing::Closed, + }; + match (reading, writing) { + (Reading::KeepAlive, Writing::KeepAlive) => State::Init, + (reading, Writing::Chunk(chunk)) => { + State::Http1(Http1 { + reading: reading, + writing: Writing::Chunk(chunk), + .. http1 + }) + } + _ => State::Closed + } + }, + (State::Http1(mut http1), Next_::Read) => { + http1.reading = match http1.reading { + Reading::Init => Reading::Parse, + Reading::Wait(decoder) => Reading::Body(decoder), + same => same + }; + + http1.writing = match http1.writing { + Writing::Ready(encoder) => if encoder.is_eof() { + if http1.keep_alive { + Writing::KeepAlive + } else { + Writing::Closed + } + } else { + Writing::Wait(encoder) + }, + Writing::Chunk(chunk) => if chunk.is_written() { + Writing::Wait(chunk.next.0) + } else { + Writing::Chunk(chunk) + }, + same => same + }; + + State::Http1(http1) + }, + (State::Http1(mut http1), Next_::Write) => { + http1.writing = match http1.writing { + Writing::Wait(encoder) => Writing::Ready(encoder), + Writing::Init => Writing::Head, + Writing::Chunk(chunk) => if chunk.is_written() { + Writing::Ready(chunk.next.0) + } else { + Writing::Chunk(chunk) + }, + same => same + }; + + http1.reading = match http1.reading { + Reading::Body(decoder) => if decoder.is_eof() { + if http1.keep_alive { + Reading::KeepAlive + } else { + Reading::Closed + } + } else { + Reading::Wait(decoder) + }, + same => same + }; + State::Http1(http1) + }, + (State::Http1(mut http1), Next_::ReadWrite) => { + http1.reading = match http1.reading { + Reading::Init => Reading::Parse, + Reading::Wait(decoder) => Reading::Body(decoder), + same => same + }; + http1.writing = match http1.writing { + Writing::Wait(encoder) => Writing::Ready(encoder), + Writing::Init => Writing::Head, + Writing::Chunk(chunk) => if chunk.is_written() { + Writing::Ready(chunk.next.0) + } else { + Writing::Chunk(chunk) + }, + same => same + }; + State::Http1(http1) + }, + (State::Http1(mut http1), Next_::Wait) => { + http1.reading = match http1.reading { + Reading::Body(decoder) => Reading::Wait(decoder), + same => same + }; + + http1.writing = match http1.writing { + Writing::Ready(encoder) => Writing::Wait(encoder), + Writing::Chunk(chunk) => if chunk.is_written() { + Writing::Wait(chunk.next.0) + } else { + Writing::Chunk(chunk) + }, + same => same + }; + State::Http1(http1) + } + }; + let new_state = match new_state { + State::Http1(mut http1) => { + http1.timeout = timeout; + State::Http1(http1) + } + other => other + }; + mem::replace(self, new_state); + } +} + +// These Reading and Writing stuff should probably get moved into h1/message.rs + +struct Http1 { + handler: H, + reading: Reading, + writing: Writing, + keep_alive: bool, + timeout: Option, + _marker: PhantomData, +} + +impl fmt::Debug for Http1 { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Http1") + .field("reading", &self.reading) + .field("writing", &self.writing) + .field("keep_alive", &self.keep_alive) + .field("timeout", &self.timeout) + .finish() + } +} + +#[derive(Debug)] +enum Reading { + Init, + Parse, + Body(h1::Decoder), + Wait(h1::Decoder), + KeepAlive, + Closed +} + +#[derive(Debug)] +enum Writing { + Init, + Head, + Chunk(Chunk) , + Ready(h1::Encoder), + Wait(h1::Encoder), + KeepAlive, + Closed +} + +#[derive(Debug)] +struct Chunk { + buf: Cow<'static, [u8]>, + pos: usize, + next: (h1::Encoder, Next), +} + +impl Chunk { + fn is_written(&self) -> bool { + self.pos >= self.buf.len() + } +} + +pub trait MessageHandler { + type Message: Http1Message; + fn on_incoming(&mut self, head: http::MessageHead<::Incoming>) -> Next; + fn on_outgoing(&mut self, head: &mut http::MessageHead<::Outgoing>) -> Next; + fn on_decode(&mut self, &mut http::Decoder) -> Next; + fn on_encode(&mut self, &mut http::Encoder) -> Next; + fn on_error(&mut self, err: ::Error) -> Next; + + fn on_remove(self, T) where Self: Sized; +} + +pub struct Seed<'a, K: Key + 'a>(&'a K, &'a channel::Sender); + +impl<'a, K: Key + 'a> Seed<'a, K> { + pub fn control(&self) -> Control { + Control { + tx: self.1.clone(), + } + } + + pub fn key(&self) -> &K { + &self.0 + } +} + + +pub trait MessageHandlerFactory { + type Output: MessageHandler; + + fn create(&mut self, seed: Seed) -> Self::Output; +} + +impl MessageHandlerFactory for F +where F: FnMut(Seed) -> H, + K: Key, + H: MessageHandler, + T: Transport { + type Output = H; + fn create(&mut self, seed: Seed) -> H { + self(seed) + } +} + +pub trait Key: Eq + Hash + Clone {} +impl Key for T {} + +#[cfg(test)] +mod tests { + /* TODO: + test when the underlying Transport of a Conn is blocked on an action that + differs from the desired interest(). + + Ex: + transport.blocked() == Some(Blocked::Read) + self.interest() == Reg::Write + + Should call `scope.register(EventSet::read())`, not with write + + #[test] + fn test_conn_register_when_transport_blocked() { + + } + */ +} diff --git a/src/http/h1.rs b/src/http/h1.rs deleted file mode 100644 index 77a3dda96d..0000000000 --- a/src/http/h1.rs +++ /dev/null @@ -1,1137 +0,0 @@ -//! Adapts the HTTP/1.1 implementation into the `HttpMessage` API. -use std::borrow::Cow; -use std::cmp::min; -use std::fmt; -use std::io::{self, Write, BufWriter, BufRead, Read}; -use std::net::Shutdown; -use std::time::Duration; - -use httparse; -use url::Position as UrlPosition; - -use buffer::BufReader; -use Error; -use header::{Headers, ContentLength, TransferEncoding}; -use header::Encoding::Chunked; -use method::{Method}; -use net::{NetworkConnector, NetworkStream}; -use status::StatusCode; -use version::HttpVersion; -use version::HttpVersion::{Http10, Http11}; -use uri::RequestUri; - -use self::HttpReader::{SizedReader, ChunkedReader, EofReader, EmptyReader}; -use self::HttpWriter::{ChunkedWriter, SizedWriter, EmptyWriter, ThroughWriter}; - -use http::{ - RawStatus, - Protocol, - HttpMessage, - RequestHead, - ResponseHead, -}; -use header; -use version; - -const MAX_INVALID_RESPONSE_BYTES: usize = 1024 * 128; - -#[derive(Debug)] -struct Wrapper { - obj: Option, -} - -impl Wrapper { - pub fn new(obj: T) -> Wrapper { - Wrapper { obj: Some(obj) } - } - - pub fn map_in_place(&mut self, f: F) where F: FnOnce(T) -> T { - let obj = self.obj.take().unwrap(); - let res = f(obj); - self.obj = Some(res); - } - - pub fn into_inner(self) -> T { self.obj.unwrap() } - pub fn as_mut(&mut self) -> &mut T { self.obj.as_mut().unwrap() } - pub fn as_ref(&self) -> &T { self.obj.as_ref().unwrap() } -} - -#[derive(Debug)] -enum Stream { - Idle(Box), - Writing(HttpWriter>>), - Reading(HttpReader>>), -} - -impl Stream { - fn writer_mut(&mut self) -> Option<&mut HttpWriter>>> { - match *self { - Stream::Writing(ref mut writer) => Some(writer), - _ => None, - } - } - fn reader_mut(&mut self) -> Option<&mut HttpReader>>> { - match *self { - Stream::Reading(ref mut reader) => Some(reader), - _ => None, - } - } - fn reader_ref(&self) -> Option<&HttpReader>>> { - match *self { - Stream::Reading(ref reader) => Some(reader), - _ => None, - } - } - - fn new(stream: Box) -> Stream { - Stream::Idle(stream) - } -} - -/// An implementation of the `HttpMessage` trait for HTTP/1.1. -#[derive(Debug)] -pub struct Http11Message { - is_proxied: bool, - method: Option, - stream: Wrapper, -} - -impl Write for Http11Message { - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - match self.stream.as_mut().writer_mut() { - None => Err(io::Error::new(io::ErrorKind::Other, - "Not in a writable state")), - Some(ref mut writer) => writer.write(buf), - } - } - #[inline] - fn flush(&mut self) -> io::Result<()> { - match self.stream.as_mut().writer_mut() { - None => Err(io::Error::new(io::ErrorKind::Other, - "Not in a writable state")), - Some(ref mut writer) => writer.flush(), - } - } -} - -impl Read for Http11Message { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.stream.as_mut().reader_mut() { - None => Err(io::Error::new(io::ErrorKind::Other, - "Not in a readable state")), - Some(ref mut reader) => reader.read(buf), - } - } -} - -impl HttpMessage for Http11Message { - fn set_outgoing(&mut self, mut head: RequestHead) -> ::Result { - let mut res = Err(Error::from(io::Error::new( - io::ErrorKind::Other, - ""))); - let mut method = None; - let is_proxied = self.is_proxied; - self.stream.map_in_place(|stream: Stream| -> Stream { - let stream = match stream { - Stream::Idle(stream) => stream, - _ => { - res = Err(Error::from(io::Error::new( - io::ErrorKind::Other, - "Message not idle, cannot start new outgoing"))); - return stream; - }, - }; - let mut stream = BufWriter::new(stream); - - { - let uri = if is_proxied { - head.url.as_ref() - } else { - &head.url[UrlPosition::BeforePath..UrlPosition::AfterQuery] - }; - - let version = version::HttpVersion::Http11; - debug!("request line: {:?} {:?} {:?}", head.method, uri, version); - match write!(&mut stream, "{} {} {}{}", - head.method, uri, version, LINE_ENDING) { - Err(e) => { - res = Err(From::from(e)); - // TODO What should we do if the BufWriter doesn't wanna - // relinquish the stream? - return Stream::Idle(stream.into_inner().ok().unwrap()); - }, - Ok(_) => {}, - }; - } - - let stream = { - let write_headers = |mut stream: BufWriter>, head: &RequestHead| { - debug!("headers={:?}", head.headers); - match write!(&mut stream, "{}{}", head.headers, LINE_ENDING) { - Ok(_) => Ok(stream), - Err(e) => { - Err((e, stream.into_inner().unwrap())) - } - } - }; - match head.method { - Method::Get | Method::Head => { - let writer = match write_headers(stream, &head) { - Ok(w) => w, - Err(e) => { - res = Err(From::from(e.0)); - return Stream::Idle(e.1); - } - }; - EmptyWriter(writer) - }, - _ => { - let mut chunked = true; - let mut len = 0; - - match head.headers.get::() { - Some(cl) => { - chunked = false; - len = **cl; - }, - None => () - }; - - // can't do in match above, thanks borrowck - if chunked { - let encodings = match head.headers.get_mut::() { - Some(encodings) => { - //TODO: check if chunked is already in encodings. use HashSet? - encodings.push(header::Encoding::Chunked); - false - }, - None => true - }; - - if encodings { - head.headers.set( - header::TransferEncoding(vec![header::Encoding::Chunked])) - } - } - - let stream = match write_headers(stream, &head) { - Ok(s) => s, - Err(e) => { - res = Err(From::from(e.0)); - return Stream::Idle(e.1); - }, - }; - - if chunked { - ChunkedWriter(stream) - } else { - SizedWriter(stream, len) - } - } - } - }; - - method = Some(head.method.clone()); - res = Ok(head); - Stream::Writing(stream) - }); - - self.method = method; - res - } - - fn get_incoming(&mut self) -> ::Result { - try!(self.flush_outgoing()); - let method = self.method.take().unwrap_or(Method::Get); - let mut res = Err(From::from( - io::Error::new(io::ErrorKind::Other, - "Read already in progress"))); - self.stream.map_in_place(|stream| { - let stream = match stream { - Stream::Idle(stream) => stream, - _ => { - // The message was already in the reading state... - // TODO Decide what happens in case we try to get a new incoming at that point - res = Err(From::from( - io::Error::new(io::ErrorKind::Other, - "Read already in progress"))); - return stream; - } - }; - - let expected_no_content = stream.previous_response_expected_no_content(); - trace!("previous_response_expected_no_content = {}", expected_no_content); - - let mut stream = BufReader::new(stream); - - let mut invalid_bytes_read = 0; - let head; - loop { - head = match parse_response(&mut stream) { - Ok(head) => head, - Err(::Error::Version) - if expected_no_content && invalid_bytes_read < MAX_INVALID_RESPONSE_BYTES => { - trace!("expected_no_content, found content"); - invalid_bytes_read += 1; - stream.consume(1); - continue; - } - Err(e) => { - res = Err(e); - return Stream::Idle(stream.into_inner()); - } - }; - break; - } - - let raw_status = head.subject; - let headers = head.headers; - - let is_empty = !should_have_response_body(&method, raw_status.0); - stream.get_mut().set_previous_response_expected_no_content(is_empty); - // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 - // 1. HEAD reponses, and Status 1xx, 204, and 304 cannot have a body. - // 2. Status 2xx to a CONNECT cannot have a body. - // 3. Transfer-Encoding: chunked has a chunked body. - // 4. If multiple differing Content-Length headers or invalid, close connection. - // 5. Content-Length header has a sized body. - // 6. Not Client. - // 7. Read till EOF. - let reader = if is_empty { - EmptyReader(stream) - } else { - if let Some(&TransferEncoding(ref codings)) = headers.get() { - if codings.last() == Some(&Chunked) { - ChunkedReader(stream, None) - } else { - trace!("not chuncked. read till eof"); - EofReader(stream) - } - } else if let Some(&ContentLength(len)) = headers.get() { - SizedReader(stream, len) - } else if headers.has::() { - trace!("illegal Content-Length: {:?}", headers.get_raw("Content-Length")); - res = Err(Error::Header); - return Stream::Idle(stream.into_inner()); - } else { - trace!("neither Transfer-Encoding nor Content-Length"); - EofReader(stream) - } - }; - - trace!("Http11Message.reader = {:?}", reader); - - - res = Ok(ResponseHead { - headers: headers, - raw_status: raw_status, - version: head.version, - }); - - Stream::Reading(reader) - }); - res - } - - fn has_body(&self) -> bool { - match self.stream.as_ref().reader_ref() { - Some(&EmptyReader(..)) | - Some(&SizedReader(_, 0)) | - Some(&ChunkedReader(_, Some(0))) => false, - // specifically EofReader is always true - _ => true - } - } - - #[inline] - fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.get_ref().set_read_timeout(dur) - } - - #[inline] - fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.get_ref().set_write_timeout(dur) - } - - #[inline] - fn close_connection(&mut self) -> ::Result<()> { - try!(self.get_mut().close(Shutdown::Both)); - Ok(()) - } - - #[inline] - fn set_proxied(&mut self, val: bool) { - self.is_proxied = val; - } -} - -impl Http11Message { - /// Consumes the `Http11Message` and returns the underlying `NetworkStream`. - pub fn into_inner(self) -> Box { - match self.stream.into_inner() { - Stream::Idle(stream) => stream, - Stream::Writing(stream) => stream.into_inner().into_inner().unwrap(), - Stream::Reading(stream) => stream.into_inner().into_inner(), - } - } - - /// Gets a mutable reference to the underlying `NetworkStream`, regardless of the state of the - /// `Http11Message`. - pub fn get_ref(&self) -> &(NetworkStream + Send) { - match *self.stream.as_ref() { - Stream::Idle(ref stream) => &**stream, - Stream::Writing(ref stream) => &**stream.get_ref().get_ref(), - Stream::Reading(ref stream) => &**stream.get_ref().get_ref() - } - } - - /// Gets a mutable reference to the underlying `NetworkStream`, regardless of the state of the - /// `Http11Message`. - pub fn get_mut(&mut self) -> &mut (NetworkStream + Send) { - match *self.stream.as_mut() { - Stream::Idle(ref mut stream) => &mut **stream, - Stream::Writing(ref mut stream) => &mut **stream.get_mut().get_mut(), - Stream::Reading(ref mut stream) => &mut **stream.get_mut().get_mut() - } - } - - /// Creates a new `Http11Message` that will use the given `NetworkStream` for communicating to - /// the peer. - pub fn with_stream(stream: Box) -> Http11Message { - Http11Message { - is_proxied: false, - method: None, - stream: Wrapper::new(Stream::new(stream)), - } - } - - /// Flushes the current outgoing content and moves the stream into the `stream` property. - /// - /// TODO It might be sensible to lift this up to the `HttpMessage` trait itself... - pub fn flush_outgoing(&mut self) -> ::Result<()> { - let mut res = Ok(()); - self.stream.map_in_place(|stream| { - let writer = match stream { - Stream::Writing(writer) => writer, - _ => { - res = Ok(()); - return stream; - }, - }; - // end() already flushes - let raw = match writer.end() { - Ok(buf) => buf.into_inner().unwrap(), - Err(e) => { - res = Err(From::from(e.0)); - return Stream::Writing(e.1); - } - }; - Stream::Idle(raw) - }); - res - } -} - -/// The `Protocol` implementation provides HTTP/1.1 messages. -pub struct Http11Protocol { - connector: Connector, -} - -impl Protocol for Http11Protocol { - fn new_message(&self, host: &str, port: u16, scheme: &str) -> ::Result> { - let stream = try!(self.connector.connect(host, port, scheme)).into(); - - Ok(Box::new(Http11Message::with_stream(stream))) - } -} - -impl Http11Protocol { - /// Creates a new `Http11Protocol` instance that will use the given `NetworkConnector` for - /// establishing HTTP connections. - pub fn with_connector(c: C) -> Http11Protocol - where C: NetworkConnector + Send + Sync + 'static, - S: NetworkStream + Send { - Http11Protocol { - connector: Connector(Box::new(ConnAdapter(c))), - } - } -} - -struct ConnAdapter(C); - -impl + Send + Sync, S: NetworkStream + Send> - NetworkConnector for ConnAdapter { - type Stream = Box; - #[inline] - fn connect(&self, host: &str, port: u16, scheme: &str) - -> ::Result> { - Ok(try!(self.0.connect(host, port, scheme)).into()) - } -} - -struct Connector(Box> + Send + Sync>); - -impl NetworkConnector for Connector { - type Stream = Box; - #[inline] - fn connect(&self, host: &str, port: u16, scheme: &str) - -> ::Result> { - Ok(try!(self.0.connect(host, port, scheme)).into()) - } -} - - -/// Readers to handle different Transfer-Encodings. -/// -/// If a message body does not include a Transfer-Encoding, it *should* -/// include a Content-Length header. -pub enum HttpReader { - /// A Reader used when a Content-Length header is passed with a positive integer. - SizedReader(R, u64), - /// A Reader used when Transfer-Encoding is `chunked`. - ChunkedReader(R, Option), - /// A Reader used for responses that don't indicate a length or chunked. - /// - /// Note: This should only used for `Response`s. It is illegal for a - /// `Request` to be made with both `Content-Length` and - /// `Transfer-Encoding: chunked` missing, as explained from the spec: - /// - /// > If a Transfer-Encoding header field is present in a response and - /// > the chunked transfer coding is not the final encoding, the - /// > message body length is determined by reading the connection until - /// > it is closed by the server. If a Transfer-Encoding header field - /// > is present in a request and the chunked transfer coding is not - /// > the final encoding, the message body length cannot be determined - /// > reliably; the server MUST respond with the 400 (Bad Request) - /// > status code and then close the connection. - EofReader(R), - /// A Reader used for messages that should never have a body. - /// - /// See https://tools.ietf.org/html/rfc7230#section-3.3.3 - EmptyReader(R), -} - -impl HttpReader { - - /// Unwraps this HttpReader and returns the underlying Reader. - pub fn into_inner(self) -> R { - match self { - SizedReader(r, _) => r, - ChunkedReader(r, _) => r, - EofReader(r) => r, - EmptyReader(r) => r, - } - } - - /// Gets a borrowed reference to the underlying Reader. - pub fn get_ref(&self) -> &R { - match *self { - SizedReader(ref r, _) => r, - ChunkedReader(ref r, _) => r, - EofReader(ref r) => r, - EmptyReader(ref r) => r, - } - } - - /// Gets a mutable reference to the underlying Reader. - pub fn get_mut(&mut self) -> &mut R { - match *self { - SizedReader(ref mut r, _) => r, - ChunkedReader(ref mut r, _) => r, - EofReader(ref mut r) => r, - EmptyReader(ref mut r) => r, - } - } -} - -impl fmt::Debug for HttpReader { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match *self { - SizedReader(_,rem) => write!(fmt, "SizedReader(remaining={:?})", rem), - ChunkedReader(_, None) => write!(fmt, "ChunkedReader(chunk_remaining=unknown)"), - ChunkedReader(_, Some(rem)) => write!(fmt, "ChunkedReader(chunk_remaining={:?})", rem), - EofReader(_) => write!(fmt, "EofReader"), - EmptyReader(_) => write!(fmt, "EmptyReader"), - } - } -} - -impl Read for HttpReader { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { - SizedReader(ref mut body, ref mut remaining) => { - trace!("Sized read, remaining={:?}", remaining); - if *remaining == 0 { - Ok(0) - } else { - let to_read = min(*remaining as usize, buf.len()); - let num = try!(body.read(&mut buf[..to_read])) as u64; - trace!("Sized read: {}", num); - if num > *remaining { - *remaining = 0; - } else if num == 0 { - return Err(io::Error::new(io::ErrorKind::Other, "early eof")); - } else { - *remaining -= num; - } - Ok(num as usize) - } - }, - ChunkedReader(ref mut body, ref mut opt_remaining) => { - let mut rem = match *opt_remaining { - Some(ref rem) => *rem, - // None means we don't know the size of the next chunk - None => try!(read_chunk_size(body)) - }; - trace!("Chunked read, remaining={:?}", rem); - - if rem == 0 { - *opt_remaining = Some(0); - - // chunk of size 0 signals the end of the chunked stream - // if the 0 digit was missing from the stream, it would - // be an InvalidInput error instead. - trace!("end of chunked"); - return Ok(0) - } - - let to_read = min(rem as usize, buf.len()); - let count = try!(body.read(&mut buf[..to_read])) as u64; - - if count == 0 { - *opt_remaining = Some(0); - return Err(io::Error::new(io::ErrorKind::Other, "early eof")); - } - - rem -= count; - *opt_remaining = if rem > 0 { - Some(rem) - } else { - try!(eat(body, LINE_ENDING.as_bytes())); - None - }; - Ok(count as usize) - }, - EofReader(ref mut body) => { - let r = body.read(buf); - trace!("eofread: {:?}", r); - r - }, - EmptyReader(_) => Ok(0) - } - } -} - -fn eat(rdr: &mut R, bytes: &[u8]) -> io::Result<()> { - let mut buf = [0]; - for &b in bytes.iter() { - match try!(rdr.read(&mut buf)) { - 1 if buf[0] == b => (), - _ => return Err(io::Error::new(io::ErrorKind::InvalidInput, - "Invalid characters found")), - } - } - Ok(()) -} - -/// Chunked chunks start with 1*HEXDIGIT, indicating the size of the chunk. -fn read_chunk_size(rdr: &mut R) -> io::Result { - macro_rules! byte ( - ($rdr:ident) => ({ - let mut buf = [0]; - match try!($rdr.read(&mut buf)) { - 1 => buf[0], - _ => return Err(io::Error::new(io::ErrorKind::InvalidInput, - "Invalid chunk size line")), - - } - }) - ); - let mut size = 0u64; - let radix = 16; - let mut in_ext = false; - let mut in_chunk_size = true; - loop { - match byte!(rdr) { - b@b'0'...b'9' if in_chunk_size => { - size *= radix; - size += (b - b'0') as u64; - }, - b@b'a'...b'f' if in_chunk_size => { - size *= radix; - size += (b + 10 - b'a') as u64; - }, - b@b'A'...b'F' if in_chunk_size => { - size *= radix; - size += (b + 10 - b'A') as u64; - }, - CR => { - match byte!(rdr) { - LF => break, - _ => return Err(io::Error::new(io::ErrorKind::InvalidInput, - "Invalid chunk size line")) - - } - }, - // If we weren't in the extension yet, the ";" signals its start - b';' if !in_ext => { - in_ext = true; - in_chunk_size = false; - }, - // "Linear white space" is ignored between the chunk size and the - // extension separator token (";") due to the "implied *LWS rule". - b'\t' | b' ' if !in_ext & !in_chunk_size => {}, - // LWS can follow the chunk size, but no more digits can come - b'\t' | b' ' if in_chunk_size => in_chunk_size = false, - // We allow any arbitrary octet once we are in the extension, since - // they all get ignored anyway. According to the HTTP spec, valid - // extensions would have a more strict syntax: - // (token ["=" (token | quoted-string)]) - // but we gain nothing by rejecting an otherwise valid chunk size. - ext if in_ext => { - todo!("chunk extension byte={}", ext); - }, - // Finally, if we aren't in the extension and we're reading any - // other octet, the chunk size line is invalid! - _ => { - return Err(io::Error::new(io::ErrorKind::InvalidInput, - "Invalid chunk size line")); - } - } - } - trace!("chunk size={:?}", size); - Ok(size) -} - -fn should_have_response_body(method: &Method, status: u16) -> bool { - trace!("should_have_response_body({:?}, {})", method, status); - match (method, status) { - (&Method::Head, _) | - (_, 100...199) | - (_, 204) | - (_, 304) | - (&Method::Connect, 200...299) => false, - _ => true - } -} - -/// Writers to handle different Transfer-Encodings. -pub enum HttpWriter { - /// A no-op Writer, used initially before Transfer-Encoding is determined. - ThroughWriter(W), - /// A Writer for when Transfer-Encoding includes `chunked`. - ChunkedWriter(W), - /// A Writer for when Content-Length is set. - /// - /// Enforces that the body is not longer than the Content-Length header. - SizedWriter(W, u64), - /// A writer that should not write any body. - EmptyWriter(W), -} - -impl HttpWriter { - /// Unwraps the HttpWriter and returns the underlying Writer. - #[inline] - pub fn into_inner(self) -> W { - match self { - ThroughWriter(w) => w, - ChunkedWriter(w) => w, - SizedWriter(w, _) => w, - EmptyWriter(w) => w, - } - } - - /// Access the inner Writer. - #[inline] - pub fn get_ref(&self) -> &W { - match *self { - ThroughWriter(ref w) => w, - ChunkedWriter(ref w) => w, - SizedWriter(ref w, _) => w, - EmptyWriter(ref w) => w, - } - } - - /// Access the inner Writer mutably. - /// - /// Warning: You should not write to this directly, as you can corrupt - /// the state. - #[inline] - pub fn get_mut(&mut self) -> &mut W { - match *self { - ThroughWriter(ref mut w) => w, - ChunkedWriter(ref mut w) => w, - SizedWriter(ref mut w, _) => w, - EmptyWriter(ref mut w) => w, - } - } - - /// Ends the HttpWriter, and returns the underlying Writer. - /// - /// A final `write_all()` is called with an empty message, and then flushed. - /// The ChunkedWriter variant will use this to write the 0-sized last-chunk. - #[inline] - pub fn end(mut self) -> Result> { - fn inner(w: &mut W) -> io::Result<()> { - try!(w.write(&[])); - w.flush() - } - - match inner(&mut self) { - Ok(..) => Ok(self.into_inner()), - Err(e) => Err(EndError(e, self)) - } - } -} - -#[derive(Debug)] -pub struct EndError(io::Error, HttpWriter); - -impl From> for io::Error { - fn from(e: EndError) -> io::Error { - e.0 - } -} - -impl Write for HttpWriter { - #[inline] - fn write(&mut self, msg: &[u8]) -> io::Result { - match *self { - ThroughWriter(ref mut w) => w.write(msg), - ChunkedWriter(ref mut w) => { - let chunk_size = msg.len(); - trace!("chunked write, size = {:?}", chunk_size); - try!(write!(w, "{:X}{}", chunk_size, LINE_ENDING)); - try!(w.write_all(msg)); - try!(w.write_all(LINE_ENDING.as_bytes())); - Ok(msg.len()) - }, - SizedWriter(ref mut w, ref mut remaining) => { - let len = msg.len() as u64; - if len > *remaining { - let len = *remaining; - *remaining = 0; - try!(w.write_all(&msg[..len as usize])); - Ok(len as usize) - } else { - *remaining -= len; - try!(w.write_all(msg)); - Ok(len as usize) - } - }, - EmptyWriter(..) => { - if !msg.is_empty() { - error!("Cannot include a body with this kind of message"); - } - Ok(0) - } - } - } - - #[inline] - fn flush(&mut self) -> io::Result<()> { - match *self { - ThroughWriter(ref mut w) => w.flush(), - ChunkedWriter(ref mut w) => w.flush(), - SizedWriter(ref mut w, _) => w.flush(), - EmptyWriter(ref mut w) => w.flush(), - } - } -} - -impl fmt::Debug for HttpWriter { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match *self { - ThroughWriter(_) => write!(fmt, "ThroughWriter"), - ChunkedWriter(_) => write!(fmt, "ChunkedWriter"), - SizedWriter(_, rem) => write!(fmt, "SizedWriter(remaining={:?})", rem), - EmptyWriter(_) => write!(fmt, "EmptyWriter"), - } - } -} - -const MAX_HEADERS: usize = 100; - -/// Parses a request into an Incoming message head. -#[inline] -pub fn parse_request(buf: &mut BufReader) -> ::Result> { - parse::(buf) -} - -/// Parses a response into an Incoming message head. -#[inline] -pub fn parse_response(buf: &mut BufReader) -> ::Result> { - parse::(buf) -} - -fn parse, I>(rdr: &mut BufReader) -> ::Result> { - loop { - match try!(try_parse::(rdr)) { - httparse::Status::Complete((inc, len)) => { - rdr.consume(len); - return Ok(inc); - }, - _partial => () - } - match try!(rdr.read_into_buf()) { - 0 if rdr.get_buf().is_empty() => { - return Err(Error::Io(io::Error::new( - io::ErrorKind::ConnectionAborted, - "Connection closed" - ))) - }, - 0 => return Err(Error::TooLarge), - _ => () - } - } -} - -fn try_parse, I>(rdr: &mut BufReader) -> TryParseResult { - let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; - let buf = rdr.get_buf(); - if buf.len() == 0 { - return Ok(httparse::Status::Partial); - } - trace!("try_parse({:?})", buf); - ::try_parse(&mut headers, buf) -} - -#[doc(hidden)] -trait TryParse { - type Subject; - fn try_parse<'a>(headers: &'a mut [httparse::Header<'a>], buf: &'a [u8]) -> - TryParseResult; -} - -type TryParseResult = Result, usize)>, Error>; - -impl<'a> TryParse for httparse::Request<'a, 'a> { - type Subject = (Method, RequestUri); - - fn try_parse<'b>(headers: &'b mut [httparse::Header<'b>], buf: &'b [u8]) -> - TryParseResult<(Method, RequestUri)> { - trace!("Request.try_parse([Header; {}], [u8; {}])", headers.len(), buf.len()); - let mut req = httparse::Request::new(headers); - Ok(match try!(req.parse(buf)) { - httparse::Status::Complete(len) => { - trace!("Request.try_parse Complete({})", len); - httparse::Status::Complete((Incoming { - version: if req.version.unwrap() == 1 { Http11 } else { Http10 }, - subject: ( - try!(req.method.unwrap().parse()), - try!(req.path.unwrap().parse()) - ), - headers: try!(Headers::from_raw(req.headers)) - }, len)) - }, - httparse::Status::Partial => httparse::Status::Partial - }) - } -} - -impl<'a> TryParse for httparse::Response<'a, 'a> { - type Subject = RawStatus; - - fn try_parse<'b>(headers: &'b mut [httparse::Header<'b>], buf: &'b [u8]) -> - TryParseResult { - trace!("Response.try_parse([Header; {}], [u8; {}])", headers.len(), buf.len()); - let mut res = httparse::Response::new(headers); - Ok(match try!(res.parse(buf)) { - httparse::Status::Complete(len) => { - trace!("Response.try_parse Complete({})", len); - let code = res.code.unwrap(); - let reason = match StatusCode::from_u16(code).canonical_reason() { - Some(reason) if reason == res.reason.unwrap() => Cow::Borrowed(reason), - _ => Cow::Owned(res.reason.unwrap().to_owned()) - }; - httparse::Status::Complete((Incoming { - version: if res.version.unwrap() == 1 { Http11 } else { Http10 }, - subject: RawStatus(code, reason), - headers: try!(Headers::from_raw(res.headers)) - }, len)) - }, - httparse::Status::Partial => httparse::Status::Partial - }) - } -} - -/// An Incoming Message head. Includes request/status line, and headers. -#[derive(Debug)] -pub struct Incoming { - /// HTTP version of the message. - pub version: HttpVersion, - /// Subject (request line or status line) of Incoming message. - pub subject: S, - /// Headers of the Incoming message. - pub headers: Headers -} - -/// The `\r` byte. -pub const CR: u8 = b'\r'; -/// The `\n` byte. -pub const LF: u8 = b'\n'; -/// The bytes `\r\n`. -pub const LINE_ENDING: &'static str = "\r\n"; - -#[cfg(test)] -mod tests { - use std::error::Error; - use std::io::{self, Read, Write}; - - - use buffer::BufReader; - use mock::MockStream; - use http::HttpMessage; - - use super::{read_chunk_size, parse_request, parse_response, Http11Message}; - - #[test] - fn test_write_chunked() { - use std::str::from_utf8; - let mut w = super::HttpWriter::ChunkedWriter(Vec::new()); - w.write_all(b"foo bar").unwrap(); - w.write_all(b"baz quux herp").unwrap(); - let buf = w.end().unwrap(); - let s = from_utf8(buf.as_ref()).unwrap(); - assert_eq!(s, "7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n"); - } - - #[test] - fn test_write_sized() { - use std::str::from_utf8; - let mut w = super::HttpWriter::SizedWriter(Vec::new(), 8); - w.write_all(b"foo bar").unwrap(); - assert_eq!(w.write(b"baz").unwrap(), 1); - - let buf = w.end().unwrap(); - let s = from_utf8(buf.as_ref()).unwrap(); - assert_eq!(s, "foo barb"); - } - - #[test] - fn test_read_chunk_size() { - fn read(s: &str, result: u64) { - assert_eq!(read_chunk_size(&mut s.as_bytes()).unwrap(), result); - } - - fn read_err(s: &str) { - assert_eq!(read_chunk_size(&mut s.as_bytes()).unwrap_err().kind(), - io::ErrorKind::InvalidInput); - } - - read("1\r\n", 1); - read("01\r\n", 1); - read("0\r\n", 0); - read("00\r\n", 0); - read("A\r\n", 10); - read("a\r\n", 10); - read("Ff\r\n", 255); - read("Ff \r\n", 255); - // Missing LF or CRLF - read_err("F\rF"); - read_err("F"); - // Invalid hex digit - read_err("X\r\n"); - read_err("1X\r\n"); - read_err("-\r\n"); - read_err("-1\r\n"); - // Acceptable (if not fully valid) extensions do not influence the size - read("1;extension\r\n", 1); - read("a;ext name=value\r\n", 10); - read("1;extension;extension2\r\n", 1); - read("1;;; ;\r\n", 1); - read("2; extension...\r\n", 2); - read("3 ; extension=123\r\n", 3); - read("3 ;\r\n", 3); - read("3 ; \r\n", 3); - // Invalid extensions cause an error - read_err("1 invalid extension\r\n"); - read_err("1 A\r\n"); - read_err("1;no CRLF"); - } - - #[test] - fn test_read_sized_early_eof() { - let mut r = super::HttpReader::SizedReader(MockStream::with_input(b"foo bar"), 10); - let mut buf = [0u8; 10]; - assert_eq!(r.read(&mut buf).unwrap(), 7); - let e = r.read(&mut buf).unwrap_err(); - assert_eq!(e.kind(), io::ErrorKind::Other); - assert_eq!(e.description(), "early eof"); - } - - #[test] - fn test_read_chunked_early_eof() { - let mut r = super::HttpReader::ChunkedReader(MockStream::with_input(b"\ - 9\r\n\ - foo bar\ - "), None); - - let mut buf = [0u8; 10]; - assert_eq!(r.read(&mut buf).unwrap(), 7); - let e = r.read(&mut buf).unwrap_err(); - assert_eq!(e.kind(), io::ErrorKind::Other); - assert_eq!(e.description(), "early eof"); - } - - #[test] - fn test_message_get_incoming_invalid_content_length() { - let raw = MockStream::with_input( - b"HTTP/1.1 200 OK\r\nContent-Length: asdf\r\n\r\n"); - let mut msg = Http11Message::with_stream(Box::new(raw)); - assert!(msg.get_incoming().is_err()); - assert!(msg.close_connection().is_ok()); - } - - #[test] - fn test_parse_incoming() { - let mut raw = MockStream::with_input(b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); - let mut buf = BufReader::new(&mut raw); - parse_request(&mut buf).unwrap(); - } - - #[test] - fn test_parse_raw_status() { - let mut raw = MockStream::with_input(b"HTTP/1.1 200 OK\r\n\r\n"); - let mut buf = BufReader::new(&mut raw); - let res = parse_response(&mut buf).unwrap(); - - assert_eq!(res.subject.1, "OK"); - - let mut raw = MockStream::with_input(b"HTTP/1.1 200 Howdy\r\n\r\n"); - let mut buf = BufReader::new(&mut raw); - let res = parse_response(&mut buf).unwrap(); - - assert_eq!(res.subject.1, "Howdy"); - } - - - #[test] - fn test_parse_tcp_closed() { - use std::io::ErrorKind; - use error::Error; - - let mut empty = MockStream::new(); - let mut buf = BufReader::new(&mut empty); - match parse_request(&mut buf) { - Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => (), - other => panic!("unexpected result: {:?}", other) - } - } - - #[cfg(feature = "nightly")] - use test::Bencher; - - #[cfg(feature = "nightly")] - #[bench] - fn bench_parse_incoming(b: &mut Bencher) { - let mut raw = MockStream::with_input(b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); - let mut buf = BufReader::new(&mut raw); - b.iter(|| { - parse_request(&mut buf).unwrap(); - buf.get_mut().read.set_position(0); - }); - } -} diff --git a/src/http/h1/decode.rs b/src/http/h1/decode.rs new file mode 100644 index 0000000000..1cc4b586fb --- /dev/null +++ b/src/http/h1/decode.rs @@ -0,0 +1,293 @@ +use std::cmp; +use std::io::{self, Read}; + +use self::Kind::{Length, Chunked, Eof}; + +/// Decoders to handle different Transfer-Encodings. +/// +/// If a message body does not include a Transfer-Encoding, it *should* +/// include a Content-Length header. +#[derive(Debug, Clone)] +pub struct Decoder { + kind: Kind, +} + +impl Decoder { + pub fn length(x: u64) -> Decoder { + Decoder { + kind: Kind::Length(x) + } + } + + pub fn chunked() -> Decoder { + Decoder { + kind: Kind::Chunked(None) + } + } + + pub fn eof() -> Decoder { + Decoder { + kind: Kind::Eof(false) + } + } +} + +#[derive(Debug, Clone)] +enum Kind { + /// A Reader used when a Content-Length header is passed with a positive integer. + Length(u64), + /// A Reader used when Transfer-Encoding is `chunked`. + Chunked(Option), + /// A Reader used for responses that don't indicate a length or chunked. + /// + /// Note: This should only used for `Response`s. It is illegal for a + /// `Request` to be made with both `Content-Length` and + /// `Transfer-Encoding: chunked` missing, as explained from the spec: + /// + /// > If a Transfer-Encoding header field is present in a response and + /// > the chunked transfer coding is not the final encoding, the + /// > message body length is determined by reading the connection until + /// > it is closed by the server. If a Transfer-Encoding header field + /// > is present in a request and the chunked transfer coding is not + /// > the final encoding, the message body length cannot be determined + /// > reliably; the server MUST respond with the 400 (Bad Request) + /// > status code and then close the connection. + Eof(bool), +} + +impl Decoder { + pub fn is_eof(&self) -> bool { + trace!("is_eof? {:?}", self); + match self.kind { + Length(0) | + Chunked(Some(0)) | + Eof(true) => true, + _ => false + } + } +} + +impl Decoder { + pub fn decode(&mut self, body: &mut R, buf: &mut [u8]) -> io::Result { + match self.kind { + Length(ref mut remaining) => { + trace!("Sized read, remaining={:?}", remaining); + if *remaining == 0 { + Ok(0) + } else { + let to_read = cmp::min(*remaining as usize, buf.len()); + let num = try!(body.read(&mut buf[..to_read])) as u64; + trace!("Length read: {}", num); + if num > *remaining { + *remaining = 0; + } else if num == 0 { + return Err(io::Error::new(io::ErrorKind::Other, "early eof")); + } else { + *remaining -= num; + } + Ok(num as usize) + } + }, + Chunked(ref mut opt_remaining) => { + let mut rem = match *opt_remaining { + Some(ref rem) => *rem, + // None means we don't know the size of the next chunk + None => try!(read_chunk_size(body)) + }; + trace!("Chunked read, remaining={:?}", rem); + + if rem == 0 { + *opt_remaining = Some(0); + + // chunk of size 0 signals the end of the chunked stream + // if the 0 digit was missing from the stream, it would + // be an InvalidInput error instead. + trace!("end of chunked"); + return Ok(0) + } + + let to_read = cmp::min(rem as usize, buf.len()); + let count = try!(body.read(&mut buf[..to_read])) as u64; + + if count == 0 { + *opt_remaining = Some(0); + return Err(io::Error::new(io::ErrorKind::Other, "early eof")); + } + + rem -= count; + *opt_remaining = if rem > 0 { + Some(rem) + } else { + try!(eat(body, b"\r\n")); + None + }; + Ok(count as usize) + }, + Eof(ref mut is_eof) => { + match body.read(buf) { + Ok(0) => { + *is_eof = true; + Ok(0) + } + other => other + } + }, + } + } +} + +fn eat(rdr: &mut R, bytes: &[u8]) -> io::Result<()> { + let mut buf = [0]; + for &b in bytes.iter() { + match try!(rdr.read(&mut buf)) { + 1 if buf[0] == b => (), + _ => return Err(io::Error::new(io::ErrorKind::InvalidInput, + "Invalid characters found")), + } + } + Ok(()) +} + +/// Chunked chunks start with 1*HEXDIGIT, indicating the size of the chunk. +fn read_chunk_size(rdr: &mut R) -> io::Result { + macro_rules! byte ( + ($rdr:ident) => ({ + let mut buf = [0]; + match try!($rdr.read(&mut buf)) { + 1 => buf[0], + _ => return Err(io::Error::new(io::ErrorKind::InvalidInput, + "Invalid chunk size line")), + + } + }) + ); + let mut size = 0u64; + let radix = 16; + let mut in_ext = false; + let mut in_chunk_size = true; + loop { + match byte!(rdr) { + b@b'0'...b'9' if in_chunk_size => { + size *= radix; + size += (b - b'0') as u64; + }, + b@b'a'...b'f' if in_chunk_size => { + size *= radix; + size += (b + 10 - b'a') as u64; + }, + b@b'A'...b'F' if in_chunk_size => { + size *= radix; + size += (b + 10 - b'A') as u64; + }, + b'\r' => { + match byte!(rdr) { + b'\n' => break, + _ => return Err(io::Error::new(io::ErrorKind::InvalidInput, + "Invalid chunk size line")) + + } + }, + // If we weren't in the extension yet, the ";" signals its start + b';' if !in_ext => { + in_ext = true; + in_chunk_size = false; + }, + // "Linear white space" is ignored between the chunk size and the + // extension separator token (";") due to the "implied *LWS rule". + b'\t' | b' ' if !in_ext & !in_chunk_size => {}, + // LWS can follow the chunk size, but no more digits can come + b'\t' | b' ' if in_chunk_size => in_chunk_size = false, + // We allow any arbitrary octet once we are in the extension, since + // they all get ignored anyway. According to the HTTP spec, valid + // extensions would have a more strict syntax: + // (token ["=" (token | quoted-string)]) + // but we gain nothing by rejecting an otherwise valid chunk size. + _ext if in_ext => { + //TODO: chunk extension byte; + }, + // Finally, if we aren't in the extension and we're reading any + // other octet, the chunk size line is invalid! + _ => { + return Err(io::Error::new(io::ErrorKind::InvalidInput, + "Invalid chunk size line")); + } + } + } + trace!("chunk size={:?}", size); + Ok(size) +} + + +#[cfg(test)] +mod tests { + use std::error::Error; + use std::io; + use super::{Decoder, read_chunk_size}; + + #[test] + fn test_read_chunk_size() { + fn read(s: &str, result: u64) { + assert_eq!(read_chunk_size(&mut s.as_bytes()).unwrap(), result); + } + + fn read_err(s: &str) { + assert_eq!(read_chunk_size(&mut s.as_bytes()).unwrap_err().kind(), + io::ErrorKind::InvalidInput); + } + + read("1\r\n", 1); + read("01\r\n", 1); + read("0\r\n", 0); + read("00\r\n", 0); + read("A\r\n", 10); + read("a\r\n", 10); + read("Ff\r\n", 255); + read("Ff \r\n", 255); + // Missing LF or CRLF + read_err("F\rF"); + read_err("F"); + // Invalid hex digit + read_err("X\r\n"); + read_err("1X\r\n"); + read_err("-\r\n"); + read_err("-1\r\n"); + // Acceptable (if not fully valid) extensions do not influence the size + read("1;extension\r\n", 1); + read("a;ext name=value\r\n", 10); + read("1;extension;extension2\r\n", 1); + read("1;;; ;\r\n", 1); + read("2; extension...\r\n", 2); + read("3 ; extension=123\r\n", 3); + read("3 ;\r\n", 3); + read("3 ; \r\n", 3); + // Invalid extensions cause an error + read_err("1 invalid extension\r\n"); + read_err("1 A\r\n"); + read_err("1;no CRLF"); + } + + #[test] + fn test_read_sized_early_eof() { + let mut bytes = &b"foo bar"[..]; + let mut decoder = Decoder::length(10); + let mut buf = [0u8; 10]; + assert_eq!(decoder.decode(&mut bytes, &mut buf).unwrap(), 7); + let e = decoder.decode(&mut bytes, &mut buf).unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::Other); + assert_eq!(e.description(), "early eof"); + } + + #[test] + fn test_read_chunked_early_eof() { + let mut bytes = &b"\ + 9\r\n\ + foo bar\ + "[..]; + let mut decoder = Decoder::chunked(); + let mut buf = [0u8; 10]; + assert_eq!(decoder.decode(&mut bytes, &mut buf).unwrap(), 7); + let e = decoder.decode(&mut bytes, &mut buf).unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::Other); + assert_eq!(e.description(), "early eof"); + } +} diff --git a/src/http/h1/encode.rs b/src/http/h1/encode.rs new file mode 100644 index 0000000000..6cc7772ae4 --- /dev/null +++ b/src/http/h1/encode.rs @@ -0,0 +1,371 @@ +use std::borrow::Cow; +use std::cmp; +use std::io::{self, Write}; + +use http::internal::{AtomicWrite, WriteBuf}; + +/// Encoders to handle different Transfer-Encodings. +#[derive(Debug, Clone)] +pub struct Encoder { + kind: Kind, + prefix: Prefix, //Option>> +} + +#[derive(Debug, PartialEq, Clone)] +enum Kind { + /// An Encoder for when Transfer-Encoding includes `chunked`. + Chunked(Chunked), + /// An Encoder for when Content-Length is set. + /// + /// Enforces that the body is not longer than the Content-Length header. + Length(u64), +} + +impl Encoder { + pub fn chunked() -> Encoder { + Encoder { + kind: Kind::Chunked(Chunked::Init), + prefix: Prefix(None) + } + } + + pub fn length(len: u64) -> Encoder { + Encoder { + kind: Kind::Length(len), + prefix: Prefix(None) + } + } + + pub fn prefix(&mut self, prefix: WriteBuf>) { + self.prefix.0 = Some(prefix); + } + + pub fn is_eof(&self) -> bool { + if self.prefix.0.is_some() { + return false; + } + match self.kind { + Kind::Length(0) | + Kind::Chunked(Chunked::End) => true, + _ => false + } + } + + pub fn end(self) -> Option>> { + let trailer = self.trailer(); + let buf = self.prefix.0; + + match (buf, trailer) { + (Some(mut buf), Some(trailer)) => { + buf.bytes.extend_from_slice(trailer); + Some(WriteBuf { + bytes: Cow::Owned(buf.bytes), + pos: buf.pos, + }) + }, + (Some(buf), None) => Some(WriteBuf { + bytes: Cow::Owned(buf.bytes), + pos: buf.pos + }), + (None, Some(trailer)) => { + Some(WriteBuf { + bytes: Cow::Borrowed(trailer), + pos: 0, + }) + }, + (None, None) => None + } + } + + fn trailer(&self) -> Option<&'static [u8]> { + match self.kind { + Kind::Chunked(Chunked::Init) => { + Some(b"0\r\n\r\n") + } + _ => None + } + } + + pub fn encode(&mut self, w: &mut W, msg: &[u8]) -> io::Result { + match self.kind { + Kind::Chunked(ref mut chunked) => { + chunked.encode(w, &mut self.prefix, msg) + }, + Kind::Length(ref mut remaining) => { + let mut n = { + let max = cmp::min(*remaining as usize, msg.len()); + let slice = &msg[..max]; + + let prefix = self.prefix.0.as_ref().map(|buf| &buf.bytes[buf.pos..]).unwrap_or(b""); + + try!(w.write_atomic(&[prefix, slice])) + }; + + n = self.prefix.update(n); + if n == 0 { + return Err(io::Error::new(io::ErrorKind::WouldBlock, "would block")); + } + + *remaining -= n as u64; + Ok(n) + }, + } + } +} + +#[derive(Debug, PartialEq, Clone)] +enum Chunked { + Init, + Size(ChunkSize), + SizeCr, + SizeLf, + Body(usize), + BodyCr, + BodyLf, + End, +} + +impl Chunked { + fn encode(&mut self, w: &mut W, prefix: &mut Prefix, msg: &[u8]) -> io::Result { + match *self { + Chunked::Init => { + let mut size = ChunkSize { + bytes: [0; CHUNK_SIZE_MAX_BYTES], + pos: 0, + len: 0, + }; + trace!("chunked write, size = {:?}", msg.len()); + write!(&mut size, "{:X}", msg.len()) + .expect("CHUNK_SIZE_MAX_BYTES should fit any usize"); + *self = Chunked::Size(size); + } + Chunked::End => return Ok(0), + _ => {} + } + let mut n = { + let pieces = match *self { + Chunked::Init => unreachable!("Chunked::Init should have become Chunked::Size"), + Chunked::Size(ref size) => [ + prefix.0.as_ref().map(|buf| &buf.bytes[buf.pos..]).unwrap_or(b""), + &size.bytes[size.pos.into() .. size.len.into()], + &b"\r\n"[..], + msg, + &b"\r\n"[..], + ], + Chunked::SizeCr => [ + &b""[..], + &b""[..], + &b"\r\n"[..], + msg, + &b"\r\n"[..], + ], + Chunked::SizeLf => [ + &b""[..], + &b""[..], + &b"\n"[..], + msg, + &b"\r\n"[..], + ], + Chunked::Body(pos) => [ + &b""[..], + &b""[..], + &b""[..], + &msg[pos..], + &b"\r\n"[..], + ], + Chunked::BodyCr => [ + &b""[..], + &b""[..], + &b""[..], + &b""[..], + &b"\r\n"[..], + ], + Chunked::BodyLf => [ + &b""[..], + &b""[..], + &b""[..], + &b""[..], + &b"\n"[..], + ], + Chunked::End => unreachable!("Chunked::End shouldn't write more") + }; + try!(w.write_atomic(&pieces)) + }; + + if n > 0 { + n = prefix.update(n); + } + while n > 0 { + match *self { + Chunked::Init => unreachable!("Chunked::Init should have become Chunked::Size"), + Chunked::Size(mut size) => { + n = size.update(n); + if size.len == 0 { + *self = Chunked::SizeCr; + } else { + *self = Chunked::Size(size); + } + }, + Chunked::SizeCr => { + *self = Chunked::SizeLf; + n -= 1; + } + Chunked::SizeLf => { + *self = Chunked::Body(0); + n -= 1; + } + Chunked::Body(pos) => { + let left = msg.len() - pos; + if n >= left { + *self = Chunked::BodyCr; + n -= left; + } else { + *self = Chunked::Body(pos + n); + n = 0; + } + } + Chunked::BodyCr => { + *self = Chunked::BodyLf; + n -= 1; + } + Chunked::BodyLf => { + assert!(n == 1); + *self = if msg.len() == 0 { + Chunked::End + } else { + Chunked::Init + }; + n = 0; + }, + Chunked::End => unreachable!("Chunked::End shouldn't have any to write") + } + } + + match *self { + Chunked::Init | + Chunked::End => Ok(msg.len()), + _ => Err(io::Error::new(io::ErrorKind::WouldBlock, "chunked incomplete")) + } + } +} + +#[cfg(target_pointer_width = "32")] +const USIZE_BYTES: usize = 4; + +#[cfg(target_pointer_width = "64")] +const USIZE_BYTES: usize = 8; + +// each byte will become 2 hex +const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2; + +#[derive(Clone, Copy)] +struct ChunkSize { + bytes: [u8; CHUNK_SIZE_MAX_BYTES], + pos: u8, + len: u8, +} + +impl ChunkSize { + fn update(&mut self, n: usize) -> usize { + let diff = (self.len - self.pos).into(); + if n >= diff { + self.pos = 0; + self.len = 0; + n - diff + } else { + self.pos += n as u8; // just verified it was a small usize + 0 + } + } +} + +impl ::std::fmt::Debug for ChunkSize { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + f.debug_struct("ChunkSize") + .field("bytes", &&self.bytes[..self.len.into()]) + .field("pos", &self.pos) + .finish() + } +} + +impl ::std::cmp::PartialEq for ChunkSize { + fn eq(&self, other: &ChunkSize) -> bool { + self.len == other.len && + self.pos == other.pos && + (&self.bytes[..]) == (&other.bytes[..]) + } +} + +impl io::Write for ChunkSize { + fn write(&mut self, msg: &[u8]) -> io::Result { + let n = (&mut self.bytes[self.len.into() ..]).write(msg) + .expect("&mut [u8].write() cannot error"); + self.len += n as u8; // safe because bytes is never bigger than 256 + Ok(n) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +#[derive(Debug, Clone)] +struct Prefix(Option>>); + +impl Prefix { + fn update(&mut self, n: usize) -> usize { + if let Some(mut buf) = self.0.take() { + if buf.bytes.len() - buf.pos > n { + buf.pos += n; + self.0 = Some(buf); + 0 + } else { + let nbuf = buf.bytes.len() - buf.pos; + n - nbuf + } + } else { + n + } + } +} + +#[cfg(test)] +mod tests { + use super::Encoder; + use mock::{Async, Buf}; + + #[test] + fn test_write_chunked_sync() { + let mut dst = Buf::new(); + let mut encoder = Encoder::chunked(); + + encoder.encode(&mut dst, b"foo bar").unwrap(); + encoder.encode(&mut dst, b"baz quux herp").unwrap(); + encoder.encode(&mut dst, b"").unwrap(); + assert_eq!(&dst[..], &b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n"[..]); + } + + #[test] + fn test_write_chunked_async() { + let mut dst = Async::new(Buf::new(), 7); + let mut encoder = Encoder::chunked(); + + assert!(encoder.encode(&mut dst, b"foo bar").is_err()); + dst.block_in(6); + assert_eq!(7, encoder.encode(&mut dst, b"foo bar").unwrap()); + dst.block_in(30); + assert_eq!(13, encoder.encode(&mut dst, b"baz quux herp").unwrap()); + encoder.encode(&mut dst, b"").unwrap(); + assert_eq!(&dst[..], &b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n"[..]); + } + + #[test] + fn test_write_sized() { + let mut dst = Buf::new(); + let mut encoder = Encoder::length(8); + encoder.encode(&mut dst, b"foo bar").unwrap(); + assert_eq!(encoder.encode(&mut dst, b"baz").unwrap(), 1); + + assert_eq!(dst, b"foo barb"); + } +} diff --git a/src/http/h1/mod.rs b/src/http/h1/mod.rs new file mode 100644 index 0000000000..135e72e613 --- /dev/null +++ b/src/http/h1/mod.rs @@ -0,0 +1,136 @@ +/* +use std::fmt; +use std::io::{self, Write}; +use std::marker::PhantomData; +use std::sync::mpsc; + +use url::Url; +use tick; +use time::now_utc; + +use header::{self, Headers}; +use http::{self, conn}; +use method::Method; +use net::{Fresh, Streaming}; +use status::StatusCode; +use version::HttpVersion; +*/ + +pub use self::decode::Decoder; +pub use self::encode::Encoder; + +pub use self::parse::parse; + +mod decode; +mod encode; +mod parse; + +/* +fn should_have_response_body(method: &Method, status: u16) -> bool { + trace!("should_have_response_body({:?}, {})", method, status); + match (method, status) { + (&Method::Head, _) | + (_, 100...199) | + (_, 204) | + (_, 304) | + (&Method::Connect, 200...299) => false, + _ => true + } +} +*/ +/* +const MAX_INVALID_RESPONSE_BYTES: usize = 1024 * 128; +impl HttpMessage for Http11Message { + + fn get_incoming(&mut self) -> ::Result { + unimplemented!(); + /* + try!(self.flush_outgoing()); + let stream = match self.stream.take() { + Some(stream) => stream, + None => { + // The message was already in the reading state... + // TODO Decide what happens in case we try to get a new incoming at that point + return Err(From::from( + io::Error::new(io::ErrorKind::Other, + "Read already in progress"))); + } + }; + + let expected_no_content = stream.previous_response_expected_no_content(); + trace!("previous_response_expected_no_content = {}", expected_no_content); + + let mut stream = BufReader::new(stream); + + let mut invalid_bytes_read = 0; + let head; + loop { + head = match parse_response(&mut stream) { + Ok(head) => head, + Err(::Error::Version) + if expected_no_content && invalid_bytes_read < MAX_INVALID_RESPONSE_BYTES => { + trace!("expected_no_content, found content"); + invalid_bytes_read += 1; + stream.consume(1); + continue; + } + Err(e) => { + self.stream = Some(stream.into_inner()); + return Err(e); + } + }; + break; + } + + let raw_status = head.subject; + let headers = head.headers; + + let method = self.method.take().unwrap_or(Method::Get); + + let is_empty = !should_have_response_body(&method, raw_status.0); + stream.get_mut().set_previous_response_expected_no_content(is_empty); + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 + // 1. HEAD reponses, and Status 1xx, 204, and 304 cannot have a body. + // 2. Status 2xx to a CONNECT cannot have a body. + // 3. Transfer-Encoding: chunked has a chunked body. + // 4. If multiple differing Content-Length headers or invalid, close connection. + // 5. Content-Length header has a sized body. + // 6. Not Client. + // 7. Read till EOF. + self.reader = Some(if is_empty { + SizedReader(stream, 0) + } else { + if let Some(&TransferEncoding(ref codings)) = headers.get() { + if codings.last() == Some(&Chunked) { + ChunkedReader(stream, None) + } else { + trace!("not chuncked. read till eof"); + EofReader(stream) + } + } else if let Some(&ContentLength(len)) = headers.get() { + SizedReader(stream, len) + } else if headers.has::() { + trace!("illegal Content-Length: {:?}", headers.get_raw("Content-Length")); + return Err(Error::Header); + } else { + trace!("neither Transfer-Encoding nor Content-Length"); + EofReader(stream) + } + }); + + trace!("Http11Message.reader = {:?}", self.reader); + + + Ok(ResponseHead { + headers: headers, + raw_status: raw_status, + version: head.version, + }) + */ + } +} + + +*/ + + diff --git a/src/http/h1/parse.rs b/src/http/h1/parse.rs new file mode 100644 index 0000000000..4c08e80a8b --- /dev/null +++ b/src/http/h1/parse.rs @@ -0,0 +1,246 @@ +use std::borrow::Cow; +use std::io::Write; + +use httparse; + +use header::{self, Headers, ContentLength, TransferEncoding}; +use http::{MessageHead, RawStatus, Http1Message, ParseResult, Next, ServerMessage, ClientMessage, Next_, RequestLine}; +use http::h1::{Encoder, Decoder}; +use method::Method; +use status::StatusCode; +use version::HttpVersion::{Http10, Http11}; + +const MAX_HEADERS: usize = 100; +const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific + +pub fn parse, I>(buf: &[u8]) -> ParseResult { + if buf.len() == 0 { + return Ok(None); + } + trace!("parse({:?})", buf); + ::parse(buf) +} + + + +impl Http1Message for ServerMessage { + type Incoming = RequestLine; + type Outgoing = StatusCode; + + fn initial_interest() -> Next { + Next::new(Next_::Read) + } + + fn parse(buf: &[u8]) -> ParseResult { + let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + trace!("Request.parse([Header; {}], [u8; {}])", headers.len(), buf.len()); + let mut req = httparse::Request::new(&mut headers); + Ok(match try!(req.parse(buf)) { + httparse::Status::Complete(len) => { + trace!("Request.parse Complete({})", len); + Some((MessageHead { + version: if req.version.unwrap() == 1 { Http11 } else { Http10 }, + subject: RequestLine( + try!(req.method.unwrap().parse()), + try!(req.path.unwrap().parse()) + ), + headers: try!(Headers::from_raw(req.headers)) + }, len)) + }, + httparse::Status::Partial => None + }) + } + + fn decoder(head: &MessageHead) -> ::Result { + use ::header; + if let Some(&header::ContentLength(len)) = head.headers.get() { + Ok(Decoder::length(len)) + } else if head.headers.has::() { + //TODO: check for Transfer-Encoding: chunked + Ok(Decoder::chunked()) + } else { + Ok(Decoder::length(0)) + } + } + + + fn encode(mut head: MessageHead, dst: &mut Vec) -> Encoder { + use ::header; + trace!("writing head: {:?}", head); + + if !head.headers.has::() { + head.headers.set(header::Date(header::HttpDate(::time::now_utc()))); + } + + let mut is_chunked = true; + let mut body = Encoder::chunked(); + if let Some(cl) = head.headers.get::() { + body = Encoder::length(**cl); + is_chunked = false + } + + if is_chunked { + let encodings = match head.headers.get_mut::() { + Some(&mut header::TransferEncoding(ref mut encodings)) => { + if encodings.last() != Some(&header::Encoding::Chunked) { + encodings.push(header::Encoding::Chunked); + } + false + }, + None => true + }; + + if encodings { + head.headers.set(header::TransferEncoding(vec![header::Encoding::Chunked])); + } + } + + + let init_cap = 30 + head.headers.len() * AVERAGE_HEADER_SIZE; + dst.reserve(init_cap); + debug!("writing {:#?}", head.headers); + let _ = write!(dst, "{} {}\r\n{}\r\n", head.version, head.subject, head.headers); + + body + } +} + +impl Http1Message for ClientMessage { + type Incoming = RawStatus; + type Outgoing = RequestLine; + + + fn initial_interest() -> Next { + Next::new(Next_::Write) + } + + fn parse(buf: &[u8]) -> ParseResult { + let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + trace!("Response.parse([Header; {}], [u8; {}])", headers.len(), buf.len()); + let mut res = httparse::Response::new(&mut headers); + Ok(match try!(res.parse(buf)) { + httparse::Status::Complete(len) => { + trace!("Response.try_parse Complete({})", len); + let code = res.code.unwrap(); + let reason = match StatusCode::from_u16(code).canonical_reason() { + Some(reason) if reason == res.reason.unwrap() => Cow::Borrowed(reason), + _ => Cow::Owned(res.reason.unwrap().to_owned()) + }; + Some((MessageHead { + version: if res.version.unwrap() == 1 { Http11 } else { Http10 }, + subject: RawStatus(code, reason), + headers: try!(Headers::from_raw(res.headers)) + }, len)) + }, + httparse::Status::Partial => None + }) + } + + fn decoder(inc: &MessageHead) -> ::Result { + use ::header; + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 + // 1. HEAD reponses, and Status 1xx, 204, and 304 cannot have a body. + // 2. Status 2xx to a CONNECT cannot have a body. + // + // First two steps taken care of before this method. + // + // 3. Transfer-Encoding: chunked has a chunked body. + // 4. If multiple differing Content-Length headers or invalid, close connection. + // 5. Content-Length header has a sized body. + // 6. Not Client. + // 7. Read till EOF. + if let Some(&header::TransferEncoding(ref codings)) = inc.headers.get() { + if codings.last() == Some(&header::Encoding::Chunked) { + Ok(Decoder::chunked()) + } else { + trace!("not chuncked. read till eof"); + Ok(Decoder::eof()) + } + } else if let Some(&header::ContentLength(len)) = inc.headers.get() { + Ok(Decoder::length(len)) + } else if inc.headers.has::() { + trace!("illegal Content-Length: {:?}", inc.headers.get_raw("Content-Length")); + Err(::Error::Header) + } else { + trace!("neither Transfer-Encoding nor Content-Length"); + Ok(Decoder::eof()) + } + } + + fn encode(mut head: MessageHead, dst: &mut Vec) -> Encoder { + trace!("writing head: {:?}", head); + + + let mut body = Encoder::length(0); + let expects_no_body = match head.subject.0 { + Method::Head | Method::Get | Method::Connect => true, + _ => false + }; + let mut chunked = false; + + if let Some(con_len) = head.headers.get::() { + body = Encoder::length(**con_len); + } else { + chunked = !expects_no_body; + } + + if chunked { + body = Encoder::chunked(); + let encodings = match head.headers.get_mut::() { + Some(encodings) => { + //TODO: check if Chunked already exists + encodings.push(header::Encoding::Chunked); + true + }, + None => false + }; + + if !encodings { + head.headers.set(TransferEncoding(vec![header::Encoding::Chunked])); + } + } + + let init_cap = 30 + head.headers.len() * AVERAGE_HEADER_SIZE; + dst.reserve(init_cap); + debug!("writing {:#?}", head.headers); + let _ = write!(dst, "{} {}\r\n{}\r\n", head.subject, head.version, head.headers); + + body + } +} + +#[cfg(test)] +mod tests { + use http; + use super::{parse}; + + #[test] + fn test_parse_request() { + let raw = b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"; + parse::(raw).unwrap(); + } + + #[test] + fn test_parse_raw_status() { + let raw = b"HTTP/1.1 200 OK\r\n\r\n"; + let (res, _) = parse::(raw).unwrap().unwrap(); + assert_eq!(res.subject.1, "OK"); + + let raw = b"HTTP/1.1 200 Howdy\r\n\r\n"; + let (res, _) = parse::(raw).unwrap().unwrap(); + assert_eq!(res.subject.1, "Howdy"); + } + + #[cfg(feature = "nightly")] + use test::Bencher; + + #[cfg(feature = "nightly")] + #[bench] + fn bench_parse_incoming(b: &mut Bencher) { + let raw = b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"; + b.iter(|| { + parse::(raw).unwrap() + }); + } + +} diff --git a/src/http/h2.rs b/src/http/h2/mod.rs similarity index 100% rename from src/http/h2.rs rename to src/http/h2/mod.rs diff --git a/src/http/message.rs b/src/http/message.rs deleted file mode 100644 index d983fafa03..0000000000 --- a/src/http/message.rs +++ /dev/null @@ -1,133 +0,0 @@ -//! Defines the `HttpMessage` trait that serves to encapsulate the operations of a single -//! request-response cycle on any HTTP connection. - -use std::any::{Any, TypeId}; -use std::fmt::Debug; -use std::io::{Read, Write}; -use std::mem; - -use std::io; -use std::time::Duration; - -use typeable::Typeable; - -use header::Headers; -use http::RawStatus; -use url::Url; - -use method; -use version; -use traitobject; - -/// The trait provides an API for creating new `HttpMessage`s depending on the underlying HTTP -/// protocol. -pub trait Protocol { - /// Creates a fresh `HttpMessage` bound to the given host, based on the given protocol scheme. - fn new_message(&self, host: &str, port: u16, scheme: &str) -> ::Result>; -} - -/// Describes a request. -#[derive(Clone, Debug)] -pub struct RequestHead { - /// The headers of the request - pub headers: Headers, - /// The method of the request - pub method: method::Method, - /// The URL of the request - pub url: Url, -} - -/// Describes a response. -#[derive(Clone, Debug)] -pub struct ResponseHead { - /// The headers of the reponse - pub headers: Headers, - /// The raw status line of the response - pub raw_status: RawStatus, - /// The HTTP/2 version which generated the response - pub version: version::HttpVersion, -} - -/// The trait provides an API for sending an receiving HTTP messages. -pub trait HttpMessage: Write + Read + Send + Any + Typeable + Debug { - /// Initiates a new outgoing request. - /// - /// Only the request's head is provided (in terms of the `RequestHead` struct). - /// - /// After this, the `HttpMessage` instance can be used as an `io::Write` in order to write the - /// body of the request. - fn set_outgoing(&mut self, head: RequestHead) -> ::Result; - /// Obtains the incoming response and returns its head (i.e. the `ResponseHead` struct) - /// - /// After this, the `HttpMessage` instance can be used as an `io::Read` in order to read out - /// the response body. - fn get_incoming(&mut self) -> ::Result; - /// Set the read timeout duration for this message. - fn set_read_timeout(&self, dur: Option) -> io::Result<()>; - /// Set the write timeout duration for this message. - fn set_write_timeout(&self, dur: Option) -> io::Result<()>; - /// Closes the underlying HTTP connection. - fn close_connection(&mut self) -> ::Result<()>; - /// Returns whether the incoming message has a body. - fn has_body(&self) -> bool; - /// Called when the Client wishes to use a Proxy. - fn set_proxied(&mut self, val: bool) { - // default implementation so as to not be a breaking change. - warn!("default set_proxied({:?})", val); - } -} - -impl HttpMessage { - unsafe fn downcast_ref_unchecked(&self) -> &T { - mem::transmute(traitobject::data(self)) - } - - unsafe fn downcast_mut_unchecked(&mut self) -> &mut T { - mem::transmute(traitobject::data_mut(self)) - } - - unsafe fn downcast_unchecked(self: Box) -> Box { - let raw: *mut HttpMessage = mem::transmute(self); - mem::transmute(traitobject::data_mut(raw)) - } -} - -impl HttpMessage { - /// Is the underlying type in this trait object a T? - #[inline] - pub fn is(&self) -> bool { - (*self).get_type() == TypeId::of::() - } - - /// If the underlying type is T, get a reference to the contained data. - #[inline] - pub fn downcast_ref(&self) -> Option<&T> { - if self.is::() { - Some(unsafe { self.downcast_ref_unchecked() }) - } else { - None - } - } - - /// If the underlying type is T, get a mutable reference to the contained - /// data. - #[inline] - pub fn downcast_mut(&mut self) -> Option<&mut T> { - if self.is::() { - Some(unsafe { self.downcast_mut_unchecked() }) - } else { - None - } - } - - /// If the underlying type is T, extract it. - #[inline] - pub fn downcast(self: Box) - -> Result, Box> { - if self.is::() { - Ok(unsafe { self.downcast_unchecked() }) - } else { - Err(self) - } - } -} diff --git a/src/http/mod.rs b/src/http/mod.rs index 396eec559d..252d874071 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -1,25 +1,196 @@ //! Pieces pertaining to the HTTP message protocol. use std::borrow::Cow; +use std::fmt; +use std::io::{self, Read, Write}; +use std::time::Duration; use header::Connection; use header::ConnectionOption::{KeepAlive, Close}; use header::Headers; +use method::Method; +use net::Transport; +use status::StatusCode; +use uri::RequestUri; use version::HttpVersion; use version::HttpVersion::{Http10, Http11}; #[cfg(feature = "serde-serialization")] use serde::{Deserialize, Deserializer, Serialize, Serializer}; -pub use self::message::{HttpMessage, RequestHead, ResponseHead, Protocol}; +pub use self::conn::{Conn, MessageHandler, MessageHandlerFactory, Seed, Key}; -pub mod h1; -pub mod h2; -pub mod message; +mod buffer; +pub mod channel; +mod conn; +mod h1; +//mod h2; + +/// Wraps a `Transport` to provide HTTP decoding when reading. +#[derive(Debug)] +pub struct Decoder<'a, T: Read + 'a>(DecoderImpl<'a, T>); + +/// Wraps a `Transport` to provide HTTP encoding when writing. +#[derive(Debug)] +pub struct Encoder<'a, T: Transport + 'a>(EncoderImpl<'a, T>); + +#[derive(Debug)] +enum DecoderImpl<'a, T: Read + 'a> { + H1(&'a mut h1::Decoder, Trans<'a, T>), +} + +#[derive(Debug)] +enum Trans<'a, T: Read + 'a> { + Port(&'a mut T), + Buf(self::buffer::BufReader<'a, T>) +} + +impl<'a, T: Read + 'a> Read for Trans<'a, T> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match *self { + Trans::Port(ref mut t) => t.read(buf), + Trans::Buf(ref mut b) => b.read(buf) + } + } +} + +#[derive(Debug)] +enum EncoderImpl<'a, T: Transport + 'a> { + H1(&'a mut h1::Encoder, &'a mut T), +} + +impl<'a, T: Read> Decoder<'a, T> { + fn h1(decoder: &'a mut h1::Decoder, transport: Trans<'a, T>) -> Decoder<'a, T> { + Decoder(DecoderImpl::H1(decoder, transport)) + } +} + +impl<'a, T: Transport> Encoder<'a, T> { + fn h1(encoder: &'a mut h1::Encoder, transport: &'a mut T) -> Encoder<'a, T> { + Encoder(EncoderImpl::H1(encoder, transport)) + } +} + +impl<'a, T: Read> Read for Decoder<'a, T> { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self.0 { + DecoderImpl::H1(ref mut decoder, ref mut transport) => { + decoder.decode(transport, buf) + } + } + } +} + +impl<'a, T: Transport> Write for Encoder<'a, T> { + #[inline] + fn write(&mut self, data: &[u8]) -> io::Result { + if data.is_empty() { + return Ok(0); + } + match self.0 { + EncoderImpl::H1(ref mut encoder, ref mut transport) => { + encoder.encode(*transport, data) + } + } + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + match self.0 { + EncoderImpl::H1(_, ref mut transport) => { + transport.flush() + } + } + } +} + +/// Because privacy rules. Reasons. +/// https://github.com/rust-lang/rust/issues/30905 +mod internal { + use std::io::{self, Write}; + + #[derive(Debug, Clone)] + pub struct WriteBuf> { + pub bytes: T, + pub pos: usize, + } + + pub trait AtomicWrite { + fn write_atomic(&mut self, data: &[&[u8]]) -> io::Result; + } + + #[cfg(not(windows))] + impl AtomicWrite for T { + + fn write_atomic(&mut self, bufs: &[&[u8]]) -> io::Result { + self.writev(bufs) + } + + } + + #[cfg(windows)] + impl AtomicWrite for T { + fn write_atomic(&mut self, bufs: &[&[u8]]) -> io::Result { + let vec = bufs.concat(); + self.write(&vec) + } + } +} + +/// An Incoming Message head. Includes request/status line, and headers. +#[derive(Debug, Default)] +pub struct MessageHead { + /// HTTP version of the message. + pub version: HttpVersion, + /// Subject (request line or status line) of Incoming message. + pub subject: S, + /// Headers of the Incoming message. + pub headers: Headers +} + +/// An incoming request message. +pub type RequestHead = MessageHead; + +#[derive(Debug, Default)] +pub struct RequestLine(pub Method, pub RequestUri); + +impl fmt::Display for RequestLine { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} {}", self.0, self.1) + } +} + +/// An incoming response message. +pub type ResponseHead = MessageHead; + +impl MessageHead { + pub fn should_keep_alive(&self) -> bool { + should_keep_alive(self.version, &self.headers) + } +} /// The raw status code and reason-phrase. #[derive(Clone, PartialEq, Debug)] pub struct RawStatus(pub u16, pub Cow<'static, str>); +impl fmt::Display for RawStatus { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} {}", self.0, self.1) + } +} + +impl From for RawStatus { + fn from(status: StatusCode) -> RawStatus { + RawStatus(status.to_u16(), Cow::Borrowed(status.canonical_reason().unwrap_or(""))) + } +} + +impl Default for RawStatus { + fn default() -> RawStatus { + RawStatus(200, Cow::Borrowed("OK")) + } +} + #[cfg(feature = "serde-serialization")] impl Serialize for RawStatus { fn serialize(&self, serializer: &mut S) -> Result<(), S::Error> where S: Serializer { @@ -46,6 +217,158 @@ pub fn should_keep_alive(version: HttpVersion, headers: &Headers) -> bool { _ => true } } +pub type ParseResult = ::Result, usize)>>; + +pub fn parse, I>(rdr: &[u8]) -> ParseResult { + h1::parse::(rdr) +} + +// These 2 enums are not actually dead_code. They are used in the server and +// and client modules, respectively. However, their being used as associated +// types doesn't mark them as used, so the dead_code linter complains. + +#[allow(dead_code)] +#[derive(Debug)] +pub enum ServerMessage {} + +#[allow(dead_code)] +#[derive(Debug)] +pub enum ClientMessage {} + +pub trait Http1Message { + type Incoming; + type Outgoing: Default; + //TODO: replace with associated const when stable + fn initial_interest() -> Next; + fn parse(bytes: &[u8]) -> ParseResult; + fn decoder(head: &MessageHead) -> ::Result; + fn encode(head: MessageHead, dst: &mut Vec) -> h1::Encoder; + +} + +/// Used to signal desired events when working with asynchronous IO. +#[must_use] +#[derive(Clone)] +pub struct Next { + interest: Next_, + timeout: Option, +} + +impl fmt::Debug for Next { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + try!(write!(f, "Next::{:?}", &self.interest)); + match self.timeout { + Some(ref d) => write!(f, "({:?})", d), + None => Ok(()) + } + } +} + +#[derive(Debug, Clone, Copy)] +enum Next_ { + Read, + Write, + ReadWrite, + Wait, + End, + Remove, +} + +#[derive(Debug, Clone, Copy)] +enum Reg { + Read, + Write, + ReadWrite, + Wait, + Remove +} + +/// A notifier to wakeup a socket after having used `Next::wait()` +#[derive(Debug, Clone)] +pub struct Control { + tx: self::channel::Sender, +} + +impl Control { + /// Wakeup a waiting socket to listen for a certain event. + pub fn ready(&self, next: Next) -> Result<(), ControlError> { + //TODO: assert!( next.interest != Next_::Wait ) ? + self.tx.send(next).map_err(|_| ControlError(())) + } +} + +/// An error occured trying to tell a Control it is ready. +#[derive(Debug)] +pub struct ControlError(()); + +impl ::std::error::Error for ControlError { + fn description(&self) -> &str { + "Cannot wakeup event loop: loop is closed" + } +} + +impl fmt::Display for ControlError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(::std::error::Error::description(self)) + } +} + +impl Next { + fn new(interest: Next_) -> Next { + Next { + interest: interest, + timeout: None, + } + } + + fn interest(&self) -> Reg { + match self.interest { + Next_::Read => Reg::Read, + Next_::Write => Reg::Write, + Next_::ReadWrite => Reg::ReadWrite, + Next_::Wait => Reg::Wait, + Next_::End => Reg::Remove, + Next_::Remove => Reg::Remove, + } + } + + /// Signals the desire to read from the transport. + pub fn read() -> Next { + Next::new(Next_::Read) + } + + /// Signals the desire to write to the transport. + pub fn write() -> Next { + Next::new(Next_::Write) + } + + /// Signals the desire to read and write to the transport. + pub fn read_and_write() -> Next { + Next::new(Next_::ReadWrite) + } + + /// Signals the desire to end the current HTTP message. + pub fn end() -> Next { + Next::new(Next_::End) + } + + /// Signals the desire to abruptly remove the current transport from the + /// event loop. + pub fn remove() -> Next { + Next::new(Next_::Remove) + } + + /// Signals the desire to wait until some future time before acting again. + pub fn wait() -> Next { + Next::new(Next_::Wait) + } + + /// Signals a maximum duration to be waited for the desired event. + pub fn timeout(mut self, dur: Duration) -> Next { + self.timeout = Some(dur); + self + } +} #[test] fn test_should_keep_alive() { diff --git a/src/lib.rs b/src/lib.rs index 72e96e22e7..9dc940a1d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![doc(html_root_url = "https://hyperium.github.io/hyper/")] -#![cfg_attr(test, deny(missing_docs))] -#![cfg_attr(test, deny(warnings))] +#![deny(missing_docs)] +#![deny(warnings)] +#![deny(missing_debug_implementations)] #![cfg_attr(all(test, feature = "nightly"), feature(test))] //! # Hyper @@ -9,125 +10,9 @@ //! is a low-level typesafe abstraction over raw HTTP, providing an elegant //! layer over "stringly-typed" HTTP. //! -//! Hyper offers both a [Client](client/index.html) and a -//! [Server](server/index.html) which can be used to drive complex web -//! applications written entirely in Rust. -//! -//! ## Internal Design -//! -//! Hyper is designed as a relatively low-level wrapper over raw HTTP. It should -//! allow the implementation of higher-level abstractions with as little pain as -//! possible, and should not irrevocably hide any information from its users. -//! -//! ### Common Functionality -//! -//! Functionality and code shared between the Server and Client implementations -//! can be found in `src` directly - this includes `NetworkStream`s, `Method`s, -//! `StatusCode`, and so on. -//! -//! #### Methods -//! -//! Methods are represented as a single `enum` to remain as simple as possible. -//! Extension Methods are represented as raw `String`s. A method's safety and -//! idempotence can be accessed using the `safe` and `idempotent` methods. -//! -//! #### StatusCode -//! -//! Status codes are also represented as a single, exhaustive, `enum`. This -//! representation is efficient, typesafe, and ergonomic as it allows the use of -//! `match` to disambiguate known status codes. -//! -//! #### Headers -//! -//! Hyper's [header](header/index.html) representation is likely the most -//! complex API exposed by Hyper. -//! -//! Hyper's headers are an abstraction over an internal `HashMap` and provides a -//! typesafe API for interacting with headers that does not rely on the use of -//! "string-typing." -//! -//! Each HTTP header in Hyper has an associated type and implementation of the -//! `Header` trait, which defines an HTTP headers name as a string, how to parse -//! that header, and how to format that header. -//! -//! Headers are then parsed from the string representation lazily when the typed -//! representation of a header is requested and formatted back into their string -//! representation when headers are written back to the client. -//! -//! #### NetworkStream and NetworkAcceptor -//! -//! These are found in `src/net.rs` and define the interface that acceptors and -//! streams must fulfill for them to be used within Hyper. They are by and large -//! internal tools and you should only need to mess around with them if you want to -//! mock or replace `TcpStream` and `TcpAcceptor`. -//! -//! ### Server -//! -//! Server-specific functionality, such as `Request` and `Response` -//! representations, are found in in `src/server`. -//! -//! #### Handler + Server -//! -//! A `Handler` in Hyper accepts a `Request` and `Response`. This is where -//! user-code can handle each connection. The server accepts connections in a -//! task pool with a customizable number of threads, and passes the Request / -//! Response to the handler. -//! -//! #### Request -//! -//! An incoming HTTP Request is represented as a struct containing -//! a `Reader` over a `NetworkStream`, which represents the body, headers, a remote -//! address, an HTTP version, and a `Method` - relatively standard stuff. -//! -//! `Request` implements `Reader` itself, meaning that you can ergonomically get -//! the body out of a `Request` using standard `Reader` methods and helpers. -//! -//! #### Response -//! -//! An outgoing HTTP Response is also represented as a struct containing a `Writer` -//! over a `NetworkStream` which represents the Response body in addition to -//! standard items such as the `StatusCode` and HTTP version. `Response`'s `Writer` -//! implementation provides a streaming interface for sending data over to the -//! client. -//! -//! One of the traditional problems with representing outgoing HTTP Responses is -//! tracking the write-status of the Response - have we written the status-line, -//! the headers, the body, etc.? Hyper tracks this information statically using the -//! type system and prevents you, using the type system, from writing headers after -//! you have started writing to the body or vice versa. -//! -//! Hyper does this through a phantom type parameter in the definition of Response, -//! which tracks whether you are allowed to write to the headers or the body. This -//! phantom type can have two values `Fresh` or `Streaming`, with `Fresh` -//! indicating that you can write the headers and `Streaming` indicating that you -//! may write to the body, but not the headers. -//! -//! ### Client -//! -//! Client-specific functionality, such as `Request` and `Response` -//! representations, are found in `src/client`. -//! -//! #### Request -//! -//! An outgoing HTTP Request is represented as a struct containing a `Writer` over -//! a `NetworkStream` which represents the Request body in addition to the standard -//! information such as headers and the request method. -//! -//! Outgoing Requests track their write-status in almost exactly the same way as -//! outgoing HTTP Responses do on the Server, so we will defer to the explanation -//! in the documentation for server Response. -//! -//! Requests expose an efficient streaming interface instead of a builder pattern, -//! but they also provide the needed interface for creating a builder pattern over -//! the API exposed by core Hyper. -//! -//! #### Response -//! -//! Incoming HTTP Responses are represented as a struct containing a `Reader` over -//! a `NetworkStream` and contain headers, a status, and an http version. They -//! implement `Reader` and can be read to get the data out of a `Response`. -//! - +//! Hyper provides both a [Client](client/index.html) and a +//! [Server](server/index.html), along with a +//! [typed Headers system](header/index.html). extern crate rustc_serialize as serialize; extern crate time; #[macro_use] extern crate url; @@ -142,10 +27,11 @@ extern crate serde; extern crate cookie; extern crate unicase; extern crate httparse; -extern crate num_cpus; +extern crate rotor; +extern crate spmc; extern crate traitobject; extern crate typeable; -extern crate solicit; +extern crate vecio; #[macro_use] extern crate language_tags; @@ -163,35 +49,31 @@ extern crate test; pub use url::Url; pub use client::Client; pub use error::{Result, Error}; -pub use method::Method::{Get, Head, Post, Delete}; -pub use status::StatusCode::{Ok, BadRequest, NotFound}; +pub use http::{Next, Encoder, Decoder, Control}; +pub use header::Headers; +pub use method::Method::{self, Get, Head, Post, Delete}; +pub use status::StatusCode::{self, Ok, BadRequest, NotFound}; pub use server::Server; +pub use uri::RequestUri; +pub use version::HttpVersion; pub use language_tags::LanguageTag; -macro_rules! todo( - ($($arg:tt)*) => (if cfg!(not(ndebug)) { - trace!("TODO: {:?}", format_args!($($arg)*)) - }) -); - -macro_rules! inspect( - ($name:expr, $value:expr) => ({ - let v = $value; - trace!("inspect: {:?} = {:?}", $name, v); - v - }) -); +macro_rules! rotor_try { + ($e:expr) => ({ + match $e { + Ok(v) => v, + Err(e) => return ::rotor::Response::error(e.into()) + } + }); +} #[cfg(test)] -#[macro_use] mod mock; -#[doc(hidden)] -pub mod buffer; pub mod client; pub mod error; pub mod method; pub mod header; -pub mod http; +mod http; pub mod net; pub mod server; pub mod status; @@ -203,6 +85,7 @@ pub mod mime { pub use mime_crate::*; } +/* #[allow(unconditional_recursion)] fn _assert_send() { _assert_send::(); @@ -216,3 +99,4 @@ fn _assert_sync() { _assert_sync::(); _assert_sync::(); } +*/ diff --git a/src/method.rs b/src/method.rs index 10d5569119..61768f9aa5 100644 --- a/src/method.rs +++ b/src/method.rs @@ -128,6 +128,12 @@ impl fmt::Display for Method { } } +impl Default for Method { + fn default() -> Method { + Method::Get + } +} + #[cfg(feature = "serde-serialization")] impl Serialize for Method { fn serialize(&self, serializer: &mut S) -> Result<(), S::Error> where S: Serializer { diff --git a/src/mock.rs b/src/mock.rs index ac70a5159e..e73d2d53ad 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -1,327 +1,127 @@ -use std::ascii::AsciiExt; -use std::io::{self, Read, Write, Cursor}; -use std::cell::RefCell; -use std::net::{SocketAddr, Shutdown}; -use std::sync::{Arc, Mutex}; -use std::time::Duration; -use std::cell::Cell; +use std::cmp; +use std::io::{self, Read, Write}; -use solicit::http::HttpScheme; -use solicit::http::transport::TransportStream; -use solicit::http::frame::{SettingsFrame, Frame}; -use solicit::http::connection::{HttpConnection, EndStream, DataChunk}; - -use header::Headers; -use net::{NetworkStream, NetworkConnector, SslClient}; - -#[derive(Clone, Debug)] -pub struct MockStream { - pub read: Cursor>, - next_reads: Vec>, - pub write: Vec, - pub is_closed: bool, - pub error_on_write: bool, - pub error_on_read: bool, - pub read_timeout: Cell>, - pub write_timeout: Cell>, -} - -impl PartialEq for MockStream { - fn eq(&self, other: &MockStream) -> bool { - self.read.get_ref() == other.read.get_ref() && self.write == other.write - } +#[derive(Debug)] +pub struct Buf { + vec: Vec } -impl MockStream { - pub fn new() -> MockStream { - MockStream::with_input(b"") - } - - pub fn with_input(input: &[u8]) -> MockStream { - MockStream::with_responses(vec![input]) - } - - pub fn with_responses(mut responses: Vec<&[u8]>) -> MockStream { - MockStream { - read: Cursor::new(responses.remove(0).to_vec()), - next_reads: responses.into_iter().map(|arr| arr.to_vec()).collect(), - write: vec![], - is_closed: false, - error_on_write: false, - error_on_read: false, - read_timeout: Cell::new(None), - write_timeout: Cell::new(None), +impl Buf { + pub fn new() -> Buf { + Buf { + vec: vec![] } } } -impl Read for MockStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - if self.error_on_read { - Err(io::Error::new(io::ErrorKind::Other, "mock error")) - } else { - match self.read.read(buf) { - Ok(n) => { - if self.read.position() as usize == self.read.get_ref().len() { - if self.next_reads.len() > 0 { - self.read = Cursor::new(self.next_reads.remove(0)); - } - } - Ok(n) - }, - r => r - } - } - } -} +impl ::std::ops::Deref for Buf { + type Target = [u8]; -impl Write for MockStream { - fn write(&mut self, msg: &[u8]) -> io::Result { - if self.error_on_write { - Err(io::Error::new(io::ErrorKind::Other, "mock error")) - } else { - Write::write(&mut self.write, msg) - } - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn deref(&self) -> &[u8] { + &self.vec } } -impl NetworkStream for MockStream { - fn peer_addr(&mut self) -> io::Result { - Ok("127.0.0.1:1337".parse().unwrap()) - } - - fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.read_timeout.set(dur); - Ok(()) - } - - fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.write_timeout.set(dur); - Ok(()) - } - - fn close(&mut self, _how: Shutdown) -> io::Result<()> { - self.is_closed = true; - Ok(()) +impl> PartialEq for Buf { + fn eq(&self, other: &S) -> bool { + self.vec == other.as_ref() } } -/// A wrapper around a `MockStream` that allows one to clone it and keep an independent copy to the -/// same underlying stream. -#[derive(Clone)] -pub struct CloneableMockStream { - pub inner: Arc>, -} - -impl Write for CloneableMockStream { - fn write(&mut self, msg: &[u8]) -> io::Result { - self.inner.lock().unwrap().write(msg) +impl Write for Buf { + fn write(&mut self, data: &[u8]) -> io::Result { + self.vec.extend(data); + Ok(data.len()) } fn flush(&mut self) -> io::Result<()> { - self.inner.lock().unwrap().flush() + Ok(()) } } -impl Read for CloneableMockStream { +impl Read for Buf { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.lock().unwrap().read(buf) + (&*self.vec).read(buf) } } -impl TransportStream for CloneableMockStream { - fn try_split(&self) -> Result { - Ok(self.clone()) - } +impl ::vecio::Writev for Buf { + fn writev(&mut self, bufs: &[&[u8]]) -> io::Result { + let cap = bufs.iter().map(|buf| buf.len()).fold(0, |total, next| total + next); + let mut vec = Vec::with_capacity(cap); + for &buf in bufs { + vec.extend(buf); + } - fn close(&mut self) -> Result<(), io::Error> { - Ok(()) + self.write(&vec) } } -impl NetworkStream for CloneableMockStream { - fn peer_addr(&mut self) -> io::Result { - self.inner.lock().unwrap().peer_addr() - } - - fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.inner.lock().unwrap().set_read_timeout(dur) - } - - fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.inner.lock().unwrap().set_write_timeout(dur) - } - - fn close(&mut self, how: Shutdown) -> io::Result<()> { - NetworkStream::close(&mut *self.inner.lock().unwrap(), how) - } +#[derive(Debug)] +pub struct Async { + inner: T, + bytes_until_block: usize, } -impl CloneableMockStream { - pub fn with_stream(stream: MockStream) -> CloneableMockStream { - CloneableMockStream { - inner: Arc::new(Mutex::new(stream)), +impl Async { + pub fn new(inner: T, bytes: usize) -> Async { + Async { + inner: inner, + bytes_until_block: bytes } } -} -pub struct MockConnector; - -impl NetworkConnector for MockConnector { - type Stream = MockStream; - - fn connect(&self, _host: &str, _port: u16, _scheme: &str) -> ::Result { - Ok(MockStream::new()) + pub fn block_in(&mut self, bytes: usize) { + self.bytes_until_block = bytes; } } -/// new connectors must be created if you wish to intercept requests. -macro_rules! mock_connector ( - ($name:ident { - $($url:expr => $res:expr)* - }) => ( - - struct $name; - - impl $crate::net::NetworkConnector for $name { - type Stream = ::mock::MockStream; - fn connect(&self, host: &str, port: u16, scheme: &str) - -> $crate::Result<::mock::MockStream> { - use std::collections::HashMap; - debug!("MockStream::connect({:?}, {:?}, {:?})", host, port, scheme); - let mut map = HashMap::new(); - $(map.insert($url, $res);)* - - - let key = format!("{}://{}", scheme, host); - // ignore port for now - match map.get(&*key) { - Some(&res) => Ok($crate::mock::MockStream::with_input(res.as_bytes())), - None => panic!("{:?} doesn't know url {}", stringify!($name), key) - } - } - } - - ); - - ($name:ident { $($response:expr),+ }) => ( - struct $name; - - impl $crate::net::NetworkConnector for $name { - type Stream = $crate::mock::MockStream; - fn connect(&self, _: &str, _: u16, _: &str) - -> $crate::Result<$crate::mock::MockStream> { - Ok($crate::mock::MockStream::with_responses(vec![ - $($response),+ - ])) - } +impl Read for Async { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if self.bytes_until_block == 0 { + Err(io::Error::new(io::ErrorKind::WouldBlock, "mock block")) + } else { + let n = cmp::min(self.bytes_until_block, buf.len()); + let n = try!(self.inner.read(&mut buf[..n])); + self.bytes_until_block -= n; + Ok(n) } - ); -); - -impl TransportStream for MockStream { - fn try_split(&self) -> Result { - Ok(self.clone()) - } - - fn close(&mut self) -> Result<(), io::Error> { - Ok(()) } } -impl MockStream { - /// Creates a new `MockStream` that will return the response described by the parameters as an - /// HTTP/2 response. This will also include the correct server preface. - pub fn new_http2_response(status: &[u8], headers: &Headers, body: Option>) - -> MockStream { - let resp_bytes = build_http2_response(status, headers, body); - MockStream::with_input(&resp_bytes) +impl Write for Async { + fn write(&mut self, data: &[u8]) -> io::Result { + if self.bytes_until_block == 0 { + Err(io::Error::new(io::ErrorKind::WouldBlock, "mock block")) + } else { + let n = cmp::min(self.bytes_until_block, data.len()); + let n = try!(self.inner.write(&data[..n])); + self.bytes_until_block -= n; + Ok(n) + } } -} -/// Builds up a sequence of bytes that represent a server's response based on the given parameters. -pub fn build_http2_response(status: &[u8], headers: &Headers, body: Option>) -> Vec { - let mut conn = HttpConnection::new(MockStream::new(), MockStream::new(), HttpScheme::Http); - // Server preface first - conn.sender.write(&SettingsFrame::new().serialize()).unwrap(); - - let mut resp_headers: Vec<_> = headers.iter().map(|h| { - (h.name().to_ascii_lowercase().into_bytes(), h.value_string().into_bytes()) - }).collect(); - resp_headers.insert(0, (b":status".to_vec(), status.into())); - - let end = if body.is_none() { - EndStream::Yes - } else { - EndStream::No - }; - conn.send_headers(resp_headers, 1, end).unwrap(); - if body.is_some() { - let chunk = DataChunk::new_borrowed(&body.as_ref().unwrap()[..], 1, EndStream::Yes); - conn.send_data(chunk).unwrap(); + fn flush(&mut self) -> io::Result<()> { + self.inner.flush() } - - conn.sender.write } -/// A mock connector that produces `MockStream`s that are set to return HTTP/2 responses. -/// -/// This means that the streams' payloads are fairly opaque byte sequences (as HTTP/2 is a binary -/// protocol), which can be understood only be HTTP/2 clients. -pub struct MockHttp2Connector { - /// The list of streams that the connector returns, in the given order. - pub streams: RefCell>, -} - -impl MockHttp2Connector { - /// Creates a new `MockHttp2Connector` with no streams. - pub fn new() -> MockHttp2Connector { - MockHttp2Connector { - streams: RefCell::new(Vec::new()), +impl ::vecio::Writev for Async { + fn writev(&mut self, bufs: &[&[u8]]) -> io::Result { + let cap = bufs.iter().map(|buf| buf.len()).fold(0, |total, next| total + next); + let mut vec = Vec::with_capacity(cap); + for &buf in bufs { + vec.extend(buf); } - } - - /// Adds a new `CloneableMockStream` to the end of the connector's stream queue. - /// - /// Streams are returned in a FIFO manner. - pub fn add_stream(&mut self, stream: CloneableMockStream) { - self.streams.borrow_mut().push(stream); - } - - /// Adds a new response stream that will be placed to the end of the connector's stream queue. - /// - /// Returns a separate `CloneableMockStream` that allows the user to inspect what is written - /// into the original stream. - pub fn new_response_stream(&mut self, status: &[u8], headers: &Headers, body: Option>) - -> CloneableMockStream { - let stream = MockStream::new_http2_response(status, headers, body); - let stream = CloneableMockStream::with_stream(stream); - let ret = stream.clone(); - self.add_stream(stream); - - ret - } -} -impl NetworkConnector for MockHttp2Connector { - type Stream = CloneableMockStream; - #[inline] - fn connect(&self, _host: &str, _port: u16, _scheme: &str) - -> ::Result { - Ok(self.streams.borrow_mut().remove(0)) + self.write(&vec) } } -#[derive(Debug, Default)] -pub struct MockSsl; +impl ::std::ops::Deref for Async { + type Target = [u8]; -impl SslClient for MockSsl { - type Stream = T; - fn wrap_client(&self, stream: T, _host: &str) -> ::Result { - Ok(stream) + fn deref(&self) -> &[u8] { + &self.inner } } diff --git a/src/net.rs b/src/net.rs index 7699c34c08..57b18d1cd1 100644 --- a/src/net.rs +++ b/src/net.rs @@ -1,412 +1,169 @@ //! A collection of traits abstracting over Listeners and Streams. -use std::any::{Any, TypeId}; -use std::fmt; -use std::io::{self, ErrorKind, Read, Write}; -use std::net::{SocketAddr, ToSocketAddrs, TcpStream, TcpListener, Shutdown}; -use std::mem; +use std::io::{self, Read, Write}; +use std::net::{SocketAddr}; -#[cfg(feature = "openssl")] -pub use self::openssl::{Openssl, OpensslClient}; - -use std::time::Duration; - -use typeable::Typeable; -use traitobject; - -/// The write-status indicating headers have not been written. -pub enum Fresh {} +use rotor::mio::tcp::{TcpStream, TcpListener}; +use rotor::mio::{Selector, Token, Evented, EventSet, PollOpt, TryAccept}; -/// The write-status indicating headers have been written. -pub enum Streaming {} - -/// An abstraction to listen for connections on a certain port. -pub trait NetworkListener: Clone { - /// The stream produced for each connection. - type Stream: NetworkStream + Send + Clone; - - /// Returns an iterator of streams. - fn accept(&mut self) -> ::Result; - - /// Get the address this Listener ended up listening on. - fn local_addr(&mut self) -> io::Result; +#[cfg(feature = "openssl")] +pub use self::openssl::{Openssl, OpensslStream}; - /// Returns an iterator over incoming connections. - fn incoming(&mut self) -> NetworkConnections { - NetworkConnections(self) - } -} +#[cfg(feature = "security-framework")] +pub use self::security_framework::{SecureTransport, SecureTransportClient, SecureTransportServer}; -/// An iterator wrapper over a `NetworkAcceptor`. -pub struct NetworkConnections<'a, N: NetworkListener + 'a>(&'a mut N); +/// A trait representing a socket transport that can be used in a Client or Server. +#[cfg(not(windows))] +pub trait Transport: Read + Write + Evented + ::vecio::Writev { + /// Takes a socket error when event polling notices an `events.is_error()`. + fn take_socket_error(&mut self) -> io::Result<()>; -impl<'a, N: NetworkListener + 'a> Iterator for NetworkConnections<'a, N> { - type Item = ::Result; - fn next(&mut self) -> Option<::Result> { - Some(self.0.accept()) + /// Returns if the this transport is blocked on read or write. + /// + /// By default, the user will declare whether they wish to wait on read + /// or write events. However, some transports, such as those protected by + /// TLS, may be blocked on reading before it can write, or vice versa. + fn blocked(&self) -> Option { + None } } -/// An abstraction over streams that a `Server` can utilize. -pub trait NetworkStream: Read + Write + Any + Send + Typeable { - /// Get the remote address of the underlying connection. - fn peer_addr(&mut self) -> io::Result; - - /// Set the maximum time to wait for a read to complete. - fn set_read_timeout(&self, dur: Option) -> io::Result<()>; - - /// Set the maximum time to wait for a write to complete. - fn set_write_timeout(&self, dur: Option) -> io::Result<()>; - - /// This will be called when Stream should no longer be kept alive. - #[inline] - fn close(&mut self, _how: Shutdown) -> io::Result<()> { - Ok(()) - } - - // Unsure about name and implementation... - - #[doc(hidden)] - fn set_previous_response_expected_no_content(&mut self, _expected: bool) { } +/// A trait representing a socket transport that can be used in a Client or Server. +#[cfg(windows)] +pub trait Transport: Read + Write + Evented { + /// Takes a socket error when event polling notices an `events.is_error()`. + fn take_socket_error(&mut self) -> io::Result<()>; - #[doc(hidden)] - fn previous_response_expected_no_content(&self) -> bool { - false + /// Returns if the this transport is blocked on read or write. + /// + /// By default, the user will declare whether they wish to wait on read + /// or write events. However, some transports, such as those protected by + /// TLS, may be blocked on reading before it can write, or vice versa. + fn blocked(&self) -> Option { + None } } -/// A connector creates a NetworkStream. -pub trait NetworkConnector { - /// Type of `Stream` to create - type Stream: Into>; - - /// Connect to a remote address. - fn connect(&self, host: &str, port: u16, scheme: &str) -> ::Result; -} - -impl From for Box { - fn from(s: T) -> Box { - Box::new(s) - } +/// Declares when a transport is blocked from any further action, until the +/// corresponding event has occured. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Blocked { + /// Blocked on reading + Read, + /// blocked on writing + Write, } -impl fmt::Debug for Box { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.pad("Box") +impl Transport for HttpStream { + fn take_socket_error(&mut self) -> io::Result<()> { + self.0.take_socket_error() } } -impl NetworkStream { - unsafe fn downcast_ref_unchecked(&self) -> &T { - mem::transmute(traitobject::data(self)) - } - - unsafe fn downcast_mut_unchecked(&mut self) -> &mut T { - mem::transmute(traitobject::data_mut(self)) - } - - unsafe fn downcast_unchecked(self: Box) -> Box { - let raw: *mut NetworkStream = mem::transmute(self); - mem::transmute(traitobject::data_mut(raw)) - } +/// Accepts sockets asynchronously. +pub trait Accept: Evented { + /// The transport type that is accepted. + type Output: Transport; + /// Accept a socket from the listener, if it doesn not block. + fn accept(&self) -> io::Result>; + /// Return the local `SocketAddr` of this listener. + fn local_addr(&self) -> io::Result; } -impl NetworkStream { - /// Is the underlying type in this trait object a `T`? - #[inline] - pub fn is(&self) -> bool { - (*self).get_type() == TypeId::of::() - } +/// An alias to `mio::tcp::TcpStream`. +#[derive(Debug)] +pub struct HttpStream(pub TcpStream); - /// If the underlying type is `T`, get a reference to the contained data. +impl Read for HttpStream { #[inline] - pub fn downcast_ref(&self) -> Option<&T> { - if self.is::() { - Some(unsafe { self.downcast_ref_unchecked() }) - } else { - None - } + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) } +} - /// If the underlying type is `T`, get a mutable reference to the contained - /// data. +impl Write for HttpStream { #[inline] - pub fn downcast_mut(&mut self) -> Option<&mut T> { - if self.is::() { - Some(unsafe { self.downcast_mut_unchecked() }) - } else { - None - } + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) } - /// If the underlying type is `T`, extract it. #[inline] - pub fn downcast(self: Box) - -> Result, Box> { - if self.is::() { - Ok(unsafe { self.downcast_unchecked() }) - } else { - Err(self) - } - } -} - -impl NetworkStream + Send { - unsafe fn downcast_ref_unchecked(&self) -> &T { - mem::transmute(traitobject::data(self)) - } - - unsafe fn downcast_mut_unchecked(&mut self) -> &mut T { - mem::transmute(traitobject::data_mut(self)) - } - - unsafe fn downcast_unchecked(self: Box) -> Box { - let raw: *mut NetworkStream = mem::transmute(self); - mem::transmute(traitobject::data_mut(raw)) + fn flush(&mut self) -> io::Result<()> { + self.0.flush() } } -impl NetworkStream + Send { - /// Is the underlying type in this trait object a `T`? - #[inline] - pub fn is(&self) -> bool { - (*self).get_type() == TypeId::of::() - } - - /// If the underlying type is `T`, get a reference to the contained data. +impl Evented for HttpStream { #[inline] - pub fn downcast_ref(&self) -> Option<&T> { - if self.is::() { - Some(unsafe { self.downcast_ref_unchecked() }) - } else { - None - } + fn register(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { + self.0.register(selector, token, interest, opts) } - /// If the underlying type is `T`, get a mutable reference to the contained - /// data. #[inline] - pub fn downcast_mut(&mut self) -> Option<&mut T> { - if self.is::() { - Some(unsafe { self.downcast_mut_unchecked() }) - } else { - None - } + fn reregister(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { + self.0.reregister(selector, token, interest, opts) } - /// If the underlying type is `T`, extract it. #[inline] - pub fn downcast(self: Box) - -> Result, Box> { - if self.is::() { - Ok(unsafe { self.downcast_unchecked() }) - } else { - Err(self) - } + fn deregister(&self, selector: &mut Selector) -> io::Result<()> { + self.0.deregister(selector) } } -/// A `NetworkListener` for `HttpStream`s. -pub struct HttpListener(TcpListener); - -impl Clone for HttpListener { +#[cfg(not(windows))] +impl ::vecio::Writev for HttpStream { #[inline] - fn clone(&self) -> HttpListener { - HttpListener(self.0.try_clone().unwrap()) + fn writev(&mut self, bufs: &[&[u8]]) -> io::Result { + use ::vecio::Rawv; + self.0.writev(bufs) } } -impl From for HttpListener { - fn from(listener: TcpListener) -> HttpListener { - HttpListener(listener) - } -} +/// An alias to `mio::tcp::TcpListener`. +#[derive(Debug)] +pub struct HttpListener(pub TcpListener); impl HttpListener { - /// Start listening to an address over HTTP. - pub fn new(addr: To) -> ::Result { - Ok(HttpListener(try!(TcpListener::bind(addr)))) - } -} - -impl NetworkListener for HttpListener { - type Stream = HttpStream; - - #[inline] - fn accept(&mut self) -> ::Result { - Ok(HttpStream(try!(self.0.accept()).0)) - } - - #[inline] - fn local_addr(&mut self) -> io::Result { - self.0.local_addr() - } -} - -#[cfg(windows)] -impl ::std::os::windows::io::AsRawSocket for HttpListener { - fn as_raw_socket(&self) -> ::std::os::windows::io::RawSocket { - self.0.as_raw_socket() - } -} - -#[cfg(windows)] -impl ::std::os::windows::io::FromRawSocket for HttpListener { - unsafe fn from_raw_socket(sock: ::std::os::windows::io::RawSocket) -> HttpListener { - HttpListener(TcpListener::from_raw_socket(sock)) + /// Bind to a socket address. + pub fn bind(addr: &SocketAddr) -> io::Result { + TcpListener::bind(addr) + .map(HttpListener) } -} -#[cfg(unix)] -impl ::std::os::unix::io::AsRawFd for HttpListener { - fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd { - self.0.as_raw_fd() + /// Try to duplicate the underlying listening socket. + pub fn try_clone(&self) -> io::Result { + self.0.try_clone().map(HttpListener) } } -#[cfg(unix)] -impl ::std::os::unix::io::FromRawFd for HttpListener { - unsafe fn from_raw_fd(fd: ::std::os::unix::io::RawFd) -> HttpListener { - HttpListener(TcpListener::from_raw_fd(fd)) - } -} -/// A wrapper around a `TcpStream`. -pub struct HttpStream(pub TcpStream); +impl Accept for HttpListener { + type Output = HttpStream; -impl Clone for HttpStream { #[inline] - fn clone(&self) -> HttpStream { - HttpStream(self.0.try_clone().unwrap()) - } -} - -impl fmt::Debug for HttpStream { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str("HttpStream(_)") + fn accept(&self) -> io::Result> { + TryAccept::accept(&self.0).map(|ok| ok.map(HttpStream)) } -} -impl Read for HttpStream { #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) - } -} - -impl Write for HttpStream { - #[inline] - fn write(&mut self, msg: &[u8]) -> io::Result { - self.0.write(msg) - } - #[inline] - fn flush(&mut self) -> io::Result<()> { - self.0.flush() - } -} - -#[cfg(windows)] -impl ::std::os::windows::io::AsRawSocket for HttpStream { - fn as_raw_socket(&self) -> ::std::os::windows::io::RawSocket { - self.0.as_raw_socket() - } -} - -#[cfg(windows)] -impl ::std::os::windows::io::FromRawSocket for HttpStream { - unsafe fn from_raw_socket(sock: ::std::os::windows::io::RawSocket) -> HttpStream { - HttpStream(TcpStream::from_raw_socket(sock)) - } -} - -#[cfg(unix)] -impl ::std::os::unix::io::AsRawFd for HttpStream { - fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd { - self.0.as_raw_fd() - } -} - -#[cfg(unix)] -impl ::std::os::unix::io::FromRawFd for HttpStream { - unsafe fn from_raw_fd(fd: ::std::os::unix::io::RawFd) -> HttpStream { - HttpStream(TcpStream::from_raw_fd(fd)) + fn local_addr(&self) -> io::Result { + self.0.local_addr() } } -impl NetworkStream for HttpStream { +impl Evented for HttpListener { #[inline] - fn peer_addr(&mut self) -> io::Result { - self.0.peer_addr() + fn register(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { + self.0.register(selector, token, interest, opts) } #[inline] - fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.0.set_read_timeout(dur) + fn reregister(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { + self.0.reregister(selector, token, interest, opts) } #[inline] - fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.0.set_write_timeout(dur) - } - - #[inline] - fn close(&mut self, how: Shutdown) -> io::Result<()> { - match self.0.shutdown(how) { - Ok(_) => Ok(()), - // see https://github.com/hyperium/hyper/issues/508 - Err(ref e) if e.kind() == ErrorKind::NotConnected => Ok(()), - err => err - } - } -} - -/// A connector that will produce HttpStreams. -#[derive(Debug, Clone, Default)] -pub struct HttpConnector; - -impl NetworkConnector for HttpConnector { - type Stream = HttpStream; - - fn connect(&self, host: &str, port: u16, scheme: &str) -> ::Result { - let addr = &(host, port); - Ok(try!(match scheme { - "http" => { - debug!("http scheme"); - Ok(HttpStream(try!(TcpStream::connect(addr)))) - }, - _ => { - Err(io::Error::new(io::ErrorKind::InvalidInput, - "Invalid scheme for Http")) - } - })) - } -} - -/// A closure as a connector used to generate `TcpStream`s per request -/// -/// # Example -/// -/// Basic example: -/// -/// ```norun -/// Client::with_connector(|addr: &str, port: u16, scheme: &str| { -/// TcpStream::connect(&(addr, port)) -/// }); -/// ``` -/// -/// Example using `TcpBuilder` from the net2 crate if you want to configure your source socket: -/// -/// ```norun -/// Client::with_connector(|addr: &str, port: u16, scheme: &str| { -/// let b = try!(TcpBuilder::new_v4()); -/// try!(b.bind("127.0.0.1:0")); -/// b.connect(&(addr, port)) -/// }); -/// ``` -impl NetworkConnector for F where F: Fn(&str, u16, &str) -> io::Result { - type Stream = HttpStream; - - fn connect(&self, host: &str, port: u16, scheme: &str) -> ::Result { - Ok(HttpStream(try!((*self)(host, port, scheme)))) + fn deregister(&self, selector: &mut Selector) -> io::Result<()> { + self.0.deregister(selector) } } @@ -415,7 +172,7 @@ impl NetworkConnector for F where F: Fn(&str, u16, &str) -> io::Result ::Result; /// Wrap a server stream with SSL. @@ -423,22 +180,22 @@ pub trait Ssl { } /// An abstraction to allow any SSL implementation to be used with client-side HttpsStreams. -pub trait SslClient { +pub trait SslClient { /// The protected stream. - type Stream: NetworkStream + Send + Clone; + type Stream: Transport; /// Wrap a client stream with SSL. - fn wrap_client(&self, stream: T, host: &str) -> ::Result; + fn wrap_client(&self, stream: HttpStream, host: &str) -> ::Result; } /// An abstraction to allow any SSL implementation to be used with server-side HttpsStreams. -pub trait SslServer { +pub trait SslServer { /// The protected stream. - type Stream: NetworkStream + Send + Clone; + type Stream: Transport; /// Wrap a server stream with SSL. - fn wrap_server(&self, stream: T) -> ::Result; + fn wrap_server(&self, stream: HttpStream) -> ::Result; } -impl SslClient for S { +impl SslClient for S { type Stream = ::Stream; fn wrap_client(&self, stream: HttpStream, host: &str) -> ::Result { @@ -446,7 +203,7 @@ impl SslClient for S { } } -impl SslServer for S { +impl SslServer for S { type Stream = ::Stream; fn wrap_server(&self, stream: HttpStream) -> ::Result { @@ -454,16 +211,16 @@ impl SslServer for S { } } -/// A stream over the HTTP protocol, possibly protected by SSL. -#[derive(Debug, Clone)] -pub enum HttpsStream { +/// A stream over the HTTP protocol, possibly protected by TLS. +#[derive(Debug)] +pub enum HttpsStream { /// A plain text stream. Http(HttpStream), - /// A stream protected by SSL. + /// A stream protected by TLS. Https(S) } -impl Read for HttpsStream { +impl Read for HttpsStream { #[inline] fn read(&mut self, buf: &mut [u8]) -> io::Result { match *self { @@ -473,7 +230,7 @@ impl Read for HttpsStream { } } -impl Write for HttpsStream { +impl Write for HttpsStream { #[inline] fn write(&mut self, msg: &[u8]) -> io::Result { match *self { @@ -491,58 +248,100 @@ impl Write for HttpsStream { } } -impl NetworkStream for HttpsStream { +#[cfg(not(windows))] +impl ::vecio::Writev for HttpsStream { #[inline] - fn peer_addr(&mut self) -> io::Result { + fn writev(&mut self, bufs: &[&[u8]]) -> io::Result { match *self { - HttpsStream::Http(ref mut s) => s.peer_addr(), - HttpsStream::Https(ref mut s) => s.peer_addr() + HttpsStream::Http(ref mut s) => s.writev(bufs), + HttpsStream::Https(ref mut s) => s.writev(bufs) } } +} + + +#[cfg(unix)] +impl ::std::os::unix::io::AsRawFd for HttpStream { + #[inline] + fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd { + self.0.as_raw_fd() + } +} +#[cfg(unix)] +impl ::std::os::unix::io::AsRawFd for HttpsStream { #[inline] - fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd { match *self { - HttpsStream::Http(ref inner) => inner.0.set_read_timeout(dur), - HttpsStream::Https(ref inner) => inner.set_read_timeout(dur) + HttpsStream::Http(ref s) => s.as_raw_fd(), + HttpsStream::Https(ref s) => s.as_raw_fd(), } } +} +impl Evented for HttpsStream { #[inline] - fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + fn register(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { match *self { - HttpsStream::Http(ref inner) => inner.0.set_write_timeout(dur), - HttpsStream::Https(ref inner) => inner.set_write_timeout(dur) + HttpsStream::Http(ref s) => s.register(selector, token, interest, opts), + HttpsStream::Https(ref s) => s.register(selector, token, interest, opts), } } #[inline] - fn close(&mut self, how: Shutdown) -> io::Result<()> { + fn reregister(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { match *self { - HttpsStream::Http(ref mut s) => s.close(how), - HttpsStream::Https(ref mut s) => s.close(how) + HttpsStream::Http(ref s) => s.reregister(selector, token, interest, opts), + HttpsStream::Https(ref s) => s.reregister(selector, token, interest, opts), + } + } + + #[inline] + fn deregister(&self, selector: &mut Selector) -> io::Result<()> { + match *self { + HttpsStream::Http(ref s) => s.deregister(selector), + HttpsStream::Https(ref s) => s.deregister(selector), + } + } +} + +impl Transport for HttpsStream { + #[inline] + fn take_socket_error(&mut self) -> io::Result<()> { + match *self { + HttpsStream::Http(ref mut s) => s.take_socket_error(), + HttpsStream::Https(ref mut s) => s.take_socket_error(), + } + } + + #[inline] + fn blocked(&self) -> Option { + match *self { + HttpsStream::Http(ref s) => s.blocked(), + HttpsStream::Https(ref s) => s.blocked(), } } } /// A Http Listener over SSL. -#[derive(Clone)] +#[derive(Debug)] pub struct HttpsListener { - listener: HttpListener, + listener: TcpListener, ssl: S, } -impl HttpsListener { +impl HttpsListener { /// Start listening to an address over HTTPS. - pub fn new(addr: To, ssl: S) -> ::Result> { - HttpListener::new(addr).map(|l| HttpsListener { + #[inline] + pub fn new(addr: &SocketAddr, ssl: S) -> io::Result> { + TcpListener::bind(addr).map(|l| HttpsListener { listener: l, ssl: ssl }) } /// Construct an HttpsListener from a bound `TcpListener`. - pub fn with_listener(listener: HttpListener, ssl: S) -> HttpsListener { + pub fn with_listener(listener: TcpListener, ssl: S) -> HttpsListener { HttpsListener { listener: listener, ssl: ssl @@ -550,56 +349,52 @@ impl HttpsListener { } } -impl NetworkListener for HttpsListener { - type Stream = S::Stream; +impl Accept for HttpsListener { + type Output = S::Stream; #[inline] - fn accept(&mut self) -> ::Result { - self.listener.accept().and_then(|s| self.ssl.wrap_server(s)) + fn accept(&self) -> io::Result> { + self.listener.accept().and_then(|s| match s { + Some((s, _)) => self.ssl.wrap_server(HttpStream(s)).map(Some).map_err(|e| { + match e { + ::Error::Io(e) => e, + _ => io::Error::new(io::ErrorKind::Other, e), + + } + }), + None => Ok(None), + }) } #[inline] - fn local_addr(&mut self) -> io::Result { + fn local_addr(&self) -> io::Result { self.listener.local_addr() } } -/// A connector that can protect HTTP streams using SSL. -#[derive(Debug, Default)] -pub struct HttpsConnector { - ssl: S, - connector: C, -} - -impl HttpsConnector { - /// Create a new connector using the provided SSL implementation. - pub fn new(s: S) -> HttpsConnector { - HttpsConnector::with_connector(s, HttpConnector) +impl Evented for HttpsListener { + #[inline] + fn register(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { + self.listener.register(selector, token, interest, opts) } -} -impl HttpsConnector { - /// Create a new connector using the provided SSL implementation. - pub fn with_connector(s: S, connector: C) -> HttpsConnector { - HttpsConnector { ssl: s, connector: connector } + #[inline] + fn reregister(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { + self.listener.reregister(selector, token, interest, opts) } -} -impl> NetworkConnector for HttpsConnector { - type Stream = HttpsStream; - - fn connect(&self, host: &str, port: u16, scheme: &str) -> ::Result { - let stream = try!(self.connector.connect(host, port, "http")); - if scheme == "https" { - debug!("https scheme"); - self.ssl.wrap_client(stream, host).map(HttpsStream::Https) - } else { - Ok(HttpsStream::Http(stream)) - } + #[inline] + fn deregister(&self, selector: &mut Selector) -> io::Result<()> { + self.listener.deregister(selector) } } +fn _assert_transport() { + fn _assert() {} + _assert::>(); +} +/* #[cfg(all(not(feature = "openssl"), not(feature = "security-framework")))] #[doc(hidden)] pub type DefaultConnector = HttpConnector; @@ -611,19 +406,24 @@ pub type DefaultConnector = HttpsConnector; #[cfg(all(feature = "security-framework", not(feature = "openssl")))] pub type DefaultConnector = HttpsConnector; +#[doc(hidden)] +pub type DefaultTransport = ::Output; +*/ + #[cfg(feature = "openssl")] mod openssl { - use std::io; - use std::net::{SocketAddr, Shutdown}; + use std::io::{self, Write}; use std::path::Path; - use std::sync::Arc; - use std::time::Duration; - use openssl::ssl::{Ssl, SslContext, SslStream, SslMethod, SSL_VERIFY_NONE, SSL_VERIFY_PEER, SSL_OP_NO_SSLV2, SSL_OP_NO_SSLV3, SSL_OP_NO_COMPRESSION}; + use rotor::mio::{Selector, Token, Evented, EventSet, PollOpt}; + + use openssl::ssl::{Ssl, SslContext, SslStream, SslMethod, SSL_VERIFY_PEER, SSL_OP_NO_SSLV2, SSL_OP_NO_SSLV3, SSL_OP_NO_COMPRESSION}; use openssl::ssl::error::StreamError as SslIoError; use openssl::ssl::error::SslError; + use openssl::ssl::error::Error as OpensslError; use openssl::x509::X509FileType; - use super::{NetworkStream, HttpStream}; + + use super::{HttpStream, Blocked}; /// An implementation of `Ssl` for OpenSSL. /// @@ -634,7 +434,7 @@ mod openssl { /// use hyper::net::Openssl; /// /// let ssl = Openssl::with_cert_and_key("/home/foo/cert", "/home/foo/key").unwrap(); - /// Server::https("0.0.0.0:443", ssl).unwrap(); + /// Server::https(&"0.0.0.0:443".parse().unwrap(), ssl).unwrap(); /// ``` /// /// For complete control, create a `SslContext` with the options you desire @@ -642,7 +442,7 @@ mod openssl { #[derive(Debug, Clone)] pub struct Openssl { /// The `SslContext` from openssl crate. - pub context: Arc + pub context: SslContext } /// A client-specific implementation of OpenSSL. @@ -662,26 +462,28 @@ mod openssl { } - impl super::SslClient for OpensslClient { - type Stream = SslStream; + impl super::SslClient for OpensslClient { + type Stream = OpensslStream; - fn wrap_client(&self, stream: T, host: &str) -> ::Result { + fn wrap_client(&self, stream: HttpStream, host: &str) -> ::Result { let mut ssl = try!(Ssl::new(&self.0)); try!(ssl.set_hostname(host)); let host = host.to_owned(); ssl.set_verify_callback(SSL_VERIFY_PEER, move |p, x| ::openssl_verify::verify_callback(&host, p, x)); - SslStream::connect(ssl, stream).map_err(From::from) + SslStream::connect(ssl, stream) + .map(openssl_stream) + .map_err(From::from) } } impl Default for Openssl { fn default() -> Openssl { Openssl { - context: Arc::new(SslContext::new(SslMethod::Sslv23).unwrap_or_else(|e| { + context: SslContext::new(SslMethod::Sslv23).unwrap_or_else(|e| { // if we cannot create a SslContext, that's because of a // serious problem. just crash. panic!("{}", e) - })) + }) } } } @@ -691,26 +493,27 @@ mod openssl { pub fn with_cert_and_key(cert: C, key: K) -> Result where C: AsRef, K: AsRef { let mut ctx = try!(SslContext::new(SslMethod::Sslv23)); - try!(ctx.set_cipher_list("DEFAULT")); + try!(ctx.set_cipher_list("ALL!EXPORT!EXPORT40!EXPORT56!aNULL!LOW!RC4@STRENGTH")); try!(ctx.set_certificate_file(cert.as_ref(), X509FileType::PEM)); try!(ctx.set_private_key_file(key.as_ref(), X509FileType::PEM)); - ctx.set_verify(SSL_VERIFY_NONE, None); - Ok(Openssl { context: Arc::new(ctx) }) + Ok(Openssl { context: ctx }) } } impl super::Ssl for Openssl { - type Stream = SslStream; + type Stream = OpensslStream; fn wrap_client(&self, stream: HttpStream, host: &str) -> ::Result { let ssl = try!(Ssl::new(&self.context)); try!(ssl.set_hostname(host)); - SslStream::connect(ssl, stream).map_err(From::from) + SslStream::connect(ssl, stream) + .map(openssl_stream) + .map_err(From::from) } fn wrap_server(&self, stream: HttpStream) -> ::Result { - match SslStream::accept(&*self.context, stream) { - Ok(ssl_stream) => Ok(ssl_stream), + match SslStream::accept(&self.context, stream) { + Ok(ssl_stream) => Ok(openssl_stream(ssl_stream)), Err(SslIoError(e)) => { Err(io::Error::new(io::ErrorKind::ConnectionAborted, e).into()) }, @@ -719,159 +522,220 @@ mod openssl { } } - impl NetworkStream for SslStream { + /// A transport protected by OpenSSL. + #[derive(Debug)] + pub struct OpensslStream { + stream: SslStream, + blocked: Option, + } + + fn openssl_stream(inner: SslStream) -> OpensslStream { + OpensslStream { + stream: inner, + blocked: None, + } + } + + impl io::Read for OpensslStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.blocked = None; + self.stream.ssl_read(buf).or_else(|e| match e { + OpensslError::ZeroReturn => Ok(0), + OpensslError::WantWrite(e) => { + self.blocked = Some(Blocked::Write); + Err(e) + }, + OpensslError::WantRead(e) | OpensslError::Stream(e) => Err(e), + e => Err(io::Error::new(io::ErrorKind::Other, e)) + }) + } + } + + impl io::Write for OpensslStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.blocked = None; + self.stream.ssl_write(buf).or_else(|e| match e { + OpensslError::ZeroReturn => Ok(0), + OpensslError::WantRead(e) => { + self.blocked = Some(Blocked::Read); + Err(e) + }, + OpensslError::WantWrite(e) | OpensslError::Stream(e) => Err(e), + e => Err(io::Error::new(io::ErrorKind::Other, e)) + }) + } + + fn flush(&mut self) -> io::Result<()> { + self.stream.flush() + } + } + + + impl Evented for OpensslStream { #[inline] - fn peer_addr(&mut self) -> io::Result { - self.get_mut().peer_addr() + fn register(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { + self.stream.get_ref().register(selector, token, interest, opts) } #[inline] - fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.get_ref().set_read_timeout(dur) + fn reregister(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { + self.stream.get_ref().reregister(selector, token, interest, opts) } #[inline] - fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.get_ref().set_write_timeout(dur) + fn deregister(&self, selector: &mut Selector) -> io::Result<()> { + self.stream.get_ref().deregister(selector) + } + } + + impl ::vecio::Writev for OpensslStream { + fn writev(&mut self, bufs: &[&[u8]]) -> io::Result { + let vec = bufs.concat(); + self.write(&vec) } + } - fn close(&mut self, how: Shutdown) -> io::Result<()> { - self.get_mut().close(how) + impl super::Transport for OpensslStream { + fn take_socket_error(&mut self) -> io::Result<()> { + self.stream.get_mut().take_socket_error() } } } #[cfg(feature = "security-framework")] -pub mod security_framework { - use std::io; - use std::fmt; - use std::sync::{Arc, Mutex}; - use std::net::{Shutdown, SocketAddr}; - use std::time::Duration; - use security_framework::secure_transport::{SslStream, ClientBuilder, ServerBuilder}; +mod security_framework { + use std::io::{self, Read, Write}; use error::Error; - use net::{SslClient, SslServer, HttpStream, NetworkStream}; + use net::{SslClient, SslServer, HttpStream, Transport, Blocked}; + + use security_framework::secure_transport::SslStream; + pub use security_framework::secure_transport::{ClientBuilder as SecureTransportClient, ServerBuilder as SecureTransportServer}; + use rotor::mio::{Selector, Token, Evented, EventSet, PollOpt}; - #[derive(Default)] - pub struct ClientWrapper(ClientBuilder); + impl SslClient for SecureTransportClient { + type Stream = SecureTransport; - impl ClientWrapper { - pub fn new(builder: ClientBuilder) -> ClientWrapper { - ClientWrapper(builder) + fn wrap_client(&self, stream: HttpStream, host: &str) -> ::Result { + match self.handshake(host, journal(stream)) { + Ok(s) => Ok(SecureTransport(s)), + Err(e) => Err(Error::Ssl(e.into())), + } } } - impl SslClient for ClientWrapper { - type Stream = Stream; + impl SslServer for SecureTransportServer { + type Stream = SecureTransport; - fn wrap_client(&self, stream: HttpStream, host: &str) -> ::Result { - match self.0.handshake(host, stream) { - Ok(s) => Ok(Stream(Arc::new(Mutex::new(s)))), + fn wrap_server(&self, stream: HttpStream) -> ::Result { + match self.handshake(journal(stream)) { + Ok(s) => Ok(SecureTransport(s)), Err(e) => Err(Error::Ssl(e.into())), } } } - #[derive(Clone)] - pub struct ServerWrapper(Arc); + /// A transport protected by Security Framework. + #[derive(Debug)] + pub struct SecureTransport(SslStream>); - impl ServerWrapper { - pub fn new(builder: ServerBuilder) -> ServerWrapper { - ServerWrapper(Arc::new(builder)) + impl io::Read for SecureTransport { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) } } - impl SslServer for ServerWrapper { - type Stream = Stream; + impl io::Write for SecureTransport { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } - fn wrap_server(&self, stream: HttpStream) -> ::Result { - match self.0.handshake(stream) { - Ok(s) => Ok(Stream(Arc::new(Mutex::new(s)))), - Err(e) => Err(Error::Ssl(e.into())), - } + #[inline] + fn flush(&mut self) -> io::Result<()> { + self.0.flush() } } - #[derive(Clone)] - pub struct Stream(Arc>>); - impl io::Read for Stream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.lock().unwrap_or_else(|e| e.into_inner()).read(buf) + impl Evented for SecureTransport { + #[inline] + fn register(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { + self.0.get_ref().inner.register(selector, token, interest, opts) } - fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { - self.0.lock().unwrap_or_else(|e| e.into_inner()).read_to_end(buf) + #[inline] + fn reregister(&self, selector: &mut Selector, token: Token, interest: EventSet, opts: PollOpt) -> io::Result<()> { + self.0.get_ref().inner.reregister(selector, token, interest, opts) } - fn read_to_string(&mut self, buf: &mut String) -> io::Result { - self.0.lock().unwrap_or_else(|e| e.into_inner()).read_to_string(buf) + #[inline] + fn deregister(&self, selector: &mut Selector) -> io::Result<()> { + self.0.get_ref().inner.deregister(selector) } + } - fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { - self.0.lock().unwrap_or_else(|e| e.into_inner()).read_exact(buf) + impl ::vecio::Writev for SecureTransport { + fn writev(&mut self, bufs: &[&[u8]]) -> io::Result { + let vec = bufs.concat(); + self.write(&vec) } } - impl io::Write for Stream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.lock().unwrap_or_else(|e| e.into_inner()).write(buf) + impl Transport for SecureTransport { + fn take_socket_error(&mut self) -> io::Result<()> { + self.0.get_mut().inner.take_socket_error() } - fn flush(&mut self) -> io::Result<()> { - self.0.lock().unwrap_or_else(|e| e.into_inner()).flush() + fn blocked(&self) -> Option { + self.0.get_ref().blocked } + } - fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { - self.0.lock().unwrap_or_else(|e| e.into_inner()).write_all(buf) - } - fn write_fmt(&mut self, fmt: fmt::Arguments) -> io::Result<()> { - self.0.lock().unwrap_or_else(|e| e.into_inner()).write_fmt(fmt) - } + // Records if this object was blocked on reading or writing. + #[derive(Debug)] + struct Journal { + inner: T, + blocked: Option, } - impl NetworkStream for Stream { - fn peer_addr(&mut self) -> io::Result { - self.0.lock().unwrap_or_else(|e| e.into_inner()).get_mut().peer_addr() + fn journal(inner: T) -> Journal { + Journal { + inner: inner, + blocked: None, } + } - fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.0.lock().unwrap_or_else(|e| e.into_inner()).get_mut().set_read_timeout(dur) + impl Read for Journal { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.blocked = None; + self.inner.read(buf).map_err(|e| match e.kind() { + io::ErrorKind::WouldBlock => { + self.blocked = Some(Blocked::Read); + e + }, + _ => e + }) } + } - fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.0.lock().unwrap_or_else(|e| e.into_inner()).get_mut().set_write_timeout(dur) + impl Write for Journal { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.blocked = None; + self.inner.write(buf).map_err(|e| match e.kind() { + io::ErrorKind::WouldBlock => { + self.blocked = Some(Blocked::Write); + e + }, + _ => e + }) } - fn close(&mut self, how: Shutdown) -> io::Result<()> { - self.0.lock().unwrap_or_else(|e| e.into_inner()).get_mut().close(how) + fn flush(&mut self) -> io::Result<()> { + self.inner.flush() } } -} -#[cfg(test)] -mod tests { - use mock::MockStream; - use super::{NetworkStream}; - - #[test] - fn test_downcast_box_stream() { - // FIXME: Use Type ascription - let stream: Box = Box::new(MockStream::new()); - - let mock = stream.downcast::().ok().unwrap(); - assert_eq!(mock, Box::new(MockStream::new())); - } - - #[test] - fn test_downcast_unchecked_box_stream() { - // FIXME: Use Type ascription - let stream: Box = Box::new(MockStream::new()); - - let mock = unsafe { stream.downcast_unchecked::() }; - assert_eq!(mock, Box::new(MockStream::new())); - } } - diff --git a/src/server/listener.rs b/src/server/listener.rs deleted file mode 100644 index 16c58905d2..0000000000 --- a/src/server/listener.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::sync::{Arc, mpsc}; -use std::thread; - -use net::NetworkListener; - -pub struct ListenerPool { - acceptor: A -} - -impl ListenerPool { - /// Create a thread pool to manage the acceptor. - pub fn new(acceptor: A) -> ListenerPool { - ListenerPool { acceptor: acceptor } - } - - /// Runs the acceptor pool. Blocks until the acceptors are closed. - /// - /// ## Panics - /// - /// Panics if threads == 0. - pub fn accept(self, work: F, threads: usize) - where F: Fn(A::Stream) + Send + Sync + 'static { - assert!(threads != 0, "Can't accept on 0 threads."); - - let (super_tx, supervisor_rx) = mpsc::channel(); - - let work = Arc::new(work); - - // Begin work. - for _ in 0..threads { - spawn_with(super_tx.clone(), work.clone(), self.acceptor.clone()) - } - - // Monitor for panics. - // FIXME(reem): This won't ever exit since we still have a super_tx handle. - for _ in supervisor_rx.iter() { - spawn_with(super_tx.clone(), work.clone(), self.acceptor.clone()); - } - } -} - -fn spawn_with(supervisor: mpsc::Sender<()>, work: Arc, mut acceptor: A) -where A: NetworkListener + Send + 'static, - F: Fn(::Stream) + Send + Sync + 'static { - thread::spawn(move || { - let _sentinel = Sentinel::new(supervisor, ()); - - loop { - match acceptor.accept() { - Ok(stream) => work(stream), - Err(e) => { - error!("Connection failed: {}", e); - } - } - } - }); -} - -struct Sentinel { - value: Option, - supervisor: mpsc::Sender, -} - -impl Sentinel { - fn new(channel: mpsc::Sender, data: T) -> Sentinel { - Sentinel { - value: Some(data), - supervisor: channel, - } - } -} - -impl Drop for Sentinel { - fn drop(&mut self) { - // Respawn ourselves - let _ = self.supervisor.send(self.value.take().unwrap()); - } -} - diff --git a/src/server/message.rs b/src/server/message.rs new file mode 100644 index 0000000000..05fe7a8c9c --- /dev/null +++ b/src/server/message.rs @@ -0,0 +1,58 @@ +use std::marker::PhantomData; + + +use http::{self, Next}; +use net::Transport; + +use super::{Handler, request, response}; + +/// A MessageHandler for a Server. +/// +/// This should be really thin glue between http::MessageHandler and +/// server::Handler, but largely just providing the proper types one +/// would expect in a Server Handler. +pub struct Message, T: Transport> { + handler: H, + _marker: PhantomData +} + +impl, T: Transport> Message { + pub fn new(handler: H) -> Message { + Message { + handler: handler, + _marker: PhantomData, + } + } +} + +impl, T: Transport> http::MessageHandler for Message { + type Message = http::ServerMessage; + + fn on_incoming(&mut self, head: http::RequestHead) -> Next { + trace!("on_incoming {:?}", head); + let req = request::new(head); + self.handler.on_request(req) + } + + fn on_decode(&mut self, transport: &mut http::Decoder) -> Next { + self.handler.on_request_readable(transport) + } + + fn on_outgoing(&mut self, head: &mut http::MessageHead<::status::StatusCode>) -> Next { + let mut res = response::new(head); + self.handler.on_response(&mut res) + } + + fn on_encode(&mut self, transport: &mut http::Encoder) -> Next { + self.handler.on_response_writable(transport) + } + + fn on_error(&mut self, error: ::Error) -> Next { + self.handler.on_error(error) + } + + fn on_remove(self, transport: T) { + self.handler.on_remove(transport); + } +} + diff --git a/src/server/mod.rs b/src/server/mod.rs index 7299a055b7..1c1f5b7f12 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,510 +1,369 @@ //! HTTP Server //! -//! # Server -//! //! A `Server` is created to listen on port, parse HTTP requests, and hand -//! them off to a `Handler`. By default, the Server will listen across multiple -//! threads, but that can be configured to a single thread if preferred. -//! -//! # Handling requests -//! -//! You must pass a `Handler` to the Server that will handle requests. There is -//! a default implementation for `fn`s and closures, allowing you pass one of -//! those easily. -//! -//! -//! ```no_run -//! use hyper::server::{Server, Request, Response}; -//! -//! fn hello(req: Request, res: Response) { -//! // handle things here -//! } -//! -//! Server::http("0.0.0.0:0").unwrap().handle(hello).unwrap(); -//! ``` -//! -//! As with any trait, you can also define a struct and implement `Handler` -//! directly on your own type, and pass that to the `Server` instead. -//! -//! ```no_run -//! use std::sync::Mutex; -//! use std::sync::mpsc::{channel, Sender}; -//! use hyper::server::{Handler, Server, Request, Response}; -//! -//! struct SenderHandler { -//! sender: Mutex> -//! } -//! -//! impl Handler for SenderHandler { -//! fn handle(&self, req: Request, res: Response) { -//! self.sender.lock().unwrap().send("start").unwrap(); -//! } -//! } -//! -//! -//! let (tx, rx) = channel(); -//! Server::http("0.0.0.0:0").unwrap().handle(SenderHandler { -//! sender: Mutex::new(tx) -//! }).unwrap(); -//! ``` -//! -//! Since the `Server` will be listening on multiple threads, the `Handler` -//! must implement `Sync`: any mutable state must be synchronized. -//! -//! ```no_run -//! use std::sync::atomic::{AtomicUsize, Ordering}; -//! use hyper::server::{Server, Request, Response}; -//! -//! let counter = AtomicUsize::new(0); -//! Server::http("0.0.0.0:0").unwrap().handle(move |req: Request, res: Response| { -//! counter.fetch_add(1, Ordering::Relaxed); -//! }).unwrap(); -//! ``` -//! -//! # The `Request` and `Response` pair -//! -//! A `Handler` receives a pair of arguments, a `Request` and a `Response`. The -//! `Request` includes access to the `method`, `uri`, and `headers` of the -//! incoming HTTP request. It also implements `std::io::Read`, in order to -//! read any body, such as with `POST` or `PUT` messages. -//! -//! Likewise, the `Response` includes ways to set the `status` and `headers`, -//! and implements `std::io::Write` to allow writing the response body. -//! -//! ```no_run -//! use std::io; -//! use hyper::server::{Server, Request, Response}; -//! use hyper::status::StatusCode; -//! -//! Server::http("0.0.0.0:0").unwrap().handle(|mut req: Request, mut res: Response| { -//! match req.method { -//! hyper::Post => { -//! io::copy(&mut req, &mut res.start().unwrap()).unwrap(); -//! }, -//! _ => *res.status_mut() = StatusCode::MethodNotAllowed -//! } -//! }).unwrap(); -//! ``` -//! -//! ## An aside: Write Status -//! -//! The `Response` uses a phantom type parameter to determine its write status. -//! What does that mean? In short, it ensures you never write a body before -//! adding all headers, and never add a header after writing some of the body. -//! -//! This is often done in most implementations by include a boolean property -//! on the response, such as `headers_written`, checking that each time the -//! body has something to write, so as to make sure the headers are sent once, -//! and only once. But this has 2 downsides: -//! -//! 1. You are typically never notified that your late header is doing nothing. -//! 2. There's a runtime cost to checking on every write. -//! -//! Instead, hyper handles this statically, or at compile-time. A -//! `Response` includes a `headers_mut()` method, allowing you add more -//! headers. It also does not implement `Write`, so you can't accidentally -//! write early. Once the "head" of the response is correct, you can "send" it -//! out by calling `start` on the `Response`. This will return a new -//! `Response` object, that no longer has `headers_mut()`, but does -//! implement `Write`. +//! them off to a `Handler`. use std::fmt; -use std::io::{self, ErrorKind, BufWriter, Write}; -use std::net::{SocketAddr, ToSocketAddrs}; -use std::thread::{self, JoinHandle}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; -use num_cpus; +use rotor::mio::{EventSet, PollOpt}; +use rotor::{self, Scope}; pub use self::request::Request; pub use self::response::Response; -pub use net::{Fresh, Streaming}; - -use Error; -use buffer::BufReader; -use header::{Headers, Expect, Connection}; -use http; -use method::Method; -use net::{NetworkListener, NetworkStream, HttpListener, HttpsListener, Ssl}; -use status::StatusCode; -use uri::RequestUri; -use version::HttpVersion::Http11; - -use self::listener::ListenerPool; +use http::{self, Next}; +use net::{Accept, HttpListener, HttpsListener, SslServer, Transport}; -pub mod request; -pub mod response; -mod listener; +mod request; +mod response; +mod message; -/// A server can listen on a TCP socket. -/// -/// Once listening, it will create a `Request`/`Response` pair for each -/// incoming connection, and hand them to the provided handler. -#[derive(Debug)] -pub struct Server { - listener: L, - timeouts: Timeouts, -} - -#[derive(Clone, Copy, Debug)] -struct Timeouts { - read: Option, - write: Option, - keep_alive: Option, +/// A configured `Server` ready to run. +pub struct ServerLoop where A: Accept, H: HandlerFactory { + inner: Option<(rotor::Loop>, Context)>, } -impl Default for Timeouts { - fn default() -> Timeouts { - Timeouts { - read: None, - write: None, - keep_alive: Some(Duration::from_secs(5)) - } +impl> fmt::Debug for ServerLoop { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.pad("ServerLoop") } } -macro_rules! try_option( - ($e:expr) => {{ - match $e { - Some(v) => v, - None => return None - } - }} -); +/// A Server that can accept incoming network requests. +#[derive(Debug)] +pub struct Server { + listener: T, + keep_alive: bool, + idle_timeout: Duration, + max_sockets: usize, +} -impl Server { - /// Creates a new server with the provided handler. +impl Server where T: Accept, T::Output: Transport { + /// Creates a new server with the provided Listener. #[inline] - pub fn new(listener: L) -> Server { + pub fn new(listener: T) -> Server { Server { listener: listener, - timeouts: Timeouts::default() + keep_alive: true, + idle_timeout: Duration::from_secs(10), + max_sockets: 4096, } } - /// Controls keep-alive for this server. - /// - /// The timeout duration passed will be used to determine how long - /// to keep the connection alive before dropping it. - /// - /// Passing `None` will disable keep-alive. + /// Enables or disables HTTP keep-alive. /// - /// Default is enabled with a 5 second timeout. - #[inline] - pub fn keep_alive(&mut self, timeout: Option) { - self.timeouts.keep_alive = timeout; + /// Default is true. + pub fn keep_alive(mut self, val: bool) -> Server { + self.keep_alive = val; + self } - /// Sets the read timeout for all Request reads. - pub fn set_read_timeout(&mut self, dur: Option) { - self.timeouts.read = dur; + /// Sets how long an idle connection will be kept before closing. + /// + /// Default is 10 seconds. + pub fn idle_timeout(mut self, val: Duration) -> Server { + self.idle_timeout = val; + self } - /// Sets the write timeout for all Response writes. - pub fn set_write_timeout(&mut self, dur: Option) { - self.timeouts.write = dur; + /// Sets the maximum open sockets for this Server. + /// + /// Default is 4096, but most servers can handle much more than this. + pub fn max_sockets(mut self, val: usize) -> Server { + self.max_sockets = val; + self } } -impl Server { - /// Creates a new server that will handle `HttpStream`s. - pub fn http(addr: To) -> ::Result> { - HttpListener::new(addr).map(Server::new) +impl Server { //::Output>> Server { + /// Creates a new HTTP server config listening on the provided address. + pub fn http(addr: &SocketAddr) -> ::Result> { + use ::rotor::mio::tcp::TcpListener; + TcpListener::bind(addr) + .map(HttpListener) + .map(Server::new) + .map_err(From::from) } } -impl Server> { - /// Creates a new server that will handle `HttpStream`s over SSL. + +impl Server> { + /// Creates a new server config that will handle `HttpStream`s over SSL. /// /// You can use any SSL implementation, as long as implements `hyper::net::Ssl`. - pub fn https(addr: A, ssl: S) -> ::Result>> { - HttpsListener::new(addr, ssl).map(Server::new) + pub fn https(addr: &SocketAddr, ssl: S) -> ::Result>> { + HttpsListener::new(addr, ssl) + .map(Server::new) + .map_err(From::from) } } -impl Server { + +impl Server where A::Output: Transport { /// Binds to a socket and starts handling connections. - pub fn handle(self, handler: H) -> ::Result { - self.handle_threads(handler, num_cpus::get() * 5 / 4) - } + pub fn handle(self, factory: H) -> ::Result<(Listening, ServerLoop)> + where H: HandlerFactory { + let addr = try!(self.listener.local_addr()); + let shutdown = Arc::new(AtomicBool::new(false)); + let shutdown_rx = shutdown.clone(); + + let mut config = rotor::Config::new(); + config.slab_capacity(self.max_sockets); + config.mio().notify_capacity(self.max_sockets); + let keep_alive = self.keep_alive; + let mut loop_ = rotor::Loop::new(&config).unwrap(); + let mut notifier = None; + { + let notifier = &mut notifier; + loop_.add_machine_with(move |scope| { + *notifier = Some(scope.notifier()); + rotor_try!(scope.register(&self.listener, EventSet::readable(), PollOpt::level())); + rotor::Response::ok(ServerFsm::Listener::(self.listener, shutdown_rx)) + }).unwrap(); + } + let notifier = notifier.expect("loop.add_machine failed"); - /// Binds to a socket and starts handling connections with the provided - /// number of threads. - pub fn handle_threads(self, handler: H, - threads: usize) -> ::Result { - handle(self, handler, threads) + let listening = Listening { + addr: addr, + shutdown: (shutdown, notifier), + }; + let server = ServerLoop { + inner: Some((loop_, Context { + keep_alive: keep_alive, + factory: factory + })) + }; + Ok((listening, server)) } } -fn handle(mut server: Server, handler: H, threads: usize) -> ::Result -where H: Handler + 'static, L: NetworkListener + Send + 'static { - let socket = try!(server.listener.local_addr()); - debug!("threads = {:?}", threads); - let pool = ListenerPool::new(server.listener); - let worker = Worker::new(handler, server.timeouts); - let work = move |mut stream| worker.handle_connection(&mut stream); - - let guard = thread::spawn(move || pool.accept(work, threads)); - - Ok(Listening { - _guard: Some(guard), - socket: socket, - }) -} - -struct Worker { - handler: H, - timeouts: Timeouts, +impl> ServerLoop { + /// Runs the server forever in this loop. + /// + /// This will block the current thread. + pub fn run(self) { + // drop will take care of it. + } } -impl Worker { - fn new(handler: H, timeouts: Timeouts) -> Worker { - Worker { - handler: handler, - timeouts: timeouts, - } +impl> Drop for ServerLoop { + fn drop(&mut self) { + self.inner.take().map(|(loop_, ctx)| { + let _ = loop_.run(ctx); + }); } +} - fn handle_connection(&self, mut stream: &mut S) where S: NetworkStream + Clone { - debug!("Incoming stream"); - - self.handler.on_connection_start(); - - if let Err(e) = self.set_timeouts(&*stream) { - error!("set_timeouts error: {:?}", e); - return; - } - - let addr = match stream.peer_addr() { - Ok(addr) => addr, - Err(e) => { - error!("Peer Name error: {:?}", e); - return; - } - }; - - // FIXME: Use Type ascription - let stream_clone: &mut NetworkStream = &mut stream.clone(); - let mut rdr = BufReader::new(stream_clone); - let mut wrt = BufWriter::new(stream); - - while self.keep_alive_loop(&mut rdr, &mut wrt, addr) { - if let Err(e) = self.set_read_timeout(*rdr.get_ref(), self.timeouts.keep_alive) { - error!("set_read_timeout keep_alive {:?}", e); - break; - } - } - - self.handler.on_connection_end(); +struct Context { + keep_alive: bool, + factory: F, +} - debug!("keep_alive loop ending for {}", addr); - } +impl, T: Transport> http::MessageHandlerFactory<(), T> for Context { + type Output = message::Message; - fn set_timeouts(&self, s: &NetworkStream) -> io::Result<()> { - try!(self.set_read_timeout(s, self.timeouts.read)); - self.set_write_timeout(s, self.timeouts.write) + fn create(&mut self, seed: http::Seed<()>) -> Self::Output { + message::Message::new(self.factory.create(seed.control())) } +} - fn set_write_timeout(&self, s: &NetworkStream, timeout: Option) -> io::Result<()> { - s.set_write_timeout(timeout) - } +enum ServerFsm +where A: Accept, + A::Output: Transport, + H: HandlerFactory { + Listener(A, Arc), + Conn(http::Conn<(), A::Output, message::Message>) +} - fn set_read_timeout(&self, s: &NetworkStream, timeout: Option) -> io::Result<()> { - s.set_read_timeout(timeout) +impl rotor::Machine for ServerFsm +where A: Accept, + A::Output: Transport, + H: HandlerFactory { + type Context = Context; + type Seed = A::Output; + + fn create(seed: Self::Seed, scope: &mut Scope) -> rotor::Response { + rotor_try!(scope.register(&seed, EventSet::readable(), PollOpt::level())); + rotor::Response::ok( + ServerFsm::Conn( + http::Conn::new((), seed, scope.notifier()) + .keep_alive(scope.keep_alive) + ) + ) } - fn keep_alive_loop(&self, mut rdr: &mut BufReader<&mut NetworkStream>, - wrt: &mut W, addr: SocketAddr) -> bool { - let req = match Request::new(rdr, addr) { - Ok(req) => req, - Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => { - trace!("tcp closed, cancelling keep-alive loop"); - return false; - } - Err(Error::Io(e)) => { - debug!("ioerror in keepalive loop = {:?}", e); - return false; - } - Err(e) => { - //TODO: send a 400 response - error!("request error = {:?}", e); - return false; + fn ready(self, events: EventSet, scope: &mut Scope) -> rotor::Response { + match self { + ServerFsm::Listener(listener, rx) => { + match listener.accept() { + Ok(Some(conn)) => { + rotor::Response::spawn(ServerFsm::Listener(listener, rx), conn) + }, + Ok(None) => rotor::Response::ok(ServerFsm::Listener(listener, rx)), + Err(e) => { + error!("listener accept error {}", e); + // usually fine, just keep listening + rotor::Response::ok(ServerFsm::Listener(listener, rx)) + } + } + }, + ServerFsm::Conn(conn) => { + match conn.ready(events, scope) { + Some((conn, None)) => rotor::Response::ok(ServerFsm::Conn(conn)), + Some((conn, Some(dur))) => { + rotor::Response::ok(ServerFsm::Conn(conn)) + .deadline(scope.now() + dur) + } + None => rotor::Response::done() + } } - }; - - if !self.handle_expect(&req, wrt) { - return false; - } - - if let Err(e) = req.set_read_timeout(self.timeouts.read) { - error!("set_read_timeout {:?}", e); - return false; - } - - let mut keep_alive = self.timeouts.keep_alive.is_some() && - http::should_keep_alive(req.version, &req.headers); - let version = req.version; - let mut res_headers = Headers::new(); - if !keep_alive { - res_headers.set(Connection::close()); - } - { - let mut res = Response::new(wrt, &mut res_headers); - res.version = version; - self.handler.handle(req, res); } + } - // if the request was keep-alive, we need to check that the server agrees - // if it wasn't, then the server cannot force it to be true anyways - if keep_alive { - keep_alive = http::should_keep_alive(version, &res_headers); + fn spawned(self, _scope: &mut Scope) -> rotor::Response { + match self { + ServerFsm::Listener(listener, rx) => { + match listener.accept() { + Ok(Some(conn)) => { + rotor::Response::spawn(ServerFsm::Listener(listener, rx), conn) + }, + Ok(None) => rotor::Response::ok(ServerFsm::Listener(listener, rx)), + Err(e) => { + error!("listener accept error {}", e); + // usually fine, just keep listening + rotor::Response::ok(ServerFsm::Listener(listener, rx)) + } + } + }, + sock => rotor::Response::ok(sock) } - debug!("keep_alive = {:?} for {}", keep_alive, addr); - keep_alive } - fn handle_expect(&self, req: &Request, wrt: &mut W) -> bool { - if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) { - let status = self.handler.check_continue((&req.method, &req.uri, &req.headers)); - match write!(wrt, "{} {}\r\n\r\n", Http11, status).and_then(|_| wrt.flush()) { - Ok(..) => (), - Err(e) => { - error!("error writing 100-continue: {:?}", e); - return false; + fn timeout(self, scope: &mut Scope) -> rotor::Response { + match self { + ServerFsm::Listener(..) => unreachable!("Listener cannot timeout"), + ServerFsm::Conn(conn) => { + match conn.timeout(scope) { + Some((conn, None)) => rotor::Response::ok(ServerFsm::Conn(conn)), + Some((conn, Some(dur))) => { + rotor::Response::ok(ServerFsm::Conn(conn)) + .deadline(scope.now() + dur) + } + None => rotor::Response::done() } } + } + } - if status != StatusCode::Continue { - debug!("non-100 status ({}) for Expect 100 request", status); - return false; + fn wakeup(self, scope: &mut Scope) -> rotor::Response { + match self { + ServerFsm::Listener(lst, shutdown) => { + if shutdown.load(Ordering::Acquire) { + let _ = scope.deregister(&lst); + scope.shutdown_loop(); + rotor::Response::done() + } else { + rotor::Response::ok(ServerFsm::Listener(lst, shutdown)) + } + }, + ServerFsm::Conn(conn) => match conn.wakeup(scope) { + Some((conn, None)) => rotor::Response::ok(ServerFsm::Conn(conn)), + Some((conn, Some(dur))) => { + rotor::Response::ok(ServerFsm::Conn(conn)) + .deadline(scope.now() + dur) + } + None => rotor::Response::done() } } - - true } } -/// A listening server, which can later be closed. +/// A handle of the running server. pub struct Listening { - _guard: Option>, - /// The socket addresses that the server is bound to. - pub socket: SocketAddr, + addr: SocketAddr, + shutdown: (Arc, rotor::Notifier), } impl fmt::Debug for Listening { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Listening {{ socket: {:?} }}", self.socket) + f.debug_struct("Listening") + .field("addr", &self.addr) + .field("closed", &self.shutdown.0.load(Ordering::Relaxed)) + .finish() } } -impl Drop for Listening { - fn drop(&mut self) { - let _ = self._guard.take().map(|g| g.join()); +impl fmt::Display for Listening { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.addr, f) } } impl Listening { - /// Warning: This function doesn't work. The server remains listening after you called - /// it. See https://github.com/hyperium/hyper/issues/338 for more details. - /// + /// The address this server is listening on. + pub fn addr(&self) -> &SocketAddr { + &self.addr + } + /// Stop the server from listening to its socket address. - pub fn close(&mut self) -> ::Result<()> { - let _ = self._guard.take(); - debug!("closing server"); - Ok(()) + pub fn close(self) { + debug!("closing server {}", self); + self.shutdown.0.store(true, Ordering::Release); + self.shutdown.1.wakeup().unwrap(); } } -/// A handler that can handle incoming requests for a server. -pub trait Handler: Sync + Send { - /// Receives a `Request`/`Response` pair, and should perform some action on them. +/// A trait to react to server events that happen for each message. +/// +/// Each event handler returns it's desired `Next` action. +pub trait Handler { + /// This event occurs first, triggering when a `Request` has been parsed. + fn on_request(&mut self, request: Request) -> Next; + /// This event occurs each time the `Request` is ready to be read from. + fn on_request_readable(&mut self, request: &mut http::Decoder) -> Next; + /// This event occurs after the first time this handled signals `Next::write()`. + fn on_response(&mut self, response: &mut Response) -> Next; + /// This event occurs each time the `Response` is ready to be written to. + fn on_response_writable(&mut self, response: &mut http::Encoder) -> Next; + + /// This event occurs whenever an `Error` occurs outside of the other events. /// - /// This could reading from the request, and writing to the response. - fn handle<'a, 'k>(&'a self, Request<'a, 'k>, Response<'a, Fresh>); + /// This could IO errors while waiting for events, or a timeout, etc. + fn on_error(&mut self, err: ::Error) -> Next where Self: Sized { + debug!("default Handler.on_error({:?})", err); + http::Next::remove() + } - /// Called when a Request includes a `Expect: 100-continue` header. - /// - /// By default, this will always immediately response with a `StatusCode::Continue`, - /// but can be overridden with custom behavior. - fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode { - StatusCode::Continue + /// This event occurs when this Handler has requested to remove the Transport. + fn on_remove(self, _transport: T) where Self: Sized { + debug!("default Handler.on_remove"); } +} - /// This is run after a connection is received, on a per-connection basis (not a - /// per-request basis, as a connection with keep-alive may handle multiple - /// requests) - fn on_connection_start(&self) { } - /// This is run before a connection is closed, on a per-connection basis (not a - /// per-request basis, as a connection with keep-alive may handle multiple - /// requests) - fn on_connection_end(&self) { } +/// Used to create a `Handler` when a new message is received by the server. +pub trait HandlerFactory { + /// The `Handler` to use for the incoming message. + type Output: Handler; + /// Creates the associated `Handler`. + fn create(&mut self, ctrl: http::Control) -> Self::Output; } -impl Handler for F where F: Fn(Request, Response), F: Sync + Send { - fn handle<'a, 'k>(&'a self, req: Request<'a, 'k>, res: Response<'a, Fresh>) { - self(req, res) +impl HandlerFactory for F +where F: FnMut(http::Control) -> H, H: Handler, T: Transport { + type Output = H; + fn create(&mut self, ctrl: http::Control) -> H { + self(ctrl) } } #[cfg(test)] mod tests { - use header::Headers; - use method::Method; - use mock::MockStream; - use status::StatusCode; - use uri::RequestUri; - - use super::{Request, Response, Fresh, Handler, Worker}; - - #[test] - fn test_check_continue_default() { - let mut mock = MockStream::with_input(b"\ - POST /upload HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Expect: 100-continue\r\n\ - Content-Length: 10\r\n\ - \r\n\ - 1234567890\ - "); - - fn handle(_: Request, res: Response) { - res.start().unwrap().end().unwrap(); - } - - Worker::new(handle, Default::default()).handle_connection(&mut mock); - let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; - assert_eq!(&mock.write[..cont.len()], cont); - let res = b"HTTP/1.1 200 OK\r\n"; - assert_eq!(&mock.write[cont.len()..cont.len() + res.len()], res); - } - #[test] - fn test_check_continue_reject() { - struct Reject; - impl Handler for Reject { - fn handle<'a, 'k>(&'a self, _: Request<'a, 'k>, res: Response<'a, Fresh>) { - res.start().unwrap().end().unwrap(); - } - - fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode { - StatusCode::ExpectationFailed - } - } - - let mut mock = MockStream::with_input(b"\ - POST /upload HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Expect: 100-continue\r\n\ - Content-Length: 10\r\n\ - \r\n\ - 1234567890\ - "); - - Worker::new(Reject, Default::default()).handle_connection(&mut mock); - assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]); - } } diff --git a/src/server/request.rs b/src/server/request.rs index aab33e2762..a6c624d39d 100644 --- a/src/server/request.rs +++ b/src/server/request.rs @@ -2,324 +2,75 @@ //! //! These are requests that a `hyper::Server` receives, and include its method, //! target URI, headers, and message body. -use std::io::{self, Read}; -use std::net::SocketAddr; -use std::time::Duration; +//use std::net::SocketAddr; -use buffer::BufReader; -use net::NetworkStream; -use version::{HttpVersion}; +use version::HttpVersion; use method::Method; -use header::{Headers, ContentLength, TransferEncoding}; -use http::h1::{self, Incoming, HttpReader}; -use http::h1::HttpReader::{SizedReader, ChunkedReader, EmptyReader}; +use header::Headers; +use http::{RequestHead, MessageHead, RequestLine}; use uri::RequestUri; -/// A request bundles several parts of an incoming `NetworkStream`, given to a `Handler`. -pub struct Request<'a, 'b: 'a> { - /// The IP address of the remote connection. - pub remote_addr: SocketAddr, - /// The `Method`, such as `Get`, `Post`, etc. - pub method: Method, - /// The headers of the incoming request. - pub headers: Headers, - /// The target request-uri for this request. - pub uri: RequestUri, - /// The version of HTTP for this request. - pub version: HttpVersion, - body: HttpReader<&'a mut BufReader<&'b mut NetworkStream>> -} +pub fn new(incoming: RequestHead) -> Request { + let MessageHead { version, subject: RequestLine(method, uri), headers } = incoming; + debug!("Request Line: {:?} {:?} {:?}", method, uri, version); + debug!("{:#?}", headers); + Request { + //remote_addr: addr, + method: method, + uri: uri, + headers: headers, + version: version, + } +} -impl<'a, 'b: 'a> Request<'a, 'b> { - /// Create a new Request, reading the StartLine and Headers so they are - /// immediately useful. - pub fn new(mut stream: &'a mut BufReader<&'b mut NetworkStream>, addr: SocketAddr) - -> ::Result> { +/// A request bundles several parts of an incoming `NetworkStream`, given to a `Handler`. +#[derive(Debug)] +pub struct Request { + // The IP address of the remote connection. + //remote_addr: SocketAddr, + method: Method, + headers: Headers, + uri: RequestUri, + version: HttpVersion, +} - let Incoming { version, subject: (method, uri), headers } = try!(h1::parse_request(stream)); - debug!("Request Line: {:?} {:?} {:?}", method, uri, version); - debug!("{:?}", headers); - let body = if headers.has::() { - match headers.get::() { - Some(&ContentLength(len)) => SizedReader(stream, len), - None => unreachable!() - } - } else if headers.has::() { - todo!("check for Transfer-Encoding: chunked"); - ChunkedReader(stream, None) - } else { - EmptyReader(stream) - }; +impl Request { + /// The `Method`, such as `Get`, `Post`, etc. + #[inline] + pub fn method(&self) -> &Method { &self.method } - Ok(Request { - remote_addr: addr, - method: method, - uri: uri, - headers: headers, - version: version, - body: body - }) - } + /// The headers of the incoming request. + #[inline] + pub fn headers(&self) -> &Headers { &self.headers } - /// Set the read timeout of the underlying NetworkStream. + /// The target request-uri for this request. #[inline] - pub fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { - self.body.get_ref().get_ref().set_read_timeout(timeout) - } + pub fn uri(&self) -> &RequestUri { &self.uri } - /// Get a reference to the underlying `NetworkStream`. + /// The version of HTTP for this request. #[inline] - pub fn downcast_ref(&self) -> Option<&T> { - self.body.get_ref().get_ref().downcast_ref() - } + pub fn version(&self) -> &HttpVersion { &self.version } - /// Get a reference to the underlying Ssl stream, if connected - /// over HTTPS. - /// - /// # Example - /// - /// ```rust - /// # extern crate hyper; - /// # #[cfg(feature = "openssl")] - /// extern crate openssl; - /// # #[cfg(feature = "openssl")] - /// use openssl::ssl::SslStream; - /// use hyper::net::HttpStream; - /// # fn main() {} - /// # #[cfg(feature = "openssl")] - /// # fn doc_ssl(req: hyper::server::Request) { - /// let maybe_ssl = req.ssl::>(); - /// # } - /// ``` + /* + /// The target path of this Request. #[inline] - pub fn ssl(&self) -> Option<&T> { - use ::net::HttpsStream; - match self.downcast_ref() { - Some(&HttpsStream::Https(ref s)) => Some(s), + pub fn path(&self) -> Option<&str> { + match *self.uri { + RequestUri::AbsolutePath(ref s) => Some(s), + RequestUri::AbsoluteUri(ref url) => Some(&url[::url::Position::BeforePath..]), _ => None } } + */ - /// Deconstruct a Request into its constituent parts. - #[inline] - pub fn deconstruct(self) -> (SocketAddr, Method, Headers, - RequestUri, HttpVersion, - HttpReader<&'a mut BufReader<&'b mut NetworkStream>>) { - (self.remote_addr, self.method, self.headers, - self.uri, self.version, self.body) - } -} - -impl<'a, 'b> Read for Request<'a, 'b> { + /// Deconstruct this Request into its pieces. + /// + /// Modifying these pieces will have no effect on how hyper behaves. #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.body.read(buf) - } -} - -#[cfg(test)] -mod tests { - use buffer::BufReader; - use header::{Host, TransferEncoding, Encoding}; - use net::NetworkStream; - use mock::MockStream; - use super::Request; - - use std::io::{self, Read}; - use std::net::SocketAddr; - - fn sock(s: &str) -> SocketAddr { - s.parse().unwrap() - } - - fn read_to_string(mut req: Request) -> io::Result { - let mut s = String::new(); - try!(req.read_to_string(&mut s)); - Ok(s) - } - - #[test] - fn test_get_empty_body() { - let mut mock = MockStream::with_input(b"\ - GET / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - \r\n\ - I'm a bad request.\r\n\ - "); - - // FIXME: Use Type ascription - let mock: &mut NetworkStream = &mut mock; - let mut stream = BufReader::new(mock); - - let req = Request::new(&mut stream, sock("127.0.0.1:80")).unwrap(); - assert_eq!(read_to_string(req).unwrap(), "".to_owned()); - } - - #[test] - fn test_get_with_body() { - let mut mock = MockStream::with_input(b"\ - GET / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Content-Length: 19\r\n\ - \r\n\ - I'm a good request.\r\n\ - "); - - // FIXME: Use Type ascription - let mock: &mut NetworkStream = &mut mock; - let mut stream = BufReader::new(mock); - - let req = Request::new(&mut stream, sock("127.0.0.1:80")).unwrap(); - assert_eq!(read_to_string(req).unwrap(), "I'm a good request.".to_owned()); - } - - #[test] - fn test_head_empty_body() { - let mut mock = MockStream::with_input(b"\ - HEAD / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - \r\n\ - I'm a bad request.\r\n\ - "); - - // FIXME: Use Type ascription - let mock: &mut NetworkStream = &mut mock; - let mut stream = BufReader::new(mock); - - let req = Request::new(&mut stream, sock("127.0.0.1:80")).unwrap(); - assert_eq!(read_to_string(req).unwrap(), "".to_owned()); - } - - #[test] - fn test_post_empty_body() { - let mut mock = MockStream::with_input(b"\ - POST / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - \r\n\ - I'm a bad request.\r\n\ - "); - - // FIXME: Use Type ascription - let mock: &mut NetworkStream = &mut mock; - let mut stream = BufReader::new(mock); - - let req = Request::new(&mut stream, sock("127.0.0.1:80")).unwrap(); - assert_eq!(read_to_string(req).unwrap(), "".to_owned()); - } - - #[test] - fn test_parse_chunked_request() { - let mut mock = MockStream::with_input(b"\ - POST / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - 1\r\n\ - q\r\n\ - 2\r\n\ - we\r\n\ - 2\r\n\ - rt\r\n\ - 0\r\n\ - \r\n" - ); - - // FIXME: Use Type ascription - let mock: &mut NetworkStream = &mut mock; - let mut stream = BufReader::new(mock); - - let req = Request::new(&mut stream, sock("127.0.0.1:80")).unwrap(); - - // The headers are correct? - match req.headers.get::() { - Some(host) => { - assert_eq!("example.domain", host.hostname); - }, - None => panic!("Host header expected!"), - }; - match req.headers.get::() { - Some(encodings) => { - assert_eq!(1, encodings.len()); - assert_eq!(Encoding::Chunked, encodings[0]); - } - None => panic!("Transfer-Encoding: chunked expected!"), - }; - // The content is correctly read? - assert_eq!(read_to_string(req).unwrap(), "qwert".to_owned()); - } - - /// Tests that when a chunk size is not a valid radix-16 number, an error - /// is returned. - #[test] - fn test_invalid_chunk_size_not_hex_digit() { - let mut mock = MockStream::with_input(b"\ - POST / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - X\r\n\ - 1\r\n\ - 0\r\n\ - \r\n" - ); - - // FIXME: Use Type ascription - let mock: &mut NetworkStream = &mut mock; - let mut stream = BufReader::new(mock); - - let req = Request::new(&mut stream, sock("127.0.0.1:80")).unwrap(); - - assert!(read_to_string(req).is_err()); - } - - /// Tests that when a chunk size contains an invalid extension, an error is - /// returned. - #[test] - fn test_invalid_chunk_size_extension() { - let mut mock = MockStream::with_input(b"\ - POST / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - 1 this is an invalid extension\r\n\ - 1\r\n\ - 0\r\n\ - \r\n" - ); - - // FIXME: Use Type ascription - let mock: &mut NetworkStream = &mut mock; - let mut stream = BufReader::new(mock); - - let req = Request::new(&mut stream, sock("127.0.0.1:80")).unwrap(); - - assert!(read_to_string(req).is_err()); - } - - /// Tests that when a valid extension that contains a digit is appended to - /// the chunk size, the chunk is correctly read. - #[test] - fn test_chunk_size_with_extension() { - let mut mock = MockStream::with_input(b"\ - POST / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - 1;this is an extension with a digit 1\r\n\ - 1\r\n\ - 0\r\n\ - \r\n" - ); - - // FIXME: Use Type ascription - let mock: &mut NetworkStream = &mut mock; - let mut stream = BufReader::new(mock); - - let req = Request::new(&mut stream, sock("127.0.0.1:80")).unwrap(); - - assert_eq!(read_to_string(req).unwrap(), "1".to_owned()); + pub fn deconstruct(self) -> (Method, RequestUri, HttpVersion, Headers) { + (self.method, self.uri, self.version, self.headers) } } diff --git a/src/server/response.rs b/src/server/response.rs index e0155c049f..95874781c9 100644 --- a/src/server/response.rs +++ b/src/server/response.rs @@ -2,431 +2,49 @@ //! //! These are responses sent by a `hyper::Server` to clients, after //! receiving a request. -use std::any::{Any, TypeId}; -use std::marker::PhantomData; -use std::mem; -use std::io::{self, Write}; -use std::ptr; -use std::thread; - -use time::now_utc; - use header; -use http::h1::{LINE_ENDING, HttpWriter}; -use http::h1::HttpWriter::{ThroughWriter, ChunkedWriter, SizedWriter, EmptyWriter}; -use status; -use net::{Fresh, Streaming}; +use http; +use status::StatusCode; use version; /// The outgoing half for a Tcp connection, created by a `Server` and given to a `Handler`. /// /// The default `StatusCode` for a `Response` is `200 OK`. -/// -/// There is a `Drop` implementation for `Response` that will automatically -/// write the head and flush the body, if the handler has not already done so, -/// so that the server doesn't accidentally leave dangling requests. #[derive(Debug)] -pub struct Response<'a, W: Any = Fresh> { - /// The HTTP version of this response. - pub version: version::HttpVersion, - // Stream the Response is writing to, not accessible through UnwrittenResponse - body: HttpWriter<&'a mut (Write + 'a)>, - // The status code for the request. - status: status::StatusCode, - // The outgoing headers on this response. - headers: &'a mut header::Headers, - - _writing: PhantomData +pub struct Response<'a> { + head: &'a mut http::MessageHead, } -impl<'a, W: Any> Response<'a, W> { - /// The status of this response. - #[inline] - pub fn status(&self) -> status::StatusCode { self.status } - +impl<'a> Response<'a> { /// The headers of this response. #[inline] - pub fn headers(&self) -> &header::Headers { &*self.headers } - - /// Construct a Response from its constituent parts. - #[inline] - pub fn construct(version: version::HttpVersion, - body: HttpWriter<&'a mut (Write + 'a)>, - status: status::StatusCode, - headers: &'a mut header::Headers) -> Response<'a, Fresh> { - Response { - status: status, - version: version, - body: body, - headers: headers, - _writing: PhantomData, - } - } - - /// Deconstruct this Response into its constituent parts. - #[inline] - pub fn deconstruct(self) -> (version::HttpVersion, HttpWriter<&'a mut (Write + 'a)>, - status::StatusCode, &'a mut header::Headers) { - unsafe { - let parts = ( - self.version, - ptr::read(&self.body), - self.status, - ptr::read(&self.headers) - ); - mem::forget(self); - parts - } - } - - fn write_head(&mut self) -> io::Result { - debug!("writing head: {:?} {:?}", self.version, self.status); - try!(write!(&mut self.body, "{} {}\r\n", self.version, self.status)); - - if !self.headers.has::() { - self.headers.set(header::Date(header::HttpDate(now_utc()))); - } - - let body_type = match self.status { - status::StatusCode::NoContent | status::StatusCode::NotModified => Body::Empty, - c if c.class() == status::StatusClass::Informational => Body::Empty, - _ => if let Some(cl) = self.headers.get::() { - Body::Sized(**cl) - } else { - Body::Chunked - } - }; - - // can't do in match above, thanks borrowck - if body_type == Body::Chunked { - let encodings = match self.headers.get_mut::() { - Some(&mut header::TransferEncoding(ref mut encodings)) => { - //TODO: check if chunked is already in encodings. use HashSet? - encodings.push(header::Encoding::Chunked); - false - }, - None => true - }; - - if encodings { - self.headers.set::( - header::TransferEncoding(vec![header::Encoding::Chunked])) - } - } - + pub fn headers(&self) -> &header::Headers { &self.head.headers } - debug!("headers [\n{:?}]", self.headers); - try!(write!(&mut self.body, "{}", self.headers)); - try!(write!(&mut self.body, "{}", LINE_ENDING)); - - Ok(body_type) - } -} - -impl<'a> Response<'a, Fresh> { - /// Creates a new Response that can be used to write to a network stream. - #[inline] - pub fn new(stream: &'a mut (Write + 'a), headers: &'a mut header::Headers) -> - Response<'a, Fresh> { - Response { - status: status::StatusCode::Ok, - version: version::HttpVersion::Http11, - headers: headers, - body: ThroughWriter(stream), - _writing: PhantomData, - } - } - - /// Writes the body and ends the response. - /// - /// This is a shortcut method for when you have a response with a fixed - /// size, and would only need a single `write` call normally. - /// - /// # Example - /// - /// ``` - /// # use hyper::server::Response; - /// fn handler(res: Response) { - /// res.send(b"Hello World!").unwrap(); - /// } - /// ``` - /// - /// The above is the same, but shorter, than the longer: - /// - /// ``` - /// # use hyper::server::Response; - /// use std::io::Write; - /// use hyper::header::ContentLength; - /// fn handler(mut res: Response) { - /// let body = b"Hello World!"; - /// res.headers_mut().set(ContentLength(body.len() as u64)); - /// let mut res = res.start().unwrap(); - /// res.write_all(body).unwrap(); - /// } - /// ``` + /// The status of this response. #[inline] - pub fn send(mut self, body: &[u8]) -> io::Result<()> { - self.headers.set(header::ContentLength(body.len() as u64)); - let mut stream = try!(self.start()); - try!(stream.write_all(body)); - stream.end() + pub fn status(&self) -> &StatusCode { + &self.head.subject } - /// Consume this Response, writing the Headers and Status and - /// creating a Response - pub fn start(mut self) -> io::Result> { - let body_type = try!(self.write_head()); - let (version, body, status, headers) = self.deconstruct(); - let stream = match body_type { - Body::Chunked => ChunkedWriter(body.into_inner()), - Body::Sized(len) => SizedWriter(body.into_inner(), len), - Body::Empty => EmptyWriter(body.into_inner()), - }; - - // "copy" to change the phantom type - Ok(Response { - version: version, - body: stream, - status: status, - headers: headers, - _writing: PhantomData, - }) - } - /// Get a mutable reference to the status. + /// The HTTP version of this response. #[inline] - pub fn status_mut(&mut self) -> &mut status::StatusCode { &mut self.status } + pub fn version(&self) -> &version::HttpVersion { &self.head.version } /// Get a mutable reference to the Headers. #[inline] - pub fn headers_mut(&mut self) -> &mut header::Headers { self.headers } -} - - -impl<'a> Response<'a, Streaming> { - /// Flushes all writing of a response to the client. - #[inline] - pub fn end(self) -> io::Result<()> { - trace!("ending"); - let (_, body, _, _) = self.deconstruct(); - try!(body.end()); - Ok(()) - } -} - -impl<'a> Write for Response<'a, Streaming> { - #[inline] - fn write(&mut self, msg: &[u8]) -> io::Result { - debug!("write {:?} bytes", msg.len()); - self.body.write(msg) - } + pub fn headers_mut(&mut self) -> &mut header::Headers { &mut self.head.headers } + /// Get a mutable reference to the status. #[inline] - fn flush(&mut self) -> io::Result<()> { - self.body.flush() + pub fn set_status(&mut self, status: StatusCode) { + self.head.subject = status; } } -#[derive(PartialEq)] -enum Body { - Chunked, - Sized(u64), - Empty, -} - -impl<'a, T: Any> Drop for Response<'a, T> { - fn drop(&mut self) { - if TypeId::of::() == TypeId::of::() { - if thread::panicking() { - self.status = status::StatusCode::InternalServerError; - } - - let mut body = match self.write_head() { - Ok(Body::Chunked) => ChunkedWriter(self.body.get_mut()), - Ok(Body::Sized(len)) => SizedWriter(self.body.get_mut(), len), - Ok(Body::Empty) => EmptyWriter(self.body.get_mut()), - Err(e) => { - debug!("error dropping request: {:?}", e); - return; - } - }; - end(&mut body); - } else { - end(&mut self.body); - }; - - - #[inline] - fn end(w: &mut W) { - match w.write(&[]) { - Ok(_) => match w.flush() { - Ok(_) => debug!("drop successful"), - Err(e) => debug!("error dropping request: {:?}", e) - }, - Err(e) => debug!("error dropping request: {:?}", e) - } - } - } -} - -#[cfg(test)] -mod tests { - use header::Headers; - use mock::MockStream; - use super::Response; - - macro_rules! lines { - ($s:ident = $($line:pat),+) => ({ - let s = String::from_utf8($s.write).unwrap(); - let mut lines = s.split_terminator("\r\n"); - - $( - match lines.next() { - Some($line) => (), - other => panic!("line mismatch: {:?} != {:?}", other, stringify!($line)) - } - )+ - - assert_eq!(lines.next(), None); - }) - } - - #[test] - fn test_fresh_start() { - let mut headers = Headers::new(); - let mut stream = MockStream::new(); - { - let res = Response::new(&mut stream, &mut headers); - res.start().unwrap().deconstruct(); - } - - lines! { stream = - "HTTP/1.1 200 OK", - _date, - _transfer_encoding, - "" - } - } - - #[test] - fn test_streaming_end() { - let mut headers = Headers::new(); - let mut stream = MockStream::new(); - { - let res = Response::new(&mut stream, &mut headers); - res.start().unwrap().end().unwrap(); - } - - lines! { stream = - "HTTP/1.1 200 OK", - _date, - _transfer_encoding, - "", - "0", - "" // empty zero body - } - } - - #[test] - fn test_fresh_drop() { - use status::StatusCode; - let mut headers = Headers::new(); - let mut stream = MockStream::new(); - { - let mut res = Response::new(&mut stream, &mut headers); - *res.status_mut() = StatusCode::NotFound; - } - - lines! { stream = - "HTTP/1.1 404 Not Found", - _date, - _transfer_encoding, - "", - "0", - "" // empty zero body - } - } - - // x86 windows msvc does not support unwinding - // See https://github.com/rust-lang/rust/issues/25869 - #[cfg(not(all(windows, target_arch="x86", target_env="msvc")))] - #[test] - fn test_fresh_drop_panicing() { - use std::thread; - use std::sync::{Arc, Mutex}; - - use status::StatusCode; - - let stream = MockStream::new(); - let stream = Arc::new(Mutex::new(stream)); - let inner_stream = stream.clone(); - let join_handle = thread::spawn(move || { - let mut headers = Headers::new(); - let mut stream = inner_stream.lock().unwrap(); - let mut res = Response::new(&mut *stream, &mut headers); - *res.status_mut() = StatusCode::NotFound; - - panic!("inside") - }); - - assert!(join_handle.join().is_err()); - - let stream = match stream.lock() { - Err(poisoned) => poisoned.into_inner().clone(), - Ok(_) => unreachable!() - }; - - lines! { stream = - "HTTP/1.1 500 Internal Server Error", - _date, - _transfer_encoding, - "", - "0", - "" // empty zero body - } - } - - - #[test] - fn test_streaming_drop() { - use std::io::Write; - use status::StatusCode; - let mut headers = Headers::new(); - let mut stream = MockStream::new(); - { - let mut res = Response::new(&mut stream, &mut headers); - *res.status_mut() = StatusCode::NotFound; - let mut stream = res.start().unwrap(); - stream.write_all(b"foo").unwrap(); - } - - lines! { stream = - "HTTP/1.1 404 Not Found", - _date, - _transfer_encoding, - "", - "3", - "foo", - "0", - "" // empty zero body - } - } - - #[test] - fn test_no_content() { - use status::StatusCode; - let mut headers = Headers::new(); - let mut stream = MockStream::new(); - { - let mut res = Response::new(&mut stream, &mut headers); - *res.status_mut() = StatusCode::NoContent; - res.start().unwrap(); - } - - lines! { stream = - "HTTP/1.1 204 No Content", - _date, - "" - } +/// Creates a new Response that can be used to write to a network stream. +pub fn new<'a>(head: &'a mut http::MessageHead) -> Response<'a> { + Response { + head: head } } diff --git a/src/status.rs b/src/status.rs index 5435182559..8a49bb7300 100644 --- a/src/status.rs +++ b/src/status.rs @@ -547,6 +547,12 @@ impl Ord for StatusCode { } } +impl Default for StatusCode { + fn default() -> StatusCode { + StatusCode::Ok + } +} + /// The class of an HTTP `status-code`. /// /// [RFC 7231, section 6 (Response Status Codes)](https://tools.ietf.org/html/rfc7231#section-6): diff --git a/src/uri.rs b/src/uri.rs index a5a9860fed..26f1ed1979 100644 --- a/src/uri.rs +++ b/src/uri.rs @@ -50,6 +50,12 @@ pub enum RequestUri { Star, } +impl Default for RequestUri { + fn default() -> RequestUri { + RequestUri::Star + } +} + impl FromStr for RequestUri { type Err = Error; @@ -67,7 +73,7 @@ impl FromStr for RequestUri { let mut temp = "http://".to_owned(); temp.push_str(s); try!(Url::parse(&temp[..])); - todo!("compare vs u.authority()"); + //TODO: compare vs u.authority()? Ok(RequestUri::Authority(s.to_owned())) } } diff --git a/src/version.rs b/src/version.rs index ad4c5c8745..fc5c840e2f 100644 --- a/src/version.rs +++ b/src/version.rs @@ -4,7 +4,7 @@ //! the `HttpVersion` enum. use std::fmt; -use self::HttpVersion::{Http09, Http10, Http11, Http20}; +use self::HttpVersion::{Http09, Http10, Http11, H2, H2c}; /// Represents a version of the HTTP spec. #[derive(PartialEq, PartialOrd, Copy, Clone, Eq, Ord, Hash, Debug)] @@ -15,8 +15,10 @@ pub enum HttpVersion { Http10, /// `HTTP/1.1` Http11, - /// `HTTP/2.0` - Http20 + /// `HTTP/2.0` over TLS + H2, + /// `HTTP/2.0` over cleartext + H2c, } impl fmt::Display for HttpVersion { @@ -25,7 +27,14 @@ impl fmt::Display for HttpVersion { Http09 => "HTTP/0.9", Http10 => "HTTP/1.0", Http11 => "HTTP/1.1", - Http20 => "HTTP/2.0", + H2 => "h2", + H2c => "h2c", }) } } + +impl Default for HttpVersion { + fn default() -> HttpVersion { + Http11 + } +} diff --git a/tests/client.rs b/tests/client.rs new file mode 100644 index 0000000000..c2c8271190 --- /dev/null +++ b/tests/client.rs @@ -0,0 +1,205 @@ +#![deny(warnings)] +extern crate hyper; + +use std::io::{self, Read, Write}; +use std::net::TcpListener; +use std::sync::mpsc; +use std::time::Duration; + +use hyper::client::{Handler, Request, Response, HttpConnector}; +use hyper::header; +use hyper::{Method, StatusCode, Next, Encoder, Decoder}; +use hyper::net::HttpStream; + +fn s(bytes: &[u8]) -> &str { + ::std::str::from_utf8(bytes.as_ref()).unwrap() +} + +#[derive(Debug)] +struct TestHandler { + opts: Opts, + tx: mpsc::Sender +} + +impl TestHandler { + fn new(opts: Opts) -> (TestHandler, mpsc::Receiver) { + let (tx, rx) = mpsc::channel(); + (TestHandler { + opts: opts, + tx: tx + }, rx) + } +} + +#[derive(Debug)] +enum Msg { + Head(Response), + Chunk(Vec), + Error(hyper::Error), +} + +fn read(opts: &Opts) -> Next { + if let Some(timeout) = opts.read_timeout { + Next::read().timeout(timeout) + } else { + Next::read() + } +} + +impl Handler for TestHandler { + fn on_request(&mut self, req: &mut Request) -> Next { + req.set_method(self.opts.method.clone()); + read(&self.opts) + } + + fn on_request_writable(&mut self, _encoder: &mut Encoder) -> Next { + read(&self.opts) + } + + fn on_response(&mut self, res: Response) -> Next { + use hyper::header; + // server responses can include a body until eof, if not size is specified + let mut has_body = true; + if let Some(len) = res.headers().get::() { + if **len == 0 { + has_body = false; + } + } + self.tx.send(Msg::Head(res)).unwrap(); + if has_body { + read(&self.opts) + } else { + Next::end() + } + } + + fn on_response_readable(&mut self, decoder: &mut Decoder) -> Next { + let mut v = vec![0; 512]; + match decoder.read(&mut v) { + Ok(n) => { + v.truncate(n); + self.tx.send(Msg::Chunk(v)).unwrap(); + if n == 0 { + Next::end() + } else { + read(&self.opts) + } + }, + Err(e) => match e.kind() { + io::ErrorKind::WouldBlock => read(&self.opts), + _ => panic!("io read error: {:?}", e) + } + } + } + + fn on_error(&mut self, err: hyper::Error) -> Next { + self.tx.send(Msg::Error(err)).unwrap(); + Next::remove() + } +} + +struct Client { + client: Option>, +} + +#[derive(Debug)] +struct Opts { + method: Method, + read_timeout: Option, +} + +impl Default for Opts { + fn default() -> Opts { + Opts { + method: Method::Get, + read_timeout: None, + } + } +} + +fn opts() -> Opts { + Opts::default() +} + +impl Opts { + fn method(mut self, method: Method) -> Opts { + self.method = method; + self + } + + fn read_timeout(mut self, timeout: Duration) -> Opts { + self.read_timeout = Some(timeout); + self + } +} + +impl Client { + fn request(&self, url: U, opts: Opts) -> mpsc::Receiver + where U: AsRef { + let (handler, rx) = TestHandler::new(opts); + self.client.as_ref().unwrap() + .request(url.as_ref().parse().unwrap(), handler).unwrap(); + rx + } +} + +impl Drop for Client { + fn drop(&mut self) { + self.client.take().map(|c| c.close()); + } +} + +fn client() -> Client { + let c = hyper::Client::::configure() + .connector(HttpConnector::default()) + .build().unwrap(); + Client { + client: Some(c), + } +} + + +#[test] +fn client_get() { + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let client = client(); + let res = client.request(format!("http://{}/", addr), opts().method(Method::Get)); + + let mut inc = server.accept().unwrap().0; + inc.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + inc.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + let n = inc.read(&mut buf).unwrap(); + let expected = format!("GET / HTTP/1.1\r\nHost: {}\r\n\r\n", addr); + assert_eq!(s(&buf[..n]), expected); + + inc.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").unwrap(); + + if let Msg::Head(head) = res.recv().unwrap() { + assert_eq!(head.status(), &StatusCode::Ok); + assert_eq!(head.headers().get(), Some(&header::ContentLength(0))); + } else { + panic!("we lost the head!"); + } + //drop(inc); + + assert!(res.recv().is_err()); +} + +#[test] +fn client_read_timeout() { + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let client = client(); + let res = client.request(format!("http://{}/", addr), opts().read_timeout(Duration::from_secs(3))); + + let mut inc = server.accept().unwrap().0; + let mut buf = [0; 4096]; + inc.read(&mut buf).unwrap(); + + match res.recv() { + Ok(Msg::Error(hyper::Error::Timeout)) => (), + other => panic!("expected timeout, actual: {:?}", other) + } +} diff --git a/tests/server.rs b/tests/server.rs new file mode 100644 index 0000000000..b92c4d088f --- /dev/null +++ b/tests/server.rs @@ -0,0 +1,379 @@ +#![deny(warnings)] +extern crate hyper; + +use std::net::{TcpStream, SocketAddr}; +use std::io::{self, Read, Write}; +use std::sync::mpsc; +use std::time::Duration; + +use hyper::{Next, Encoder, Decoder}; +use hyper::net::HttpStream; +use hyper::server::{Server, Handler, Request, Response}; + +struct Serve { + listening: Option, + msg_rx: mpsc::Receiver, + reply_tx: mpsc::Sender, +} + +impl Serve { + fn addr(&self) -> &SocketAddr { + self.listening.as_ref().unwrap().addr() + } + + /* + fn head(&self) -> Request { + unimplemented!() + } + */ + + fn body(&self) -> Vec { + let mut buf = vec![]; + while let Ok(Msg::Chunk(msg)) = self.msg_rx.try_recv() { + buf.extend(&msg); + } + buf + } + + fn reply(&self) -> ReplyBuilder { + ReplyBuilder { + tx: &self.reply_tx + } + } +} + +struct ReplyBuilder<'a> { + tx: &'a mpsc::Sender, +} + +impl<'a> ReplyBuilder<'a> { + fn status(self, status: hyper::StatusCode) -> Self { + self.tx.send(Reply::Status(status)).unwrap(); + self + } + + fn header(self, header: H) -> Self { + let mut headers = hyper::Headers::new(); + headers.set(header); + self.tx.send(Reply::Headers(headers)).unwrap(); + self + } + + fn body>(self, body: T) { + self.tx.send(Reply::Body(body.as_ref().into())).unwrap(); + } +} + +impl Drop for Serve { + fn drop(&mut self) { + self.listening.take().unwrap().close(); + } +} + +struct TestHandler { + tx: mpsc::Sender, + rx: mpsc::Receiver, + peeked: Option>, + timeout: Option, +} + +enum Reply { + Status(hyper::StatusCode), + Headers(hyper::Headers), + Body(Vec), +} + +enum Msg { + //Head(Request), + Chunk(Vec), +} + +impl TestHandler { + fn next(&self, next: Next) -> Next { + if let Some(dur) = self.timeout { + next.timeout(dur) + } else { + next + } + } +} + +impl Handler for TestHandler { + fn on_request(&mut self, _req: Request) -> Next { + //self.tx.send(Msg::Head(req)).unwrap(); + self.next(Next::read()) + } + + fn on_request_readable(&mut self, decoder: &mut Decoder) -> Next { + let mut vec = vec![0; 1024]; + match decoder.read(&mut vec) { + Ok(0) => { + self.next(Next::write()) + } + Ok(n) => { + vec.truncate(n); + self.tx.send(Msg::Chunk(vec)).unwrap(); + self.next(Next::read()) + } + Err(e) => match e.kind() { + io::ErrorKind::WouldBlock => self.next(Next::read()), + _ => panic!("test error: {}", e) + } + } + } + + fn on_response(&mut self, res: &mut Response) -> Next { + loop { + match self.rx.try_recv() { + Ok(Reply::Status(s)) => { + res.set_status(s); + }, + Ok(Reply::Headers(headers)) => { + use std::iter::Extend; + res.headers_mut().extend(headers.iter()); + }, + Ok(Reply::Body(body)) => { + self.peeked = Some(body); + }, + Err(..) => { + return if self.peeked.is_some() { + self.next(Next::write()) + } else { + self.next(Next::end()) + }; + }, + } + } + + } + + fn on_response_writable(&mut self, encoder: &mut Encoder) -> Next { + match self.peeked { + Some(ref body) => { + encoder.write(body).unwrap(); + self.next(Next::end()) + }, + None => self.next(Next::end()) + } + } +} + +fn serve() -> Serve { + serve_with_timeout(None) +} + +fn serve_with_timeout(dur: Option) -> Serve { + use std::thread; + + let (msg_tx, msg_rx) = mpsc::channel(); + let (reply_tx, reply_rx) = mpsc::channel(); + let mut reply_rx = Some(reply_rx); + let (listening, server) = Server::http(&"127.0.0.1:0".parse().unwrap()).unwrap() + .handle(move |_| TestHandler { + tx: msg_tx.clone(), + timeout: dur, + rx: reply_rx.take().unwrap(), + peeked: None, + }).unwrap(); + + + let thread_name = format!("test-server-{}: {:?}", listening.addr(), dur); + thread::Builder::new().name(thread_name).spawn(move || { + server.run(); + }).unwrap(); + + Serve { + listening: Some(listening), + msg_rx: msg_rx, + reply_tx: reply_tx, + } +} + +#[test] +fn server_get_should_ignore_body() { + let server = serve(); + + let mut req = TcpStream::connect(server.addr()).unwrap(); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + \r\n\ + I shouldn't be read.\r\n\ + ").unwrap(); + req.read(&mut [0; 256]).unwrap(); + + assert_eq!(server.body(), b""); +} + +#[test] +fn server_get_with_body() { + let server = serve(); + let mut req = TcpStream::connect(server.addr()).unwrap(); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Content-Length: 19\r\n\ + \r\n\ + I'm a good request.\r\n\ + ").unwrap(); + req.read(&mut [0; 256]).unwrap(); + + // note: doesnt include trailing \r\n, cause Content-Length wasn't 21 + assert_eq!(server.body(), b"I'm a good request."); +} + +#[test] +fn server_get_fixed_response() { + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(foo_bar.len() as u64)) + .body(foo_bar); + let mut req = TcpStream::connect(server.addr()).unwrap(); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n + \r\n\ + ").unwrap(); + let mut body = String::new(); + req.read_to_string(&mut body).unwrap(); + let n = body.find("\r\n\r\n").unwrap() + 4; + + assert_eq!(&body[n..], "foo bar baz"); +} + +#[test] +fn server_get_chunked_response() { + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .status(hyper::Ok) + .header(hyper::header::TransferEncoding::chunked()) + .body(foo_bar); + let mut req = TcpStream::connect(server.addr()).unwrap(); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n + \r\n\ + ").unwrap(); + let mut body = String::new(); + req.read_to_string(&mut body).unwrap(); + let n = body.find("\r\n\r\n").unwrap() + 4; + + assert_eq!(&body[n..], "B\r\nfoo bar baz\r\n0\r\n\r\n"); +} + +#[test] +fn server_post_with_chunked_body() { + let server = serve(); + let mut req = TcpStream::connect(server.addr()).unwrap(); + req.write_all(b"\ + POST / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 1\r\n\ + q\r\n\ + 2\r\n\ + we\r\n\ + 2\r\n\ + rt\r\n\ + 0\r\n\ + \r\n + ").unwrap(); + req.read(&mut [0; 256]).unwrap(); + + assert_eq!(server.body(), b"qwert"); +} + +/* +#[test] +fn server_empty_response() { + let server = serve(); + server.reply() + .status(hyper::Ok); + let mut req = TcpStream::connect(server.addr()).unwrap(); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n + \r\n\ + ").unwrap(); + + let mut response = String::new(); + req.read_to_string(&mut response).unwrap(); + + assert_eq!(response, "foo"); + assert!(!response.contains("Transfer-Encoding: chunked\r\n")); + + let mut lines = response.lines(); + assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); + + let mut lines = lines.skip_while(|line| !line.is_empty()); + assert_eq!(lines.next(), Some("")); + assert_eq!(lines.next(), None); +} +*/ + +#[test] +fn server_empty_response_chunked() { + let server = serve(); + server.reply() + .status(hyper::Ok) + .body(""); + let mut req = TcpStream::connect(server.addr()).unwrap(); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n + \r\n\ + ").unwrap(); + + let mut response = String::new(); + req.read_to_string(&mut response).unwrap(); + + assert!(response.contains("Transfer-Encoding: chunked\r\n")); + + let mut lines = response.lines(); + assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); + + let mut lines = lines.skip_while(|line| !line.is_empty()); + assert_eq!(lines.next(), Some("")); + // 0\r\n\r\n + assert_eq!(lines.next(), Some("0")); + assert_eq!(lines.next(), Some("")); + assert_eq!(lines.next(), None); +} + +#[test] +fn server_empty_response_chunked_without_calling_write() { + let server = serve(); + server.reply() + .status(hyper::Ok) + .header(hyper::header::TransferEncoding::chunked()); + let mut req = TcpStream::connect(server.addr()).unwrap(); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n + \r\n\ + ").unwrap(); + + let mut response = String::new(); + req.read_to_string(&mut response).unwrap(); + + assert!(response.contains("Transfer-Encoding: chunked\r\n")); + + let mut lines = response.lines(); + assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); + + let mut lines = lines.skip_while(|line| !line.is_empty()); + assert_eq!(lines.next(), Some("")); + // 0\r\n\r\n + assert_eq!(lines.next(), Some("0")); + assert_eq!(lines.next(), Some("")); + assert_eq!(lines.next(), None); +}