Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): add a Connection Pool #486

Merged
merged 1 commit into from
Apr 30, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,17 @@ use status::StatusClass::Redirection;
use {Url, HttpResult};
use HttpError::HttpUriError;

pub use self::pool::Pool;
pub use self::request::Request;
pub use self::response::Response;

pub mod pool;
pub mod request;
pub mod response;

/// A Client to use additional features with Requests.
///
/// Clients can handle things such as: redirect policy.
/// Clients can handle things such as: redirect policy, connection pooling.
pub struct Client {
connector: Connector,
redirect_policy: RedirectPolicy,
Expand All @@ -64,7 +66,12 @@ impl Client {

/// Create a new Client.
pub fn new() -> Client {
Client::with_connector(HttpConnector(None))
Client::with_pool_config(Default::default())
}

/// Create a new Client with a configured Pool Config.
pub fn with_pool_config(config: pool::Config) -> Client {
Client::with_connector(Pool::new(config))
}

/// Create a new client with a specific connector.
Expand All @@ -78,7 +85,10 @@ impl Client {

/// Set the SSL verifier callback for use with OpenSSL.
pub fn set_ssl_verifier(&mut self, verifier: ContextVerifier) {
self.connector = with_connector(HttpConnector(Some(verifier)));
self.connector = with_connector(Pool::with_connector(
Default::default(),
HttpConnector(Some(verifier))
));
}

/// Set the RedirectPolicy.
Expand Down
227 changes: 227 additions & 0 deletions src/client/pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
//! Client Connection Pooling
use std::borrow::ToOwned;
use std::collections::HashMap;
use std::io::{self, Read, Write};
use std::net::{SocketAddr, Shutdown};
use std::sync::{Arc, Mutex};

use net::{NetworkConnector, NetworkStream, HttpConnector};

/// The `NetworkConnector` that behaves as a connection pool used by hyper's `Client`.
pub struct Pool<C: NetworkConnector> {
connector: C,
inner: Arc<Mutex<PoolImpl<<C as NetworkConnector>::Stream>>>
}

/// Config options for the `Pool`.
#[derive(Debug)]
pub struct Config {
/// The maximum idle connections *per host*.
pub max_idle: usize,
}

impl Default for Config {
#[inline]
fn default() -> Config {
Config {
max_idle: 5,
}
}
}

#[derive(Debug)]
struct PoolImpl<S> {
conns: HashMap<Key, Vec<S>>,
config: Config,
}

type Key = (String, u16, Scheme);

fn key<T: Into<Scheme>>(host: &str, port: u16, scheme: T) -> Key {
(host.to_owned(), port, scheme.into())
}

#[derive(Clone, PartialEq, Eq, Debug, Hash)]
enum Scheme {
Http,
Https,
Other(String)
}

impl<'a> From<&'a str> for Scheme {
fn from(s: &'a str) -> Scheme {
match s {
"http" => Scheme::Http,
"https" => Scheme::Https,
s => Scheme::Other(String::from(s))
}
}
}

impl Pool<HttpConnector> {
/// Creates a `Pool` with an `HttpConnector`.
#[inline]
pub fn new(config: Config) -> Pool<HttpConnector> {
Pool::with_connector(config, HttpConnector(None))
}
}

impl<C: NetworkConnector> Pool<C> {
/// Creates a `Pool` with a specified `NetworkConnector`.
#[inline]
pub fn with_connector(config: Config, connector: C) -> Pool<C> {
Pool {
connector: connector,
inner: Arc::new(Mutex::new(PoolImpl {
conns: HashMap::new(),
config: config,
}))
}
}

/// Clear all idle connections from the Pool, closing them.
#[inline]
pub fn clear_idle(&mut self) {
self.inner.lock().unwrap().conns.clear();
}
}

impl<S> PoolImpl<S> {
fn reuse(&mut self, key: Key, conn: S) {
trace!("reuse {:?}", key);
let conns = self.conns.entry(key).or_insert(vec![]);
if conns.len() < self.config.max_idle {
conns.push(conn);
}
}
}

impl<C: NetworkConnector<Stream=S>, S: NetworkStream + Send> NetworkConnector for Pool<C> {
type Stream = PooledStream<S>;
fn connect(&mut self, host: &str, port: u16, scheme: &str) -> io::Result<PooledStream<S>> {
let key = key(host, port, scheme);
let mut locked = self.inner.lock().unwrap();
let mut should_remove = false;
let conn = match locked.conns.get_mut(&key) {
Some(ref mut vec) => {
should_remove = vec.len() == 1;
vec.pop().unwrap()
}
_ => try!(self.connector.connect(host, port, scheme))
};
if should_remove {
locked.conns.remove(&key);
}
Ok(PooledStream {
inner: Some((key, conn)),
is_closed: false,
is_drained: false,
pool: self.inner.clone()
})
}
}

/// A Stream that will try to be returned to the Pool when dropped.
pub struct PooledStream<S> {
inner: Option<(Key, S)>,
is_closed: bool,
is_drained: bool,
pool: Arc<Mutex<PoolImpl<S>>>
}

impl<S: NetworkStream> Read for PooledStream<S> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.inner.as_mut().unwrap().1.read(buf) {
Ok(0) => {
self.is_drained = true;
Ok(0)
}
r => r
}
}
}

impl<S: NetworkStream> Write for PooledStream<S> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.as_mut().unwrap().1.write(buf)
}

#[inline]
fn flush(&mut self) -> io::Result<()> {
self.inner.as_mut().unwrap().1.flush()
}
}

impl<S: NetworkStream> NetworkStream for PooledStream<S> {
#[inline]
fn peer_addr(&mut self) -> io::Result<SocketAddr> {
self.inner.as_mut().unwrap().1.peer_addr()
}

#[inline]
fn close(&mut self, how: Shutdown) -> io::Result<()> {
self.is_closed = true;
self.inner.as_mut().unwrap().1.close(how)
}
}

impl<S> Drop for PooledStream<S> {
fn drop(&mut self) {
trace!("PooledStream.drop, is_closed={}, is_drained={}", self.is_closed, self.is_drained);
if !self.is_closed && self.is_drained {
self.inner.take().map(|(key, conn)| {
if let Ok(mut pool) = self.pool.lock() {
pool.reuse(key, conn);
}
// else poisoned, give up
});
}
}
}

#[cfg(test)]
mod tests {
use std::net::Shutdown;
use mock::MockConnector;
use net::{NetworkConnector, NetworkStream};

use super::{Pool, key};

macro_rules! mocked {
() => ({
Pool::with_connector(Default::default(), MockConnector)
})
}

#[test]
fn test_connect_and_drop() {
let mut pool = mocked!();
let key = key("127.0.0.1", 3000, "http");
pool.connect("127.0.0.1", 3000, "http").unwrap().is_drained = true;
{
let locked = pool.inner.lock().unwrap();
assert_eq!(locked.conns.len(), 1);
assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
}
pool.connect("127.0.0.1", 3000, "http").unwrap().is_drained = true; //reused
{
let locked = pool.inner.lock().unwrap();
assert_eq!(locked.conns.len(), 1);
assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
}
}

#[test]
fn test_closed() {
let mut pool = mocked!();
let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
stream.close(Shutdown::Both).unwrap();
drop(stream);
let locked = pool.inner.lock().unwrap();
assert_eq!(locked.conns.len(), 0);
}


}
8 changes: 6 additions & 2 deletions src/client/request.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
//! Client Requests
use std::marker::PhantomData;
use std::io::{self, Write, BufWriter};
use std::net::Shutdown;

use url::Url;

use method::{self, Method};
use header::Headers;
use header::{self, Host};
use net::{NetworkStream, NetworkConnector, HttpConnector, Fresh, Streaming};
use http::{HttpWriter, LINE_ENDING};
use http::{self, HttpWriter, LINE_ENDING};
use http::HttpWriter::{ThroughWriter, ChunkedWriter, SizedWriter, EmptyWriter};
use version;
use HttpResult;
Expand Down Expand Up @@ -154,7 +155,10 @@ impl Request<Streaming> {
///
/// Consumes the Request.
pub fn send(self) -> HttpResult<Response> {
let raw = try!(self.body.end()).into_inner().unwrap(); // end() already flushes
let mut raw = try!(self.body.end()).into_inner().unwrap(); // end() already flushes
if !http::should_keep_alive(self.version, &self.headers) {
try!(raw.close(Shutdown::Write));
}
Response::new(raw)
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/client/response.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Client Responses
use std::io::{self, Read};
use std::marker::PhantomData;
use std::net::Shutdown;

use buffer::BufReader;
use header;
Expand Down Expand Up @@ -42,6 +43,10 @@ impl Response {
debug!("version={:?}, status={:?}", head.version, status);
debug!("headers={:?}", headers);

if !http::should_keep_alive(head.version, &headers) {
try!(stream.get_mut().close(Shutdown::Write));
}

let body = if headers.has::<TransferEncoding>() {
match headers.get::<TransferEncoding>() {
Some(&TransferEncoding(ref codings)) => {
Expand Down
12 changes: 11 additions & 1 deletion src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use std::fmt;
use httparse;

use buffer::BufReader;
use header::Headers;
use header::{Headers, Connection};
use header::ConnectionOption::{Close, KeepAlive};
use method::Method;
use status::StatusCode;
use uri::RequestUri;
Expand Down Expand Up @@ -443,6 +444,15 @@ pub const LINE_ENDING: &'static str = "\r\n";
#[derive(Clone, PartialEq, Debug)]
pub struct RawStatus(pub u16, pub Cow<'static, str>);

/// Checks if a connection should be kept alive.
pub fn should_keep_alive(version: HttpVersion, headers: &Headers) -> bool {
match (version, headers.get::<Connection>()) {
(Http10, Some(conn)) if !conn.contains(&KeepAlive) => false,
(Http11, Some(conn)) if conn.contains(&Close) => false,
_ => true
}
}

#[cfg(test)]
mod tests {
use std::io::{self, Write};
Expand Down
16 changes: 15 additions & 1 deletion src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::any::{Any, TypeId};
use std::fmt;
use std::io::{self, Read, Write};
use std::net::{SocketAddr, ToSocketAddrs, TcpStream, TcpListener};
use std::net::{SocketAddr, ToSocketAddrs, TcpStream, TcpListener, Shutdown};
use std::mem;
use std::path::Path;
use std::sync::Arc;
Expand Down Expand Up @@ -57,6 +57,10 @@ impl<'a, N: NetworkListener + 'a> Iterator for NetworkConnections<'a, N> {
pub trait NetworkStream: Read + Write + Any + Send + Typeable {
/// Get the remote address of the underlying connection.
fn peer_addr(&mut self) -> io::Result<SocketAddr>;
/// This will be called when Stream should no longer be kept alive.
fn close(&mut self, _how: Shutdown) -> io::Result<()> {
Ok(())
}
}

/// A connector creates a NetworkStream.
Expand Down Expand Up @@ -123,6 +127,7 @@ impl NetworkStream + Send {
}

/// If the underlying type is T, extract it.
#[inline]
pub fn downcast<T: Any>(self: Box<NetworkStream + Send>)
-> Result<Box<T>, Box<NetworkStream + Send>> {
if self.is::<T>() {
Expand Down Expand Up @@ -277,12 +282,21 @@ impl Write for HttpStream {
}

impl NetworkStream for HttpStream {
#[inline]
fn peer_addr(&mut self) -> io::Result<SocketAddr> {
match *self {
HttpStream::Http(ref mut inner) => inner.0.peer_addr(),
HttpStream::Https(ref mut inner) => inner.get_mut().0.peer_addr()
}
}

#[inline]
fn close(&mut self, how: Shutdown) -> io::Result<()> {
match *self {
HttpStream::Http(ref mut inner) => inner.0.shutdown(how),
HttpStream::Https(ref mut inner) => inner.get_mut().0.shutdown(how)
}
}
}

/// A connector that will produce HttpStreams.
Expand Down
Loading