Skip to content

Commit

Permalink
Make sure nested services still see full URI (#166)
Browse files Browse the repository at this point in the history
They'd previously see the nested URI as we mutated the request. Now we
always route based on the nested URI (if present) without mutating the
request. Also meant we could get rid of `OriginalUri` which is nice.
  • Loading branch information
davidpdrsn authored Aug 8, 2021
1 parent 0674c91 commit bc27b09
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 28 deletions.
8 changes: 2 additions & 6 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@
//!
//! [`body::Body`]: crate::body::Body
use crate::{response::IntoResponse, routing::OriginalUri};
use crate::response::IntoResponse;
use async_trait::async_trait;
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
use rejection::*;
Expand Down Expand Up @@ -378,7 +378,7 @@ impl<B> RequestParts<B> {
let (
http::request::Parts {
method,
mut uri,
uri,
version,
headers,
extensions,
Expand All @@ -387,10 +387,6 @@ impl<B> RequestParts<B> {
body,
) = req.into_parts();

if let Some(original_uri) = extensions.get::<OriginalUri>() {
uri = original_uri.0.clone();
};

RequestParts {
method,
uri,
Expand Down
46 changes: 24 additions & 22 deletions src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ where
}

fn call(&mut self, mut req: Request<B>) -> Self::Future {
if let Some(captures) = self.pattern.full_match(req.uri().path()) {
if let Some(captures) = self.pattern.full_match(&req) {
insert_url_params(&mut req, captures);
let fut = self.svc.clone().oneshot(req);
RouteFuture::a(fut)
Expand Down Expand Up @@ -606,8 +606,8 @@ impl PathPattern {
}))
}

pub(crate) fn full_match(&self, path: &str) -> Option<Captures> {
self.do_match(path).and_then(|match_| {
pub(crate) fn full_match<B>(&self, req: &Request<B>) -> Option<Captures> {
self.do_match(req).and_then(|match_| {
if match_.full_match {
Some(match_.captures)
} else {
Expand All @@ -616,12 +616,18 @@ impl PathPattern {
})
}

pub(crate) fn prefix_match<'a>(&self, path: &'a str) -> Option<(&'a str, Captures)> {
self.do_match(path)
pub(crate) fn prefix_match<'a, B>(&self, req: &'a Request<B>) -> Option<(&'a str, Captures)> {
self.do_match(req)
.map(|match_| (match_.matched, match_.captures))
}

fn do_match<'a>(&self, path: &'a str) -> Option<Match<'a>> {
fn do_match<'a, B>(&self, req: &'a Request<B>) -> Option<Match<'a>> {
let path = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
nested_uri.0.path()
} else {
req.uri().path()
};

self.0.full_path_regex.captures(path).map(|captures| {
let matched = captures.get(0).unwrap();
let full_match = matched.as_str() == path;
Expand Down Expand Up @@ -864,16 +870,15 @@ where
}

fn call(&mut self, mut req: Request<B>) -> Self::Future {
if req.extensions().get::<OriginalUri>().is_none() {
let original_uri = OriginalUri(req.uri().clone());
req.extensions_mut().insert(original_uri);
}
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
let uri = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
&nested_uri.0
} else {
req.uri()
};

let f = if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) {
let without_prefix = strip_prefix(req.uri(), prefix);
req.extensions_mut()
.insert(NestedUri(without_prefix.clone()));
*req.uri_mut() = without_prefix;
let without_prefix = strip_prefix(uri, prefix);
req.extensions_mut().insert(NestedUri(without_prefix));

insert_url_params(&mut req, captures);
let fut = self.svc.clone().oneshot(req);
Expand All @@ -887,11 +892,6 @@ where
}
}

/// `Nested` changes the incoming requests URI. This will be saved as an
/// extension so extractors can still access the original URI.
#[derive(Clone)]
pub(crate) struct OriginalUri(pub(crate) Uri);

fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
Expand Down Expand Up @@ -953,8 +953,9 @@ mod tests {

fn assert_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec);
let req = Request::builder().uri(path).body(()).unwrap();
assert!(
route.full_match(path).is_some(),
route.full_match(&req).is_some(),
"`{}` doesn't match `{}`",
path,
route_spec
Expand All @@ -963,8 +964,9 @@ mod tests {

fn refute_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec);
let req = Request::builder().uri(path).body(()).unwrap();
assert!(
route.full_match(path).is_none(),
route.full_match(&req).is_none(),
"`{}` did match `{}` (but shouldn't)",
path,
route_spec
Expand Down
33 changes: 33 additions & 0 deletions src/tests/nest.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::body::box_body;
use std::collections::HashMap;

#[tokio::test]
Expand Down Expand Up @@ -149,13 +150,15 @@ async fn nested_url_extractor() {
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");

let res = client
.get(format!("http://{}/foo/bar/qux", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/qux");
}

Expand All @@ -181,5 +184,35 @@ async fn nested_url_nested_extractor() {
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/baz");
}

#[tokio::test]
async fn nested_service_sees_original_uri() {
let app = nest(
"/foo",
nest(
"/bar",
route(
"/baz",
service_fn(|req: Request<Body>| async move {
let body = box_body(Body::from(req.uri().to_string()));
Ok::<_, Infallible>(Response::new(body))
}),
),
),
);

let addr = run_in_background(app).await;

let client = reqwest::Client::new();

let res = client
.get(format!("http://{}/foo/bar/baz", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
}

0 comments on commit bc27b09

Please sign in to comment.