Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rpc)!: read from substream while streaming to check for interruptions #3548

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions comms/src/protocol/rpc/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,13 +587,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
self.protocol_name(),
);
rx.close();
// RPC is strictly request/response
// If the client drops the RpcClient request at this point after the, we have two options:
// 1. Obey the protocol: receive the response
// 2. Error out and immediately close the session (seems brittle and may be unexpected)
// Option 1 has the disadvantage when receiving large/many streamed responses, however if all client handles
// have been dropped, then read_reply will exit early the stream will close and the server-side
// can exit early
return Ok(());
}

if let Err(err) = self.send_request(req).await {
Expand All @@ -603,6 +597,17 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
}

loop {
if self.shutdown_signal.is_triggered() {
debug!(
target: LOG_TARGET,
"[{}, stream_id: {}, req_id: {}] Client connector closed. Quitting stream early",
self.protocol_name(),
self.stream_id(),
request_id
);
break;
}

let resp = match self.read_response(request_id).await {
Ok(resp) => {
let latency = start.elapsed();
Expand Down Expand Up @@ -667,10 +672,20 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId
warn!(
target: LOG_TARGET,
"(stream={}) Response receiver was dropped before the response/stream could complete for \
protocol {}, the stream will continue until completed",
protocol {}, interrupting the stream. ",
self.stream_id(),
self.protocol_name()
);
let req = proto::rpc::RpcRequest {
request_id: request_id as u32,
method,
deadline: self.config.deadline.map(|t| t.as_secs()).unwrap_or(0),
flags: RpcMessageFlags::FIN.bits().into(),
payload: vec![],
};

self.send_request(req).await?;
break;
} else {
let _ = response_tx.send(Ok(resp)).await;
}
Expand Down Expand Up @@ -714,7 +729,31 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId

async fn read_response(&mut self, request_id: u16) -> Result<proto::rpc::RpcResponse, RpcError> {
let mut reader = RpcResponseReader::new(&mut self.framed, self.config, request_id);
let resp = reader.read_response().await?;

let mut num_ignored = 0;
let resp = loop {
match reader.read_response().await {
Ok(resp) => break resp,
Err(RpcError::ResponseIdDidNotMatchRequest { actual, expected })
if actual.saturating_add(1) == request_id =>
{
warn!(
target: LOG_TARGET,
"Possible delayed response received for previous request {}", actual
);
num_ignored += 1;

// Be lenient for a number of messages that may have been buffered to come through for the previous
// request.
const MAX_ALLOWED_IGNORED: usize = 5;
if num_ignored > MAX_ALLOWED_IGNORED {
return Err(RpcError::ResponseIdDidNotMatchRequest { actual, expected });
}
continue;
}
Err(err) => return Err(err),
}
};
Ok(resp)
}

Expand Down
4 changes: 4 additions & 0 deletions comms/src/protocol/rpc/server/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ pub enum RpcServerError {
MaximumSessionsReached,
#[error("Internal service request canceled")]
RequestCanceled,
#[error("Stream was closed by remote")]
StreamClosedByRemote,
#[error("Handshake error: {0}")]
HandshakeError(#[from] RpcHandshakeError),
#[error("Service not found for protocol `{0}`")]
ProtocolServiceNotFound(String),
#[error("Unexpected incoming message")]
UnexpectedIncomingMessage,
}

impl From<oneshot::error::RecvError> for RpcServerError {
Expand Down
32 changes: 30 additions & 2 deletions comms/src/protocol/rpc/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,18 @@ use crate::{
Bytes,
Substream,
};
use futures::{stream, SinkExt, StreamExt};
use futures::{future, stream, SinkExt, StreamExt};
use prost::Message;
use std::{
borrow::Cow,
future::Future,
pin::Pin,
sync::Arc,
task::Poll,
time::{Duration, Instant},
};
use tokio::{sync::mpsc, time};
use tokio_stream::Stream;
use tower::Service;
use tower_make::MakeService;
use tracing::{debug, error, instrument, span, trace, warn, Instrument, Level};
Expand Down Expand Up @@ -502,6 +505,11 @@ where
}

let msg_flags = RpcMessageFlags::from_bits_truncate(decoded_msg.flags as u8);

if msg_flags.contains(RpcMessageFlags::FIN) {
debug!(target: LOG_TARGET, "({}) Client sent FIN.", self.logging_context_string);
return Ok(());
}
if msg_flags.contains(RpcMessageFlags::ACK) {
debug!(
target: LOG_TARGET,
Expand Down Expand Up @@ -594,6 +602,12 @@ where
.map(|resp| Bytes::from(resp.to_encoded_bytes()));

loop {
// Check if the client interrupted the outgoing stream
if let Err(err) = self.check_interruptions().await {
warn!(target: LOG_TARGET, "{}", err);
break;
}

let next_item = log_timing(
self.logging_context_string.clone(),
request_id,
Expand All @@ -602,7 +616,7 @@ where
);
match time::timeout(deadline, next_item).await {
Ok(Some(msg)) => {
trace!(
debug!(
target: LOG_TARGET,
"({}) Sending body len = {}",
self.logging_context_string,
Expand Down Expand Up @@ -630,6 +644,20 @@ where
Ok(())
}

async fn check_interruptions(&mut self) -> Result<(), RpcServerError> {
let check = future::poll_fn(|cx| match Pin::new(&mut self.framed).poll_next(cx) {
Poll::Ready(Some(Ok(_))) => Poll::Ready(Some(RpcServerError::UnexpectedIncomingMessage)),
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(RpcServerError::from(err))),
Poll::Ready(None) => Poll::Ready(Some(RpcServerError::StreamClosedByRemote)),
Poll::Pending => Poll::Ready(None),
})
.await;
match check {
Some(err) => Err(err),
None => Ok(()),
}
}

fn create_request_context(&self, request_id: u32) -> RequestContext {
RequestContext::new(request_id, self.node_id.clone(), Box::new(self.comms_provider.clone()))
}
Expand Down
5 changes: 4 additions & 1 deletion comms/src/protocol/rpc/test/greeting_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ impl GreetingRpc for GreetingService {
tokio::spawn(async move {
for _ in 0..num_items {
time::sleep(Duration::from_millis(delay_ms)).await;
tx.send(Ok(item.clone())).await.unwrap();
if tx.send(Ok(item.clone())).await.is_err() {
log::info!("stream was interrupted");
break;
}
}
});

Expand Down
50 changes: 49 additions & 1 deletion comms/src/protocol/rpc/test/smoke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use tari_test_utils::unpack_enum;
use tokio::{
sync::{mpsc, RwLock},
task,
time,
};

pub(super) async fn setup_service<T: GreetingRpc>(
Expand Down Expand Up @@ -389,7 +390,7 @@ async fn stream_still_works_after_cancel() {
// Request was sent
assert_eq!(service_impl.call_count(), 1);

// Subsequent call still works, after waiting for the previous one
// Subsequent call still works
let resp = client
.slow_stream(SlowStreamRequest {
num_items: 100,
Expand All @@ -403,3 +404,50 @@ async fn stream_still_works_after_cancel() {
r.unwrap();
});
}

#[runtime::test]
async fn stream_interruption_handling() {
let service_impl = GreetingService::default();
let (mut muxer, _outbound, _, _, _shutdown) = setup(service_impl.clone(), 1).await;
let socket = muxer.incoming_mut().next().await.unwrap();

let framed = framing::canonical(socket, 1024);
let mut client = GreetingClient::builder()
.with_deadline(Duration::from_secs(5))
.connect(framed)
.await
.unwrap();

let mut resp = client
.slow_stream(SlowStreamRequest {
num_items: 10000,
item_size: 100,
delay_ms: 100,
})
.await
.unwrap();

let _ = resp.next().await.unwrap().unwrap();
// Drop it before the stream is finished
drop(resp);

// Subsequent call still works, without waiting
let mut resp = client
.slow_stream(SlowStreamRequest {
num_items: 100,
item_size: 100,
delay_ms: 1,
})
.await
.unwrap();

let next_fut = resp.next();
tokio::pin!(next_fut);
// Allow 10 seconds, if the previous stream is still streaming, it will take a while for this stream to start and
// the timeout will expire
time::timeout(Duration::from_secs(10), next_fut)
.await
.unwrap()
.unwrap()
.unwrap();
}