Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement requiements from #1067 by importing typed headers from 'headers' … #2867

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ memchr = "2"
stable-pattern = "0.1"
cookie = { version = "0.18", features = ["percent-encode"] }
state = "0.6"
headers = "0.4.0"

[dependencies.serde]
version = "1.0"
Expand Down
145 changes: 145 additions & 0 deletions core/http/src/header/header.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use core::str;
use std::borrow::{Borrow, Cow};
use std::fmt;

use headers::{Header as HHeader, HeaderValue};
use indexmap::IndexMap;

use crate::uncased::{Uncased, UncasedStr};
Expand Down Expand Up @@ -798,10 +800,153 @@ impl From<&cookie::Cookie<'_>> for Header<'static> {
}
}

/// A destination for `HeaderValue`s that can be used to accumulate
/// a single header value using from hyperium headers' decode protocol.
#[derive(Default)]
struct HeaderValueDestination {
value: Option<HeaderValue>,
count: usize,
}

impl <'r>HeaderValueDestination {
fn into_value(self) -> HeaderValue {
if let Some(value) = self.value {
// TODO: if value.count > 1, then log that multiple header values are
// generated by the typed header, but that the dropped.
value
} else {
// Perhaps log that the typed header didn't create any values.
// This won't happen in the current implementation (headers 0.4.0).
HeaderValue::from_static("")
}
}

fn into_header_string(self) -> Cow<'static, str> {
let value = self.into_value();
// TODO: Optimize if we know this is a static reference.
value.to_str().unwrap_or("").to_string().into()
}
}

impl Extend<HeaderValue> for HeaderValueDestination {
fn extend<T: IntoIterator<Item = HeaderValue>>(&mut self, iter: T) {
for value in iter {
self.count += 1;
if self.value.is_none() {
self.value = Some(value)
}
}
}
}

macro_rules! from_typed_header {
($($name:ident),*) => ($(
pub use headers::$name;

impl ::std::convert::From<self::$name> for Header<'static> {
fn from(header: self::$name) -> Self {
let mut destination = HeaderValueDestination::default();
header.encode(&mut destination);
let name = self::$name::name();
Header::new(name.as_str(), destination.into_header_string())
}
}
)*)
}

macro_rules! generic_from_typed_header {
($($name:ident<$bound:ident>),*) => ($(
pub use headers::$name;

impl <T1: 'static + $bound>::std::convert::From<self::$name<T1>>
for Header<'static> {
fn from(header: self::$name<T1>) -> Self {
let mut destination = HeaderValueDestination::default();
header.encode(&mut destination);
let name = self::$name::<T1>::name();
Header::new(name.as_str(), destination.into_header_string())
}
}
)*)
}

// The following headers from 'headers' 0.4 are not imported, since they are
// provided by other Rocket features.

// * ContentType, // Content-Type header, defined in RFC7231
// * Cookie, // Cookie header, defined in RFC6265
// * Host, // The Host header.
// * Location, // Location header, defined in RFC7231
// * SetCookie, // Set-Cookie header, defined RFC6265

from_typed_header! {
AcceptRanges, // Accept-Ranges header, defined in RFC7233
AccessControlAllowCredentials, // Access-Control-Allow-Credentials header, part of CORS
AccessControlAllowHeaders, // Access-Control-Allow-Headers header, part of CORS
AccessControlAllowMethods, // Access-Control-Allow-Methods header, part of CORS
AccessControlAllowOrigin, // The Access-Control-Allow-Origin response header, part of CORS
AccessControlExposeHeaders, // Access-Control-Expose-Headers header, part of CORS
AccessControlMaxAge, // Access-Control-Max-Age header, part of CORS
AccessControlRequestHeaders, // Access-Control-Request-Headers header, part of CORS
AccessControlRequestMethod, // Access-Control-Request-Method header, part of CORS
Age, // Age header, defined in RFC7234
Allow, // Allow header, defined in RFC7231
CacheControl, // Cache-Control header, defined in RFC7234 with extensions in RFC8246
Connection, // Connection header, defined in RFC7230
ContentDisposition, // A Content-Disposition header, (re)defined in RFC6266.
ContentEncoding, // Content-Encoding header, defined in RFC7231
ContentLength, // Content-Length header, defined in RFC7230
ContentLocation, // Content-Location header, defined in RFC7231
ContentRange, // Content-Range, described in RFC7233
Date, // Date header, defined in RFC7231
ETag, // ETag header, defined in RFC7232
Expect, // The Expect header.
Expires, // Expires header, defined in RFC7234
IfMatch, // If-Match header, defined in RFC7232
IfModifiedSince, // If-Modified-Since header, defined in RFC7232
IfNoneMatch, // If-None-Match header, defined in RFC7232
IfRange, // If-Range header, defined in RFC7233
IfUnmodifiedSince, // If-Unmodified-Since header, defined in RFC7232
LastModified, // Last-Modified header, defined in RFC7232
Origin, // The Origin header.
Pragma, // The Pragma header defined by HTTP/1.0.
Range, // Range header, defined in RFC7233
Referer, // Referer header, defined in RFC7231
ReferrerPolicy, // Referrer-Policy header, part of Referrer Policy
RetryAfter, // The Retry-After header.
SecWebsocketAccept, // The Sec-Websocket-Accept header.
SecWebsocketKey, // The Sec-Websocket-Key header.
SecWebsocketVersion, // The Sec-Websocket-Version header.
Server, // Server header, defined in RFC7231
StrictTransportSecurity, // StrictTransportSecurity header, defined in RFC6797
Te, // TE header, defined in RFC7230
TransferEncoding, // Transfer-Encoding header, defined in RFC7230
Upgrade, // Upgrade header, defined in RFC7230
UserAgent, // User-Agent header, defined in RFC7231
Vary // Vary header, defined in RFC7231
}

generic_from_typed_header! {
Authorization<Credentials>, // Authorization header, defined in RFC7235
ProxyAuthorization<Credentials> // Proxy-Authorization header, defined in RFC7235
}

pub use headers::authorization::Credentials;

#[cfg(test)]
mod tests {
use std::time::SystemTime;

use super::HeaderMap;

#[test]
fn add_typed_header() {
use super::LastModified;
let mut map = HeaderMap::new();
map.add(LastModified::from(SystemTime::now()));
assert!(map.get_one("last-modified").unwrap().contains("GMT"));
}

#[test]
fn case_insensitive_add_get() {
let mut map = HeaderMap::new();
Expand Down
13 changes: 12 additions & 1 deletion core/http/src/header/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,18 @@ mod proxy_proto;
pub use self::content_type::ContentType;
pub use self::accept::{Accept, QMediaType};
pub use self::media_type::MediaType;
pub use self::header::{Header, HeaderMap};
pub use self::header::{
Header, HeaderMap, AcceptRanges, AccessControlAllowCredentials,
AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlAllowOrigin,
AccessControlExposeHeaders, AccessControlMaxAge, AccessControlRequestHeaders,
AccessControlRequestMethod, Age, Allow, CacheControl, Connection, ContentDisposition,
ContentEncoding, ContentLength, ContentLocation, ContentRange, Date, ETag, Expect,
Expires, IfMatch, IfModifiedSince, IfNoneMatch, IfRange, IfUnmodifiedSince,
LastModified, Origin, Pragma, Range, Referer, ReferrerPolicy, RetryAfter,
SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion, Server, StrictTransportSecurity,
Te, TransferEncoding, Upgrade, UserAgent, Vary, Authorization, ProxyAuthorization,
Credentials
};
pub use self::proxy_proto::ProxyProto;

pub(crate) use self::media_type::Source;
1 change: 1 addition & 0 deletions core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ http = "1"
bytes = "1.4"
hyper = { version = "1.1", default-features = false, features = ["http1", "server"] }
hyper-util = { version = "0.1.3", default-features = false, features = ["http1", "server", "tokio"] }
headers = "0.4.0"

# Non-optional, core dependencies from here on out.
yansi = { version = "1.0.1", features = ["detect-tty"] }
Expand Down
106 changes: 106 additions & 0 deletions core/lib/src/request/from_request_headers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use crate::{outcome::IntoOutcome, Request};
use super::FromRequest;

use headers::{Header as HHeader, HeaderValue as HHeaderValue};
use rocket_http::Status;

macro_rules! typed_headers_from_request {
($($name:ident),*) => ($(
pub use crate::http::$name;

#[rocket::async_trait]
impl<'r> FromRequest<'r> for $name {
type Error = headers::Error;
async fn from_request(req: &'r Request<'_>) ->
crate::request::Outcome<Self, Self::Error> {
req.headers().get($name::name().as_str()).next().or_forward(Status::NotFound)
.and_then(|h| HHeaderValue::from_str(h).or_error(Status::BadRequest))
.map_error(|(s, _)| (s, headers::Error::invalid()))
.and_then(|h| $name::decode(&mut std::iter::once(&h))
.or_forward(Status::BadRequest))
}
}
)*)
}

macro_rules! generic_typed_headers_from_request {
($($name:ident<$bound:ident>),*) => ($(
pub use crate::http::$name;

#[rocket::async_trait]
impl<'r, T1: 'static + $bound> FromRequest<'r> for $name<T1> {
type Error = headers::Error;
async fn from_request(req: &'r Request<'_>) -> crate::request::Outcome<Self, Self::Error> {
req.headers().get($name::<T1>::name().as_str()).next()
.or_forward(Status::NotFound)
.and_then(|h| HHeaderValue::from_str(h).or_error(Status::BadRequest))
.map_error(|(s, _)| (s, headers::Error::invalid()))
.and_then(|h| $name::decode(&mut std::iter::once(&h))
.or_forward(Status::BadRequest))
}
}
)*)
}

// The following headers from 'headers' 0.4 are not imported, since they are
// provided by other Rocket features.

// * ContentType, // Content-Type header, defined in RFC7231
// * Cookie, // Cookie header, defined in RFC6265
// * Host, // The Host header.
// * Location, // Location header, defined in RFC7231
// * SetCookie, // Set-Cookie header, defined RFC6265

typed_headers_from_request! {
AcceptRanges, // Accept-Ranges header, defined in RFC7233
AccessControlAllowCredentials, // Access-Control-Allow-Credentials header, part of CORS
AccessControlAllowHeaders, // Access-Control-Allow-Headers header, part of CORS
AccessControlAllowMethods, // Access-Control-Allow-Methods header, part of CORS
AccessControlAllowOrigin, // The Access-Control-Allow-Origin response header, part of CORS
AccessControlExposeHeaders, // Access-Control-Expose-Headers header, part of CORS
AccessControlMaxAge, // Access-Control-Max-Age header, part of CORS
AccessControlRequestHeaders, // Access-Control-Request-Headers header, part of CORS
AccessControlRequestMethod, // Access-Control-Request-Method header, part of CORS
Age, // Age header, defined in RFC7234
Allow, // Allow header, defined in RFC7231
CacheControl, // Cache-Control header, defined in RFC7234 with extensions in RFC8246
Connection, // Connection header, defined in RFC7230
ContentDisposition, // A Content-Disposition header, (re)defined in RFC6266.
ContentEncoding, // Content-Encoding header, defined in RFC7231
ContentLength, // Content-Length header, defined in RFC7230
ContentLocation, // Content-Location header, defined in RFC7231
ContentRange, // Content-Range, described in RFC7233
Date, // Date header, defined in RFC7231
ETag, // ETag header, defined in RFC7232
Expect, // The Expect header.
Expires, // Expires header, defined in RFC7234
IfMatch, // If-Match header, defined in RFC7232
IfModifiedSince, // If-Modified-Since header, defined in RFC7232
IfNoneMatch, // If-None-Match header, defined in RFC7232
IfRange, // If-Range header, defined in RFC7233
IfUnmodifiedSince, // If-Unmodified-Since header, defined in RFC7232
LastModified, // Last-Modified header, defined in RFC7232
Origin, // The Origin header.
Pragma, // The Pragma header defined by HTTP/1.0.
Range, // Range header, defined in RFC7233
Referer, // Referer header, defined in RFC7231
ReferrerPolicy, // Referrer-Policy header, part of Referrer Policy
RetryAfter, // The Retry-After header.
SecWebsocketAccept, // The Sec-Websocket-Accept header.
SecWebsocketKey, // The Sec-Websocket-Key header.
SecWebsocketVersion, // The Sec-Websocket-Version header.
Server, // Server header, defined in RFC7231
StrictTransportSecurity, // StrictTransportSecurity header, defined in RFC6797
Te, // TE header, defined in RFC7230
TransferEncoding, // Transfer-Encoding header, defined in RFC7230
Upgrade, // Upgrade header, defined in RFC7230
UserAgent, // User-Agent header, defined in RFC7231
Vary // Vary header, defined in RFC7231
}

pub use headers::authorization::Credentials;

generic_typed_headers_from_request! {
Authorization<Credentials>, // Authorization header, defined in RFC7235
ProxyAuthorization<Credentials> // Proxy-Authorization header, defined in RFC7235
}
1 change: 1 addition & 0 deletions core/lib/src/request/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
mod request;
mod from_param;
mod from_request;
mod from_request_headers;
mod atomic_method;

#[cfg(test)]
Expand Down
78 changes: 78 additions & 0 deletions core/lib/tests/typed-headers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#[macro_use]
extern crate rocket;

use std::time::{Duration, SystemTime};
use headers::IfModifiedSince;
use rocket::http::Expires;
use rocket_http::{Header, Status};

#[derive(Responder)]
struct MyResponse {
body: String,
expires: Expires,
}

#[get("/expires")]
fn index() -> MyResponse {
let some_future_time =
SystemTime::UNIX_EPOCH.checked_add(Duration::from_secs(60 * 60 * 24 * 365 * 100)).unwrap();

MyResponse {
body: "Hello, world!".into(),
expires: Expires::from(some_future_time)
}
}

#[get("/data")]
fn get_data_with_opt_header(since: Option<IfModifiedSince>) -> String {
if let Some(time) = since {
format!("GET after: {:}", time::OffsetDateTime::from(SystemTime::from(time)))
} else {
format!("Unconditional GET")
}
}

#[get("/data_since")]
fn get_data_with_header(since: IfModifiedSince) -> String {
format!("GET after: {:}", time::OffsetDateTime::from(SystemTime::from(since)))
}

#[test]
fn respond_with_typed_header() {
let rocket = rocket::build().mount(
"/",
routes![index, get_data_with_opt_header, get_data_with_header]);
let client = rocket::local::blocking::Client::debug(rocket).unwrap();

let response = client.get("/expires").dispatch();
assert_eq!(response.headers().get_one("Expires").unwrap(), "Sat, 07 Dec 2069 00:00:00 GMT");
}

#[test]
fn read_typed_header() {
let rocket = rocket::build().mount(
"/",
routes![index, get_data_with_opt_header, get_data_with_header]);
let client = rocket::local::blocking::Client::debug(rocket).unwrap();

let response = client.get("/data").dispatch();
assert_eq!(response.into_string().unwrap(), "Unconditional GET".to_string());

let response = client.get("/data")
.header(Header::new("if-modified-since", "Mon, 07 Dec 2020 00:00:00 GMT")).dispatch();
assert_eq!(response.into_string().unwrap(),
"GET after: 2020-12-07 0:00:00.0 +00:00:00".to_string());

let response = client.get("/data_since")
.header(Header::new("if-modified-since", "Tue, 08 Dec 2020 00:00:00 GMT")).dispatch();
assert_eq!(response.into_string().unwrap(),
"GET after: 2020-12-08 0:00:00.0 +00:00:00".to_string());

let response = client.get("/data_since")
.header(Header::new("if-modified-since", "WTF, 07 Dec 2020 00:00:00 GMT")).dispatch();
assert_eq!(response.status(), Status::BadRequest);

let response = client.get("/data_since")
.header(Header::new("if-modified-since", "\x0c , 07 Dec 2020 00:00:00 GMT")).dispatch();
assert_eq!(response.status(), Status::BadRequest);
}