diff --git a/comms/src/protocol/rpc/client.rs b/comms/src/protocol/rpc/client.rs index 2f19d4646b..d3cf5b9df5 100644 --- a/comms/src/protocol/rpc/client.rs +++ b/comms/src/protocol/rpc/client.rs @@ -587,13 +587,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId self.protocol_name(), ); rx.close(); - // RPC is strictly request/response - // If the client drops the RpcClient request at this point after the, we have two options: - // 1. Obey the protocol: receive the response - // 2. Error out and immediately close the session (seems brittle and may be unexpected) - // Option 1 has the disadvantage when receiving large/many streamed responses, however if all client handles - // have been dropped, then read_reply will exit early the stream will close and the server-side - // can exit early + return Ok(()); } if let Err(err) = self.send_request(req).await { @@ -603,6 +597,17 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId } loop { + if self.shutdown_signal.is_triggered() { + debug!( + target: LOG_TARGET, + "[{}, stream_id: {}, req_id: {}] Client connector closed. Quitting stream early", + self.protocol_name(), + self.stream_id(), + request_id + ); + break; + } + let resp = match self.read_response(request_id).await { Ok(resp) => { let latency = start.elapsed(); @@ -667,10 +672,20 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId warn!( target: LOG_TARGET, "(stream={}) Response receiver was dropped before the response/stream could complete for \ - protocol {}, the stream will continue until completed", + protocol {}, interrupting the stream. ", self.stream_id(), self.protocol_name() ); + let req = proto::rpc::RpcRequest { + request_id: request_id as u32, + method, + deadline: self.config.deadline.map(|t| t.as_secs()).unwrap_or(0), + flags: RpcMessageFlags::FIN.bits().into(), + payload: vec![], + }; + + self.send_request(req).await?; + break; } else { let _ = response_tx.send(Ok(resp)).await; } @@ -714,7 +729,31 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId async fn read_response(&mut self, request_id: u16) -> Result { let mut reader = RpcResponseReader::new(&mut self.framed, self.config, request_id); - let resp = reader.read_response().await?; + + let mut num_ignored = 0; + let resp = loop { + match reader.read_response().await { + Ok(resp) => break resp, + Err(RpcError::ResponseIdDidNotMatchRequest { actual, expected }) + if actual.saturating_add(1) == request_id => + { + warn!( + target: LOG_TARGET, + "Possible delayed response received for previous request {}", actual + ); + num_ignored += 1; + + // Be lenient for a number of messages that may have been buffered to come through for the previous + // request. + const MAX_ALLOWED_IGNORED: usize = 5; + if num_ignored > MAX_ALLOWED_IGNORED { + return Err(RpcError::ResponseIdDidNotMatchRequest { actual, expected }); + } + continue; + } + Err(err) => return Err(err), + } + }; Ok(resp) } diff --git a/comms/src/protocol/rpc/server/error.rs b/comms/src/protocol/rpc/server/error.rs index 6972cec60b..f5cc145de2 100644 --- a/comms/src/protocol/rpc/server/error.rs +++ b/comms/src/protocol/rpc/server/error.rs @@ -35,10 +35,14 @@ pub enum RpcServerError { MaximumSessionsReached, #[error("Internal service request canceled")] RequestCanceled, + #[error("Stream was closed by remote")] + StreamClosedByRemote, #[error("Handshake error: {0}")] HandshakeError(#[from] RpcHandshakeError), #[error("Service not found for protocol `{0}`")] ProtocolServiceNotFound(String), + #[error("Unexpected incoming message")] + UnexpectedIncomingMessage, } impl From for RpcServerError { diff --git a/comms/src/protocol/rpc/server/mod.rs b/comms/src/protocol/rpc/server/mod.rs index a78cd146b0..974db6136d 100644 --- a/comms/src/protocol/rpc/server/mod.rs +++ b/comms/src/protocol/rpc/server/mod.rs @@ -63,15 +63,18 @@ use crate::{ Bytes, Substream, }; -use futures::{stream, SinkExt, StreamExt}; +use futures::{future, stream, SinkExt, StreamExt}; use prost::Message; use std::{ borrow::Cow, future::Future, + pin::Pin, sync::Arc, + task::Poll, time::{Duration, Instant}, }; use tokio::{sync::mpsc, time}; +use tokio_stream::Stream; use tower::Service; use tower_make::MakeService; use tracing::{debug, error, instrument, span, trace, warn, Instrument, Level}; @@ -502,6 +505,11 @@ where } let msg_flags = RpcMessageFlags::from_bits_truncate(decoded_msg.flags as u8); + + if msg_flags.contains(RpcMessageFlags::FIN) { + debug!(target: LOG_TARGET, "({}) Client sent FIN.", self.logging_context_string); + return Ok(()); + } if msg_flags.contains(RpcMessageFlags::ACK) { debug!( target: LOG_TARGET, @@ -594,6 +602,12 @@ where .map(|resp| Bytes::from(resp.to_encoded_bytes())); loop { + // Check if the client interrupted the outgoing stream + if let Err(err) = self.check_interruptions().await { + warn!(target: LOG_TARGET, "{}", err); + break; + } + let next_item = log_timing( self.logging_context_string.clone(), request_id, @@ -602,7 +616,7 @@ where ); match time::timeout(deadline, next_item).await { Ok(Some(msg)) => { - trace!( + debug!( target: LOG_TARGET, "({}) Sending body len = {}", self.logging_context_string, @@ -630,6 +644,20 @@ where Ok(()) } + async fn check_interruptions(&mut self) -> Result<(), RpcServerError> { + let check = future::poll_fn(|cx| match Pin::new(&mut self.framed).poll_next(cx) { + Poll::Ready(Some(Ok(_))) => Poll::Ready(Some(RpcServerError::UnexpectedIncomingMessage)), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(RpcServerError::from(err))), + Poll::Ready(None) => Poll::Ready(Some(RpcServerError::StreamClosedByRemote)), + Poll::Pending => Poll::Ready(None), + }) + .await; + match check { + Some(err) => Err(err), + None => Ok(()), + } + } + fn create_request_context(&self, request_id: u32) -> RequestContext { RequestContext::new(request_id, self.node_id.clone(), Box::new(self.comms_provider.clone())) } diff --git a/comms/src/protocol/rpc/test/greeting_service.rs b/comms/src/protocol/rpc/test/greeting_service.rs index f66221b5ae..9153db59bf 100644 --- a/comms/src/protocol/rpc/test/greeting_service.rs +++ b/comms/src/protocol/rpc/test/greeting_service.rs @@ -172,7 +172,10 @@ impl GreetingRpc for GreetingService { tokio::spawn(async move { for _ in 0..num_items { time::sleep(Duration::from_millis(delay_ms)).await; - tx.send(Ok(item.clone())).await.unwrap(); + if tx.send(Ok(item.clone())).await.is_err() { + log::info!("stream was interrupted"); + break; + } } }); diff --git a/comms/src/protocol/rpc/test/smoke.rs b/comms/src/protocol/rpc/test/smoke.rs index 96c85d1ff7..8ca1bfce52 100644 --- a/comms/src/protocol/rpc/test/smoke.rs +++ b/comms/src/protocol/rpc/test/smoke.rs @@ -62,6 +62,7 @@ use tari_test_utils::unpack_enum; use tokio::{ sync::{mpsc, RwLock}, task, + time, }; pub(super) async fn setup_service( @@ -389,7 +390,7 @@ async fn stream_still_works_after_cancel() { // Request was sent assert_eq!(service_impl.call_count(), 1); - // Subsequent call still works, after waiting for the previous one + // Subsequent call still works let resp = client .slow_stream(SlowStreamRequest { num_items: 100, @@ -403,3 +404,50 @@ async fn stream_still_works_after_cancel() { r.unwrap(); }); } + +#[runtime::test] +async fn stream_interruption_handling() { + let service_impl = GreetingService::default(); + let (mut muxer, _outbound, _, _, _shutdown) = setup(service_impl.clone(), 1).await; + let socket = muxer.incoming_mut().next().await.unwrap(); + + let framed = framing::canonical(socket, 1024); + let mut client = GreetingClient::builder() + .with_deadline(Duration::from_secs(5)) + .connect(framed) + .await + .unwrap(); + + let mut resp = client + .slow_stream(SlowStreamRequest { + num_items: 10000, + item_size: 100, + delay_ms: 100, + }) + .await + .unwrap(); + + let _ = resp.next().await.unwrap().unwrap(); + // Drop it before the stream is finished + drop(resp); + + // Subsequent call still works, without waiting + let mut resp = client + .slow_stream(SlowStreamRequest { + num_items: 100, + item_size: 100, + delay_ms: 1, + }) + .await + .unwrap(); + + let next_fut = resp.next(); + tokio::pin!(next_fut); + // Allow 10 seconds, if the previous stream is still streaming, it will take a while for this stream to start and + // the timeout will expire + time::timeout(Duration::from_secs(10), next_fut) + .await + .unwrap() + .unwrap() + .unwrap(); +}