From 36e66a50546347c6f9b74c6d3c26e8b910483a4b Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 22 Jan 2018 10:08:27 -0800 Subject: [PATCH] fix(lib): properly handle HTTP/1.0 remotes - Downgrades internal semantics to HTTP/1.0 if peer sends a message with 1.0 version. - If downgraded, chunked writers become EOF writers, with the connection closing once the writing is complete. - When downgraded, if keep-alive was wanted, the `Connection: keep-alive` header is added. Closes #1304 --- src/proto/conn.rs | 121 +++++++++++++++++++++++---------- src/proto/h1/encode.rs | 19 +++++- src/proto/h1/parse.rs | 13 ++-- tests/server.rs | 149 ++++++++++++++++++++++++++++++++++++++++- 4 files changed, 258 insertions(+), 44 deletions(-) diff --git a/src/proto/conn.rs b/src/proto/conn.rs index 13805acb16..bda987f400 100644 --- a/src/proto/conn.rs +++ b/src/proto/conn.rs @@ -45,6 +45,9 @@ where I: AsyncRead + AsyncWrite, read_task: None, reading: Reading::Init, writing: Writing::Init, + // We assume a modern world where the remote speaks HTTP/1.1. + // If they tell us otherwise, we'll downgrade in `read_head`. + version: Version::Http11, }, _marker: PhantomData, } @@ -189,43 +192,44 @@ where I: AsyncRead + AsyncWrite, } }; - match version { - HttpVersion::Http10 | HttpVersion::Http11 => { - let decoder = match T::decoder(&head, &mut self.state.method) { - Ok(d) => d, - Err(e) => { - debug!("decoder error = {:?}", e); - self.state.close_read(); - return Err(e); - } - }; - - debug!("incoming body is {}", decoder); - - self.state.busy(); - if head.expecting_continue() { - let msg = b"HTTP/1.1 100 Continue\r\n\r\n"; - self.state.writing = Writing::Continue(Cursor::new(msg)); - } - let wants_keep_alive = head.should_keep_alive(); - self.state.keep_alive &= wants_keep_alive; - let (body, reading) = if decoder.is_eof() { - (false, Reading::KeepAlive) - } else { - (true, Reading::Body(decoder)) - }; - self.state.reading = reading; - if !body { - self.try_keep_alive(); - } - Ok(Async::Ready(Some((head, body)))) - }, + self.state.version = match version { + HttpVersion::Http10 => Version::Http10, + HttpVersion::Http11 => Version::Http11, _ => { error!("unimplemented HTTP Version = {:?}", version); self.state.close_read(); - Err(::Error::Version) + return Err(::Error::Version); + } + }; + + let decoder = match T::decoder(&head, &mut self.state.method) { + Ok(d) => d, + Err(e) => { + debug!("decoder error = {:?}", e); + self.state.close_read(); + return Err(e); } + }; + + debug!("incoming body is {}", decoder); + + self.state.busy(); + if head.expecting_continue() { + let msg = b"HTTP/1.1 100 Continue\r\n\r\n"; + self.state.writing = Writing::Continue(Cursor::new(msg)); + } + let wants_keep_alive = head.should_keep_alive(); + self.state.keep_alive &= wants_keep_alive; + let (body, reading) = if decoder.is_eof() { + (false, Reading::KeepAlive) + } else { + (true, Reading::Body(decoder)) + }; + self.state.reading = reading; + if !body { + self.try_keep_alive(); } + Ok(Async::Ready(Some((head, body)))) } pub fn read_body(&mut self) -> Poll, io::Error> { @@ -414,11 +418,11 @@ where I: AsyncRead + AsyncWrite, } } - pub fn write_head(&mut self, head: super::MessageHead, body: bool) { + pub fn write_head(&mut self, mut head: super::MessageHead, body: bool) { debug_assert!(self.can_write_head()); - let wants_keep_alive = head.should_keep_alive(); - self.state.keep_alive &= wants_keep_alive; + self.enforce_version(&mut head); + let buf = self.io.write_buf_mut(); // if a 100-continue has started but not finished sending, tack the // remainder on to the start of the buffer. @@ -435,6 +439,36 @@ where I: AsyncRead + AsyncWrite, }; } + // If we know the remote speaks an older version, we try to fix up any messages + // to work with our older peer. + fn enforce_version(&mut self, head: &mut super::MessageHead) { + use header::Connection; + + let wants_keep_alive = if self.state.wants_keep_alive() { + let ka = head.should_keep_alive(); + self.state.keep_alive &= ka; + ka + } else { + false + }; + + match self.state.version { + Version::Http10 => { + // If the remote only knows HTTP/1.0, we should force ourselves + // to do only speak HTTP/1.0 as well. + head.version = HttpVersion::Http10; + if wants_keep_alive { + head.headers.set(Connection::keep_alive()); + } + }, + Version::Http11 => { + // If the remote speaks HTTP/1.1, then it *should* be fine with + // both HTTP/1.0 and HTTP/1.1 from us. So again, we just let + // the user's headers be. + } + } + } + pub fn write_body(&mut self, chunk: Option) -> StartSend, io::Error> { debug_assert!(self.can_write_body()); @@ -486,7 +520,7 @@ where I: AsyncRead + AsyncWrite, } } else { // end of stream, that means we should try to eof - match encoder.eof() { + match encoder.end() { Ok(Some(end)) => Writing::Ending(Cursor::new(end)), Ok(None) => Writing::KeepAlive, Err(_not_eof) => Writing::Closed, @@ -701,6 +735,7 @@ struct State { read_task: Option, reading: Reading, writing: Writing, + version: Version, } #[derive(Debug)] @@ -819,6 +854,14 @@ impl State { self.keep_alive.disable(); } + fn wants_keep_alive(&self) -> bool { + if let KA::Disabled = self.keep_alive.status() { + false + } else { + true + } + } + fn try_keep_alive(&mut self) { match (&self.reading, &self.writing) { (&Reading::KeepAlive, &Writing::KeepAlive) => { @@ -881,6 +924,12 @@ impl State { } } +#[derive(Debug, Clone, Copy)] +enum Version { + Http10, + Http11, +} + // The DebugFrame and DebugChunk are simple Debug implementations that allow // us to dump the frame into logs, without logging the entirety of the bytes. #[cfg(feature = "tokio-proto")] diff --git a/src/proto/h1/encode.rs b/src/proto/h1/encode.rs index 5033eba532..17f2724492 100644 --- a/src/proto/h1/encode.rs +++ b/src/proto/h1/encode.rs @@ -17,6 +17,11 @@ enum Kind { /// /// Enforces that the body is not longer than the Content-Length header. Length(u64), + /// An Encoder for when neither Content-Length nore Chunked encoding is set. + /// + /// This is mostly only used with HTTP/1.0 with a length. This kind requires + /// the connection to be closed when the body is finished. + Eof } impl Encoder { @@ -32,6 +37,12 @@ impl Encoder { } } + pub fn eof() -> Encoder { + Encoder { + kind: Kind::Eof, + } + } + pub fn is_eof(&self) -> bool { match self.kind { Kind::Length(0) | @@ -40,7 +51,7 @@ impl Encoder { } } - pub fn eof(&self) -> Result, NotEof> { + pub fn end(&self) -> Result, NotEof> { match self.kind { Kind::Length(0) => Ok(None), Kind::Chunked(Chunked::Init) => Ok(Some(b"0\r\n\r\n")), @@ -73,6 +84,12 @@ impl Encoder { trace!("encoded {} bytes, remaining = {}", n, remaining); Ok(n) }, + Kind::Eof => { + if msg.is_empty() { + return Ok(0); + } + w.write_atomic(&[msg]) + } } } } diff --git a/src/proto/h1/parse.rs b/src/proto/h1/parse.rs index 027d3496ab..0050394405 100644 --- a/src/proto/h1/parse.rs +++ b/src/proto/h1/parse.rs @@ -10,7 +10,7 @@ use proto::{MessageHead, RawStatus, Http1Transaction, ParseResult, use proto::h1::{Encoder, Decoder, date}; use method::Method; use status::StatusCode; -use version::HttpVersion::{Http10, Http11}; +use version::HttpVersion::{self, Http10, Http11}; const MAX_HEADERS: usize = 100; const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific @@ -166,7 +166,7 @@ impl ServerTransaction { }; if has_body && can_have_body { - set_length(&mut head.headers) + set_length(head.version, &mut head.headers) } else { head.headers.remove::(); if can_have_body { @@ -302,7 +302,7 @@ impl Http1Transaction for ClientTransaction { impl ClientTransaction { fn set_length(head: &mut RequestHead, has_body: bool) -> Encoder { if has_body { - set_length(&mut head.headers) + set_length(head.version, &mut head.headers) } else { head.headers.remove::(); head.headers.remove::(); @@ -311,12 +311,12 @@ impl ClientTransaction { } } -fn set_length(headers: &mut Headers) -> Encoder { +fn set_length(version: HttpVersion, headers: &mut Headers) -> Encoder { let len = headers.get::().map(|n| **n); if let Some(len) = len { Encoder::length(len) - } else { + } else if version == Http11 { let encodings = match headers.get_mut::() { Some(&mut header::TransferEncoding(ref mut encodings)) => { if encodings.last() != Some(&header::Encoding::Chunked) { @@ -331,6 +331,9 @@ fn set_length(headers: &mut Headers) -> Encoder { headers.set(header::TransferEncoding(vec![header::Encoding::Chunked])); } Encoder::chunked() + } else { + headers.remove::(); + Encoder::eof() } } diff --git a/tests/server.rs b/tests/server.rs index 78675734f3..ec9fd76d3e 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -145,6 +145,75 @@ fn get_chunked_response() { assert_eq!(&body[n..], "B\r\nfoo bar baz\r\n0\r\n\r\n"); } +#[test] +fn get_auto_response() { + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .status(hyper::Ok) + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + ").unwrap(); + let mut body = String::new(); + req.read_to_string(&mut body).unwrap(); + + assert!(has_header(&body, "Transfer-Encoding: chunked")); + + let n = body.find("\r\n\r\n").unwrap() + 4; + assert_eq!(&body[n..], "B\r\nfoo bar baz\r\n0\r\n\r\n"); +} + +#[test] +fn http_10_get_auto_response() { + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .status(hyper::Ok) + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.0\r\n\ + Host: example.domain\r\n\ + \r\n\ + ").unwrap(); + let mut body = String::new(); + req.read_to_string(&mut body).unwrap(); + + assert!(!has_header(&body, "Transfer-Encoding:")); + + let n = body.find("\r\n\r\n").unwrap() + 4; + assert_eq!(&body[n..], "foo bar baz"); +} + +#[test] +fn http_10_get_chunked_response() { + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .status(hyper::Ok) + // this header should actually get removed + .header(hyper::header::TransferEncoding::chunked()) + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.0\r\n\ + Host: example.domain\r\n\ + \r\n\ + ").unwrap(); + let mut body = String::new(); + req.read_to_string(&mut body).unwrap(); + + assert!(!has_header(&body, "Transfer-Encoding:")); + + let n = body.find("\r\n\r\n").unwrap() + 4; + assert_eq!(&body[n..], "foo bar baz"); +} + #[test] fn get_chunked_response_with_ka() { let foo_bar = b"foo bar baz"; @@ -378,7 +447,6 @@ fn keep_alive() { req.write_all(b"\ GET / HTTP/1.1\r\n\ Host: example.domain\r\n\ - Connection: keep-alive\r\n\ \r\n\ ").expect("writing 1"); @@ -388,7 +456,6 @@ fn keep_alive() { if n < buf.len() { if &buf[n - foo_bar.len()..n] == foo_bar { break; - } else { } } } @@ -419,6 +486,57 @@ fn keep_alive() { } } +#[test] +fn http_10_keep_alive() { + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(foo_bar.len() as u64)) + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.0\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + \r\n\ + ").expect("writing 1"); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf[..]).expect("reading 1"); + if n < buf.len() { + if &buf[n - foo_bar.len()..n] == foo_bar { + break; + } + } + } + + // try again! + + let quux = b"zar quux"; + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(quux.len() as u64)) + .body(quux); + req.write_all(b"\ + GET /quux HTTP/1.0\r\n\ + Host: example.domain\r\n\ + \r\n\ + ").expect("writing 2"); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf[..]).expect("reading 2"); + assert!(n > 0, "n = {}", n); + if n < buf.len() { + if &buf[n - quux.len()..n] == quux { + break; + } + } + } +} + #[test] fn disable_keep_alive() { let foo_bar = b"foo bar baz"; @@ -574,6 +692,23 @@ fn pipeline_enabled() { assert_eq!(n, 0); } +#[test] +fn http_10_request_receives_http_10_response() { + let server = serve(); + + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.0\r\n\ + \r\n\ + ").unwrap(); + + let expected = "HTTP/1.0 200 OK\r\nContent-Length: 0\r\n"; + let mut buf = [0; 256]; + let n = req.read(&mut buf).unwrap(); + assert!(n >= expected.len(), "read: {:?} >= {:?}", n, expected.len()); + assert_eq!(s(&buf[..expected.len()]), expected); +} + #[test] fn disable_keep_alive_mid_request() { let mut core = Core::new().unwrap(); @@ -997,6 +1132,16 @@ fn serve_with_options(options: ServeOptions) -> Serve { } } +fn s(buf: &[u8]) -> &str { + ::std::str::from_utf8(buf).unwrap() +} + +fn has_header(msg: &str, name: &str) -> bool { + let n = msg.find("\r\n\r\n").unwrap_or(msg.len()); + + msg[..n].contains(name) +} + struct DebugStream { stream: T, _debug: D,