diff --git a/src/filters/multipart.rs b/src/filters/multipart.rs index 00948cb22..36c612939 100644 --- a/src/filters/multipart.rs +++ b/src/filters/multipart.rs @@ -27,7 +27,7 @@ const DEFAULT_FORM_DATA_MAX_LENGTH: u64 = 1024 * 1024 * 2; /// Create with the `warp::multipart::form()` function. #[derive(Debug, Clone)] pub struct FormOptions { - max_length: u64, + max_length: Option, } /// A `Stream` of multipart/form-data `Part`s. @@ -50,7 +50,7 @@ pub struct Part { /// in turn is a `Stream` of bytes. pub fn form() -> FormOptions { FormOptions { - max_length: DEFAULT_FORM_DATA_MAX_LENGTH, + max_length: Some(DEFAULT_FORM_DATA_MAX_LENGTH), } } @@ -59,9 +59,10 @@ pub fn form() -> FormOptions { impl FormOptions { /// Set the maximum byte length allowed for this body. /// + /// `max_length(None)` means that maximum byte length is not checked. /// Defaults to 2MB. - pub fn max_length(mut self, max: u64) -> Self { - self.max_length = max; + pub fn max_length(mut self, max: impl Into>) -> Self { + self.max_length = max.into(); self } } @@ -83,8 +84,7 @@ impl FilterBase for FormOptions { future::ready(mime) }); - let filt = super::body::content_length_limit(self.max_length) - .and(boundary) + let filt = boundary .and(super::body::body()) .map(|boundary: String, body| { let body = BodyIoError(body); @@ -93,9 +93,15 @@ impl FilterBase for FormOptions { } }); - let fut = filt.filter(Internal); - - Box::pin(fut) + if let Some(max_length) = self.max_length { + Box::pin( + super::body::content_length_limit(max_length) + .and(filt) + .filter(Internal), + ) + } else { + Box::pin(filt.filter(Internal)) + } } } diff --git a/tests/multipart.rs b/tests/multipart.rs index 3172367bc..ed98232be 100644 --- a/tests/multipart.rs +++ b/tests/multipart.rs @@ -52,3 +52,51 @@ async fn form_fields() { assert_eq!(&vec[0].0, "foo"); assert_eq!(&vec[0].1, b"bar"); } + +#[tokio::test] +async fn max_length_is_enforced() { + let _ = pretty_env_logger::try_init(); + + let route = multipart::form() + .and_then(|_: multipart::FormData| async { Ok::<(), warp::Rejection>(()) }); + + let boundary = "--abcdef1234--"; + + let req = warp::test::request() + .method("POST") + // Note no content-length header + .header("transfer-encoding", "chunked") + .header( + "content-type", + format!("multipart/form-data; boundary={}", boundary), + ); + + // Intentionally don't add body, as it automatically also adds + // content-length header + let resp = req.filter(&route).await; + assert!(resp.is_err()); +} + +#[tokio::test] +async fn max_length_can_be_disabled() { + let _ = pretty_env_logger::try_init(); + + let route = multipart::form() + .max_length(None) + .and_then(|_: multipart::FormData| async { Ok::<(), warp::Rejection>(()) }); + + let boundary = "--abcdef1234--"; + + let req = warp::test::request() + .method("POST") + .header("transfer-encoding", "chunked") + .header( + "content-type", + format!("multipart/form-data; boundary={}", boundary), + ); + + // Intentionally don't add body, as it automatically also adds + // content-length header + let resp = req.filter(&route).await; + assert!(resp.is_ok()); +}