From b0308a22cc4fc2bd485eb6d5edd56c3b2d6db72e Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Wed, 29 May 2024 17:48:06 -0700 Subject: [PATCH] feat: add Headers::entry interface as well as various other header related tidying, including improving header related test coverage --- http/src/conn.rs | 13 +- http/src/headers.rs | 240 +++++---------- http/src/headers/entry.rs | 238 +++++++++++++++ http/src/headers/header_name.rs | 40 ++- http/src/headers/header_value.rs | 71 ++++- http/src/headers/header_values.rs | 87 ++++-- http/src/headers/unknown_header_name.rs | 6 + http/tests/headers.rs | 388 ++++++++++++++++++++++++ http/tests/serialize_headers.rs | 4 +- 9 files changed, 875 insertions(+), 212 deletions(-) create mode 100644 http/src/headers/entry.rs create mode 100644 http/tests/headers.rs diff --git a/http/src/conn.rs b/http/src/conn.rs index c8296b9a4c..a85c0d7630 100644 --- a/http/src/conn.rs +++ b/http/src/conn.rs @@ -795,13 +795,14 @@ where .try_insert_with(Date, || httpdate::fmt_http_date(SystemTime::now())); if !matches!(self.status, Some(Status::NotModified | Status::NoContent)) { - if let Some(len) = self.body_len() { - self.response_headers - .try_insert(ContentLength, len.to_string()); - } + let has_content_length = if let Some(len) = self.body_len() { + self.response_headers.try_insert(ContentLength, len); + true + } else { + self.response_headers.has_header(ContentLength) + }; - if self.version == Version::Http1_1 && !self.response_headers.has_header(ContentLength) - { + if self.version == Version::Http1_1 && !has_content_length { self.response_headers.insert(TransferEncoding, "chunked"); } else { self.response_headers.remove(TransferEncoding); diff --git a/http/src/headers.rs b/http/src/headers.rs index 6885bd97fd..4b3252bffa 100644 --- a/http/src/headers.rs +++ b/http/src/headers.rs @@ -1,20 +1,21 @@ +mod entry; mod header_name; mod header_value; mod header_values; mod known_header_name; mod unknown_header_name; +pub use entry::{Entry, OccupiedEntry, VacantEntry}; pub use header_name::HeaderName; pub use header_value::HeaderValue; pub use header_values::HeaderValues; pub use known_header_name::KnownHeaderName; use header_name::HeaderNameInner; -use memchr::memmem::Finder; use unknown_header_name::UnknownHeaderName; use hashbrown::{ - hash_map::{self, Entry}, + hash_map::{self, Entry as HashbrownEntry}, HashMap, }; use smartcow::SmartCow; @@ -22,12 +23,9 @@ use std::collections::{ btree_map::{self, Entry as BTreeEntry}, BTreeMap, }; -use std::{ - fmt::{self, Debug, Display, Formatter}, - hash::Hasher, -}; +use std::fmt::{self, Debug, Display, Formatter}; -use crate::Error; +use crate::headers::entry::{OccupiedEntryInner, VacantEntryInner}; /// Trillium's header map type #[derive(Debug, Clone, PartialEq, Eq, Default)] @@ -63,14 +61,7 @@ impl Display for Headers { } } -#[derive(Debug, Clone, Copy)] -pub struct ParseError; -impl Display for ParseError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.write_str("parse error") - } -} - +#[cfg(feature = "parse")] fn is_tchar(c: u8) -> bool { matches!( c, @@ -95,9 +86,12 @@ fn is_tchar(c: u8) -> bool { ) } +#[cfg(feature = "parse")] impl Headers { #[doc(hidden)] - pub fn extend_parse(&mut self, bytes: &[u8]) -> Result { + pub fn extend_parse(&mut self, bytes: &[u8]) -> Result { + use memchr::memmem::Finder; + let newlines = Finder::new(b"\r\n").find_iter(bytes).collect::>(); // self.reserve(newlines.len().saturating_sub(1)); let mut new_header_count = 0; @@ -132,14 +126,15 @@ impl Headers { Ok(new_header_count) } - #[cfg(feature = "parse")] #[doc(hidden)] pub fn parse(bytes: &[u8]) -> Result { let mut headers = Headers::new(); headers.extend_parse(bytes)?; Ok(headers) } +} +impl Headers { /// Construct a new headers with a default capacity pub fn new() -> Self { Self::default() @@ -167,27 +162,14 @@ impl Headers { /// there is already a header with the same name, the new values /// will be added to the existing ones. To replace any existing /// values, use [`Headers::insert`] - pub fn append(&mut self, name: impl Into>, value: impl Into) { - let value = value.into(); - match name.into().0 { - HeaderNameInner::KnownHeader(known) => match self.known.entry(known) { - BTreeEntry::Occupied(mut o) => { - o.get_mut().extend(value); - } - BTreeEntry::Vacant(v) => { - v.insert(value); - } - }, - - HeaderNameInner::UnknownHeader(unknown) => match self.unknown.entry(unknown) { - Entry::Occupied(mut o) => { - o.get_mut().extend(value); - } - Entry::Vacant(v) => { - v.insert(value); - } - }, - } + /// + /// Identical to [`headers.entry(name).append(values)`][Entry::append] + pub fn append( + &mut self, + name: impl Into>, + values: impl Into, + ) -> &mut HeaderValues { + self.entry(name).append(values) } /// A slightly more efficient way to combine two [`Headers`] than @@ -206,10 +188,10 @@ impl Headers { for (name, value) in other.unknown { match self.unknown.entry(name) { - Entry::Occupied(mut entry) => { + HashbrownEntry::Occupied(mut entry) => { entry.get_mut().extend(value); } - Entry::Vacant(entry) => { + HashbrownEntry::Vacant(entry) => { entry.insert(value); } } @@ -230,35 +212,63 @@ impl Headers { /// Add a header value or header values into this header map. If a /// header already exists with the same name, it will be /// replaced. To combine, see [`Headers::append`] - pub fn insert(&mut self, name: impl Into>, value: impl Into) { - let value = value.into(); - match name.into().0 { - HeaderNameInner::KnownHeader(known) => { - self.known.insert(known, value); - } - - HeaderNameInner::UnknownHeader(unknown) => { - self.unknown.insert(unknown, value); - } - } + pub fn insert( + &mut self, + name: impl Into>, + values: impl Into, + ) { + self.entry(name).insert(values); } /// Add a header value or header values into this header map if /// and only if there is not already a header with the same name. + /// + /// Identical to [`headers.entry(name).or_insert(default)`][Entry::or_insert] pub fn try_insert( &mut self, name: impl Into>, - value: impl Into, + values: impl Into, ) { - let value = value.into(); + self.entry(name).or_insert(values); + } + + /// if a key does not exist already, execute the provided function and insert a value + /// + /// Identical to [`headers.entry(name).or_insert_with(values)`][Entry::or_insert_with] + pub fn try_insert_with( + &mut self, + name: impl Into>, + values: impl FnOnce() -> V, + ) -> &mut HeaderValues + where + V: Into, + { + self.entry(name).or_insert_with(values) + } + + /// Return a view into the entry for this header name, whether or not it is populated. + /// + /// See also [`Entry`] + pub fn entry(&mut self, name: impl Into>) -> Entry<'_> { match name.into().0 { - HeaderNameInner::KnownHeader(known) => { - self.known.entry(known).or_insert(value); - } + HeaderNameInner::KnownHeader(known) => match self.known.entry(known) { + BTreeEntry::Vacant(vacant) => { + Entry::Vacant(VacantEntry(VacantEntryInner::Known(vacant))) + } + BTreeEntry::Occupied(occupied) => { + Entry::Occupied(OccupiedEntry(OccupiedEntryInner::Known(occupied))) + } + }, - HeaderNameInner::UnknownHeader(unknown) => { - self.unknown.entry(unknown).or_insert(value); - } + HeaderNameInner::UnknownHeader(unknown) => match self.unknown.entry(unknown) { + HashbrownEntry::Occupied(occupied) => { + Entry::Occupied(OccupiedEntry(OccupiedEntryInner::Unknown(occupied))) + } + + HashbrownEntry::Vacant(vacant) => { + Entry::Vacant(VacantEntry(VacantEntryInner::Unknown(vacant))) + } + }, } } @@ -275,16 +285,15 @@ impl Headers { self.get_values(name).and_then(HeaderValues::as_lower) } - /// Retrieves a singular header value from this header map. If - /// there are several headers with the same name, this follows the - /// behavior defined at [`HeaderValues::one`]. Returns None if there is no header with the provided header name + /// Retrieves a singular header value from this header map. If there are several headers with + /// the same name, this follows the behavior defined at [`HeaderValues::one`]. Returns None if + /// there is no header with the provided header name pub fn get<'a>(&self, name: impl Into>) -> Option<&HeaderValue> { self.get_values(name).and_then(HeaderValues::one) } - /// Takes all headers with the provided header name out of this - /// header map and returns them. Returns None if the header did - /// not have an entry in this map. + /// Takes all headers with the provided header name out of this header map and returns + /// them. Returns None if the header did not have an entry in this map. pub fn remove<'a>(&mut self, name: impl Into>) -> Option { match name.into().0 { HeaderNameInner::KnownHeader(known) => self.known.remove(&known), @@ -374,53 +383,29 @@ See documentation for deprecation rationale"] } /// Chainable method to remove a header - pub fn without_header(mut self, name: impl Into>) -> Self { + pub fn without_header<'a>(mut self, name: impl Into>) -> Self { self.remove(name); self } /// Chainable method to remove multiple headers by name - pub fn without_headers(mut self, names: I) -> Self + pub fn without_headers<'a, I, H>(mut self, names: I) -> Self where I: IntoIterator, - H: Into>, + H: Into>, { self.remove_all(names); self } /// remove multiple headers by name - pub fn remove_all(&mut self, names: I) + pub fn remove_all<'a, I, H>(&mut self, names: I) where I: IntoIterator, - H: Into>, + H: Into>, { - for header in names { - self.remove(header.into()); - } - } - - /// if a key does not exist already, execute the provided function and insert a value - /// - /// this can be useful to avoid calculating an unnecessary header value, or checking for the - /// presence of a key before insertion - pub fn try_insert_with(&mut self, name: impl Into>, values_fn: F) - where - F: Fn() -> V, - V: Into, - { - match name.into().0 { - HeaderNameInner::KnownHeader(known) => { - self.known - .entry(known) - .or_insert_with(|| values_fn().into()); - } - - HeaderNameInner::UnknownHeader(unknown) => { - self.unknown - .entry(unknown) - .or_insert_with(|| values_fn().into()); - } + for name in names { + self.remove(name); } } } @@ -453,25 +438,6 @@ where } } -#[derive(Default)] -struct DirectHasher(u8); - -impl Hasher for DirectHasher { - fn write(&mut self, _: &[u8]) { - unreachable!("KnownHeaderName calls write_u64"); - } - - #[inline] - fn write_u8(&mut self, i: u8) { - self.0 = i; - } - - #[inline] - fn finish(&self) -> u64 { - u64::from(self.0) - } -} - impl<'a> IntoIterator for &'a Headers { type Item = (HeaderName<'a>, &'a HeaderValues); @@ -499,6 +465,7 @@ impl Iterator for IntoIter { .or_else(|| unknown.next().map(|(k, v)| (HeaderName::from(k), v))) } } + impl From for IntoIter { fn from(value: Headers) -> Self { Self { @@ -544,46 +511,3 @@ impl IntoIterator for Headers { self.into() } } - -#[cfg(test)] -mod tests { - use crate::{Headers, KnownHeaderName}; - - #[test] - fn header_names_are_case_insensitive_for_access_but_retain_initial_case_in_headers() { - let mut headers = Headers::new(); - headers.insert("my-Header-name", "initial-value"); - headers.insert("my-Header-NAME", "my-header-value"); - - assert_eq!(headers.len(), 1); - - assert_eq!( - headers.get_str("My-Header-Name").unwrap(), - "my-header-value" - ); - - headers.append("mY-hEaDer-NaMe", "second-value"); - assert_eq!( - headers.get_values("my-header-name").unwrap(), - ["my-header-value", "second-value"].as_slice() - ); - - assert_eq!( - headers.iter().next().unwrap().0.to_string(), - "my-Header-name" - ); - - assert!(headers.remove("my-HEADER-name").is_some()); - assert!(headers.is_empty()); - } - - #[test] - fn value_case_insensitive_comparison() { - let mut headers = Headers::new(); - headers.insert(KnownHeaderName::Upgrade, "WebSocket"); - headers.insert(KnownHeaderName::Connection, "upgrade"); - - assert!(headers.eq_ignore_ascii_case(KnownHeaderName::Upgrade, "websocket")); - assert!(headers.eq_ignore_ascii_case(KnownHeaderName::Connection, "Upgrade")); - } -} diff --git a/http/src/headers/entry.rs b/http/src/headers/entry.rs new file mode 100644 index 0000000000..2e5fffcfa8 --- /dev/null +++ b/http/src/headers/entry.rs @@ -0,0 +1,238 @@ +use super::{HeaderName, HeaderValues, KnownHeaderName, UnknownHeaderName}; +use hashbrown::hash_map; +use std::{ + collections::btree_map, + fmt::{self, Debug, Formatter}, +}; + +/// A view into the storage for a particular header name +#[derive(Debug)] +pub enum Entry<'a> { + /// A mutable view into the location that header values would be inserted into this + /// [`Headers`][super::Headers] for the specified `HeaderName` + Vacant(VacantEntry<'a>), + /// A mutable view into the header values are stored for the specified `HeaderName` + Occupied(OccupiedEntry<'a>), +} + +/// A view into a vacant entry in particular `Headers` for a given `HeaderName`. +/// +/// It is part of the [`Entry`] enum. +pub struct VacantEntry<'a>(pub(super) VacantEntryInner<'a>); + +impl<'a> Debug for VacantEntry<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("VacantEntry") + .field("name", &self.name()) + .finish() + } +} + +pub(super) enum VacantEntryInner<'a> { + Known(btree_map::VacantEntry<'a, KnownHeaderName, HeaderValues>), + Unknown(hash_map::VacantEntry<'a, UnknownHeaderName<'static>, HeaderValues>), +} + +/// A view into an occupied entry in particular `Headers` for a given `HeaderName`. +/// +/// It is part of the [`Entry`] enum. +pub struct OccupiedEntry<'a>(pub(super) OccupiedEntryInner<'a>); + +impl<'a> Debug for OccupiedEntry<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("OccupiedEntry") + .field("name", &self.name()) + .field("values", &self.values()) + .finish() + } +} + +pub(super) enum OccupiedEntryInner<'a> { + Known(btree_map::OccupiedEntry<'a, KnownHeaderName, HeaderValues>), + Unknown(hash_map::OccupiedEntry<'a, UnknownHeaderName<'static>, HeaderValues>), +} + +impl<'a> Entry<'a> { + /// retrieve the `HeaderName` for this entry. + pub fn name(&self) -> HeaderName<'_> { + match self { + Entry::Vacant(v) => v.name(), + Entry::Occupied(o) => o.name(), + } + } + + /// Sets the value of the entry, and returns a mutable reference to the inserted value. + /// + /// Note that this drops any previous occupied value. + pub fn insert(self, values: impl Into) -> &'a mut HeaderValues { + match self { + Entry::Vacant(v) => v.insert(values), + Entry::Occupied(mut o) => { + o.insert(values); + o.into_mut() + } + } + } + + /// Sets the value of the entry if it is vacant, and appends the new header values to the + /// previous ones if occupied. Returns a mutable reference to the inserted or updated value. + pub fn append(self, values: impl Into) -> &'a mut HeaderValues { + match self { + Entry::Vacant(v) => v.insert(values), + Entry::Occupied(mut o) => { + o.values_mut().extend(values); + o.into_mut() + } + } + } + + /// Sets the value if previously vacant. The provided function is only executed if the entry is + /// vacant. + pub fn or_insert_with>( + self, + values: impl FnOnce() -> V, + ) -> &'a mut HeaderValues { + match self { + Self::Occupied(entry) => entry.into_mut(), + Self::Vacant(entry) => entry.insert(values()), + } + } + + /// Sets the value if previously vacant. See also [`Entry::or_insert_with`] if constructing the + /// default value is expensive. + pub fn or_insert(self, values: impl Into) -> &'a mut HeaderValues { + match self { + Entry::Vacant(vacant) => vacant.insert(values), + Entry::Occupied(occupied) => occupied.into_mut(), + } + } + + /// Provides in-place mutable access to an occupied entry before any + /// potential inserts with [`Entry::or_insert`] or [`Entry::or_insert_with`]. + pub fn and_modify(self, f: impl FnOnce(&mut HeaderValues)) -> Self { + match self { + Self::Occupied(mut entry) => { + f(entry.values_mut()); + Self::Occupied(entry) + } + Self::Vacant(entry) => Self::Vacant(entry), + } + } + + /// Predicate to determine if this is a `VacantEntry` + pub fn is_vacant(&self) -> bool { + matches!(self, Self::Vacant(_)) + } + + /// Predicate to determine if this is an `OccupiedEntry` + pub fn is_occupied(&self) -> bool { + matches!(self, Self::Occupied(_)) + } + + /// Return the `OccupiedEntry`, if this entry is occupied + pub fn occupied(self) -> Option> { + match self { + Entry::Vacant(_) => None, + Entry::Occupied(o) => Some(o), + } + } + + /// Return the `VacantEntry`, if this entry is vacant + pub fn vacant(self) -> Option> { + match self { + Entry::Vacant(v) => Some(v), + Entry::Occupied(_) => None, + } + } +} + +impl<'a> VacantEntry<'a> { + /// Retrieves the `HeaderName` for this `Entry`. + pub fn name(&self) -> HeaderName<'_> { + match &self.0 { + VacantEntryInner::Known(k) => (*k.key()).into(), + VacantEntryInner::Unknown(u) => u.key().reborrow().into(), + } + } + + /// Replace this `VacantEntry` with a value, returning a mutable reference to that value. + pub fn insert(self, values: impl Into) -> &'a mut HeaderValues { + match self.0 { + VacantEntryInner::Known(k) => k.insert(values.into()), + VacantEntryInner::Unknown(u) => u.insert(values.into()), + } + } +} + +impl<'a> OccupiedEntry<'a> { + /// Retrieves the `HeaderName` for this `Entry` + pub fn name(&self) -> HeaderName<'_> { + match &self.0 { + OccupiedEntryInner::Known(known) => (*known.key()).into(), + OccupiedEntryInner::Unknown(unknown) => unknown.key().reborrow().into(), + } + } + + /// Borrows the `HeaderValues` for this `Entry` + pub fn values(&self) -> &HeaderValues { + match &self.0 { + OccupiedEntryInner::Known(known) => known.get(), + OccupiedEntryInner::Unknown(unknown) => unknown.get(), + } + } + + /// Mutate the `HeaderValues` for this `Entry` + pub fn values_mut(&mut self) -> &mut HeaderValues { + match &mut self.0 { + OccupiedEntryInner::Known(known) => known.get_mut(), + OccupiedEntryInner::Unknown(unknown) => unknown.get_mut(), + } + } + + /// Take ownership of the `HeaderName` and `HeaderValues` represented by this entry, removing it + /// from the `Headers` + pub fn remove_entry(self) -> (HeaderName<'static>, HeaderValues) { + match self.0 { + OccupiedEntryInner::Known(known) => { + let (n, v) = known.remove_entry(); + (n.into(), v) + } + OccupiedEntryInner::Unknown(unknown) => { + let (n, v) = unknown.remove_entry(); + (n.into(), v) + } + } + } + + /// Take ownership of the `HeaderValues` contained in this entry, removing it from the `Headers` + pub fn remove(self) -> HeaderValues { + match self.0 { + OccupiedEntryInner::Known(known) => known.remove(), + OccupiedEntryInner::Unknown(unknown) => unknown.remove(), + } + } + + /// Converts this `OccupiedEntry` into a mutable reference to the value in the entry with a + /// lifetime bound to the Headers itself. + /// + /// If you need multiple references to the `OccupiedEntry`, see [`OccupiedEntry::values_mut`] + pub fn into_mut(self) -> &'a mut HeaderValues { + match self.0 { + OccupiedEntryInner::Known(k) => k.into_mut(), + OccupiedEntryInner::Unknown(u) => u.into_mut(), + } + } + + /// Sets the value of the entry, and returns the entry's old value. + pub fn insert(&mut self, values: impl Into) -> HeaderValues { + match &mut self.0 { + OccupiedEntryInner::Known(k) => k.insert(values.into()), + OccupiedEntryInner::Unknown(u) => u.insert(values.into()), + } + } + + /// Adds additional `HeaderValues` to the existing `HeaderValues` in this entry. + pub fn append(&mut self, values: impl Into) { + self.values_mut().extend(values); + } +} diff --git a/http/src/headers/header_name.rs b/http/src/headers/header_name.rs index dbd0f85704..4c5d42f8a0 100644 --- a/http/src/headers/header_name.rs +++ b/http/src/headers/header_name.rs @@ -11,9 +11,16 @@ use HeaderNameInner::{KnownHeader, UnknownHeader}; /// The name of a http header. This can be either a /// [`KnownHeaderName`] or a string representation of an unknown /// header. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct HeaderName<'a>(pub(super) HeaderNameInner<'a>); +impl<'a> Debug for HeaderName<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.0, f) + } +} + +#[cfg(feature = "parse")] impl<'a> HeaderName<'a> { pub(crate) fn parse(bytes: &'a [u8]) -> Result { std::str::from_utf8(bytes) @@ -32,13 +39,22 @@ impl serde::Serialize for HeaderName<'_> { } } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash)] pub(super) enum HeaderNameInner<'a> { /// A `KnownHeaderName` KnownHeader(KnownHeaderName), UnknownHeader(UnknownHeaderName<'a>), } +impl<'a> Debug for HeaderNameInner<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::KnownHeader(known) => Debug::fmt(known, f), + Self::UnknownHeader(unknown) => Debug::fmt(unknown, f), + } + } +} + impl<'a> HeaderName<'a> { /// Convert a potentially-borrowed headername to a static /// headername _by value_. @@ -50,6 +66,14 @@ impl<'a> HeaderName<'a> { }) } + /// Turn a `&'b HeaderName<'a>` into a `HeaderName<'b>` + pub fn reborrow<'b: 'a>(&'b self) -> HeaderName<'b> { + match self.0 { + KnownHeader(khn) => khn.into(), + UnknownHeader(ref uhn) => uhn.reborrow().into(), + } + } + /// Convert a potentially-borrowed headername to a static /// headername _by cloning if needed from a borrow_. If you have /// ownership of a headername with a non-static lifetime, it is @@ -89,6 +113,18 @@ impl PartialEq for &HeaderName<'_> { } } +impl PartialEq for HeaderName<'_> { + fn eq(&self, other: &str) -> bool { + self.as_ref() == other + } +} + +impl PartialEq<&str> for HeaderName<'_> { + fn eq(&self, other: &&str) -> bool { + self.as_ref() == *other + } +} + impl From for HeaderName<'static> { fn from(s: String) -> Self { Self(match s.parse::() { diff --git a/http/src/headers/header_value.rs b/http/src/headers/header_value.rs index 6d512ffc89..311961f2d8 100644 --- a/http/src/headers/header_value.rs +++ b/http/src/headers/header_value.rs @@ -1,24 +1,16 @@ use smallvec::SmallVec; use smartcow::SmartCow; +use smartstring::SmartString; use std::{ borrow::Cow, - fmt::{Debug, Display, Formatter}, + fmt::{Debug, Display, Formatter, Write}, }; use HeaderValueInner::{Bytes, Utf8}; /// A `HeaderValue` represents the right hand side of a single `name: /// value` pair. -#[derive(Eq, PartialEq, Clone)] -pub struct HeaderValue(HeaderValueInner); - -impl HeaderValue { - /// determine if this header contains no unsafe characters (\r, \n, \0) - /// - /// since 0.3.12 - pub fn is_valid(&self) -> bool { - memchr::memchr3(b'\r', b'\n', 0, self.as_ref()).is_none() - } -} +#[derive(Eq, PartialEq, Clone, Ord, PartialOrd)] +pub struct HeaderValue(pub(crate) HeaderValueInner); #[derive(Eq, PartialEq, Clone)] pub(crate) enum HeaderValueInner { @@ -26,6 +18,19 @@ pub(crate) enum HeaderValueInner { Bytes(SmallVec<[u8; 32]>), } +impl PartialOrd for HeaderValueInner { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for HeaderValueInner { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + let this: &[u8] = self.as_ref(); + let that: &[u8] = other.as_ref(); + this.cmp(that) + } +} + #[cfg(feature = "serde")] impl serde::Serialize for HeaderValue { fn serialize(&self, serializer: S) -> Result @@ -49,6 +54,13 @@ impl Debug for HeaderValue { } impl HeaderValue { + /// determine if this header contains no unsafe characters (\r, \n, \0) + /// + /// since 0.3.12 + pub fn is_valid(&self) -> bool { + memchr::memchr3(b'\r', b'\n', 0, self.as_ref()).is_none() + } + /// Returns this header value as a &str if it is utf8, None /// otherwise. If you need to convert non-utf8 bytes to a string /// somehow, match directly on the `HeaderValue` as an enum and @@ -70,7 +82,10 @@ impl HeaderValue { } }) } +} +#[cfg(feature = "parse")] +impl HeaderValue { pub(crate) fn parse(bytes: &[u8]) -> Self { match std::str::from_utf8(bytes) { Ok(s) => Self(Utf8(SmartCow::Owned(s.into()))), @@ -124,15 +139,43 @@ impl From<&'static str> for HeaderValue { } } -impl AsRef<[u8]> for HeaderValue { +macro_rules! delegate_from_to_format { + ($($t:ty),*) => { + $( + impl From<$t> for HeaderValue { + fn from(value: $t) -> Self { + format_args!("{value}").into() + } + } + )* + }; +} + +delegate_from_to_format!(usize, u64, u16, u32, i32, i64); + +impl From> for HeaderValue { + fn from(value: std::fmt::Arguments<'_>) -> Self { + let mut s = SmartString::new(); + s.write_fmt(value).unwrap(); + Self(Utf8(SmartCow::Owned(s))) + } +} + +impl AsRef<[u8]> for HeaderValueInner { fn as_ref(&self) -> &[u8] { - match &self.0 { + match self { Utf8(utf8) => utf8.as_bytes(), Bytes(b) => b, } } } +impl AsRef<[u8]> for HeaderValue { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + impl PartialEq<&str> for HeaderValue { fn eq(&self, other: &&str) -> bool { self.as_str() == Some(*other) diff --git a/http/src/headers/header_values.rs b/http/src/headers/header_values.rs index e76164812c..d977a1f459 100644 --- a/http/src/headers/header_values.rs +++ b/http/src/headers/header_values.rs @@ -22,13 +22,13 @@ impl Deref for HeaderValues { #[cfg(feature = "serde")] impl serde::Serialize for HeaderValues { - fn serialize(&self, serializer: S) -> std::prelude::v1::Result + fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - match self.one() { - Some(one) => one.serialize(serializer), - None => self.0.serialize(serializer), + match &**self { + [one] => one.serialize(serializer), + several => several.serialize(serializer), } } } @@ -139,47 +139,74 @@ impl HeaderValues { // } // } -impl From<&'static [u8]> for HeaderValues { - fn from(value: &'static [u8]) -> Self { - HeaderValue::from(value).into() - } -} +macro_rules! delegate_from_to_header_value { + ($($t:ty),*) => { + $( + impl From<$t> for HeaderValues { + fn from(value: $t) -> Self { + HeaderValue::from(value).into() + } + } + )* + }; +} + +delegate_from_to_header_value!( + &'static [u8], + Vec, + String, + usize, + u64, + u16, + u32, + i32, + i64, + Cow<'static, str>, + &'static str, + std::fmt::Arguments<'_> +); -impl From> for HeaderValues { - fn from(value: Vec) -> Self { - HeaderValue::from(value).into() +impl From for HeaderValues { + fn from(v: HeaderValue) -> Self { + Self(smallvec![v]) } } -impl From for HeaderValues { - fn from(value: String) -> Self { - HeaderValue::from(value).into() +impl From<[HV; N]> for HeaderValues +where + HV: Into, +{ + fn from(v: [HV; N]) -> Self { + Self(v.into_iter().map(Into::into).collect()) } } -impl From<&'static str> for HeaderValues { - fn from(value: &'static str) -> Self { - HeaderValue::from(value).into() +impl From> for HeaderValues +where + HV: Into, +{ + fn from(value: Vec) -> Self { + Self(value.into_iter().map(Into::into).collect()) } } -impl From> for HeaderValues { - fn from(value: Cow<'static, str>) -> Self { - HeaderValue::from(value).into() +impl<'a, HV> From<&'a [HV]> for HeaderValues +where + &'a HV: Into, +{ + fn from(value: &'a [HV]) -> Self { + Self(value.iter().map(Into::into).collect()) } } -impl From for HeaderValues { - fn from(v: HeaderValue) -> Self { - Self(smallvec![v]) +impl PartialEq for HeaderValues { + fn eq(&self, other: &str) -> bool { + self.as_str().is_some_and(|v| v == other) } } -impl From> for HeaderValues -where - HV: Into, -{ - fn from(v: Vec) -> Self { - Self(v.into_iter().map(Into::into).collect()) +impl PartialEq<&str> for HeaderValues { + fn eq(&self, other: &&str) -> bool { + self == *other } } diff --git a/http/src/headers/unknown_header_name.rs b/http/src/headers/unknown_header_name.rs index 2aaeb62650..b8f373a049 100644 --- a/http/src/headers/unknown_header_name.rs +++ b/http/src/headers/unknown_header_name.rs @@ -80,6 +80,12 @@ impl UnknownHeaderName<'_> { } } +impl<'a> UnknownHeaderName<'a> { + pub(crate) fn reborrow<'b: 'a>(&'b self) -> UnknownHeaderName<'b> { + Self(self.0.borrow()) + } +} + impl From for UnknownHeaderName<'static> { fn from(value: String) -> Self { Self(value.into()) diff --git a/http/tests/headers.rs b/http/tests/headers.rs new file mode 100644 index 0000000000..1dbbece237 --- /dev/null +++ b/http/tests/headers.rs @@ -0,0 +1,388 @@ +use indoc::indoc; +use pretty_assertions::{assert_eq, assert_str_eq}; +use test_harness::test; +use trillium_http::{ + Headers, + KnownHeaderName::{self, ContentLength}, +}; + +#[test] +fn known_entry() { + let mut headers = Headers::new(); + let header_name = ContentLength; + let entry = headers.entry(header_name); + assert!(entry.is_vacant()); + assert!(!entry.is_occupied()); + + assert_eq!(entry.name(), header_name); + + assert_eq!( + "Vacant(VacantEntry { name: ContentLength })", + format!("{entry:?}") + ); + + entry.insert("value"); + + let entry = headers.entry(header_name); + assert_eq!(entry.name(), header_name); + + assert_str_eq!(headers.get_str(header_name).unwrap(), "value"); + assert_eq!(headers.entry(header_name).or_insert("ignored"), "value"); + assert_str_eq!( + r#"Occupied(OccupiedEntry { name: ContentLength, values: "value" })"#, + format!("{:?}", headers.entry(header_name)) + ); + + assert_eq!(headers.entry(header_name).insert("new-value"), "new-value"); + assert_str_eq!(headers.get_str(header_name).unwrap(), "new-value"); + + headers.remove(header_name); + assert!(!headers.has_header(header_name)); + assert_eq!( + headers + .entry(header_name) + .or_insert_with(|| String::from("generated-value")), + "generated-value" + ); + + assert_eq!( + **headers.entry(header_name).append("appended-header-value"), + ["generated-value", "appended-header-value"] + ); + + headers.remove(header_name); + assert_eq!( + **headers.entry(header_name).append("appended-header-value"), + ["appended-header-value"] + ); + + let occupied = headers.entry(header_name).occupied().unwrap(); + assert_eq!(occupied.remove(), "appended-header-value"); + + headers.insert(header_name, "some-value"); + let mut occupied = headers.entry(header_name).occupied().unwrap(); + occupied.append("another-value"); + let (n, v) = occupied.remove_entry(); + assert_eq!(n, header_name); + assert_eq!(*v, ["some-value", "another-value"]); +} + +#[test] +fn unknown_entry() { + let mut headers = Headers::new(); + let header_name = "x-unknown-header"; + let entry = headers.entry(header_name); + assert!(entry.is_vacant()); + assert!(!entry.is_occupied()); + assert_eq!(entry.name(), header_name); + let entry = entry.and_modify(|_| panic!("never called")); + assert!(entry.occupied().is_none()); + assert!(headers.entry(header_name).vacant().is_some()); + let entry = headers.entry(header_name); + + assert_str_eq!( + r#"Vacant(VacantEntry { name: "x-unknown-header" })"#, + format!("{entry:?}") + ); + + entry.insert("value"); + + let entry = headers.entry(header_name); + assert!(!entry.is_vacant()); + assert!(entry.is_occupied()); + + assert!(entry.occupied().is_some()); + assert!(headers.entry(header_name).vacant().is_none()); + + assert_str_eq!(headers.get_str(header_name).unwrap(), "value"); + assert_eq!(headers.entry(header_name).or_insert("ignored"), "value"); + assert_str_eq!( + r#"Occupied(OccupiedEntry { name: "x-unknown-header", values: "value" })"#, + format!("{:?}", headers.entry(header_name)) + ); + + let entry = headers.entry(header_name); + assert_eq!(entry.name(), header_name); + + assert_eq!(headers.entry(header_name).insert("new-value"), "new-value"); + assert_str_eq!(headers.get_str(header_name).unwrap(), "new-value"); + + headers.remove(header_name); + assert!(!headers.has_header(header_name)); + assert_eq!( + headers + .entry(header_name) + .or_insert_with(|| String::from("generated-value")), + "generated-value" + ); + + assert_eq!( + **headers.entry(header_name).append("appended-header-value"), + ["generated-value", "appended-header-value"] + ); + + assert_eq!( + **headers + .entry(header_name) + .and_modify(|values| values.sort()) + .or_insert(""), + ["appended-header-value", "generated-value"] + ); + + headers.remove(header_name); + assert_eq!( + **headers.entry(header_name).append("appended-header-value"), + ["appended-header-value"] + ); + + let occupied = headers.entry(header_name).occupied().unwrap(); + assert_eq!(occupied.remove(), "appended-header-value"); + + headers.insert(header_name, "some-value"); + let mut occupied = headers.entry(header_name).occupied().unwrap(); + occupied.append("another-value"); + let (n, v) = occupied.remove_entry(); + assert_eq!(n, header_name); + assert_eq!(*v, ["some-value", "another-value"]); +} + +#[test] +fn headers_known() { + let mut headers = Headers::new(); + let header_name = ContentLength; + assert!(headers.is_empty()); + assert_eq!(headers.len(), 0); + assert!(!headers.has_header(header_name)); + + headers.insert(header_name, 100); + assert!(!headers.is_empty()); + assert_eq!(headers.len(), 1); + assert!(headers.has_header(header_name)); + + assert_str_eq!("Content-Length: 100\r\n", format!("{headers}")); + assert_str_eq!( + r#"Headers { known: {ContentLength: "100"}, unknown: {} }"#, + format!("{headers:?}") + ); + assert_eq!(**headers.get_values(header_name).unwrap(), ["100"]); + headers.try_insert(header_name, "ignored"); + assert_eq!(**headers.get_values(header_name).unwrap(), ["100"]); + + headers.append(header_name, "second value"); + assert!(!headers.is_empty()); + assert_eq!(headers.len(), 1); + + assert_str_eq!( + "Content-Length: 100\r\nContent-Length: second value\r\n", + format!("{headers}") + ); + assert_str_eq!( + r#"Headers { known: {ContentLength: ["100", "second value"]}, unknown: {} }"#, + format!("{headers:?}") + ); + assert_eq!( + **headers.get_values(header_name).unwrap(), + ["100", "second value"] + ); + + headers.try_insert(header_name, "ignored"); + assert_eq!( + **headers.get_values(header_name).unwrap(), + ["100", "second value"] + ); + + headers.remove(header_name); + headers.try_insert(header_name, "INSERTED"); + assert_eq!(headers.get_values(header_name).unwrap(), "INSERTED"); + assert_str_eq!(headers.get_str(header_name).unwrap(), "INSERTED"); + assert!(headers.eq_ignore_ascii_case(header_name, "inserted")); + assert!(!headers.eq_ignore_ascii_case(header_name, "other")); +} + +#[test] +fn bulk_header_operations() { + let headers = Headers::from_iter([("Content-Length", 1), ("x-unknown-header", 2)]) + .with_inserted_header("x-other", 1) + .with_inserted_header("other-Header", format_args!("1 + 2 = {}", 1 + 2)) + .with_inserted_header(KnownHeaderName::Host, "host") + .with_inserted_header(KnownHeaderName::Host, "other-host") + .with_appended_header(KnownHeaderName::Server, String::from("server")) + .with_appended_header(KnownHeaderName::Server, "x") + .without_header("x-Unknown-Header") + .without_header(ContentLength); + + assert_str_eq!( + headers.to_string(), + "Host: other-host\r\nServer: server\r\nServer: x\r\nother-Header: 1 + 2 = 3\r\nx-other: 1\r\n" + ); + + assert_str_eq!( + headers + .without_headers(["x-unknown-header", "server"]) + .to_string(), + "Host: other-host\r\nother-Header: 1 + 2 = 3\r\nx-other: 1\r\n" + ); +} + +#[test] +fn combining_headers() { + let headers_a = Headers::from_iter([ + ("a", "b"), + ("c", "d"), + ("host", "is a known header"), + ("server", "known"), + ]); + let headers_b = Headers::from_iter([ + ("A", "E"), + ("C", "F"), + ("HOST", "also known"), + ("SERVER", "also known"), + ("new-unknown", "only in b"), + ("Content-TYPE", "also only in b"), + ]); + + let mut extended = headers_a.clone(); + extended.extend(headers_b.clone()); + assert_str_eq!( + indoc! {" + Host: is a known header\r + Host: also known\r + Content-Type: also only in b\r + Server: known\r + Server: also known\r + new-unknown: only in b\r + a: b\r + a: E\r + c: d\r + c: F\r + "}, + extended.to_string(), + ); + + let mut insert_all = headers_a.clone(); + insert_all.insert_all(headers_b.clone()); + assert_str_eq!( + indoc! {" + Host: also known\r + Content-Type: also only in b\r + Server: also known\r + c: F\r + new-unknown: only in b\r + a: E\r + "}, + insert_all.to_string(), + ); + + let mut append_all = headers_a.clone(); + append_all.append_all(headers_b.clone()); + assert_str_eq!( + indoc! {" + Host: is a known header\r + Host: also known\r + Content-Type: also only in b\r + Server: known\r + Server: also known\r + new-unknown: only in b\r + a: b\r + a: E\r + c: d\r + c: F\r + "}, + append_all.to_string(), + ); +} + +#[test] +fn headers_unknown() { + let mut headers = Headers::new(); + let header_name = "x-unknown-header"; + assert!(headers.is_empty()); + assert_eq!(headers.len(), 0); + assert!(!headers.has_header(header_name)); + + headers.insert(header_name, 100); + assert!(!headers.is_empty()); + assert_eq!(headers.len(), 1); + assert!(headers.has_header(header_name)); + + assert_str_eq!("x-unknown-header: 100\r\n", format!("{headers}")); + assert_str_eq!( + r#"Headers { known: {}, unknown: {"x-unknown-header": "100"} }"#, + format!("{headers:?}") + ); + assert_eq!(**headers.get_values(header_name).unwrap(), ["100"]); + headers.try_insert(header_name, "ignored"); + assert_eq!(**headers.get_values(header_name).unwrap(), ["100"]); + + headers.append(header_name, "second value"); + assert!(!headers.is_empty()); + assert_eq!(headers.len(), 1); + + assert_str_eq!( + "x-unknown-header: 100\r\nx-unknown-header: second value\r\n", + format!("{headers}") + ); + assert_str_eq!( + r#"Headers { known: {}, unknown: {"x-unknown-header": ["100", "second value"]} }"#, + format!("{headers:?}") + ); + assert_eq!( + **headers.get_values(header_name).unwrap(), + ["100", "second value"] + ); + + headers.try_insert(header_name, "ignored"); + assert_eq!( + **headers.get_values(header_name).unwrap(), + ["100", "second value"] + ); + + headers.remove(header_name); + headers.try_insert(header_name, "INSERTED"); + assert_eq!(headers.get_values(header_name).unwrap(), "INSERTED"); + assert_str_eq!(headers.get_str(header_name).unwrap(), "INSERTED"); + assert!(headers.eq_ignore_ascii_case(header_name, "inserted")); + assert!(!headers.eq_ignore_ascii_case(header_name, "other")); + + headers.remove(header_name); + headers.append(header_name, "inserted"); + assert_str_eq!(headers.get_str(header_name).unwrap(), "inserted"); +} + +#[test] +fn header_names_are_case_insensitive_for_access_but_retain_initial_case_in_headers() { + let mut headers = Headers::new(); + headers.insert("my-Header-name", "initial-value"); + headers.insert("my-Header-NAME", "my-header-value"); + + assert_eq!(headers.len(), 1); + + assert_eq!( + headers.get_str("My-Header-Name").unwrap(), + "my-header-value" + ); + + headers.append("mY-hEaDer-NaMe", "second-value"); + assert_eq!( + headers.get_values("my-header-name").unwrap(), + ["my-header-value", "second-value"].as_slice() + ); + + assert_eq!( + headers.iter().next().unwrap().0.to_string(), + "my-Header-name" + ); + + assert!(headers.remove("my-HEADER-name").is_some()); + assert!(headers.is_empty()); +} + +#[test] +fn value_case_insensitive_comparison() { + let mut headers = Headers::new(); + headers.insert(KnownHeaderName::Upgrade, "WebSocket"); + headers.insert(KnownHeaderName::Connection, "upgrade"); + + assert!(headers.eq_ignore_ascii_case(KnownHeaderName::Upgrade, "websocket")); + assert!(headers.eq_ignore_ascii_case(KnownHeaderName::Connection, "Upgrade")); +} diff --git a/http/tests/serialize_headers.rs b/http/tests/serialize_headers.rs index a3ab20feea..7531dff072 100644 --- a/http/tests/serialize_headers.rs +++ b/http/tests/serialize_headers.rs @@ -7,7 +7,7 @@ fn header_serialization() { headers.insert( "non-utf8", vec![ - 0xC0, 0xC1, 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF, + 0xC0u8, 0xC1, 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF, ], ); headers.insert("multi-values", "value1"); @@ -17,7 +17,7 @@ fn header_serialization() { serde_json::json!({ "Accept": "Known", "non-utf8": vec![ - 0xC0, 0xC1, 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF, + 0xC0u8, 0xC1, 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF, ], "multi-values": ["value1", "value2", "value3"] }),