diff --git a/benches/end_to_end.rs b/benches/end_to_end.rs index 83b090fa44..3b68543790 100644 --- a/benches/end_to_end.rs +++ b/benches/end_to_end.rs @@ -8,14 +8,13 @@ extern crate tokio_core; use std::net::SocketAddr; -use futures::{future, Future, Stream}; +use futures::{Future, Stream}; use tokio_core::reactor::{Core, Handle}; use tokio_core::net::TcpListener; use hyper::client; use hyper::header::{ContentLength, ContentType}; use hyper::Method; -use hyper::server::{self, Service}; #[bench] @@ -42,13 +41,15 @@ fn get_one_at_a_time(b: &mut test::Bencher) { #[bench] fn post_one_at_a_time(b: &mut test::Bencher) { + extern crate pretty_env_logger; + let _ = pretty_env_logger::try_init(); let mut core = Core::new().unwrap(); let handle = core.handle(); let addr = spawn_hello(&handle); let client = hyper::Client::new(&handle); - let url: hyper::Uri = format!("http://{}/get", addr).parse().unwrap(); + let url: hyper::Uri = format!("http://{}/post", addr).parse().unwrap(); let post = "foo bar baz quux"; b.bytes = 180 * 2 + post.len() as u64 + PHRASE.len() as u64; @@ -69,35 +70,32 @@ fn post_one_at_a_time(b: &mut test::Bencher) { static PHRASE: &'static [u8] = include_bytes!("../CHANGELOG.md"); //b"Hello, World!"; -#[derive(Clone, Copy)] -struct Hello; - -impl Service for Hello { - type Request = server::Request; - type Response = server::Response; - type Error = hyper::Error; - type Future = future::FutureResult; - fn call(&self, _req: Self::Request) -> Self::Future { - future::ok( - server::Response::new() - .with_header(ContentLength(PHRASE.len() as u64)) - .with_header(ContentType::plaintext()) - .with_body(PHRASE) - ) - } - -} - fn spawn_hello(handle: &Handle) -> SocketAddr { + use hyper::server::{const_service, service_fn, NewService, Request, Response}; let addr = "127.0.0.1:0".parse().unwrap(); let listener = TcpListener::bind(&addr, handle).unwrap(); let addr = listener.local_addr().unwrap(); let handle2 = handle.clone(); let http = hyper::server::Http::::new(); + + let service = const_service(service_fn(|req: Request| { + req.body() + .concat2() + .map(|_| { + Response::::new() + .with_header(ContentLength(PHRASE.len() as u64)) + .with_header(ContentType::plaintext()) + .with_body(PHRASE) + }) + })); + + let mut conns = 0; handle.spawn(listener.incoming().for_each(move |(socket, _addr)| { + conns += 1; + assert_eq!(conns, 1, "should only need 1 connection"); handle2.spawn( - http.serve_connection(socket, Hello) + http.serve_connection(socket, service.new_service()?) .map(|_| ()) .map_err(|_| ()) ); diff --git a/src/client/cancel.rs b/src/client/cancel.rs deleted file mode 100644 index bc29dbf60c..0000000000 --- a/src/client/cancel.rs +++ /dev/null @@ -1,154 +0,0 @@ -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, Ordering}; - -use futures::{Async, Future, Poll}; -use futures::task::{self, Task}; - -use common::Never; - -use self::lock::Lock; - -#[derive(Clone)] -pub struct Cancel { - inner: Arc, -} - -pub struct Canceled { - inner: Arc, -} - -struct Inner { - is_canceled: AtomicBool, - task: Lock>, -} - -impl Cancel { - pub fn new() -> (Cancel, Canceled) { - let inner = Arc::new(Inner { - is_canceled: AtomicBool::new(false), - task: Lock::new(None), - }); - let inner2 = inner.clone(); - ( - Cancel { - inner: inner, - }, - Canceled { - inner: inner2, - }, - ) - } - - pub fn cancel(&self) { - if !self.inner.is_canceled.swap(true, Ordering::SeqCst) { - if let Some(mut locked) = self.inner.task.try_lock() { - if let Some(task) = locked.take() { - task.notify(); - } - } - // if we couldn't take the lock, Canceled was trying to park. - // After parking, it will check is_canceled one last time, - // so we can just stop here. - } - } - - pub fn is_canceled(&self) -> bool { - self.inner.is_canceled.load(Ordering::SeqCst) - } -} - -impl Canceled { - pub fn cancel(&self) { - self.inner.is_canceled.store(true, Ordering::SeqCst); - } -} - -impl Future for Canceled { - type Item = (); - type Error = Never; - - fn poll(&mut self) -> Poll { - if self.inner.is_canceled.load(Ordering::SeqCst) { - Ok(Async::Ready(())) - } else { - if let Some(mut locked) = self.inner.task.try_lock() { - if locked.is_none() { - // it's possible a Cancel just tried to cancel on another thread, - // and we just missed it. Once we have the lock, we should check - // one more time before parking this task and going away. - if self.inner.is_canceled.load(Ordering::SeqCst) { - return Ok(Async::Ready(())); - } - *locked = Some(task::current()); - } - Ok(Async::NotReady) - } else { - // if we couldn't take the lock, then a Cancel taken has it. - // The *ONLY* reason is because it is in the process of canceling. - Ok(Async::Ready(())) - } - } - } -} - -impl Drop for Canceled { - fn drop(&mut self) { - self.cancel(); - } -} - - -// a sub module just to protect unsafety -mod lock { - use std::cell::UnsafeCell; - use std::ops::{Deref, DerefMut}; - use std::sync::atomic::{AtomicBool, Ordering}; - - pub struct Lock { - is_locked: AtomicBool, - value: UnsafeCell, - } - - impl Lock { - pub fn new(val: T) -> Lock { - Lock { - is_locked: AtomicBool::new(false), - value: UnsafeCell::new(val), - } - } - - pub fn try_lock(&self) -> Option> { - if !self.is_locked.swap(true, Ordering::SeqCst) { - Some(Locked { lock: self }) - } else { - None - } - } - } - - unsafe impl Send for Lock {} - unsafe impl Sync for Lock {} - - pub struct Locked<'a, T: 'a> { - lock: &'a Lock, - } - - impl<'a, T> Deref for Locked<'a, T> { - type Target = T; - fn deref(&self) -> &T { - unsafe { &*self.lock.value.get() } - } - } - - impl<'a, T> DerefMut for Locked<'a, T> { - fn deref_mut(&mut self) -> &mut T { - unsafe { &mut *self.lock.value.get() } - } - } - - impl<'a, T> Drop for Locked<'a, T> { - fn drop(&mut self) { - self.lock.is_locked.store(false, Ordering::SeqCst); - } - } -} diff --git a/src/client/conn.rs b/src/client/conn.rs new file mode 100644 index 0000000000..5c8dc8b56e --- /dev/null +++ b/src/client/conn.rs @@ -0,0 +1,436 @@ +//! Lower-level client connection API. +//! +//! The types in thie module are to provide a lower-level API based around a +//! single connection. Connecting to a host, pooling connections, and the like +//! are not handled at this level. This module provides the building blocks to +//! customize those things externally. +//! +//! If don't have need to manage connections yourself, consider using the +//! higher-level [Client](super) API. +use std::fmt; +use std::marker::PhantomData; + +use bytes::Bytes; +use futures::{Async, Future, Poll, Stream}; +use futures::future::{self, Either}; +use tokio_io::{AsyncRead, AsyncWrite}; + +use proto; +use super::{dispatch, Request, Response}; + +/// Returns a `Handshake` future over some IO. +/// +/// This is a shortcut for `Builder::new().handshake(io)`. +pub fn handshake(io: T) -> Handshake +where + T: AsyncRead + AsyncWrite, +{ + Builder::new() + .handshake(io) +} + +/// The sender side of an established connection. +pub struct SendRequest { + dispatch: dispatch::Sender, ::Response>, + +} + +/// A future that processes all HTTP state for the IO object. +/// +/// In most cases, this should just be spawned into an executor, so that it +/// can process incoming and outgoing messages, notice hangups, and the like. +#[must_use = "futures do nothing unless polled"] +pub struct Connection +where + T: AsyncRead + AsyncWrite, + B: Stream + 'static, + B::Item: AsRef<[u8]>, +{ + inner: proto::dispatch::Dispatcher< + proto::dispatch::Client, + B, + T, + B::Item, + proto::ClientUpgradeTransaction, + >, +} + + +/// A builder to configure an HTTP connection. +/// +/// After setting options, the builder is used to create a `Handshake` future. +#[derive(Clone, Debug)] +pub struct Builder { + h1_writev: bool, +} + +/// A future setting up HTTP over an IO object. +/// +/// If successful, yields a `(SendRequest, Connection)` pair. +#[must_use = "futures do nothing unless polled"] +pub struct Handshake { + inner: HandshakeInner, +} + +/// A future returned by `SendRequest::send_request`. +/// +/// Yields a `Response` if successful. +#[must_use = "futures do nothing unless polled"] +pub struct ResponseFuture { + // for now, a Box is used to hide away the internal `B` + // that can be returned if canceled + inner: Box>, +} + +/// Deconstructed parts of a `Connection`. +/// +/// This allows taking apart a `Connection` at a later time, in order to +/// reclaim the IO object, and additional related pieces. +#[derive(Debug)] +pub struct Parts { + /// The original IO object used in the handshake. + pub io: T, + /// A buffer of bytes that have been read but not processed as HTTP. + /// + /// For instance, if the `Connection` is used for an HTTP upgrade request, + /// it is possible the server sent back the first bytes of the new protocol + /// along with the response upgrade. + /// + /// You will want to check for any existing bytes if you plan to continue + /// communicating on the IO object. + pub read_buf: Bytes, + _inner: (), +} + +// internal client api + +#[must_use = "futures do nothing unless polled"] +pub(super) struct HandshakeNoUpgrades { + inner: HandshakeInner, +} + +struct HandshakeInner { + builder: Builder, + io: Option, + _marker: PhantomData<(B, R)>, +} + +// ===== impl SendRequest + +impl SendRequest +{ + /// Polls to determine whether this sender can be used yet for a request. + /// + /// If the associated connection is closed, this returns an Error. + pub fn poll_ready(&mut self) -> Poll<(), ::Error> { + self.dispatch.poll_ready() + } + + pub(super) fn is_closed(&self) -> bool { + self.dispatch.is_closed() + } +} + +impl SendRequest +where + B: Stream + 'static, + B::Item: AsRef<[u8]>, +{ + /// Sends a `Request` on the associated connection. + /// + /// Returns a future that if successful, yields the `Response`. + pub fn send_request(&mut self, req: Request) -> ResponseFuture { + let inner = self.send_request_retryable(req).map_err(|e| { + let (err, _orig_req) = e; + err + }); + ResponseFuture { + inner: Box::new(inner), + } + } + + //TODO: replace with `impl Future` when stable + pub(crate) fn send_request_retryable(&mut self, req: Request) -> Box)>)>> { + let (head, body) = proto::request::split(req); + let inner = match self.dispatch.send((head, body)) { + Ok(rx) => { + Either::A(rx.then(move |res| { + match res { + Ok(Ok(res)) => Ok(res), + Ok(Err(err)) => Err(err), + // this is definite bug if it happens, but it shouldn't happen! + Err(_) => panic!("dispatch dropped without returning error"), + } + })) + }, + Err(req) => { + debug!("connection was not ready"); + let err = ::Error::new_canceled(Some("connection was not ready")); + Either::B(future::err((err, Some(req)))) + } + }; + Box::new(inner) + } +} + +/* TODO(0.12.0): when we change from tokio-service to tower. +impl Service for SendRequest { + type Request = Request; + type Response = Response; + type Error = ::Error; + type Future = ResponseFuture; + + fn call(&self, req: Self::Request) -> Self::Future { + + } +} +*/ + +impl fmt::Debug for SendRequest { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("SendRequest") + .finish() + } +} + +// ===== impl Connection + +impl Connection +where + T: AsyncRead + AsyncWrite, + B: Stream + 'static, + B::Item: AsRef<[u8]>, +{ + /// Return the inner IO object, and additional information. + pub fn into_parts(self) -> Parts { + let (io, read_buf) = self.inner.into_inner(); + Parts { + io: io, + read_buf: read_buf, + _inner: (), + } + } + + /// Poll the connection for completion, but without calling `shutdown` + /// on the underlying IO. + /// + /// This is useful to allow running a connection while doing an HTTP + /// upgrade. Once the upgrade is completed, the connection would be "done", + /// but it is not desired to actally shutdown the IO object. Instead you + /// would take it back using `into_parts`. + pub fn poll_without_shutdown(&mut self) -> Poll<(), ::Error> { + self.inner.poll_without_shutdown() + } +} + +impl Future for Connection +where + T: AsyncRead + AsyncWrite, + B: Stream + 'static, + B::Item: AsRef<[u8]>, +{ + type Item = (); + type Error = ::Error; + + fn poll(&mut self) -> Poll { + self.inner.poll() + } +} + +impl fmt::Debug for Connection +where + T: AsyncRead + AsyncWrite + fmt::Debug, + B: Stream + 'static, + B::Item: AsRef<[u8]>, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Connection") + .finish() + } +} + +// ===== impl Builder + +impl Builder { + /// Creates a new connection builder. + #[inline] + pub fn new() -> Builder { + Builder { + h1_writev: true, + } + } + + pub(super) fn h1_writev(&mut self, enabled: bool) -> &mut Builder { + self.h1_writev = enabled; + self + } + + /// Constructs a connection with the configured options and IO. + #[inline] + pub fn handshake(&self, io: T) -> Handshake + where + T: AsyncRead + AsyncWrite, + B: Stream + 'static, + B::Item: AsRef<[u8]>, + { + Handshake { + inner: HandshakeInner { + builder: self.clone(), + io: Some(io), + _marker: PhantomData, + } + } + } + + pub(super) fn handshake_no_upgrades(&self, io: T) -> HandshakeNoUpgrades + where + T: AsyncRead + AsyncWrite, + B: Stream + 'static, + B::Item: AsRef<[u8]>, + { + HandshakeNoUpgrades { + inner: HandshakeInner { + builder: self.clone(), + io: Some(io), + _marker: PhantomData, + } + } + } +} + +// ===== impl Handshake + +impl Future for Handshake +where + T: AsyncRead + AsyncWrite, + B: Stream + 'static, + B::Item: AsRef<[u8]>, +{ + type Item = (SendRequest, Connection); + type Error = ::Error; + + fn poll(&mut self) -> Poll { + self.inner.poll() + .map(|async| { + async.map(|(tx, dispatch)| { + (tx, Connection { inner: dispatch }) + }) + }) + } +} + +impl fmt::Debug for Handshake { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Handshake") + .finish() + } +} + +impl Future for HandshakeNoUpgrades +where + T: AsyncRead + AsyncWrite, + B: Stream + 'static, + B::Item: AsRef<[u8]>, +{ + type Item = (SendRequest, proto::dispatch::Dispatcher< + proto::dispatch::Client, + B, + T, + B::Item, + proto::ClientTransaction, + >); + type Error = ::Error; + + fn poll(&mut self) -> Poll { + self.inner.poll() + } +} + +impl Future for HandshakeInner +where + T: AsyncRead + AsyncWrite, + B: Stream + 'static, + B::Item: AsRef<[u8]>, + R: proto::Http1Transaction< + Incoming=proto::RawStatus, + Outgoing=proto::RequestLine, + >, +{ + type Item = (SendRequest, proto::dispatch::Dispatcher< + proto::dispatch::Client, + B, + T, + B::Item, + R, + >); + type Error = ::Error; + + fn poll(&mut self) -> Poll { + let io = self.io.take().expect("polled more than once"); + let (tx, rx) = dispatch::channel(); + let mut conn = proto::Conn::new(io); + if !self.builder.h1_writev { + conn.set_write_strategy_flatten(); + } + let dispatch = proto::dispatch::Dispatcher::new(proto::dispatch::Client::new(rx), conn); + Ok(Async::Ready(( + SendRequest { + dispatch: tx, + }, + dispatch, + ))) + } +} + +// ===== impl ResponseFuture + +impl Future for ResponseFuture { + type Item = Response; + type Error = ::Error; + + #[inline] + fn poll(&mut self) -> Poll { + self.inner.poll() + } +} + +impl fmt::Debug for ResponseFuture { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("ResponseFuture") + .finish() + } +} + +// assert trait markers + +trait AssertSend: Send {} +trait AssertSendSync: Send + Sync {} + + +#[doc(hidden)] +impl AssertSendSync for SendRequest {} + +#[doc(hidden)] +impl AssertSend for Connection +where + T: AsyncRead + AsyncWrite, + B: Stream, + B::Item: AsRef<[u8]> + Send, +{} + +#[doc(hidden)] +impl AssertSendSync for Connection +where + T: AsyncRead + AsyncWrite, + B: Stream, + B::Item: AsRef<[u8]> + Send + Sync, +{} + +#[doc(hidden)] +impl AssertSendSync for Builder {} + +// TODO: This could be done by using a dispatch channel that doesn't +// return the `B` on Error, removing the possibility of contains some !Send +// thing. +//#[doc(hidden)] +//impl AssertSend for ResponseFuture {} diff --git a/src/client/dispatch.rs b/src/client/dispatch.rs index 5c6267a092..80ed8e1c63 100644 --- a/src/client/dispatch.rs +++ b/src/client/dispatch.rs @@ -1,60 +1,64 @@ -use futures::{Async, Future, Poll, Stream}; +use futures::{Async, Poll, Stream}; use futures::sync::{mpsc, oneshot}; use common::Never; -use super::cancel::{Cancel, Canceled}; +use super::signal; pub type Callback = oneshot::Sender)>>; pub type Promise = oneshot::Receiver)>>; pub fn channel() -> (Sender, Receiver) { - let (tx, rx) = mpsc::unbounded(); - let (cancel, canceled) = Cancel::new(); + let (tx, rx) = mpsc::channel(0); + let (giver, taker) = signal::new(); let tx = Sender { - cancel: cancel, + giver: giver, inner: tx, }; let rx = Receiver { - canceled: canceled, inner: rx, + taker: taker, }; (tx, rx) } pub struct Sender { - cancel: Cancel, - inner: mpsc::UnboundedSender<(T, Callback)>, + // The Giver helps watch that the the Receiver side has been polled + // when the queue is empty. This helps us know when a request and + // response have been fully processed, and a connection is ready + // for more. + giver: signal::Giver, + inner: mpsc::Sender<(T, Callback)>, } impl Sender { - pub fn is_closed(&self) -> bool { - self.cancel.is_canceled() + pub fn poll_ready(&mut self) -> Poll<(), ::Error> { + match self.inner.poll_ready() { + Ok(Async::Ready(())) => { + // there's room in the queue, but does the Connection + // want a message yet? + self.giver.poll_want() + .map_err(|_| ::Error::Closed) + }, + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(_) => Err(::Error::Closed), + } } - pub fn cancel(&self) { - self.cancel.cancel(); + pub fn is_closed(&self) -> bool { + self.giver.is_canceled() } - pub fn send(&self, val: T) -> Result, T> { + pub fn send(&mut self, val: T) -> Result, T> { let (tx, rx) = oneshot::channel(); - self.inner.unbounded_send((val, tx)) + self.inner.try_send((val, tx)) .map(move |_| rx) .map_err(|e| e.into_inner().0) } } -impl Clone for Sender { - fn clone(&self) -> Sender { - Sender { - cancel: self.cancel.clone(), - inner: self.inner.clone(), - } - } -} - pub struct Receiver { - canceled: Canceled, - inner: mpsc::UnboundedReceiver<(T, Callback)>, + inner: mpsc::Receiver<(T, Callback)>, + taker: signal::Taker, } impl Stream for Receiver { @@ -62,16 +66,20 @@ impl Stream for Receiver { type Error = Never; fn poll(&mut self) -> Poll, Self::Error> { - if let Async::Ready(()) = self.canceled.poll()? { - return Ok(Async::Ready(None)); + match self.inner.poll() { + Ok(Async::Ready(item)) => Ok(Async::Ready(item)), + Ok(Async::NotReady) => { + self.taker.want(); + Ok(Async::NotReady) + } + Err(()) => unreachable!("mpsc never errors"), } - self.inner.poll().map_err(|()| unreachable!("mpsc never errors")) } } impl Drop for Receiver { fn drop(&mut self) { - self.canceled.cancel(); + self.taker.cancel(); self.inner.close(); // This poll() is safe to call in `Drop`, because we've @@ -84,8 +92,7 @@ impl Drop for Receiver { // - NotReady: unreachable // - Err: unreachable while let Ok(Async::Ready(Some((val, cb)))) = self.inner.poll() { - // maybe in future, we pass the value along with the error? - let _ = cb.send(Err((::Error::new_canceled(None), Some(val)))); + let _ = cb.send(Err((::Error::new_canceled(None::<::Error>), Some(val)))); } } @@ -109,7 +116,7 @@ mod tests { future::lazy(|| { #[derive(Debug)] struct Custom(i32); - let (tx, rx) = super::channel::(); + let (mut tx, rx) = super::channel::(); let promise = tx.send(Custom(43)).unwrap(); drop(rx); @@ -128,8 +135,8 @@ mod tests { #[cfg(feature = "nightly")] #[bench] - fn cancelable_queue_throughput(b: &mut test::Bencher) { - let (tx, mut rx) = super::channel::(); + fn giver_queue_throughput(b: &mut test::Bencher) { + let (mut tx, mut rx) = super::channel::(); b.iter(move || { ::futures::future::lazy(|| { @@ -149,7 +156,7 @@ mod tests { #[cfg(feature = "nightly")] #[bench] - fn cancelable_queue_not_ready(b: &mut test::Bencher) { + fn giver_queue_not_ready(b: &mut test::Bencher) { let (_tx, mut rx) = super::channel::(); b.iter(move || { @@ -163,11 +170,11 @@ mod tests { #[cfg(feature = "nightly")] #[bench] - fn cancelable_queue_cancel(b: &mut test::Bencher) { - let (tx, _rx) = super::channel::(); + fn giver_queue_cancel(b: &mut test::Bencher) { + let (_tx, rx) = super::channel::(); b.iter(move || { - tx.cancel(); + rx.taker.cancel(); }) } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 0114f4e053..0ad70b06d5 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,14 +1,14 @@ //! HTTP Client -use std::cell::Cell; use std::fmt; use std::io; use std::marker::PhantomData; use std::rc::Rc; +use std::sync::Arc; use std::time::Duration; use futures::{Async, Future, Poll, Stream}; -use futures::future::{self, Either, Executor}; +use futures::future::{self, Executor}; #[cfg(feature = "compat")] use http; use tokio::reactor::Handle; @@ -28,7 +28,7 @@ pub use self::connect::{HttpConnector, Connect}; use self::background::{bg, Background}; -mod cancel; +pub mod conn; mod connect; //TODO(easy): move cancel and dispatch into common instead pub(crate) mod dispatch; @@ -36,6 +36,7 @@ mod dns; mod pool; #[cfg(feature = "compat")] pub mod compat; +mod signal; #[cfg(test)] mod tests; @@ -44,7 +45,7 @@ pub struct Client { connector: Rc, executor: Exec, h1_writev: bool, - pool: Pool>, + pool: Pool>, retry_canceled_requests: bool, set_host: bool, } @@ -191,77 +192,73 @@ where C: Connect, //TODO: replace with `impl Future` when stable fn send_request(&self, req: Request, domain: &Uri) -> Box>> { + //fn send_request(&self, req: Request, domain: &Uri) -> Box> { let url = req.uri().clone(); - let (head, body) = request::split(req); let checkout = self.pool.checkout(domain.as_ref()); let connect = { let executor = self.executor.clone(); let pool = self.pool.clone(); - let pool_key = Rc::new(domain.to_string()); + let pool_key = Arc::new(domain.to_string()); let h1_writev = self.h1_writev; let connector = self.connector.clone(); future::lazy(move || { connector.connect(url) + .from_err() .and_then(move |io| { - let (tx, rx) = dispatch::channel(); - let tx = HyperClient { + conn::Builder::new() + .h1_writev(h1_writev) + .handshake_no_upgrades(io) + }).and_then(move |(tx, conn)| { + executor.execute(conn.map_err(|e| debug!("client connection error: {}", e)))?; + Ok(pool.pooled(pool_key, PoolClient { tx: tx, - should_close: Cell::new(true), - }; - let pooled = pool.pooled(pool_key, tx); - let mut conn = proto::Conn::<_, _, proto::ClientTransaction, _>::new(io, pooled.clone()); - if !h1_writev { - conn.set_write_strategy_flatten(); - } - let dispatch = proto::dispatch::Dispatcher::new(proto::dispatch::Client::new(rx), conn); - executor.execute(dispatch.map_err(|e| debug!("client connection error: {}", e)))?; - Ok(pooled) + })) }) }) }; let race = checkout.select(connect) - .map(|(client, _work)| client) + .map(|(pooled, _work)| pooled) .map_err(|(e, _checkout)| { // the Pool Checkout cannot error, so the only error // is from the Connector // XXX: should wait on the Checkout? Problem is // that if the connector is failing, it may be that we // never had a pooled stream at all - ClientError::Normal(e.into()) + ClientError::Normal(e) }); - let resp = race.and_then(move |client| { - let conn_reused = client.is_reused(); - match client.tx.send((head, body)) { - Ok(rx) => { - client.should_close.set(false); - Either::A(rx.then(move |res| { - match res { - Ok(Ok(res)) => Ok(res), - Ok(Err((err, orig_req))) => Err(match orig_req { - Some(req) => ClientError::Canceled { - connection_reused: conn_reused, - reason: err, - req: req, - }, - None => ClientError::Normal(err), - }), - // this is definite bug if it happens, but it shouldn't happen! - Err(_) => panic!("dispatch dropped without returning error"), + + let executor = self.executor.clone(); + let resp = race.and_then(move |mut pooled| { + let conn_reused = pooled.is_reused(); + let fut = pooled.tx.send_request_retryable(req) + .map_err(move |(err, orig_req)| { + if let Some(req) = orig_req { + ClientError::Canceled { + connection_reused: conn_reused, + reason: err, + req: req, } - })) - }, - Err(req) => { - debug!("pooled connection was not ready"); - let err = ClientError::Canceled { - connection_reused: conn_reused, - reason: ::Error::new_canceled(None), - req: req, - }; - Either::B(future::err(err)) - } - } + } else { + ClientError::Normal(err) + } + }); + + // when pooled is dropped, it will try to insert back into the + // pool. To delay that, spawn a future that completes once the + // sender is ready again. + // + // This *should* only be once the related `Connection` has polled + // for a new request to start. + // + // If the executor doesn't have room, oh well. Things will likely + // be blowing up soon, but this specific task isn't required. + let _ = executor.execute(future::poll_fn(move || { + pooled.tx.poll_ready().map_err(|_| ()) + })); + + fut }); Box::new(resp) @@ -373,35 +370,19 @@ where } } -struct HyperClient { - should_close: Cell, - tx: dispatch::Sender, ::Response>, +struct PoolClient { + tx: conn::SendRequest, } -impl Clone for HyperClient { - fn clone(&self) -> HyperClient { - HyperClient { - tx: self.tx.clone(), - should_close: self.should_close.clone(), - } - } -} - -impl self::pool::Closed for HyperClient { +impl self::pool::Closed for PoolClient +where + B: 'static, +{ fn is_closed(&self) -> bool { self.tx.is_closed() } } -impl Drop for HyperClient { - fn drop(&mut self) { - if self.should_close.get() { - self.should_close.set(false); - self.tx.cancel(); - } - } -} - pub(crate) enum ClientError { Normal(::Error), Canceled { diff --git a/src/client/pool.rs b/src/client/pool.rs index 9dfe527e4e..7531b77e44 100644 --- a/src/client/pool.rs +++ b/src/client/pool.rs @@ -1,19 +1,15 @@ -use std::cell::{Cell, RefCell}; use std::collections::{HashMap, VecDeque}; use std::fmt; -use std::io; -use std::ops::{Deref, DerefMut, BitAndAssign}; -use std::rc::{Rc, Weak}; +use std::ops::{Deref, DerefMut}; +use std::sync::{Arc, Mutex, Weak}; use std::time::{Duration, Instant}; use futures::{Future, Async, Poll, Stream}; +use futures::sync::oneshot; use tokio::reactor::{Handle, Interval}; -use relay; - -use proto::{KeepAlive, KA}; pub struct Pool { - inner: Rc>>, + inner: Arc>>, } // Before using a pooled connection, make sure the sender is not dead. @@ -29,7 +25,7 @@ struct PoolInner { enabled: bool, // These are internal Conns sitting in the event loop in the KeepAlive // state, waiting to receive a new Request to send on the socket. - idle: HashMap, Vec>>, + idle: HashMap, Vec>>, // These are outstanding Checkouts that are waiting for a socket to be // able to send a Request one. This is used when "racing" for a new // connection. @@ -39,7 +35,7 @@ struct PoolInner { // this list is checked for any parked Checkouts, and tries to notify // them that the Conn could be used instead of waiting for a brand new // connection. - parked: HashMap, VecDeque>>>, + parked: HashMap, VecDeque>>, timeout: Option, // Used to prevent multiple intervals from being spawned to clear // expired connections. @@ -49,10 +45,10 @@ struct PoolInner { expired_timer_spawned: bool, } -impl Pool { +impl Pool { pub fn new(enabled: bool, timeout: Option) -> Pool { Pool { - inner: Rc::new(RefCell::new(PoolInner { + inner: Arc::new(Mutex::new(PoolInner { enabled: enabled, idle: HashMap::new(), parked: HashMap::new(), @@ -61,74 +57,33 @@ impl Pool { })), } } +} +impl Pool { pub fn checkout(&self, key: &str) -> Checkout { Checkout { - key: Rc::new(key.to_owned()), + key: Arc::new(key.to_owned()), pool: self.clone(), parked: None, } } - fn put(&mut self, key: Rc, entry: Entry) { - trace!("Pool::put {:?}", key); - let mut inner = self.inner.borrow_mut(); - let mut remove_parked = false; - let mut entry = Some(entry); - if let Some(parked) = inner.parked.get_mut(&key) { - while let Some(tx) = parked.pop_front() { - if tx.is_canceled() { - trace!("Pool::put removing canceled parked {:?}", key); - } else { - tx.complete(entry.take().unwrap()); - break; - } - /* - match tx.send(entry.take().unwrap()) { - Ok(()) => break, - Err(e) => { - trace!("Pool::put removing canceled parked {:?}", key); - entry = Some(e); - } - } - */ - } - remove_parked = parked.is_empty(); - } - if remove_parked { - inner.parked.remove(&key); - } - - match entry { - Some(entry) => { - debug!("pooling idle connection for {:?}", key); - inner.idle.entry(key) - .or_insert(Vec::new()) - .push(entry); - } - None => trace!("Pool::put found parked {:?}", key), - } - } - - fn take(&self, key: &Rc) -> Option> { + fn take(&self, key: &Arc) -> Option> { let entry = { - let mut inner = self.inner.borrow_mut(); + let mut inner = self.inner.lock().unwrap(); let expiration = Expiration::new(inner.timeout); let mut should_remove = false; let entry = inner.idle.get_mut(key).and_then(|list| { trace!("take; url = {:?}, expiration = {:?}", key, expiration.0); while let Some(entry) = list.pop() { - match entry.status.get() { - TimedKA::Idle(idle_at) if !expiration.expires(idle_at) => { - if !entry.value.is_closed() { - should_remove = list.is_empty(); - return Some(entry); - } - }, - _ => {}, + if !expiration.expires(entry.idle_at) { + if !entry.value.is_closed() { + should_remove = list.is_empty(); + return Some(entry); + } } trace!("removing unacceptable pooled {:?}", key); - // every other case the Entry should just be dropped + // every other case the Idle should just be dropped // 1. Idle but expired // 2. Busy (something else somehow took it?) // 3. Disabled don't reuse of course @@ -143,72 +98,102 @@ impl Pool { entry }; - entry.map(|e| self.reuse(key, e)) + entry.map(|e| self.reuse(key, e.value)) } - pub fn pooled(&self, key: Rc, value: T) -> Pooled { + pub fn pooled(&self, key: Arc, value: T) -> Pooled { Pooled { - entry: Entry { - value: value, - is_reused: false, - status: Rc::new(Cell::new(TimedKA::Busy)), - }, + is_reused: false, key: key, - pool: Rc::downgrade(&self.inner), + pool: Arc::downgrade(&self.inner), + value: Some(value) } } - fn is_enabled(&self) -> bool { - self.inner.borrow().enabled - } - - fn reuse(&self, key: &Rc, mut entry: Entry) -> Pooled { + fn reuse(&self, key: &Arc, value: T) -> Pooled { debug!("reuse idle connection for {:?}", key); - entry.is_reused = true; - entry.status.set(TimedKA::Busy); Pooled { - entry: entry, + is_reused: true, key: key.clone(), - pool: Rc::downgrade(&self.inner), + pool: Arc::downgrade(&self.inner), + value: Some(value), } } - fn park(&mut self, key: Rc, tx: relay::Sender>) { + fn park(&mut self, key: Arc, tx: oneshot::Sender) { trace!("park; waiting for idle connection: {:?}", key); - self.inner.borrow_mut() + self.inner.lock().unwrap() .parked.entry(key) .or_insert(VecDeque::new()) .push_back(tx); } } -impl Pool { +impl PoolInner { + fn put(&mut self, key: Arc, value: T) { + if !self.enabled { + return; + } + trace!("Pool::put {:?}", key); + let mut remove_parked = false; + let mut value = Some(value); + if let Some(parked) = self.parked.get_mut(&key) { + while let Some(tx) = parked.pop_front() { + if !tx.is_canceled() { + match tx.send(value.take().unwrap()) { + Ok(()) => break, + Err(e) => { + value = Some(e); + } + } + } + + trace!("Pool::put removing canceled parked {:?}", key); + } + remove_parked = parked.is_empty(); + } + if remove_parked { + self.parked.remove(&key); + } + + match value { + Some(value) => { + debug!("pooling idle connection for {:?}", key); + self.idle.entry(key) + .or_insert(Vec::new()) + .push(Idle { + value: value, + idle_at: Instant::now(), + }); + } + None => trace!("Pool::put found parked {:?}", key), + } + } +} + +impl PoolInner { /// Any `FutureResponse`s that were created will have made a `Checkout`, /// and possibly inserted into the pool that it is waiting for an idle /// connection. If a user ever dropped that future, we need to clean out /// those parked senders. - fn clean_parked(&mut self, key: &Rc) { - let mut inner = self.inner.borrow_mut(); - + fn clean_parked(&mut self, key: &Arc) { let mut remove_parked = false; - if let Some(parked) = inner.parked.get_mut(key) { + if let Some(parked) = self.parked.get_mut(key) { parked.retain(|tx| { !tx.is_canceled() }); remove_parked = parked.is_empty(); } if remove_parked { - inner.parked.remove(key); + self.parked.remove(key); } } } -impl Pool { - fn clear_expired(&self) { - let mut inner = self.inner.borrow_mut(); - - let dur = if let Some(dur) = inner.timeout { +impl PoolInner { + fn clear_expired(&mut self) { + let dur = if let Some(dur) = self.timeout { dur } else { return @@ -217,19 +202,13 @@ impl Pool { let now = Instant::now(); //self.last_idle_check_at = now; - inner.idle.retain(|_key, values| { + self.idle.retain(|_key, values| { - values.retain(|val| { - if val.value.is_closed() { + values.retain(|entry| { + if entry.value.is_closed() { return false; } - match val.status.get() { - TimedKA::Idle(idle_at) if now - idle_at < dur => { - true - }, - _ => false, - } - //now - val.idle_at < dur + now - entry.idle_at < dur }); // returning false evicts this key/val @@ -241,28 +220,30 @@ impl Pool { impl Pool { pub(super) fn spawn_expired_interval(&self, handle: &Handle) { - let mut inner = self.inner.borrow_mut(); + let dur = { + let mut inner = self.inner.lock().unwrap(); - if !inner.enabled { - return; - } + if !inner.enabled { + return; + } - if inner.expired_timer_spawned { - return; - } - inner.expired_timer_spawned = true; + if inner.expired_timer_spawned { + return; + } + inner.expired_timer_spawned = true; - let dur = if let Some(dur) = inner.timeout { - dur - } else { - return + if let Some(dur) = inner.timeout { + dur + } else { + return + } }; let interval = Interval::new(dur, handle) .expect("reactor is gone"); handle.spawn(IdleInterval { interval: interval, - pool: Rc::downgrade(&self.inner), + pool: Arc::downgrade(&self.inner), }); } } @@ -275,121 +256,83 @@ impl Clone for Pool { } } -#[derive(Clone)] -pub struct Pooled { - entry: Entry, - key: Rc, - pool: Weak>>, +pub struct Pooled { + value: Option, + is_reused: bool, + key: Arc, + pool: Weak>>, } -impl Pooled { +impl Pooled { pub fn is_reused(&self) -> bool { - self.entry.is_reused + self.is_reused + } + + fn as_ref(&self) -> &T { + self.value.as_ref().expect("not dropped") + } + + fn as_mut(&mut self) -> &mut T { + self.value.as_mut().expect("not dropped") } } -impl Deref for Pooled { +impl Deref for Pooled { type Target = T; fn deref(&self) -> &T { - &self.entry.value + self.as_ref() } } -impl DerefMut for Pooled { +impl DerefMut for Pooled { fn deref_mut(&mut self) -> &mut T { - &mut self.entry.value + self.as_mut() } } -impl KeepAlive for Pooled { - fn busy(&mut self) { - self.entry.status.set(TimedKA::Busy); - } - - fn disable(&mut self) { - self.entry.status.set(TimedKA::Disabled); - } - - fn idle(&mut self) { - let previous = self.status(); - self.entry.status.set(TimedKA::Idle(Instant::now())); - if let KA::Idle = previous { - trace!("Pooled::idle already idle"); - return; - } - self.entry.is_reused = true; - if let Some(inner) = self.pool.upgrade() { - let mut pool = Pool { - inner: inner, - }; - if pool.is_enabled() { - pool.put(self.key.clone(), self.entry.clone()); +impl Drop for Pooled { + fn drop(&mut self) { + if let Some(value) = self.value.take() { + if let Some(inner) = self.pool.upgrade() { + if let Ok(mut inner) = inner.lock() { + inner.put(self.key.clone(), value); + } } else { - trace!("keepalive disabled, dropping pooled ({:?})", self.key); - self.disable(); + trace!("pool dropped, dropping pooled ({:?})", self.key); } - } else { - trace!("pool dropped, dropping pooled ({:?})", self.key); - self.disable(); - } - } - - fn status(&self) -> KA { - match self.entry.status.get() { - TimedKA::Idle(_) => KA::Idle, - TimedKA::Busy => KA::Busy, - TimedKA::Disabled => KA::Disabled, } } } -impl fmt::Debug for Pooled { +impl fmt::Debug for Pooled { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Pooled") - .field("status", &self.entry.status.get()) .field("key", &self.key) .finish() } } -impl BitAndAssign for Pooled { - fn bitand_assign(&mut self, enabled: bool) { - if !enabled { - self.disable(); - } - } -} - -#[derive(Clone)] -struct Entry { +struct Idle { + idle_at: Instant, value: T, - is_reused: bool, - status: Rc>, -} - -#[derive(Clone, Copy, Debug)] -enum TimedKA { - Idle(Instant), - Busy, - Disabled, } pub struct Checkout { - key: Rc, + key: Arc, pool: Pool, - parked: Option>>, + parked: Option>, } struct NotParked; -impl Checkout { +impl Checkout { fn poll_parked(&mut self) -> Poll, NotParked> { let mut drop_parked = false; if let Some(ref mut rx) = self.parked { match rx.poll() { - Ok(Async::Ready(entry)) => { - if !entry.value.is_closed() { - return Ok(Async::Ready(self.pool.reuse(&self.key, entry))); + Ok(Async::Ready(value)) => { + if !value.is_closed() { + return Ok(Async::Ready(self.pool.reuse(&self.key, value))); } drop_parked = true; }, @@ -405,7 +348,7 @@ impl Checkout { fn park(&mut self) { if self.parked.is_none() { - let (tx, mut rx) = relay::channel(); + let (tx, mut rx) = oneshot::channel(); let _ = rx.poll(); // park this task self.pool.park(self.key.clone(), tx); self.parked = Some(rx); @@ -413,9 +356,9 @@ impl Checkout { } } -impl Future for Checkout { +impl Future for Checkout { type Item = Pooled; - type Error = io::Error; + type Error = ::Error; fn poll(&mut self) -> Poll { match self.poll_parked() { @@ -437,7 +380,9 @@ impl Future for Checkout { impl Drop for Checkout { fn drop(&mut self) { self.parked.take(); - self.pool.clean_parked(&self.key); + if let Ok(mut inner) = self.pool.inner.lock() { + inner.clean_parked(&self.key); + } } } @@ -458,7 +403,7 @@ impl Expiration { struct IdleInterval { interval: Interval, - pool: Weak>>, + pool: Weak>>, } impl Future for IdleInterval { @@ -470,22 +415,22 @@ impl Future for IdleInterval { try_ready!(self.interval.poll().map_err(|_| unreachable!("interval cannot error"))); if let Some(inner) = self.pool.upgrade() { - let pool = Pool { inner: inner }; - pool.clear_expired(); - } else { - return Ok(Async::Ready(())); + if let Ok(mut inner) = inner.lock() { + inner.clear_expired(); + continue; + } } + return Ok(Async::Ready(())); } } } #[cfg(test)] mod tests { - use std::rc::Rc; + use std::sync::Arc; use std::time::Duration; use futures::{Async, Future}; use futures::future; - use proto::KeepAlive; use super::{Closed, Pool}; impl Closed for i32 { @@ -497,9 +442,10 @@ mod tests { #[test] fn test_pool_checkout_smoke() { let pool = Pool::new(true, Some(Duration::from_secs(5))); - let key = Rc::new("foo".to_string()); - let mut pooled = pool.pooled(key.clone(), 41); - pooled.idle(); + let key = Arc::new("foo".to_string()); + let pooled = pool.pooled(key.clone(), 41); + + drop(pooled); match pool.checkout(&key).poll().unwrap() { Async::Ready(pooled) => assert_eq!(*pooled, 41), @@ -510,11 +456,11 @@ mod tests { #[test] fn test_pool_checkout_returns_none_if_expired() { future::lazy(|| { - let pool = Pool::new(true, Some(Duration::from_secs(1))); - let key = Rc::new("foo".to_string()); - let mut pooled = pool.pooled(key.clone(), 41); - pooled.idle(); - ::std::thread::sleep(pool.inner.borrow().timeout.unwrap()); + let pool = Pool::new(true, Some(Duration::from_millis(100))); + let key = Arc::new("foo".to_string()); + let pooled = pool.pooled(key.clone(), 41); + drop(pooled); + ::std::thread::sleep(pool.inner.lock().unwrap().timeout.unwrap()); assert!(pool.checkout(&key).poll().unwrap().is_not_ready()); ::futures::future::ok::<(), ()>(()) }).wait().unwrap(); @@ -522,26 +468,23 @@ mod tests { #[test] fn test_pool_checkout_removes_expired() { - let pool = Pool::new(true, Some(Duration::from_secs(1))); - let key = Rc::new("foo".to_string()); + future::lazy(|| { + let pool = Pool::new(true, Some(Duration::from_millis(100))); + let key = Arc::new("foo".to_string()); - let mut pooled1 = pool.pooled(key.clone(), 41); - pooled1.idle(); - let mut pooled2 = pool.pooled(key.clone(), 5); - pooled2.idle(); - let mut pooled3 = pool.pooled(key.clone(), 99); - pooled3.idle(); + pool.pooled(key.clone(), 41); + pool.pooled(key.clone(), 5); + pool.pooled(key.clone(), 99); + assert_eq!(pool.inner.lock().unwrap().idle.get(&key).map(|entries| entries.len()), Some(3)); + ::std::thread::sleep(pool.inner.lock().unwrap().timeout.unwrap()); - assert_eq!(pool.inner.borrow().idle.get(&key).map(|entries| entries.len()), Some(3)); - ::std::thread::sleep(pool.inner.borrow().timeout.unwrap()); + // checkout.poll() should clean out the expired + pool.checkout(&key).poll().unwrap(); + assert!(pool.inner.lock().unwrap().idle.get(&key).is_none()); - pooled1.idle(); - pooled2.idle(); // idle after sleep, not expired - pool.checkout(&key).poll().unwrap(); - assert_eq!(pool.inner.borrow().idle.get(&key).map(|entries| entries.len()), Some(1)); - pool.checkout(&key).poll().unwrap(); - assert!(pool.inner.borrow().idle.get(&key).is_none()); + Ok::<(), ()>(()) + }).wait().unwrap(); } #[test] @@ -549,16 +492,13 @@ mod tests { let mut core = ::tokio::reactor::Core::new().unwrap(); let pool = Pool::new(true, Some(Duration::from_millis(100))); pool.spawn_expired_interval(&core.handle()); - let key = Rc::new("foo".to_string()); + let key = Arc::new("foo".to_string()); - let mut pooled1 = pool.pooled(key.clone(), 41); - pooled1.idle(); - let mut pooled2 = pool.pooled(key.clone(), 5); - pooled2.idle(); - let mut pooled3 = pool.pooled(key.clone(), 99); - pooled3.idle(); + pool.pooled(key.clone(), 41); + pool.pooled(key.clone(), 5); + pool.pooled(key.clone(), 99); - assert_eq!(pool.inner.borrow().idle.get(&key).map(|entries| entries.len()), Some(3)); + assert_eq!(pool.inner.lock().unwrap().idle.get(&key).map(|entries| entries.len()), Some(3)); let timeout = ::tokio::reactor::Timeout::new( Duration::from_millis(400), // allow for too-good resolution @@ -566,49 +506,48 @@ mod tests { ).unwrap(); core.run(timeout).unwrap(); - assert!(pool.inner.borrow().idle.get(&key).is_none()); + assert!(pool.inner.lock().unwrap().idle.get(&key).is_none()); } #[test] fn test_pool_checkout_task_unparked() { let pool = Pool::new(true, Some(Duration::from_secs(10))); - let key = Rc::new("foo".to_string()); - let pooled1 = pool.pooled(key.clone(), 41); + let key = Arc::new("foo".to_string()); + let pooled = pool.pooled(key.clone(), 41); - let mut pooled = pooled1.clone(); let checkout = pool.checkout(&key).join(future::lazy(move || { // the checkout future will park first, // and then this lazy future will be polled, which will insert // the pooled back into the pool // // this test makes sure that doing so will unpark the checkout - pooled.idle(); + drop(pooled); Ok(()) })).map(|(entry, _)| entry); - assert_eq!(*checkout.wait().unwrap(), *pooled1); + assert_eq!(*checkout.wait().unwrap(), 41); } #[test] fn test_pool_checkout_drop_cleans_up_parked() { future::lazy(|| { - let pool = Pool::new(true, Some(Duration::from_secs(10))); - let key = Rc::new("localhost:12345".to_string()); - let _pooled1 = pool.pooled(key.clone(), 41); + let pool = Pool::::new(true, Some(Duration::from_secs(10))); + let key = Arc::new("localhost:12345".to_string()); + let mut checkout1 = pool.checkout(&key); let mut checkout2 = pool.checkout(&key); // first poll needed to get into Pool's parked checkout1.poll().unwrap(); - assert_eq!(pool.inner.borrow().parked.get(&key).unwrap().len(), 1); + assert_eq!(pool.inner.lock().unwrap().parked.get(&key).unwrap().len(), 1); checkout2.poll().unwrap(); - assert_eq!(pool.inner.borrow().parked.get(&key).unwrap().len(), 2); + assert_eq!(pool.inner.lock().unwrap().parked.get(&key).unwrap().len(), 2); // on drop, clean up Pool drop(checkout1); - assert_eq!(pool.inner.borrow().parked.get(&key).unwrap().len(), 1); + assert_eq!(pool.inner.lock().unwrap().parked.get(&key).unwrap().len(), 1); drop(checkout2); - assert!(pool.inner.borrow().parked.get(&key).is_none()); + assert!(pool.inner.lock().unwrap().parked.get(&key).is_none()); ::futures::future::ok::<(), ()>(()) }).wait().unwrap(); diff --git a/src/client/signal.rs b/src/client/signal.rs new file mode 100644 index 0000000000..2ddf67f7a5 --- /dev/null +++ b/src/client/signal.rs @@ -0,0 +1,188 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use futures::{Async, Poll}; +use futures::task::{self, Task}; + +use self::lock::Lock; + +pub fn new() -> (Giver, Taker) { + let inner = Arc::new(Inner { + state: AtomicUsize::new(STATE_IDLE), + task: Lock::new(None), + }); + let inner2 = inner.clone(); + ( + Giver { + inner: inner, + }, + Taker { + inner: inner2, + }, + ) +} + +#[derive(Clone)] +pub struct Giver { + inner: Arc, +} + +pub struct Taker { + inner: Arc, +} + +const STATE_IDLE: usize = 0; +const STATE_WANT: usize = 1; +const STATE_GIVE: usize = 2; +const STATE_CLOSED: usize = 3; + +struct Inner { + state: AtomicUsize, + task: Lock>, +} + +impl Giver { + pub fn poll_want(&mut self) -> Poll<(), ()> { + loop { + let state = self.inner.state.load(Ordering::SeqCst); + match state { + STATE_WANT => { + // only set to IDLE if it is still Want + self.inner.state.compare_and_swap( + STATE_WANT, + STATE_IDLE, + Ordering::SeqCst, + ); + return Ok(Async::Ready(())) + }, + STATE_GIVE => { + // we're already waiting, return + return Ok(Async::NotReady) + } + STATE_CLOSED => return Err(()), + // Taker doesn't want anything yet, so park. + _ => { + if let Some(mut locked) = self.inner.task.try_lock() { + + // While we have the lock, try to set to GIVE. + let old = self.inner.state.compare_and_swap( + STATE_IDLE, + STATE_GIVE, + Ordering::SeqCst, + ); + // If it's not still IDLE, something happened! + // Go around the loop again. + if old == STATE_IDLE { + *locked = Some(task::current()); + return Ok(Async::NotReady) + } + } else { + // if we couldn't take the lock, then a Taker has it. + // The *ONLY* reason is because it is in the process of notifying us + // of its want. + // + // We need to loop again to see what state it was changed to. + } + }, + } + } + } + + pub fn is_canceled(&self) -> bool { + self.inner.state.load(Ordering::SeqCst) == STATE_CLOSED + } +} + +impl Taker { + pub fn cancel(&self) { + self.signal(STATE_CLOSED) + } + + pub fn want(&self) { + self.signal(STATE_WANT) + } + + fn signal(&self, state: usize) { + let old_state = self.inner.state.swap(state, Ordering::SeqCst); + match old_state { + STATE_WANT | STATE_CLOSED | STATE_IDLE => (), + _ => { + loop { + if let Some(mut locked) = self.inner.task.try_lock() { + if let Some(task) = locked.take() { + task.notify(); + } + return; + } else { + // if we couldn't take the lock, then a Giver has it. + // The *ONLY* reason is because it is in the process of parking. + // + // We need to loop and take the lock so we can notify this task. + } + } + }, + } + } +} + +impl Drop for Taker { + fn drop(&mut self) { + self.cancel(); + } +} + + +// a sub module just to protect unsafety +mod lock { + use std::cell::UnsafeCell; + use std::ops::{Deref, DerefMut}; + use std::sync::atomic::{AtomicBool, Ordering}; + + pub struct Lock { + is_locked: AtomicBool, + value: UnsafeCell, + } + + impl Lock { + pub fn new(val: T) -> Lock { + Lock { + is_locked: AtomicBool::new(false), + value: UnsafeCell::new(val), + } + } + + pub fn try_lock(&self) -> Option> { + if !self.is_locked.swap(true, Ordering::SeqCst) { + Some(Locked { lock: self }) + } else { + None + } + } + } + + unsafe impl Send for Lock {} + unsafe impl Sync for Lock {} + + pub struct Locked<'a, T: 'a> { + lock: &'a Lock, + } + + impl<'a, T> Deref for Locked<'a, T> { + type Target = T; + fn deref(&self) -> &T { + unsafe { &*self.lock.value.get() } + } + } + + impl<'a, T> DerefMut for Locked<'a, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.lock.value.get() } + } + } + + impl<'a, T> Drop for Locked<'a, T> { + fn drop(&mut self) { + self.lock.is_locked.store(false, Ordering::SeqCst); + } + } +} diff --git a/src/client/tests.rs b/src/client/tests.rs index 6aecb65502..7d2157ddf9 100644 --- a/src/client/tests.rs +++ b/src/client/tests.rs @@ -45,3 +45,47 @@ fn retryable_request() { core.run(res2.join(srv2)).expect("res2"); } + +#[test] +fn conn_reset_after_write() { + let _ = pretty_env_logger::try_init(); + let mut core = Core::new().unwrap(); + + let mut connector = MockConnector::new(); + + let sock1 = connector.mock("http://mock.local/a"); + + let client = Client::configure() + .connector(connector) + .build(&core.handle()); + + + { + let res1 = client.get("http://mock.local/a".parse().unwrap()); + let srv1 = poll_fn(|| { + try_ready!(sock1.read(&mut [0u8; 512])); + try_ready!(sock1.write(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")); + Ok(Async::Ready(())) + }); + core.run(res1.join(srv1)).expect("res1"); + } + + let res2 = client.get("http://mock.local/a".parse().unwrap()); + let mut sock1 = Some(sock1); + let srv2 = poll_fn(|| { + // We purposefully keep the socket open until the client + // has written the second request, and THEN disconnect. + // + // Not because we expect servers to be jerks, but to trigger + // state where we write on an assumedly good connetion, and + // only reset the close AFTER we wrote bytes. + try_ready!(sock1.as_mut().unwrap().read(&mut [0u8; 512])); + sock1.take(); + Ok(Async::Ready(())) + }); + let err = core.run(res2.join(srv2)).expect_err("res2"); + match err { + ::Error::Incomplete => (), + other => panic!("expected Incomplete, found {:?}", other) + } +} diff --git a/src/error.rs b/src/error.rs index fc63f7d258..b6f846a585 100644 --- a/src/error.rs +++ b/src/error.rs @@ -17,6 +17,7 @@ use self::Error::{ Status, Timeout, Upgrade, + Closed, Cancel, Io, TooLarge, @@ -50,6 +51,8 @@ pub enum Error { Upgrade, /// A pending item was dropped before ever being processed. Cancel(Canceled), + /// Indicates a connection is closed. + Closed, /// An `io::Error` that occurred while trying to read or write to a network stream. Io(IoError), /// Parsing a field as string failed @@ -60,9 +63,9 @@ pub enum Error { } impl Error { - pub(crate) fn new_canceled(cause: Option) -> Error { + pub(crate) fn new_canceled>>(cause: Option) -> Error { Error::Cancel(Canceled { - cause: cause.map(Box::new), + cause: cause.map(Into::into), }) } } @@ -75,7 +78,7 @@ impl Error { /// fulfilled with this error, signaling the `Request` was never started. #[derive(Debug)] pub struct Canceled { - cause: Option>, + cause: Option>, } impl Canceled { @@ -121,6 +124,7 @@ impl StdError for Error { Incomplete => "message is incomplete", Timeout => "timeout", Upgrade => "unsupported protocol upgrade", + Closed => "connection is closed", Cancel(ref e) => e.description(), Uri(ref e) => e.description(), Io(ref e) => e.description(), diff --git a/src/proto/body.rs b/src/proto/body.rs index f921118742..8f23378285 100644 --- a/src/proto/body.rs +++ b/src/proto/body.rs @@ -1,3 +1,5 @@ +use std::fmt; + use bytes::Bytes; use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; use futures::sync::{mpsc, oneshot}; @@ -13,11 +15,12 @@ pub type BodySender = mpsc::Sender>; /// A `Stream` for `Chunk`s used in requests and responses. #[must_use = "streams do nothing unless polled"] -#[derive(Debug)] -pub struct Body(Inner); +pub struct Body { + kind: Kind, +} #[derive(Debug)] -enum Inner { +enum Kind { #[cfg(feature = "tokio-proto")] Tokio(TokioBody), Chan { @@ -40,7 +43,7 @@ impl Body { /// Return an empty body stream #[inline] pub fn empty() -> Body { - Body(Inner::Empty) + Body::new(Kind::Empty) } /// Return a body stream with an associated sender half @@ -60,11 +63,32 @@ impl Body { /// there are more chunks immediately. #[inline] pub fn is_empty(&self) -> bool { - match self.0 { - Inner::Empty => true, + match self.kind { + Kind::Empty => true, _ => false, } } + + fn new(kind: Kind) -> Body { + Body { + kind: kind, + } + } + + fn poll_inner(&mut self) -> Poll, ::Error> { + match self.kind { + #[cfg(feature = "tokio-proto")] + Kind::Tokio(ref mut rx) => rx.poll(), + Kind::Chan { ref mut rx, .. } => match rx.poll().expect("mpsc cannot error") { + Async::Ready(Some(Ok(chunk))) => Ok(Async::Ready(Some(chunk))), + Async::Ready(Some(Err(err))) => Err(err), + Async::Ready(None) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + }, + Kind::Once(ref mut val) => Ok(Async::Ready(val.take())), + Kind::Empty => Ok(Async::Ready(None)), + } + } } impl Default for Body { @@ -80,18 +104,15 @@ impl Stream for Body { #[inline] fn poll(&mut self) -> Poll, ::Error> { - match self.0 { - #[cfg(feature = "tokio-proto")] - Inner::Tokio(ref mut rx) => rx.poll(), - Inner::Chan { ref mut rx, .. } => match rx.poll().expect("mpsc cannot error") { - Async::Ready(Some(Ok(chunk))) => Ok(Async::Ready(Some(chunk))), - Async::Ready(Some(Err(err))) => Err(err), - Async::Ready(None) => Ok(Async::Ready(None)), - Async::NotReady => Ok(Async::NotReady), - }, - Inner::Once(ref mut val) => Ok(Async::Ready(val.take())), - Inner::Empty => Ok(Async::Ready(None)), - } + self.poll_inner() + } +} + +impl fmt::Debug for Body { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_tuple("Body") + .field(&self.kind) + .finish() } } @@ -105,7 +126,7 @@ pub fn channel() -> (ChunkSender, Body) { close_rx_check: true, tx: tx, }; - let rx = Body(Inner::Chan { + let rx = Body::new(Kind::Chan { close_tx: close_tx, rx: rx, }); @@ -142,24 +163,24 @@ impl ChunkSender { feat_server_proto! { impl From for tokio_proto::streaming::Body { fn from(b: Body) -> tokio_proto::streaming::Body { - match b.0 { - Inner::Tokio(b) => b, - Inner::Chan { close_tx, rx } => { + match b.kind { + Kind::Tokio(b) => b, + Kind::Chan { close_tx, rx } => { // disable knowing if the Rx gets dropped, since we cannot // pass this tx along. let _ = close_tx.send(false); rx.into() }, - Inner::Once(Some(chunk)) => TokioBody::from(chunk), - Inner::Once(None) | - Inner::Empty => TokioBody::empty(), + Kind::Once(Some(chunk)) => TokioBody::from(chunk), + Kind::Once(None) | + Kind::Empty => TokioBody::empty(), } } } impl From> for Body { fn from(tokio_body: tokio_proto::streaming::Body) -> Body { - Body(Inner::Tokio(tokio_body)) + Body::new(Kind::Tokio(tokio_body)) } } } @@ -168,7 +189,7 @@ impl From>> for Body { #[inline] fn from(src: mpsc::Receiver>) -> Body { let (tx, _) = oneshot::channel(); - Body(Inner::Chan { + Body::new(Kind::Chan { close_tx: tx, rx: src, }) @@ -178,7 +199,7 @@ impl From>> for Body { impl From for Body { #[inline] fn from (chunk: Chunk) -> Body { - Body(Inner::Once(Some(chunk))) + Body::new(Kind::Once(Some(chunk))) } } diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 329279bb66..581c8233f2 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -2,6 +2,7 @@ use std::fmt; use std::io::{self}; use std::marker::PhantomData; +use bytes::Bytes; use futures::{Async, AsyncSink, Poll, StartSend}; #[cfg(feature = "tokio-proto")] use futures::{Sink, Stream}; @@ -10,7 +11,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; #[cfg(feature = "tokio-proto")] use tokio_proto::streaming::pipeline::{Frame, Transport}; -use proto::{Chunk, Http1Transaction, MessageHead}; +use proto::{Chunk, Decode, Http1Transaction, MessageHead}; use super::io::{Cursor, Buffered}; use super::{EncodedBuf, Encoder, Decoder}; use method::Method; @@ -24,24 +25,34 @@ use version::HttpVersion; /// The connection will determine when a message begins and ends as well as /// determine if this connection can be kept alive after the message, /// or if it is complete. -pub struct Conn { +pub struct Conn { io: Buffered>>, - state: State, + state: State, _marker: PhantomData } -impl Conn +/* +impl Conn +where I: AsyncRead + AsyncWrite, + B: AsRef<[u8]>, +{ + pub fn new_client(io: I) -> Conn { + Conn::new(io) + } +} +*/ + +impl Conn where I: AsyncRead + AsyncWrite, B: AsRef<[u8]>, T: Http1Transaction, - K: KeepAlive { - pub fn new(io: I, keep_alive: K) -> Conn { + pub fn new(io: I) -> Conn { Conn { io: Buffered::new(io), state: State { error: None, - keep_alive: keep_alive, + keep_alive: KA::Busy, method: None, read_task: None, reading: Reading::Init, @@ -66,6 +77,10 @@ where I: AsyncRead + AsyncWrite, self.io.set_write_strategy_flatten(); } + pub fn into_inner(self) -> (I, Bytes) { + self.io.into_inner() + } + #[cfg(feature = "tokio-proto")] fn poll_incoming(&mut self) -> Poll, Chunk, ::Error>>, io::Error> { trace!("Conn::poll_incoming()"); @@ -205,8 +220,16 @@ where I: AsyncRead + AsyncWrite, }; let decoder = match T::decoder(&head, &mut self.state.method) { - Ok(Some(d)) => d, - Ok(None) => { + Ok(Decode::Normal(d)) => { + d + }, + Ok(Decode::Final(d)) => { + trace!("final decoder, HTTP ending"); + debug_assert!(d.is_eof()); + self.state.close_read(); + d + }, + Ok(Decode::Ignore) => { // likely a 1xx message that we can ignore continue; } @@ -232,7 +255,11 @@ where I: AsyncRead + AsyncWrite, } else { (true, Reading::Body(decoder)) }; - self.state.reading = reading; + if let Reading::Closed = self.state.reading { + // actually want an `if not let ...` + } else { + self.state.reading = reading; + } if !body { self.try_keep_alive(); } @@ -434,6 +461,12 @@ where I: AsyncRead + AsyncWrite, } pub fn can_write_head(&self) -> bool { + if !T::should_read_first() { + match self.state.reading { + Reading::Closed => return false, + _ => {}, + } + } match self.state.writing { Writing::Init => true, _ => false @@ -456,6 +489,10 @@ where I: AsyncRead + AsyncWrite, pub fn write_head(&mut self, mut head: MessageHead, body: bool) { debug_assert!(self.can_write_head()); + if !T::should_read_first() { + self.state.busy(); + } + self.enforce_version(&mut head); let buf = self.io.write_buf_mut(); @@ -623,11 +660,10 @@ where I: AsyncRead + AsyncWrite, // ==== tokio_proto impl ==== #[cfg(feature = "tokio-proto")] -impl Stream for Conn +impl Stream for Conn where I: AsyncRead + AsyncWrite, B: AsRef<[u8]>, T: Http1Transaction, - K: KeepAlive, T::Outgoing: fmt::Debug { type Item = Frame, Chunk, ::Error>; type Error = io::Error; @@ -642,11 +678,10 @@ where I: AsyncRead + AsyncWrite, } #[cfg(feature = "tokio-proto")] -impl Sink for Conn +impl Sink for Conn where I: AsyncRead + AsyncWrite, B: AsRef<[u8]>, T: Http1Transaction, - K: KeepAlive, T::Outgoing: fmt::Debug { type SinkItem = Frame, B, ::Error>; type SinkError = io::Error; @@ -711,14 +746,13 @@ where I: AsyncRead + AsyncWrite, } #[cfg(feature = "tokio-proto")] -impl Transport for Conn +impl Transport for Conn where I: AsyncRead + AsyncWrite + 'static, B: AsRef<[u8]> + 'static, T: Http1Transaction + 'static, - K: KeepAlive + 'static, T::Outgoing: fmt::Debug {} -impl, T, K: KeepAlive> fmt::Debug for Conn { +impl, T> fmt::Debug for Conn { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Conn") .field("state", &self.state) @@ -727,9 +761,9 @@ impl, T, K: KeepAlive> fmt::Debug for Conn { } } -struct State { +struct State { error: Option<::Error>, - keep_alive: K, + keep_alive: KA, method: Option, read_task: Option, reading: Reading, @@ -752,12 +786,12 @@ enum Writing { Closed, } -impl fmt::Debug for State { +impl fmt::Debug for State { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("State") .field("reading", &self.reading) .field("writing", &self.writing) - .field("keep_alive", &self.keep_alive.status()) + .field("keep_alive", &self.keep_alive) .field("error", &self.error) //.field("method", &self.method) .field("read_task", &self.read_task) @@ -786,15 +820,8 @@ impl ::std::ops::BitAndAssign for KA { } } -pub trait KeepAlive: fmt::Debug + ::std::ops::BitAndAssign { - fn busy(&mut self); - fn disable(&mut self); - fn idle(&mut self); - fn status(&self) -> KA; -} - #[derive(Clone, Copy, Debug)] -pub enum KA { +enum KA { Idle, Busy, Disabled, @@ -806,7 +833,7 @@ impl Default for KA { } } -impl KeepAlive for KA { +impl KA { fn idle(&mut self) { *self = KA::Idle; } @@ -824,7 +851,7 @@ impl KeepAlive for KA { } } -impl State { +impl State { fn close(&mut self) { trace!("State::close()"); self.reading = Reading::Closed; @@ -976,7 +1003,7 @@ mod tests { let good_message = b"GET / HTTP/1.1\r\n\r\n".to_vec(); let len = good_message.len(); let io = AsyncIo::new_buf(good_message, len); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); match conn.poll().unwrap() { Async::Ready(Some(Frame::Message { message, body: false })) => { @@ -994,7 +1021,7 @@ mod tests { let _: Result<(), ()> = future::lazy(|| { let good_message = b"GET / HTTP/1.1\r\nHost: foo.bar\r\n\r\n".to_vec(); let io = AsyncIo::new_buf(good_message, 10); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); assert!(conn.poll().unwrap().is_not_ready()); conn.io.io_mut().block_in(50); let async = conn.poll().unwrap(); @@ -1010,7 +1037,7 @@ mod tests { #[test] fn test_conn_init_read_eof_idle() { let io = AsyncIo::new_buf(vec![], 1); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.idle(); match conn.poll().unwrap() { @@ -1022,7 +1049,7 @@ mod tests { #[test] fn test_conn_init_read_eof_idle_partial_parse() { let io = AsyncIo::new_buf(b"GET / HTTP/1.1".to_vec(), 100); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.idle(); match conn.poll() { @@ -1036,7 +1063,7 @@ mod tests { let _: Result<(), ()> = future::lazy(|| { // server ignores let io = AsyncIo::new_eof(); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.busy(); match conn.poll().unwrap() { @@ -1046,7 +1073,7 @@ mod tests { // client let io = AsyncIo::new_eof(); - let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io); conn.state.busy(); match conn.poll() { @@ -1061,7 +1088,7 @@ mod tests { fn test_conn_body_finish_read_eof() { let _: Result<(), ()> = future::lazy(|| { let io = AsyncIo::new_eof(); - let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io); conn.state.busy(); conn.state.writing = Writing::KeepAlive; conn.state.reading = Reading::Body(Decoder::length(0)); @@ -1086,7 +1113,7 @@ mod tests { fn test_conn_message_empty_body_read_eof() { let _: Result<(), ()> = future::lazy(|| { let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(), 1024); - let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io); conn.state.busy(); conn.state.writing = Writing::KeepAlive; @@ -1110,7 +1137,7 @@ mod tests { fn test_conn_read_body_end() { let _: Result<(), ()> = future::lazy(|| { let io = AsyncIo::new_buf(b"POST / HTTP/1.1\r\nContent-Length: 5\r\n\r\n12345".to_vec(), 1024); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.busy(); match conn.poll() { @@ -1140,7 +1167,7 @@ mod tests { #[test] fn test_conn_closed_read() { let io = AsyncIo::new_buf(vec![], 0); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.close(); match conn.poll().unwrap() { @@ -1155,7 +1182,7 @@ mod tests { let _ = pretty_env_logger::try_init(); let _: Result<(), ()> = future::lazy(|| { let io = AsyncIo::new_buf(vec![], 0); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); let max = super::super::io::DEFAULT_MAX_BUFFER_SIZE + 4096; conn.state.writing = Writing::Body(Encoder::length((max * 2) as u64)); @@ -1180,7 +1207,7 @@ mod tests { fn test_conn_body_write_chunked() { let _: Result<(), ()> = future::lazy(|| { let io = AsyncIo::new_buf(vec![], 4096); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.writing = Writing::Body(Encoder::chunked()); assert!(conn.start_send(Frame::Body { chunk: Some("headers".into()) }).unwrap().is_ready()); @@ -1193,7 +1220,7 @@ mod tests { fn test_conn_body_flush() { let _: Result<(), ()> = future::lazy(|| { let io = AsyncIo::new_buf(vec![], 1024 * 1024 * 5); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.writing = Writing::Body(Encoder::length(1024 * 1024)); assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; 1024 * 1024].into()) }).unwrap().is_ready()); assert!(!conn.can_buffer_body()); @@ -1230,7 +1257,7 @@ mod tests { // test that once writing is done, unparks let f = future::lazy(|| { let io = AsyncIo::new_buf(vec![], 4096); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.reading = Reading::KeepAlive; assert!(conn.poll().unwrap().is_not_ready()); @@ -1244,7 +1271,7 @@ mod tests { // test that flushing when not waiting on read doesn't unpark let f = future::lazy(|| { let io = AsyncIo::new_buf(vec![], 4096); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.writing = Writing::KeepAlive; assert!(conn.poll_complete().unwrap().is_ready()); Ok::<(), ()>(()) @@ -1255,7 +1282,7 @@ mod tests { // test that flushing and writing isn't done doesn't unpark let f = future::lazy(|| { let io = AsyncIo::new_buf(vec![], 4096); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.reading = Reading::KeepAlive; assert!(conn.poll().unwrap().is_not_ready()); conn.state.writing = Writing::Body(Encoder::length(5_000)); @@ -1268,7 +1295,7 @@ mod tests { #[test] fn test_conn_closed_write() { let io = AsyncIo::new_buf(vec![], 0); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.close(); match conn.start_send(Frame::Body { chunk: Some(b"foobar".to_vec().into()) }) { @@ -1282,7 +1309,7 @@ mod tests { #[test] fn test_conn_write_empty_chunk() { let io = AsyncIo::new_buf(vec![], 0); - let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); + let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io); conn.state.writing = Writing::KeepAlive; assert!(conn.start_send(Frame::Body { chunk: None }).unwrap().is_ready()); diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 03ec04acce..7023dc29cc 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -1,15 +1,16 @@ use std::io; +use bytes::Bytes; use futures::{Async, AsyncSink, Future, Poll, Stream}; use futures::sync::oneshot; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_service::Service; -use proto::{Body, Conn, KeepAlive, Http1Transaction, MessageHead, RequestHead, ResponseHead}; +use proto::{Body, Conn, Http1Transaction, MessageHead, RequestHead, ResponseHead}; use ::StatusCode; -pub struct Dispatcher { - conn: Conn, +pub struct Dispatcher { + conn: Conn, dispatch: D, body_tx: Option<::proto::body::ChunkSender>, body_rx: Option, @@ -40,16 +41,15 @@ pub type ClientMsg = (RequestHead, Option); type ClientRx = ::client::dispatch::Receiver, ::Response>; -impl Dispatcher +impl Dispatcher where D: Dispatch, PollBody=Bs, RecvItem=MessageHead>, I: AsyncRead + AsyncWrite, B: AsRef<[u8]>, T: Http1Transaction, - K: KeepAlive, Bs: Stream, { - pub fn new(dispatch: D, conn: Conn) -> Self { + pub fn new(dispatch: D, conn: Conn) -> Self { Dispatcher { conn: conn, dispatch: dispatch, @@ -63,15 +63,44 @@ where self.conn.disable_keep_alive() } - fn poll2(&mut self) -> Poll<(), ::Error> { + pub fn into_inner(self) -> (I, Bytes) { + self.conn.into_inner() + } + + /// The "Future" poll function. Runs this dispatcher until the + /// connection is shutdown, or an error occurs. + pub fn poll_until_shutdown(&mut self) -> Poll<(), ::Error> { + self.poll_catch(true) + } + + /// Run this dispatcher until HTTP says this connection is done, + /// but don't call `AsyncWrite::shutdown` on the underlying IO. + /// + /// This is useful for HTTP upgrades. + pub fn poll_without_shutdown(&mut self) -> Poll<(), ::Error> { + self.poll_catch(false) + } + + fn poll_catch(&mut self, should_shutdown: bool) -> Poll<(), ::Error> { + self.poll_inner(should_shutdown).or_else(|e| { + // An error means we're shutting down either way. + // We just try to give the error to the user, + // and close the connection with an Ok. If we + // cannot give it to the user, then return the Err. + self.dispatch.recv_msg(Err(e)).map(Async::Ready) + }) + } + + fn poll_inner(&mut self, should_shutdown: bool) -> Poll<(), ::Error> { self.poll_read()?; self.poll_write()?; self.poll_flush()?; if self.is_done() { - try_ready!(self.conn.shutdown()); + if should_shutdown { + try_ready!(self.conn.shutdown()); + } self.conn.take_error()?; - trace!("Dispatch::poll done"); Ok(Async::Ready(())) } else { Ok(Async::NotReady) @@ -183,7 +212,7 @@ where loop { if self.is_closing { return Ok(Async::Ready(())); - } else if self.body_rx.is_none() && self.dispatch.should_poll() { + } else if self.body_rx.is_none() && self.conn.can_write_head() && self.dispatch.should_poll() { if let Some((head, body)) = try_ready!(self.dispatch.poll_msg()) { self.conn.write_head(head, body.is_some()); self.body_rx = body; @@ -257,13 +286,12 @@ where } -impl Future for Dispatcher +impl Future for Dispatcher where D: Dispatch, PollBody=Bs, RecvItem=MessageHead>, I: AsyncRead + AsyncWrite, B: AsRef<[u8]>, T: Http1Transaction, - K: KeepAlive, Bs: Stream, { type Item = (); @@ -271,14 +299,7 @@ where #[inline] fn poll(&mut self) -> Poll { - trace!("Dispatcher::poll"); - self.poll2().or_else(|e| { - // An error means we're shutting down either way. - // We just try to give the error to the user, - // and close the connection with an Ok. If we - // cannot give it to the user, then return the Err. - self.dispatch.recv_msg(Err(e)).map(Async::Ready) - }) + self.poll_until_shutdown() } } @@ -445,8 +466,8 @@ mod tests { let _ = pretty_env_logger::try_init(); ::futures::lazy(|| { let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\n\r\n".to_vec(), 100); - let (tx, rx) = ::client::dispatch::channel(); - let conn = Conn::<_, ::Chunk, ClientTransaction>::new(io, Default::default()); + let (mut tx, rx) = ::client::dispatch::channel(); + let conn = Conn::<_, ::Chunk, ClientTransaction>::new(io); let mut dispatcher = Dispatcher::new(Client::new(rx), conn); let req = RequestHead { diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 585ea94b45..4b112a37e1 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -152,6 +152,10 @@ where }) } + pub fn into_inner(self) -> (T, Bytes) { + (self.io, self.read_buf.freeze()) + } + pub fn io_mut(&mut self) -> &mut T { &mut self.io } diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 04e34c5806..4ff94ee8e0 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -1,4 +1,4 @@ -pub use self::conn::{Conn, KeepAlive, KA}; +pub use self::conn::Conn; pub use self::decode::Decoder; pub use self::encode::{EncodedBuf, Encoder}; diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 40324e1a7f..de569e2270 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -5,8 +5,8 @@ use httparse; use bytes::{BytesMut, Bytes}; use header::{self, Headers, ContentLength, TransferEncoding}; -use proto::{MessageHead, RawStatus, Http1Transaction, ParseResult, - ServerTransaction, ClientTransaction, RequestLine, RequestHead}; +use proto::{Decode, MessageHead, RawStatus, Http1Transaction, ParseResult, + RequestLine, RequestHead}; use proto::h1::{Encoder, Decoder, date}; use method::Method; use status::StatusCode; @@ -15,7 +15,19 @@ use version::HttpVersion::{Http10, Http11}; const MAX_HEADERS: usize = 100; const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific -impl Http1Transaction for ServerTransaction { +// There are 2 main roles, Client and Server. +// +// There is 1 modifier, OnUpgrade, which can wrap Client and Server, +// to signal that HTTP upgrades are not supported. + +pub struct Client(T); + +pub struct Server(T); + +impl Http1Transaction for Server +where + T: OnUpgrade, +{ type Incoming = RequestLine; type Outgoing = StatusCode; @@ -72,7 +84,7 @@ impl Http1Transaction for ServerTransaction { }, len))) } - fn decoder(head: &MessageHead, method: &mut Option) -> ::Result> { + fn decoder(head: &MessageHead, method: &mut Option) -> ::Result { use ::header; *method = Some(head.subject.0.clone()); @@ -95,37 +107,40 @@ impl Http1Transaction for ServerTransaction { debug!("HTTP/1.0 has Transfer-Encoding header"); Err(::Error::Header) } else if encodings.last() == Some(&header::Encoding::Chunked) { - Ok(Some(Decoder::chunked())) + Ok(Decode::Normal(Decoder::chunked())) } else { debug!("request with transfer-encoding header, but not chunked, bad request"); Err(::Error::Header) } } else if let Some(&header::ContentLength(len)) = head.headers.get() { - Ok(Some(Decoder::length(len))) + Ok(Decode::Normal(Decoder::length(len))) } else if head.headers.has::() { debug!("illegal Content-Length: {:?}", head.headers.get_raw("Content-Length")); Err(::Error::Header) } else { - Ok(Some(Decoder::length(0))) + Ok(Decode::Normal(Decoder::length(0))) } } fn encode(mut head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> ::Result { - trace!("ServerTransaction::encode has_body={}, method={:?}", has_body, method); + trace!("Server::encode has_body={}, method={:?}", has_body, method); // hyper currently doesn't support returning 1xx status codes as a Response // This is because Service only allows returning a single Response, and // so if you try to reply with a e.g. 100 Continue, you have no way of // replying with the latter status code response. - let ret = if head.subject.is_informational() { + let ret = if ::StatusCode::SwitchingProtocols == head.subject { + T::on_encode_upgrade(&mut head) + .map(|_| Server::set_length(&mut head, has_body, method.as_ref())) + } else if head.subject.is_informational() { error!("response with 1xx status code not supported"); head = MessageHead::default(); head.subject = ::StatusCode::InternalServerError; head.headers.set(ContentLength(0)); Err(::Error::Status) } else { - Ok(ServerTransaction::set_length(&mut head, has_body, method.as_ref())) + Ok(Server::set_length(&mut head, has_body, method.as_ref())) }; @@ -179,7 +194,7 @@ impl Http1Transaction for ServerTransaction { } } -impl ServerTransaction { +impl Server<()> { fn set_length(head: &mut MessageHead, has_body: bool, method: Option<&Method>) -> Encoder { // these are here thanks to borrowck // `if method == Some(&Method::Get)` says the RHS doesn't live long enough @@ -214,7 +229,10 @@ impl ServerTransaction { } } -impl Http1Transaction for ClientTransaction { +impl Http1Transaction for Client +where + T: OnUpgrade, +{ type Incoming = RawStatus; type Outgoing = RequestLine; @@ -262,7 +280,7 @@ impl Http1Transaction for ClientTransaction { }, len))) } - fn decoder(inc: &MessageHead, method: &mut Option) -> ::Result> { + fn decoder(inc: &MessageHead, method: &mut Option) -> ::Result { // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 // 1. HEAD responses, and Status 1xx, 204, and 304 cannot have a body. // 2. Status 2xx to a CONNECT cannot have a body. @@ -274,30 +292,29 @@ impl Http1Transaction for ClientTransaction { match inc.subject.0 { 101 => { - debug!("received 101 upgrade response, not supported"); - return Err(::Error::Upgrade); + return T::on_decode_upgrade().map(Decode::Final); }, 100...199 => { trace!("ignoring informational response: {}", inc.subject.0); - return Ok(None); + return Ok(Decode::Ignore); }, 204 | - 304 => return Ok(Some(Decoder::length(0))), + 304 => return Ok(Decode::Normal(Decoder::length(0))), _ => (), } match *method { Some(Method::Head) => { - return Ok(Some(Decoder::length(0))); + return Ok(Decode::Normal(Decoder::length(0))); } Some(Method::Connect) => match inc.subject.0 { 200...299 => { - return Ok(Some(Decoder::length(0))); + return Ok(Decode::Final(Decoder::length(0))); }, _ => {}, }, Some(_) => {}, None => { - trace!("ClientTransaction::decoder is missing the Method"); + trace!("Client::decoder is missing the Method"); } } @@ -307,28 +324,28 @@ impl Http1Transaction for ClientTransaction { debug!("HTTP/1.0 has Transfer-Encoding header"); Err(::Error::Header) } else if codings.last() == Some(&header::Encoding::Chunked) { - Ok(Some(Decoder::chunked())) + Ok(Decode::Normal(Decoder::chunked())) } else { trace!("not chunked. read till eof"); - Ok(Some(Decoder::eof())) + Ok(Decode::Normal(Decoder::eof())) } } else if let Some(&header::ContentLength(len)) = inc.headers.get() { - Ok(Some(Decoder::length(len))) + Ok(Decode::Normal(Decoder::length(len))) } else if inc.headers.has::() { debug!("illegal Content-Length: {:?}", inc.headers.get_raw("Content-Length")); Err(::Error::Header) } else { trace!("neither Transfer-Encoding nor Content-Length"); - Ok(Some(Decoder::eof())) + Ok(Decode::Normal(Decoder::eof())) } } fn encode(mut head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> ::Result { - trace!("ClientTransaction::encode has_body={}, method={:?}", has_body, method); + trace!("Client::encode has_body={}, method={:?}", has_body, method); *method = Some(head.subject.0.clone()); - let body = ClientTransaction::set_length(&mut head, has_body); + let body = Client::set_length(&mut head, has_body); let init_cap = 30 + head.headers.len() * AVERAGE_HEADER_SIZE; dst.reserve(init_cap); @@ -351,7 +368,7 @@ impl Http1Transaction for ClientTransaction { } } -impl ClientTransaction { +impl Client<()> { fn set_length(head: &mut RequestHead, has_body: bool) -> Encoder { if has_body { let can_chunked = head.version == Http11 @@ -393,6 +410,42 @@ fn set_length(headers: &mut Headers, can_chunked: bool) -> Encoder { } } +pub trait OnUpgrade { + fn on_encode_upgrade(head: &mut MessageHead) -> ::Result<()>; + fn on_decode_upgrade() -> ::Result; +} + +pub enum YesUpgrades {} + +pub enum NoUpgrades {} + +impl OnUpgrade for YesUpgrades { + fn on_encode_upgrade(_head: &mut MessageHead) -> ::Result<()> { + Ok(()) + } + + fn on_decode_upgrade() -> ::Result { + debug!("101 response received, upgrading"); + // 101 upgrades always have no body + Ok(Decoder::length(0)) + } +} + +impl OnUpgrade for NoUpgrades { + fn on_encode_upgrade(head: &mut MessageHead) -> ::Result<()> { + error!("response with 101 status code not supported"); + *head = MessageHead::default(); + head.subject = ::StatusCode::InternalServerError; + head.headers.set(ContentLength(0)); + Err(::Error::Status) + } + + fn on_decode_upgrade() -> ::Result { + debug!("received 101 upgrade response, not supported"); + return Err(::Error::Upgrade); + } +} + #[derive(Clone, Copy)] struct HeaderIndices { name: (usize, usize), @@ -456,16 +509,43 @@ fn extend(dst: &mut Vec, data: &[u8]) { mod tests { use bytes::BytesMut; - use proto::{MessageHead, ServerTransaction, ClientTransaction, Http1Transaction}; + use proto::{Decode, MessageHead}; + use super::{Decoder, Server as S, Client as C, NoUpgrades, Http1Transaction}; use header::{ContentLength, TransferEncoding}; + type Server = S; + type Client = C; + + impl Decode { + fn final_(self) -> Decoder { + match self { + Decode::Final(d) => d, + other => panic!("expected Final, found {:?}", other), + } + } + + fn normal(self) -> Decoder { + match self { + Decode::Normal(d) => d, + other => panic!("expected Normal, found {:?}", other), + } + } + + fn ignore(self) { + match self { + Decode::Ignore => {}, + other => panic!("expected Ignore, found {:?}", other), + } + } + } + #[test] fn test_parse_request() { extern crate pretty_env_logger; let _ = pretty_env_logger::try_init(); let mut raw = BytesMut::from(b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n".to_vec()); let expected_len = raw.len(); - let (req, len) = ServerTransaction::parse(&mut raw).unwrap().unwrap(); + let (req, len) = Server::parse(&mut raw).unwrap().unwrap(); assert_eq!(len, expected_len); assert_eq!(req.subject.0, ::Method::Get); assert_eq!(req.subject.1, "/echo"); @@ -481,7 +561,7 @@ mod tests { let _ = pretty_env_logger::try_init(); let mut raw = BytesMut::from(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec()); let expected_len = raw.len(); - let (req, len) = ClientTransaction::parse(&mut raw).unwrap().unwrap(); + let (req, len) = Client::parse(&mut raw).unwrap().unwrap(); assert_eq!(len, expected_len); assert_eq!(req.subject.0, 200); assert_eq!(req.subject.1, "OK"); @@ -493,17 +573,17 @@ mod tests { #[test] fn test_parse_request_errors() { let mut raw = BytesMut::from(b"GET htt:p// HTTP/1.1\r\nHost: hyper.rs\r\n\r\n".to_vec()); - ServerTransaction::parse(&mut raw).unwrap_err(); + Server::parse(&mut raw).unwrap_err(); } #[test] fn test_parse_raw_status() { let mut raw = BytesMut::from(b"HTTP/1.1 200 OK\r\n\r\n".to_vec()); - let (res, _) = ClientTransaction::parse(&mut raw).unwrap().unwrap(); + let (res, _) = Client::parse(&mut raw).unwrap().unwrap(); assert_eq!(res.subject.1, "OK"); let mut raw = BytesMut::from(b"HTTP/1.1 200 Howdy\r\n\r\n".to_vec()); - let (res, _) = ClientTransaction::parse(&mut raw).unwrap().unwrap(); + let (res, _) = Client::parse(&mut raw).unwrap().unwrap(); assert_eq!(res.subject.1, "Howdy"); } @@ -516,32 +596,32 @@ mod tests { let mut head = MessageHead::<::proto::RequestLine>::default(); head.subject.0 = ::Method::Get; - assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(0), Server::decoder(&head, method).unwrap().normal()); assert_eq!(*method, Some(::Method::Get)); head.subject.0 = ::Method::Post; - assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(0), Server::decoder(&head, method).unwrap().normal()); assert_eq!(*method, Some(::Method::Post)); head.headers.set(TransferEncoding::chunked()); - assert_eq!(Decoder::chunked(), ServerTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::chunked(), Server::decoder(&head, method).unwrap().normal()); // transfer-encoding and content-length = chunked head.headers.set(ContentLength(10)); - assert_eq!(Decoder::chunked(), ServerTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::chunked(), Server::decoder(&head, method).unwrap().normal()); head.headers.remove::(); - assert_eq!(Decoder::length(10), ServerTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(10), Server::decoder(&head, method).unwrap().normal()); head.headers.set_raw("Content-Length", vec![b"5".to_vec(), b"5".to_vec()]); - assert_eq!(Decoder::length(5), ServerTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(5), Server::decoder(&head, method).unwrap().normal()); head.headers.set_raw("Content-Length", vec![b"10".to_vec(), b"11".to_vec()]); - ServerTransaction::decoder(&head, method).unwrap_err(); + Server::decoder(&head, method).unwrap_err(); head.headers.remove::(); head.headers.set_raw("Transfer-Encoding", "gzip"); - ServerTransaction::decoder(&head, method).unwrap_err(); + Server::decoder(&head, method).unwrap_err(); // http/1.0 @@ -549,14 +629,14 @@ mod tests { head.headers.clear(); // 1.0 requests can only have bodies if content-length is set - assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(0), Server::decoder(&head, method).unwrap().normal()); head.headers.set(TransferEncoding::chunked()); - ServerTransaction::decoder(&head, method).unwrap_err(); + Server::decoder(&head, method).unwrap_err(); head.headers.remove::(); head.headers.set(ContentLength(15)); - assert_eq!(Decoder::length(15), ServerTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(15), Server::decoder(&head, method).unwrap().normal()); } #[test] @@ -567,64 +647,64 @@ mod tests { let mut head = MessageHead::<::proto::RawStatus>::default(); head.subject.0 = 204; - assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(0), Client::decoder(&head, method).unwrap().normal()); head.subject.0 = 304; - assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(0), Client::decoder(&head, method).unwrap().normal()); head.subject.0 = 200; - assert_eq!(Decoder::eof(), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::eof(), Client::decoder(&head, method).unwrap().normal()); *method = Some(::Method::Head); - assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(0), Client::decoder(&head, method).unwrap().normal()); *method = Some(::Method::Connect); - assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(0), Client::decoder(&head, method).unwrap().final_()); // CONNECT receiving non 200 can have a body head.subject.0 = 404; head.headers.set(ContentLength(10)); - assert_eq!(Decoder::length(10), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(10), Client::decoder(&head, method).unwrap().normal()); head.headers.remove::(); *method = Some(::Method::Get); head.headers.set(TransferEncoding::chunked()); - assert_eq!(Decoder::chunked(), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::chunked(), Client::decoder(&head, method).unwrap().normal()); // transfer-encoding and content-length = chunked head.headers.set(ContentLength(10)); - assert_eq!(Decoder::chunked(), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::chunked(), Client::decoder(&head, method).unwrap().normal()); head.headers.remove::(); - assert_eq!(Decoder::length(10), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(10), Client::decoder(&head, method).unwrap().normal()); head.headers.set_raw("Content-Length", vec![b"5".to_vec(), b"5".to_vec()]); - assert_eq!(Decoder::length(5), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::length(5), Client::decoder(&head, method).unwrap().normal()); head.headers.set_raw("Content-Length", vec![b"10".to_vec(), b"11".to_vec()]); - ClientTransaction::decoder(&head, method).unwrap_err(); + Client::decoder(&head, method).unwrap_err(); head.headers.clear(); // 1xx status codes head.subject.0 = 100; - assert!(ClientTransaction::decoder(&head, method).unwrap().is_none()); + Client::decoder(&head, method).unwrap().ignore(); head.subject.0 = 103; - assert!(ClientTransaction::decoder(&head, method).unwrap().is_none()); + Client::decoder(&head, method).unwrap().ignore(); // 101 upgrade not supported yet head.subject.0 = 101; - ClientTransaction::decoder(&head, method).unwrap_err(); + Client::decoder(&head, method).unwrap_err(); head.subject.0 = 200; // http/1.0 head.version = ::HttpVersion::Http10; - assert_eq!(Decoder::eof(), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + assert_eq!(Decoder::eof(), Client::decoder(&head, method).unwrap().normal()); head.headers.set(TransferEncoding::chunked()); - ClientTransaction::decoder(&head, method).unwrap_err(); + Client::decoder(&head, method).unwrap_err(); } #[cfg(feature = "nightly")] @@ -656,7 +736,7 @@ mod tests { b.bytes = len as u64; b.iter(|| { - ServerTransaction::parse(&mut raw).unwrap(); + Server::parse(&mut raw).unwrap(); restart(&mut raw, len); }); @@ -688,7 +768,7 @@ mod tests { b.iter(|| { let mut vec = Vec::new(); - ServerTransaction::encode(head.clone(), true, &mut None, &mut vec).unwrap(); + Server::encode(head.clone(), true, &mut None, &mut vec).unwrap(); assert_eq!(vec.len(), len); ::test::black_box(vec); }) diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 24f077afba..49ed32100f 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -16,7 +16,7 @@ pub use self::body::Body; #[cfg(feature = "tokio-proto")] pub use self::body::TokioBody; pub use self::chunk::Chunk; -pub use self::h1::{dispatch, Conn, KeepAlive, KA}; +pub use self::h1::{dispatch, Conn}; mod body; mod chunk; @@ -134,17 +134,16 @@ pub fn expecting_continue(version: HttpVersion, headers: &Headers) -> bool { ret } -#[derive(Debug)] -pub enum ServerTransaction {} +pub type ServerTransaction = h1::role::Server; -#[derive(Debug)] -pub enum ClientTransaction {} +pub type ClientTransaction = h1::role::Client; +pub type ClientUpgradeTransaction = h1::role::Client; pub trait Http1Transaction { type Incoming; type Outgoing: Default; fn parse(bytes: &mut BytesMut) -> ParseResult; - fn decoder(head: &MessageHead, method: &mut Option<::Method>) -> ::Result>; + fn decoder(head: &MessageHead, method: &mut Option<::Method>) -> ::Result; fn encode(head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> ::Result; fn on_error(err: &::Error) -> Option>; @@ -154,6 +153,16 @@ pub trait Http1Transaction { pub type ParseResult = ::Result, usize)>>; +#[derive(Debug)] +pub enum Decode { + /// Decode normally. + Normal(h1::Decoder), + /// After this decoder is done, HTTP is done. + Final(h1::Decoder), + /// A header block that should be ignored, like unknown 1xx responses. + Ignore, +} + #[test] fn test_should_keep_alive() { let mut headers = Headers::new(); diff --git a/src/server/mod.rs b/src/server/mod.rs index 58883fb775..b1416d16f0 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -134,7 +134,6 @@ where I, ::Item, proto::ServerTransaction, - proto::KA, >, } @@ -310,12 +309,10 @@ impl + 'static> Http { I: AsyncRead + AsyncWrite, { - let ka = if self.keep_alive { - proto::KA::Busy - } else { - proto::KA::Disabled - }; - let mut conn = proto::Conn::new(io, ka); + let mut conn = proto::Conn::new(io); + if !self.keep_alive { + conn.disable_keep_alive(); + } conn.set_flush_pipeline(self.pipeline); if let Some(max) = self.max_buf_size { conn.set_max_buf_size(max); diff --git a/src/server/server_proto.rs b/src/server/server_proto.rs index 484f8ecf2d..8e46d3442a 100644 --- a/src/server/server_proto.rs +++ b/src/server/server_proto.rs @@ -99,12 +99,10 @@ impl ServerProto for Http #[inline] fn bind_transport(&self, io: T) -> Self::BindTransport { - let ka = if self.keep_alive { - proto::KA::Busy - } else { - proto::KA::Disabled - }; - let mut conn = proto::Conn::new(io, ka); + let mut conn = proto::Conn::new(io); + if !self.keep_alive { + conn.disable_keep_alive(); + } conn.set_flush_pipeline(self.pipeline); if let Some(max) = self.max_buf_size { conn.set_max_buf_size(max); diff --git a/tests/client.rs b/tests/client.rs index 5883677033..ca41416f9b 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -737,7 +737,7 @@ mod dispatch_impl { use std::time::Duration; use futures::{self, Future}; - use futures::sync::oneshot; + use futures::sync::{mpsc, oneshot}; use tokio_core::reactor::{Timeout}; use tokio_core::net::TcpStream; use tokio_io::{AsyncRead, AsyncWrite}; @@ -758,9 +758,9 @@ mod dispatch_impl { let addr = server.local_addr().unwrap(); let mut core = Core::new().unwrap(); let handle = core.handle(); - let closes = Arc::new(AtomicUsize::new(0)); + let (closes_tx, closes) = mpsc::channel(10); let client = Client::configure() - .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &core.handle()), closes.clone())) + .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &core.handle()), closes_tx)) .build(&handle); let (tx1, rx1) = oneshot::channel(); @@ -787,7 +787,7 @@ mod dispatch_impl { let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); core.run(res.join(rx).map(|r| r.0)).unwrap(); - assert_eq!(closes.load(Ordering::Relaxed), 1); + core.run(closes.into_future()).unwrap().0.expect("closes"); } #[test] @@ -799,7 +799,7 @@ mod dispatch_impl { let addr = server.local_addr().unwrap(); let mut core = Core::new().unwrap(); let handle = core.handle(); - let closes = Arc::new(AtomicUsize::new(0)); + let (closes_tx, closes) = mpsc::channel(10); let (tx1, rx1) = oneshot::channel(); @@ -819,7 +819,7 @@ mod dispatch_impl { let res = { let client = Client::configure() - .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes.clone())) + .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes_tx)) .build(&handle); client.get(uri).and_then(move |res| { assert_eq!(res.status(), hyper::StatusCode::Ok); @@ -833,7 +833,7 @@ mod dispatch_impl { let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); core.run(res.join(rx).map(|r| r.0)).unwrap(); - assert_eq!(closes.load(Ordering::Relaxed), 1); + core.run(closes.into_future()).unwrap().0.expect("closes"); } @@ -845,7 +845,7 @@ mod dispatch_impl { let addr = server.local_addr().unwrap(); let mut core = Core::new().unwrap(); let handle = core.handle(); - let closes = Arc::new(AtomicUsize::new(0)); + let (closes_tx, mut closes) = mpsc::channel(10); let (tx1, rx1) = oneshot::channel(); let (_client_drop_tx, client_drop_rx) = oneshot::channel::<()>(); @@ -869,7 +869,7 @@ mod dispatch_impl { let uri = format!("http://{}/a", addr).parse().unwrap(); let client = Client::configure() - .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes.clone())) + .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes_tx)) .build(&handle); let res = client.get(uri).and_then(move |res| { assert_eq!(res.status(), hyper::StatusCode::Ok); @@ -879,13 +879,25 @@ mod dispatch_impl { core.run(res.join(rx).map(|r| r.0)).unwrap(); // not closed yet, just idle - assert_eq!(closes.load(Ordering::Relaxed), 0); + { + futures::future::poll_fn(|| { + assert!(closes.poll()?.is_not_ready()); + Ok::<_, ()>(().into()) + }).wait().unwrap(); + } drop(client); - core.run(Timeout::new(Duration::from_millis(100), &handle).unwrap()).unwrap(); - assert_eq!(closes.load(Ordering::Relaxed), 1); + let t = Timeout::new(Duration::from_millis(100), &handle).unwrap() + .map(|_| panic!("time out")); + let close = closes.into_future() + .map(|(opt, _)| { + opt.expect("closes"); + }) + .map_err(|_| panic!("closes dropped")); + let _ = core.run(t.select(close)); } + #[test] fn drop_response_future_closes_in_progress_connection() { let _ = pretty_env_logger::try_init(); @@ -894,7 +906,7 @@ mod dispatch_impl { let addr = server.local_addr().unwrap(); let mut core = Core::new().unwrap(); let handle = core.handle(); - let closes = Arc::new(AtomicUsize::new(0)); + let (closes_tx, closes) = mpsc::channel(10); let (tx1, rx1) = oneshot::channel(); let (_client_drop_tx, client_drop_rx) = oneshot::channel::<()>(); @@ -918,7 +930,7 @@ mod dispatch_impl { let res = { let client = Client::configure() - .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes.clone())) + .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes_tx)) .build(&handle); client.get(uri) }; @@ -926,9 +938,14 @@ mod dispatch_impl { //let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); core.run(res.select2(rx1)).unwrap(); // res now dropped - core.run(Timeout::new(Duration::from_millis(100), &handle).unwrap()).unwrap(); - - assert_eq!(closes.load(Ordering::Relaxed), 1); + let t = Timeout::new(Duration::from_millis(100), &handle).unwrap() + .map(|_| panic!("time out")); + let close = closes.into_future() + .map(|(opt, _)| { + opt.expect("closes"); + }) + .map_err(|_| panic!("closes dropped")); + let _ = core.run(t.select(close)); } #[test] @@ -939,7 +956,7 @@ mod dispatch_impl { let addr = server.local_addr().unwrap(); let mut core = Core::new().unwrap(); let handle = core.handle(); - let closes = Arc::new(AtomicUsize::new(0)); + let (closes_tx, closes) = mpsc::channel(10); let (tx1, rx1) = oneshot::channel(); let (_client_drop_tx, client_drop_rx) = oneshot::channel::<()>(); @@ -962,7 +979,7 @@ mod dispatch_impl { let res = { let client = Client::configure() - .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes.clone())) + .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes_tx)) .build(&handle); // notably, havent read body yet client.get(uri) @@ -970,9 +987,15 @@ mod dispatch_impl { let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); core.run(res.join(rx).map(|r| r.0)).unwrap(); - core.run(Timeout::new(Duration::from_millis(100), &handle).unwrap()).unwrap(); - assert_eq!(closes.load(Ordering::Relaxed), 1); + let t = Timeout::new(Duration::from_millis(100), &handle).unwrap() + .map(|_| panic!("time out")); + let close = closes.into_future() + .map(|(opt, _)| { + opt.expect("closes"); + }) + .map_err(|_| panic!("closes dropped")); + let _ = core.run(t.select(close)); } #[test] @@ -984,7 +1007,7 @@ mod dispatch_impl { let addr = server.local_addr().unwrap(); let mut core = Core::new().unwrap(); let handle = core.handle(); - let closes = Arc::new(AtomicUsize::new(0)); + let (closes_tx, closes) = mpsc::channel(10); let (tx1, rx1) = oneshot::channel(); @@ -1001,7 +1024,7 @@ mod dispatch_impl { let uri = format!("http://{}/a", addr).parse().unwrap(); let client = Client::configure() - .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes.clone())) + .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes_tx)) .keep_alive(false) .build(&handle); let res = client.get(uri).and_then(move |res| { @@ -1011,7 +1034,14 @@ mod dispatch_impl { let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); core.run(res.join(rx).map(|r| r.0)).unwrap(); - assert_eq!(closes.load(Ordering::Relaxed), 1); + let t = Timeout::new(Duration::from_millis(100), &handle).unwrap() + .map(|_| panic!("time out")); + let close = closes.into_future() + .map(|(opt, _)| { + opt.expect("closes"); + }) + .map_err(|_| panic!("closes dropped")); + let _ = core.run(t.select(close)); } #[test] @@ -1023,7 +1053,7 @@ mod dispatch_impl { let addr = server.local_addr().unwrap(); let mut core = Core::new().unwrap(); let handle = core.handle(); - let closes = Arc::new(AtomicUsize::new(0)); + let (closes_tx, closes) = mpsc::channel(10); let (tx1, rx1) = oneshot::channel(); @@ -1040,19 +1070,24 @@ mod dispatch_impl { let uri = format!("http://{}/a", addr).parse().unwrap(); let client = Client::configure() - .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes.clone())) + .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes_tx)) .build(&handle); let res = client.get(uri).and_then(move |res| { assert_eq!(res.status(), hyper::StatusCode::Ok); res.body().concat2() }); let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); - - let timeout = Timeout::new(Duration::from_millis(200), &handle).unwrap(); - let rx = rx.and_then(move |_| timeout.map_err(|e| e.into())); core.run(res.join(rx).map(|r| r.0)).unwrap(); - assert_eq!(closes.load(Ordering::Relaxed), 1); + + let t = Timeout::new(Duration::from_millis(100), &handle).unwrap() + .map(|_| panic!("time out")); + let close = closes.into_future() + .map(|(opt, _)| { + opt.expect("closes"); + }) + .map_err(|_| panic!("closes dropped")); + let _ = core.run(t.select(close)); } #[test] @@ -1113,7 +1148,7 @@ mod dispatch_impl { let addr = server.local_addr().unwrap(); let mut core = Core::new().unwrap(); let handle = core.handle(); - let closes = Arc::new(AtomicUsize::new(0)); + let (closes_tx, closes) = mpsc::channel(10); let (tx1, rx1) = oneshot::channel(); @@ -1130,7 +1165,7 @@ mod dispatch_impl { let uri = format!("http://{}/a", addr).parse().unwrap(); let client = Client::configure() - .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes.clone())) + .connector(DebugConnector::with_http_and_closes(HttpConnector::new(1, &handle), closes_tx)) .executor(handle.clone()); let res = client.get(uri).and_then(move |res| { assert_eq!(res.status(), hyper::StatusCode::Ok); @@ -1142,75 +1177,58 @@ mod dispatch_impl { let rx = rx.and_then(move |_| timeout.map_err(|e| e.into())); core.run(res.join(rx).map(|r| r.0)).unwrap(); - assert_eq!(closes.load(Ordering::Relaxed), 1); + + let t = Timeout::new(Duration::from_millis(100), &handle).unwrap() + .map(|_| panic!("time out")); + let close = closes.into_future() + .map(|(opt, _)| { + opt.expect("closes"); + }) + .map_err(|_| panic!("closes dropped")); + let _ = core.run(t.select(close)); } #[test] - fn idle_conn_prevents_connect_call() { + fn connect_call_is_lazy() { + // We especially don't want connects() triggered if there's + // idle connections that the Checkout would have found let _ = pretty_env_logger::try_init(); let server = TcpListener::bind("127.0.0.1:0").unwrap(); let addr = server.local_addr().unwrap(); - let mut core = Core::new().unwrap(); + let core = Core::new().unwrap(); let handle = core.handle(); let connector = DebugConnector::new(&handle); let connects = connector.connects.clone(); - let (tx1, rx1) = oneshot::channel(); - let (tx2, rx2) = oneshot::channel(); - - thread::spawn(move || { - let mut sock = server.accept().unwrap().0; - sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); - sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); - let mut buf = [0; 4096]; - sock.read(&mut buf).expect("read 1"); - sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").unwrap(); - let _ = tx1.send(()); - - sock.read(&mut buf).expect("read 2"); - sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").unwrap(); - let _ = tx2.send(()); - }); - let uri: hyper::Uri = format!("http://{}/a", addr).parse().unwrap(); let client = Client::configure() .connector(connector) .build(&handle); - let res = client.get(uri.clone()).and_then(move |res| { - assert_eq!(res.status(), hyper::StatusCode::Ok); - res.body().concat2() - }); - let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); - core.run(res.join(rx).map(|r| r.0)).unwrap(); - assert_eq!(connects.load(Ordering::Relaxed), 1); - - let res2 = client.get(uri).and_then(move |res| { - assert_eq!(res.status(), hyper::StatusCode::Ok); - res.body().concat2() - }); - let rx = rx2.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); - core.run(res2.join(rx).map(|r| r.0)).unwrap(); - - assert_eq!(connects.load(Ordering::Relaxed), 1); + assert_eq!(connects.load(Ordering::Relaxed), 0); + let _fut = client.get(uri); + // internal Connect::connect should have been lazy, and not + // triggered an actual connect yet. + assert_eq!(connects.load(Ordering::Relaxed), 0); } struct DebugConnector { http: HttpConnector, - closes: Arc, + closes: mpsc::Sender<()>, connects: Arc, } impl DebugConnector { fn new(handle: &Handle) -> DebugConnector { let http = HttpConnector::new(1, handle); - DebugConnector::with_http_and_closes(http, Arc::new(AtomicUsize::new(0))) + let (tx, _) = mpsc::channel(10); + DebugConnector::with_http_and_closes(http, tx) } - fn with_http_and_closes(http: HttpConnector, closes: Arc) -> DebugConnector { + fn with_http_and_closes(http: HttpConnector, closes: mpsc::Sender<()>) -> DebugConnector { DebugConnector { http: http, closes: closes, @@ -1234,11 +1252,11 @@ mod dispatch_impl { } } - struct DebugStream(TcpStream, Arc); + struct DebugStream(TcpStream, mpsc::Sender<()>); impl Drop for DebugStream { fn drop(&mut self) { - self.1.fetch_add(1, Ordering::SeqCst); + let _ = self.1.try_send(()); } } @@ -1266,3 +1284,301 @@ mod dispatch_impl { impl AsyncRead for DebugStream {} } + +mod conn { + use std::io::{self, Read, Write}; + use std::net::TcpListener; + use std::thread; + use std::time::Duration; + + use futures::{Async, Future, Poll, Stream}; + use futures::future::poll_fn; + use futures::sync::oneshot; + use tokio_core::reactor::{Core, Timeout}; + use tokio_core::net::TcpStream; + use tokio_io::{AsyncRead, AsyncWrite}; + + use hyper::{self, Method, Request}; + use hyper::client::conn; + + #[test] + fn get() { + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + + let (tx1, rx1) = oneshot::channel(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").unwrap(); + let _ = tx1.send(()); + }); + + let tcp = core.run(TcpStream::connect(&addr, &handle)).unwrap(); + + let (mut client, conn) = core.run(conn::handshake(tcp)).unwrap(); + + handle.spawn(conn.map(|_| ()).map_err(|e| panic!("conn error: {}", e))); + + let uri = format!("http://{}/a", addr).parse().unwrap(); + let req = Request::new(Method::Get, uri); + + let res = client.send_request(req).and_then(move |res| { + assert_eq!(res.status(), hyper::StatusCode::Ok); + res.body().concat2() + }); + let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + + let timeout = Timeout::new(Duration::from_millis(200), &handle).unwrap(); + let rx = rx.and_then(move |_| timeout.map_err(|e| e.into())); + core.run(res.join(rx).map(|r| r.0)).unwrap(); + } + + #[test] + fn pipeline() { + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + + let (tx1, rx1) = oneshot::channel(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").unwrap(); + let _ = tx1.send(()); + }); + + let tcp = core.run(TcpStream::connect(&addr, &handle)).unwrap(); + + let (mut client, conn) = core.run(conn::handshake(tcp)).unwrap(); + + handle.spawn(conn.map(|_| ()).map_err(|e| panic!("conn error: {}", e))); + + let uri = format!("http://{}/a", addr).parse().unwrap(); + let req = Request::new(Method::Get, uri); + let res1 = client.send_request(req).and_then(move |res| { + assert_eq!(res.status(), hyper::StatusCode::Ok); + res.body().concat2() + }); + + // pipelined request will hit NotReady, and thus should return an Error::Cancel + let uri = format!("http://{}/b", addr).parse().unwrap(); + let req = Request::new(Method::Get, uri); + let res2 = client.send_request(req) + .then(|result| { + let err = result.expect_err("res2"); + match err { + hyper::Error::Cancel(..) => (), + other => panic!("expected Cancel, found {:?}", other), + } + Ok(()) + }); + + let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + + let timeout = Timeout::new(Duration::from_millis(200), &handle).unwrap(); + let rx = rx.and_then(move |_| timeout.map_err(|e| e.into())); + core.run(res1.join(res2).join(rx).map(|r| r.0)).unwrap(); + } + + #[test] + fn upgrade() { + use tokio_io::io::{read_to_end, write_all}; + let _ = ::pretty_env_logger::try_init(); + + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + + let (tx1, rx1) = oneshot::channel(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + sock.write_all(b"\ + HTTP/1.1 101 Switching Protocols\r\n\ + Upgrade: foobar\r\n\ + \r\n\ + foobar=ready\ + ").unwrap(); + let _ = tx1.send(()); + + let n = sock.read(&mut buf).expect("read 2"); + assert_eq!(&buf[..n], b"foo=bar"); + sock.write_all(b"bar=foo").expect("write 2"); + }); + + let tcp = core.run(TcpStream::connect(&addr, &handle)).unwrap(); + + let io = DebugStream { + tcp: tcp, + shutdown_called: false, + }; + + let (mut client, mut conn) = core.run(conn::handshake(io)).unwrap(); + + { + let until_upgrade = poll_fn(|| { + conn.poll_without_shutdown() + }); + + let uri = format!("http://{}/a", addr).parse().unwrap(); + let req = Request::new(Method::Get, uri); + let res = client.send_request(req).and_then(move |res| { + assert_eq!(res.status(), hyper::StatusCode::SwitchingProtocols); + assert_eq!(res.headers().get_raw("Upgrade").unwrap(), "foobar"); + res.body().concat2() + }); + + let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + + let timeout = Timeout::new(Duration::from_millis(200), &handle).unwrap(); + let rx = rx.and_then(move |_| timeout.map_err(|e| e.into())); + core.run(until_upgrade.join(res).join(rx).map(|r| r.0)).unwrap(); + + // should not be ready now + core.run(poll_fn(|| { + assert!(client.poll_ready().unwrap().is_not_ready()); + Ok::<_, ()>(Async::Ready(())) + })).unwrap(); + } + + let parts = conn.into_parts(); + let io = parts.io; + let buf = parts.read_buf; + + assert_eq!(buf, b"foobar=ready"[..]); + assert!(!io.shutdown_called, "upgrade shouldn't shutdown AsyncWrite"); + assert!(client.poll_ready().is_err()); + + let io = core.run(write_all(io, b"foo=bar")).unwrap().0; + let vec = core.run(read_to_end(io, vec![])).unwrap().1; + assert_eq!(vec, b"bar=foo"); + } + + #[test] + fn connect_method() { + use tokio_io::io::{read_to_end, write_all}; + let _ = ::pretty_env_logger::try_init(); + + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + + let (tx1, rx1) = oneshot::channel(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + sock.write_all(b"\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + foobar=ready\ + ").unwrap(); + let _ = tx1.send(()); + + let n = sock.read(&mut buf).expect("read 2"); + assert_eq!(&buf[..n], b"foo=bar", "sock read 2 bytes"); + sock.write_all(b"bar=foo").expect("write 2"); + }); + + let tcp = core.run(TcpStream::connect(&addr, &handle)).unwrap(); + + let io = DebugStream { + tcp: tcp, + shutdown_called: false, + }; + + let (mut client, mut conn) = core.run(conn::handshake(io)).unwrap(); + + { + let until_tunneled = poll_fn(|| { + conn.poll_without_shutdown() + }); + + let uri = format!("http://{}", addr).parse().unwrap(); + let req = Request::new(Method::Connect, uri); + let res = client.send_request(req) + .and_then(move |res| { + assert_eq!(res.status(), hyper::StatusCode::Ok); + res.body().concat2() + }) + .map(|body| { + assert_eq!(body.as_ref(), b""); + }); + + let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + + let timeout = Timeout::new(Duration::from_millis(200), &handle).unwrap(); + let rx = rx.and_then(move |_| timeout.map_err(|e| e.into())); + core.run(until_tunneled.join(res).join(rx).map(|r| r.0)).unwrap(); + + // should not be ready now + core.run(poll_fn(|| { + assert!(client.poll_ready().unwrap().is_not_ready()); + Ok::<_, ()>(Async::Ready(())) + })).unwrap(); + } + + let parts = conn.into_parts(); + let io = parts.io; + let buf = parts.read_buf; + + assert_eq!(buf, b"foobar=ready"[..]); + assert!(!io.shutdown_called, "tunnel shouldn't shutdown AsyncWrite"); + assert!(client.poll_ready().is_err()); + + let io = core.run(write_all(io, b"foo=bar")).unwrap().0; + let vec = core.run(read_to_end(io, vec![])).unwrap().1; + assert_eq!(vec, b"bar=foo"); + } + + struct DebugStream { + tcp: TcpStream, + shutdown_called: bool, + } + + impl Write for DebugStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.tcp.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.tcp.flush() + } + } + + impl AsyncWrite for DebugStream { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.shutdown_called = true; + AsyncWrite::shutdown(&mut self.tcp) + } + } + + impl Read for DebugStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.tcp.read(buf) + } + } + + impl AsyncRead for DebugStream {} +}