From e689f20376d3e078f5d380902d39f8ae9c043486 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 17 Jun 2015 13:17:56 -0700 Subject: [PATCH] fix(client): check for drained stream in Response::drop --- src/client/pool.rs | 18 +++++----------- src/client/response.rs | 48 ++++++++++++++++++++++++------------------ src/http/mod.rs | 1 + 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/src/client/pool.rs b/src/client/pool.rs index 45b3bb1a42..9736a6a5d2 100644 --- a/src/client/pool.rs +++ b/src/client/pool.rs @@ -116,7 +116,6 @@ impl, S: NetworkStream + Send> NetworkConnector fo Ok(PooledStream { inner: Some((key, conn)), is_closed: false, - is_drained: false, pool: self.inner.clone() }) } @@ -130,20 +129,13 @@ impl, S: NetworkStream + Send> NetworkConnector fo pub struct PooledStream { inner: Option<(Key, S)>, is_closed: bool, - is_drained: bool, pool: Arc>> } impl Read for PooledStream { #[inline] fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.inner.as_mut().unwrap().1.read(buf) { - Ok(0) => { - self.is_drained = true; - Ok(0) - } - r => r - } + self.inner.as_mut().unwrap().1.read(buf) } } @@ -174,8 +166,8 @@ impl NetworkStream for PooledStream { impl Drop for PooledStream { fn drop(&mut self) { - trace!("PooledStream.drop, is_closed={}, is_drained={}", self.is_closed, self.is_drained); - if !self.is_closed && self.is_drained { + trace!("PooledStream.drop, is_closed={}", self.is_closed); + if !self.is_closed { self.inner.take().map(|(key, conn)| { if let Ok(mut pool) = self.pool.lock() { pool.reuse(key, conn); @@ -205,13 +197,13 @@ mod tests { fn test_connect_and_drop() { let pool = mocked!(); let key = key("127.0.0.1", 3000, "http"); - pool.connect("127.0.0.1", 3000, "http").unwrap().is_drained = true; + pool.connect("127.0.0.1", 3000, "http").unwrap(); { let locked = pool.inner.lock().unwrap(); assert_eq!(locked.conns.len(), 1); assert_eq!(locked.conns.get(&key).unwrap().len(), 1); } - pool.connect("127.0.0.1", 3000, "http").unwrap().is_drained = true; //reused + pool.connect("127.0.0.1", 3000, "http").unwrap(); //reused { let locked = pool.inner.lock().unwrap(); assert_eq!(locked.conns.len(), 1); diff --git a/src/client/response.rs b/src/client/response.rs index 5e7522c96c..f3611cf5c5 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -19,6 +19,7 @@ pub struct Response { pub version: version::HttpVersion, status_raw: RawStatus, message: Box, + is_drained: bool, } impl Response { @@ -43,6 +44,7 @@ impl Response { headers: headers, message: message, status_raw: raw_status, + is_drained: false, }) } @@ -50,34 +52,46 @@ impl Response { pub fn status_raw(&self) -> &RawStatus { &self.status_raw } + } impl Read for Response { #[inline] fn read(&mut self, buf: &mut [u8]) -> io::Result { - let count = try!(self.message.read(buf)); + match self.message.read(buf) { + Ok(0) => { + self.is_drained = true; + Ok(0) + }, + r => r + } + } +} - if count == 0 { - if !http::should_keep_alive(self.version, &self.headers) { - try!(self.message.close_connection() - .map_err(|_| io::Error::new(io::ErrorKind::Other, - "Error closing connection"))); +impl Drop for Response { + fn drop(&mut self) { + // if not drained, theres old bits in the Reader. we can't reuse this, + // since those old bits would end up in new Responses + // + // otherwise, the response has been drained. we should check that the + // server has agreed to keep the connection open + trace!("Response.is_drained = {:?}", self.is_drained); + if !(self.is_drained && http::should_keep_alive(self.version, &self.headers)) { + trace!("closing connection"); + if let Err(e) = self.message.close_connection() { + error!("error closing connection: {}", e); } } - - Ok(count) } } #[cfg(test)] mod tests { - use std::borrow::Cow::Borrowed; use std::io::{self, Read}; - use header::Headers; use header::TransferEncoding; use header::Encoding; - use http::RawStatus; + use http::HttpMessage; use mock::MockStream; use status; use version; @@ -94,18 +108,10 @@ mod tests { #[test] fn test_into_inner() { - let res = Response { - status: status::StatusCode::Ok, - headers: Headers::new(), - version: version::HttpVersion::Http11, - message: Box::new(Http11Message::with_stream(Box::new(MockStream::new()))), - status_raw: RawStatus(200, Borrowed("OK")), - }; - - let message = res.message.downcast::().ok().unwrap(); + let message: Box = Box::new(Http11Message::with_stream(Box::new(MockStream::new()))); + let message = message.downcast::().ok().unwrap(); let b = message.into_inner().downcast::().ok().unwrap(); assert_eq!(b, Box::new(MockStream::new())); - } #[test] diff --git a/src/http/mod.rs b/src/http/mod.rs index 6571f073bf..e997f7c5cd 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -20,6 +20,7 @@ pub struct RawStatus(pub u16, pub Cow<'static, str>); /// Checks if a connection should be kept alive. #[inline] pub fn should_keep_alive(version: HttpVersion, headers: &Headers) -> bool { + trace!("should_keep_alive( {:?}, {:?} )", version, headers.get::()); match (version, headers.get::()) { (Http10, Some(conn)) if !conn.contains(&KeepAlive) => false, (Http11, Some(conn)) if conn.contains(&Close) => false,