diff --git a/comms/examples/stress/node.rs b/comms/examples/stress/node.rs index d060d18071..84a5f9ffbd 100644 --- a/comms/examples/stress/node.rs +++ b/comms/examples/stress/node.rs @@ -38,6 +38,7 @@ use tari_comms::{ NodeIdentity, Substream, }; +use tari_shutdown::ShutdownSignal; use tari_storage::{ lmdb_store::{LMDBBuilder, LMDBConfig}, LMDBWrapper, @@ -51,6 +52,7 @@ pub async fn create( port: u16, tor_identity: Option<TorIdentity>, is_tcp: bool, + shutdown_signal: ShutdownSignal, ) -> Result< ( CommsNode, @@ -94,6 +96,7 @@ pub async fn create( let builder = CommsBuilder::new() .allow_test_addresses() + .with_shutdown_signal(shutdown_signal) .with_node_identity(node_identity.clone()) .with_dial_backoff(ConstantBackoff::new(Duration::from_secs(0))) .with_peer_storage(peer_database, None) diff --git a/comms/examples/stress/service.rs b/comms/examples/stress/service.rs index 45e2bc0fd3..e9683d045e 100644 --- a/comms/examples/stress/service.rs +++ b/comms/examples/stress/service.rs @@ -41,6 +41,7 @@ use tari_comms::{ Substream, }; use tari_crypto::tari_utilities::hex::Hex; +use tari_shutdown::Shutdown; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, sync::{mpsc, oneshot, RwLock}, @@ -54,6 +55,7 @@ pub fn start_service( protocol_notif: mpsc::Receiver<ProtocolNotification<Substream>>, inbound_rx: mpsc::Receiver<InboundMessage>, outbound_tx: mpsc::Sender<OutboundMessage>, + shutdown: Shutdown, ) -> (JoinHandle<Result<(), Error>>, mpsc::Sender<StressTestServiceRequest>) { let node_identity = comms_node.node_identity(); let (request_tx, request_rx) = mpsc::channel(1); @@ -65,7 +67,14 @@ pub fn start_service( comms_node.listening_address(), ); - let service = StressTestService::new(request_rx, comms_node, protocol_notif, inbound_rx, outbound_tx); + let service = StressTestService::new( + request_rx, + comms_node, + protocol_notif, + inbound_rx, + outbound_tx, + shutdown, + ); (task::spawn(service.start()), request_tx) } @@ -138,10 +147,11 @@ struct StressTestService { request_rx: mpsc::Receiver<StressTestServiceRequest>, comms_node: CommsNode, protocol_notif: mpsc::Receiver<ProtocolNotification<Substream>>, - shutdown: bool, inbound_rx: Arc<RwLock<mpsc::Receiver<InboundMessage>>>, outbound_tx: mpsc::Sender<OutboundMessage>, + + shutdown: Shutdown, } impl StressTestService { @@ -151,12 +161,13 @@ impl StressTestService { protocol_notif: mpsc::Receiver<ProtocolNotification<Substream>>, inbound_rx: mpsc::Receiver<InboundMessage>, outbound_tx: mpsc::Sender<OutboundMessage>, + shutdown: Shutdown, ) -> Self { Self { request_rx, comms_node, protocol_notif, - shutdown: false, + shutdown, inbound_rx: Arc::new(RwLock::new(inbound_rx)), outbound_tx, } @@ -164,6 +175,7 @@ impl StressTestService { async fn start(mut self) -> Result<(), Error> { let mut events = self.comms_node.subscribe_connectivity_events(); + let mut shutdown_signal = self.shutdown.to_signal(); loop { tokio::select! { @@ -180,10 +192,9 @@ impl StressTestService { Some(notif) = self.protocol_notif.recv() => { self.handle_protocol_notification(notif).await; }, - } - - if self.shutdown { - break; + _ = shutdown_signal.wait() => { + break; + } } } @@ -197,7 +208,7 @@ impl StressTestService { match request { BeginProtocol(peer, protocol, reply) => self.begin_protocol(peer, protocol, reply).await?, Shutdown => { - self.shutdown = true; + self.shutdown.trigger(); }, } diff --git a/comms/examples/stress_test.rs b/comms/examples/stress_test.rs index 3a0c04f020..7e15ed879c 100644 --- a/comms/examples/stress_test.rs +++ b/comms/examples/stress_test.rs @@ -27,6 +27,7 @@ use crate::stress::{node, prompt::parse_from_short_str, service, service::Stress use futures::{future, future::Either}; use std::{env, net::Ipv4Addr, path::Path, process, sync::Arc, time::Duration}; use tari_crypto::tari_utilities::message_format::MessageFormat; +use tari_shutdown::Shutdown; use tempfile::Builder; use tokio::{sync::oneshot, time}; @@ -85,10 +86,19 @@ async fn run() -> Result<(), Error> { let tor_identity = tor_identity_path.as_ref().and_then(load_json); let node_identity = node_identity_path.as_ref().and_then(load_json).map(Arc::new); + let shutdown = Shutdown::new(); let temp_dir = Builder::new().prefix("stress-test").tempdir().unwrap(); - let (comms_node, protocol_notif, inbound_rx, outbound_tx) = - node::create(node_identity, temp_dir.as_ref(), public_ip, port, tor_identity, is_tcp).await?; + let (comms_node, protocol_notif, inbound_rx, outbound_tx) = node::create( + node_identity, + temp_dir.as_ref(), + public_ip, + port, + tor_identity, + is_tcp, + shutdown.to_signal(), + ) + .await?; if let Some(node_identity_path) = node_identity_path.as_ref() { save_json(comms_node.node_identity_ref(), node_identity_path)?; } @@ -99,7 +109,7 @@ async fn run() -> Result<(), Error> { } println!("Stress test service started!"); - let (handle, requester) = service::start_service(comms_node, protocol_notif, inbound_rx, outbound_tx); + let (handle, requester) = service::start_service(comms_node, protocol_notif, inbound_rx, outbound_tx, shutdown); let mut last_peer = peer.as_ref().and_then(parse_from_short_str); diff --git a/comms/src/builder/comms_node.rs b/comms/src/builder/comms_node.rs index abd71e8952..f7a47b16f0 100644 --- a/comms/src/builder/comms_node.rs +++ b/comms/src/builder/comms_node.rs @@ -190,7 +190,7 @@ impl UnspawnedCommsNode { connection_manager.add_protocols(protocols); //---------------------------------- Spawn Actors --------------------------------------------// - connectivity_manager.create().spawn(); + connectivity_manager.spawn(); connection_manager.spawn(); info!(target: LOG_TARGET, "Hello from comms!"); diff --git a/comms/src/connection_manager/tests/manager.rs b/comms/src/connection_manager/tests/manager.rs index 910280b4cf..306aa9a63b 100644 --- a/comms/src/connection_manager/tests/manager.rs +++ b/comms/src/connection_manager/tests/manager.rs @@ -391,6 +391,8 @@ async fn dial_cancelled() { ..Default::default() }; config.connection_manager_config.network_info.user_agent = "node1".to_string(); + // To ensure that dial takes a long time so that we can test cancelling it + config.connection_manager_config.max_dial_attempts = 100; config }, MemoryTransport, diff --git a/comms/src/connectivity/config.rs b/comms/src/connectivity/config.rs index 743c8dd741..fa6919f626 100644 --- a/comms/src/connectivity/config.rs +++ b/comms/src/connectivity/config.rs @@ -29,12 +29,12 @@ pub struct ConnectivityConfig { /// Default: 30% pub min_connectivity: f32, /// Interval to check the connection pool, including reaping inactive connections and retrying failed managed peer - /// connections. Default: 30s + /// connections. Default: 60s pub connection_pool_refresh_interval: Duration, /// True if connection reaping is enabled, otherwise false (default: true) pub is_connection_reaping_enabled: bool, /// The minimum age of the connection before it can be reaped. This prevents a connection that has just been - /// established from being reaped due to inactivity. + /// established from being reaped due to inactivity. Default: 20 minutes pub reaper_min_inactive_age: Duration, /// The number of connection failures before a peer is considered offline /// Default: 1 @@ -48,8 +48,8 @@ impl Default for ConnectivityConfig { fn default() -> Self { Self { min_connectivity: 0.3, - connection_pool_refresh_interval: Duration::from_secs(30), - reaper_min_inactive_age: Duration::from_secs(60), + connection_pool_refresh_interval: Duration::from_secs(60), + reaper_min_inactive_age: Duration::from_secs(20 * 60), is_connection_reaping_enabled: true, max_failures_mark_offline: 2, connection_tie_break_linger: Duration::from_secs(2), diff --git a/comms/src/connectivity/manager.rs b/comms/src/connectivity/manager.rs index e172b53c2d..1f892fa71c 100644 --- a/comms/src/connectivity/manager.rs +++ b/comms/src/connectivity/manager.rs @@ -80,7 +80,7 @@ pub struct ConnectivityManager { } impl ConnectivityManager { - pub fn create(self) -> ConnectivityManagerActor { + pub fn spawn(self) -> JoinHandle<()> { ConnectivityManagerActor { config: self.config, status: ConnectivityStatus::Initializing, @@ -90,12 +90,11 @@ impl ConnectivityManager { event_tx: self.event_tx, connection_stats: HashMap::new(), node_identity: self.node_identity, - managed_peers: Vec::new(), - - shutdown_signal: Some(self.shutdown_signal), pool: ConnectionPool::new(), + shutdown_signal: self.shutdown_signal, } + .spawn() } } @@ -137,19 +136,18 @@ impl fmt::Display for ConnectivityStatus { } } -pub struct ConnectivityManagerActor { +struct ConnectivityManagerActor { config: ConnectivityConfig, status: ConnectivityStatus, request_rx: mpsc::Receiver<ConnectivityRequest>, connection_manager: ConnectionManagerRequester, node_identity: Arc<NodeIdentity>, - shutdown_signal: Option<ShutdownSignal>, peer_manager: Arc<PeerManager>, event_tx: ConnectivityEventTx, connection_stats: HashMap<NodeId, PeerConnectionStats>, - managed_peers: Vec<NodeId>, pool: ConnectionPool, + shutdown_signal: ShutdownSignal, } impl ConnectivityManagerActor { @@ -160,10 +158,6 @@ impl ConnectivityManagerActor { #[tracing::instrument(name = "connectivity_manager_actor::run", skip(self))] pub async fn run(mut self) { info!(target: LOG_TARGET, "ConnectivityManager started"); - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("ConnectivityManager initialized without a shutdown_signal"); let mut connection_manager_events = self.connection_manager.get_event_subscription(); @@ -199,7 +193,7 @@ impl ConnectivityManagerActor { } }, - _ = &mut shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "ConnectivityManager is shutting down because it received the shutdown signal"); self.disconnect_all().await; break; diff --git a/comms/src/connectivity/test.rs b/comms/src/connectivity/test.rs index 948d083e94..552d3aee40 100644 --- a/comms/src/connectivity/test.rs +++ b/comms/src/connectivity/test.rs @@ -76,7 +76,6 @@ fn setup_connectivity_manager( peer_manager: peer_manager.clone(), shutdown_signal: shutdown.to_signal(), } - .create() .spawn(); ( diff --git a/comms/src/multiplexing/yamux.rs b/comms/src/multiplexing/yamux.rs index 17558133f2..1723033739 100644 --- a/comms/src/multiplexing/yamux.rs +++ b/comms/src/multiplexing/yamux.rs @@ -136,10 +136,12 @@ impl Control { /// Open a new stream to the remote. pub async fn open_stream(&mut self) -> Result<Substream, ConnectionError> { + // Ensure that this counts as used while the substream is being opened + let counter_guard = self.substream_counter.new_guard(); let stream = self.inner.open_stream().await?; Ok(Substream { stream: stream.compat(), - counter_guard: self.substream_counter.new_guard(), + counter_guard, }) }