diff --git a/comms/src/bounded_executor.rs b/comms/src/bounded_executor.rs index fb8d018398..1e17a88848 100644 --- a/comms/src/bounded_executor.rs +++ b/comms/src/bounded_executor.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::runtime::current_executor; +use crate::runtime::current; use std::{future::Future, sync::Arc}; use tokio::{runtime, sync::Semaphore, task::JoinHandle}; @@ -42,7 +42,7 @@ impl BoundedExecutor { } pub fn from_current(num_permits: usize) -> Self { - Self::new(current_executor(), num_permits) + Self::new(current(), num_permits) } /// Spawn a future onto the Tokio runtime asynchronously blocking if there are too many diff --git a/comms/src/builder/comms_node.rs b/comms/src/builder/comms_node.rs index dd0c0923b0..0cce6de471 100644 --- a/comms/src/builder/comms_node.rs +++ b/comms/src/builder/comms_node.rs @@ -24,14 +24,17 @@ use super::{placeholder::PlaceholderService, CommsBuilderError, CommsShutdown}; use crate::{ backoff::BoxedBackoff, bounded_executor::BoundedExecutor, + builder::consts, connection_manager::{ConnectionManager, ConnectionManagerEvent, ConnectionManagerRequester}, connectivity::{ConnectivityManager, ConnectivityRequester}, message::InboundMessage, multiaddr::Multiaddr, + multiplexing::Substream, peer_manager::{NodeIdentity, PeerManager}, pipeline, - protocol::{messaging, messaging::MessagingProtocol}, + protocol::{messaging, messaging::MessagingProtocol, ProtocolNotifier, Protocols}, runtime, + runtime::task, tor, transports::Transport, }; @@ -58,13 +61,10 @@ pub struct BuiltCommsNode< pub connectivity_requester: ConnectivityRequester, pub messaging_pipeline: Option>, pub node_identity: Arc, - pub messaging: MessagingProtocol, - pub messaging_event_tx: messaging::MessagingEventSender, - pub inbound_message_rx: mpsc::Receiver, pub hidden_service: Option, - pub messaging_request_tx: mpsc::Sender, - pub shutdown: Shutdown, pub peer_manager: Arc, + pub protocols: Protocols, + pub shutdown: Shutdown, } impl BuiltCommsNode @@ -100,11 +100,8 @@ where connectivity_manager: self.connectivity_manager, connectivity_requester: self.connectivity_requester, node_identity: self.node_identity, - messaging: self.messaging, - messaging_event_tx: self.messaging_event_tx, - inbound_message_rx: self.inbound_message_rx, shutdown: self.shutdown, - messaging_request_tx: self.messaging_request_tx, + protocols: self.protocols, hidden_service: self.hidden_service, peer_manager: self.peer_manager, } @@ -131,19 +128,16 @@ where pub async fn spawn(self) -> Result { let BuiltCommsNode { - connection_manager, + mut connection_manager, connection_manager_requester, connection_manager_event_tx, connectivity_manager, connectivity_requester, messaging_pipeline, - messaging_request_tx, - inbound_message_rx, node_identity, shutdown, peer_manager, - messaging, - messaging_event_tx, + mut protocols, hidden_service, } = self; @@ -163,34 +157,48 @@ where "Your node's public address is '{}'", node_identity.public_address() ); - let messaging_pipeline = messaging_pipeline.ok_or(CommsBuilderError::MessagingPiplineNotProvided)?; + let mut complete_signals = Vec::new(); let events_stream = connection_manager_event_tx.subscribe(); - let conn_man_shutdown_signal = connection_manager.complete_signal(); - - let executor = runtime::current_executor(); + complete_signals.push(connection_manager.complete_signal()); // Connectivity manager - executor.spawn(connectivity_manager.create().run()); - executor.spawn(connection_manager.run()); - - // Spawn messaging protocol - let messaging_signal = messaging.complete_signal(); - executor.spawn(messaging.run()); - - // Spawn inbound pipeline - let bounded_executor = BoundedExecutor::new(executor.clone(), messaging_pipeline.max_concurrent_inbound_tasks); - let inbound = pipeline::Inbound::new( - bounded_executor, - inbound_message_rx, - messaging_pipeline.inbound, - shutdown.to_signal(), - ); - executor.spawn(inbound.run()); + task::spawn(connectivity_manager.create().run()); + + let mut messaging_event_tx = None; + if let Some(messaging_pipeline) = messaging_pipeline { + let (messaging, notifier, messaging_request_tx, inbound_message_rx, messaging_event_sender) = + initialize_messaging( + node_identity.clone(), + peer_manager.clone(), + connection_manager_requester.clone(), + shutdown.to_signal(), + ); + messaging_event_tx = Some(messaging_event_sender); + protocols.add(&[messaging::MESSAGING_PROTOCOL.clone()], notifier); + // Spawn messaging protocol + complete_signals.push(messaging.complete_signal()); + task::spawn(messaging.run()); + + // Spawn inbound pipeline + let bounded_executor = + BoundedExecutor::new(runtime::current(), messaging_pipeline.max_concurrent_inbound_tasks); + let inbound = pipeline::Inbound::new( + bounded_executor, + inbound_message_rx, + messaging_pipeline.inbound, + shutdown.to_signal(), + ); + task::spawn(inbound.run()); + + // Spawn outbound pipeline + let outbound = + pipeline::Outbound::new(runtime::current(), messaging_pipeline.outbound, messaging_request_tx); + task::spawn(outbound.run()); + } - // Spawn outbound pipeline - let outbound = pipeline::Outbound::new(executor.clone(), messaging_pipeline.outbound, messaging_request_tx); - executor.spawn(outbound.run()); + connection_manager.set_protocols(protocols); + task::spawn(connection_manager.run()); let listening_addr = Self::wait_listening(events_stream).await?; @@ -202,9 +210,9 @@ where listening_addr, node_identity, peer_manager, - messaging_event_tx, + messaging_event_tx: messaging_event_tx.unwrap_or_else(|| broadcast::channel(1).0), hidden_service, - complete_signals: vec![conn_man_shutdown_signal, messaging_signal], + complete_signals, }) } @@ -218,11 +226,6 @@ where Arc::clone(&self.node_identity) } - /// Return a subscription to OMS events. This will emit events sent _after_ this subscription was created. - pub fn subscribe_messaging_events(&self) -> messaging::MessagingEventReceiver { - self.messaging_event_tx.subscribe() - } - /// Return an owned copy of a ConnectionManagerRequester. Used to initiate connections to peers. pub fn connection_manager_requester(&self) -> ConnectionManagerRequester { self.connection_manager_requester.clone() @@ -298,11 +301,6 @@ impl CommsNode { self.messaging_event_tx.subscribe() } - /// Return a clone of the of the messaging event Sender to allow for other services to create subscriptions - pub fn message_event_sender(&self) -> messaging::MessagingEventSender { - self.messaging_event_tx.clone() - } - /// Return an owned copy of a ConnectionManagerRequester. Used to initiate connections to peers. pub fn connection_manager(&self) -> ConnectionManagerRequester { self.connection_manager_requester.clone() @@ -325,3 +323,34 @@ impl CommsNode { CommsShutdown::new(self.complete_signals) } } + +fn initialize_messaging( + node_identity: Arc, + peer_manager: Arc, + connection_manager_requester: ConnectionManagerRequester, + shutdown_signal: ShutdownSignal, +) -> ( + messaging::MessagingProtocol, + ProtocolNotifier, + mpsc::Sender, + mpsc::Receiver, + messaging::MessagingEventSender, +) +{ + let (proto_tx, proto_rx) = mpsc::channel(consts::MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE); + let (messaging_request_tx, messaging_request_rx) = mpsc::channel(consts::MESSAGING_REQUEST_BUFFER_SIZE); + let (inbound_message_tx, inbound_message_rx) = mpsc::channel(consts::INBOUND_MESSAGE_BUFFER_SIZE); + let (event_tx, _) = broadcast::channel(consts::MESSAGING_EVENTS_BUFFER_SIZE); + let messaging = MessagingProtocol::new( + connection_manager_requester, + peer_manager, + node_identity, + proto_rx, + messaging_request_rx, + event_tx.clone(), + inbound_message_tx, + shutdown_signal, + ); + + (messaging, proto_tx, messaging_request_tx, inbound_message_rx, event_tx) +} diff --git a/comms/src/builder/error.rs b/comms/src/builder/error.rs index cfad5cc51b..deb14842e1 100644 --- a/comms/src/builder/error.rs +++ b/comms/src/builder/error.rs @@ -33,11 +33,6 @@ pub enum CommsBuilderError { NodeIdentityNotSet, #[error("The PeerStorage was not provided to the CommsBuilder. Use `with_peer_storage` to set it.")] PeerStorageNotProvided, - #[error( - "The messaging pipeline was not provided to the CommsBuilder. Use `with_messaging_pipeline` to set it's \ - pipeline." - )] - MessagingPiplineNotProvided, #[error("Unable to receive a ConnectionManagerEvent within timeout")] ConnectionManagerEventStreamTimeout, #[error("ConnectionManagerEvent stream unexpectedly closed")] diff --git a/comms/src/builder/mod.rs b/comms/src/builder/mod.rs index fa4250d62f..a2027bfe83 100644 --- a/comms/src/builder/mod.rs +++ b/comms/src/builder/mod.rs @@ -57,12 +57,11 @@ use crate::{ ConnectivityRequest, ConnectivityRequester, }, - message::InboundMessage, multiaddr::Multiaddr, multiplexing::Substream, noise::NoiseConfig, peer_manager::{NodeIdentity, PeerManager}, - protocol::{messaging, messaging::MessagingProtocol, ProtocolNotification, Protocols}, + protocol::Protocols, tor, transports::{SocksTransport, TcpTransport, Transport}, types::CommsDatabase, @@ -249,37 +248,6 @@ where self } - fn make_messaging( - &self, - conn_man_requester: ConnectionManagerRequester, - peer_manager: Arc, - node_identity: Arc, - ) -> ( - messaging::MessagingProtocol, - mpsc::Sender>, - mpsc::Sender, - mpsc::Receiver, - messaging::MessagingEventSender, - ) - { - let (proto_tx, proto_rx) = mpsc::channel(consts::MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE); - let (messaging_request_tx, messaging_request_rx) = mpsc::channel(consts::MESSAGING_REQUEST_BUFFER_SIZE); - let (inbound_message_tx, inbound_message_rx) = mpsc::channel(consts::INBOUND_MESSAGE_BUFFER_SIZE); - let (event_tx, _) = broadcast::channel(consts::MESSAGING_EVENTS_BUFFER_SIZE); - let messaging = MessagingProtocol::new( - conn_man_requester, - peer_manager, - node_identity, - proto_rx, - messaging_request_rx, - event_tx.clone(), - inbound_message_tx, - self.shutdown.to_signal(), - ); - - (messaging, proto_tx, messaging_request_tx, inbound_message_rx, event_tx) - } - fn make_peer_manager(&mut self) -> Result, CommsBuilderError> { match self.peer_storage.take() { Some(storage) => { @@ -294,7 +262,6 @@ where &mut self, node_identity: Arc, peer_manager: Arc, - protocols: Protocols, request_rx: mpsc::Receiver, connection_manager_events_tx: broadcast::Sender>, ) -> ConnectionManager @@ -311,7 +278,6 @@ where request_rx, node_identity, peer_manager, - protocols, connection_manager_events_tx, self.shutdown.to_signal(), ) @@ -352,20 +318,8 @@ where let connection_manager_requester = ConnectionManagerRequester::new(conn_man_tx, connection_manager_event_tx.clone()); - let (messaging, messaging_proto_tx, messaging_request_tx, inbound_message_rx, messaging_event_tx) = self - .make_messaging( - connection_manager_requester.clone(), - peer_manager.clone(), - node_identity.clone(), - ); - //---------------------------------- Protocols --------------------------------------------// - let protocols = self - .protocols - .take() - .or_else(|| Some(Protocols::new())) - .map(move |protocols| protocols.add(&[messaging::MESSAGING_PROTOCOL.clone()], messaging_proto_tx)) - .expect("cannot fail"); + let protocols = self.protocols.take().unwrap_or_default(); //---------------------------------- ConnectivityManager --------------------------------------------// @@ -383,7 +337,6 @@ where let connection_manager = self.make_connection_manager( node_identity.clone(), peer_manager.clone(), - protocols, conn_man_rx, connection_manager_event_tx.clone(), ); @@ -394,13 +347,10 @@ where connection_manager_event_tx, connectivity_manager, connectivity_requester, - messaging_request_tx, messaging_pipeline: None, - messaging, - messaging_event_tx, - inbound_message_rx, node_identity, peer_manager, + protocols, hidden_service: self.hidden_service, shutdown: self.shutdown, }) diff --git a/comms/src/builder/tests.rs b/comms/src/builder/tests.rs index 0329ea2455..274611bd10 100644 --- a/comms/src/builder/tests.rs +++ b/comms/src/builder/tests.rs @@ -58,7 +58,7 @@ async fn spawn_node( let comms_node = CommsBuilder::new() // These calls are just to get rid of unused function warnings. // - .with_executor(runtime::current_executor()) + .with_executor(runtime::current()) .with_dial_backoff(ConstantBackoff::new(Duration::from_millis(500))) .on_shutdown(|| {}) // @@ -103,12 +103,14 @@ async fn peer_to_peer_custom_protocols() { // Setup test protocols let (test_sender, _test_protocol_rx1) = mpsc::channel(10); let (another_test_sender, mut another_test_protocol_rx1) = mpsc::channel(10); - let protocols1 = Protocols::new() + let mut protocols1 = Protocols::new(); + protocols1 .add(&[TEST_PROTOCOL], test_sender) .add(&[ANOTHER_TEST_PROTOCOL], another_test_sender); let (test_sender, mut test_protocol_rx2) = mpsc::channel(10); let (another_test_sender, _another_test_protocol_rx2) = mpsc::channel(10); - let protocols2 = Protocols::new() + let mut protocols2 = Protocols::new(); + protocols2 .add(&[TEST_PROTOCOL], test_sender) .add(&[ANOTHER_TEST_PROTOCOL], another_test_sender); @@ -288,7 +290,7 @@ async fn peer_to_peer_messaging_simultaneous() { .unwrap(); // Simultaneously send messages between the two nodes - let rt_handle = runtime::current_executor(); + let rt_handle = runtime::current(); let handle1 = rt_handle.spawn(async move { for i in 0..NUM_MSGS { let outbound_msg = OutboundMessage::new( diff --git a/comms/src/connection_manager/dialer.rs b/comms/src/connection_manager/dialer.rs index 9994c3caf6..85d9005027 100644 --- a/comms/src/connection_manager/dialer.rs +++ b/comms/src/connection_manager/dialer.rs @@ -84,7 +84,7 @@ pub struct Dialer { conn_man_notifier: mpsc::Sender, shutdown: Option, pending_dial_requests: HashMap>>>, - supported_protocols: Vec, + our_supported_protocols: Vec, } impl Dialer @@ -103,7 +103,6 @@ where backoff: TBackoff, request_rx: mpsc::Receiver, conn_man_notifier: mpsc::Sender, - supported_protocols: Vec, shutdown: ShutdownSignal, ) -> Self { @@ -119,10 +118,16 @@ where conn_man_notifier, shutdown: Some(shutdown), pending_dial_requests: Default::default(), - supported_protocols, + our_supported_protocols: Vec::new(), } } + /// Set the supported protocols of this node to send to peers during the peer identity exchange + pub fn set_supported_protocols(&mut self, our_supported_protocols: Vec) -> &mut Self { + self.our_supported_protocols = our_supported_protocols; + self + } + pub async fn run(mut self) { let mut pending_dials = FuturesUnordered::new(); let mut shutdown = self @@ -274,7 +279,7 @@ where let node_identity = Arc::clone(&self.node_identity); let peer_manager = self.peer_manager.clone(); let conn_man_notifier = self.conn_man_notifier.clone(); - let supported_protocols = self.supported_protocols.clone(); + let supported_protocols = self.our_supported_protocols.clone(); let noise_config = self.noise_config.clone(); let allow_test_addresses = self.config.allow_test_addresses; diff --git a/comms/src/connection_manager/listener.rs b/comms/src/connection_manager/listener.rs index 405ad1934d..cb8000ef00 100644 --- a/comms/src/connection_manager/listener.rs +++ b/comms/src/connection_manager/listener.rs @@ -85,7 +85,6 @@ where conn_man_notifier: mpsc::Sender, peer_manager: Arc, node_identity: Arc, - supported_protocols: Vec, shutdown_signal: ShutdownSignal, ) -> Self { @@ -97,13 +96,19 @@ where node_identity, shutdown_signal, listening_address: None, - our_supported_protocols: supported_protocols, + our_supported_protocols: Vec::new(), bounded_executor: BoundedExecutor::from_current(config.max_simultaneous_inbound_connects), liveness_session_count: Arc::new(AtomicUsize::new(config.liveness_max_sessions)), config, } } + /// Set the supported protocols of this node to send to peers during the peer identity exchange + pub fn set_supported_protocols(&mut self, our_supported_protocols: Vec) -> &mut Self { + self.our_supported_protocols = our_supported_protocols; + self + } + pub async fn run(mut self) { let mut shutdown_signal = self.shutdown_signal.clone(); @@ -182,7 +187,7 @@ where permit.fetch_sub(1, Ordering::SeqCst); let liveness = LivenessSession::new(socket); debug!(target: LOG_TARGET, "Started liveness session"); - runtime::current_executor().spawn(async move { + runtime::current().spawn(async move { future::select(liveness.run(), shutdown_signal).await; permit.fetch_add(1, Ordering::SeqCst); }); diff --git a/comms/src/connection_manager/liveness.rs b/comms/src/connection_manager/liveness.rs index 48b236eb13..350e9c5f85 100644 --- a/comms/src/connection_manager/liveness.rs +++ b/comms/src/connection_manager/liveness.rs @@ -57,7 +57,7 @@ mod test { async fn echos() { let (inbound, outbound) = MemorySocket::new_pair(); let liveness = LivenessSession::new(inbound); - let join_handle = runtime::current_executor().spawn(liveness.run()); + let join_handle = runtime::current().spawn(liveness.run()); let mut outbound = Framed::new(IoCompat::new(outbound), LinesCodec::new()); for _ in 0..10usize { outbound.send("ECHO".to_string()).await.unwrap() diff --git a/comms/src/connection_manager/manager.rs b/comms/src/connection_manager/manager.rs index 4d94d9e3d0..e0d6f7baf7 100644 --- a/comms/src/connection_manager/manager.rs +++ b/comms/src/connection_manager/manager.rs @@ -180,7 +180,6 @@ where request_rx: mpsc::Receiver, node_identity: Arc, peer_manager: Arc, - protocols: Protocols, connection_manager_events_tx: broadcast::Sender>, shutdown_signal: ShutdownSignal, ) -> Self @@ -189,8 +188,6 @@ where let (dialer_tx, dialer_rx) = mpsc::channel(DIALER_REQUEST_CHANNEL_SIZE); - let supported_protocols = protocols.get_supported_protocols(); - let listener = PeerListener::new( config.clone(), transport.clone(), @@ -198,7 +195,6 @@ where internal_event_tx.clone(), peer_manager.clone(), Arc::clone(&node_identity), - supported_protocols.clone(), shutdown_signal.clone(), ); @@ -211,7 +207,6 @@ where backoff, dialer_rx, internal_event_tx, - supported_protocols, shutdown_signal.clone(), ); @@ -221,7 +216,7 @@ where request_rx: request_rx.fuse(), node_identity, peer_manager, - protocols, + protocols: Protocols::new(), internal_event_rx: internal_event_rx.fuse(), dialer_tx, dialer: Some(dialer), @@ -234,6 +229,11 @@ where } } + pub fn set_protocols(&mut self, protocols: Protocols) -> &mut Self { + self.protocols = protocols; + self + } + pub fn complete_signal(&self) -> ShutdownSignal { self.complete_trigger.to_signal() } @@ -295,21 +295,23 @@ where } fn run_listener(&mut self) { - let listener = self + let mut listener = self .listener .take() .expect("ConnectionManager initialized without a listener"); - runtime::current_executor().spawn(listener.run()); + listener.set_supported_protocols(self.protocols.get_supported_protocols()); + runtime::current().spawn(listener.run()); } fn run_dialer(&mut self) { - let dialer = self + let mut dialer = self .dialer .take() .expect("ConnectionManager initialized without a dialer"); - runtime::current_executor().spawn(dialer.run()); + dialer.set_supported_protocols(self.protocols.get_supported_protocols()); + runtime::current().spawn(dialer.run()); } async fn handle_request(&mut self, request: ConnectionManagerRequest) { @@ -542,7 +544,7 @@ where linger.as_millis() ); - runtime::current_executor().spawn(async move { + runtime::current().spawn(async move { debug!( target: LOG_TARGET, "Waiting for linger period ({}ms) to expire...", diff --git a/comms/src/connection_manager/peer_connection.rs b/comms/src/connection_manager/peer_connection.rs index 578cb064fe..edb8e28989 100644 --- a/comms/src/connection_manager/peer_connection.rs +++ b/comms/src/connection_manager/peer_connection.rs @@ -88,7 +88,7 @@ pub fn create( event_notifier, our_supported_protocols, ); - runtime::current_executor().spawn(peer_actor.run()); + runtime::current().spawn(peer_actor.run()); Ok(peer_conn) } diff --git a/comms/src/connection_manager/tests/listener_dialer.rs b/comms/src/connection_manager/tests/listener_dialer.rs index e78cca73aa..9376f25571 100644 --- a/comms/src/connection_manager/tests/listener_dialer.rs +++ b/comms/src/connection_manager/tests/listener_dialer.rs @@ -66,7 +66,6 @@ async fn listen() -> Result<(), Box> { event_tx.clone(), peer_manager.into(), node_identity, - vec![], shutdown.to_signal(), ); @@ -98,7 +97,7 @@ async fn smoke() { let expected_proto = ProtocolId::from_static(b"/tari/test-proto"); let supported_protocols = vec![expected_proto.clone()]; let peer_manager1 = build_peer_manager(); - let listener = PeerListener::new( + let mut listener = PeerListener::new( ConnectionManagerConfig { listener_address: "/memory/0".parse().unwrap(), ..Default::default() @@ -108,9 +107,9 @@ async fn smoke() { event_tx.clone(), peer_manager1.clone().into(), node_identity1.clone(), - supported_protocols.clone(), shutdown.to_signal(), ); + listener.set_supported_protocols(supported_protocols.clone()); let listener_fut = rt_handle.spawn(listener.run()); @@ -118,7 +117,7 @@ async fn smoke() { let noise_config2 = NoiseConfig::new(node_identity2.clone()); let (mut request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); - let dialer = Dialer::new( + let mut dialer = Dialer::new( ConnectionManagerConfig::default(), node_identity2.clone(), peer_manager2.clone().into(), @@ -127,9 +126,9 @@ async fn smoke() { ConstantBackoff::new(Duration::from_millis(100)), request_rx, event_tx, - supported_protocols, shutdown.to_signal(), ); + dialer.set_supported_protocols(supported_protocols.clone()); let dialer_fut = rt_handle.spawn(dialer.run()); @@ -200,7 +199,7 @@ async fn banned() { let expected_proto = ProtocolId::from_static(b"/tari/test-proto"); let supported_protocols = vec![expected_proto.clone()]; let peer_manager1 = build_peer_manager(); - let listener = PeerListener::new( + let mut listener = PeerListener::new( ConnectionManagerConfig { listener_address: "/memory/0".parse().unwrap(), ..Default::default() @@ -210,9 +209,9 @@ async fn banned() { event_tx.clone(), peer_manager1.clone().into(), node_identity1.clone(), - supported_protocols.clone(), shutdown.to_signal(), ); + listener.set_supported_protocols(supported_protocols.clone()); let listener_fut = rt_handle.spawn(listener.run()); @@ -225,7 +224,7 @@ async fn banned() { let noise_config2 = NoiseConfig::new(node_identity2.clone()); let (mut request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); - let dialer = Dialer::new( + let mut dialer = Dialer::new( ConnectionManagerConfig::default(), node_identity2.clone(), peer_manager2.clone().into(), @@ -234,9 +233,9 @@ async fn banned() { ConstantBackoff::new(Duration::from_millis(100)), request_rx, event_tx, - supported_protocols, shutdown.to_signal(), ); + dialer.set_supported_protocols(supported_protocols); let dialer_fut = rt_handle.spawn(dialer.run()); diff --git a/comms/src/connection_manager/tests/manager.rs b/comms/src/connection_manager/tests/manager.rs index 707072c4f7..a715d5266c 100644 --- a/comms/src/connection_manager/tests/manager.rs +++ b/comms/src/connection_manager/tests/manager.rs @@ -65,7 +65,6 @@ async fn connect_to_nonexistent_peer() { request_rx, node_identity, peer_manager.into(), - Protocols::new(), event_tx, shutdown.to_signal(), ); @@ -98,26 +97,31 @@ async fn dial_success() { // Setup connection manager 1 let peer_manager1 = build_peer_manager(); + + let mut protocols = Protocols::new(); + protocols.add([TEST_PROTO], proto_tx1); let mut conn_man1 = build_connection_manager( TestNodeConfig { node_identity: node_identity1.clone(), ..Default::default() }, peer_manager1.clone(), - Protocols::new().add([TEST_PROTO], proto_tx1), + protocols, shutdown.to_signal(), ); conn_man1.wait_until_listening().await.unwrap(); let peer_manager2 = build_peer_manager(); + let mut protocols = Protocols::new(); + protocols.add([TEST_PROTO], proto_tx2); let mut conn_man2 = build_connection_manager( TestNodeConfig { node_identity: node_identity2.clone(), ..Default::default() }, peer_manager2.clone(), - Protocols::new().add([TEST_PROTO], proto_tx2), + protocols, shutdown.to_signal(), ); let mut subscription2 = conn_man2.get_event_subscription(); diff --git a/comms/src/multiplexing/yamux.rs b/comms/src/multiplexing/yamux.rs index 8f4ff6de65..5c3facd5ce 100644 --- a/comms/src/multiplexing/yamux.rs +++ b/comms/src/multiplexing/yamux.rs @@ -99,7 +99,7 @@ impl Yamux { let (incoming_tx, incoming_rx) = mpsc::channel(10); let stream = yamux::into_stream(connection).boxed(); let incoming = IncomingWorker::new(stream, incoming_tx, shutdown.to_signal()); - runtime::current_executor().spawn(incoming.run()); + runtime::current().spawn(incoming.run()); IncomingSubstreams::new(incoming_rx, counter, shutdown) } diff --git a/comms/src/protocol/mod.rs b/comms/src/protocol/mod.rs index 29902bb072..3a0f0a1fb7 100644 --- a/comms/src/protocol/mod.rs +++ b/comms/src/protocol/mod.rs @@ -30,7 +30,7 @@ mod negotiation; pub use negotiation::ProtocolNegotiation; mod protocols; -pub use protocols::{ProtocolEvent, ProtocolNotification, Protocols}; +pub use protocols::{ProtocolEvent, ProtocolNotification, ProtocolNotifier, Protocols}; pub mod messaging; diff --git a/comms/src/protocol/protocols.rs b/comms/src/protocol/protocols.rs index 0c2e9fd8fc..0ccbd8147c 100644 --- a/comms/src/protocol/protocols.rs +++ b/comms/src/protocol/protocols.rs @@ -27,6 +27,8 @@ use crate::{ use futures::{channel::mpsc, SinkExt}; use std::collections::HashMap; +pub type ProtocolNotifier = mpsc::Sender>; + #[derive(Debug, Clone)] pub enum ProtocolEvent { NewInboundSubstream(Box, TSubstream), @@ -45,7 +47,7 @@ impl ProtocolNotification { } pub struct Protocols { - protocols: HashMap>>, + protocols: HashMap>, } impl Clone for Protocols { @@ -69,12 +71,7 @@ impl Protocols { Default::default() } - pub fn add>( - mut self, - protocols: I, - notifier: mpsc::Sender>, - ) -> Self - { + pub fn add>(&mut self, protocols: I, notifier: ProtocolNotifier) -> &mut Self { self.protocols .extend(protocols.as_ref().iter().map(|p| (p.clone(), notifier.clone()))); self @@ -120,7 +117,8 @@ mod test { ProtocolId::from_static(b"/tari/test/1"), ProtocolId::from_static(b"/tari/test/2"), ]; - let protocols = Protocols::<()>::new().add(&protos, tx); + let mut protocols = Protocols::<()>::new(); + protocols.add(&protos, tx); assert!(protocols.get_supported_protocols().iter().all(|p| protos.contains(p))); } @@ -129,7 +127,8 @@ mod test { async fn notify() { let (tx, mut rx) = mpsc::channel(1); let protos = [ProtocolId::from_static(b"/tari/test/1")]; - let mut protocols = Protocols::<()>::new().add(&protos, tx); + let mut protocols = Protocols::<()>::new(); + protocols.add(&protos, tx); protocols .notify( diff --git a/comms/src/runtime.rs b/comms/src/runtime.rs index 1377912666..3d4907dfe7 100644 --- a/comms/src/runtime.rs +++ b/comms/src/runtime.rs @@ -31,6 +31,6 @@ pub use tokio_macros::test_basic; /// Return the current tokio executor. Panics if the tokio runtime is not started. #[inline] -pub fn current_executor() -> runtime::Handle { +pub fn current() -> runtime::Handle { runtime::Handle::current() } diff --git a/comms/src/test_utils/test_node.rs b/comms/src/test_utils/test_node.rs index f8b875c7b0..1ad94555b4 100644 --- a/comms/src/test_utils/test_node.rs +++ b/comms/src/test_utils/test_node.rs @@ -81,7 +81,7 @@ pub fn build_connection_manager( let requester = ConnectionManagerRequester::new(request_tx, event_tx.clone()); - let connection_manager = ConnectionManager::new( + let mut connection_manager = ConnectionManager::new( config.connection_manager_config, config.transport, noise_config, @@ -89,12 +89,12 @@ pub fn build_connection_manager( request_rx, config.node_identity, peer_manager.into(), - protocols, event_tx, shutdown, ); + connection_manager.set_protocols(protocols); - runtime::current_executor().spawn(connection_manager.run()); + runtime::current().spawn(connection_manager.run()); requester } diff --git a/comms/src/tor/control_client/test_server.rs b/comms/src/tor/control_client/test_server.rs index 8e46c9a31b..1feaf9aa1b 100644 --- a/comms/src/tor/control_client/test_server.rs +++ b/comms/src/tor/control_client/test_server.rs @@ -36,7 +36,7 @@ pub async fn spawn() -> (Multiaddr, State, MemorySocket) { let server = TorControlPortTestServer::new(socket_in); let state = server.get_shared_state(); - runtime::current_executor().spawn(server.run()); + runtime::current().spawn(server.run()); (addr, state, socket_out) }