Skip to content

Commit

Permalink
Add Scheme extractor (#2507)
Browse files Browse the repository at this point in the history
  • Loading branch information
bengsparks authored Oct 19, 2024
1 parent 89fec69 commit ffeb4f9
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 0 deletions.
1 change: 1 addition & 0 deletions axum-extra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ json-lines = [
]
multipart = ["dep:multer", "dep:fastrand"]
protobuf = ["dep:prost"]
scheme = []
query = ["dep:serde_html_form"]
tracing = ["axum-core/tracing", "axum/tracing"]
typed-header = ["dep:headers"]
Expand Down
7 changes: 7 additions & 0 deletions axum-extra/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ mod query;
#[cfg(feature = "multipart")]
pub mod multipart;

#[cfg(feature = "scheme")]
mod scheme;

pub use self::{
cached::Cached, host::Host, optional_path::OptionalPath, with_rejection::WithRejection,
};
Expand All @@ -43,6 +46,10 @@ pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejecti
#[cfg(feature = "multipart")]
pub use self::multipart::Multipart;

#[cfg(feature = "scheme")]
#[doc(no_inline)]
pub use self::scheme::{Scheme, SchemeMissing};

#[cfg(feature = "json-deserializer")]
pub use self::json_deserializer::{
JsonDataError, JsonDeserializer, JsonDeserializerRejection, JsonSyntaxError,
Expand Down
152 changes: 152 additions & 0 deletions axum-extra/src/extract/scheme.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//! Extractor that parses the scheme of a request.
//! See [`Scheme`] for more details.

use axum::{
extract::FromRequestParts,
response::{IntoResponse, Response},
};
use http::{
header::{HeaderMap, FORWARDED},
request::Parts,
};
const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto";

/// Extractor that resolves the scheme / protocol of a request.
///
/// The scheme is resolved through the following, in order:
/// - `Forwarded` header
/// - `X-Forwarded-Proto` header
/// - Request URI (If the request is an HTTP/2 request! e.g. use `--http2(-prior-knowledge)` with cURL)
///
/// Note that user agents can set the `X-Forwarded-Proto` header to arbitrary values so make
/// sure to validate them to avoid security issues.
#[derive(Debug, Clone)]
pub struct Scheme(pub String);

/// Rejection type used if the [`Scheme`] extractor is unable to
/// resolve a scheme.
#[derive(Debug)]
pub struct SchemeMissing;

impl IntoResponse for SchemeMissing {
fn into_response(self) -> Response {
(http::StatusCode::BAD_REQUEST, "No scheme found in request").into_response()
}
}

impl<S> FromRequestParts<S> for Scheme
where
S: Send + Sync,
{
type Rejection = SchemeMissing;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Within Forwarded header
if let Some(scheme) = parse_forwarded(&parts.headers) {
return Ok(Scheme(scheme.to_owned()));
}

// X-Forwarded-Proto
if let Some(scheme) = parts
.headers
.get(X_FORWARDED_PROTO_HEADER_KEY)
.and_then(|scheme| scheme.to_str().ok())
{
return Ok(Scheme(scheme.to_owned()));
}

// From parts of an HTTP/2 request
if let Some(scheme) = parts.uri.scheme_str() {
return Ok(Scheme(scheme.to_owned()));
}

Err(SchemeMissing)
}
}

fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
// if there are multiple `Forwarded` `HeaderMap::get` will return the first one
let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;

// get the first set of values
let first_value = forwarded_values.split(',').next()?;

// find the value of the `proto` field
first_value.split(';').find_map(|pair| {
let (key, value) = pair.split_once('=')?;
key.trim()
.eq_ignore_ascii_case("proto")
.then(|| value.trim().trim_matches('"'))
})
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::TestClient;
use axum::{routing::get, Router};
use http::header::HeaderName;

fn test_client() -> TestClient {
async fn scheme_as_body(Scheme(scheme): Scheme) -> String {
scheme
}

TestClient::new(Router::new().route("/", get(scheme_as_body)))
}

#[crate::test]
async fn forwarded_scheme_parsing() {
// the basic case
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "http");

// is case insensitive
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;PROTO=https;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "https");

// multiple values in one header
let headers = header_map(&[(FORWARDED, "proto=ftp, proto=https")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "ftp");

// multiple header values
let headers = header_map(&[(FORWARDED, "proto=ftp"), (FORWARDED, "proto=https")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "ftp");
}

#[crate::test]
async fn x_forwarded_scheme_header() {
let original_scheme = "https";
let scheme = test_client()
.get("/")
.header(X_FORWARDED_PROTO_HEADER_KEY, original_scheme)
.await
.text()
.await;
assert_eq!(scheme, original_scheme);
}

#[crate::test]
async fn precedence_forwarded_over_x_forwarded() {
let scheme = test_client()
.get("/")
.header(X_FORWARDED_PROTO_HEADER_KEY, "https")
.header(FORWARDED, "proto=ftp")
.await
.text()
.await;
assert_eq!(scheme, "ftp");
}

fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
let mut headers = HeaderMap::new();
for (key, value) in values {
headers.append(key, value.parse().unwrap());
}
headers
}
}

0 comments on commit ffeb4f9

Please sign in to comment.