diff --git a/Cargo.toml b/Cargo.toml index f485aca..c0f4a19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ failure = "0.1.1" log = "0.4" trust-dns = "0.15" -trust-dns-proto = "0.5" serde = "1.0" serde_derive = "1.0" diff --git a/examples/get.rs b/examples/get.rs index da9e645..1e8c84e 100644 --- a/examples/get.rs +++ b/examples/get.rs @@ -1,8 +1,8 @@ +extern crate chrootable_https; extern crate env_logger; extern crate structopt; -extern crate chrootable_https; -use chrootable_https::{Resolver, Client}; +use chrootable_https::{Client, Resolver}; use std::io; use std::io::prelude::*; use std::time::Duration; @@ -10,7 +10,7 @@ use structopt::StructOpt; #[derive(Debug, StructOpt)] pub struct Args { - #[structopt(short="-t", long="--timeout")] + #[structopt(short = "-t", long = "--timeout")] timeout: Option, urls: Vec, } @@ -27,8 +27,13 @@ fn main() { } for url in &args.urls { - let reply = client.get(&url).expect("request failed"); + let reply = client + .get(&url) + .wait_for_response() + .expect("request failed"); eprintln!("{:#?}", reply); - io::stdout().write(&reply.body).expect("failed to write body"); + io::stdout() + .write(&reply.body) + .expect("failed to write body"); } } diff --git a/src/connector.rs b/src/connector.rs index 6db2235..112d6ec 100644 --- a/src/connector.rs +++ b/src/connector.rs @@ -1,57 +1,65 @@ -use hyper_rustls::HttpsConnector; -use hyper::rt::Future; -use hyper::client::connect::{self, Connect}; -use hyper::client::connect::HttpConnector; -use hyper::client::connect::Destination; +use ct_logs; +use dns::{DnsResolver, RecordType}; use futures::{future, Poll}; +use hyper::client::connect::Destination; +use hyper::client::connect::HttpConnector; +use hyper::client::connect::{self, Connect}; +use hyper::rt::Future; +use hyper_rustls::HttpsConnector; use rustls::ClientConfig; use webpki_roots; -use ct_logs; +use errors::Error; use std::io; use std::net::IpAddr; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; -use errors::Result; - +use std::sync::Arc; -pub struct Connector { +pub struct Connector { http: T, - // resolver: ResolverFuture, - records: Arc>>, + resolver: Arc, } -impl Connector { - pub fn resolve_dest(&self, mut dest: Destination) -> Result { - let ip = { - let cache = self.records.lock().unwrap(); - cache.get(dest.host()).map(|x| x.to_owned()) - }; - - let ip = match ip { - Some(IpAddr::V4(ip)) => ip.to_string(), - Some(IpAddr::V6(ip)) => format!("[{}]", ip), - None => bail!("host wasn't pre-resolved"), - }; - - dest.set_host(&ip)?; +impl Connector { + pub fn resolve_dest(&self, mut dest: Destination) -> Resolving { + let resolver = self.resolver.clone(); + let host = dest.host().to_string(); + + let resolve = future::lazy(move || { + resolver + .resolve(&host, RecordType::A) + }); + + let resolved = Box::new(resolve.and_then(move |record| { + // TODO: we might have more than one record available + match record.success()?.into_iter().next() { + Some(record) => { + let ip = match record { + IpAddr::V4(ip) => ip.to_string(), + IpAddr::V6(ip) => format!("[{}]", ip), + }; + + dest.set_host(&ip)?; + Ok(dest) + } + None => bail!("no record found"), + } + })); - Ok(dest) + Resolving(resolved) } } -impl Connector { - pub fn new(records: Arc>>) -> Connector { +impl Connector { + pub fn new(resolver: Arc) -> Connector { let mut http = HttpConnector::new(4); http.enforce_http(false); - Connector { - http, - records, - } + Connector { http, resolver } } - pub fn https(records: Arc>>) -> HttpsConnector> { - let http = Connector::new(records); + pub fn https( + resolver: Arc, + ) -> HttpsConnector> { + let http = Connector::new(resolver); let mut config = ClientConfig::new(); config @@ -63,13 +71,15 @@ impl Connector { } } -impl Connect for Connector +impl Connect for Connector where - T: Connect, + T: Connect, T: Clone, T: 'static, T::Transport: 'static, T::Future: 'static, + R: DnsResolver, + R: 'static, { type Transport = T::Transport; type Error = io::Error; @@ -77,41 +87,23 @@ where fn connect(&self, dest: connect::Destination) -> Self::Future { debug!("original destination: {:?}", dest); - let dest = match self.resolve_dest(dest) { - Ok(dest) => dest, - Err(err) => { - let err = io::Error::new(io::ErrorKind::Other, err.to_string()); - return Connecting(Box::new(future::err(err))); - }, - }; - debug!("resolved destination: {:?}", dest); - let connecting = self.http.connect(dest); - let fut = Box::new(connecting); - Connecting(fut) - - /* - // async implementation - // compiles but hangs forever - println!("creating resolve"); - let resolving = self.resolve_dest(&dest); + let resolving = self + .resolve_dest(dest) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string())); let http = self.http.clone(); - println!("chaining resolve"); - let fut = Box::new(resolving.and_then(move |records| { - // unimplemented!() - println!("records: {:?}", records); + let fut = Box::new(resolving.and_then(move |dest| { + debug!("resolved destination: {:?}", dest); http.connect(dest) })); - println!("returning future"); + Connecting(fut) - */ } } -/// A Future representing work to connect to a URL -pub struct Connecting( - Box + Send>, -); +/// A Future representing work to connect to a URL. +#[must_use = "futures do nothing unless polled"] +pub struct Connecting(Box + Send>); impl Future for Connecting { type Item = (T, connect::Connected); @@ -121,3 +113,16 @@ impl Future for Connecting { self.0.poll() } } + +/// A Future representing work to resolve a DNS query. +#[must_use = "futures do nothing unless polled"] +pub struct Resolving(Box + Send>); + +impl Future for Resolving { + type Item = Destination; + type Error = Error; + + fn poll(&mut self) -> Poll { + self.0.poll() + } +} diff --git a/src/dns/mod.rs b/src/dns/mod.rs index e0d001a..f70ada3 100644 --- a/src/dns/mod.rs +++ b/src/dns/mod.rs @@ -1,32 +1,25 @@ use errors::*; -use std::time::Duration; -use std::net::{SocketAddr, IpAddr, Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::result; use std::str::{self, FromStr}; +use std::time::Duration; -use futures::Future; use futures::Poll; +use futures::{future, Future}; use tokio::prelude::FutureExt; use tokio::runtime::Runtime; -use tokio::net::TcpStream; use trust_dns::client::ClientHandle; +use trust_dns::client::{Client, ClientConnection, SyncClient}; +use trust_dns::op::ResponseCode; use trust_dns::rr::rdata; use trust_dns::rr::record_data; pub use trust_dns::rr::record_type::RecordType; -use trust_dns::client::{Client, ClientConnection, ClientFuture, SyncClient}; -use trust_dns::udp::{UdpClientConnection, UdpClientStream}; -use trust_dns_proto::udp::UdpClientConnect; -use trust_dns::tcp::{TcpClientConnection, TcpClientStream}; -use trust_dns_proto::tcp::TcpClientConnect; -use trust_dns::op::{DnsResponse, ResponseCode}; use trust_dns::rr::{DNSClass, Name}; -use trust_dns::rr::dnssec::Signer; -use trust_dns_proto::DnsMultiplexer; -use trust_dns_proto::xfer; +use trust_dns::tcp::TcpClientConnection; +use trust_dns::udp::UdpClientConnection; pub mod system_conf; - #[derive(Debug, PartialEq, Serialize, Deserialize)] pub enum DnsError { FormErr, @@ -85,20 +78,19 @@ impl<'a> From<&'a record_data::RData> for RData { fn from(rdata: &'a record_data::RData) -> RData { use trust_dns::rr::record_data::RData::*; match rdata { - A(ip) => RData::A(ip.clone()), - AAAA(ip) => RData::AAAA(ip.clone()), + A(ip) => RData::A(ip.clone()), + AAAA(ip) => RData::AAAA(ip.clone()), CNAME(name) => RData::CNAME(name.to_string()), - MX(mx) => RData::MX((mx.preference(), mx.exchange().to_string())), - NS(ns) => RData::NS(ns.to_string()), - PTR(ptr) => RData::PTR(ptr.to_string()), - SOA(soa) => RData::SOA(soa.into()), - SRV(srv) => RData::SRV((srv.target().to_string(), srv.port())), - TXT(txt) => RData::TXT(txt.iter() - .fold(Vec::new(), |mut a, b| { - a.extend(b.iter()); - a - })), - _ => RData::Other("unknown".to_string()), + MX(mx) => RData::MX((mx.preference(), mx.exchange().to_string())), + NS(ns) => RData::NS(ns.to_string()), + PTR(ptr) => RData::PTR(ptr.to_string()), + SOA(soa) => RData::SOA(soa.into()), + SRV(srv) => RData::SRV((srv.target().to_string(), srv.port())), + TXT(txt) => RData::TXT(txt.iter().fold(Vec::new(), |mut a, b| { + a.extend(b.iter()); + a + })), + _ => RData::Other("unknown".to_string()), } } } @@ -111,7 +103,7 @@ pub struct SOA { refresh: i32, retry: i32, expire: i32, - minimum: u32 + minimum: u32, } impl<'a> From<&'a rdata::soa::SOA> for SOA { @@ -129,16 +121,18 @@ impl<'a> From<&'a rdata::soa::SOA> for SOA { } pub fn dns_name_to_string(name: &Name) -> Result { - let labels = name.iter() + let labels = name + .iter() .map(str::from_utf8) .collect::, _>>()?; Ok(labels.join(".")) } -pub trait DnsResolver { - fn resolve(&self, name: &str, query_type: RecordType) -> Result; +pub trait DnsResolver: Send + Sync { + fn resolve(&self, name: &str, query_type: RecordType) -> Resolving; } +/// An asynchronous DNS resolver. #[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)] pub struct Resolver { pub ns: Vec, @@ -148,18 +142,18 @@ pub struct Resolver { } impl Resolver { + /// Creates a new resolver using the [CloudFlare Authoritative DNS][cf] service. + /// + /// [cf]: https://www.cloudflare.com/learning/dns/what-is-1.1.1.1/ pub fn cloudflare() -> Resolver { Resolver { - ns: vec![ - "1.1.1.1:53".parse().unwrap(), - "1.0.0.1:53".parse().unwrap(), - ], + ns: vec!["1.1.1.1:53".parse().unwrap(), "1.0.0.1:53".parse().unwrap()], tcp: false, timeout: Some(Duration::from_secs(3)), } } - /// Create a new resolver from /etc/resolv.conf + /// Creates a new resolver from `/etc/resolv.conf`. pub fn from_system() -> Result { let ns = system_conf::read_system_conf()?; Ok(Resolver { @@ -169,65 +163,79 @@ impl Resolver { }) } + /// Sets a timeout within which each DNS query must complete. + /// + /// Default setting is no timeout. pub fn timeout(&mut self, timeout: Option) { self.timeout = timeout; } } impl Resolver { - fn resolve_with(&self, conn: T, name: &Name, query_type: RecordType) -> Result { + fn resolve_with(&self, conn: T, name: Name, query_type: RecordType) -> Resolving + where + T: ClientConnection, + { let client = SyncClient::new(conn); - - let mut reactor = Runtime::new()?; let (bg, mut client) = client.new_future(); - let rt = reactor - .spawn(bg); - - let fut = client.query(name.clone(), DNSClass::IN, query_type) - .map_err(Error::from); - - let response = match self.timeout { - Some(timeout) => rt.block_on(fut.timeout(timeout)) - .map_err(|x| match x.into_inner() { - Some(e) => e, - _ => format_err!("Dns query timed out"), - })?, - None => rt.block_on(fut)?, + + let query = future::lazy(move || { + tokio::executor::spawn(bg); + client + .query(name, DNSClass::IN, query_type) + .map_err(Error::from) + }); + + let response: Box + Send> = match self.timeout { + Some(ref timeout) => Box::new(query.timeout(*timeout).map_err(|e| { + e.into_inner() + .unwrap_or_else(|| format_err!("DNS query timed out")) + })), + None => Box::new(query), }; - Ok(response) + let reply = response.and_then(|response| { + let error = DnsError::from_response_code(&response.response_code()); + + let answers = response + .answers() + .iter() + .map(|x| { + let name = dns_name_to_string(x.name())?; + let rdata = x.rdata().into(); + Ok((name, rdata)) + }).collect::>>()?; + + Ok(DnsReply { answers, error }) + }); + + Resolving::new(reply) } } impl DnsResolver for Resolver { - fn resolve(&self, name: &str, query_type: RecordType) -> Result { - let name = Name::from_str(name)?; - - let address = self.ns.iter().next() - .ok_or_else(|| format_err!("No nameserver configured"))?; - - let response: DnsResponse = if self.tcp { - let conn = TcpClientConnection::new(*address)?; - self.resolve_with(conn, &name, query_type)? - } else { - let conn = UdpClientConnection::new(*address)?; - self.resolve_with(conn, &name, query_type)? + fn resolve(&self, name: &str, query_type: RecordType) -> Resolving { + let name = match Name::from_str(name) { + Ok(name) => name, + Err(e) => return Resolving::new(future::err(e.into())), }; - let error = DnsError::from_response_code(&response.response_code()); - - let answers = response.answers().iter() - .map(|x| { - let name = dns_name_to_string(x.name())?; - let rdata = x.rdata().into(); - Ok((name, rdata)) - }) - .collect::>>()?; + let address = match self.ns.first() { + Some(ref address) => *address, + None => return Resolving::new(future::err(format_err!("No nameserver configured"))), + }; - Ok(DnsReply { - answers, - error, - }) + if self.tcp { + match TcpClientConnection::new(*address) { + Ok(conn) => self.resolve_with(conn, name, query_type), + Err(e) => return Resolving::new(future::err(e.into())), + } + } else { + match UdpClientConnection::new(*address) { + Ok(conn) => self.resolve_with(conn, name, query_type), + Err(e) => return Resolving::new(future::err(e.into())), + } + } } } @@ -243,99 +251,39 @@ impl DnsReply { bail!("dns server returned error: {:?}", error) } - let ips = self.answers.iter() + let ips = self + .answers + .iter() .flat_map(|x| match x.1 { RData::A(ip) => Some(IpAddr::V4(ip.clone())), RData::AAAA(ip) => Some(IpAddr::V6(ip.clone())), _ => None, - }) - .collect(); + }).collect(); + Ok(ips) } } -#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)] -pub struct AsyncResolver { - pub ns: Vec, - #[serde(default)] - pub tcp: bool, -} - -impl AsyncResolver { - pub fn cloudflare() -> AsyncResolver { - AsyncResolver { - ns: vec![ - "1.1.1.1:53".parse().unwrap(), - "1.0.0.1:53".parse().unwrap(), - ], - tcp: false, - } - } - - /// Create a new resolver from /etc/resolv.conf - pub fn from_system() -> Result { - let ns = system_conf::read_system_conf()?; - Ok(AsyncResolver { - ns, - tcp: false, - }) - } - - fn resolve_with(conn: T, name: &Name, query_type: RecordType) -> Result<(ClientFuture, Resolving)> { - let client = SyncClient::new(conn); - - let (bg, mut client) = client.new_future(); - - let fut = client.query(name.clone(), DNSClass::IN, query_type) - .map_err(Error::from) - .and_then(|response| { - let error = DnsError::from_response_code(&response.response_code()); - - let answers = response.answers().iter() - .map(|x| { - let name = dns_name_to_string(x.name())?; - let rdata = x.rdata().into(); - Ok((name, rdata)) - }) - .collect::>>()?; - - Ok(DnsReply { - answers, - error, - }) - }); - - Ok((bg, Resolving(Box::new(fut)))) +/// A `Future` that represents a resolving DNS query. +#[must_use = "futures do nothing unless polled"] +pub struct Resolving(Box + Send>); + +impl Resolving { + /// Creates a new `Resolving` future. + pub(crate) fn new(inner: F) -> Self + where + F: Future + Send + 'static, + { + Resolving(Box::new(inner)) } - pub fn resolve(&self, name: &str, query_type: RecordType) -> Result<(AsyncResolverFuture, Resolving)> { - let name = Name::from_str(name)?; - - let address = self.ns.iter().next() - .ok_or_else(|| format_err!("No nameserver configured"))?; - - if self.tcp { - let conn = TcpClientConnection::new(*address)?; - let (bg, fut) = Self::resolve_with(conn, &name, query_type)?; - Ok((AsyncResolverFuture::Tcp(bg), fut)) - } else { - let conn = UdpClientConnection::new(*address)?; - let (bg, fut) = Self::resolve_with(conn, &name, query_type)?; - Ok((AsyncResolverFuture::Udp(bg), fut)) - } + /// Drives this future to completion, eventually returning a DNS reply. + pub fn wait_for_response(self) -> Result { + let mut rt = Runtime::new()?; + rt.block_on(self) } } -pub enum AsyncResolverFuture { - Udp(ClientFuture, DnsMultiplexer, xfer::DnsMultiplexerSerialResponse>), - Tcp(ClientFuture, Signer>, DnsMultiplexer, Signer>, xfer::DnsMultiplexerSerialResponse>), -} - -/// A Future representing work to connect to a URL -pub struct Resolving( - Box + Send>, -); - impl Future for Resolving { type Item = DnsReply; type Error = Error; @@ -345,21 +293,24 @@ impl Future for Resolving { } } - #[cfg(test)] mod tests { extern crate serde_json; use super::*; + use tokio::runtime::current_thread::Runtime; #[test] fn verify_dns_config() { + let mut runtime = Runtime::new().unwrap(); + let config = Resolver::from_system().expect("DnsConfig::from_system"); let json = serde_json::to_string(&config).expect("to json"); println!("{:?}", json); let resolver = serde_json::from_str::(&json).expect("to json"); - resolver.resolve("example.com", RecordType::A).expect("resolve failed"); + let fut = resolver.resolve("example.com", RecordType::A); + runtime.block_on(fut).expect("resolve failed"); } #[test] @@ -370,42 +321,56 @@ mod tests { #[test] fn verify_dns_query() { + let mut runtime = Runtime::new().unwrap(); let resolver = Resolver::from_system().expect("DnsConfig::from_system"); - let x = resolver.resolve("example.com", RecordType::A).expect("resolve failed"); + let fut = resolver.resolve("example.com", RecordType::A); + let x = runtime.block_on(fut).expect("resolve failed"); println!("{:?}", x); assert!(x.error.is_none()); } #[test] fn verify_dns_query_timeout() { + let mut runtime = Runtime::new().unwrap(); let resolver = Resolver { ns: vec!["1.2.3.4:53".parse().unwrap()], tcp: false, timeout: Some(Duration::from_millis(100)), }; - let x = resolver.resolve("example.com", RecordType::A); + let fut = resolver.resolve("example.com", RecordType::A); + let x = runtime.block_on(fut); assert!(x.is_err()); } #[test] fn verify_dns_query_nx() { + let mut runtime = Runtime::new().unwrap(); let resolver = Resolver::from_system().expect("DnsConfig::from_system"); - let x = resolver.resolve("nonexistant.example.com", RecordType::A).expect("resolve failed"); + let fut = resolver.resolve("nonexistant.example.com", RecordType::A); + let x = runtime.block_on(fut).expect("resolve failed"); println!("{:?}", x); - assert_eq!(x, DnsReply { - answers: Vec::new(), - error: Some(DnsError::NXDomain), - }); + assert_eq!( + x, + DnsReply { + answers: Vec::new(), + error: Some(DnsError::NXDomain), + } + ); } #[test] fn verify_dns_query_empty_cname() { + let mut runtime = Runtime::new().unwrap(); let resolver = Resolver::from_system().expect("DnsConfig::from_system"); - let x = resolver.resolve("example.com", RecordType::CNAME).expect("resolve failed"); + let fut = resolver.resolve("example.com", RecordType::CNAME); + let x = runtime.block_on(fut).expect("resolve failed"); println!("{:?}", x); - assert_eq!(x, DnsReply { - answers: Vec::new(), - error: None, - }); + assert_eq!( + x, + DnsReply { + answers: Vec::new(), + error: None, + } + ); } } diff --git a/src/dns/system_conf/unix.rs b/src/dns/system_conf/unix.rs index 5f7a6d3..2e499da 100644 --- a/src/dns/system_conf/unix.rs +++ b/src/dns/system_conf/unix.rs @@ -1,19 +1,19 @@ use errors::*; use resolv_conf; use std::fs; -use std::net::{SocketAddr, IpAddr}; - +use std::net::{IpAddr, SocketAddr}; pub fn read_system_conf() -> Result> { let r = fs::read("/etc/resolv.conf")?; let conf = resolv_conf::Config::parse(&r)?; - let ns = conf.nameservers.into_iter() + let ns = conf + .nameservers + .into_iter() .map(|x| match x { resolv_conf::ScopedIp::V4(x) => IpAddr::V4(x), resolv_conf::ScopedIp::V6(x, _) => IpAddr::V6(x), - }) - .map(|x| SocketAddr::new(x, 53)) + }).map(|x| SocketAddr::new(x, 53)) .collect(); Ok(ns) } diff --git a/src/dns/system_conf/windows.rs b/src/dns/system_conf/windows.rs index 36445fd..05e178f 100644 --- a/src/dns/system_conf/windows.rs +++ b/src/dns/system_conf/windows.rs @@ -2,7 +2,6 @@ use errors::*; use ipconfig::get_adapters; use std::net::SocketAddr; - pub fn read_system_conf() -> Result> { let ns = get_adapters()? .iter() diff --git a/src/lib.rs b/src/lib.rs index 5049236..70455c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,54 +12,55 @@ //! let resolver = Resolver::cloudflare(); //! let client = Client::new(resolver); //! -//! let reply = client.get("https://httpbin.org/anything").expect("request failed"); +//! let reply = client.get("https://httpbin.org/anything").wait_for_response().expect("request failed"); //! println!("{:#?}", reply); //! ``` #![warn(unused_extern_crates)] -pub extern crate hyper; +extern crate bytes; +extern crate ct_logs; +extern crate futures; pub extern crate http; -extern crate tokio; -extern crate rustls; +pub extern crate hyper; extern crate hyper_rustls; -extern crate webpki_roots; -extern crate ct_logs; +extern crate rustls; +extern crate tokio; extern crate trust_dns; -extern crate trust_dns_proto; -extern crate futures; -extern crate bytes; -#[macro_use] extern crate serde_derive; -#[macro_use] extern crate failure; -#[macro_use] extern crate log; +extern crate webpki_roots; +#[macro_use] +extern crate serde_derive; +#[macro_use] +extern crate failure; +#[macro_use] +extern crate log; -#[cfg(unix)] -extern crate resolv_conf; #[cfg(windows)] extern crate ipconfig; +#[cfg(unix)] +extern crate resolv_conf; -pub use hyper::Body; -use http::response::Parts; +use bytes::Bytes; pub use http::header; -use hyper_rustls::HttpsConnector; -use hyper::rt::Future; -use hyper::client::connect::HttpConnector; +use http::response::Parts; pub use http::Request; -use bytes::Bytes; +use hyper::client::connect::HttpConnector; +use hyper::rt::Future; +pub use hyper::Body; +use hyper_rustls::HttpsConnector; -use tokio::runtime::Runtime; +use futures::{future, Poll, Stream}; use tokio::prelude::FutureExt; -use futures::{future, Stream}; +use tokio::runtime::Runtime; -use std::net::IpAddr; +pub use http::Uri; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::Duration; -pub use http::Uri; mod connector; pub mod dns; use self::connector::Connector; -pub use dns::{Resolver, DnsResolver, RecordType}; +pub use dns::{DnsResolver, RecordType, Resolver}; pub mod errors { pub use failure::{Error, ResultExt}; @@ -67,113 +68,128 @@ pub mod errors { } pub use errors::*; - +/// A Client to make outgoing HTTP requests. +/// +/// Uses an specific DNS resolver. #[derive(Debug)] pub struct Client { - client: Arc>>>, - resolver: R, - records: Arc>>, + client: Arc>>>, timeout: Option, } -impl Client { - /// Create a new client with a specific dns resolver. +impl Client { + /// Create a new client with a specific DNS resolver. /// - /// This bypasses /etc/resolv.conf + /// This bypasses `/etc/resolv.conf`. pub fn new(resolver: R) -> Client { - let records = Arc::new(Mutex::new(HashMap::new())); - let https = Connector::https(records.clone()); + let https = Connector::https(Arc::new(resolver)); let client = hyper::Client::builder() .keep_alive(false) .build::<_, hyper::Body>(https); Client { client: Arc::new(client), - resolver, - records, timeout: None, } } - /// Set a timeout, default is no timeout + /// Set a timeout (default setting is no timeout). pub fn timeout(&mut self, timeout: Duration) { self.timeout = Some(timeout); } - /// Pre-populate the dns-cache. This function is usually called internally - pub fn pre_resolve(&self, uri: &Uri) -> Result<()> { - let host = match uri.host() { - Some(host) => host, - None => bail!("url has no host"), - }; - - let record = self.resolver.resolve(&host, RecordType::A)?; - match record.success()?.into_iter().next() { - Some(record) => { - // TODO: make sure we only add the records we want - let mut cache = self.records.lock().unwrap(); - cache.insert(host.to_string(), record); - }, - None => bail!("no record found"), - } - Ok(()) - } - - /// Shorthand function to do a GET request with [`HttpClient::request`] + /// Shorthand function to do a GET request with [`HttpClient::request`]. /// /// [`HttpClient::request`]: trait.HttpClient.html#tymethod.request - pub fn get(&self, url: &str) -> Result { - let url = url.parse::()?; + pub fn get(&self, url: &str) -> ResponseFuture { + let url = match url.parse::() { + Ok(url) => url, + Err(e) => return ResponseFuture::new(future::err(e.into())), + }; let mut request = Request::builder(); - let request = request.uri(url) - .body(Body::empty())?; + let request = match request.uri(url).body(Body::empty()) { + Ok(request) => request, + Err(e) => return ResponseFuture::new(future::err(e.into())), + }; self.request(request) } } impl Client { - /// Create a new client with the system resolver from /etc/resolv.conf + /// Create a new client with the system resolver from `/etc/resolv.conf`. pub fn with_system_resolver() -> Result> { let resolver = Resolver::from_system()?; Ok(Client::new(resolver)) } } +/// Generic abstraction over HTTP clients. pub trait HttpClient { - fn request(&self, request: Request) -> Result; + fn request(&self, request: Request) -> ResponseFuture; } -impl HttpClient for Client { - fn request(&self, request: Request) -> Result { - info!("sending request to {:?}", request.uri()); - self.pre_resolve(request.uri())?; - +impl HttpClient for Client { + fn request(&self, request: Request) -> ResponseFuture { let client = self.client.clone(); let timeout = self.timeout.clone(); - let mut rt = Runtime::new()?; - let fut = client.request(request) + info!("sending request to {:?}", request.uri()); + let fut = client.request(request).map_err(Error::from) .and_then(|res| { debug!("http response: {:?}", res); let (parts, body) = res.into_parts(); - let body = body.concat2(); + let body = body.concat2().map_err(Error::from); (future::ok(parts), body) - }); + }).map_err(|e| e.compat()); - let (parts, body) = match timeout { - Some(timeout) => rt.block_on(fut.timeout(timeout))?, - None => rt.block_on(fut)?, + let fut: Box + Send> = match timeout { + Some(timeout) => Box::new(fut.timeout(timeout).map_err(Error::from)), + None => Box::new(fut.map_err(Error::from)), }; - let body = body.into_bytes(); - let reply = Response::from((parts, body)); - info!("got reply {:?}", reply); - Ok(reply) + let reply = fut.and_then(|(parts, body)| { + let body = body.into_bytes(); + let reply = Response::from((parts, body)); + info!("got reply {:?}", reply); + Ok(reply) + }); + + ResponseFuture::new(reply) + } +} + +/// A `Future` that will resolve to an HTTP Response. +#[must_use = "futures do nothing unless polled"] +pub struct ResponseFuture(Box + Send>); + +impl ResponseFuture { + /// Creates a new `ResponseFuture`. + pub(crate) fn new(inner: F) -> Self + where + F: Future + Send + 'static, + { + ResponseFuture(Box::new(inner)) + } + + /// Drives this future to completion, eventually returning an HTTP response. + pub fn wait_for_response(self) -> Result { + let mut rt = Runtime::new()?; + rt.block_on(self) } } +impl Future for ResponseFuture { + type Item = Response; + type Error = Error; + + fn poll(&mut self) -> Poll { + self.0.poll() + } +} + +/// Represents an HTTP response. #[derive(Debug)] pub struct Response { pub status: u16, @@ -187,9 +203,12 @@ impl From<(Parts, Bytes)> for Response { let parts = x.0; let body = x.1; - let cookies = parts.headers.get_all("set-cookie").into_iter() - .flat_map(|x| x.to_str().map(|x| x.to_owned()).ok()) - .collect(); + let cookies = parts + .headers + .get_all("set-cookie") + .into_iter() + .flat_map(|x| x.to_str().map(|x| x.to_owned()).ok()) + .collect(); let mut headers = HashMap::new(); @@ -213,19 +232,21 @@ impl From<(Parts, Bytes)> for Response { } } - #[cfg(test)] mod tests { use super::*; use dns::Resolver; - use std::time::{Instant, Duration}; + use std::time::{Duration, Instant}; #[test] fn verify_200_http() { let resolver = Resolver::cloudflare(); let client = Client::new(resolver); - let reply = client.get("http://httpbin.org/anything").expect("request failed"); + let reply = client + .get("http://httpbin.org/anything") + .wait_for_response() + .expect("request failed"); assert_eq!(reply.status, 200); } @@ -234,14 +255,20 @@ mod tests { let resolver = Resolver::cloudflare(); let client = Client::new(resolver); - let reply = client.get("https://httpbin.org/anything").expect("request failed"); + let reply = client + .get("https://httpbin.org/anything") + .wait_for_response() + .expect("request failed"); assert_eq!(reply.status, 200); } #[test] fn verify_200_https_system_resolver() { let client = Client::with_system_resolver().expect("failed to create client"); - let reply = client.get("https://httpbin.org/anything").expect("request failed"); + let reply = client + .get("https://httpbin.org/anything") + .wait_for_response() + .expect("request failed"); assert_eq!(reply.status, 200); } @@ -250,7 +277,10 @@ mod tests { let resolver = Resolver::cloudflare(); let client = Client::new(resolver); - let reply = client.get("https://httpbin.org/redirect-to?url=/anything&status=302").expect("request failed"); + let reply = client + .get("https://httpbin.org/redirect-to?url=/anything&status=302") + .wait_for_response() + .expect("request failed"); assert_eq!(reply.status, 302); } @@ -262,7 +292,7 @@ mod tests { client.timeout(Duration::from_millis(250)); let start = Instant::now(); - let _reply = client.get("http://1.2.3.4").err(); + let _reply = client.get("http://1.2.3.4").wait_for_response().err(); let end = Instant::now(); assert!(end.duration_since(start) < Duration::from_secs(1));