diff --git a/http/src/after_send.rs b/http/src/after_send.rs new file mode 100644 index 0000000000..f09dacba95 --- /dev/null +++ b/http/src/after_send.rs @@ -0,0 +1,50 @@ +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum SendStatus { + Success, + Failure, +} +impl From for SendStatus { + fn from(success: bool) -> Self { + if success { + Self::Success + } else { + Self::Failure + } + } +} + +impl SendStatus { + pub fn is_success(self) -> bool { + SendStatus::Success == self + } +} + +#[derive(Default)] +pub(crate) struct AfterSend(Option>); + +impl AfterSend { + pub(crate) fn call(&mut self, send_status: SendStatus) { + if let Some(after_send) = self.0.take() { + after_send(send_status); + } + } + + pub(crate) fn append(&mut self, after_send: F) + where + F: FnOnce(SendStatus) + Send + Sync + 'static, + { + self.0 = Some(match self.0.take() { + Some(existing_after_send) => Box::new(move |ss| { + existing_after_send(ss); + after_send(ss); + }), + None => Box::new(after_send), + }); + } +} + +impl Drop for AfterSend { + fn drop(&mut self) { + self.call(SendStatus::Failure); + } +} diff --git a/http/src/conn.rs b/http/src/conn.rs index 7f16d22dc1..64dba27418 100644 --- a/http/src/conn.rs +++ b/http/src/conn.rs @@ -1,8 +1,11 @@ use crate::{ + after_send::{AfterSend, SendStatus}, copy, + http_config::DEFAULT_CONFIG, received_body::ReceivedBodyState, util::encoding, Body, BufWriter, ConnectionStatus, Error, HeaderName, HeaderValue, HeaderValues, Headers, + HttpConfig, KnownHeaderName::{Connection, ContentLength, Date, Expect, Host, Server, TransferEncoding}, Method, ReceivedBody, Result, StateSet, Status, Stopper, Upgrade, Version, }; @@ -16,63 +19,11 @@ use std::{ future::Future, net::IpAddr, str::FromStr, - sync::Arc, time::{Instant, SystemTime}, }; const SERVER: &str = concat!("trillium/", env!("CARGO_PKG_VERSION")); -#[derive(Debug, Clone, Copy, Eq, PartialEq)] -pub enum SendStatus { - Success, - Failure, -} -impl From for SendStatus { - fn from(success: bool) -> Self { - if success { - Self::Success - } else { - Self::Failure - } - } -} - -impl SendStatus { - pub fn is_success(self) -> bool { - SendStatus::Success == self - } -} - -#[derive(Default)] -pub(crate) struct AfterSend(Option>); - -impl AfterSend { - pub(crate) fn call(&mut self, send_status: SendStatus) { - if let Some(after_send) = self.0.take() { - after_send(send_status); - } - } - - pub(crate) fn append(&mut self, after_send: F) - where - F: FnOnce(SendStatus) + Send + Sync + 'static, - { - self.0 = Some(match self.0.take() { - Some(existing_after_send) => Box::new(move |ss| { - existing_after_send(ss); - after_send(ss); - }), - None => Box::new(after_send), - }); - } -} - -impl Drop for AfterSend { - fn drop(&mut self) { - self.call(SendStatus::Failure); - } -} - /** A http connection Unlike in other rust http implementations, this struct represents both @@ -96,7 +47,7 @@ pub struct Conn { pub(crate) after_send: AfterSend, pub(crate) start_time: Instant, pub(crate) peer_ip: Option, - pub(crate) http_config: Arc, + pub(crate) http_config: HttpConfig, } impl Debug for Conn { @@ -126,77 +77,6 @@ impl Debug for Conn { } } -#[derive(Clone, Debug)] -pub struct HttpConfig { - stopper: Stopper, - write_buffer_len: usize, - read_buffer_len: usize, - max_head_len: usize, - max_headers: usize, - initial_header_capacity: usize, - copy_loops_per_yield: usize, -} - -// impl HttpConfig { -// pub fn with_stopper(mut self, stopper: Stopper) -> Self { -// self.stopper = stopper; -// self -// } - -// pub fn with_write_buffer_len(mut self, write_buffer_len: usize) -> Self { -// self.write_buffer_len = write_buffer_len; -// self -// } - -// pub fn with_read_buffer_len(mut self, read_buffer_len: usize) -> Self { -// self.read_buffer_len = read_buffer_len; -// self -// } - -// pub fn with_max_head_len(mut self, max_head_len: usize) -> Self { -// self.max_head_len = max_head_len; -// self -// } - -// pub fn with_max_headers(mut self, max_headers: usize) -> Self { -// self.max_headers = max_headers; -// self -// } - -// pub fn with_initial_header_capacity(mut self, initial_header_capacity: usize) -> Self { -// self.initial_header_capacity = initial_header_capacity; -// self -// } - -// pub fn with_copy_loops_per_yield(mut self, copy_loops_per_yield: usize) -> Self { -// self.copy_loops_per_yield = copy_loops_per_yield; -// self -// } -// } - -impl Default for HttpConfig { - fn default() -> Self { - Self { - stopper: Stopper::default(), - write_buffer_len: 512, - read_buffer_len: 128, - max_head_len: 8 * 1024, - max_headers: 128, - initial_header_capacity: 16, - copy_loops_per_yield: 16, - } - } -} - -impl From for HttpConfig { - fn from(stopper: Stopper) -> Self { - Self { - stopper, - ..Self::default() - } - } -} - impl Conn where Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, @@ -230,39 +110,20 @@ where F: Fn(Conn) -> Fut, Fut: Future> + Send, { - Self::map_with_config(transport, Arc::new(stopper.into()), handler).await + Self::map_with_config(DEFAULT_CONFIG, stopper, transport, handler).await } - /// read any number of new `Conn`s from the transport and call the - /// provided handler function until either the connection is closed or - /// an upgrade is requested. A return value of Ok(None) indicates a - /// closed connection, while a return value of Ok(Some(upgrade)) - /// represents an upgrade. - /// - /// See the documentation for [`Conn`] for a full example. - /// - /// # Errors - /// - /// This will return an error variant if: - /// - /// * there is an io error when reading from the underlying transport - /// * headers are too long - /// * we are unable to parse some aspect of the request - /// * the request is an unsupported http version - /// * we cannot make sense of the headers, such as if there is a - /// `content-length` header as well as a `transfer-encoding: chunked` - /// header. - async fn map_with_config( + http_config: HttpConfig, + stopper: Stopper, transport: Transport, - config: Arc, handler: F, ) -> Result>> where F: Fn(Conn) -> Fut, Fut: Future> + Send, { - let mut conn = Conn::new_with_config(transport, None, config).await?; + let mut conn = Conn::new_with_config(http_config, transport, None, stopper).await?; loop { conn = match handler(conn).await.send().await? { @@ -598,7 +459,7 @@ where bytes: Option>, stopper: Stopper, ) -> Result { - Self::new_with_config(transport, bytes, Arc::new(stopper.into())).await + Self::new_with_config(DEFAULT_CONFIG, transport, bytes, stopper).await } /// # Create a new `Conn` @@ -621,13 +482,13 @@ where /// `content-length` header as well as a `transfer-encoding: chunked` /// header. async fn new_with_config( + http_config: HttpConfig, transport: Transport, bytes: Option>, - http_config: Arc, + stopper: Stopper, ) -> Result { - let stopper = http_config.stopper.clone(); let (transport, buf, extra_bytes, start_time) = - Self::head(transport, bytes, &http_config).await?; + Self::head(transport, bytes, &stopper, &http_config).await?; let buffer = if extra_bytes.is_empty() { None @@ -795,6 +656,7 @@ where async fn head( mut transport: Transport, bytes: Option>, + stopper: &Stopper, http_config: &HttpConfig, ) -> Result<(Transport, Vec, Vec, Instant)> { let mut buf = bytes.unwrap_or_default(); @@ -805,10 +667,6 @@ where let mut resize_by = http_config.read_buffer_len; loop { if len >= http_config.max_head_len { - log::error!( - "headers were {len}, greater than {}", - http_config.max_head_len - ); return Err(Error::HeadersTooLong); } @@ -816,8 +674,7 @@ where buf.resize(buf.len() + resize_by, 0); resize_by *= 2; if len == 0 { - http_config - .stopper + stopper .stop_future(transport.read(&mut buf[..])) .await .ok_or(Error::Closed)?? diff --git a/http/src/error.rs b/http/src/error.rs index 3f9bfaeb9e..56826b5aca 100644 --- a/http/src/error.rs +++ b/http/src/error.rs @@ -80,10 +80,16 @@ pub enum Error { #[error("unexpected header: {0}")] UnexpectedHeader(&'static str), - /// for security reasons, we do not allow request headers beyond - /// 8kb. - #[error("Head byte length should be less than 8kb")] + /// to mitigate against malicious http clients, we do not allow request headers beyond this + /// length. + #[error("Headers were malformed or longer than allowed")] HeadersTooLong, + + /// to mitigate against malicious http clients, we do not read received bodies beyond this + /// length to memory. If you need to receive longer bodies, use the Stream or AsyncRead + /// implementation on ReceivedBody + #[error("Received body too long. Maximum {0} bytes")] + ReceivedBodyTooLong(u64), } /// this crate's result type diff --git a/http/src/http_config.rs b/http/src/http_config.rs new file mode 100644 index 0000000000..3fb84615d1 --- /dev/null +++ b/http/src/http_config.rs @@ -0,0 +1,67 @@ +#![allow(dead_code)] + +#[derive(Clone, Copy, Debug)] +pub struct HttpConfig { + pub(crate) write_buffer_len: usize, + pub(crate) read_buffer_len: usize, + pub(crate) max_head_len: usize, + pub(crate) max_headers: usize, + pub(crate) initial_header_capacity: usize, + pub(crate) copy_loops_per_yield: usize, + pub(crate) received_body_max_len: u64, + pub(crate) received_body_initial_len: usize, +} + +impl HttpConfig { + pub(crate) fn with_write_buffer_len(mut self, write_buffer_len: usize) -> Self { + self.write_buffer_len = write_buffer_len; + self + } + + pub(crate) fn with_read_buffer_len(mut self, read_buffer_len: usize) -> Self { + self.read_buffer_len = read_buffer_len; + self + } + + pub(crate) fn with_max_head_len(mut self, max_head_len: usize) -> Self { + self.max_head_len = max_head_len; + self + } + + pub(crate) fn with_max_headers(mut self, max_headers: usize) -> Self { + self.max_headers = max_headers; + self + } + + pub(crate) fn with_initial_header_capacity(mut self, initial_header_capacity: usize) -> Self { + self.initial_header_capacity = initial_header_capacity; + self + } + + pub(crate) fn with_copy_loops_per_yield(mut self, copy_loops_per_yield: usize) -> Self { + self.copy_loops_per_yield = copy_loops_per_yield; + self + } + + pub(crate) fn with_received_body_max_len(mut self, received_body_max_len: u64) -> Self { + self.received_body_max_len = received_body_max_len; + self + } +} + +impl Default for HttpConfig { + fn default() -> Self { + DEFAULT_CONFIG + } +} + +pub const DEFAULT_CONFIG: HttpConfig = HttpConfig { + write_buffer_len: 512, + read_buffer_len: 128, + max_head_len: 8 * 1024, + max_headers: 128, + initial_header_capacity: 16, + copy_loops_per_yield: 16, + received_body_max_len: 524_288_000u64, + received_body_initial_len: 128, +}; diff --git a/http/src/lib.rs b/http/src/lib.rs index 9fb8194613..d0fd737e49 100644 --- a/http/src/lib.rs +++ b/http/src/lib.rs @@ -135,3 +135,8 @@ pub(crate) use copy::copy; mod bufwriter; pub(crate) use bufwriter::BufWriter; + +mod http_config; +pub(crate) use http_config::HttpConfig; + +pub(crate) mod after_send; diff --git a/http/src/received_body.rs b/http/src/received_body.rs index 34371fb305..6f28eee40d 100644 --- a/http/src/received_body.rs +++ b/http/src/received_body.rs @@ -1,20 +1,21 @@ -use crate::{Body, MutCow}; +use crate::{copy, http_config::DEFAULT_CONFIG, Body, HttpConfig, MutCow}; use encoding_rs::Encoding; -use futures_lite::{io, ready, AsyncRead, AsyncReadExt, AsyncWrite, Stream}; -use httparse::Status; +use futures_lite::{ready, AsyncRead, AsyncReadExt, AsyncWrite, Stream}; +use httparse::{InvalidChunkSize, Status}; use std::{ - convert::TryInto, fmt::{self, Formatter}, - future::IntoFuture, - io::ErrorKind, + future::{Future, IntoFuture}, + io::{self, ErrorKind}, iter, pin::Pin, task::{Context, Poll}, }; - use Poll::{Pending, Ready}; use ReceivedBodyState::{Chunked, End, FixedLength, Start}; +#[cfg(test)] +mod tests; + macro_rules! trace { ($s:literal, $($arg:tt)+) => ( log::trace!(concat!(":{} ", $s), line!(), $($arg)+); @@ -35,6 +36,20 @@ let body = conn.request_body().await; assert_eq!(body.read_string().await?, "hello"); # trillium_http::Result::Ok(()) }).unwrap(); ``` + +## Bounds checking + +Every `ReceivedBody` has a maximum length beyond which it will return an error, expressed as a +u64. To override this on the specific `ReceivedBody`, use [`ReceivedBody::with_max_len`] or +[`ReceivedBody::set_max_len`] + +The default maximum length is currently set to 500mb. In the next semver-minor release, this value +will decrease substantially. + +## Large chunks, small read buffers + +Attempting to read a chunked body with a buffer that is shorter than the chunk size in hex will +result in an error. This limitation is temporary. */ pub struct ReceivedBody<'conn, Transport> { @@ -44,6 +59,14 @@ pub struct ReceivedBody<'conn, Transport> { state: MutCow<'conn, ReceivedBodyState>, on_completion: Option>, encoding: &'static Encoding, + max_len: u64, + initial_len: usize, + copy_loops_per_yield: usize, +} + +fn slice_from(min: u64, buf: &[u8]) -> Option<&[u8]> { + buf.get(usize::try_from(min).unwrap_or(usize::MAX)..) + .filter(|buf| !buf.is_empty()) } impl<'conn, Transport> ReceivedBody<'conn, Transport> @@ -59,6 +82,28 @@ where state: impl Into>, on_completion: Option>, encoding: &'static Encoding, + ) -> Self { + Self::new_with_config( + content_length, + buffer, + transport, + state, + on_completion, + encoding, + &DEFAULT_CONFIG, + ) + } + + #[allow(missing_docs)] + #[doc(hidden)] + fn new_with_config( + content_length: Option, + buffer: impl Into>>>, + transport: impl Into>, + state: impl Into>, + on_completion: Option>, + encoding: &'static Encoding, + config: &HttpConfig, ) -> Self { Self { content_length, @@ -67,6 +112,9 @@ where state: state.into(), on_completion, encoding, + max_len: config.received_body_max_len, + initial_len: config.received_body_initial_len, + copy_loops_per_yield: config.copy_loops_per_yield, } } @@ -104,6 +152,10 @@ where /// /// This will return an error if there is an IO error on the /// underlying transport such as a disconnect + /// + /// This will also return an error if the length exceeds the maximum length. To override this + /// value on this specific body, use [`ReceivedBody::with_max_len`] or + /// [`ReceivedBody::set_max_len`] pub async fn read_string(self) -> crate::Result { let encoding = self.encoding(); let bytes = self.read_bytes().await?; @@ -118,22 +170,44 @@ where .unwrap_or_default() } - /// Similar to [`ReceivedBody::read_string`], but returns the raw - /// bytes. This is useful for bodies that are not text. + /// Set the maximum length that can be read from this body before error + pub fn set_max_len(&mut self, max_len: u64) { + self.max_len = max_len; + } + + /// chainable setter for the maximum length that can be read from this body before error + #[must_use] + pub fn with_max_len(mut self, max_len: u64) -> Self { + self.set_max_len(max_len); + self + } + + /// Similar to [`ReceivedBody::read_string`], but returns the raw bytes. This is useful for + /// bodies that are not text. /// - /// You can use this in conjunction with `encoding` if you need - /// different handling of malformed character encoding than the lossy - /// conversion provided by [`ReceivedBody::read_string`]. + /// You can use this in conjunction with `encoding` if you need different handling of malformed + /// character encoding than the lossy conversion provided by [`ReceivedBody::read_string`]. /// /// # Errors /// - /// This will return an error if there is an IO error on the - /// underlying transport such as a disconnect + /// This will return an error if there is an IO error on the underlying transport such as a + /// disconnect + /// + /// This will also return an error if the length exceeds + /// [`received_body_max_len`][HttpConfig::with_received_body_max_len]. To override this value on + /// this specific body, use [`ReceivedBody::with_max_len`] or [`ReceivedBody::set_max_len`] pub async fn read_bytes(mut self) -> crate::Result> { let mut vec = if let Some(len) = self.content_length { - Vec::with_capacity(len.try_into().unwrap_or(usize::max_value())) + if len > self.max_len { + return Err(crate::Error::ReceivedBodyTooLong(self.max_len)); + } + + let len = usize::try_from(len) + .map_err(|_| crate::Error::ReceivedBodyTooLong(self.max_len))?; + + Vec::with_capacity(len) } else { - Vec::new() + Vec::with_capacity(self.initial_len) }; self.read_to_end(&mut vec).await?; @@ -141,37 +215,35 @@ where } /** - returns the character encoding of this body, usually - determined from the content type (mime-type) of the associated - Conn. + returns the character encoding of this body, usually determined from the content type + (mime-type) of the associated Conn. */ pub fn encoding(&self) -> &'static Encoding { self.encoding } fn read_raw(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { - if let Some(transport) = self.transport.as_mut() { - read_raw(&mut self.buffer, &mut **transport, cx, buf) + if let Some(transport) = self.transport.as_deref_mut() { + read_raw(&mut self.buffer, transport, cx, buf) } else { Ready(Err(ErrorKind::NotConnected.into())) } } /** - Consumes the remainder of this body from the underlying transport - by reading it to the end and discarding the contents. This is - important for http1.1 keepalive, but most of the time you do not - need to directly call this. It returns the number of bytes - consumed. + Consumes the remainder of this body from the underlying transport by reading it to the end and + discarding the contents. This is important for http1.1 keepalive, but most of the time you do + not need to directly call this. It returns the number of bytes consumed. # Errors - This will return an [`std::io::Result::Err`] if there is an io - error on the underlying transport, such as a disconnect + This will return an [`std::io::Result::Err`] if there is an io error on the underlying + transport, such as a disconnect */ #[allow(clippy::missing_errors_doc)] // false positive pub async fn drain(self) -> io::Result { - io::copy(self, io::sink()).await + let copy_loops_per_yield = self.copy_loops_per_yield; + copy(self, futures_lite::io::sink(), copy_loops_per_yield).await } } @@ -181,7 +253,7 @@ where { type Output = crate::Result; - type IntoFuture = Pin + Send + 'a>>; + type IntoFuture = Pin + Send + 'a>>; fn into_future(self) -> Self::IntoFuture { Box::pin(async move { self.read_string().await }) @@ -233,16 +305,12 @@ where } } -#[allow( - // the clippy::only_used_in_recursion seems like a false positive, - // it thinks `total` is unused - clippy::only_used_in_recursion, - clippy::cast_possible_truncation -)] fn chunk_decode( - remaining: usize, - mut total: usize, + remaining: u64, + mut chunk_total: u64, + mut total: u64, buf: &mut [u8], + max_len: u64, ) -> io::Result<(ReceivedBodyState, usize, Option>)> { if buf.is_empty() { return Err(io::Error::new( @@ -251,47 +319,43 @@ fn chunk_decode( )); } let mut ranges_to_keep = vec![]; - let mut chunk_start = 0; + let mut chunk_start = 0u64; let mut chunk_end = remaining; let (request_body_state, unused) = loop { if chunk_end > 2 { - let keep_end = buf.len().min(chunk_end - 2); - ranges_to_keep.push(chunk_start..keep_end); - total += keep_end - chunk_start; + let keep_start = usize::try_from(chunk_start).unwrap_or(usize::MAX); + let keep_end = buf + .len() + .min(usize::try_from(chunk_end - 2).unwrap_or(usize::MAX)); + ranges_to_keep.push(keep_start..keep_end); + let new_bytes = (keep_end - keep_start) as u64; + chunk_total += new_bytes; + total += new_bytes; + if total > max_len { + return Err(io::Error::new(ErrorKind::Unsupported, "content too long")); + } } chunk_start = chunk_end; - if chunk_start >= buf.len() { + let Some(buf_to_read) = slice_from(chunk_start, buf) else { break ( Chunked { - remaining: (chunk_start - buf.len()), + remaining: (chunk_start - buf.len() as u64), + chunk_total, total, }, None, ); - } + }; - match httparse::parse_chunk_size(&buf[chunk_start..]) { + match httparse::parse_chunk_size(buf_to_read) { Ok(Status::Complete((framing_bytes, chunk_size))) => { - chunk_start += framing_bytes; - // the #[allow(clippy::cast_possible_truncation)] - // applied to this function is for the following - // line. This may in fact be a bug, and we do not - // handle chunks that are longer than a usize. There - // is no reason 32bit platforms should not be able - // to stream chunks longer than u32::MAX. - chunk_end = 2 + chunk_start + chunk_size as usize; + chunk_start += framing_bytes as u64; + chunk_end = 2 + chunk_start + chunk_size; if chunk_size == 0 { - break ( - End, - if chunk_end < buf.len() { - Some(buf[chunk_end..].to_vec()) - } else { - None - }, - ); + break (End, slice_from(chunk_end, buf).map(Vec::from)); } } @@ -299,25 +363,15 @@ fn chunk_decode( break ( Chunked { remaining: 0, + chunk_total, total, }, - if chunk_start < buf.len() { - Some(buf[chunk_start..].to_vec()) - } else { - None - }, + slice_from(chunk_start, buf).map(Vec::from), ); } - Err(httparse::InvalidChunkSize) => { - log::error!( - "invalid chunk size in buffer:\n\n {}", - String::from_utf8_lossy(&buf[chunk_start..]) - ); - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "invalid chunk size", - )); + Err(InvalidChunkSize) => { + return Err(io::Error::new(ErrorKind::InvalidData, "invalid chunk size")); } } }; @@ -369,7 +423,6 @@ impl<'conn, Transport> AsyncRead for ReceivedBody<'conn, Transport> where Transport: AsyncRead + Unpin + Send + Sync + 'static, { - #[allow(clippy::cast_possible_truncation)] fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -381,13 +434,21 @@ where match self.content_length { Some(0) => End, - Some(total_length) => FixedLength { + Some(total_length) if total_length < self.max_len => FixedLength { current_index: 0, total_length, }, + Some(_) => { + return Ready(Err(io::Error::new( + ErrorKind::Unsupported, + "content too long", + ))) + } + None => Chunked { remaining: 0, + chunk_total: 0, total: 0, }, }, @@ -395,9 +456,26 @@ where None, ), - Chunked { remaining, total } => { + Chunked { + remaining, + chunk_total, + total, + } => { let bytes = ready!(self.read_raw(cx, buf)?); - chunk_decode(remaining, total, &mut buf[..bytes])? + let source_buf = &mut buf[..bytes]; + match chunk_decode(remaining, chunk_total, total, source_buf, self.max_len)? { + (Chunked { remaining: 0, .. }, 0, Some(unused)) + if unused.len() == buf.len() => + { + // we didn't use any of the bytes, which would result in a pathological loop + return Ready(Err(io::Error::new( + ErrorKind::Unsupported, + "read buffer too short", + ))); + } + + other => other, + } } FixedLength { @@ -405,14 +483,7 @@ where total_length, } => { let len = buf.len(); - // the #[allow(clippy::cast_possible_truncation)] on - // this function is for this line. this is ok because - // we are taking the min(remaining bytes, buffer - // length) and the buffer length will always be - // usize. As a result, if we truncate when we cast, - // the result of the min function will always be the - // same - let remaining = (total_length - current_index) as usize; + let remaining = usize::try_from(total_length - current_index).unwrap_or(usize::MAX); let buf = &mut buf[..len.min(remaining)]; let bytes = ready!(self.read_raw(cx, buf)?); let current_index = current_index + bytes as u64; @@ -471,9 +542,9 @@ impl<'conn, Transport> fmt::Debug for ReceivedBody<'conn, Transport> { } } -#[derive(Debug, Clone, Copy, Eq, PartialEq)] -/// the current read state of this body -#[derive(Default)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)] +#[allow(missing_docs)] +#[doc(hidden)] pub enum ReceivedBodyState { /// initial state #[default] @@ -484,11 +555,15 @@ pub enum ReceivedBodyState { Chunked { /// remaining indicates the bytes left _in the current /// chunk_. initial state is zero. - remaining: usize, - /// total indicates the size of the current chunk or zero to + remaining: u64, + + /// chunk_total indicates the size of the current chunk or zero to /// indicate that we expect to read a chunk size at the start /// of the next bytes. initial state is zero. - total: usize, + chunk_total: u64, + + /// total indicates the absolute number of bytes read from all chunks + total: u64, }, /// read state for a fixed-length body. @@ -515,136 +590,3 @@ where Body::new_streaming(rb, len) } } - -// This is commented out because I do not have use for it anymore and -// as I was writing out the documentation I realized it was a footgun -// without a clear use case. I'm retaining it in the code in case it's -// useful later, though -// -// impl<'conn, Transport> ReceivedBody<'conn, Transport> -// where -// Transport: AsyncRead + Unpin + Send + Sync + Clone + 'static, -// { -// /** -// When the transport is Clone, this allows the creation of an owned -// body without taking the original transport away from the Conn. -// -// Caution: You -// probably don't want to use this if it can be avoided, as it opens -// up the potential for two different bodies reading from the same -// transport, and rust will not protect you from those mistakes. -// */ -// pub fn into_owned_by_cloning_transport(mut self) -> ReceivedBody<'static, Transport> { -// ReceivedBody { -// content_length: self.content_length, -// buffer: MutCow::Owned(self.buffer.take()), -// transport: self.transport.map(|transport| MutCow::Owned((*transport).clone())), -// state: MutCow::Owned(*self.state), -// on_completion: self.on_completion, -// encoding: self.encoding, -// } -// } -// } - -#[cfg(test)] -mod chunk_decode { - use super::{chunk_decode, ReceivedBody, ReceivedBodyState}; - use encoding_rs::UTF_8; - use futures_lite::{io::Cursor, AsyncRead, AsyncReadExt}; - - fn assert_decoded(input: (usize, &str), expected_output: (Option, &str, Option<&str>)) { - let (remaining, input_data) = input; - - let mut buf = input_data.to_string().into_bytes(); - - let (output_state, bytes, unused) = chunk_decode(remaining, 0, &mut buf).unwrap(); - - assert_eq!( - ( - match output_state { - ReceivedBodyState::Chunked { remaining, .. } => Some(remaining), - ReceivedBodyState::End => None, - _ => panic!("unexpected output state {output_state:?}"), - }, - &*String::from_utf8_lossy(&buf[0..bytes]), - unused.as_deref().map(String::from_utf8_lossy).as_deref() - ), - expected_output - ); - } - - async fn read_with_buffers_of_size(reader: &mut R, size: usize) -> crate::Result - where - R: AsyncRead + Unpin, - { - let mut return_buffer = vec![]; - loop { - let mut buf = vec![0; size]; - match reader.read(&mut buf).await? { - 0 => break Ok(String::from_utf8_lossy(&return_buffer).into()), - bytes_read => return_buffer.extend_from_slice(&buf[..bytes_read]), - } - } - } - - fn full_decode_with_size( - input: &str, - poll_size: usize, - ) -> crate::Result<(String, ReceivedBody<'static, Cursor<&str>>)> { - let mut rb = ReceivedBody::new( - None, - None, - Cursor::new(input), - ReceivedBodyState::Chunked { - remaining: 0, - total: 0, - }, - None, - UTF_8, - ); - - let output = trillium_testing::block_on(read_with_buffers_of_size(&mut rb, poll_size))?; - Ok((output, rb)) - } - - #[test] - fn test_full_decode() { - env_logger::try_init().ok(); - - for size in 3..50 { - let input = "5\r\n12345\r\n1\r\na\r\n2\r\nbc\r\n3\r\ndef\r\n0\r\n"; - let (output, _) = full_decode_with_size(input, size).unwrap(); - assert_eq!(output, "12345abcdef", "size: {size}"); - - let input = "7\r\nMozilla\r\n9\r\nDeveloper\r\n7\r\nNetwork\r\n0\r\n\r\n"; - let (output, _) = full_decode_with_size(input, size).unwrap(); - assert_eq!(output, "MozillaDeveloperNetwork", "size: {size}"); - - assert!(full_decode_with_size("", size).is_err()); - } - } - - #[test] - fn test_chunk_start() { - assert_decoded((0, "5\r\n12345\r\n"), (Some(0), "12345", None)); - assert_decoded((0, "F\r\n1"), (Some(14 + 2), "1", None)); - assert_decoded((0, "5\r\n123"), (Some(2 + 2), "123", None)); - assert_decoded((0, "1\r\nX\r\n1\r\nX\r\n"), (Some(0), "XX", None)); - assert_decoded((0, "1\r\nX\r\n1\r\nX\r\n1"), (Some(0), "XX", Some("1"))); - assert_decoded((0, "FFF\r\n"), (Some(0xfff + 2), "", None)); - assert_decoded((10, "hello"), (Some(5), "hello", None)); - assert_decoded( - (7, "hello\r\nA\r\n world"), - (Some(4 + 2), "hello world", None), - ); - assert_decoded( - (0, "e\r\ntest test test\r\n0\r\n\r\n"), - (None, "test test test", None), - ); - assert_decoded( - (0, "1\r\n_\r\n0\r\n\r\nnext request"), - (None, "_", Some("next request")), - ); - assert_decoded((7, "hello\r\n0\r\n\r\n"), (None, "hello", None)); - } -} diff --git a/http/src/received_body/tests.rs b/http/src/received_body/tests.rs new file mode 100644 index 0000000000..7e1a098101 --- /dev/null +++ b/http/src/received_body/tests.rs @@ -0,0 +1,3 @@ +use super::*; +mod chunked; +mod fixed_length; diff --git a/http/src/received_body/tests/chunked.rs b/http/src/received_body/tests/chunked.rs new file mode 100644 index 0000000000..f7149373ad --- /dev/null +++ b/http/src/received_body/tests/chunked.rs @@ -0,0 +1,220 @@ +use super::{chunk_decode, ReceivedBody, ReceivedBodyState}; +use crate::{http_config::DEFAULT_CONFIG, HttpConfig}; +use encoding_rs::UTF_8; +use futures_lite::{io::Cursor, AsyncRead, AsyncReadExt}; +use trillium_testing::block_on; + +fn assert_decoded( + (remaining, input_data): (u64, &str), + expected_output: (Option, &str, Option<&str>), +) { + let mut buf = input_data.to_string().into_bytes(); + + let (output_state, bytes, unused) = chunk_decode( + remaining, + 0, + 0, + &mut buf, + DEFAULT_CONFIG.received_body_max_len, + ) + .unwrap(); + + assert_eq!( + ( + match output_state { + ReceivedBodyState::Chunked { remaining, .. } => Some(remaining), + ReceivedBodyState::End => None, + _ => panic!("unexpected output state {output_state:?}"), + }, + &*String::from_utf8_lossy(&buf[0..bytes]), + unused.as_deref().map(String::from_utf8_lossy).as_deref() + ), + expected_output + ); +} + +async fn read_with_buffers_of_size(reader: &mut R, size: usize) -> crate::Result +where + R: AsyncRead + Unpin, +{ + let mut return_buffer = vec![]; + loop { + let mut buf = vec![0; size]; + match reader.read(&mut buf).await? { + 0 => break Ok(String::from_utf8_lossy(&return_buffer).into()), + bytes_read => return_buffer.extend_from_slice(&buf[..bytes_read]), + } + } +} + +fn new_with_config(input: String, config: &HttpConfig) -> ReceivedBody<'_, Cursor> { + ReceivedBody::new_with_config( + None, + None, + Cursor::new(input), + ReceivedBodyState::Start, + None, + UTF_8, + config, + ) +} + +async fn decode_with_config( + input: String, + poll_size: usize, + config: &HttpConfig, +) -> crate::Result { + let mut rb = new_with_config(input, config); + read_with_buffers_of_size(&mut rb, poll_size).await +} + +async fn decode(input: String, poll_size: usize) -> crate::Result { + decode_with_config(input, poll_size, &DEFAULT_CONFIG).await +} + +#[test] +fn test_full_decode() { + block_on(async { + for size in 3..50 { + let input = "5\r\n12345\r\n1\r\na\r\n2\r\nbc\r\n3\r\ndef\r\n0\r\n"; + let output = decode(input.into(), size).await.unwrap(); + assert_eq!(output, "12345abcdef", "size: {size}"); + + let input = "7\r\nMozilla\r\n9\r\nDeveloper\r\n7\r\nNetwork\r\n0\r\n\r\n"; + let output = decode(input.into(), size).await.unwrap(); + assert_eq!(output, "MozillaDeveloperNetwork", "size: {size}"); + + assert!(decode(String::new(), size).await.is_err()); + } + }); +} + +async fn build_chunked_body(input: String) -> String { + let mut output = Vec::with_capacity(10); + let len = crate::copy( + crate::Body::new_streaming(Cursor::new(input), None), + &mut output, + 16, + ) + .await + .unwrap(); + + output.truncate(len.try_into().unwrap()); + String::from_utf8(output).unwrap() +} + +#[test] +fn test_read_buffer_too_short() { + block_on(async { + let input = "test ".repeat(50); + let chunked = build_chunked_body(input.clone()).await; + assert!(chunked.starts_with("FA\r\n")); + + for size in 1..4 { + assert!(decode(chunked.clone(), size).await.is_err()); + } + + for size in 4..10 { + assert_eq!(&decode(chunked.clone(), size).await.unwrap(), &input); + } + }); +} + +#[test] +fn test_max_len() { + block_on(async { + let input = build_chunked_body("test ".repeat(10)).await; + + for size in 4..10 { + assert!(decode_with_config( + input.clone(), + size, + &HttpConfig::default().with_received_body_max_len(5) + ) + .await + .is_err()); + + assert!( + decode_with_config(input.clone(), size, &HttpConfig::default()) + .await + .is_ok() + ); + } + }); +} + +#[test] +fn test_chunk_start() { + assert_decoded((0, "5\r\n12345\r\n"), (Some(0), "12345", None)); + assert_decoded((0, "F\r\n1"), (Some(14 + 2), "1", None)); + assert_decoded((0, "5\r\n123"), (Some(2 + 2), "123", None)); + assert_decoded((0, "1\r\nX\r\n1\r\nX\r\n"), (Some(0), "XX", None)); + assert_decoded((0, "1\r\nX\r\n1\r\nX\r\n1"), (Some(0), "XX", Some("1"))); + assert_decoded((0, "FFF\r\n"), (Some(0xfff + 2), "", None)); + assert_decoded((10, "hello"), (Some(5), "hello", None)); + assert_decoded( + (7, "hello\r\nA\r\n world"), + (Some(4 + 2), "hello world", None), + ); + assert_decoded( + (0, "e\r\ntest test test\r\n0\r\n\r\n"), + (None, "test test test", None), + ); + assert_decoded( + (0, "1\r\n_\r\n0\r\n\r\nnext request"), + (None, "_", Some("next request")), + ); + assert_decoded((7, "hello\r\n0\r\n\r\n"), (None, "hello", None)); +} + +#[test] +fn read_string_and_read_bytes() { + block_on(async { + let content = build_chunked_body("test ".repeat(100)).await; + assert_eq!( + new_with_config(content.clone(), &DEFAULT_CONFIG) + .read_string() + .await + .unwrap() + .len(), + 500 + ); + + assert_eq!( + new_with_config(content.clone(), &DEFAULT_CONFIG) + .read_bytes() + .await + .unwrap() + .len(), + 500 + ); + + assert!(new_with_config( + content.clone(), + &DEFAULT_CONFIG.with_received_body_max_len(400) + ) + .read_string() + .await + .is_err()); + + assert!(new_with_config( + content.clone(), + &DEFAULT_CONFIG.with_received_body_max_len(400) + ) + .read_bytes() + .await + .is_err()); + + assert!(new_with_config(content.clone(), &DEFAULT_CONFIG) + .with_max_len(400) + .read_bytes() + .await + .is_err()); + + assert!(new_with_config(content.clone(), &DEFAULT_CONFIG) + .with_max_len(400) + .read_string() + .await + .is_err()); + }); +} diff --git a/http/src/received_body/tests/fixed_length.rs b/http/src/received_body/tests/fixed_length.rs new file mode 100644 index 0000000000..b00b8b33c4 --- /dev/null +++ b/http/src/received_body/tests/fixed_length.rs @@ -0,0 +1,115 @@ +use super::{ReceivedBody, ReceivedBodyState}; +use crate::{http_config::DEFAULT_CONFIG, HttpConfig}; +use encoding_rs::UTF_8; +use futures_lite::{future::block_on, io::Cursor, AsyncRead, AsyncReadExt}; + +fn new_with_config(input: String, config: &HttpConfig) -> ReceivedBody<'static, Cursor> { + ReceivedBody::new_with_config( + Some(input.len() as u64), + None, + Cursor::new(input), + ReceivedBodyState::Start, + None, + UTF_8, + config, + ) +} + +fn decode_with_config( + input: String, + poll_size: usize, + config: &HttpConfig, +) -> crate::Result { + let mut rb = new_with_config(input, config); + + block_on(read_with_buffers_of_size(&mut rb, poll_size)) +} + +async fn read_with_buffers_of_size(reader: &mut R, size: usize) -> crate::Result +where + R: AsyncRead + Unpin, +{ + let mut return_buffer = vec![]; + loop { + let mut buf = vec![0; size]; + match reader.read(&mut buf).await? { + 0 => break Ok(String::from_utf8_lossy(&return_buffer).into()), + bytes_read => return_buffer.extend_from_slice(&buf[..bytes_read]), + } + } +} + +#[test] +fn test() { + for size in 3..50 { + let input = "12345abcdef"; + let output = decode_with_config(input.into(), size, &DEFAULT_CONFIG).unwrap(); + assert_eq!(output, "12345abcdef", "size: {size}"); + + let input = "MozillaDeveloperNetwork"; + let output = decode_with_config(input.into(), size, &DEFAULT_CONFIG).unwrap(); + assert_eq!(output, "MozillaDeveloperNetwork", "size: {size}"); + + assert!(decode_with_config(String::new(), size, &DEFAULT_CONFIG).is_ok()); + + let input = "MozillaDeveloperNetwork"; + assert!(decode_with_config( + input.into(), + size, + &DEFAULT_CONFIG.with_received_body_max_len(5) + ) + .is_err()); + } +} + +#[test] +fn read_string_and_read_bytes() { + block_on(async { + let content = "test ".repeat(1000); + assert_eq!( + new_with_config(content.clone(), &DEFAULT_CONFIG) + .read_string() + .await + .unwrap() + .len(), + 5000 + ); + + assert_eq!( + new_with_config(content.clone(), &DEFAULT_CONFIG) + .read_bytes() + .await + .unwrap() + .len(), + 5000 + ); + + assert!(new_with_config( + content.clone(), + &DEFAULT_CONFIG.with_received_body_max_len(750) + ) + .read_string() + .await + .is_err()); + + assert!(new_with_config( + content.clone(), + &DEFAULT_CONFIG.with_received_body_max_len(750) + ) + .read_bytes() + .await + .is_err()); + + assert!(new_with_config(content.clone(), &DEFAULT_CONFIG) + .with_max_len(750) + .read_bytes() + .await + .is_err()); + + assert!(new_with_config(content.clone(), &DEFAULT_CONFIG) + .with_max_len(750) + .read_string() + .await + .is_err()); + }); +} diff --git a/http/src/synthetic.rs b/http/src/synthetic.rs index c957e07747..77eee1b3ca 100644 --- a/http/src/synthetic.rs +++ b/http/src/synthetic.rs @@ -1,13 +1,10 @@ use crate::{ - conn::{AfterSend, HttpConfig}, - received_body::ReceivedBodyState, - transport::Transport, - Conn, Headers, KnownHeaderName, Method, StateSet, Stopper, Version, + after_send::AfterSend, http_config::DEFAULT_CONFIG, received_body::ReceivedBodyState, + transport::Transport, Conn, Headers, KnownHeaderName, Method, StateSet, Stopper, Version, }; use futures_lite::io::{AsyncRead, AsyncWrite, Result}; use std::{ pin::Pin, - sync::Arc, task::{Context, Poll}, time::Instant, }; @@ -143,7 +140,7 @@ impl Conn { after_send: AfterSend::default(), start_time: Instant::now(), peer_ip: None, - http_config: Arc::new(HttpConfig::default()), + http_config: DEFAULT_CONFIG, } }