Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle stream read error case by explicitly closing the substream #3321

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions applications/tari_base_node/src/command_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,8 @@ impl CommandHandler {

self.executor.spawn(async move {
let mut status_line = StatusLine::new();
let version = format!("v{}", consts::APP_VERSION_NUMBER);
status_line.add_field("", version);
let network = format!("{}", config.network);
status_line.add_field("", network);
status_line.add_field("", format!("v{}", consts::APP_VERSION_NUMBER));
status_line.add_field("", config.network);
status_line.add_field("State", state_info.borrow().state_info.short_desc());

let metadata = node.get_metadata().await.unwrap();
Expand Down
3 changes: 1 addition & 2 deletions base_layer/wallet/src/connectivity_service/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use tari_comms::{
mocks::{create_connectivity_mock, ConnectivityManagerMockState},
node_identity::build_node_identity,
},
Substream,
};
use tari_shutdown::Shutdown;
use tari_test_utils::runtime::spawn_until_shutdown;
Expand All @@ -46,7 +45,7 @@ use tokio::{

async fn setup() -> (
WalletConnectivityHandle,
MockRpcServer<MockRpcImpl, Substream>,
MockRpcServer<MockRpcImpl>,
ConnectivityManagerMockState,
Shutdown,
) {
Expand Down
3 changes: 1 addition & 2 deletions base_layer/wallet/tests/output_manager_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ use tari_comms::{
node_identity::build_node_identity,
},
types::CommsSecretKey,
Substream,
};
use tari_core::{
base_node::rpc::BaseNodeWalletRpcServer,
Expand Down Expand Up @@ -97,7 +96,7 @@ async fn setup_output_manager_service<T: OutputManagerBackend + 'static>(
OutputManagerHandle,
Shutdown,
TransactionServiceHandle,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>, Substream>,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>>,
Arc<NodeIdentity>,
BaseNodeWalletRpcMockState,
ConnectivityManagerMockState,
Expand Down
5 changes: 2 additions & 3 deletions base_layer/wallet/tests/transaction_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ use tari_comms::{
},
types::CommsSecretKey,
CommsNode,
Substream,
};
use tari_comms_dht::outbound::mock::{
create_outbound_service_mock,
Expand Down Expand Up @@ -244,7 +243,7 @@ pub fn setup_transaction_service_no_comms(
Sender<DomainMessage<base_node_proto::BaseNodeServiceResponse>>,
Sender<DomainMessage<proto::TransactionCancelledMessage>>,
Shutdown,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>, Substream>,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>>,
Arc<NodeIdentity>,
BaseNodeWalletRpcMockState,
) {
Expand All @@ -268,7 +267,7 @@ pub fn setup_transaction_service_no_comms_and_oms_backend(
Sender<DomainMessage<base_node_proto::BaseNodeServiceResponse>>,
Sender<DomainMessage<proto::TransactionCancelledMessage>>,
Shutdown,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>, Substream>,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>>,
Arc<NodeIdentity>,
BaseNodeWalletRpcMockState,
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use tari_comms::{
},
types::CommsPublicKey,
NodeIdentity,
Substream,
};
use tari_comms_dht::outbound::mock::{create_outbound_service_mock, OutboundServiceMockState};
use tari_core::{
Expand Down Expand Up @@ -96,7 +95,7 @@ pub async fn setup(
TransactionServiceResources<TransactionServiceSqliteDatabase>,
ConnectivityManagerMockState,
OutboundServiceMockState,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>, Substream>,
MockRpcServer<BaseNodeWalletRpcServer<BaseNodeWalletRpcMockService>>,
Arc<NodeIdentity>,
BaseNodeWalletRpcMockState,
broadcast::Sender<Duration>,
Expand Down
6 changes: 3 additions & 3 deletions comms/rpc_macros/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,15 @@ impl RpcCodeGenerator {
.collect::<TokenStream>();

let client_struct_body = quote! {
pub async fn connect<TSubstream>(framed: #dep_mod::CanonicalFraming<TSubstream>) -> Result<Self, #dep_mod::RpcError>
where TSubstream: #dep_mod::AsyncRead + #dep_mod::AsyncWrite + Unpin + Send + 'static {
pub async fn connect(framed: #dep_mod::CanonicalFraming<#dep_mod::Substream>) -> Result<Self, #dep_mod::RpcError> {
use #dep_mod::NamedProtocolService;
let inner = #dep_mod::RpcClient::connect(Default::default(), framed, Self::PROTOCOL_NAME.into()).await?;
Ok(Self { inner })
}

pub fn builder() -> #dep_mod::RpcClientBuilder<Self> {
#dep_mod::RpcClientBuilder::new()
use #dep_mod::NamedProtocolService;
#dep_mod::RpcClientBuilder::new().with_protocol_id(Self::PROTOCOL_NAME.into())
}

#client_methods
Expand Down
2 changes: 1 addition & 1 deletion comms/rpc_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod options;
///
/// Generates Tari RPC "harness code" for a given trait.
///
/// ```no_run
/// ```no_run,ignore
/// # use tari_comms_rpc_macros::tari_rpc;
/// # use tari_comms::protocol::rpc::{Request, Streaming, Response, RpcStatus, RpcServer};
/// use tari_comms::{framing, memsocket::MemorySocket};
Expand Down
11 changes: 7 additions & 4 deletions comms/rpc_macros/tests/macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ use prost::Message;
use std::{collections::HashMap, ops::AddAssign, sync::Arc};
use tari_comms::{
framing,
memsocket::MemorySocket,
message::MessageExt,
protocol::{
rpc,
rpc::{NamedProtocolService, Request, Response, RpcStatus, RpcStatusCode, Streaming},
},
test_utils::transport::build_multiplexed_connections,
};
use tari_comms_rpc_macros::tari_rpc;
use tari_test_utils::unpack_enum;
Expand Down Expand Up @@ -152,9 +152,12 @@ async fn it_returns_an_error_for_invalid_method_nums() {

#[tokio::test]
async fn it_generates_client_calls() {
let (sock_client, sock_server) = MemorySocket::new_pair();
let client = task::spawn(TestClient::connect(framing::canonical(sock_client, 1024)));
let mut sock_server = framing::canonical(sock_server, 1024);
let (_, sock_client, mut sock_server) = build_multiplexed_connections().await;
let client = task::spawn(TestClient::connect(framing::canonical(
sock_client.get_yamux_control().open_stream().await.unwrap(),
1024,
)));
let mut sock_server = framing::canonical(sock_server.incoming_mut().next().await.unwrap(), 1024);
let mut handshake = rpc::Handshake::new(&mut sock_server);
handshake.perform_server_handshake().await.unwrap();
// Wait for client to connect
Expand Down
37 changes: 29 additions & 8 deletions comms/src/memsocket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use futures::{
stream::{FusedStream, Stream},
task::{Context, Poll},
};
use log::*;
use std::{
cmp,
collections::{hash_map::Entry, HashMap},
Expand Down Expand Up @@ -433,6 +434,7 @@ impl AsyncRead for MemorySocket {
buf.advance(bytes_to_read);

current_buffer.advance(bytes_to_read);
trace!("reading {} bytes", bytes_to_read);

bytes_read += bytes_to_read;
}
Expand Down Expand Up @@ -462,11 +464,12 @@ impl AsyncRead for MemorySocket {

impl AsyncWrite for MemorySocket {
/// Attempt to write bytes from `buf` into the outgoing channel.
fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let len = buf.len();

match self.outgoing.poll_ready(context) {
match self.outgoing.poll_ready(cx) {
Poll::Ready(Ok(())) => {
trace!("writing {} bytes", len);
if let Err(e) = self.outgoing.start_send(Bytes::copy_from_slice(buf)) {
if e.is_disconnected() {
return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e)));
Expand All @@ -475,6 +478,7 @@ impl AsyncWrite for MemorySocket {
// Unbounded channels should only ever have "Disconnected" errors
unreachable!();
}
Poll::Ready(Ok(len))
},
Poll::Ready(Err(e)) => {
if e.is_disconnected() {
Expand All @@ -484,19 +488,18 @@ impl AsyncWrite for MemorySocket {
// Unbounded channels should only ever have "Disconnected" errors
unreachable!();
},
Poll::Pending => return Poll::Pending,
Poll::Pending => Poll::Pending,
}

Poll::Ready(Ok(len))
}

/// Attempt to flush the channel. Cannot Fail.
fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll<io::Result<()>> {
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<io::Result<()>> {
trace!("flush");
Poll::Ready(Ok(()))
}

/// Attempt to close the channel. Cannot Fail.
fn poll_shutdown(self: Pin<&mut Self>, _context: &mut Context) -> Poll<io::Result<()>> {
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context) -> Poll<io::Result<()>> {
self.outgoing.close_channel();

Poll::Ready(Ok(()))
Expand All @@ -506,7 +509,8 @@ impl AsyncWrite for MemorySocket {
#[cfg(test)]
mod test {
use super::*;
use crate::runtime;
use crate::{framing, runtime};
use futures::SinkExt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_stream::StreamExt;

Expand Down Expand Up @@ -705,4 +709,21 @@ mod test {

Ok(())
}

#[runtime::test]
async fn read_and_write_canonical_framing() -> io::Result<()> {
let (a, b) = MemorySocket::new_pair();
let mut a = framing::canonical(a, 1024);
let mut b = framing::canonical(b, 1024);

a.send(Bytes::from_static(b"frame-1")).await?;
b.send(Bytes::from_static(b"frame-2")).await?;
let msg = b.next().await.unwrap()?;
assert_eq!(&msg[..], b"frame-1");

let msg = a.next().await.unwrap()?;
assert_eq!(&msg[..], b"frame-2");

Ok(())
}
}
42 changes: 29 additions & 13 deletions comms/src/multiplexing/yamux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ pub struct IncomingSubstreams {
}

impl IncomingSubstreams {
pub fn new(inner: IncomingRx, substream_counter: SubstreamCounter, shutdown: Shutdown) -> Self {
pub(self) fn new(inner: IncomingRx, substream_counter: SubstreamCounter, shutdown: Shutdown) -> Self {
Self {
inner,
substream_counter,
Expand Down Expand Up @@ -205,6 +205,12 @@ pub struct Substream {
counter_guard: CounterGuard,
}

impl Substream {
pub fn id(&self) -> yamux::StreamId {
self.stream.get_ref().id()
}
}

impl tokio::io::AsyncRead for Substream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
Expand Down Expand Up @@ -242,13 +248,17 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static
}
}

#[tracing::instrument(name = "yamux::incoming_worker::run", skip(self))]
#[tracing::instrument(name = "yamux::incoming_worker::run", skip(self), fields(connection = %self.connection))]
pub async fn run(mut self) {
loop {
tokio::select! {
biased;

_ = &mut self.shutdown_signal => {
_ = self.shutdown_signal.wait() => {
debug!(
target: LOG_TARGET,
"{} Yamux connection shutdown", self.connection
);
let mut control = self.connection.control();
if let Err(err) = control.close().await {
error!(target: LOG_TARGET, "Failed to close yamux connection: {}", err);
Expand All @@ -259,31 +269,37 @@ where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static
result = self.connection.next_stream() => {
match result {
Ok(Some(stream)) => {
event!(Level::TRACE, "yamux::stream received {}", stream);if self.sender.send(stream).await.is_err() {
event!(Level::TRACE, "yamux::incoming_worker::new_stream {}", stream);
if self.sender.send(stream).await.is_err() {
debug!(
target: LOG_TARGET,
"Incoming peer substream task is shutting down because the internal stream sender channel \
was closed"
"{} Incoming peer substream task is shutting down because the internal stream sender channel \
was closed",
self.connection
);
break;
}
},
Ok(None) =>{
debug!(
target: LOG_TARGET,
"Incoming peer substream completed. IncomingWorker exiting"
"{} Incoming peer substream completed. IncomingWorker exiting",
self.connection
);
break;
}
Err(err) => {
event!(
Level::ERROR,
"Incoming peer substream task received an error because '{}'",
err
);
error!(
Level::ERROR,
"{} Incoming peer substream task received an error because '{}'",
self.connection,
err
);
error!(
target: LOG_TARGET,
"Incoming peer substream task received an error because '{}'", err
"{} Incoming peer substream task received an error because '{}'",
self.connection,
err
);
break;
},
Expand Down
Loading