Skip to content

Commit

Permalink
Clean up RequestParts API
Browse files Browse the repository at this point in the history
In http-body 0.4.3 `BoxBody` implements `Default`. This allows us to
clean up the API of `RequestParts` quite a bit.
  • Loading branch information
davidpdrsn committed Aug 8, 2021
1 parent bc27b09 commit d8a6c76
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 40 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
- `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))
- `ServiceExt` has been removed and its methods have been moved to `RoutingDsl` ([#160](https://github.com/tokio-rs/axum/pull/160))
- `extractor_middleware` now requires `RequestBody: Default` ([#167](https://github.com/tokio-rs/axum/pull/167))
- Convert `RequestAlreadyExtracted` to an enum with each possible error variant ([#167](https://github.com/tokio-rs/axum/pull/167))
- These future types have been moved
- `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved
to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133))
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ bitflags = "1.0"
bytes = "1.0"
futures-util = "0.3"
http = "0.2"
http-body = "0.4.2"
http-body = "0.4.3"
hyper = { version = "0.14", features = ["server", "tcp", "http1", "stream"] }
pin-project-lite = "0.2.7"
regex = "1.5"
Expand Down
10 changes: 10 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ impl Error {
inner: error.into(),
}
}

pub(crate) fn downcast<T>(self) -> Result<T, Self>
where
T: StdError + 'static,
{
match self.inner.downcast::<T>() {
Ok(t) => Ok(*t),
Err(err) => Err(*err.downcast().unwrap()),
}
}
}

impl fmt::Display for Error {
Expand Down
8 changes: 5 additions & 3 deletions src/extract/extractor_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ where
impl<S, E, ReqBody, ResBody> Service<Request<ReqBody>> for ExtractorMiddleware<S, E>
where
E: FromRequest<ReqBody> + 'static,
ReqBody: Send + 'static,
ReqBody: Default + Send + 'static,
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<BoxError>,
Expand Down Expand Up @@ -212,6 +212,7 @@ impl<ReqBody, S, E, ResBody> Future for ResponseFuture<ReqBody, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ReqBody: Default,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<BoxError>,
{
Expand All @@ -223,12 +224,13 @@ where

let new_state = match this.state.as_mut().project() {
StateProj::Extracting { future } => {
let (mut req, extracted) = ready!(future.as_mut().poll(cx));
let (req, extracted) = ready!(future.as_mut().poll(cx));

match extracted {
Ok(_) => {
let mut svc = this.svc.take().expect("future polled after completion");
let future = svc.call(req.into_request());
let req = req.try_into_request().unwrap_or_default();
let future = svc.call(req);
State::Call { future }
}
Err(err) => {
Expand Down
35 changes: 24 additions & 11 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@
//!
//! [`body::Body`]: crate::body::Body
use crate::response::IntoResponse;
use crate::{response::IntoResponse, Error};
use async_trait::async_trait;
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
use rejection::*;
Expand Down Expand Up @@ -397,32 +397,45 @@ impl<B> RequestParts<B> {
}
}

#[allow(clippy::wrong_self_convention)]
pub(crate) fn into_request(&mut self) -> Request<B> {
pub(crate) fn try_into_request(self) -> Result<Request<B>, Error> {
let Self {
method,
uri,
version,
headers,
extensions,
body,
mut headers,
mut extensions,
mut body,
} = self;

let mut req = Request::new(body.take().expect("body already extracted"));
let mut req = if let Some(body) = body.take() {
Request::new(body)
} else {
return Err(Error::new(RequestAlreadyExtracted::BodyAlreadyExtracted(
BodyAlreadyExtracted,
)));
};

*req.method_mut() = method.clone();
*req.uri_mut() = uri.clone();
*req.version_mut() = *version;
*req.method_mut() = method;
*req.uri_mut() = uri;
*req.version_mut() = version;

if let Some(headers) = headers.take() {
*req.headers_mut() = headers;
} else {
return Err(Error::new(
RequestAlreadyExtracted::HeadersAlreadyExtracted(HeadersAlreadyExtracted),
));
}

if let Some(extensions) = extensions.take() {
*req.extensions_mut() = extensions;
} else {
return Err(Error::new(
RequestAlreadyExtracted::ExtensionsAlreadyExtracted(ExtensionsAlreadyExtracted),
));
}

req
Ok(req)
}

/// Gets a reference the request method.
Expand Down
25 changes: 16 additions & 9 deletions src/extract/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ use tower::BoxError;
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Extensions taken by other extractor"]
/// Rejection used if the method has been taken by another extractor.
/// Rejection used if the request extension has been taken by another
/// extractor.
pub struct ExtensionsAlreadyExtracted;
}

define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Headers taken by other extractor"]
/// Rejection used if the URI has been taken by another extractor.
/// Rejection used if the headers has been taken by another extractor.
pub struct HeadersAlreadyExtracted;
}

Expand Down Expand Up @@ -94,13 +95,6 @@ define_rejection! {
pub struct BodyAlreadyExtracted;
}

define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Cannot have two `Request<_>` extractors for a single handler"]
/// Rejection type used if you try and extract the request more than once.
pub struct RequestAlreadyExtracted;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "Form requests must have `Content-Type: x-www-form-urlencoded`"]
Expand Down Expand Up @@ -272,6 +266,19 @@ composite_rejection! {
}
}

composite_rejection! {
/// Rejection used for [`Request<_>`].
///
/// Contains one variant for each way the [`Request<_>`] extractor can fail.
///
/// [`Request<_>`]: http::Request
pub enum RequestAlreadyExtracted {
BodyAlreadyExtracted,
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}

/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit).
///
/// Contains one variant for each way the
Expand Down
68 changes: 53 additions & 15 deletions src/extract/request_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,29 @@ where
type Rejection = RequestAlreadyExtracted;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let RequestParts {
method: _,
uri: _,
version: _,
headers,
extensions,
body,
} = req;

let all_parts = extensions.as_ref().zip(body.as_ref()).zip(headers.as_ref());

if all_parts.is_some() {
Ok(req.into_request())
} else {
Err(RequestAlreadyExtracted)
let req = std::mem::replace(
req,
RequestParts {
method: req.method.clone(),
version: req.version,
uri: req.uri.clone(),
headers: None,
extensions: None,
body: None,
},
);

let err = match req.try_into_request() {
Ok(req) => return Ok(req),
Err(err) => err,
};

match err.downcast::<RequestAlreadyExtracted>() {
Ok(err) => return Err(err),
Err(err) => unreachable!(
"Unexpected error type from `try_into_request`: `{:?}`. This is a bug in axum, please file an issue",
err,
),
}
}
}
Expand Down Expand Up @@ -251,3 +259,33 @@ where
Ok(string)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{body::Body, prelude::*, tests::*};
use http::StatusCode;

#[tokio::test]
async fn multiple_request_extractors() {
async fn handler(_: Request<Body>, _: Request<Body>) {}

let app = route("/", post(handler));

let addr = run_in_background(app).await;

let client = reqwest::Client::new();

let res = client
.post(format!("http://{}", addr))
.body("hi there")
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
res.text().await.unwrap(),
"Cannot have two request body extractors for a single handler"
);
}
}
2 changes: 1 addition & 1 deletion src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ async fn wrong_method_service() {
}

/// Run a `tower::Service` in the background and get a URI for it.
async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
pub(crate) async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
ResBody: http_body::Body + Send + 'static,
Expand Down

0 comments on commit d8a6c76

Please sign in to comment.