From 2f922848c3da03c26bf4230da8fcedf3bebcdca8 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Thu, 5 Oct 2017 12:11:03 -0700 Subject: [PATCH] when receiving a GOAWAY, allow earlier streams to still process Once all active streams have finished, send a GOAWAY back and close the connection. --- src/frame/go_away.rs | 5 +-- src/proto/connection.rs | 70 ++++++++++++++++++++++++++++-------- src/proto/streams/send.rs | 25 +++++++------ src/proto/streams/store.rs | 2 +- src/proto/streams/streams.rs | 34 +++++++++++++++++- tests/stream_states.rs | 69 +++++++++++++++++++++++++++++++++++ 6 files changed, 174 insertions(+), 31 deletions(-) diff --git a/src/frame/go_away.rs b/src/frame/go_away.rs index 7dacaf52c..1af7acb0a 100644 --- a/src/frame/go_away.rs +++ b/src/frame/go_away.rs @@ -16,7 +16,6 @@ impl GoAway { } } - #[cfg(feature = "unstable")] pub fn last_stream_id(&self) -> StreamId { self.last_stream_id } @@ -27,9 +26,7 @@ impl GoAway { pub fn load(payload: &[u8]) -> Result { if payload.len() < 8 { - // Invalid payload len - // TODO: Handle error - unimplemented!(); + return Err(Error::BadFrameSize); } let (last_stream_id, _) = StreamId::parse(&payload[..4]); diff --git a/src/proto/connection.rs b/src/proto/connection.rs index c2b6fdf5b..ea3213521 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -20,6 +20,12 @@ where /// Tracks the connection level state transitions. state: State, + /// An error to report back once complete. + /// + /// This exists separately from State in order to support + /// graceful shutdown. + error: Option, + /// Read / write frame values codec: Codec>, @@ -41,14 +47,14 @@ enum State { /// Currently open in a sane state Open, - /// Waiting to send a GO_AWAY frame + /// Waiting to send a GOAWAY frame GoAway(frame::GoAway), /// The codec must be flushed Flush(Reason), - /// In an errored state - Error(Reason), + /// In a closed state + Closed(Reason), } impl Connection @@ -74,6 +80,7 @@ where }); Connection { state: State::Open, + error: None, codec: codec, ping_pong: PingPong::new(), settings: Settings::new(), @@ -118,10 +125,19 @@ where // This will also handle flushing `self.codec` try_ready!(self.streams.poll_complete(&mut self.codec)); + if self.error.is_some() { + if self.streams.num_active_streams() == 0 { + let id = self.streams.last_processed_id(); + let goaway = frame::GoAway::new(id, Reason::NoError); + self.state = State::GoAway(goaway); + continue; + } + } + return Ok(Async::NotReady); }, // Attempting to read a frame resulted in a connection level - // error. This is handled by setting a GO_AWAY frame followed by + // error. This is handled by setting a GOAWAY frame followed by // terminating the connection. Err(Connection(e)) => { debug!("Connection::poll; err={:?}", e); @@ -164,24 +180,45 @@ where // Ensure the codec is ready to accept the frame try_ready!(self.codec.poll_ready()); - // Buffer the GO_AWAY frame + // Buffer the GOAWAY frame self.codec .buffer(frame.into()) .ok() .expect("invalid GO_AWAY frame"); - // GO_AWAY sent, transition the connection to an errored state - self.state = State::Flush(frame.reason()); + // GOAWAY sent, transition the connection to a closed state + // Determine what error code should be returned to user. + let reason = if let Some(theirs) = self.error.take() { + let ours = frame.reason(); + match (ours, theirs) { + // If either side reported an error, return that + // to the user. + (Reason::NoError, err) | + (err, Reason::NoError) => err, + // If both sides reported an error, give their + // error back to th user. We assume our error + // was a consequence of their error, and less + // important. + (_, theirs) => theirs, + } + } else { + frame.reason() + }; + self.state = State::Flush(reason); }, State::Flush(reason) => { // Flush the codec try_ready!(self.codec.flush()); // Transition the state to error - self.state = State::Error(reason); + self.state = State::Closed(reason); }, - State::Error(reason) => { - return Err(reason.into()); + State::Closed(reason) => { + if let Reason::NoError = reason { + return Ok(Async::Ready(())); + } else { + return Err(reason.into()); + } }, } } @@ -215,11 +252,14 @@ where trace!("recv SETTINGS; frame={:?}", frame); self.settings.recv_settings(frame); }, - Some(GoAway(_)) => { - // TODO: handle the last_processed_id. Also, should this be - // handled as an error? - // let _ = RecvError::Proto(frame.reason()); - return Ok(().into()); + Some(GoAway(frame)) => { + trace!("recv GOAWAY; frame={:?}", frame); + // This should prevent starting new streams, + // but should allow continuing to process current streams + // until they are all EOS. Once they are, State should + // transition to GoAway. + self.streams.recv_goaway(&frame); + self.error = Some(frame.reason()); }, Some(Ping(frame)) => { trace!("recv PING; frame={:?}", frame); diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 6e24912aa..3d0bbda37 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -101,22 +101,14 @@ where // Transition the state stream.state.set_reset(reason); - // Clear all pending outbound frames - self.prioritize.clear_queue(stream); - - // Reclaim all capacity assigned to the stream and re-assign it to the - // connection - let available = stream.send_flow.available(); - stream.send_flow.claim_capacity(available); + self.recv_err(stream); let frame = frame::Reset::new(stream.id, reason); trace!("send_reset -- queueing; frame={:?}", frame); self.prioritize.queue_frame(frame.into(), stream, task); - // Re-assign all capacity to the connection - self.prioritize - .assign_connection_capacity(available, stream); + } pub fn send_data( @@ -221,6 +213,19 @@ where Ok(()) } + pub fn recv_err(&mut self, stream: &mut store::Ptr) { + // Clear all pending outbound frames + self.prioritize.clear_queue(stream); + + // Reclaim all capacity assigned to the stream and re-assign it to the + // connection + let available = stream.send_flow.available(); + stream.send_flow.claim_capacity(available); + // Re-assign all capacity to the connection + self.prioritize + .assign_connection_capacity(available, stream); + } + pub fn apply_remote_settings( &mut self, settings: &frame::Settings, diff --git a/src/proto/streams/store.rs b/src/proto/streams/store.rs index 4f674d018..c38f5e218 100644 --- a/src/proto/streams/store.rs +++ b/src/proto/streams/store.rs @@ -222,7 +222,6 @@ where } } -#[cfg(feature = "unstable")] impl Store where P: Peer, @@ -231,6 +230,7 @@ where self.ids.len() } + #[cfg(feature = "unstable")] pub fn num_wired_streams(&self) -> usize { self.slab.len() } diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 8d6c3c435..2cc56b8de 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -192,6 +192,7 @@ where .for_each(|stream| { counts.transition(stream, |_, stream| { actions.recv.recv_err(err, &mut *stream); + actions.send.recv_err(stream); Ok::<_, ()>(()) }) }) @@ -202,6 +203,37 @@ where last_processed_id } + pub fn recv_goaway(&mut self, frame: &frame::GoAway) { + let mut me = self.inner.lock().unwrap(); + let me = &mut *me; + + let actions = &mut me.actions; + let counts = &mut me.counts; + + let last_stream_id = frame.last_stream_id(); + let err = frame.reason().into(); + + me.store + .for_each(|stream| { + if stream.id > last_stream_id { + counts.transition(stream, |_, stream| { + actions.recv.recv_err(&err, &mut *stream); + actions.send.recv_err(stream); + Ok::<_, ()>(()) + }) + } else { + Ok::<_, ()>(()) + } + }) + .unwrap(); + + actions.conn_error = Some(err); + } + + pub fn last_processed_id(&self) -> StreamId { + self.inner.lock().unwrap().actions.recv.last_processed_id() + } + pub fn recv_window_update(&mut self, frame: frame::WindowUpdate) -> Result<(), RecvError> { let id = frame.stream_id(); let mut me = self.inner.lock().unwrap(); @@ -446,7 +478,6 @@ where } } -#[cfg(feature = "unstable")] impl Streams where B: Buf, @@ -457,6 +488,7 @@ where me.store.num_active_streams() } + #[cfg(feature = "unstable")] pub fn num_wired_streams(&self) -> usize { let me = self.inner.lock().unwrap(); me.store.num_wired_streams() diff --git a/tests/stream_states.rs b/tests/stream_states.rs index 30f31e1d1..7d8c3aa5b 100644 --- a/tests/stream_states.rs +++ b/tests/stream_states.rs @@ -298,6 +298,75 @@ fn configure_max_frame_size() { let _ = h2.join(srv).wait().expect("wait"); } +#[test] +fn recv_goaway_finishes_processed_streams() { + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_settings() + .recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .recv_frame( + frames::headers(3) + .request("GET", "https://example.com/") + .eos(), + ) + .send_frame(frames::go_away(1)) + .send_frame(frames::headers(1).response(200)) + .send_frame(frames::data(1, vec![0; 16_384]).eos()) + // expecting a goaway of 0, since server never initiated a stream + .recv_frame(frames::go_away(0)); + //.close(); + + let h2 = Client::handshake(io) + .expect("handshake") + .and_then(|(mut client, h2)| { + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); + + let req1 = client.send_request(request, true) + .unwrap() + .expect("response") + .and_then(|resp| { + assert_eq!(resp.status(), StatusCode::OK); + let body = resp.into_parts().1; + body.concat2().expect("body") + }) + .and_then(|buf| { + assert_eq!(buf.len(), 16_384); + Ok(()) + }); + + + // this request will trigger a goaway + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); + let req2 = client.send_request(request, true) + .unwrap() + .then(|res| { + let err = res.unwrap_err(); + assert_eq!(err.to_string(), "protocol error: not a result of an error"); + Ok::<(), ()>(()) + }); + + h2.expect("client").join3(req1, req2) + }); + + + h2.join(srv).wait().expect("wait"); +} + /* #[test] fn send_data_after_headers_eos() {