Skip to content

Commit

Permalink
adjust CookieStore trait, make default store public as Jar
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmonstar committed Mar 4, 2021
1 parent 1880f1b commit 9f174e3
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 46 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ rustls-tls-native-roots = ["rustls-native-certs", "__rustls"]

blocking = ["futures-util/io", "tokio/rt-multi-thread", "tokio/sync"]

cookies = ["cookie_crate", "time"]
cookies = ["cookie_crate", "cookie_store", "time"]

gzip = ["async-compression", "async-compression/gzip", "tokio-util"]

Expand Down Expand Up @@ -117,6 +117,7 @@ rustls-native-certs = { version = "0.5", optional = true }

## cookies
cookie_crate = { version = "0.14", package = "cookie", optional = true }
cookie_store = { version = "0.12", optional = true }

time = { version = "0.2.11", optional = true }

Expand Down
31 changes: 13 additions & 18 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,10 @@ impl ClientBuilder {
self
}

/// Set the persistent cookie store for the client.
/// Enable a persistent cookie store for the client.
///
/// Cookies received in responses will be passed to this store, and
/// additional requests will query this store for cookies.
/// Cookies received in responses will be preserved and included in
/// additional requests.
///
/// By default, no cookie store is used.
///
Expand All @@ -462,12 +462,12 @@ impl ClientBuilder {
#[cfg(feature = "cookies")]
#[cfg_attr(docsrs, doc(cfg(feature = "cookies")))]
pub fn cookie_store(mut self, enable: bool) -> ClientBuilder {
self.config.cookie_store = if enable {
Some(cookie::CookieStore::default())
if enable {
self.cookie_provider(Arc::new(cookie::Jar::default()))
} else {
None
};
self
self.config.cookie_store = None;
self
}
}

/// Set the persistent cookie store for the client.
Expand All @@ -484,9 +484,9 @@ impl ClientBuilder {
#[cfg_attr(docsrs, doc(cfg(feature = "cookies")))]
pub fn cookie_provider<C: cookie::CookieStore + 'static>(
mut self,
cookie_store: Option<Arc<C>>,
cookie_store: Arc<C>,
) -> ClientBuilder {
self.config.cookie_store = cookie_store.map(|a| a as _);
self.config.cookie_store = Some(cookie_store as _);
self
}

Expand Down Expand Up @@ -1452,8 +1452,7 @@ impl Future for PendingRequest {
let mut cookies =
cookie::extract_response_cookie_headers(&res.headers()).peekable();
if cookies.peek().is_some() {
let cookies = cookies.collect();
cookie_store.set_cookies(cookies, &self.url);
cookie_store.set_cookies(&mut cookies, &self.url);
}
}
}
Expand Down Expand Up @@ -1605,12 +1604,8 @@ fn make_referer(next: &Url, previous: &Url) -> Option<HeaderValue> {

#[cfg(feature = "cookies")]
fn add_cookie_header(headers: &mut HeaderMap, cookie_store: &dyn cookie::CookieStore, url: &Url) {
let header = cookie_store.cookies(url).join("; ");
if !header.is_empty() {
headers.insert(
crate::header::COOKIE,
HeaderValue::from_bytes(header.as_bytes()).unwrap(),
);
if let Some(header) = cookie_store.cookies(url) {
headers.insert(crate::header::COOKIE, header);
}
}

Expand Down
19 changes: 19 additions & 0 deletions src/blocking/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@ impl ClientBuilder {
self.with_inner(|inner| inner.cookie_store(enable))
}

/// Set the persistent cookie store for the client.
///
/// Cookies received in responses will be passed to this store, and
/// additional requests will query this store for cookies.
///
/// By default, no cookie store is used.
///
/// # Optional
///
/// This requires the optional `cookies` feature to be enabled.
#[cfg(feature = "cookies")]
#[cfg_attr(docsrs, doc(cfg(feature = "cookies")))]
pub fn cookie_provider<C: crate::cookie::CookieStore + 'static>(
self,
cookie_store: Arc<C>,
) -> ClientBuilder {
self.with_inner(|inner| inner.cookie_provider(cookie_store))
}

/// Enable auto gzip decompression by checking the `Content-Encoding` response header.
///
/// If auto gzip decompresson is turned on:
Expand Down
96 changes: 79 additions & 17 deletions src/cookie.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
//! HTTP Cookies
use std::convert::TryInto;

use crate::header;
use std::fmt;
use std::sync::RwLock;
use std::time::SystemTime;

use crate::header::{HeaderValue, SET_COOKIE};
use bytes::Bytes;

/// Actions for a persistent cookie store providing session supprt.
pub trait CookieStore: Send + Sync {
/// Store a set of Set-Cookie header values recevied from `url`
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &HeaderValue>, url: &url::Url);
/// Get any Cookie values in the store for `url`
fn cookies(&self, url: &url::Url) -> Option<HeaderValue>;
}

/// A single HTTP cookie.
pub struct Cookie<'a>(cookie_crate::Cookie<'a>);

/// A good default `CookieStore` implementation.
///
/// This is the implementation used when simply calling `cookie_store(true)`.
/// This type is exposed to allow creating one and filling it with some
/// existing cookies more easily, before creating a `Client`.
#[derive(Debug, Default)]
pub struct Jar(RwLock<cookie_store::CookieStore>);

// ===== impl Cookie =====

impl<'a> Cookie<'a> {
fn parse(value: &'a crate::header::HeaderValue) -> Result<Cookie<'a>, CookieParseError> {
fn parse(value: &'a HeaderValue) -> Result<Cookie<'a>, CookieParseError> {
std::str::from_utf8(value.as_bytes())
.map_err(cookie_crate::ParseError::from)
.and_then(cookie_crate::Cookie::parse)
Expand Down Expand Up @@ -80,30 +100,19 @@ impl<'a> fmt::Debug for Cookie<'a> {

pub(crate) fn extract_response_cookie_headers<'a>(
headers: &'a hyper::HeaderMap,
) -> impl Iterator<Item = &'a str> + 'a {
headers
.get_all(header::SET_COOKIE)
.iter()
.filter_map(|value| std::str::from_utf8(value.as_bytes()).ok())
) -> impl Iterator<Item = &'a HeaderValue> + 'a {
headers.get_all(SET_COOKIE).iter()
}

pub(crate) fn extract_response_cookies<'a>(
headers: &'a hyper::HeaderMap,
) -> impl Iterator<Item = Result<Cookie<'a>, CookieParseError>> + 'a {
headers
.get_all(header::SET_COOKIE)
.get_all(SET_COOKIE)
.iter()
.map(|value| Cookie::parse(value))
}

/// Actions for a persistent cookie store providing session supprt.
pub trait CookieStore: Send + Sync {
/// Store a set of Set-Cookie header values recevied from `url`
fn set_cookies(&self, cookie_headers: Vec<&str>, url: &url::Url);
/// Get any Cookie values in the store for `url`
fn cookies(&self, url: &url::Url) -> Vec<String>;
}

/// Error representing a parse failure of a 'Set-Cookie' header.
pub(crate) struct CookieParseError(cookie_crate::ParseError);

Expand All @@ -120,3 +129,56 @@ impl<'a> fmt::Display for CookieParseError {
}

impl std::error::Error for CookieParseError {}

// ===== impl Jar =====

impl Jar {
/// Add a cookie to this jar.
///
/// # Example
///
/// ```
/// use reqwest::{cookie::Jar, Url};
///
/// let cookie = "foo=bar; Domain=yolo.local";
/// let url = "https://yolo.local".parse::<Url>().unwrap();
///
/// let jar = Jar::default();
/// jar.add_cookie_str(cookie, &url);
///
/// // and now add to a `ClientBuilder`?
/// ```
pub fn add_cookie_str(&self, cookie: &str, url: &url::Url) {
let cookies = cookie_crate::Cookie::parse(cookie)
.ok()
.map(|c| c.into_owned())
.into_iter();
self.0.write().unwrap().store_response_cookies(cookies, url);
}
}

impl CookieStore for Jar {
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &HeaderValue>, url: &url::Url) {
let iter =
cookie_headers.filter_map(|val| Cookie::parse(val).map(|c| c.0.into_owned()).ok());

self.0.write().unwrap().store_response_cookies(iter, url);
}

fn cookies(&self, url: &url::Url) -> Option<HeaderValue> {
let s = self
.0
.read()
.unwrap()
.get_request_cookies(url)
.map(|c| format!("{}={}", c.name(), c.value()))
.collect::<Vec<_>>()
.join("; ");

if s.is_empty() {
return None;
}

HeaderValue::from_maybe_shared(Bytes::from(s)).ok()
}
}
12 changes: 5 additions & 7 deletions tests/cookie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use std::sync::Arc;
mod support;
use support::*;

use cookie_store::CookieStoreMutex as CookieStore;

#[tokio::test]
async fn cookie_response_accessor() {
let server = server::http(move |_req| async move {
Expand Down Expand Up @@ -88,7 +86,7 @@ async fn cookie_store_simple() {
});

let client = reqwest::Client::builder()
.cookie_store(Some(Arc::new(CookieStore::default())))
.cookie_store(true)
.build()
.unwrap();

Expand Down Expand Up @@ -121,7 +119,7 @@ async fn cookie_store_overwrite_existing() {
});

let client = reqwest::Client::builder()
.cookie_store(Some(Arc::new(CookieStore::default())))
.cookie_store(true)
.build()
.unwrap();

Expand All @@ -146,7 +144,7 @@ async fn cookie_store_max_age() {
});

let client = reqwest::Client::builder()
.cookie_store(Some(Arc::new(CookieStore::default())))
.cookie_store(true)
.build()
.unwrap();
let url = format!("http://{}/", server.addr());
Expand All @@ -168,7 +166,7 @@ async fn cookie_store_expires() {
});

let client = reqwest::Client::builder()
.cookie_store(Some(Arc::new(CookieStore::default())))
.cookie_store(true)
.build()
.unwrap();

Expand All @@ -194,7 +192,7 @@ async fn cookie_store_path() {
});

let client = reqwest::Client::builder()
.cookie_store(Some(Arc::new(CookieStore::default())))
.cookie_store(true)
.build()
.unwrap();

Expand Down
4 changes: 1 addition & 3 deletions tests/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ mod support;
use futures_util::stream::StreamExt;
use support::*;

use cookie_store::CookieStoreMutex as CookieStore;

#[tokio::test]
async fn test_redirect_301_and_302_and_303_changes_post_to_get() {
let client = reqwest::Client::new();
Expand Down Expand Up @@ -312,7 +310,7 @@ async fn test_redirect_302_with_set_cookies() {
let dst = format!("http://{}/{}", server.addr(), "dst");

let client = reqwest::ClientBuilder::new()
.cookie_store(Some(Arc::new(CookieStore::default())))
.cookie_store(true)
.build()
.unwrap();
let res = client.get(&url).send().await.unwrap();
Expand Down

0 comments on commit 9f174e3

Please sign in to comment.