diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f3a26d1bf..9c7dadddf9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead - `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108)) - `ServiceExt` has been removed and its methods have been moved to `RoutingDsl` ([#160](https://github.com/tokio-rs/axum/pull/160)) +- `extractor_middleware` now requires `RequestBody: Default` ([#167](https://github.com/tokio-rs/axum/pull/167)) +- Convert `RequestAlreadyExtracted` to an enum with each possible error variant ([#167](https://github.com/tokio-rs/axum/pull/167)) - These future types have been moved - `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133)) diff --git a/Cargo.toml b/Cargo.toml index 48496a2cf9..887bd7a6a1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ bitflags = "1.0" bytes = "1.0" futures-util = "0.3" http = "0.2" -http-body = "0.4.2" +http-body = "0.4.3" hyper = { version = "0.14", features = ["server", "tcp", "http1", "stream"] } pin-project-lite = "0.2.7" regex = "1.5" diff --git a/src/error.rs b/src/error.rs index 849d5ce848..c21a443c38 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,6 +13,16 @@ impl Error { inner: error.into(), } } + + pub(crate) fn downcast(self) -> Result + where + T: StdError + 'static, + { + match self.inner.downcast::() { + Ok(t) => Ok(*t), + Err(err) => Err(*err.downcast().unwrap()), + } + } } impl fmt::Display for Error { diff --git a/src/extract/extractor_middleware.rs b/src/extract/extractor_middleware.rs index 7f28f3cc9a..76d26c7858 100644 --- a/src/extract/extractor_middleware.rs +++ b/src/extract/extractor_middleware.rs @@ -152,7 +152,7 @@ where impl Service> for ExtractorMiddleware where E: FromRequest + 'static, - ReqBody: Send + 'static, + ReqBody: Default + Send + 'static, S: Service, Response = Response> + Clone, ResBody: http_body::Body + Send + Sync + 'static, ResBody::Error: Into, @@ -212,6 +212,7 @@ impl Future for ResponseFuture where E: FromRequest, S: Service, Response = Response>, + ReqBody: Default, ResBody: http_body::Body + Send + Sync + 'static, ResBody::Error: Into, { @@ -223,12 +224,13 @@ where let new_state = match this.state.as_mut().project() { StateProj::Extracting { future } => { - let (mut req, extracted) = ready!(future.as_mut().poll(cx)); + let (req, extracted) = ready!(future.as_mut().poll(cx)); match extracted { Ok(_) => { let mut svc = this.svc.take().expect("future polled after completion"); - let future = svc.call(req.into_request()); + let req = req.try_into_request().unwrap_or_default(); + let future = svc.call(req); State::Call { future } } Err(err) => { diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 635bfd8959..bec4954af3 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -244,7 +244,7 @@ //! //! [`body::Body`]: crate::body::Body -use crate::response::IntoResponse; +use crate::{response::IntoResponse, Error}; use async_trait::async_trait; use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version}; use rejection::*; @@ -397,32 +397,47 @@ impl RequestParts { } } - #[allow(clippy::wrong_self_convention)] - pub(crate) fn into_request(&mut self) -> Request { + // this method uses `Error` since we might make this method public one day and then + // `Error` is more flexible. + pub(crate) fn try_into_request(self) -> Result, Error> { let Self { method, uri, version, - headers, - extensions, - body, + mut headers, + mut extensions, + mut body, } = self; - let mut req = Request::new(body.take().expect("body already extracted")); + let mut req = if let Some(body) = body.take() { + Request::new(body) + } else { + return Err(Error::new(RequestAlreadyExtracted::BodyAlreadyExtracted( + BodyAlreadyExtracted, + ))); + }; - *req.method_mut() = method.clone(); - *req.uri_mut() = uri.clone(); - *req.version_mut() = *version; + *req.method_mut() = method; + *req.uri_mut() = uri; + *req.version_mut() = version; if let Some(headers) = headers.take() { *req.headers_mut() = headers; + } else { + return Err(Error::new( + RequestAlreadyExtracted::HeadersAlreadyExtracted(HeadersAlreadyExtracted), + )); } if let Some(extensions) = extensions.take() { *req.extensions_mut() = extensions; + } else { + return Err(Error::new( + RequestAlreadyExtracted::ExtensionsAlreadyExtracted(ExtensionsAlreadyExtracted), + )); } - req + Ok(req) } /// Gets a reference the request method. diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs index 834a084355..5efcb3aee4 100644 --- a/src/extract/rejection.rs +++ b/src/extract/rejection.rs @@ -13,14 +13,15 @@ use tower::BoxError; define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "Extensions taken by other extractor"] - /// Rejection used if the method has been taken by another extractor. + /// Rejection used if the request extension has been taken by another + /// extractor. pub struct ExtensionsAlreadyExtracted; } define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "Headers taken by other extractor"] - /// Rejection used if the URI has been taken by another extractor. + /// Rejection used if the headers has been taken by another extractor. pub struct HeadersAlreadyExtracted; } @@ -94,13 +95,6 @@ define_rejection! { pub struct BodyAlreadyExtracted; } -define_rejection! { - #[status = INTERNAL_SERVER_ERROR] - #[body = "Cannot have two `Request<_>` extractors for a single handler"] - /// Rejection type used if you try and extract the request more than once. - pub struct RequestAlreadyExtracted; -} - define_rejection! { #[status = BAD_REQUEST] #[body = "Form requests must have `Content-Type: x-www-form-urlencoded`"] @@ -272,6 +266,19 @@ composite_rejection! { } } +composite_rejection! { + /// Rejection used for [`Request<_>`]. + /// + /// Contains one variant for each way the [`Request<_>`] extractor can fail. + /// + /// [`Request<_>`]: http::Request + pub enum RequestAlreadyExtracted { + BodyAlreadyExtracted, + HeadersAlreadyExtracted, + ExtensionsAlreadyExtracted, + } +} + /// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit). /// /// Contains one variant for each way the diff --git a/src/extract/request_parts.rs b/src/extract/request_parts.rs index fbc83d87e6..55c720293f 100644 --- a/src/extract/request_parts.rs +++ b/src/extract/request_parts.rs @@ -18,21 +18,29 @@ where type Rejection = RequestAlreadyExtracted; async fn from_request(req: &mut RequestParts) -> Result { - let RequestParts { - method: _, - uri: _, - version: _, - headers, - extensions, - body, - } = req; - - let all_parts = extensions.as_ref().zip(body.as_ref()).zip(headers.as_ref()); - - if all_parts.is_some() { - Ok(req.into_request()) - } else { - Err(RequestAlreadyExtracted) + let req = std::mem::replace( + req, + RequestParts { + method: req.method.clone(), + version: req.version, + uri: req.uri.clone(), + headers: None, + extensions: None, + body: None, + }, + ); + + let err = match req.try_into_request() { + Ok(req) => return Ok(req), + Err(err) => err, + }; + + match err.downcast::() { + Ok(err) => return Err(err), + Err(err) => unreachable!( + "Unexpected error type from `try_into_request`: `{:?}`. This is a bug in axum, please file an issue", + err, + ), } } } @@ -251,3 +259,33 @@ where Ok(string) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{body::Body, prelude::*, tests::*}; + use http::StatusCode; + + #[tokio::test] + async fn multiple_request_extractors() { + async fn handler(_: Request, _: Request) {} + + let app = route("/", post(handler)); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .post(format!("http://{}", addr)) + .body("hi there") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + res.text().await.unwrap(), + "Cannot have two request body extractors for a single handler" + ); + } +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 58f1d031f0..9de386c5b1 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -605,7 +605,7 @@ async fn wrong_method_service() { } /// Run a `tower::Service` in the background and get a URI for it. -async fn run_in_background(svc: S) -> SocketAddr +pub(crate) async fn run_in_background(svc: S) -> SocketAddr where S: Service, Response = Response> + Clone + Send + 'static, ResBody: http_body::Body + Send + 'static,