Skip to content

Commit

Permalink
cors: Add CallNext to control when call next handlers (#899)
Browse files Browse the repository at this point in the history
* cors: Add `CallNext` to control when call next handlers

* Format Rust code using rustfmt

* fix error

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
chrislearn and github-actions[bot] authored Sep 11, 2024
1 parent 7344d2d commit bbbabda
Showing 1 changed file with 37 additions and 12 deletions.
49 changes: 37 additions & 12 deletions crates/cors/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ impl Cors {
/// Returns a new `CorsHandler` using current cors settings.
pub fn into_handler(self) -> CorsHandler {
self.ensure_usable_cors_rules();
CorsHandler(self)
CorsHandler::new(self, CallNext::default())
}

fn ensure_usable_cors_rules(&self) {
Expand Down Expand Up @@ -262,9 +262,29 @@ impl Cors {
}
}

/// Enum to control when to call next handler.
#[non_exhaustive]
#[derive(Default, Clone, Copy, Eq, PartialEq, Debug)]
pub enum CallNext {
/// Call next handlers before [`CorsHandler`] write data to response.
#[default]
Before,
/// Call next handlers after [`CorsHandler`] write data to response.
After,
}

/// CorsHandler
#[derive(Clone, Debug)]
pub struct CorsHandler(Cors);
pub struct CorsHandler {
cors: Cors,
call_next: CallNext,
}
impl CorsHandler {
/// Create a new `CorsHandler`.
pub fn new(cors: Cors, call_next: CallNext) -> Self {
Self { cors, call_next }
}
}

#[async_trait]
impl Handler for CorsHandler {
Expand All @@ -275,17 +295,19 @@ impl Handler for CorsHandler {
res: &mut Response,
ctrl: &mut FlowCtrl,
) {
let origin = req.headers().get(&header::ORIGIN);
if self.call_next == CallNext::Before {
ctrl.call_next(req, depot, res).await;
}

let origin = req.headers().get(&header::ORIGIN);
let mut headers = HeaderMap::new();

// These headers are applied to both preflight and subsequent regular CORS requests:
// https://fetch.spec.whatwg.org/#http-responses
headers.extend(self.cors.allow_origin.to_header(origin, req, depot));
headers.extend(self.cors.allow_credentials.to_header(origin, req, depot));

headers.extend(self.0.allow_origin.to_header(origin, req, depot));
headers.extend(self.0.allow_credentials.to_header(origin, req, depot));

let mut vary_headers = self.0.vary.values();
let mut vary_headers = self.cors.vary.values();
if let Some(first) = vary_headers.next() {
let mut header = match headers.entry(header::VARY) {
header::Entry::Occupied(_) => {
Expand All @@ -302,16 +324,19 @@ impl Handler for CorsHandler {
// Return results immediately upon preflight request
if req.method() == Method::OPTIONS {
// These headers are applied only to preflight requests
headers.extend(self.0.allow_methods.to_header(origin, req, depot));
headers.extend(self.0.allow_headers.to_header(origin, req, depot));
headers.extend(self.0.max_age.to_header(origin, req, depot));
headers.extend(self.cors.allow_methods.to_header(origin, req, depot));
headers.extend(self.cors.allow_headers.to_header(origin, req, depot));
headers.extend(self.cors.max_age.to_header(origin, req, depot));
res.status_code = Some(StatusCode::NO_CONTENT);
} else {
// This header is applied only to non-preflight requests
headers.extend(self.0.expose_headers.to_header(origin, req, depot));
headers.extend(self.cors.expose_headers.to_header(origin, req, depot));
}
res.headers_mut().extend(headers);
ctrl.call_next(req, depot, res).await;

if self.call_next == CallNext::After {
ctrl.call_next(req, depot, res).await;
}
}
}

Expand Down

0 comments on commit bbbabda

Please sign in to comment.