diff --git a/src/body/body.rs b/src/body/body.rs index 87cf087426..71c180f07a 100644 --- a/src/body/body.rs +++ b/src/body/body.rs @@ -11,7 +11,6 @@ use common::Never; use super::{Chunk, Payload}; use super::internal::{FullDataArg, FullDataRet}; - type BodySender = mpsc::Sender>; /// A stream of `Chunk`s, used when receiving bodies. @@ -36,7 +35,7 @@ pub struct Body { enum Kind { Once(Option), Chan { - _close_tx: oneshot::Sender<()>, + abort_rx: oneshot::Receiver<()>, rx: mpsc::Receiver>, }, H2(h2::RecvStream), @@ -61,7 +60,7 @@ enum DelayEof { #[must_use = "Sender does nothing unless sent on"] #[derive(Debug)] pub struct Sender { - close_rx: oneshot::Receiver<()>, + abort_tx: oneshot::Sender<()>, tx: BodySender, } @@ -87,14 +86,14 @@ impl Body { #[inline] pub fn channel() -> (Sender, Body) { let (tx, rx) = mpsc::channel(0); - let (close_tx, close_rx) = oneshot::channel(); + let (abort_tx, abort_rx) = oneshot::channel(); let tx = Sender { - close_rx: close_rx, + abort_tx: abort_tx, tx: tx, }; let rx = Body::new(Kind::Chan { - _close_tx: close_tx, + abort_rx: abort_rx, rx: rx, }); @@ -189,11 +188,17 @@ impl Body { fn poll_inner(&mut self) -> Poll, ::Error> { match self.kind { Kind::Once(ref mut val) => Ok(Async::Ready(val.take())), - Kind::Chan { ref mut rx, .. } => match rx.poll().expect("mpsc cannot error") { - Async::Ready(Some(Ok(chunk))) => Ok(Async::Ready(Some(chunk))), - Async::Ready(Some(Err(err))) => Err(err), - Async::Ready(None) => Ok(Async::Ready(None)), - Async::NotReady => Ok(Async::NotReady), + Kind::Chan { ref mut rx, ref mut abort_rx } => { + if let Ok(Async::Ready(())) = abort_rx.poll() { + return Err(::Error::new_body_write("body write aborted")); + } + + match rx.poll().expect("mpsc cannot error") { + Async::Ready(Some(Ok(chunk))) => Ok(Async::Ready(Some(chunk))), + Async::Ready(Some(Err(err))) => Err(err), + Async::Ready(None) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } }, Kind::H2(ref mut h2) => { h2.poll() @@ -283,7 +288,7 @@ impl fmt::Debug for Body { impl Sender { /// Check to see if this `Sender` can send more data. pub fn poll_ready(&mut self) -> Poll<(), ::Error> { - match self.close_rx.poll() { + match self.abort_tx.poll_cancel() { Ok(Async::Ready(())) | Err(_) => return Err(::Error::new_closed()), Ok(Async::NotReady) => (), } @@ -303,6 +308,11 @@ impl Sender { .map_err(|err| err.into_inner().expect("just sent Ok")) } + /// Aborts the body in an abnormal fashion. + pub fn abort(self) { + let _ = self.abort_tx.send(()); + } + pub(crate) fn send_error(&mut self, err: ::Error) { let _ = self.tx.try_send(Err(err)); } diff --git a/src/error.rs b/src/error.rs index f1326e2d99..0eb064f2ef 100644 --- a/src/error.rs +++ b/src/error.rs @@ -265,7 +265,7 @@ impl StdError for Error { Kind::NewService => "calling user's new_service failed", Kind::Service => "error from user's server service", Kind::Body => "error reading a body from connection", - Kind::BodyWrite => "error write a body to connection", + Kind::BodyWrite => "error writing a body to connection", Kind::BodyUser => "error from user's Payload stream", Kind::Shutdown => "error shutting down connection", Kind::Http2 => "http2 general error", diff --git a/tests/client.rs b/tests/client.rs index 3436ca7eee..bd5958a7f8 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1373,7 +1373,7 @@ mod conn { use tokio::net::TcpStream; use tokio_io::{AsyncRead, AsyncWrite}; - use hyper::{self, Request}; + use hyper::{self, Request, Body, Method}; use hyper::client::conn; use super::{s, tcp_connect, FutureHyperExt}; @@ -1424,6 +1424,53 @@ mod conn { res.join(rx).map(|r| r.0).wait().unwrap(); } + #[test] + fn aborted_body_isnt_completed() { + let _ = ::pretty_env_logger::try_init(); + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut runtime = Runtime::new().unwrap(); + + let (tx, rx) = oneshot::channel(); + let server = thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let expected = "POST / HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n5\r\nhello\r\n"; + let mut buf = vec![0; expected.len()]; + sock.read_exact(&mut buf).expect("read 1"); + assert_eq!(s(&buf), expected); + + let _ = tx.send(()); + + assert_eq!(sock.read(&mut buf).expect("read 2"), 0); + }); + + let tcp = tcp_connect(&addr).wait().unwrap(); + + let (mut client, conn) = conn::handshake(tcp).wait().unwrap(); + + runtime.spawn(conn.map(|_| ()).map_err(|e| panic!("conn error: {}", e))); + + let (mut sender, body) = Body::channel(); + let sender = thread::spawn(move || { + sender.send_data("hello".into()).ok().unwrap(); + rx.wait().unwrap(); + sender.abort(); + }); + + let req = Request::builder() + .method(Method::POST) + .uri("/") + .body(body) + .unwrap(); + let res = client.send_request(req); + res.wait().unwrap_err(); + + server.join().expect("server thread panicked"); + sender.join().expect("sender thread panicked"); + } + #[test] fn uri_absolute_form() { let server = TcpListener::bind("127.0.0.1:0").unwrap();