From 178eeba56611996caa4079c9b772e3deba1cce41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Ml=C3=A1dek?= Date: Fri, 22 Mar 2024 20:35:38 +0100 Subject: [PATCH] docs, tests, and nesting for matchit --- axum/src/docs/routing/without_v07_checks.md | 43 +++++++++++ axum/src/extract/matched_path.rs | 40 ++++++++++ axum/src/routing/mod.rs | 7 ++ axum/src/routing/path_router.rs | 84 ++++++++++++++++----- axum/src/routing/tests/mod.rs | 16 ++++ axum/src/routing/tests/nest.rs | 16 ++++ 6 files changed, 186 insertions(+), 20 deletions(-) create mode 100644 axum/src/docs/routing/without_v07_checks.md diff --git a/axum/src/docs/routing/without_v07_checks.md b/axum/src/docs/routing/without_v07_checks.md new file mode 100644 index 0000000000..1eb377e62b --- /dev/null +++ b/axum/src/docs/routing/without_v07_checks.md @@ -0,0 +1,43 @@ +Turn off checks for compatibility with route matching syntax from 0.7. + +This allows usage of paths starting with a colon `:` or an asterisk `*` which are otherwise prohibited. + +# Example + +```rust +use axum::{ + routing::get, + Router, +}; + +let app = Router::<()>::new() + .without_v07_checks() + .route("/:colon", get(|| async {})) + .route("/*asterisk", get(|| async {})); + +// Our app now accepts +// - GET /:colon +// - GET /*asterisk +# let _: Router = app; +``` + +Adding such routes without calling this method first will panic. + +```rust,should_panic +use axum::{ + routing::get, + Router, +}; + +// This panics... +let app = Router::<()>::new() + .route("/:colon", get(|| async {})); +``` + +# Merging + +When two routers are merged, v0.7 checks are disabled if both of the two routers had them also disabled. + +# Nesting + +Each router needs to have the checks explicitly disabled. Nesting a router with the checks either enabled or disabled has no effect on the outer router. diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 3e9d29fd33..d51d36c2fe 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -351,4 +351,44 @@ mod tests { let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } + + #[crate::test] + async fn matching_colon() { + let app = Router::new().without_v07_checks().route( + "/:foo", + get(|path: MatchedPath| async move { path.as_str().to_owned() }), + ); + + let client = TestClient::new(app); + + let res = client.get("/:foo").await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "/:foo"); + + let res = client.get("/:bar").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let res = client.get("/foo").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } + + #[crate::test] + async fn matching_asterisk() { + let app = Router::new().without_v07_checks().route( + "/*foo", + get(|path: MatchedPath| async move { path.as_str().to_owned() }), + ); + + let client = TestClient::new(app); + + let res = client.get("/*foo").await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "/*foo"); + + let res = client.get("/*bar").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let res = client.get("/foo").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 68a3f295d7..58babe9f5f 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -154,6 +154,13 @@ where } } + #[doc = include_str!("../docs/routing/without_v07_checks.md")] + pub fn without_v07_checks(self) -> Self { + self.tap_inner_mut(|this| { + this.path_router.without_v07_checks(); + }) + } + #[doc = include_str!("../docs/routing/route.md")] #[track_caller] pub fn route(self, path: &str, method_router: MethodRouter) -> Self { diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index 32b27f18b0..8deceb25b2 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -14,6 +14,7 @@ pub(super) struct PathRouter { routes: HashMap>, node: Arc, prev_route_id: RouteId, + v7_checks: bool, } impl PathRouter @@ -32,26 +33,56 @@ where } } +fn validate_path(v7_checks: bool, path: &str) -> Result<(), &'static str> { + if path.is_empty() { + return Err("Paths must start with a `/`. Use \"/\" for root routes"); + } else if !path.starts_with('/') { + return Err("Paths must start with a `/`"); + } + + if v7_checks { + validate_v07_paths(path)?; + } + + Ok(()) +} + +fn validate_v07_paths(path: &str) -> Result<(), &'static str> { + path.split('/') + .find_map(|segment| { + if segment.starts_with(':') { + Some(Err( + "Path segments must not start with `:`. For capture groups, use \ + `{capture}`. If you meant to literally match a segment starting with \ + a colon, call `without_v07_checks` on the router.", + )) + } else if segment.starts_with('*') { + Some(Err( + "Path segments must not start with `*`. For wildcard capture, use \ + `{*wildcard}`. If you meant to literally match a segment starting with \ + an asterisk, call `without_v07_checks` on the router.", + )) + } else { + None + } + }) + .unwrap_or(Ok(())) +} + impl PathRouter where S: Clone + Send + Sync + 'static, { + pub(super) fn without_v07_checks(&mut self) { + self.v7_checks = false; + } + pub(super) fn route( &mut self, path: &str, method_router: MethodRouter, ) -> Result<(), Cow<'static, str>> { - fn validate_path(path: &str) -> Result<(), &'static str> { - if path.is_empty() { - return Err("Paths must start with a `/`. Use \"/\" for root routes"); - } else if !path.starts_with('/') { - return Err("Paths must start with a `/`"); - } - - Ok(()) - } - - validate_path(path)?; + validate_path(self.v7_checks, path)?; let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self .node @@ -97,11 +128,7 @@ where path: &str, endpoint: Endpoint, ) -> Result<(), Cow<'static, str>> { - if path.is_empty() { - return Err("Paths must start with a `/`. Use \"/\" for root routes".into()); - } else if !path.starts_with('/') { - return Err("Paths must start with a `/`".into()); - } + validate_path(self.v7_checks, path)?; let id = self.next_route_id(); self.set_node(path, id)?; @@ -128,8 +155,12 @@ where routes, node, prev_route_id: _, + v7_checks, } = other; + // If either of the two did not allow paths starting with `:` or `*`, do not allow them for the merged router either. + self.v7_checks |= v7_checks; + for (id, route) in routes { let path = node .route_id_to_path @@ -165,12 +196,14 @@ where path_to_nest_at: &str, router: PathRouter, ) -> Result<(), Cow<'static, str>> { - let prefix = validate_nest_path(path_to_nest_at); + let prefix = validate_nest_path(self.v7_checks, path_to_nest_at); let PathRouter { routes, node, prev_route_id: _, + // Ignore the configuration of the nested router + v7_checks: _, } = router; for (id, endpoint) in routes { @@ -208,7 +241,7 @@ where T::Response: IntoResponse, T::Future: Send + 'static, { - let path = validate_nest_path(path_to_nest_at); + let path = validate_nest_path(self.v7_checks, path_to_nest_at); let prefix = path; let path = if path.ends_with('/') { @@ -258,6 +291,7 @@ where routes, node: self.node, prev_route_id: self.prev_route_id, + v7_checks: self.v7_checks, } } @@ -290,6 +324,7 @@ where routes, node: self.node, prev_route_id: self.prev_route_id, + v7_checks: self.v7_checks, } } @@ -312,6 +347,7 @@ where routes, node: self.node, prev_route_id: self.prev_route_id, + v7_checks: self.v7_checks, } } @@ -394,6 +430,7 @@ impl Default for PathRouter { routes: Default::default(), node: Default::default(), prev_route_id: RouteId(0), + v7_checks: true, } } } @@ -413,6 +450,7 @@ impl Clone for PathRouter { routes: self.routes.clone(), node: self.node.clone(), prev_route_id: self.prev_route_id, + v7_checks: self.v7_checks, } } } @@ -459,16 +497,22 @@ impl fmt::Debug for Node { } #[track_caller] -fn validate_nest_path(path: &str) -> &str { +fn validate_nest_path(v7_checks: bool, path: &str) -> &str { if path.is_empty() { // nesting at `""` and `"/"` should mean the same thing return "/"; } - if path.contains('*') { + if path.split('/').any(|segment| { + segment.starts_with("{*") && segment.ends_with('}') && !segment.ends_with("}}") + }) { panic!("Invalid route: nested routes cannot contain wildcards (*)"); } + if v7_checks { + validate_v07_paths(path).unwrap(); + } + path } diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index a49c022907..e3a9d238a7 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -1102,3 +1102,19 @@ async fn locks_mutex_very_little() { assert_eq!(num, 1); } } + +#[crate::test] +#[should_panic( + expected = "Path segments must not start with `:`. For capture groups, use `{capture}`. If you meant to literally match a segment starting with a colon, call `without_v07_checks` on the router." +)] +async fn colon_in_route() { + _ = Router::<()>::new().route("/:foo", get(|| async move {})); +} + +#[crate::test] +#[should_panic( + expected = "Path segments must not start with `*`. For wildcard capture, use `{*wildcard}`. If you meant to literally match a segment starting with an asterisk, call `without_v07_checks` on the router." +)] +async fn asterisk_in_route() { + _ = Router::<()>::new().route("/*foo", get(|| async move {})); +} diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index 2f4a03c843..1fd289ae87 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -418,3 +418,19 @@ nested_route_test!(nest_9, nest = "/a", route = "/a/", expected = "/a/a/"); nested_route_test!(nest_11, nest = "/a/", route = "/", expected = "/a/"); nested_route_test!(nest_12, nest = "/a/", route = "/a", expected = "/a/a"); nested_route_test!(nest_13, nest = "/a/", route = "/a/", expected = "/a/a/"); + +#[crate::test] +#[should_panic( + expected = "Path segments must not start with `:`. For capture groups, use `{capture}`. If you meant to literally match a segment starting with a colon, call `without_v07_checks` on the router." +)] +async fn colon_in_route() { + _ = Router::<()>::new().nest("/:foo", Router::new()); +} + +#[crate::test] +#[should_panic( + expected = "Path segments must not start with `*`. For wildcard capture, use `{*wildcard}`. If you meant to literally match a segment starting with an asterisk, call `without_v07_checks` on the router." +)] +async fn asterisk_in_route() { + _ = Router::<()>::new().nest("/*foo", Router::new()); +}