diff --git a/comms/src/builder/mod.rs b/comms/src/builder/mod.rs index 3b9452591d..68b8e0ee15 100644 --- a/comms/src/builder/mod.rs +++ b/comms/src/builder/mod.rs @@ -52,12 +52,13 @@ use crate::{ }, message::InboundMessage, multiaddr::Multiaddr, + multiplexing::Substream, noise::NoiseConfig, peer_manager::{NodeIdentity, PeerManager}, protocol::{messaging, messaging::MessagingProtocol, ProtocolNotification, Protocols}, tor, transports::{SocksTransport, TcpWithTorTransport, Transport}, - types::{CommsDatabase, CommsSubstream}, + types::CommsDatabase, }; use futures::{channel::mpsc, AsyncRead, AsyncWrite}; use log::*; @@ -73,7 +74,7 @@ pub struct CommsBuilder { node_identity: Option>, transport: Option, executor: Option, - protocols: Option>, + protocols: Option>, dial_backoff: Option, hidden_service: Option, connection_manager_config: ConnectionManagerConfig, @@ -220,7 +221,7 @@ where } } - pub fn with_protocols(mut self, protocols: Protocols) -> Self { + pub fn with_protocols(mut self, protocols: Protocols) -> Self { self.protocols = Some(protocols); self } @@ -238,7 +239,7 @@ where node_identity: Arc, ) -> ( messaging::MessagingProtocol, - mpsc::Sender>, + mpsc::Sender>, mpsc::Sender, mpsc::Receiver, messaging::MessagingEventSender, @@ -277,7 +278,7 @@ where &mut self, node_identity: Arc, peer_manager: Arc, - protocols: Protocols, + protocols: Protocols, request_rx: mpsc::Receiver, connection_manager_events_tx: broadcast::Sender>, ) -> ConnectionManager diff --git a/comms/src/builder/tests.rs b/comms/src/builder/tests.rs index 65576e4a40..9f933a7f68 100644 --- a/comms/src/builder/tests.rs +++ b/comms/src/builder/tests.rs @@ -27,6 +27,7 @@ use crate::{ memsocket, message::{InboundMessage, OutboundMessage}, multiaddr::{Multiaddr, Protocol}, + multiplexing::Substream, peer_manager::{Peer, PeerFeatures}, pipeline, pipeline::SinkService, @@ -34,7 +35,6 @@ use crate::{ runtime, test_utils::node_identity::build_node_identity, transports::MemoryTransport, - types::CommsSubstream, CommsNode, }; use bytes::Bytes; @@ -44,7 +44,7 @@ use tari_storage::HashmapDatabase; use tari_test_utils::{collect_stream, unpack_enum}; async fn spawn_node( - protocols: Protocols, + protocols: Protocols, ) -> (CommsNode, mpsc::Receiver, mpsc::Sender) { let addr = format!("/memory/{}", memsocket::acquire_next_memsocket_port()) .parse::() diff --git a/comms/src/connection_manager/manager.rs b/comms/src/connection_manager/manager.rs index 6b42e6ac45..0356a7e3f8 100644 --- a/comms/src/connection_manager/manager.rs +++ b/comms/src/connection_manager/manager.rs @@ -30,6 +30,7 @@ use super::{ }; use crate::{ backoff::Backoff, + multiplexing::Substream, noise::NoiseConfig, peer_manager::{NodeId, NodeIdentity}, protocol::{ProtocolEvent, ProtocolId, Protocols}, @@ -72,7 +73,7 @@ pub enum ConnectionManagerEvent { ListenFailed(ConnectionManagerError), // Substreams - NewInboundSubstream(Box, ProtocolId, yamux::Stream), + NewInboundSubstream(Box, ProtocolId, Substream), } impl fmt::Display for ConnectionManagerEvent { @@ -157,7 +158,7 @@ pub struct ConnectionManager { node_identity: Arc, active_connections: HashMap, shutdown_signal: Option, - protocols: Protocols, + protocols: Protocols, listener_address: Option, listening_notifiers: Vec>, connection_manager_events_tx: broadcast::Sender>, @@ -179,7 +180,7 @@ where request_rx: mpsc::Receiver, node_identity: Arc, peer_manager: Arc, - protocols: Protocols, + protocols: Protocols, connection_manager_events_tx: broadcast::Sender>, shutdown_signal: ShutdownSignal, ) -> Self diff --git a/comms/src/connection_manager/peer_connection.rs b/comms/src/connection_manager/peer_connection.rs index 3d96cba28e..31a5252882 100644 --- a/comms/src/connection_manager/peer_connection.rs +++ b/comms/src/connection_manager/peer_connection.rs @@ -26,11 +26,10 @@ use super::{ types::ConnectionDirection, }; use crate::{ - multiplexing::{IncomingSubstreams, Yamux}, + multiplexing::{Control, IncomingSubstreams, Substream, Yamux}, peer_manager::{NodeId, Peer, PeerFeatures}, protocol::{ProtocolId, ProtocolNegotiation}, runtime, - types::CommsSubstream, }; use futures::{ channel::{mpsc, oneshot}, @@ -92,7 +91,7 @@ pub enum PeerConnectionRequest { /// Open a new substream and negotiate the given protocol OpenSubstream( ProtocolId, - oneshot::Sender, PeerConnectionError>>, + oneshot::Sender, PeerConnectionError>>, ), /// Disconnect all substreams and close the transport connection Disconnect(bool, oneshot::Sender<()>), @@ -167,7 +166,7 @@ impl PeerConnection { pub async fn open_substream( &mut self, protocol_id: &ProtocolId, - ) -> Result, PeerConnectionError> + ) -> Result, PeerConnectionError> { let (reply_tx, reply_rx) = oneshot::channel(); self.request_tx @@ -220,7 +219,7 @@ pub struct PeerConnectionActor { direction: ConnectionDirection, incoming_substreams: Fuse, substream_shutdown: Option, - control: yamux::Control, + control: Control, event_notifier: mpsc::Sender, supported_protocols: Vec, shutdown: bool, @@ -309,7 +308,7 @@ impl PeerConnectionActor { } } - async fn handle_incoming_substream(&mut self, mut stream: yamux::Stream) -> Result<(), PeerConnectionError> { + async fn handle_incoming_substream(&mut self, mut stream: Substream) -> Result<(), PeerConnectionError> { let selected_protocol = ProtocolNegotiation::new(&mut stream) .negotiate_protocol_inbound(&self.supported_protocols) .await?; @@ -327,7 +326,7 @@ impl PeerConnectionActor { async fn open_negotiated_protocol_stream( &mut self, protocol: ProtocolId, - ) -> Result, PeerConnectionError> + ) -> Result, PeerConnectionError> { debug!( target: LOG_TARGET, diff --git a/comms/src/multiplexing/mod.rs b/comms/src/multiplexing/mod.rs index ad321f7a41..8ac082d5c4 100644 --- a/comms/src/multiplexing/mod.rs +++ b/comms/src/multiplexing/mod.rs @@ -21,4 +21,4 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod yamux; -pub use self::yamux::{Control, IncomingSubstreams, Yamux}; +pub use self::yamux::{ConnectionError, Control, IncomingSubstreams, Substream, Yamux}; diff --git a/comms/src/multiplexing/yamux.rs b/comms/src/multiplexing/yamux.rs index 00fc3affaf..79a0b2bfdb 100644 --- a/comms/src/multiplexing/yamux.rs +++ b/comms/src/multiplexing/yamux.rs @@ -33,20 +33,22 @@ use futures::{ StreamExt, }; use log::*; -use std::{io, pin::Pin, task::Poll}; +use std::{future::Future, io, pin::Pin, sync::Arc, task::Poll}; use tari_shutdown::{Shutdown, ShutdownSignal}; use yamux::Mode; type IncomingRx = mpsc::Receiver; type IncomingTx = mpsc::Sender; -pub type Control = yamux::Control; +// Reexport +pub use yamux::ConnectionError; const LOG_TARGET: &str = "comms::multiplexing::yamux"; pub struct Yamux { control: Control, incoming: IncomingSubstreams, + substream_counter: SubstreamCounter, } const MAX_BUFFER_SIZE: u32 = 8 * 1024 * 1024; // 8MB @@ -72,28 +74,37 @@ impl Yamux { config.set_max_buffer_size(MAX_BUFFER_SIZE as usize); config.set_receive_window(RECEIVE_WINDOW); + let substream_counter = SubstreamCounter::new(); let connection = yamux::Connection::new(socket, config, mode); - let control = connection.control(); - - let incoming = Self::spawn_incoming_stream_worker(connection); - - Ok(Self { control, incoming }) + let control = Control::new(connection.control(), substream_counter.clone()); + let incoming = Self::spawn_incoming_stream_worker(connection, substream_counter.clone()); + + Ok(Self { + control, + incoming, + substream_counter, + }) } // yamux@0.4 requires the incoming substream stream be polled in order to make progress on requests from it's // Control api. Here we spawn off a worker which will do this job - fn spawn_incoming_stream_worker(connection: yamux::Connection) -> IncomingSubstreams - where TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static { + fn spawn_incoming_stream_worker( + connection: yamux::Connection, + counter: SubstreamCounter, + ) -> IncomingSubstreams + where + TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { let shutdown = Shutdown::new(); let (incoming_tx, incoming_rx) = mpsc::channel(10); let stream = yamux::into_stream(connection).boxed(); let incoming = IncomingWorker::new(stream, incoming_tx, shutdown.to_signal()); runtime::current_executor().spawn(incoming.run()); - IncomingSubstreams::new(incoming_rx, shutdown) + IncomingSubstreams::new(incoming_rx, counter, shutdown) } /// Get the yamux control struct - pub fn get_yamux_control(&self) -> yamux::Control { + pub fn get_yamux_control(&self) -> Control { self.control.clone() } @@ -107,19 +118,66 @@ impl Yamux { self.incoming } + /// Return the number of active substreams + pub fn substream_count(&self) -> usize { + self.substream_counter.count() + } + pub fn is_terminated(&self) -> bool { self.incoming.is_terminated() } } +#[derive(Clone)] +pub struct Control { + inner: yamux::Control, + substream_counter: SubstreamCounter, +} + +impl Control { + pub fn new(inner: yamux::Control, substream_counter: SubstreamCounter) -> Self { + Self { + inner, + substream_counter, + } + } + + /// Open a new stream to the remote. + pub async fn open_stream(&mut self) -> Result { + let stream = self.inner.open_stream().await?; + Ok(Substream { + stream, + counter_guard: self.substream_counter.new_guard(), + }) + } + + /// Close the connection. + pub fn close(&mut self) -> impl Future> + '_ { + self.inner.close() + } + + pub fn substream_count(&self) -> usize { + self.substream_counter.count() + } +} + pub struct IncomingSubstreams { inner: IncomingRx, + substream_counter: SubstreamCounter, shutdown: Shutdown, } impl IncomingSubstreams { - pub fn new(inner: IncomingRx, shutdown: Shutdown) -> Self { - Self { inner, shutdown } + pub fn new(inner: IncomingRx, substream_counter: SubstreamCounter, shutdown: Shutdown) -> Self { + Self { + inner, + substream_counter, + shutdown, + } + } + + pub fn substream_count(&self) -> usize { + self.substream_counter.count() } } @@ -130,10 +188,16 @@ impl FusedStream for IncomingSubstreams { } impl Stream for IncomingSubstreams { - type Item = yamux::Stream; + type Item = Substream; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_next(cx) + match futures::ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(stream) => Poll::Ready(Some(Substream { + stream, + counter_guard: self.substream_counter.new_guard(), + })), + None => Poll::Ready(None), + } } } @@ -143,6 +207,32 @@ impl Drop for IncomingSubstreams { } } +#[derive(Debug)] +pub struct Substream { + stream: yamux::Stream, + counter_guard: CounterGuard, +} + +impl AsyncRead for Substream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + Pin::new(&mut self.stream).poll_read(cx, buf) + } +} + +impl AsyncWrite for Substream { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + Pin::new(&mut self.stream).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_close(cx) + } +} + struct IncomingWorker { inner: S, sender: mpsc::Sender, @@ -209,6 +299,28 @@ where S: Stream> + Unpin } } +pub type CounterGuard = Arc<()>; +#[derive(Debug, Clone, Default)] +pub struct SubstreamCounter(Arc); + +impl SubstreamCounter { + pub fn new() -> Self { + Default::default() + } + + /// Create a new CounterGuard. Each of these counts 1 in the substream count + /// until it is dropped. + pub fn new_guard(&self) -> CounterGuard { + Arc::clone(&*self.0) + } + + /// Get the substream count + pub fn count(&self) -> usize { + // Substract one to account for the initial CounterGuard reference + Arc::strong_count(&*self.0) - 1 + } +} + #[cfg(test)] mod test { use crate::{connection_manager::ConnectionDirection, memsocket::MemorySocket, multiplexing::yamux::Yamux}; @@ -217,21 +329,21 @@ mod test { io::{AsyncReadExt, AsyncWriteExt}, StreamExt, }; - use std::io; - use tokio::runtime::Handle; + use std::{io, time::Duration}; + use tari_test_utils::collect_stream; + use tokio::task; #[tokio_macros::test_basic] async fn open_substream() -> io::Result<()> { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"The Way of Kings"; - let rt_handle = Handle::current(); let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound) .await .unwrap(); let mut dialer_control = dialer.get_yamux_control(); - rt_handle.spawn(async move { + task::spawn(async move { let mut substream = dialer_control.open_stream().await.unwrap(); substream.write_all(msg).await.unwrap(); @@ -254,16 +366,49 @@ mod test { Ok(()) } + #[tokio_macros::test_basic] + async fn substream_count() { + const NUM_SUBSTREAMS: usize = 10; + let (dialer, listener) = MemorySocket::new_pair(); + + let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound) + .await + .unwrap(); + let mut dialer_control = dialer.get_yamux_control(); + + let substreams_out = task::spawn(async move { + let mut substreams = Vec::with_capacity(NUM_SUBSTREAMS); + for _ in 0..NUM_SUBSTREAMS { + substreams.push(dialer_control.open_stream().await.unwrap()); + } + substreams + }); + + let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound) + .await + .unwrap() + .incoming(); + let substreams_in = collect_stream!(&mut listener, take = NUM_SUBSTREAMS, timeout = Duration::from_secs(10)); + + assert_eq!(dialer.substream_count(), NUM_SUBSTREAMS); + assert_eq!(listener.substream_count(), NUM_SUBSTREAMS); + + drop(substreams_in); + drop(substreams_out); + + assert_eq!(dialer.substream_count(), 0); + assert_eq!(listener.substream_count(), 0); + } + #[tokio_macros::test_basic] async fn close() -> io::Result<()> { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"Words of Radiance"; - let rt_handle = Handle::current(); let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound).await?; let mut dialer_control = dialer.get_yamux_control(); - rt_handle.spawn(async move { + task::spawn(async move { let mut substream = dialer_control.open_stream().await.unwrap(); substream.write_all(msg).await.unwrap(); @@ -278,9 +423,6 @@ mod test { .await? .incoming(); let mut substream = incoming.next().await.unwrap(); - rt_handle.spawn(async move { - incoming.next().await; - }); let mut buf = vec![0; msg.len()]; substream.read_exact(&mut buf).await?; @@ -300,7 +442,6 @@ mod test { #[tokio_macros::test_basic] async fn send_big_message() -> io::Result<()> { - let rt_handle = Handle::current(); #[allow(non_upper_case_globals)] static MiB: usize = 1 << 20; static MSG_LEN: usize = 16 * MiB; @@ -309,13 +450,11 @@ mod test { let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound).await?; let mut dialer_control = dialer.get_yamux_control(); - // The incoming stream must be polled for the control to work - rt_handle.spawn(async move { - dialer.incoming().next().await; - }); - rt_handle.spawn(async move { + task::spawn(async move { + assert_eq!(dialer_control.substream_count(), 0); let mut substream = dialer_control.open_stream().await.unwrap(); + assert_eq!(dialer_control.substream_count(), 1); let msg = vec![0x55u8; MSG_LEN]; substream.write_all(msg.as_slice()).await.unwrap(); @@ -331,10 +470,9 @@ mod test { let mut incoming = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound) .await? .incoming(); + assert_eq!(incoming.substream_count(), 0); let mut substream = incoming.next().await.unwrap(); - rt_handle.spawn(async move { - incoming.next().await; - }); + assert_eq!(incoming.substream_count(), 1); let mut buf = vec![0u8; MSG_LEN]; substream.read_exact(&mut buf).await?; @@ -343,6 +481,9 @@ mod test { let msg = vec![0xAAu8; MSG_LEN]; substream.write_all(msg.as_slice()).await?; substream.close().await?; + drop(substream); + + assert_eq!(incoming.substream_count(), 0); Ok(()) } diff --git a/comms/src/protocol/messaging/inbound.rs b/comms/src/protocol/messaging/inbound.rs new file mode 100644 index 0000000000..adac9e894b --- /dev/null +++ b/comms/src/protocol/messaging/inbound.rs @@ -0,0 +1,118 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::{ + message::InboundMessage, + peer_manager::Peer, + protocol::messaging::{MessagingEvent, MessagingProtocol}, +}; +use futures::{channel::mpsc, AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use log::*; +use std::sync::Arc; +use tokio::sync::broadcast; + +const LOG_TARGET: &str = "comms::protocol::messaging::inbound"; + +pub struct InboundMessaging { + peer: Arc, + inbound_message_tx: mpsc::Sender, + messaging_events_tx: broadcast::Sender>, +} + +impl InboundMessaging { + pub fn new( + peer: Arc, + inbound_message_tx: mpsc::Sender, + messaging_events_tx: broadcast::Sender>, + ) -> Self + { + Self { + peer, + inbound_message_tx, + messaging_events_tx, + } + } + + pub async fn run(mut self, socket: S) + where S: AsyncRead + AsyncWrite + Unpin { + let mut framed_socket = MessagingProtocol::framed(socket); + let peer = &self.peer; + while let Some(result) = framed_socket.next().await { + match result { + Ok(raw_msg) => { + trace!( + target: LOG_TARGET, + "Received message from peer '{}' ({} bytes)", + peer.node_id.short_str(), + raw_msg.len() + ); + + let inbound_msg = InboundMessage::new(Arc::clone(&peer), raw_msg.freeze()); + + let event = MessagingEvent::MessageReceived( + Box::new(inbound_msg.source_peer.node_id.clone()), + inbound_msg.tag, + ); + + if let Err(err) = self.inbound_message_tx.send(inbound_msg).await { + warn!( + target: LOG_TARGET, + "Failed to send InboundMessage for peer '{}' because '{}'", + peer.node_id.short_str(), + err + ); + + if err.is_disconnected() { + break; + } + } + + trace!(target: LOG_TARGET, "Inbound handler sending event '{}'", event); + if let Err(err) = self.messaging_events_tx.send(Arc::new(event)) { + trace!( + target: LOG_TARGET, + "Messaging event '{}' not sent for peer '{}' because there are no subscribers. \ + MessagingEvent dropped", + err.0, + peer.node_id.short_str(), + ); + } + }, + Err(err) => { + error!( + target: LOG_TARGET, + "Failed to receive from peer '{}' because '{}'", + peer.node_id.short_str(), + err + ); + break; + }, + } + } + + debug!( + target: LOG_TARGET, + "Inbound messaging handler for peer '{}' has stopped", + peer.node_id.short_str() + ); + } +} diff --git a/comms/src/protocol/messaging/mod.rs b/comms/src/protocol/messaging/mod.rs index af1f6d20a8..3281f670cb 100644 --- a/comms/src/protocol/messaging/mod.rs +++ b/comms/src/protocol/messaging/mod.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod error; +mod inbound; mod outbound; mod protocol; diff --git a/comms/src/protocol/messaging/outbound.rs b/comms/src/protocol/messaging/outbound.rs index bb1038914d..c1890be978 100644 --- a/comms/src/protocol/messaging/outbound.rs +++ b/comms/src/protocol/messaging/outbound.rs @@ -24,8 +24,8 @@ use super::{error::MessagingProtocolError, MessagingEvent, MessagingProtocol, Se use crate::{ connection_manager::{ConnectionManagerError, ConnectionManagerRequester, NegotiatedSubstream, PeerConnection}, message::OutboundMessage, + multiplexing::Substream, peer_manager::{NodeId, NodeIdentity}, - types::CommsSubstream, }; use futures::{channel::mpsc, SinkExt, StreamExt}; use log::*; @@ -104,7 +104,7 @@ impl OutboundMessaging { async fn try_open_substream( &mut self, mut conn: PeerConnection, - ) -> Result, MessagingProtocolError> + ) -> Result, MessagingProtocolError> { match conn.open_substream(&MESSAGING_PROTOCOL).await { Ok(substream) => Ok(substream), @@ -122,7 +122,7 @@ impl OutboundMessaging { } } - async fn start_forwarding_messages(mut self, substream: CommsSubstream) -> Result<(), MessagingProtocolError> { + async fn start_forwarding_messages(mut self, substream: Substream) -> Result<(), MessagingProtocolError> { let mut framed = MessagingProtocol::framed(substream); while let Some(mut out_msg) = self.request_rx.next().await { trace!( diff --git a/comms/src/protocol/messaging/protocol.rs b/comms/src/protocol/messaging/protocol.rs index e8238bddb9..cbcc9d3806 100644 --- a/comms/src/protocol/messaging/protocol.rs +++ b/comms/src/protocol/messaging/protocol.rs @@ -25,10 +25,14 @@ use crate::{ compat::IoCompat, connection_manager::{ConnectionManagerEvent, ConnectionManagerRequester}, message::{InboundMessage, MessageTag, OutboundMessage}, + multiplexing::Substream, peer_manager::{NodeId, NodeIdentity, Peer, PeerManagerError}, - protocol::{messaging::outbound::OutboundMessaging, ProtocolEvent, ProtocolNotification}, + protocol::{ + messaging::{inbound::InboundMessaging, outbound::OutboundMessaging}, + ProtocolEvent, + ProtocolNotification, + }, runtime::current_executor, - types::CommsSubstream, PeerManager, }; use bytes::Bytes; @@ -94,7 +98,7 @@ pub struct MessagingProtocol { connection_manager_requester: ConnectionManagerRequester, node_identity: Arc, peer_manager: Arc, - proto_notification: Fuse>>, + proto_notification: Fuse>>, active_queues: HashMap, mpsc::UnboundedSender>, request_rx: Fuse>, messaging_events_tx: MessagingEventSender, @@ -115,7 +119,7 @@ impl MessagingProtocol { connection_manager_requester: ConnectionManagerRequester, peer_manager: Arc, node_identity: Arc, - proto_notification: mpsc::Receiver>, + proto_notification: mpsc::Receiver>, request_rx: mpsc::Receiver, messaging_events_tx: MessagingEventSender, inbound_message_tx: mpsc::Sender, @@ -346,80 +350,20 @@ impl MessagingProtocol { ) -> Result, MessagingProtocolError> { let (msg_tx, msg_rx) = mpsc::unbounded(); - executor.spawn( - OutboundMessaging::new(conn_man_requester, our_node_identity, events_tx, msg_rx, peer_node_id).run(), - ); + let outbound_messaging = + OutboundMessaging::new(conn_man_requester, our_node_identity, events_tx, msg_rx, peer_node_id); + executor.spawn(outbound_messaging.run()); Ok(msg_tx) } - async fn spawn_inbound_handler(&mut self, peer: Arc, substream: CommsSubstream) { + async fn spawn_inbound_handler(&mut self, peer: Arc, substream: Substream) { let messaging_events_tx = self.messaging_events_tx.clone(); - let mut inbound_message_tx = self.inbound_message_tx.clone(); - let mut framed_substream = Self::framed(substream); - - self.executor.spawn(async move { - while let Some(result) = framed_substream.next().await { - match result { - Ok(raw_msg) => { - trace!( - target: LOG_TARGET, - "Received message from peer '{}' ({} bytes)", - peer.node_id.short_str(), - raw_msg.len() - ); - - let inbound_msg = InboundMessage::new(Arc::clone(&peer), raw_msg.freeze()); - - let event = MessagingEvent::MessageReceived( - Box::new(inbound_msg.source_peer.node_id.clone()), - inbound_msg.tag, - ); - - if let Err(err) = inbound_message_tx.send(inbound_msg).await { - warn!( - target: LOG_TARGET, - "Failed to send InboundMessage for peer '{}' because '{}'", - peer.node_id.short_str(), - err - ); - - if err.is_disconnected() { - break; - } - } - - trace!(target: LOG_TARGET, "Inbound handler sending event '{}'", event); - if let Err(err) = messaging_events_tx.send(Arc::new(event)) { - debug!( - target: LOG_TARGET, - "Messaging event '{}' not sent for peer '{}' because there are no subscribers. \ - MessagingEvent dropped", - err.0, - peer.node_id.short_str(), - ); - } - }, - Err(err) => { - error!( - target: LOG_TARGET, - "Failed to receive from peer '{}' because '{}'", - peer.node_id.short_str(), - err - ); - break; - }, - } - } - - debug!( - target: LOG_TARGET, - "Inbound messaging handler for peer '{}' has stopped", - peer.node_id.short_str() - ); - }); + let inbound_message_tx = self.inbound_message_tx.clone(); + let inbound_messaging = InboundMessaging::new(peer, inbound_message_tx, messaging_events_tx); + self.executor.spawn(inbound_messaging.run(substream)); } - async fn handle_notification(&mut self, notification: ProtocolNotification) { + async fn handle_notification(&mut self, notification: ProtocolNotification) { debug_assert_eq!(notification.protocol, MESSAGING_PROTOCOL); match notification.event { // Peer negotiated to speak the messaging protocol with us @@ -437,7 +381,6 @@ impl MessagingProtocol { }, Err(PeerManagerError::PeerNotFoundError) => { // This should never happen if everything is working correctly - error!( target: LOG_TARGET, "[ThisNode={}] *** Could not find verified node_id '{}' in peer list. This should not \ diff --git a/comms/src/protocol/messaging/test.rs b/comms/src/protocol/messaging/test.rs index 51ed8ef475..45bc1e191b 100644 --- a/comms/src/protocol/messaging/test.rs +++ b/comms/src/protocol/messaging/test.rs @@ -29,6 +29,7 @@ use super::protocol::{ }; use crate::{ message::{InboundMessage, MessageTag, OutboundMessage}, + multiplexing::Substream, net_address::MultiaddressesWithStats, peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags, PeerManager}, protocol::{messaging::SendFailReason, ProtocolEvent, ProtocolNotification}, @@ -38,7 +39,7 @@ use crate::{ node_identity::build_node_identity, transport, }, - types::{CommsDatabase, CommsPublicKey, CommsSubstream}, + types::{CommsDatabase, CommsPublicKey}, }; use bytes::Bytes; use futures::{ @@ -62,7 +63,7 @@ async fn spawn_messaging_protocol() -> ( Arc, Arc, ConnectionManagerMockState, - mpsc::Sender>, + mpsc::Sender>, mpsc::Sender, mpsc::Receiver, MessagingEventReceiver, diff --git a/comms/src/test_utils/mocks/peer_connection.rs b/comms/src/test_utils/mocks/peer_connection.rs index b3540d0040..90a0edeb36 100644 --- a/comms/src/test_utils/mocks/peer_connection.rs +++ b/comms/src/test_utils/mocks/peer_connection.rs @@ -29,7 +29,7 @@ use crate::{ PeerConnectionRequest, }, multiplexing, - multiplexing::{IncomingSubstreams, Yamux}, + multiplexing::{IncomingSubstreams, Substream, Yamux}, peer_manager::Peer, test_utils::transport, }; @@ -102,11 +102,11 @@ impl PeerConnectionMockState { self.call_count.load(Ordering::SeqCst) } - pub async fn open_substream(&self) -> Result { + pub async fn open_substream(&self) -> Result { self.mux_control.lock().await.open_stream().await.map_err(Into::into) } - pub async fn next_incoming_substream(&self) -> Option { + pub async fn next_incoming_substream(&self) -> Option { self.mux_incoming.lock().await.next().await } diff --git a/comms/src/test_utils/test_node.rs b/comms/src/test_utils/test_node.rs index a83bc207d0..f8b875c7b0 100644 --- a/comms/src/test_utils/test_node.rs +++ b/comms/src/test_utils/test_node.rs @@ -23,6 +23,7 @@ use crate::{ backoff::ConstantBackoff, connection_manager::{ConnectionManager, ConnectionManagerConfig, ConnectionManagerRequester}, + multiplexing::Substream, noise::NoiseConfig, peer_manager::{NodeIdentity, PeerFeatures, PeerManager}, protocol::Protocols, @@ -70,7 +71,7 @@ impl Default for TestNodeConfig { pub fn build_connection_manager( config: TestNodeConfig, peer_manager: Arc, - protocols: Protocols, + protocols: Protocols, shutdown: ShutdownSignal, ) -> ConnectionManagerRequester { diff --git a/comms/src/types.rs b/comms/src/types.rs index ec062d78d4..aa3e6e3105 100644 --- a/comms/src/types.rs +++ b/comms/src/types.rs @@ -49,5 +49,3 @@ pub type CommsDataStore = LMDBStore; pub type CommsDatabase = LMDBWrapper; #[cfg(test)] pub type CommsDatabase = HashmapDatabase; - -pub type CommsSubstream = yamux::Stream;