Skip to content

Commit

Permalink
fix(comms): ensure that inbound messaging terminates on disconnect (#…
Browse files Browse the repository at this point in the history
…6653)

Description
---
fix(comms): ensure that inbound messaging terminates on disconnect

Motivation and Context
---
Ensure that the inbound messaging worker terminates when the peer
connection is disconnected.
This PR ensures this by using the `PeerConnection::on_disconnect`
future. It has not been confirmed
that the inbound worker would not terminate before however this PR will
guarantee it.

How Has This Been Tested?
---
Updated unit tests. 

What process can a PR reviewer use to test or verify this change?
---

<!-- Checklist -->
<!-- 1. Is the title of your PR in the form that would make nice release
notes? The title, excluding the conventional commit
tag, will be included exactly as is in the CHANGELOG, so please think
about it carefully. -->


Breaking Changes
---

- [x] None
- [ ] Requires data directory on base node to be deleted
- [ ] Requires hard fork
- [ ] Other - Please specify

<!-- Does this include a breaking change? If so, include this line as a
footer -->
<!-- BREAKING CHANGE: Description what the user should do, e.g. delete a
database, resync the chain -->
  • Loading branch information
sdbondi authored Oct 29, 2024
1 parent 3d0b44a commit 47b4877
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 54 deletions.
16 changes: 13 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions comms/core/src/protocol/messaging/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub enum MessagingProtocolError {
PeerConnectionError(#[from] PeerConnectionError),
#[error("Failed to dial peer: {0}")]
PeerDialFailed(ConnectivityError),
#[error("Connectivity error: {0}")]
ConnectivityError(#[from] ConnectivityError),
#[error("IO Error: {0}")]
Io(io::Error),
#[error("Sender error: {0}")]
Expand Down
13 changes: 7 additions & 6 deletions comms/core/src/protocol/messaging/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ use tokio::{
#[cfg(feature = "metrics")]
use super::metrics;
use super::{MessagingEvent, MessagingProtocol};
use crate::{message::InboundMessage, peer_manager::NodeId};
use crate::{message::InboundMessage, PeerConnection};

const LOG_TARGET: &str = "comms::protocol::messaging::inbound";

/// Inbound messaging actor. This is lazily spawned per peer when a peer requests a messaging session.
pub struct InboundMessaging {
peer: NodeId,
connection: PeerConnection,
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
Expand All @@ -48,14 +48,14 @@ pub struct InboundMessaging {

impl InboundMessaging {
pub fn new(
peer: NodeId,
connection: PeerConnection,
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
shutdown_signal: ShutdownSignal,
) -> Self {
Self {
peer,
connection,
inbound_message_tx,
messaging_events_tx,
enable_message_received_event,
Expand All @@ -65,7 +65,7 @@ impl InboundMessaging {

pub async fn run<S>(mut self, socket: S)
where S: AsyncRead + AsyncWrite + Unpin {
let peer = &self.peer;
let peer = self.connection.peer_node_id();
#[cfg(feature = "metrics")]
metrics::num_sessions().inc();
debug!(
Expand All @@ -75,13 +75,14 @@ impl InboundMessaging {
);

let stream = MessagingProtocol::framed(socket);
let stream = stream.take_until(self.connection.on_disconnect());
tokio::pin!(stream);

while let Either::Right((Some(result), _)) = future::select(self.shutdown_signal.wait(), stream.next()).await {
match result {
Ok(raw_msg) => {
#[cfg(feature = "metrics")]
metrics::inbound_message_count(&self.peer).inc();
metrics::inbound_message_count(self.connection.peer_node_id()).inc();
let msg_len = raw_msg.len();
let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.freeze());
debug!(
Expand Down
26 changes: 20 additions & 6 deletions comms/core/src/protocol/messaging/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use crate::{
ProtocolId,
ProtocolNotification,
},
PeerConnection,
};

const LOG_TARGET: &str = "comms::protocol::messaging";
Expand Down Expand Up @@ -203,7 +204,9 @@ impl MessagingProtocol {
},

Some(notification) = self.proto_notification.recv() => {
self.handle_protocol_notification(notification);
if let Err(err) = self.handle_protocol_notification(notification).await {
error!(target: LOG_TARGET, "handle_protocol_notification failed: {err}");
}
},

_ = &mut shutdown_signal => {
Expand Down Expand Up @@ -332,7 +335,8 @@ impl MessagingProtocol {
msg_tx
}

fn spawn_inbound_handler(&mut self, peer: NodeId, substream: Substream) {
fn spawn_inbound_handler(&mut self, conn: PeerConnection, substream: Substream) {
let peer = conn.peer_node_id().clone();
if let Some(handle) = self.active_inbound.get(&peer) {
if handle.is_finished() {
self.active_inbound.remove(&peer);
Expand All @@ -347,7 +351,7 @@ impl MessagingProtocol {
let messaging_events_tx = self.messaging_events_tx.clone();
let inbound_message_tx = self.inbound_message_tx.clone();
let inbound_messaging = InboundMessaging::new(
peer.clone(),
conn,
inbound_message_tx,
messaging_events_tx,
self.enable_message_received_event,
Expand All @@ -357,7 +361,10 @@ impl MessagingProtocol {
self.active_inbound.insert(peer, handle);
}

fn handle_protocol_notification(&mut self, notification: ProtocolNotification<Substream>) {
async fn handle_protocol_notification(
&mut self,
notification: ProtocolNotification<Substream>,
) -> Result<(), MessagingProtocolError> {
match notification.event {
// Peer negotiated to speak the messaging protocol with us
ProtocolEvent::NewInboundSubstream(node_id, substream) => {
Expand All @@ -366,10 +373,17 @@ impl MessagingProtocol {
"NewInboundSubstream for peer '{}'",
node_id.short_str()
);

self.spawn_inbound_handler(node_id, substream);
match self.connectivity.get_connection(node_id.clone()).await? {
Some(conn) => {
self.spawn_inbound_handler(conn, substream);
},
None => {
error!(target: LOG_TARGET, "No active connection for new inbound substream for node {node_id}");
},
}
},
}
Ok(())
}

async fn ban_peer<T: Display>(&mut self, peer_node_id: NodeId, reason: T) {
Expand Down
95 changes: 56 additions & 39 deletions comms/core/src/protocol/messaging/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ use crate::{
mocks::{create_connectivity_mock, create_peer_connection_mock_pair, ConnectivityManagerMockState},
node_id,
node_identity::build_node_identity,
transport,
},
types::{CommsDatabase, CommsPublicKey},
};
Expand Down Expand Up @@ -108,34 +107,47 @@ async fn spawn_messaging_protocol() -> (

#[tokio::test]
async fn new_inbound_substream_handling() {
let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) =
let (peer_manager, _, conn_man_mock, proto_tx, outbound_msg_tx, mut inbound_msg_rx, mut events_rx, _shutdown) =
spawn_messaging_protocol().await;

let expected_node_id = node_id::random();
let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
peer_manager
.add_peer(Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
))
.await
.unwrap();
let peer1 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);
peer_manager.add_peer(peer1.clone()).await.unwrap();

// Create connected memory sockets - we use each end of the connection as if they exist on different nodes
let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await;
let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
let peer2 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);

let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let (_, conn1_state, conn2, _conn2_state) = create_peer_connection_mock_pair(peer1.clone(), peer2.clone()).await;

let mut framed_ours = MessagingProtocol::framed(stream_ours);
framed_ours.send(TEST_MSG1.clone()).await.unwrap();
conn_man_mock.add_active_connection(conn2).await;

let (reply_tx, _reply_rx) = oneshot::channel();
let out_msg = OutboundMessage {
tag: MessageTag::new(),
reply: reply_tx.into(),
peer_node_id: peer1.node_id.clone(),
body: TEST_MSG1.clone(),
};
outbound_msg_tx.send(out_msg).unwrap();

// Notify the messaging protocol that a new substream has been established that wants to talk the messaging.
let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap();
let stream_theirs = conn1_state.next_incoming_substream().await.unwrap();
proto_tx
.send(ProtocolNotification::new(
MESSAGING_PROTOCOL_ID.clone(),
Expand Down Expand Up @@ -352,30 +364,35 @@ async fn many_concurrent_send_message_requests_that_fail() {

#[tokio::test]
async fn new_inbound_substream_only_single_session_permitted() {
let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, _, _shutdown) = spawn_messaging_protocol().await;
let (peer_manager, node_identity_1, conn_man_mock, proto_tx, _, mut inbound_msg_rx, _, _shutdown) =
spawn_messaging_protocol().await;

let expected_node_id = node_id::random();
let peer1 = node_identity_1.to_peer();

let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
peer_manager
.add_peer(Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
))
.await
.unwrap();
let peer2 = Peer::new(
pk.clone(),
expected_node_id.clone(),
MultiaddressesWithStats::default(),
PeerFlags::empty(),
PeerFeatures::COMMUNICATION_CLIENT,
Default::default(),
Default::default(),
);
peer_manager.add_peer(peer2.clone()).await.unwrap();

let (conn1, conn1_state, _, conn2_state) = create_peer_connection_mock_pair(peer1.clone(), peer2.clone()).await;

conn_man_mock.add_active_connection(conn1).await;

// Create connected memory sockets - we use each end of the connection as if they exist on different nodes
let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await;
// let (_, muxer_ours, mut muxer_theirs) = transport::build_multiplexed_connections().await;
// Spawn a task to deal with incoming substreams
tokio::spawn({
let expected_node_id = expected_node_id.clone();
async move {
while let Some(stream_theirs) = muxer_theirs.incoming_mut().next().await {
while let Some(stream_theirs) = conn2_state.next_incoming_substream().await {
proto_tx
.send(ProtocolNotification::new(
MESSAGING_PROTOCOL_ID.clone(),
Expand All @@ -388,7 +405,7 @@ async fn new_inbound_substream_only_single_session_permitted() {
});

// Open first stream
let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let stream_ours = conn1_state.open_substream().await.unwrap();
let mut framed_ours = MessagingProtocol::framed(stream_ours);
framed_ours.send(TEST_MSG1.clone()).await.unwrap();

Expand All @@ -401,7 +418,7 @@ async fn new_inbound_substream_only_single_session_permitted() {
assert_eq!(in_msg.body, TEST_MSG1);

// Check the second stream closes immediately
let stream_ours2 = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let stream_ours2 = conn1_state.open_substream().await.unwrap();

let mut framed_ours2 = MessagingProtocol::framed(stream_ours2);
// Check that it eventually exits. The first send will initiate the substream and send. Once the other side closes
Expand Down Expand Up @@ -431,7 +448,7 @@ async fn new_inbound_substream_only_single_session_permitted() {
framed_ours.close().await.unwrap();

// Open another one for messaging
let stream_ours = muxer_ours.get_yamux_control().open_stream().await.unwrap();
let stream_ours = conn1_state.open_substream().await.unwrap();
let mut framed_ours = MessagingProtocol::framed(stream_ours);
framed_ours.send(TEST_MSG1.clone()).await.unwrap();

Expand Down

0 comments on commit 47b4877

Please sign in to comment.