diff --git a/src/client/connect/dns.rs b/src/client/connect/dns.rs index 1ea7d511bf..03e30f098a 100644 --- a/src/client/connect/dns.rs +++ b/src/client/connect/dns.rs @@ -235,6 +235,10 @@ impl IpAddrs { pub(super) fn is_empty(&self) -> bool { self.iter.as_slice().is_empty() } + + pub(super) fn len(&self) -> usize { + self.iter.as_slice().len() + } } impl Iterator for IpAddrs { diff --git a/src/client/connect/http.rs b/src/client/connect/http.rs index effa6f7dc4..ad60fd7d7b 100644 --- a/src/client/connect/http.rs +++ b/src/client/connect/http.rs @@ -10,7 +10,7 @@ use futures_util::{TryFutureExt, FutureExt}; use net2::TcpBuilder; use tokio_net::driver::Handle; use tokio_net::tcp::TcpStream; -use tokio_timer::Delay; +use tokio_timer::{Delay, Timeout}; use crate::common::{Future, Pin, Poll, task}; use super::{Connect, Connected, Destination}; @@ -32,6 +32,7 @@ type ConnectFuture = Pin> + Send>> pub struct HttpConnector { enforce_http: bool, handle: Option, + connect_timeout: Option, happy_eyeballs_timeout: Option, keep_alive_timeout: Option, local_address: Option, @@ -101,6 +102,7 @@ impl HttpConnector { HttpConnector { enforce_http: true, handle: None, + connect_timeout: None, happy_eyeballs_timeout: Some(Duration::from_millis(300)), keep_alive_timeout: None, local_address: None, @@ -168,6 +170,17 @@ impl HttpConnector { self.local_address = addr; } + /// Set the connect timeout. + /// + /// If a domain resolves to multiple IP addresses, the timeout will be + /// evenly divided across them. + /// + /// Default is `None`. + #[inline] + pub fn set_connect_timeout(&mut self, dur: Option) { + self.connect_timeout = dur; + } + /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm. /// /// If hostname resolves to both IPv4 and IPv6 addresses and connection @@ -240,6 +253,7 @@ where HttpConnecting { state: State::Lazy(self.resolver.clone(), host.into(), self.local_address), handle: self.handle.clone(), + connect_timeout: self.connect_timeout, happy_eyeballs_timeout: self.happy_eyeballs_timeout, keep_alive_timeout: self.keep_alive_timeout, nodelay: self.nodelay, @@ -295,6 +309,7 @@ where let fut = HttpConnecting { state: State::Lazy(self.resolver.clone(), host.into(), self.local_address), handle: self.handle.clone(), + connect_timeout: self.connect_timeout, happy_eyeballs_timeout: self.happy_eyeballs_timeout, keep_alive_timeout: self.keep_alive_timeout, nodelay: self.nodelay, @@ -323,6 +338,7 @@ fn invalid_url(err: InvalidUrl, handle: &Option) -> HttpConn keep_alive_timeout: None, nodelay: false, port: 0, + connect_timeout: None, happy_eyeballs_timeout: None, reuse_address: false, send_buffer_size: None, @@ -357,6 +373,7 @@ impl StdError for InvalidUrl { pub struct HttpConnecting { state: State, handle: Option, + connect_timeout: Option, happy_eyeballs_timeout: Option, keep_alive_timeout: Option, nodelay: bool, @@ -389,7 +406,7 @@ where // skip resolving the dns and start connecting right away. if let Some(addrs) = dns::IpAddrs::try_parse(host, me.port) { state = State::Connecting(ConnectingTcp::new( - local_addr, addrs, me.happy_eyeballs_timeout, me.reuse_address)); + local_addr, addrs, me.connect_timeout, me.happy_eyeballs_timeout, me.reuse_address)); } else { let name = dns::Name::new(mem::replace(host, String::new())); state = State::Resolving(resolver.resolve(name), local_addr); @@ -403,7 +420,7 @@ where .collect(); let addrs = dns::IpAddrs::new(addrs); state = State::Connecting(ConnectingTcp::new( - local_addr, addrs, me.happy_eyeballs_timeout, me.reuse_address)); + local_addr, addrs, me.connect_timeout, me.happy_eyeballs_timeout, me.reuse_address)); }, State::Connecting(ref mut c) => { let sock = ready!(c.poll(cx, &me.handle))?; @@ -454,6 +471,7 @@ impl ConnectingTcp { fn new( local_addr: Option, remote_addrs: dns::IpAddrs, + connect_timeout: Option, fallback_timeout: Option, reuse_address: bool, ) -> ConnectingTcp { @@ -462,7 +480,7 @@ impl ConnectingTcp { if fallback_addrs.is_empty() { return ConnectingTcp { local_addr, - preferred: ConnectingTcpRemote::new(preferred_addrs), + preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout), fallback: None, reuse_address, }; @@ -470,17 +488,17 @@ impl ConnectingTcp { ConnectingTcp { local_addr, - preferred: ConnectingTcpRemote::new(preferred_addrs), + preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout), fallback: Some(ConnectingTcpFallback { delay: tokio_timer::delay_for(fallback_timeout), - remote: ConnectingTcpRemote::new(fallback_addrs), + remote: ConnectingTcpRemote::new(fallback_addrs, connect_timeout), }), reuse_address, } } else { ConnectingTcp { local_addr, - preferred: ConnectingTcpRemote::new(remote_addrs), + preferred: ConnectingTcpRemote::new(remote_addrs, connect_timeout), fallback: None, reuse_address, } @@ -495,13 +513,17 @@ struct ConnectingTcpFallback { struct ConnectingTcpRemote { addrs: dns::IpAddrs, + connect_timeout: Option, current: Option, } impl ConnectingTcpRemote { - fn new(addrs: dns::IpAddrs) -> Self { + fn new(addrs: dns::IpAddrs, connect_timeout: Option) -> Self { + let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32)); + Self { addrs, + connect_timeout, current: None, } } @@ -530,14 +552,14 @@ impl ConnectingTcpRemote { err = Some(e); if let Some(addr) = self.addrs.next() { debug!("connecting to {}", addr); - *current = connect(&addr, local_addr, handle, reuse_address)?; + *current = connect(&addr, local_addr, handle, reuse_address, self.connect_timeout)?; continue; } } } } else if let Some(addr) = self.addrs.next() { debug!("connecting to {}", addr); - self.current = Some(connect(&addr, local_addr, handle, reuse_address)?); + self.current = Some(connect(&addr, local_addr, handle, reuse_address, self.connect_timeout)?); continue; } @@ -546,7 +568,13 @@ impl ConnectingTcpRemote { } } -fn connect(addr: &SocketAddr, local_addr: &Option, handle: &Option, reuse_address: bool) -> io::Result { +fn connect( + addr: &SocketAddr, + local_addr: &Option, + handle: &Option, + reuse_address: bool, + connect_timeout: Option, +) -> io::Result { let builder = match addr { &SocketAddr::V4(_) => TcpBuilder::new_v4()?, &SocketAddr::V6(_) => TcpBuilder::new_v6()?, @@ -581,10 +609,16 @@ fn connect(addr: &SocketAddr, local_addr: &Option, handle: &Option match Timeout::new(connect, timeout).await { + Ok(Ok(s)) => Ok(s), + Ok(Err(e)) => Err(e), + Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)), + } + None => connect.await, + } })) - - //Ok(Box::pin(TcpStream::connect_std(std_tcp, addr, &handle))) } impl ConnectingTcp { @@ -673,7 +707,6 @@ mod tests { }) } - #[test] fn test_errors_missing_scheme() { let mut rt = Runtime::new().unwrap(); @@ -765,7 +798,7 @@ mod tests { } let addrs = hosts.iter().map(|host| (host.clone(), addr.port()).into()).collect(); - let connecting_tcp = ConnectingTcp::new(None, dns::IpAddrs::new(addrs), Some(fallback_timeout), false); + let connecting_tcp = ConnectingTcp::new(None, dns::IpAddrs::new(addrs), None, Some(fallback_timeout), false); let fut = ConnectingTcpFuture(connecting_tcp); let start = Instant::now();