From 64881ae05458f06261b2e7d0f790184678cc42b9 Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 6 Jul 2016 16:01:20 -0400 Subject: [PATCH] feat(headers): add origin header Add an Origin header so users may properly send CORS requests Addresses #651 --- src/header/common/host.rs | 70 +++++++++++----------- src/header/common/mod.rs | 2 + src/header/common/origin.rs | 112 ++++++++++++++++++++++++++++++++++++ 3 files changed, 146 insertions(+), 38 deletions(-) create mode 100644 src/header/common/origin.rs diff --git a/src/header/common/host.rs b/src/header/common/host.rs index 7c41253241..357122c62e 100644 --- a/src/header/common/host.rs +++ b/src/header/common/host.rs @@ -1,6 +1,8 @@ use header::{Header, HeaderFormat}; use std::fmt; +use std::str::FromStr; use header::parsing::from_one_raw_str; +use url::idna::domain_to_unicode; /// The `Host` header. /// @@ -47,44 +49,7 @@ impl Header for Host { } fn parse_header(raw: &[Vec]) -> ::Result { - from_one_raw_str(raw).and_then(|mut s: String| { - // FIXME: use rust-url to parse this - // https://github.com/servo/rust-url/issues/42 - let idx = { - let slice = &s[..]; - let mut chars = slice.chars(); - chars.next(); - if chars.next().unwrap() == '[' { - match slice.rfind(']') { - Some(idx) => { - if slice.len() > idx + 2 { - Some(idx + 1) - } else { - None - } - } - None => return Err(::Error::Header) // this is a bad ipv6 address... - } - } else { - slice.rfind(':') - } - }; - - let port = match idx { - Some(idx) => s[idx + 1..].parse().ok(), - None => None - }; - - match idx { - Some(idx) => s.truncate(idx), - None => () - } - - Ok(Host { - hostname: s, - port: port - }) - }) + from_one_raw_str(raw) } } @@ -97,6 +62,35 @@ impl HeaderFormat for Host { } } +impl fmt::Display for Host { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(f) + } +} + +impl FromStr for Host { + type Err = ::Error; + + fn from_str(s: &str) -> ::Result { + let (host_port, res) = domain_to_unicode(s); + if res.is_err() { + return Err(::Error::Header) + } + let idx = host_port.rfind(':'); + let port = idx.and_then( + |idx| s[idx + 1..].parse().ok() + ); + let hostname = match idx { + None => host_port, + Some(idx) => host_port[..idx].to_owned() + }; + Ok(Host { + hostname: hostname, + port: port + }) + } +} + #[cfg(test)] mod tests { use super::Host; diff --git a/src/header/common/mod.rs b/src/header/common/mod.rs index 279362cae9..6ec2efafb7 100644 --- a/src/header/common/mod.rs +++ b/src/header/common/mod.rs @@ -43,6 +43,7 @@ pub use self::if_unmodified_since::IfUnmodifiedSince; pub use self::if_range::IfRange; pub use self::last_modified::LastModified; pub use self::location::Location; +pub use self::origin::Origin; pub use self::pragma::Pragma; pub use self::prefer::{Prefer, Preference}; pub use self::preference_applied::PreferenceApplied; @@ -406,6 +407,7 @@ mod if_range; mod if_unmodified_since; mod last_modified; mod location; +mod origin; mod pragma; mod prefer; mod preference_applied; diff --git a/src/header/common/origin.rs b/src/header/common/origin.rs new file mode 100644 index 0000000000..3064dc9ed9 --- /dev/null +++ b/src/header/common/origin.rs @@ -0,0 +1,112 @@ +use header::{Header, Host, HeaderFormat}; +use std::fmt; +use std::str::FromStr; +use header::parsing::from_one_raw_str; + +/// The `Origin` header. +/// +/// The `Origin` header is a version of the `Referer` header that is used for all HTTP fetches and `POST`s whose CORS flag is set. +/// This header is often used to inform recipients of the security context of where the request was initiated. +/// +/// +/// Following the spec, https://fetch.spec.whatwg.org/#origin-header, the value of this header is composed of +/// a String (scheme), header::Host (host/port) +/// +/// # Examples +/// ``` +/// use hyper::header::{Headers, Origin}; +/// +/// let mut headers = Headers::new(); +/// headers.set( +/// Origin::new("http", "hyper.rs", None) +/// ); +/// ``` +/// ``` +/// use hyper::header::{Headers, Origin}; +/// +/// let mut headers = Headers::new(); +/// headers.set( +/// Origin::new("https", "wikipedia.org", Some(443)) +/// ); +/// ``` + +#[derive(Clone, Debug)] +pub struct Origin { + /// The scheme, such as http or https + pub scheme: String, + /// The host, such as Host{hostname: "hyper.rs".to_owned(), port: None} + pub host: Host, +} + +impl Origin { + pub fn new, H: Into>(scheme: S, hostname: H, port: Option) -> Origin{ + Origin { + scheme: scheme.into(), + host: Host { + hostname: hostname.into(), + port: port + } + } + } +} + +impl Header for Origin { + fn header_name() -> &'static str { + static NAME: &'static str = "Origin"; + NAME + } + + fn parse_header(raw: &[Vec]) -> ::Result { + from_one_raw_str(raw) + } +} + +impl FromStr for Origin { + type Err = ::Error; + + fn from_str(s: &str) -> ::Result { + let idx = match s.find("://") { + Some(idx) => idx, + None => return Err(::Error::Header) + }; + // idx + 3 because thats how long "://" is + let (scheme, etc) = (&s[..idx], &s[idx + 3..]); + let host = try!(Host::from_str(etc)); + + + Ok(Origin{ + scheme: scheme.to_owned(), + host: host + }) + } +} + +impl HeaderFormat for Origin { + fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}://{}", self.scheme, self.host) + } +} + +impl PartialEq for Origin { + fn eq(&self, other: &Origin) -> bool { + self.scheme == other.scheme && self.host == other.host + } +} + + +#[cfg(test)] +mod tests { + use super::Origin; + use header::Header; + + #[test] + fn test_origin() { + let origin = Header::parse_header([b"http://foo.com".to_vec()].as_ref()); + assert_eq!(origin.ok(), Some(Origin::new("http", "foo.com", None))); + + let origin = Header::parse_header([b"https://foo.com:443".to_vec()].as_ref()); + assert_eq!(origin.ok(), Some(Origin::new("https", "foo.com", Some(443)))); + } +} + +bench_header!(bench, Origin, { vec![b"https://foo.com".to_vec()] });