Skip to content

Commit

Permalink
fix(client): strip path from Uri before calling Connector (#2109)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmonstar authored Jan 13, 2020
1 parent a5720fa commit ba2a144
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 34 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ serde_derive = "1.0"
serde_json = "1.0"
tokio = { version = "0.2.2", features = ["fs", "macros", "io-std", "rt-util", "sync", "time", "test-util"] }
tokio-test = "0.2"
tower-util = "0.3"
url = "1.0"

[features]
Expand Down
36 changes: 21 additions & 15 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
use std::fmt;
use std::mem;
use std::sync::Arc;
use std::time::Duration;

use futures_channel::oneshot;
Expand Down Expand Up @@ -230,14 +229,13 @@ where
other => return ResponseFuture::error_version(other),
};

let domain = match extract_domain(req.uri_mut(), is_http_connect) {
let pool_key = match extract_domain(req.uri_mut(), is_http_connect) {
Ok(s) => s,
Err(err) => {
return ResponseFuture::new(Box::new(future::err(err)));
}
};

let pool_key = Arc::new(domain);
ResponseFuture::new(Box::new(self.retryably_send_request(req, pool_key)))
}

Expand Down Expand Up @@ -281,7 +279,7 @@ where
mut req: Request<B>,
pool_key: PoolKey,
) -> impl Future<Output = Result<Response<Body>, ClientError<B>>> + Unpin {
let conn = self.connection_for(req.uri().clone(), pool_key);
let conn = self.connection_for(pool_key);

let set_host = self.config.set_host;
let executor = self.conn_builder.exec.clone();
Expand Down Expand Up @@ -377,7 +375,6 @@ where

fn connection_for(
&self,
uri: Uri,
pool_key: PoolKey,
) -> impl Future<Output = Result<Pooled<PoolClient<B>>, ClientError<B>>> {
// This actually races 2 different futures to try to get a ready
Expand All @@ -394,7 +391,7 @@ where
// connection future is spawned into the runtime to complete,
// and then be inserted into the pool as an idle connection.
let checkout = self.pool.checkout(pool_key.clone());
let connect = self.connect_to(uri, pool_key);
let connect = self.connect_to(pool_key);

let executor = self.conn_builder.exec.clone();
// The order of the `select` is depended on below...
Expand Down Expand Up @@ -455,7 +452,6 @@ where

fn connect_to(
&self,
uri: Uri,
pool_key: PoolKey,
) -> impl Lazy<Output = crate::Result<Pooled<PoolClient<B>>>> + Unpin {
let executor = self.conn_builder.exec.clone();
Expand All @@ -464,7 +460,7 @@ where
let ver = self.config.ver;
let is_ver_h2 = ver == Ver::Http2;
let connector = self.connector.clone();
let dst = uri;
let dst = domain_as_uri(pool_key.clone());
hyper_lazy(move || {
// Try to take a "connecting lock".
//
Expand Down Expand Up @@ -794,22 +790,22 @@ fn authority_form(uri: &mut Uri) {
};
}

fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result<String> {
fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result<PoolKey> {
let uri_clone = uri.clone();
match (uri_clone.scheme(), uri_clone.authority()) {
(Some(scheme), Some(auth)) => Ok(format!("{}://{}", scheme, auth)),
(Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())),
(None, Some(auth)) if is_http_connect => {
let scheme = match auth.port_u16() {
Some(443) => {
set_scheme(uri, Scheme::HTTPS);
"https"
Scheme::HTTPS
}
_ => {
set_scheme(uri, Scheme::HTTP);
"http"
Scheme::HTTP
}
};
Ok(format!("{}://{}", scheme, auth))
Ok((scheme, auth.clone()))
}
_ => {
debug!("Client requires absolute-form URIs, received: {:?}", uri);
Expand All @@ -818,6 +814,15 @@ fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result<String>
}
}

fn domain_as_uri((scheme, auth): PoolKey) -> Uri {
http::uri::Builder::new()
.scheme(scheme)
.authority(auth)
.path_and_query("/")
.build()
.expect("domain is valid Uri")
}

fn set_scheme(uri: &mut Uri, scheme: Scheme) {
debug_assert!(
uri.scheme().is_none(),
Expand Down Expand Up @@ -1126,7 +1131,8 @@ mod unit_tests {
#[test]
fn test_extract_domain_connect_no_port() {
let mut uri = "hyper.rs".parse().unwrap();
let domain = extract_domain(&mut uri, true).expect("extract domain");
assert_eq!(domain, "http://hyper.rs");
let (scheme, host) = extract_domain(&mut uri, true).expect("extract domain");
assert_eq!(scheme, *"http");
assert_eq!(host, "hyper.rs");
}
}
23 changes: 13 additions & 10 deletions src/client/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub(super) enum Reservation<T> {
}

/// Simple type alias in case the key type needs to be adjusted.
pub(super) type Key = Arc<String>;
pub(super) type Key = (http::uri::Scheme, http::uri::Authority); //Arc<String>;

struct PoolInner<T> {
// A flag that a connection is being established, and the connection
Expand Down Expand Up @@ -755,7 +755,6 @@ impl<T> WeakOpt<T> {

#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;

Expand Down Expand Up @@ -787,6 +786,10 @@ mod tests {
}
}

fn host_key(s: &str) -> Key {
(http::uri::Scheme::HTTP, s.parse().expect("host key"))
}

fn pool_no_timer<T>() -> Pool<T> {
pool_max_idle_no_timer(::std::usize::MAX)
}
Expand All @@ -807,7 +810,7 @@ mod tests {
#[tokio::test]
async fn test_pool_checkout_smoke() {
let pool = pool_no_timer();
let key = Arc::new("foo".to_string());
let key = host_key("foo");
let pooled = pool.pooled(c(key.clone()), Uniq(41));

drop(pooled);
Expand Down Expand Up @@ -839,7 +842,7 @@ mod tests {
#[tokio::test]
async fn test_pool_checkout_returns_none_if_expired() {
let pool = pool_no_timer();
let key = Arc::new("foo".to_string());
let key = host_key("foo");
let pooled = pool.pooled(c(key.clone()), Uniq(41));

drop(pooled);
Expand All @@ -854,7 +857,7 @@ mod tests {
#[tokio::test]
async fn test_pool_checkout_removes_expired() {
let pool = pool_no_timer();
let key = Arc::new("foo".to_string());
let key = host_key("foo");

pool.pooled(c(key.clone()), Uniq(41));
pool.pooled(c(key.clone()), Uniq(5));
Expand All @@ -876,7 +879,7 @@ mod tests {
#[test]
fn test_pool_max_idle_per_host() {
let pool = pool_max_idle_no_timer(2);
let key = Arc::new("foo".to_string());
let key = host_key("foo");

pool.pooled(c(key.clone()), Uniq(41));
pool.pooled(c(key.clone()), Uniq(5));
Expand Down Expand Up @@ -904,7 +907,7 @@ mod tests {
&Exec::Default,
);

let key = Arc::new("foo".to_string());
let key = host_key("foo");

pool.pooled(c(key.clone()), Uniq(41));
pool.pooled(c(key.clone()), Uniq(5));
Expand All @@ -929,7 +932,7 @@ mod tests {
use futures_util::FutureExt;

let pool = pool_no_timer();
let key = Arc::new("foo".to_string());
let key = host_key("foo");
let pooled = pool.pooled(c(key.clone()), Uniq(41));

let checkout = join(pool.checkout(key), async {
Expand All @@ -948,7 +951,7 @@ mod tests {
#[tokio::test]
async fn test_pool_checkout_drop_cleans_up_waiters() {
let pool = pool_no_timer::<Uniq<i32>>();
let key = Arc::new("localhost:12345".to_string());
let key = host_key("foo");

let mut checkout1 = pool.checkout(key.clone());
let mut checkout2 = pool.checkout(key.clone());
Expand Down Expand Up @@ -993,7 +996,7 @@ mod tests {
#[test]
fn pooled_drop_if_closed_doesnt_reinsert() {
let pool = pool_no_timer();
let key = Arc::new("localhost:12345".to_string());
let key = host_key("foo");
pool.pooled(
c(key.clone()),
CanClose {
Expand Down
33 changes: 24 additions & 9 deletions src/client/tests.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
// FIXME: re-implement tests with `async/await`
/*
#![cfg(feature = "runtime")]
use std::io;

use futures_util::future;
use tokio::net::TcpStream;

use futures::{Async, Future, Stream};
use futures::future::poll_fn;
use futures::sync::oneshot;
use tokio::runtime::current_thread::Runtime;
use super::Client;

use crate::mock::MockConnector;
use super::*;
#[tokio::test]
async fn client_connect_uri_argument() {
let connector = tower_util::service_fn(|dst: http::Uri| {
assert_eq!(dst.scheme(), Some(&http::uri::Scheme::HTTP));
assert_eq!(dst.host(), Some("example.local"));
assert_eq!(dst.port(), None);
assert_eq!(dst.path(), "/", "path should be removed");

future::err::<TcpStream, _>(io::Error::new(io::ErrorKind::Other, "expect me"))
});

let client = Client::builder().build::<_, crate::Body>(connector);
let _ = client
.get("http://example.local/and/a/path".parse().unwrap())
.await
.expect_err("response should fail");
}

/*
// FIXME: re-implement tests with `async/await`
#[test]
fn retryable_request() {
let _ = pretty_env_logger::try_init();
Expand Down

0 comments on commit ba2a144

Please sign in to comment.