From 3663513e41e6aede9e63458bf127aeaa6fd0da58 Mon Sep 17 00:00:00 2001 From: Stanimal Date: Wed, 17 Jun 2020 15:29:09 +0200 Subject: [PATCH] Made messaging pipeline optional in CommsBuilder Messaging pipeline is optional. If a MessageingPipeline is not specified in the builder, the messaging protocol, associated channels and tokio tasks will never be initialized/spawned. This allows a more light-weight comms to be initialized in tests that do not require the messaging pipeline (such as integration tests for components using e.g. RPC only) without the added cruft of setting up a _no-op_ messaging pipeline. --- comms/src/bounded_executor.rs | 4 +- comms/src/builder/comms_node.rs | 129 +++++++++++------- comms/src/builder/error.rs | 5 - comms/src/builder/mod.rs | 56 +------- comms/src/builder/tests.rs | 10 +- comms/src/connection_manager/dialer.rs | 13 +- comms/src/connection_manager/listener.rs | 11 +- comms/src/connection_manager/liveness.rs | 2 +- comms/src/connection_manager/manager.rs | 24 ++-- .../src/connection_manager/peer_connection.rs | 2 +- .../tests/listener_dialer.rs | 17 ++- comms/src/connection_manager/tests/manager.rs | 10 +- comms/src/multiplexing/yamux.rs | 2 +- comms/src/protocol/mod.rs | 2 +- comms/src/protocol/protocols.rs | 17 ++- comms/src/runtime.rs | 2 +- comms/src/test_utils/test_node.rs | 6 +- comms/src/tor/control_client/test_server.rs | 2 +- 18 files changed, 152 insertions(+), 162 deletions(-) diff --git a/comms/src/bounded_executor.rs b/comms/src/bounded_executor.rs index fb8d018398..1e17a88848 100644 --- a/comms/src/bounded_executor.rs +++ b/comms/src/bounded_executor.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::runtime::current_executor; +use crate::runtime::current; use std::{future::Future, sync::Arc}; use tokio::{runtime, sync::Semaphore, task::JoinHandle}; @@ -42,7 +42,7 @@ impl BoundedExecutor { } pub fn from_current(num_permits: usize) -> Self { - Self::new(current_executor(), num_permits) + Self::new(current(), num_permits) } /// Spawn a future onto the Tokio runtime asynchronously blocking if there are too many diff --git a/comms/src/builder/comms_node.rs b/comms/src/builder/comms_node.rs index dd0c0923b0..0cce6de471 100644 --- a/comms/src/builder/comms_node.rs +++ b/comms/src/builder/comms_node.rs @@ -24,14 +24,17 @@ use super::{placeholder::PlaceholderService, CommsBuilderError, CommsShutdown}; use crate::{ backoff::BoxedBackoff, bounded_executor::BoundedExecutor, + builder::consts, connection_manager::{ConnectionManager, ConnectionManagerEvent, ConnectionManagerRequester}, connectivity::{ConnectivityManager, ConnectivityRequester}, message::InboundMessage, multiaddr::Multiaddr, + multiplexing::Substream, peer_manager::{NodeIdentity, PeerManager}, pipeline, - protocol::{messaging, messaging::MessagingProtocol}, + protocol::{messaging, messaging::MessagingProtocol, ProtocolNotifier, Protocols}, runtime, + runtime::task, tor, transports::Transport, }; @@ -58,13 +61,10 @@ pub struct BuiltCommsNode< pub connectivity_requester: ConnectivityRequester, pub messaging_pipeline: Option>, pub node_identity: Arc, - pub messaging: MessagingProtocol, - pub messaging_event_tx: messaging::MessagingEventSender, - pub inbound_message_rx: mpsc::Receiver, pub hidden_service: Option, - pub messaging_request_tx: mpsc::Sender, - pub shutdown: Shutdown, pub peer_manager: Arc, + pub protocols: Protocols, + pub shutdown: Shutdown, } impl BuiltCommsNode @@ -100,11 +100,8 @@ where connectivity_manager: self.connectivity_manager, connectivity_requester: self.connectivity_requester, node_identity: self.node_identity, - messaging: self.messaging, - messaging_event_tx: self.messaging_event_tx, - inbound_message_rx: self.inbound_message_rx, shutdown: self.shutdown, - messaging_request_tx: self.messaging_request_tx, + protocols: self.protocols, hidden_service: self.hidden_service, peer_manager: self.peer_manager, } @@ -131,19 +128,16 @@ where pub async fn spawn(self) -> Result { let BuiltCommsNode { - connection_manager, + mut connection_manager, connection_manager_requester, connection_manager_event_tx, connectivity_manager, connectivity_requester, messaging_pipeline, - messaging_request_tx, - inbound_message_rx, node_identity, shutdown, peer_manager, - messaging, - messaging_event_tx, + mut protocols, hidden_service, } = self; @@ -163,34 +157,48 @@ where "Your node's public address is '{}'", node_identity.public_address() ); - let messaging_pipeline = messaging_pipeline.ok_or(CommsBuilderError::MessagingPiplineNotProvided)?; + let mut complete_signals = Vec::new(); let events_stream = connection_manager_event_tx.subscribe(); - let conn_man_shutdown_signal = connection_manager.complete_signal(); - - let executor = runtime::current_executor(); + complete_signals.push(connection_manager.complete_signal()); // Connectivity manager - executor.spawn(connectivity_manager.create().run()); - executor.spawn(connection_manager.run()); - - // Spawn messaging protocol - let messaging_signal = messaging.complete_signal(); - executor.spawn(messaging.run()); - - // Spawn inbound pipeline - let bounded_executor = BoundedExecutor::new(executor.clone(), messaging_pipeline.max_concurrent_inbound_tasks); - let inbound = pipeline::Inbound::new( - bounded_executor, - inbound_message_rx, - messaging_pipeline.inbound, - shutdown.to_signal(), - ); - executor.spawn(inbound.run()); + task::spawn(connectivity_manager.create().run()); + + let mut messaging_event_tx = None; + if let Some(messaging_pipeline) = messaging_pipeline { + let (messaging, notifier, messaging_request_tx, inbound_message_rx, messaging_event_sender) = + initialize_messaging( + node_identity.clone(), + peer_manager.clone(), + connection_manager_requester.clone(), + shutdown.to_signal(), + ); + messaging_event_tx = Some(messaging_event_sender); + protocols.add(&[messaging::MESSAGING_PROTOCOL.clone()], notifier); + // Spawn messaging protocol + complete_signals.push(messaging.complete_signal()); + task::spawn(messaging.run()); + + // Spawn inbound pipeline + let bounded_executor = + BoundedExecutor::new(runtime::current(), messaging_pipeline.max_concurrent_inbound_tasks); + let inbound = pipeline::Inbound::new( + bounded_executor, + inbound_message_rx, + messaging_pipeline.inbound, + shutdown.to_signal(), + ); + task::spawn(inbound.run()); + + // Spawn outbound pipeline + let outbound = + pipeline::Outbound::new(runtime::current(), messaging_pipeline.outbound, messaging_request_tx); + task::spawn(outbound.run()); + } - // Spawn outbound pipeline - let outbound = pipeline::Outbound::new(executor.clone(), messaging_pipeline.outbound, messaging_request_tx); - executor.spawn(outbound.run()); + connection_manager.set_protocols(protocols); + task::spawn(connection_manager.run()); let listening_addr = Self::wait_listening(events_stream).await?; @@ -202,9 +210,9 @@ where listening_addr, node_identity, peer_manager, - messaging_event_tx, + messaging_event_tx: messaging_event_tx.unwrap_or_else(|| broadcast::channel(1).0), hidden_service, - complete_signals: vec![conn_man_shutdown_signal, messaging_signal], + complete_signals, }) } @@ -218,11 +226,6 @@ where Arc::clone(&self.node_identity) } - /// Return a subscription to OMS events. This will emit events sent _after_ this subscription was created. - pub fn subscribe_messaging_events(&self) -> messaging::MessagingEventReceiver { - self.messaging_event_tx.subscribe() - } - /// Return an owned copy of a ConnectionManagerRequester. Used to initiate connections to peers. pub fn connection_manager_requester(&self) -> ConnectionManagerRequester { self.connection_manager_requester.clone() @@ -298,11 +301,6 @@ impl CommsNode { self.messaging_event_tx.subscribe() } - /// Return a clone of the of the messaging event Sender to allow for other services to create subscriptions - pub fn message_event_sender(&self) -> messaging::MessagingEventSender { - self.messaging_event_tx.clone() - } - /// Return an owned copy of a ConnectionManagerRequester. Used to initiate connections to peers. pub fn connection_manager(&self) -> ConnectionManagerRequester { self.connection_manager_requester.clone() @@ -325,3 +323,34 @@ impl CommsNode { CommsShutdown::new(self.complete_signals) } } + +fn initialize_messaging( + node_identity: Arc, + peer_manager: Arc, + connection_manager_requester: ConnectionManagerRequester, + shutdown_signal: ShutdownSignal, +) -> ( + messaging::MessagingProtocol, + ProtocolNotifier, + mpsc::Sender, + mpsc::Receiver, + messaging::MessagingEventSender, +) +{ + let (proto_tx, proto_rx) = mpsc::channel(consts::MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE); + let (messaging_request_tx, messaging_request_rx) = mpsc::channel(consts::MESSAGING_REQUEST_BUFFER_SIZE); + let (inbound_message_tx, inbound_message_rx) = mpsc::channel(consts::INBOUND_MESSAGE_BUFFER_SIZE); + let (event_tx, _) = broadcast::channel(consts::MESSAGING_EVENTS_BUFFER_SIZE); + let messaging = MessagingProtocol::new( + connection_manager_requester, + peer_manager, + node_identity, + proto_rx, + messaging_request_rx, + event_tx.clone(), + inbound_message_tx, + shutdown_signal, + ); + + (messaging, proto_tx, messaging_request_tx, inbound_message_rx, event_tx) +} diff --git a/comms/src/builder/error.rs b/comms/src/builder/error.rs index cfad5cc51b..deb14842e1 100644 --- a/comms/src/builder/error.rs +++ b/comms/src/builder/error.rs @@ -33,11 +33,6 @@ pub enum CommsBuilderError { NodeIdentityNotSet, #[error("The PeerStorage was not provided to the CommsBuilder. Use `with_peer_storage` to set it.")] PeerStorageNotProvided, - #[error( - "The messaging pipeline was not provided to the CommsBuilder. Use `with_messaging_pipeline` to set it's \ - pipeline." - )] - MessagingPiplineNotProvided, #[error("Unable to receive a ConnectionManagerEvent within timeout")] ConnectionManagerEventStreamTimeout, #[error("ConnectionManagerEvent stream unexpectedly closed")] diff --git a/comms/src/builder/mod.rs b/comms/src/builder/mod.rs index fa4250d62f..a2027bfe83 100644 --- a/comms/src/builder/mod.rs +++ b/comms/src/builder/mod.rs @@ -57,12 +57,11 @@ use crate::{ ConnectivityRequest, ConnectivityRequester, }, - message::InboundMessage, multiaddr::Multiaddr, multiplexing::Substream, noise::NoiseConfig, peer_manager::{NodeIdentity, PeerManager}, - protocol::{messaging, messaging::MessagingProtocol, ProtocolNotification, Protocols}, + protocol::Protocols, tor, transports::{SocksTransport, TcpTransport, Transport}, types::CommsDatabase, @@ -249,37 +248,6 @@ where self } - fn make_messaging( - &self, - conn_man_requester: ConnectionManagerRequester, - peer_manager: Arc, - node_identity: Arc, - ) -> ( - messaging::MessagingProtocol, - mpsc::Sender>, - mpsc::Sender, - mpsc::Receiver, - messaging::MessagingEventSender, - ) - { - let (proto_tx, proto_rx) = mpsc::channel(consts::MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE); - let (messaging_request_tx, messaging_request_rx) = mpsc::channel(consts::MESSAGING_REQUEST_BUFFER_SIZE); - let (inbound_message_tx, inbound_message_rx) = mpsc::channel(consts::INBOUND_MESSAGE_BUFFER_SIZE); - let (event_tx, _) = broadcast::channel(consts::MESSAGING_EVENTS_BUFFER_SIZE); - let messaging = MessagingProtocol::new( - conn_man_requester, - peer_manager, - node_identity, - proto_rx, - messaging_request_rx, - event_tx.clone(), - inbound_message_tx, - self.shutdown.to_signal(), - ); - - (messaging, proto_tx, messaging_request_tx, inbound_message_rx, event_tx) - } - fn make_peer_manager(&mut self) -> Result, CommsBuilderError> { match self.peer_storage.take() { Some(storage) => { @@ -294,7 +262,6 @@ where &mut self, node_identity: Arc, peer_manager: Arc, - protocols: Protocols, request_rx: mpsc::Receiver, connection_manager_events_tx: broadcast::Sender>, ) -> ConnectionManager @@ -311,7 +278,6 @@ where request_rx, node_identity, peer_manager, - protocols, connection_manager_events_tx, self.shutdown.to_signal(), ) @@ -352,20 +318,8 @@ where let connection_manager_requester = ConnectionManagerRequester::new(conn_man_tx, connection_manager_event_tx.clone()); - let (messaging, messaging_proto_tx, messaging_request_tx, inbound_message_rx, messaging_event_tx) = self - .make_messaging( - connection_manager_requester.clone(), - peer_manager.clone(), - node_identity.clone(), - ); - //---------------------------------- Protocols --------------------------------------------// - let protocols = self - .protocols - .take() - .or_else(|| Some(Protocols::new())) - .map(move |protocols| protocols.add(&[messaging::MESSAGING_PROTOCOL.clone()], messaging_proto_tx)) - .expect("cannot fail"); + let protocols = self.protocols.take().unwrap_or_default(); //---------------------------------- ConnectivityManager --------------------------------------------// @@ -383,7 +337,6 @@ where let connection_manager = self.make_connection_manager( node_identity.clone(), peer_manager.clone(), - protocols, conn_man_rx, connection_manager_event_tx.clone(), ); @@ -394,13 +347,10 @@ where connection_manager_event_tx, connectivity_manager, connectivity_requester, - messaging_request_tx, messaging_pipeline: None, - messaging, - messaging_event_tx, - inbound_message_rx, node_identity, peer_manager, + protocols, hidden_service: self.hidden_service, shutdown: self.shutdown, }) diff --git a/comms/src/builder/tests.rs b/comms/src/builder/tests.rs index 0329ea2455..274611bd10 100644 --- a/comms/src/builder/tests.rs +++ b/comms/src/builder/tests.rs @@ -58,7 +58,7 @@ async fn spawn_node( let comms_node = CommsBuilder::new() // These calls are just to get rid of unused function warnings. // - .with_executor(runtime::current_executor()) + .with_executor(runtime::current()) .with_dial_backoff(ConstantBackoff::new(Duration::from_millis(500))) .on_shutdown(|| {}) // @@ -103,12 +103,14 @@ async fn peer_to_peer_custom_protocols() { // Setup test protocols let (test_sender, _test_protocol_rx1) = mpsc::channel(10); let (another_test_sender, mut another_test_protocol_rx1) = mpsc::channel(10); - let protocols1 = Protocols::new() + let mut protocols1 = Protocols::new(); + protocols1 .add(&[TEST_PROTOCOL], test_sender) .add(&[ANOTHER_TEST_PROTOCOL], another_test_sender); let (test_sender, mut test_protocol_rx2) = mpsc::channel(10); let (another_test_sender, _another_test_protocol_rx2) = mpsc::channel(10); - let protocols2 = Protocols::new() + let mut protocols2 = Protocols::new(); + protocols2 .add(&[TEST_PROTOCOL], test_sender) .add(&[ANOTHER_TEST_PROTOCOL], another_test_sender); @@ -288,7 +290,7 @@ async fn peer_to_peer_messaging_simultaneous() { .unwrap(); // Simultaneously send messages between the two nodes - let rt_handle = runtime::current_executor(); + let rt_handle = runtime::current(); let handle1 = rt_handle.spawn(async move { for i in 0..NUM_MSGS { let outbound_msg = OutboundMessage::new( diff --git a/comms/src/connection_manager/dialer.rs b/comms/src/connection_manager/dialer.rs index 9994c3caf6..85d9005027 100644 --- a/comms/src/connection_manager/dialer.rs +++ b/comms/src/connection_manager/dialer.rs @@ -84,7 +84,7 @@ pub struct Dialer { conn_man_notifier: mpsc::Sender, shutdown: Option, pending_dial_requests: HashMap>>>, - supported_protocols: Vec, + our_supported_protocols: Vec, } impl Dialer @@ -103,7 +103,6 @@ where backoff: TBackoff, request_rx: mpsc::Receiver, conn_man_notifier: mpsc::Sender, - supported_protocols: Vec, shutdown: ShutdownSignal, ) -> Self { @@ -119,10 +118,16 @@ where conn_man_notifier, shutdown: Some(shutdown), pending_dial_requests: Default::default(), - supported_protocols, + our_supported_protocols: Vec::new(), } } + /// Set the supported protocols of this node to send to peers during the peer identity exchange + pub fn set_supported_protocols(&mut self, our_supported_protocols: Vec) -> &mut Self { + self.our_supported_protocols = our_supported_protocols; + self + } + pub async fn run(mut self) { let mut pending_dials = FuturesUnordered::new(); let mut shutdown = self @@ -274,7 +279,7 @@ where let node_identity = Arc::clone(&self.node_identity); let peer_manager = self.peer_manager.clone(); let conn_man_notifier = self.conn_man_notifier.clone(); - let supported_protocols = self.supported_protocols.clone(); + let supported_protocols = self.our_supported_protocols.clone(); let noise_config = self.noise_config.clone(); let allow_test_addresses = self.config.allow_test_addresses; diff --git a/comms/src/connection_manager/listener.rs b/comms/src/connection_manager/listener.rs index 405ad1934d..cb8000ef00 100644 --- a/comms/src/connection_manager/listener.rs +++ b/comms/src/connection_manager/listener.rs @@ -85,7 +85,6 @@ where conn_man_notifier: mpsc::Sender, peer_manager: Arc, node_identity: Arc, - supported_protocols: Vec, shutdown_signal: ShutdownSignal, ) -> Self { @@ -97,13 +96,19 @@ where node_identity, shutdown_signal, listening_address: None, - our_supported_protocols: supported_protocols, + our_supported_protocols: Vec::new(), bounded_executor: BoundedExecutor::from_current(config.max_simultaneous_inbound_connects), liveness_session_count: Arc::new(AtomicUsize::new(config.liveness_max_sessions)), config, } } + /// Set the supported protocols of this node to send to peers during the peer identity exchange + pub fn set_supported_protocols(&mut self, our_supported_protocols: Vec) -> &mut Self { + self.our_supported_protocols = our_supported_protocols; + self + } + pub async fn run(mut self) { let mut shutdown_signal = self.shutdown_signal.clone(); @@ -182,7 +187,7 @@ where permit.fetch_sub(1, Ordering::SeqCst); let liveness = LivenessSession::new(socket); debug!(target: LOG_TARGET, "Started liveness session"); - runtime::current_executor().spawn(async move { + runtime::current().spawn(async move { future::select(liveness.run(), shutdown_signal).await; permit.fetch_add(1, Ordering::SeqCst); }); diff --git a/comms/src/connection_manager/liveness.rs b/comms/src/connection_manager/liveness.rs index 48b236eb13..350e9c5f85 100644 --- a/comms/src/connection_manager/liveness.rs +++ b/comms/src/connection_manager/liveness.rs @@ -57,7 +57,7 @@ mod test { async fn echos() { let (inbound, outbound) = MemorySocket::new_pair(); let liveness = LivenessSession::new(inbound); - let join_handle = runtime::current_executor().spawn(liveness.run()); + let join_handle = runtime::current().spawn(liveness.run()); let mut outbound = Framed::new(IoCompat::new(outbound), LinesCodec::new()); for _ in 0..10usize { outbound.send("ECHO".to_string()).await.unwrap() diff --git a/comms/src/connection_manager/manager.rs b/comms/src/connection_manager/manager.rs index 4d94d9e3d0..e0d6f7baf7 100644 --- a/comms/src/connection_manager/manager.rs +++ b/comms/src/connection_manager/manager.rs @@ -180,7 +180,6 @@ where request_rx: mpsc::Receiver, node_identity: Arc, peer_manager: Arc, - protocols: Protocols, connection_manager_events_tx: broadcast::Sender>, shutdown_signal: ShutdownSignal, ) -> Self @@ -189,8 +188,6 @@ where let (dialer_tx, dialer_rx) = mpsc::channel(DIALER_REQUEST_CHANNEL_SIZE); - let supported_protocols = protocols.get_supported_protocols(); - let listener = PeerListener::new( config.clone(), transport.clone(), @@ -198,7 +195,6 @@ where internal_event_tx.clone(), peer_manager.clone(), Arc::clone(&node_identity), - supported_protocols.clone(), shutdown_signal.clone(), ); @@ -211,7 +207,6 @@ where backoff, dialer_rx, internal_event_tx, - supported_protocols, shutdown_signal.clone(), ); @@ -221,7 +216,7 @@ where request_rx: request_rx.fuse(), node_identity, peer_manager, - protocols, + protocols: Protocols::new(), internal_event_rx: internal_event_rx.fuse(), dialer_tx, dialer: Some(dialer), @@ -234,6 +229,11 @@ where } } + pub fn set_protocols(&mut self, protocols: Protocols) -> &mut Self { + self.protocols = protocols; + self + } + pub fn complete_signal(&self) -> ShutdownSignal { self.complete_trigger.to_signal() } @@ -295,21 +295,23 @@ where } fn run_listener(&mut self) { - let listener = self + let mut listener = self .listener .take() .expect("ConnectionManager initialized without a listener"); - runtime::current_executor().spawn(listener.run()); + listener.set_supported_protocols(self.protocols.get_supported_protocols()); + runtime::current().spawn(listener.run()); } fn run_dialer(&mut self) { - let dialer = self + let mut dialer = self .dialer .take() .expect("ConnectionManager initialized without a dialer"); - runtime::current_executor().spawn(dialer.run()); + dialer.set_supported_protocols(self.protocols.get_supported_protocols()); + runtime::current().spawn(dialer.run()); } async fn handle_request(&mut self, request: ConnectionManagerRequest) { @@ -542,7 +544,7 @@ where linger.as_millis() ); - runtime::current_executor().spawn(async move { + runtime::current().spawn(async move { debug!( target: LOG_TARGET, "Waiting for linger period ({}ms) to expire...", diff --git a/comms/src/connection_manager/peer_connection.rs b/comms/src/connection_manager/peer_connection.rs index 578cb064fe..edb8e28989 100644 --- a/comms/src/connection_manager/peer_connection.rs +++ b/comms/src/connection_manager/peer_connection.rs @@ -88,7 +88,7 @@ pub fn create( event_notifier, our_supported_protocols, ); - runtime::current_executor().spawn(peer_actor.run()); + runtime::current().spawn(peer_actor.run()); Ok(peer_conn) } diff --git a/comms/src/connection_manager/tests/listener_dialer.rs b/comms/src/connection_manager/tests/listener_dialer.rs index e78cca73aa..9376f25571 100644 --- a/comms/src/connection_manager/tests/listener_dialer.rs +++ b/comms/src/connection_manager/tests/listener_dialer.rs @@ -66,7 +66,6 @@ async fn listen() -> Result<(), Box> { event_tx.clone(), peer_manager.into(), node_identity, - vec![], shutdown.to_signal(), ); @@ -98,7 +97,7 @@ async fn smoke() { let expected_proto = ProtocolId::from_static(b"/tari/test-proto"); let supported_protocols = vec![expected_proto.clone()]; let peer_manager1 = build_peer_manager(); - let listener = PeerListener::new( + let mut listener = PeerListener::new( ConnectionManagerConfig { listener_address: "/memory/0".parse().unwrap(), ..Default::default() @@ -108,9 +107,9 @@ async fn smoke() { event_tx.clone(), peer_manager1.clone().into(), node_identity1.clone(), - supported_protocols.clone(), shutdown.to_signal(), ); + listener.set_supported_protocols(supported_protocols.clone()); let listener_fut = rt_handle.spawn(listener.run()); @@ -118,7 +117,7 @@ async fn smoke() { let noise_config2 = NoiseConfig::new(node_identity2.clone()); let (mut request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); - let dialer = Dialer::new( + let mut dialer = Dialer::new( ConnectionManagerConfig::default(), node_identity2.clone(), peer_manager2.clone().into(), @@ -127,9 +126,9 @@ async fn smoke() { ConstantBackoff::new(Duration::from_millis(100)), request_rx, event_tx, - supported_protocols, shutdown.to_signal(), ); + dialer.set_supported_protocols(supported_protocols.clone()); let dialer_fut = rt_handle.spawn(dialer.run()); @@ -200,7 +199,7 @@ async fn banned() { let expected_proto = ProtocolId::from_static(b"/tari/test-proto"); let supported_protocols = vec![expected_proto.clone()]; let peer_manager1 = build_peer_manager(); - let listener = PeerListener::new( + let mut listener = PeerListener::new( ConnectionManagerConfig { listener_address: "/memory/0".parse().unwrap(), ..Default::default() @@ -210,9 +209,9 @@ async fn banned() { event_tx.clone(), peer_manager1.clone().into(), node_identity1.clone(), - supported_protocols.clone(), shutdown.to_signal(), ); + listener.set_supported_protocols(supported_protocols.clone()); let listener_fut = rt_handle.spawn(listener.run()); @@ -225,7 +224,7 @@ async fn banned() { let noise_config2 = NoiseConfig::new(node_identity2.clone()); let (mut request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); - let dialer = Dialer::new( + let mut dialer = Dialer::new( ConnectionManagerConfig::default(), node_identity2.clone(), peer_manager2.clone().into(), @@ -234,9 +233,9 @@ async fn banned() { ConstantBackoff::new(Duration::from_millis(100)), request_rx, event_tx, - supported_protocols, shutdown.to_signal(), ); + dialer.set_supported_protocols(supported_protocols); let dialer_fut = rt_handle.spawn(dialer.run()); diff --git a/comms/src/connection_manager/tests/manager.rs b/comms/src/connection_manager/tests/manager.rs index 707072c4f7..a715d5266c 100644 --- a/comms/src/connection_manager/tests/manager.rs +++ b/comms/src/connection_manager/tests/manager.rs @@ -65,7 +65,6 @@ async fn connect_to_nonexistent_peer() { request_rx, node_identity, peer_manager.into(), - Protocols::new(), event_tx, shutdown.to_signal(), ); @@ -98,26 +97,31 @@ async fn dial_success() { // Setup connection manager 1 let peer_manager1 = build_peer_manager(); + + let mut protocols = Protocols::new(); + protocols.add([TEST_PROTO], proto_tx1); let mut conn_man1 = build_connection_manager( TestNodeConfig { node_identity: node_identity1.clone(), ..Default::default() }, peer_manager1.clone(), - Protocols::new().add([TEST_PROTO], proto_tx1), + protocols, shutdown.to_signal(), ); conn_man1.wait_until_listening().await.unwrap(); let peer_manager2 = build_peer_manager(); + let mut protocols = Protocols::new(); + protocols.add([TEST_PROTO], proto_tx2); let mut conn_man2 = build_connection_manager( TestNodeConfig { node_identity: node_identity2.clone(), ..Default::default() }, peer_manager2.clone(), - Protocols::new().add([TEST_PROTO], proto_tx2), + protocols, shutdown.to_signal(), ); let mut subscription2 = conn_man2.get_event_subscription(); diff --git a/comms/src/multiplexing/yamux.rs b/comms/src/multiplexing/yamux.rs index 8f4ff6de65..5c3facd5ce 100644 --- a/comms/src/multiplexing/yamux.rs +++ b/comms/src/multiplexing/yamux.rs @@ -99,7 +99,7 @@ impl Yamux { let (incoming_tx, incoming_rx) = mpsc::channel(10); let stream = yamux::into_stream(connection).boxed(); let incoming = IncomingWorker::new(stream, incoming_tx, shutdown.to_signal()); - runtime::current_executor().spawn(incoming.run()); + runtime::current().spawn(incoming.run()); IncomingSubstreams::new(incoming_rx, counter, shutdown) } diff --git a/comms/src/protocol/mod.rs b/comms/src/protocol/mod.rs index 29902bb072..3a0f0a1fb7 100644 --- a/comms/src/protocol/mod.rs +++ b/comms/src/protocol/mod.rs @@ -30,7 +30,7 @@ mod negotiation; pub use negotiation::ProtocolNegotiation; mod protocols; -pub use protocols::{ProtocolEvent, ProtocolNotification, Protocols}; +pub use protocols::{ProtocolEvent, ProtocolNotification, ProtocolNotifier, Protocols}; pub mod messaging; diff --git a/comms/src/protocol/protocols.rs b/comms/src/protocol/protocols.rs index 0c2e9fd8fc..0ccbd8147c 100644 --- a/comms/src/protocol/protocols.rs +++ b/comms/src/protocol/protocols.rs @@ -27,6 +27,8 @@ use crate::{ use futures::{channel::mpsc, SinkExt}; use std::collections::HashMap; +pub type ProtocolNotifier = mpsc::Sender>; + #[derive(Debug, Clone)] pub enum ProtocolEvent { NewInboundSubstream(Box, TSubstream), @@ -45,7 +47,7 @@ impl ProtocolNotification { } pub struct Protocols { - protocols: HashMap>>, + protocols: HashMap>, } impl Clone for Protocols { @@ -69,12 +71,7 @@ impl Protocols { Default::default() } - pub fn add>( - mut self, - protocols: I, - notifier: mpsc::Sender>, - ) -> Self - { + pub fn add>(&mut self, protocols: I, notifier: ProtocolNotifier) -> &mut Self { self.protocols .extend(protocols.as_ref().iter().map(|p| (p.clone(), notifier.clone()))); self @@ -120,7 +117,8 @@ mod test { ProtocolId::from_static(b"/tari/test/1"), ProtocolId::from_static(b"/tari/test/2"), ]; - let protocols = Protocols::<()>::new().add(&protos, tx); + let mut protocols = Protocols::<()>::new(); + protocols.add(&protos, tx); assert!(protocols.get_supported_protocols().iter().all(|p| protos.contains(p))); } @@ -129,7 +127,8 @@ mod test { async fn notify() { let (tx, mut rx) = mpsc::channel(1); let protos = [ProtocolId::from_static(b"/tari/test/1")]; - let mut protocols = Protocols::<()>::new().add(&protos, tx); + let mut protocols = Protocols::<()>::new(); + protocols.add(&protos, tx); protocols .notify( diff --git a/comms/src/runtime.rs b/comms/src/runtime.rs index 1377912666..3d4907dfe7 100644 --- a/comms/src/runtime.rs +++ b/comms/src/runtime.rs @@ -31,6 +31,6 @@ pub use tokio_macros::test_basic; /// Return the current tokio executor. Panics if the tokio runtime is not started. #[inline] -pub fn current_executor() -> runtime::Handle { +pub fn current() -> runtime::Handle { runtime::Handle::current() } diff --git a/comms/src/test_utils/test_node.rs b/comms/src/test_utils/test_node.rs index f8b875c7b0..1ad94555b4 100644 --- a/comms/src/test_utils/test_node.rs +++ b/comms/src/test_utils/test_node.rs @@ -81,7 +81,7 @@ pub fn build_connection_manager( let requester = ConnectionManagerRequester::new(request_tx, event_tx.clone()); - let connection_manager = ConnectionManager::new( + let mut connection_manager = ConnectionManager::new( config.connection_manager_config, config.transport, noise_config, @@ -89,12 +89,12 @@ pub fn build_connection_manager( request_rx, config.node_identity, peer_manager.into(), - protocols, event_tx, shutdown, ); + connection_manager.set_protocols(protocols); - runtime::current_executor().spawn(connection_manager.run()); + runtime::current().spawn(connection_manager.run()); requester } diff --git a/comms/src/tor/control_client/test_server.rs b/comms/src/tor/control_client/test_server.rs index 8e46c9a31b..1feaf9aa1b 100644 --- a/comms/src/tor/control_client/test_server.rs +++ b/comms/src/tor/control_client/test_server.rs @@ -36,7 +36,7 @@ pub async fn spawn() -> (Multiaddr, State, MemorySocket) { let server = TorControlPortTestServer::new(socket_in); let state = server.get_shared_state(); - runtime::current_executor().spawn(server.run()); + runtime::current().spawn(server.run()); (addr, state, socket_out) }