From 751b7541621298e6aa45522687a98183138c7ebb Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Wed, 7 Feb 2024 20:10:14 -0800 Subject: [PATCH] feat: parse directly into trillium::Headers --- client/Cargo.toml | 1 + client/src/client.rs | 4 +- client/src/conn.rs | 60 +++++++-- client/tests/one_hundred_continue.rs | 24 ++-- http/Cargo.toml | 1 + http/src/conn.rs | 82 ++++++++++- http/src/headers.rs | 127 ++++++++++++------ http/src/headers/header_name.rs | 8 ++ http/src/headers/header_value.rs | 7 + http/src/headers/known_header_name.rs | 10 +- http/src/method.rs | 13 +- http/src/received_body.rs | 1 - http/src/received_body/chunked.rs | 107 ++++++++++----- http/src/status.rs | 16 ++- http/src/version.rs | 25 ++-- http/tests/corpus/1.response | 2 +- http/tests/corpus/10.response | 2 +- http/tests/corpus/2.response | 2 +- http/tests/corpus/3.response | 2 +- http/tests/corpus/4.response | 2 +- http/tests/corpus/5.response | 2 +- http/tests/corpus/6.response | 2 +- http/tests/corpus/7.response | 2 +- http/tests/corpus/8.response | 2 +- http/tests/corpus/9.response | 2 +- .../corpus/unsupported-http-version.error | 1 - .../corpus/unsupported-http-version.request | 3 - http/tests/one_hundred_continue.rs | 4 +- http/tests/unsafe_headers.rs | 2 +- 29 files changed, 374 insertions(+), 142 deletions(-) delete mode 100644 http/tests/corpus/unsupported-http-version.error delete mode 100644 http/tests/corpus/unsupported-http-version.request diff --git a/client/Cargo.toml b/client/Cargo.toml index 8953f7ab8e..6dfe4e3877 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -13,6 +13,7 @@ categories = ["web-programming", "web-programming::http-client"] [features] websockets = ["dep:trillium-websockets", "dep:thiserror"] json = ["dep:serde_json", "dep:serde", "dep:thiserror"] +parse = ["trillium-http/parse"] [dependencies] encoding_rs = "0.8.33" diff --git a/client/src/client.rs b/client/src/client.rs index ce2e7f407f..23bb2a84df 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -2,7 +2,7 @@ use crate::{Conn, IntoUrl, Pool, USER_AGENT}; use std::{fmt::Debug, sync::Arc, time::Duration}; use trillium_http::{ transport::BoxedTransport, HeaderName, HeaderValues, Headers, KnownHeaderName, Method, - ReceivedBodyState, + ReceivedBodyState, Version::Http1_1, }; use trillium_server_common::{ url::{Origin, Url}, @@ -169,6 +169,8 @@ impl Client { config: self.config.clone(), headers_finalized: false, timeout: self.timeout, + http_version: Http1_1, + max_head_length: 8 * 1024, } } diff --git a/client/src/conn.rs b/client/src/conn.rs index 4ee309c702..8cf8f7789a 100644 --- a/client/src/conn.rs +++ b/client/src/conn.rs @@ -9,23 +9,19 @@ use std::{ io::{ErrorKind, Write}, ops::{Deref, DerefMut}, pin::Pin, - str::FromStr, time::Duration, }; use trillium_http::{ - transport::BoxedTransport, - Body, Error, HeaderName, HeaderValue, HeaderValues, Headers, + transport::{BoxedTransport, Transport}, + Body, Error, HeaderName, HeaderValues, Headers, KnownHeaderName::{Connection, ContentLength, Expect, Host, TransferEncoding}, - Method, ReceivedBody, ReceivedBodyState, Result, Status, Upgrade, + Method, ReceivedBody, ReceivedBodyState, Result, Status, Upgrade, Version, }; use trillium_server_common::{ url::{Origin, Url}, - ArcedConnector, Connector, Transport, + ArcedConnector, Connector, }; -const MAX_HEADERS: usize = 128; -const MAX_HEAD_LENGTH: usize = 2 * 1024; - /** A wrapper error for [`trillium_http::Error`] or [`serde_json::Error`]. Only available when the `json` crate feature is @@ -62,6 +58,8 @@ pub struct Conn { pub(crate) config: ArcedConnector, pub(crate) headers_finalized: bool, pub(crate) timeout: Option, + pub(crate) http_version: Version, + pub(crate) max_head_length: usize, } /// default http user-agent header @@ -80,6 +78,8 @@ impl Debug for Conn { .field("buffer", &String::from_utf8_lossy(&self.buffer)) .field("response_body_state", &self.response_body_state) .field("config", &self.config) + .field("http_version", &self.http_version) + .field("max_head_length", &self.max_head_length) .finish() } } @@ -509,6 +509,14 @@ impl Conn { self } + /// returns the http version for this conn. + /// + /// prior to conn execution, this reflects the request http version that will be sent, and after + /// execution this reflects the server-indicated http version + pub fn http_version(&self) -> Version { + self.http_version + } + // --- everything below here is private --- fn finalize_headers(&mut self) -> Result<()> { @@ -615,7 +623,7 @@ impl Conn { } } - write!(buf, " HTTP/1.1\r\n")?; + write!(buf, " {}\r\n", self.http_version)?; for (name, values) in &self.request_headers { if !name.is_valid() { @@ -688,13 +696,18 @@ impl Conn { } } - if len >= MAX_HEAD_LENGTH { + if len >= self.max_head_length { return Err(Error::HeadersTooLong); } } } + #[cfg(not(feature = "parse"))] async fn parse_head(&mut self) -> Result<()> { + const MAX_HEADERS: usize = 128; + use crate::HeaderValue; + use std::str::FromStr; + let head_offset = self.read_head().await?; let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut httparse_res = httparse::Response::new(&mut headers); @@ -717,7 +730,6 @@ impl Conn { self.status = httparse_res.code.map(|code| code.try_into().unwrap()); - self.response_headers.reserve(httparse_res.headers.len()); for header in httparse_res.headers { let header_name = HeaderName::from_str(header.name)?; let header_value = HeaderValue::from(header.value.to_owned()); @@ -730,6 +742,32 @@ impl Conn { Ok(()) } + #[cfg(feature = "parse")] + async fn parse_head(&mut self) -> Result<()> { + use std::str; + + let head_offset = self.read_head().await?; + + let space = memchr::memchr(b' ', &self.buffer[..head_offset]).ok_or(Error::InvalidHead)?; + self.http_version = str::from_utf8(&self.buffer[..space]) + .map_err(|_| Error::InvalidHead)? + .parse() + .map_err(|_| Error::InvalidHead)?; + self.status = Some(str::from_utf8(&self.buffer[space + 1..space + 4])?.parse()?); + let end_of_first_line = 2 + Finder::new("\r\n") + .find(&self.buffer[..head_offset]) + .ok_or(Error::InvalidHead)?; + + self.response_headers + .extend_parse(&self.buffer[end_of_first_line..head_offset]) + .map_err(|_| Error::InvalidHead)?; + + self.buffer.ignore_front(head_offset); + + self.validate_response_headers()?; + Ok(()) + } + async fn send_body_and_parse_head(&mut self) -> Result<()> { if self .request_headers diff --git a/client/tests/one_hundred_continue.rs b/client/tests/one_hundred_continue.rs index ba072cd273..4337211fb5 100644 --- a/client/tests/one_hundred_continue.rs +++ b/client/tests/one_hundred_continue.rs @@ -17,10 +17,10 @@ async fn extra_one_hundred_continue() -> TestResult { POST / HTTP/1.1\r Host: example.com\r Accept: */*\r - Expect: 100-continue\r - User-Agent: {USER_AGENT}\r Connection: close\r Content-Length: 4\r + Expect: 100-continue\r + User-Agent: {USER_AGENT}\r \r "}; @@ -37,9 +37,9 @@ async fn extra_one_hundred_continue() -> TestResult { let response_head = formatdoc! {" HTTP/1.1 200 Ok\r Date: {TEST_DATE}\r - Server: text\r Connection: close\r Content-Length: 20\r + Server: text\r \r response: 0123456789\ "}; @@ -71,10 +71,10 @@ async fn one_hundred_continue() -> TestResult { POST / HTTP/1.1\r Host: example.com\r Accept: */*\r - Expect: 100-continue\r - User-Agent: {USER_AGENT}\r Connection: close\r Content-Length: 4\r + Expect: 100-continue\r + User-Agent: {USER_AGENT}\r \r "}; @@ -87,9 +87,9 @@ async fn one_hundred_continue() -> TestResult { HTTP/1.1 200 Ok\r Date: {TEST_DATE}\r Accept: */*\r - Server: text\r Connection: close\r Content-Length: 20\r + Server: text\r \r response: 0123456789\ "}); @@ -115,8 +115,8 @@ async fn empty_body_no_100_continue() -> TestResult { POST / HTTP/1.1\r Host: example.com\r Accept: */*\r - User-Agent: {USER_AGENT}\r Connection: close\r + User-Agent: {USER_AGENT}\r \r "}; @@ -125,9 +125,9 @@ async fn empty_body_no_100_continue() -> TestResult { transport.write_all(formatdoc! {" HTTP/1.1 200 Ok\r Date: {TEST_DATE}\r - Server: text\r Connection: close\r Content-Length: 20\r + Server: text\r \r response: 0123456789\ "}); @@ -145,10 +145,10 @@ async fn two_small_continues() -> TestResult { POST / HTTP/1.1\r Host: example.com\r Accept: */*\r - Expect: 100-continue\r - User-Agent: {USER_AGENT}\r Connection: close\r Content-Length: 4\r + Expect: 100-continue\r + User-Agent: {USER_AGENT}\r \r "}; @@ -184,10 +184,10 @@ async fn little_continue_big_continue() -> TestResult { POST / HTTP/1.1\r Host: example.com\r Accept: */*\r - Expect: 100-continue\r - User-Agent: {USER_AGENT}\r Connection: close\r Content-Length: 4\r + Expect: 100-continue\r + User-Agent: {USER_AGENT}\r \r "}; diff --git a/http/Cargo.toml b/http/Cargo.toml index 3c4ed49a53..85fb2721ef 100644 --- a/http/Cargo.toml +++ b/http/Cargo.toml @@ -15,6 +15,7 @@ unstable = [] http-compat = ["dep:http0"] http-compat-1 = ["dep:http1"] serde = ["dep:serde"] +parse = [] [dependencies] encoding_rs = "0.8.33" diff --git a/http/src/conn.rs b/http/src/conn.rs index b76d005852..58213b3c00 100644 --- a/http/src/conn.rs +++ b/http/src/conn.rs @@ -5,7 +5,7 @@ use crate::{ liveness::{CancelOnDisconnect, LivenessFut}, received_body::ReceivedBodyState, util::encoding, - Body, BufWriter, Buffer, ConnectionStatus, Error, HeaderName, HeaderValue, Headers, HttpConfig, + Body, BufWriter, Buffer, ConnectionStatus, Error, Headers, HttpConfig, KnownHeaderName::{Connection, ContentLength, Date, Expect, Host, Server, TransferEncoding}, Method, ReceivedBody, Result, StateSet, Status, Swansong, Upgrade, Version, }; @@ -14,14 +14,13 @@ use futures_lite::{ future, io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, }; -use httparse::{Request, EMPTY_HEADER}; use memchr::memmem::Finder; use std::{ fmt::{self, Debug, Formatter}, future::Future, net::IpAddr, pin::pin, - str::FromStr, + str, sync::Arc, time::{Instant, SystemTime}, }; @@ -612,6 +611,7 @@ where Self::new_internal(DEFAULT_CONFIG, transport, bytes.into(), swansong, None).await } + #[cfg(not(feature = "parse"))] async fn new_internal( http_config: HttpConfig, mut transport: Transport, @@ -619,6 +619,10 @@ where swansong: Swansong, shared_state: Option>, ) -> Result { + use crate::{HeaderName, HeaderValue}; + use httparse::{Request, EMPTY_HEADER}; + use std::str::FromStr; + let (head_size, start_time) = Self::head(&mut transport, &mut buffer, &swansong, &http_config).await?; @@ -633,6 +637,7 @@ where httparse::Error::Version => Error::InvalidVersion, _ => Error::InvalidHead, })?; + if status.is_partial() { return Err(Error::InvalidHead); } @@ -651,7 +656,7 @@ where _ => return Err(Error::InvalidVersion), }; - let mut request_headers = Headers::with_capacity(httparse_req.headers.len()); + let mut request_headers = Headers::new(); for header in httparse_req.headers { let header_name = HeaderName::from_str(header.name)?; let header_value = HeaderValue::from(header.value.to_owned()); @@ -666,8 +671,73 @@ where .to_owned(); log::trace!("received:\n{method} {path} {version}\n{request_headers}"); - let mut response_headers = - Headers::with_capacity(http_config.response_header_initial_capacity); + let mut response_headers = Headers::new(); + response_headers.insert(Server, SERVER); + + buffer.ignore_front(head_size); + + Ok(Self { + transport, + request_headers, + method, + version, + path, + buffer, + response_headers, + status: None, + state: StateSet::new(), + response_body: None, + request_body_state: ReceivedBodyState::Start, + secure: false, + swansong, + after_send: AfterSend::default(), + start_time, + peer_ip: None, + http_config, + shared_state, + }) + } + + #[cfg(feature = "parse")] + async fn new_internal( + http_config: HttpConfig, + mut transport: Transport, + mut buffer: Buffer, + swansong: Swansong, + shared_state: Option>, + ) -> Result { + let (head_size, start_time) = + Self::head(&mut transport, &mut buffer, &swansong, &http_config).await?; + + let first_line_index = Finder::new(b"\r\n") + .find(&buffer) + .ok_or(Error::InvalidHead)?; + + let (method, path, version) = match &memchr::memchr_iter(b' ', &buffer[..first_line_index]) + .collect::>()[..] + { + [first, second] => { + let (first, second) = (*first, *second); + let method = Method::parse(&buffer[0..first])?; + let path = str::from_utf8(&buffer[first + 1..second]) + .map_err(|_| Error::RequestPathMissing)? + .to_string(); + let version = Version::parse(&buffer[second + 1..first_line_index])?; + + if !matches!(version, Version::Http1_1 | Version::Http1_0) { + return Err(Error::UnsupportedVersion(version)); + } + + (method, path, version) + } + _ => return Err(Error::InvalidHead), + }; + + let request_headers = Headers::parse(&buffer[first_line_index + 2..])?; + + Self::validate_headers(&request_headers)?; + + let mut response_headers = Headers::new(); response_headers.insert(Server, SERVER); buffer.ignore_front(head_size); diff --git a/http/src/headers.rs b/http/src/headers.rs index 288dae988e..00d966dec9 100644 --- a/http/src/headers.rs +++ b/http/src/headers.rs @@ -10,6 +10,7 @@ pub use header_values::HeaderValues; pub use known_header_name::KnownHeaderName; use header_name::HeaderNameInner; +use memchr::memmem::Finder; use unknown_header_name::UnknownHeaderName; use hashbrown::{ @@ -17,16 +18,19 @@ use hashbrown::{ HashMap, }; use smartcow::SmartCow; +use std::collections::{btree_map::Entry as BTreeEntry, BTreeMap}; use std::{ fmt::{self, Debug, Display, Formatter}, - hash::{BuildHasherDefault, Hasher}, + hash::Hasher, }; +use crate::Error; + /// Trillium's header map type -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Default)] #[must_use] pub struct Headers { - known: HashMap>, + known: BTreeMap, unknown: HashMap, HeaderValues>, } @@ -45,12 +49,6 @@ impl serde::Serialize for Headers { } } -impl Default for Headers { - fn default() -> Self { - Self::with_capacity(15) - } -} - impl Display for Headers { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { for (n, v) in self { @@ -62,23 +60,86 @@ impl Display for Headers { } } +#[derive(Debug, Clone, Copy)] +pub struct ParseError; +impl Display for ParseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("parse error") + } +} + +fn is_tchar(c: u8) -> bool { + matches!( + c, + b'a'..=b'z' + | b'A'..=b'Z' + | b'0'..=b'9' + | b'!' + | b'#' + | b'$' + | b'%' + | b'&' + | b'\'' + | b'*' + | b'+' + | b'-' + | b'.' + | b'^' + | b'_' + | b'`' + | b'|' + | b'~' + ) +} + impl Headers { - /// Construct a new Headers, expecting to see at least this many known headers. - pub fn with_capacity(capacity: usize) -> Self { - Self { - known: HashMap::with_capacity_and_hasher(capacity, BuildHasherDefault::default()), - unknown: HashMap::with_capacity(0), + #[doc(hidden)] + pub fn extend_parse(&mut self, bytes: &[u8]) -> Result { + let newlines = Finder::new(b"\r\n").find_iter(bytes).collect::>(); + // self.reserve(newlines.len().saturating_sub(1)); + let mut new_header_count = 0; + let mut last_line = 0; + for newline in newlines { + if newline == last_line { + continue; + } + + let token_start = last_line; + let mut token_end = token_start; + while is_tchar(bytes[token_end]) { + token_end += 1; + } + + let header_name = HeaderName::parse(&bytes[token_start..token_end])?.to_owned(); + + if bytes[token_end] != b':' { + return Err(Error::InvalidHead); + } + + let mut value_start = token_end + 1; + while (bytes[value_start] as char).is_whitespace() { + value_start += 1; + } + + let header_value = HeaderValue::parse(&bytes[value_start..newline]); + self.append(header_name, header_value); + new_header_count += 1; + last_line = newline + 2; } + Ok(new_header_count) } - /// Construct a new headers with a default capacity of 15 known headers - pub fn new() -> Self { - Self::default() + #[cfg(feature = "parse")] + #[doc(hidden)] + pub fn parse(bytes: &[u8]) -> Result { + let mut headers = Headers::new(); + headers.extend_parse(bytes)?; + Ok(headers) } - /// Extend the capacity of the known headers map by this many - pub fn reserve(&mut self, additional: usize) { - self.known.reserve(additional); + /// Construct a new headers with a default capacity + pub fn new() -> Self { + Self::default() } /// Return an iterator over borrowed header names and header @@ -107,10 +168,10 @@ impl Headers { let value = value.into(); match name.into().0 { HeaderNameInner::KnownHeader(known) => match self.known.entry(known) { - Entry::Occupied(mut o) => { + BTreeEntry::Occupied(mut o) => { o.get_mut().extend(value); } - Entry::Vacant(v) => { + BTreeEntry::Vacant(v) => { v.insert(value); } }, @@ -129,13 +190,12 @@ impl Headers { /// A slightly more efficient way to combine two [`Headers`] than /// using [`Extend`] pub fn append_all(&mut self, other: Headers) { - self.known.reserve(other.known.len()); for (name, value) in other.known { match self.known.entry(name) { - Entry::Occupied(mut entry) => { + BTreeEntry::Occupied(mut entry) => { entry.get_mut().extend(value); } - Entry::Vacant(entry) => { + BTreeEntry::Vacant(entry) => { entry.insert(value); } } @@ -155,7 +215,6 @@ impl Headers { /// Combine two [`Headers`], replacing any existing header values pub fn insert_all(&mut self, other: Headers) { - self.known.reserve(other.known.len()); for (name, value) in other.known { self.known.insert(name, value); } @@ -365,12 +424,6 @@ where HV: Into, { fn extend>(&mut self, iter: T) { - let iter = iter.into_iter(); - match iter.size_hint() { - (additional, _) if additional > 0 => self.known.reserve(additional), - _ => {} - }; - for (name, values) in iter { self.append(name, values); } @@ -384,11 +437,7 @@ where { fn from_iter>(iter: T) -> Self { let iter = iter.into_iter(); - let mut headers = match iter.size_hint() { - (0, _) => Self::new(), - (n, _) => Self::with_capacity(n), - }; - + let mut headers = Self::new(); for (name, values) in iter { headers.append(name, values); } @@ -428,7 +477,7 @@ impl<'a> IntoIterator for &'a Headers { #[derive(Debug)] pub struct IntoIter { - known: hash_map::IntoIter, + known: std::collections::btree_map::IntoIter, unknown: hash_map::IntoIter, HeaderValues>, } @@ -454,7 +503,7 @@ impl From for IntoIter { #[derive(Debug)] pub struct Iter<'a> { - known: hash_map::Iter<'a, KnownHeaderName, HeaderValues>, + known: std::collections::btree_map::Iter<'a, KnownHeaderName, HeaderValues>, unknown: hash_map::Iter<'a, UnknownHeaderName<'static>, HeaderValues>, } diff --git a/http/src/headers/header_name.rs b/http/src/headers/header_name.rs index 656d43096a..dbd0f85704 100644 --- a/http/src/headers/header_name.rs +++ b/http/src/headers/header_name.rs @@ -14,6 +14,14 @@ use HeaderNameInner::{KnownHeader, UnknownHeader}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct HeaderName<'a>(pub(super) HeaderNameInner<'a>); +impl<'a> HeaderName<'a> { + pub(crate) fn parse(bytes: &'a [u8]) -> Result { + std::str::from_utf8(bytes) + .map_err(|_| Error::InvalidHeaderName) + .map(HeaderName::from) + } +} + #[cfg(feature = "serde")] impl serde::Serialize for HeaderName<'_> { fn serialize(&self, serializer: S) -> Result diff --git a/http/src/headers/header_value.rs b/http/src/headers/header_value.rs index f3c4389ff1..6d512ffc89 100644 --- a/http/src/headers/header_value.rs +++ b/http/src/headers/header_value.rs @@ -70,6 +70,13 @@ impl HeaderValue { } }) } + + pub(crate) fn parse(bytes: &[u8]) -> Self { + match std::str::from_utf8(bytes) { + Ok(s) => Self(Utf8(SmartCow::Owned(s.into()))), + Err(_) => Self(Bytes(bytes.into())), + } + } } impl Display for HeaderValue { diff --git a/http/src/headers/known_header_name.rs b/http/src/headers/known_header_name.rs index ba048c47ec..987d932dd8 100644 --- a/http/src/headers/known_header_name.rs +++ b/http/src/headers/known_header_name.rs @@ -4,7 +4,6 @@ use std::{ hash::Hash, str::FromStr, }; -use HeaderNameInner::{KnownHeader, UnknownHeader}; impl Display for KnownHeaderName { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { @@ -14,16 +13,13 @@ impl Display for KnownHeaderName { impl From for HeaderName<'_> { fn from(khn: KnownHeaderName) -> Self { - Self(KnownHeader(khn)) + Self(HeaderNameInner::KnownHeader(khn)) } } impl PartialEq> for KnownHeaderName { fn eq(&self, other: &HeaderName) -> bool { - match &other.0 { - KnownHeader(k) => self == k, - UnknownHeader(_) => false, - } + matches!(&other.0, HeaderNameInner::KnownHeader(k) if self == k) } } @@ -38,7 +34,7 @@ macro_rules! known_headers { /// represent as a u8. Use a `KnownHeaderName` variant instead /// of a &'static str anywhere possible, as it allows trillium /// to skip parsing the header entirely. - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] #[non_exhaustive] #[repr(u8)] pub enum KnownHeaderName { diff --git a/http/src/method.rs b/http/src/method.rs index bf4d1c4e2f..b4a8ade0f7 100644 --- a/http/src/method.rs +++ b/http/src/method.rs @@ -1,7 +1,7 @@ // originally from https://github.com/http-rs/http-types/blob/main/src/method.rs use std::{ fmt::{self, Display}, - str::FromStr, + str::{self, FromStr}, }; /// HTTP request methods. @@ -432,6 +432,13 @@ impl Method { Self::VersionControl => "VERSION-CONTROL", } } + + #[cfg(feature = "parse")] + pub(crate) fn parse(bytes: &[u8]) -> crate::Result { + str::from_utf8(bytes) + .map_err(|_| crate::Error::UnrecognizedMethod(String::from_utf8_lossy(bytes).into()))? + .parse() + } } impl Display for Method { @@ -485,9 +492,7 @@ impl FromStr for Method { "UPDATE" => Ok(Self::Update), "UPDATEREDIRECTREF" => Ok(Self::UpdateRedirectRef), "VERSION-CONTROL" => Ok(Self::VersionControl), - _ => Err(crate::Error::UnrecognizedMethod( - "Invalid HTTP method".into(), - )), + _ => Err(crate::Error::UnrecognizedMethod(s.to_string())), } } } diff --git a/http/src/received_body.rs b/http/src/received_body.rs index e4eee6b284..bf315d6aac 100644 --- a/http/src/received_body.rs +++ b/http/src/received_body.rs @@ -1,7 +1,6 @@ use crate::{copy, http_config::DEFAULT_CONFIG, Body, Buffer, HttpConfig, MutCow}; use encoding_rs::Encoding; use futures_lite::{ready, AsyncRead, AsyncReadExt, AsyncWrite}; -use httparse::{InvalidChunkSize, Status}; use std::{ fmt::{self, Debug, Formatter}, future::{Future, IntoFuture}, diff --git a/http/src/received_body/chunked.rs b/http/src/received_body/chunked.rs index c96275ff67..02c70eea83 100644 --- a/http/src/received_body/chunked.rs +++ b/http/src/received_body/chunked.rs @@ -1,8 +1,35 @@ +use std::io::ErrorKind::InvalidData; + use super::{ - io, ready, slice_from, AsyncRead, Buffer, Chunked, Context, End, ErrorKind, InvalidChunkSize, - PartialChunkSize, Pin, Ready, ReceivedBody, ReceivedBodyState, StateOutput, Status, + io, ready, slice_from, AsyncRead, Buffer, Chunked, Context, End, ErrorKind, PartialChunkSize, + Pin, Ready, ReceivedBody, ReceivedBodyState, StateOutput, }; +#[cfg(feature = "parse")] +fn parse_chunk_size(buf: &[u8]) -> Result, ()> { + use memchr::memmem::Finder; + use std::str; + + let Some(index) = memchr::memchr2(b';', b'\r', &buf[..buf.len().min(17)]) else { + return if buf.len() < 17 { Ok(None) } else { Err(()) }; + }; + let src = str::from_utf8(&buf[..index]).map_err(|_| ())?; + let chunk_size = u64::from_str_radix(src, 16).map_err(|_| ())?; + Ok(Finder::new("\r\n") + .find(&buf[index..]) + .map(|end| (index + end + 2, chunk_size + 2))) +} + +#[cfg(not(feature = "parse"))] +fn parse_chunk_size(buf: &[u8]) -> Result, ()> { + use httparse::{parse_chunk_size, Status}; + match parse_chunk_size(buf) { + Ok(Status::Complete((index, next_chunk))) => Ok(Some((index, next_chunk + 2))), + Ok(Status::Partial) => Ok(None), + Err(_) => Err(()), + } +} + impl<'conn, Transport> ReceivedBody<'conn, Transport> where Transport: AsyncRead + Unpin + Send + Sync + 'static, @@ -45,29 +72,18 @@ where self.buffer.extend_from_slice(&buf[..bytes]); - match httparse::parse_chunk_size(&self.buffer) { - Ok(Status::Complete((framing_bytes, remaining))) => { - self.buffer.ignore_front(framing_bytes); - Ready(Ok(( - if remaining == 0 { - End - } else { - Chunked { - remaining: remaining + 2, - total, - } - }, - 0, - ))) + Ready(match parse_chunk_size(&self.buffer) { + Ok(Some((used, remaining))) => { + self.buffer.ignore_front(used); + if remaining == 2 { + Ok((End, 0)) + } else { + Ok((Chunked { remaining, total }, 0)) + } } - - Ok(Status::Partial) => Ready(Ok((PartialChunkSize { total }, 0))), - - Err(InvalidChunkSize) => Ready(Err(io::Error::new( - ErrorKind::InvalidData, - "invalid chunk framing", - ))), - } + Ok(None) => Ok((PartialChunkSize { total }, 0)), + Err(()) => Err(io::Error::new(InvalidData, "invalid chunk size")), + }) } } @@ -113,14 +129,14 @@ pub(super) fn chunk_decode( }; } - match httparse::parse_chunk_size(buf_to_read) { - Ok(Status::Complete((framing_bytes, chunk_size))) => { + match parse_chunk_size(buf_to_read) { + Ok(Some((framing_bytes, chunk_size))) => { chunk_start += framing_bytes as u64; - chunk_end = (2 + chunk_start) + chunk_end = chunk_start .checked_add(chunk_size) - .ok_or_else(|| io::Error::new(ErrorKind::InvalidData, "chunk size too long"))?; + .ok_or_else(|| io::Error::new(InvalidData, "chunk size too long"))?; - if chunk_size == 0 { + if chunk_size == 2 { if let Some(buf) = slice_from(chunk_end, buf) { self_buffer.extend_from_slice(buf); } @@ -128,13 +144,13 @@ pub(super) fn chunk_decode( } } - Ok(Status::Partial) => { + Ok(None) => { self_buffer.extend_from_slice(buf_to_read); break PartialChunkSize { total }; } - Err(InvalidChunkSize) => { - return Err(io::Error::new(ErrorKind::InvalidData, "invalid chunk size")); + Err(()) => { + return Err(io::Error::new(InvalidData, "invalid chunk size")); } } }; @@ -324,6 +340,33 @@ mod tests { assert_decoded((7, "hello\r\n0\r\n\r\n"), (None, "hello", "")); } + #[test] + fn test_chunk_start_with_ext() { + assert_decoded((0, "5;abcdefg\r\n12345\r\n"), (Some(0), "12345", "")); + assert_decoded((0, "F;aaa\taaaaa\taaa aaa\r\n1"), (Some(14 + 2), "1", "")); + assert_decoded((0, "5;;;;;;;;;;;;;;;;\r\n123"), (Some(2 + 2), "123", "")); + assert_decoded( + (0, "1; a = b\"\" \r\nX\r\n1;;;\r\nX\r\n"), + (Some(0), "XX", ""), + ); + assert_decoded((0, "1\r\nX\r\n1;\r\nX\r\n1"), (Some(0), "XX", "1")); + assert_decoded((0, "FFF; 000\r\n"), (Some(0xfff + 2), "", "")); + assert_decoded((10, "hello"), (Some(5), "hello", "")); + assert_decoded( + (7, "hello\r\nA;111\r\n world"), + (Some(4 + 2), "hello world", ""), + ); + assert_decoded( + (0, "e\r\ntest test test\r\n0;00\r\n\r\n"), + (None, "test test test", ""), + ); + assert_decoded( + (0, "1;\r\n_\r\n0;\r\n\r\nnext request"), + (None, "_", "next request"), + ); + assert_decoded((7, "hello\r\n0;\r\n\r\n"), (None, "hello", "")); + } + #[test] fn read_string_and_read_bytes() { block_on(async { diff --git a/http/src/status.rs b/http/src/status.rs index 09bcf05215..6345cadbfa 100644 --- a/http/src/status.rs +++ b/http/src/status.rs @@ -1,6 +1,10 @@ // originally from https://github.com/http-rs/http-types/blob/main/src/status_code.rs use crate::Error; -use std::fmt::{self, Debug, Display}; +use std::{ + convert::TryFrom, + fmt::{self, Debug, Display}, + str::FromStr, +}; /// HTTP response status codes. /// @@ -632,3 +636,13 @@ impl Display for Status { write!(f, "{} {}", *self as u16, self.canonical_reason()) } } + +impl FromStr for Status { + type Err = Error; + + fn from_str(s: &str) -> Result { + u16::from_str(s) + .map_err(|_| Error::InvalidStatus)? + .try_into() + } +} diff --git a/http/src/version.rs b/http/src/version.rs index 1f5caf17c5..43dae034ea 100644 --- a/http/src/version.rs +++ b/http/src/version.rs @@ -1,6 +1,11 @@ // originally from https://github.com/http-rs/http-types/blob/main/src/version.rs -use std::{error::Error, fmt::Display, str::FromStr}; +use std::{ + fmt::Display, + str::{self, FromStr}, +}; + +use crate::Error; /// The version of the HTTP protocol in use. #[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] @@ -69,19 +74,17 @@ impl Version { Version::Http3_0 => "HTTP/3", } } -} -#[derive(Debug, Clone)] -pub struct UnrecognizedVersion(String); -impl Display for UnrecognizedVersion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!("unrecognized http version: {}", self.0)) + #[cfg(feature = "parse")] + pub(crate) fn parse(buf: &[u8]) -> crate::Result { + str::from_utf8(buf) + .map_err(|_| Error::InvalidVersion)? + .parse() } } -impl Error for UnrecognizedVersion {} impl FromStr for Version { - type Err = UnrecognizedVersion; + type Err = Error; fn from_str(s: &str) -> Result { match s { @@ -90,7 +93,7 @@ impl FromStr for Version { "HTTP/1.1" | "http/1.1" | "1.1" => Ok(Self::Http1_1), "HTTP/2" | "http/2" | "2" => Ok(Self::Http2_0), "HTTP/3" | "http/3" | "3" => Ok(Self::Http3_0), - _ => Err(UnrecognizedVersion(s.to_string())), + _ => Err(Error::InvalidVersion), } } } @@ -133,7 +136,7 @@ mod test { assert_eq!( "not a version".parse::().unwrap_err().to_string(), - "unrecognized http version: not a version" + "Invalid or missing version" ); } diff --git a/http/tests/corpus/1.response b/http/tests/corpus/1.response index 28436fa05a..efd1904aea 100644 --- a/http/tests/corpus/1.response +++ b/http/tests/corpus/1.response @@ -1,7 +1,7 @@ HTTP/1.1 200 OK\r\n Date: Tue, 21 Nov 2023 21:27:21 GMT\r\n -Server: corpus-test\r\n Content-Length: 186\r\n +Server: corpus-test\r\n \r\n ===request===\n method: POST\n diff --git a/http/tests/corpus/10.response b/http/tests/corpus/10.response index 98f6af74fe..e336a9fff2 100644 --- a/http/tests/corpus/10.response +++ b/http/tests/corpus/10.response @@ -1,7 +1,7 @@ HTTP/1.1 200 OK\r\n Date: Tue, 21 Nov 2023 21:27:21 GMT\r\n -Server: corpus-test\r\n Content-Length: 140\r\n +Server: corpus-test\r\n \r\n ===request===\n method: GET\n diff --git a/http/tests/corpus/2.response b/http/tests/corpus/2.response index 9b49164ca5..3dd43b1f9d 100644 --- a/http/tests/corpus/2.response +++ b/http/tests/corpus/2.response @@ -1,7 +1,7 @@ HTTP/1.1 200 OK\r\n Date: Tue, 21 Nov 2023 21:27:21 GMT\r\n -Server: corpus-test\r\n Content-Length: 286\r\n +Server: corpus-test\r\n \r\n ===request===\n method: POST\n diff --git a/http/tests/corpus/3.response b/http/tests/corpus/3.response index 9cdb728ac0..a0cb825a8d 100644 --- a/http/tests/corpus/3.response +++ b/http/tests/corpus/3.response @@ -1,7 +1,7 @@ HTTP/1.1 200 OK\r\n Date: Tue, 21 Nov 2023 21:27:21 GMT\r\n -Server: corpus-test\r\n Content-Length: 286\r\n +Server: corpus-test\r\n \r\n ===request===\n method: POST\n diff --git a/http/tests/corpus/4.response b/http/tests/corpus/4.response index 31f0e2f040..3002243e11 100644 --- a/http/tests/corpus/4.response +++ b/http/tests/corpus/4.response @@ -1,7 +1,7 @@ HTTP/1.1 200 OK\r\n Date: Tue, 21 Nov 2023 21:27:21 GMT\r\n -Server: corpus-test\r\n Content-Length: 286\r\n +Server: corpus-test\r\n \r\n ===request===\n method: POST\n diff --git a/http/tests/corpus/5.response b/http/tests/corpus/5.response index 3a9464bf5e..2557e0559e 100644 --- a/http/tests/corpus/5.response +++ b/http/tests/corpus/5.response @@ -1,7 +1,7 @@ HTTP/1.1 200 OK\r\n Date: Tue, 21 Nov 2023 21:27:21 GMT\r\n -Server: corpus-test\r\n Content-Length: 286\r\n +Server: corpus-test\r\n \r\n ===request===\n method: POST\n diff --git a/http/tests/corpus/6.response b/http/tests/corpus/6.response index 34bf7869db..1255ecf8ce 100644 --- a/http/tests/corpus/6.response +++ b/http/tests/corpus/6.response @@ -1,7 +1,7 @@ HTTP/1.1 200 OK\r\n Date: Tue, 21 Nov 2023 21:27:21 GMT\r\n -Server: corpus-test\r\n Content-Length: 286\r\n +Server: corpus-test\r\n \r\n ===request===\n method: POST\n diff --git a/http/tests/corpus/7.response b/http/tests/corpus/7.response index 00a6e0d674..285dc77c42 100644 --- a/http/tests/corpus/7.response +++ b/http/tests/corpus/7.response @@ -1,7 +1,7 @@ HTTP/1.1 200 OK\r\n Date: Tue, 21 Nov 2023 21:27:21 GMT\r\n -Server: corpus-test\r\n Content-Length: 286\r\n +Server: corpus-test\r\n \r\n ===request===\n method: POST\n diff --git a/http/tests/corpus/8.response b/http/tests/corpus/8.response index 1390ed2f07..f52acb960e 100644 --- a/http/tests/corpus/8.response +++ b/http/tests/corpus/8.response @@ -1,7 +1,7 @@ HTTP/1.1 200 OK\r\n Date: Tue, 21 Nov 2023 21:27:21 GMT\r\n -Server: corpus-test\r\n Content-Length: 159\r\n +Server: corpus-test\r\n \r\n ===request===\n method: PATCH\n diff --git a/http/tests/corpus/9.response b/http/tests/corpus/9.response index 61b49a393a..e4a95c7356 100644 --- a/http/tests/corpus/9.response +++ b/http/tests/corpus/9.response @@ -1,7 +1,7 @@ HTTP/1.1 200 OK\r\n Date: Tue, 21 Nov 2023 21:27:21 GMT\r\n -Server: corpus-test\r\n Content-Length: 170\r\n +Server: corpus-test\r\n \r\n ===request===\n method: PATCH\n diff --git a/http/tests/corpus/unsupported-http-version.error b/http/tests/corpus/unsupported-http-version.error deleted file mode 100644 index 200289f76d..0000000000 --- a/http/tests/corpus/unsupported-http-version.error +++ /dev/null @@ -1 +0,0 @@ -Invalid or missing version \ No newline at end of file diff --git a/http/tests/corpus/unsupported-http-version.request b/http/tests/corpus/unsupported-http-version.request deleted file mode 100644 index d11affff97..0000000000 --- a/http/tests/corpus/unsupported-http-version.request +++ /dev/null @@ -1,3 +0,0 @@ -GET / HTTP/0.9\r\n -\r\n - diff --git a/http/tests/one_hundred_continue.rs b/http/tests/one_hundred_continue.rs index 84f13fbd36..b2c6ac4ccd 100644 --- a/http/tests/one_hundred_continue.rs +++ b/http/tests/one_hundred_continue.rs @@ -44,9 +44,9 @@ async fn one_hundred_continue() -> TestResult { let expected_response = formatdoc! {" HTTP/1.1 200 OK\r Date: {TEST_DATE}\r - Server: {SERVER}\r Connection: close\r Content-Length: 20\r + Server: {SERVER}\r \r response: 0123456789\ "}; @@ -77,9 +77,9 @@ async fn one_hundred_continue_http_one_dot_zero() -> TestResult { let expected_response = formatdoc! {" HTTP/1.0 200 OK\r Date: {TEST_DATE}\r - Server: {SERVER}\r Connection: close\r Content-Length: 20\r + Server: {SERVER}\r \r response: 0123456789\ "}; diff --git a/http/tests/unsafe_headers.rs b/http/tests/unsafe_headers.rs index 59a537b4aa..1b7029c767 100644 --- a/http/tests/unsafe_headers.rs +++ b/http/tests/unsafe_headers.rs @@ -37,8 +37,8 @@ async fn bad_headers() -> TestResult { let expected_response = formatdoc! {" HTTP/1.1 200 OK\r Date: {TEST_DATE}\r - Server: {SERVER}\r Content-Length: 20\r + Server: {SERVER}\r \r response: 0123456789\ "};