Skip to content
This repository has been archived by the owner on Apr 13, 2020. It is now read-only.

Commit

Permalink
Merge pull request fxbox#110 from ferjm/cors.nicer.api
Browse files Browse the repository at this point in the history
Better way to define paths for CORS endpoints. r=fabrice
  • Loading branch information
ferjm committed Mar 1, 2016
2 parents a061a9e + 6bd0c07 commit c28bb63
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions src/service_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@ use iron::status::Status;
use router::Router;
use unicase::UniCase;

type Endpoint = (&'static[Method], &'static[&'static str]);
type Endpoint = (&'static[Method], &'static str);

struct CORS;

impl CORS {
// Only endpoints listed here will allow CORS.
// Endpoints containing a variable path part can use '*' like in:
// &["bar", "*"] for a URL like https://foo.com/bar/123
// Endpoints containing a variable path part can use ':foo' like in:
// "/foo/:bar" for a URL like https://domain.com/foo/123 where 123 is
// variable.
pub const ENDPOINTS: &'static[Endpoint] = &[
(&[Method::Get], &["list.json"]),
(&[Method::Get, Method::Post, Method::Put], &["*", "*"])
(&[Method::Get], "list.json"),
(&[Method::Get, Method::Post, Method::Put], ":service/:command")
];
}

Expand All @@ -39,14 +40,19 @@ impl AfterMiddleware for CORS {
continue;
}

let path: Vec<&str> = if path.starts_with('/') {
path[1..].split('/').collect()
} else {
path[0..].split('/').collect()
};

if path.len() != req.url.path.len() {
continue;
}

for (i, path) in path.iter().enumerate() {
for (i, req_path) in req.url.path.iter().enumerate() {
is_cors_endpoint = false;
if req.url.path[i] != path.to_owned() &&
"*" != path.to_owned() {
if req_path != path[i] && !path[i].starts_with(':') {
break;
}
is_cors_endpoint = true;
Expand Down Expand Up @@ -161,8 +167,10 @@ describe! service_router {
for endpoint in CORS::ENDPOINTS {
let (_, path) = *endpoint;
let path = "http://localhost:3000/".to_owned() +
&(path.join("/").replace("*", "foo"));
let response = request::options(&path, Headers::new(), &service_router).unwrap();
&(path.replace(":", "foo"));
let response = request::options(&path,
Headers::new(),
&service_router).unwrap();
let headers = &response.headers;
assert!(headers.has::<headers::AccessControlAllowOrigin>());
assert!(headers.has::<headers::AccessControlAllowHeaders>());
Expand Down

0 comments on commit c28bb63

Please sign in to comment.