Skip to content

Commit

Permalink
cors: Don't overwrite headers set by the inner service
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte committed Aug 17, 2023
1 parent c969dbe commit 37ec041
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
45 changes: 43 additions & 2 deletions tower-http/src/cors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ mod allow_origin;
mod allow_private_network;
mod expose_headers;
mod max_age;
#[cfg(test)]
mod tests;
mod vary;

pub use self::{
Expand Down Expand Up @@ -681,8 +683,7 @@ where
match self.project().inner.project() {
KindProj::CorsCall { future, headers } => {
let mut response: Response<B> = ready!(future.poll(cx))?;
response.headers_mut().extend(headers.drain());

header_map_append_all(response.headers_mut(), headers.drain());
Poll::Ready(Ok(response))
}
KindProj::PreflightCall { headers } => {
Expand All @@ -695,6 +696,46 @@ where
}
}

// There is an Extend implementation for HeaderMap, but we can't use it because
// it would delete existing header values for any header names that are in
// `headers`.
fn header_map_append_all(
header_map: &mut HeaderMap,
mut cors_header_iter: impl Iterator<Item = (Option<HeaderName>, HeaderValue)>,
) {
if let Some((first_name, first_value)) = cors_header_iter.next() {
let first_name =
first_name.expect("first item of HeaderMap Drain iterator must have a name");

let mut entry = header_map.entry(first_name);
entry = entry_append(entry, first_value);

for (name, value) in cors_header_iter {
// If name is `None`, this iterator item is a new value
// for the same key, reuse the previous entry.
// If it is `Some(name)`, get a new entry for that name
// and start operating on that.
if let Some(name) = name {
entry = header_map.entry(name);
}
entry = entry_append(entry, value);
}
}
}

fn entry_append(
entry: header::Entry<'_, HeaderValue>,
first_value: HeaderValue,
) -> header::Entry<'_, HeaderValue> {
match entry {
header::Entry::Occupied(mut o) => {
o.append(first_value);
header::Entry::Occupied(o)
}
header::Entry::Vacant(v) => header::Entry::Occupied(v.insert_entry(first_value)),
}
}

fn ensure_usable_cors_rules(layer: &CorsLayer) {
if layer.allow_credentials.is_true() {
assert!(
Expand Down
29 changes: 29 additions & 0 deletions tower-http/src/cors/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use std::convert::Infallible;

use http::{header, HeaderValue, Request, Response};
use hyper::Body;
use tower::{service_fn, util::ServiceExt, Layer};

use crate::cors::CorsLayer;

#[tokio::test]
async fn vary_set_by_inner_service() {
const CUSTOM_VARY_HEADERS: HeaderValue = HeaderValue::from_static("accept, accept-encoding");
const PERMISSIVE_CORS_VARY_HEADERS: HeaderValue = HeaderValue::from_static(
"origin, access-control-request-method, access-control-request-headers",
);

async fn inner_svc(_: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::builder()
.header(header::VARY, CUSTOM_VARY_HEADERS)
.body(Body::empty())
.unwrap())
}

let svc = CorsLayer::permissive().layer(service_fn(inner_svc));
let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
let mut vary_headers = res.headers().get_all(header::VARY).into_iter();
assert_eq!(vary_headers.next(), Some(&CUSTOM_VARY_HEADERS));
assert_eq!(vary_headers.next(), Some(&PERMISSIVE_CORS_VARY_HEADERS));
assert_eq!(vary_headers.next(), None);
}

0 comments on commit 37ec041

Please sign in to comment.