Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up RequestParts API #167

Merged
merged 1 commit into from
Aug 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
- `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))
- `ServiceExt` has been removed and its methods have been moved to `RoutingDsl` ([#160](https://github.com/tokio-rs/axum/pull/160))
- `extractor_middleware` now requires `RequestBody: Default` ([#167](https://github.com/tokio-rs/axum/pull/167))
- Convert `RequestAlreadyExtracted` to an enum with each possible error variant ([#167](https://github.com/tokio-rs/axum/pull/167))
- These future types have been moved
- `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved
to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133))
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ bitflags = "1.0"
bytes = "1.0"
futures-util = "0.3"
http = "0.2"
http-body = "0.4.2"
http-body = "0.4.3"
hyper = { version = "0.14", features = ["server", "tcp", "http1", "stream"] }
pin-project-lite = "0.2.7"
regex = "1.5"
Expand Down
10 changes: 10 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ impl Error {
inner: error.into(),
}
}

pub(crate) fn downcast<T>(self) -> Result<T, Self>
where
T: StdError + 'static,
{
match self.inner.downcast::<T>() {
Ok(t) => Ok(*t),
Err(err) => Err(*err.downcast().unwrap()),
}
}
}

impl fmt::Display for Error {
Expand Down
8 changes: 5 additions & 3 deletions src/extract/extractor_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ where
impl<S, E, ReqBody, ResBody> Service<Request<ReqBody>> for ExtractorMiddleware<S, E>
where
E: FromRequest<ReqBody> + 'static,
ReqBody: Send + 'static,
ReqBody: Default + Send + 'static,
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<BoxError>,
Expand Down Expand Up @@ -212,6 +212,7 @@ impl<ReqBody, S, E, ResBody> Future for ResponseFuture<ReqBody, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ReqBody: Default,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<BoxError>,
{
Expand All @@ -223,12 +224,13 @@ where

let new_state = match this.state.as_mut().project() {
StateProj::Extracting { future } => {
let (mut req, extracted) = ready!(future.as_mut().poll(cx));
let (req, extracted) = ready!(future.as_mut().poll(cx));

match extracted {
Ok(_) => {
let mut svc = this.svc.take().expect("future polled after completion");
let future = svc.call(req.into_request());
let req = req.try_into_request().unwrap_or_default();
let future = svc.call(req);
State::Call { future }
}
Err(err) => {
Expand Down
37 changes: 26 additions & 11 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@
//!
//! [`body::Body`]: crate::body::Body

use crate::response::IntoResponse;
use crate::{response::IntoResponse, Error};
use async_trait::async_trait;
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
use rejection::*;
Expand Down Expand Up @@ -397,32 +397,47 @@ impl<B> RequestParts<B> {
}
}

#[allow(clippy::wrong_self_convention)]
pub(crate) fn into_request(&mut self) -> Request<B> {
// this method uses `Error` since we might make this method public one day and then
// `Error` is more flexible.
pub(crate) fn try_into_request(self) -> Result<Request<B>, Error> {
let Self {
method,
uri,
version,
headers,
extensions,
body,
mut headers,
mut extensions,
mut body,
} = self;

let mut req = Request::new(body.take().expect("body already extracted"));
let mut req = if let Some(body) = body.take() {
Request::new(body)
} else {
return Err(Error::new(RequestAlreadyExtracted::BodyAlreadyExtracted(
BodyAlreadyExtracted,
)));
};

*req.method_mut() = method.clone();
*req.uri_mut() = uri.clone();
*req.version_mut() = *version;
*req.method_mut() = method;
*req.uri_mut() = uri;
*req.version_mut() = version;

if let Some(headers) = headers.take() {
*req.headers_mut() = headers;
} else {
return Err(Error::new(
RequestAlreadyExtracted::HeadersAlreadyExtracted(HeadersAlreadyExtracted),
));
}

if let Some(extensions) = extensions.take() {
*req.extensions_mut() = extensions;
} else {
return Err(Error::new(
RequestAlreadyExtracted::ExtensionsAlreadyExtracted(ExtensionsAlreadyExtracted),
));
}

req
Ok(req)
}

/// Gets a reference the request method.
Expand Down
25 changes: 16 additions & 9 deletions src/extract/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ use tower::BoxError;
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Extensions taken by other extractor"]
/// Rejection used if the method has been taken by another extractor.
/// Rejection used if the request extension has been taken by another
/// extractor.
pub struct ExtensionsAlreadyExtracted;
}

define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Headers taken by other extractor"]
/// Rejection used if the URI has been taken by another extractor.
/// Rejection used if the headers has been taken by another extractor.
pub struct HeadersAlreadyExtracted;
}

Expand Down Expand Up @@ -94,13 +95,6 @@ define_rejection! {
pub struct BodyAlreadyExtracted;
}

define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Cannot have two `Request<_>` extractors for a single handler"]
/// Rejection type used if you try and extract the request more than once.
pub struct RequestAlreadyExtracted;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "Form requests must have `Content-Type: x-www-form-urlencoded`"]
Expand Down Expand Up @@ -272,6 +266,19 @@ composite_rejection! {
}
}

composite_rejection! {
/// Rejection used for [`Request<_>`].
///
/// Contains one variant for each way the [`Request<_>`] extractor can fail.
///
/// [`Request<_>`]: http::Request
pub enum RequestAlreadyExtracted {
BodyAlreadyExtracted,
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}

/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit).
///
/// Contains one variant for each way the
Expand Down
68 changes: 53 additions & 15 deletions src/extract/request_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,29 @@ where
type Rejection = RequestAlreadyExtracted;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let RequestParts {
method: _,
uri: _,
version: _,
headers,
extensions,
body,
} = req;

let all_parts = extensions.as_ref().zip(body.as_ref()).zip(headers.as_ref());

if all_parts.is_some() {
Ok(req.into_request())
} else {
Err(RequestAlreadyExtracted)
let req = std::mem::replace(
req,
RequestParts {
method: req.method.clone(),
version: req.version,
uri: req.uri.clone(),
headers: None,
extensions: None,
body: None,
},
);

let err = match req.try_into_request() {
Ok(req) => return Ok(req),
Err(err) => err,
};

match err.downcast::<RequestAlreadyExtracted>() {
Ok(err) => return Err(err),
Err(err) => unreachable!(
"Unexpected error type from `try_into_request`: `{:?}`. This is a bug in axum, please file an issue",
err,
),
}
}
}
Expand Down Expand Up @@ -251,3 +259,33 @@ where
Ok(string)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{body::Body, prelude::*, tests::*};
use http::StatusCode;

#[tokio::test]
async fn multiple_request_extractors() {
async fn handler(_: Request<Body>, _: Request<Body>) {}

let app = route("/", post(handler));

let addr = run_in_background(app).await;

let client = reqwest::Client::new();

let res = client
.post(format!("http://{}", addr))
.body("hi there")
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
res.text().await.unwrap(),
"Cannot have two request body extractors for a single handler"
);
}
}
2 changes: 1 addition & 1 deletion src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ async fn wrong_method_service() {
}

/// Run a `tower::Service` in the background and get a URI for it.
async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
pub(crate) async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
ResBody: http_body::Body + Send + 'static,
Expand Down