diff --git a/base_layer/contacts/tests/contacts_service.rs b/base_layer/contacts/tests/contacts_service.rs index 73acc051ae..f960ca9bd7 100644 --- a/base_layer/contacts/tests/contacts_service.rs +++ b/base_layer/contacts/tests/contacts_service.rs @@ -88,6 +88,7 @@ pub fn setup_contacts_service( auto_request: true, ..Default::default() }, + excluded_dial_addresses: vec![], ..Default::default() }, allow_test_addresses: true, diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 843f629d2d..7579d5954e 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -331,7 +331,8 @@ async fn configure_comms_and_dht( .with_listener_liveness_max_sessions(config.listener_liveness_max_sessions) .with_listener_liveness_allowlist_cidrs(listener_liveness_allowlist_cidrs) .with_dial_backoff(ConstantBackoff::new(Duration::from_millis(500))) - .with_peer_storage(peer_database, Some(file_lock)); + .with_peer_storage(peer_database, Some(file_lock)) + .with_excluded_dial_addresses(config.dht.excluded_dial_addresses.clone()); let mut comms = match config.auxiliary_tcp_listener_address { Some(ref addr) => builder.with_auxiliary_tcp_listener_address(addr.clone()).build()?, diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index 421d47486c..895163c93c 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -5327,6 +5327,7 @@ pub unsafe extern "C" fn comms_config_create( minimum_desired_tcpv4_node_ratio: 0.0, ..Default::default() }, + excluded_dial_addresses: vec![], ..Default::default() }, allow_test_addresses: true, diff --git a/common/config/presets/c_base_node_c.toml b/common/config/presets/c_base_node_c.toml index 36d9c6b5fc..4c7b147950 100644 --- a/common/config/presets/c_base_node_c.toml +++ b/common/config/presets/c_base_node_c.toml @@ -316,3 +316,5 @@ database_url = "data/base_node/dht.db" # In a situation where a node is not well-connected and many nodes are locally marked as offline, we can retry # peers that were previously tried. Default: 2 hours #offline_peer_cooldown = 7_200 # 2 * 60 * 60 +# Addresses that should never be dialed (default value = []) +#excluded_dial_addresses = ["/ip4/x.x.x.x/tcp/xxxx", "/ip4/x.y.x.y/tcp/xyxy"] diff --git a/common/config/presets/d_console_wallet.toml b/common/config/presets/d_console_wallet.toml index a8d3225ba0..93b5a8f920 100644 --- a/common/config/presets/d_console_wallet.toml +++ b/common/config/presets/d_console_wallet.toml @@ -168,7 +168,7 @@ event_channel_size = 3500 # peers can find you. # _NOTE_: If using the `tor` transport type, public_address will be ignored and an onion address will be # automatically configured -#public_addresses = ["/ip4/172.2.3.4/tcp/18189",] +#public_addresses = ["/ip4/172.2.3.4/tcp/18188",] # Optionally bind an additional TCP socket for inbound Tari P2P protocol commms. # Use cases include: @@ -360,3 +360,5 @@ network_discovery.initial_peer_sync_delay = 25 # In a situation where a node is not well-connected and many nodes are locally marked as offline, we can retry # peers that were previously tried. Default: 2 hours #offline_peer_cooldown = 7_200 # 2 * 60 * 60 +# Addresses that should never be dialed (default value = []) +#excluded_dial_addresses = ["/ip4/x.x.x.x/tcp/xxxx", "/ip4/x.y.x.y/tcp/xyxy"] diff --git a/comms/core/src/builder/mod.rs b/comms/core/src/builder/mod.rs index 5cae88e774..727d13cf6f 100644 --- a/comms/core/src/builder/mod.rs +++ b/comms/core/src/builder/mod.rs @@ -242,6 +242,11 @@ impl CommsBuilder { self } + pub fn with_excluded_dial_addresses(mut self, excluded_addresses: Vec) -> Self { + self.connection_manager_config.excluded_dial_addresses = excluded_addresses; + self + } + /// Restrict liveness sessions to certain address ranges (CIDR format). pub fn with_listener_liveness_allowlist_cidrs(mut self, cidrs: Vec) -> Self { self.connection_manager_config.liveness_cidr_allowlist = cidrs; diff --git a/comms/core/src/connection_manager/dialer.rs b/comms/core/src/connection_manager/dialer.rs index 32d399bb7c..245d3e4308 100644 --- a/comms/core/src/connection_manager/dialer.rs +++ b/comms/core/src/connection_manager/dialer.rs @@ -174,6 +174,7 @@ where fn handle_request(&mut self, pending_dials: &mut DialFuturesUnordered, request: DialerRequest) { use DialerRequest::{CancelPendingDial, Dial, NotifyNewInboundConnection}; debug!(target: LOG_TARGET, "Connection dialer got request: {:?}", request); + match request { Dial(peer, reply_tx) => { self.handle_dial_peer_request(pending_dials, peer, reply_tx); @@ -515,7 +516,7 @@ where tokio::select! { _ = delay => { debug!(target: LOG_TARGET, "[Attempt {}] Connecting to peer '{}'", current_state.num_attempts(), current_state.peer().node_id.short_str()); - match Self::dial_peer(current_state, &noise_config, ¤t_transport, config.network_info.network_wire_byte).await { + match Self::dial_peer(current_state, &noise_config, ¤t_transport, config.network_info.network_wire_byte, config.excluded_dial_addresses.clone()).await { (state, Ok((socket, addr))) => { debug!(target: LOG_TARGET, "Dial succeeded for peer '{}' after {} attempt(s)", state.peer().node_id.short_str(), state.num_attempts()); break (state, Ok((socket, addr))); @@ -524,6 +525,8 @@ where (state, Err(ConnectionManagerError::NoiseHandshakeError(e))) => break (state, Err(ConnectionManagerError::NoiseHandshakeError(e))), // Inflight dial was cancelled (state, Err(ConnectionManagerError::DialCancelled)) => break (state, Err(ConnectionManagerError::DialCancelled)), + // All public addresses for this peer are excluded + (state, Err(ConnectionManagerError::AllPeerAddressesAreExcluded(e))) => break (state, Err(ConnectionManagerError::AllPeerAddressesAreExcluded(e))), (state, Err(err)) => { debug!(target: LOG_TARGET, "Failed to dial peer {} | Attempt {} | Error: {}", state.peer().node_id.short_str(), state.num_attempts(), err); if state.num_attempts() >= config.max_dial_attempts { @@ -554,6 +557,7 @@ where noise_config: &NoiseConfig, transport: &TTransport, network_byte: u8, + excluded_dial_addresses: Vec, ) -> ( DialState, Result<(NoiseSocket, Multiaddr), ConnectionManagerError>, @@ -564,10 +568,7 @@ where .clone() .into_vec() .iter() - .filter(|&a| { - a == &"/memory/0".parse::().expect("will not fail") || // Used for tests, allowed - a != &ConnectionManagerConfig::default().listener_address // Not allowed to dial the default - }) + .filter(|&a| !excluded_dial_addresses.iter().any(|excluded| a == excluded)) .cloned() .collect::>(); if addresses.is_empty() { @@ -577,10 +578,17 @@ where "Dial - No more contactable addresses for peer '{}'", node_id_hex ); - return ( - dial_state, - Err(ConnectionManagerError::NoContactableAddressesForPeer(node_id_hex)), - ); + return if dial_state.peer().addresses.is_empty() { + ( + dial_state, + Err(ConnectionManagerError::NoContactableAddressesForPeer(node_id_hex)), + ) + } else { + ( + dial_state, + Err(ConnectionManagerError::AllPeerAddressesAreExcluded(node_id_hex)), + ) + }; } let cancel_signal = dial_state.get_cancel_signal(); for address in addresses { diff --git a/comms/core/src/connection_manager/error.rs b/comms/core/src/connection_manager/error.rs index 7d5b174329..d229aefadd 100644 --- a/comms/core/src/connection_manager/error.rs +++ b/comms/core/src/connection_manager/error.rs @@ -90,6 +90,8 @@ pub enum ConnectionManagerError { PeerValidationError(#[from] PeerValidatorError), #[error("No contactable addresses for peer {0} left")] NoContactableAddressesForPeer(String), + #[error("All peer addresses are excluded for peer {0}")] + AllPeerAddressesAreExcluded(String), #[error("Yamux error: {0}")] YamuxControlError(#[from] YamuxControlError), } diff --git a/comms/core/src/connection_manager/manager.rs b/comms/core/src/connection_manager/manager.rs index ff88f17501..67c28679cd 100644 --- a/comms/core/src/connection_manager/manager.rs +++ b/comms/core/src/connection_manager/manager.rs @@ -133,6 +133,8 @@ pub struct ConnectionManagerConfig { pub auxiliary_tcp_listener_address: Option, /// Peer validation configuration. See [PeerValidatorConfig] pub peer_validation_config: PeerValidatorConfig, + /// Addresses that should never be dialed + pub excluded_dial_addresses: Vec, } impl Default for ConnectionManagerConfig { @@ -154,6 +156,7 @@ impl Default for ConnectionManagerConfig { auxiliary_tcp_listener_address: None, peer_validation_config: PeerValidatorConfig::default(), noise_handshake_recv_timeout: Duration::from_secs(6), + excluded_dial_addresses: vec![], } } } diff --git a/comms/core/src/connection_manager/tests/listener_dialer.rs b/comms/core/src/connection_manager/tests/listener_dialer.rs index 6db482ac33..a1c244b838 100644 --- a/comms/core/src/connection_manager/tests/listener_dialer.rs +++ b/comms/core/src/connection_manager/tests/listener_dialer.rs @@ -256,3 +256,138 @@ async fn banned() { timeout(Duration::from_secs(5), dialer_fut).await.unwrap().unwrap(); } + +#[tokio::test] +async fn excluded_yes() { + let (event_tx, _event_rx) = mpsc::channel(10); + let mut shutdown = Shutdown::new(); + + let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); + let noise_config1 = NoiseConfig::new(node_identity1.clone()); + let expected_proto = ProtocolId::from_static(b"/tari/test-proto"); + let supported_protocols = vec![expected_proto.clone()]; + let peer_manager1 = build_peer_manager(); + let mut listener = PeerListener::new( + Default::default(), + "/memory/0".parse().unwrap(), + MemoryTransport, + noise_config1.clone(), + event_tx.clone(), + peer_manager1.clone(), + node_identity1.clone(), + shutdown.to_signal(), + ); + listener.set_supported_protocols(supported_protocols.clone()); + + // Get the listener address of the peer + let address = listener.listen().await.unwrap(); + + let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); + let noise_config2 = NoiseConfig::new(node_identity2.clone()); + let (request_tx, request_rx) = mpsc::channel(1); + let peer_manager2 = build_peer_manager(); + let connection_manager_config = ConnectionManagerConfig { + excluded_dial_addresses: vec![address.clone()], + ..Default::default() + }; + let mut dialer = Dialer::new( + connection_manager_config, + node_identity2.clone(), + peer_manager2.clone(), + MemoryTransport, + noise_config2.clone(), + ConstantBackoff::new(Duration::from_millis(100)), + request_rx, + event_tx.clone(), + shutdown.to_signal(), + ); + dialer.set_supported_protocols(supported_protocols.clone()); + + let dialer_fut = tokio::spawn(dialer.run()); + + let mut peer = node_identity1.to_peer(); + peer.addresses = MultiaddressesWithStats::from_addresses_with_source(vec![address], &PeerAddressSource::Config); + peer.set_id_for_test(1); + + let (reply_tx, reply_rx) = oneshot::channel(); + request_tx + .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx))) + .await + .unwrap(); + + // Check that the dial failed. We're checking that the dial attempt was never made. + let res = reply_rx.await.unwrap(); + assert_eq!( + format!("{:?}", res), + format!("Err(AllPeerAddressesAreExcluded(\"{}\"))", node_identity1.node_id()) + ); + + shutdown.trigger(); + timeout(Duration::from_secs(5), dialer_fut).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn excluded_no() { + let (event_tx, _event_rx) = mpsc::channel(10); + let mut shutdown = Shutdown::new(); + + let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); + let noise_config1 = NoiseConfig::new(node_identity1.clone()); + let expected_proto = ProtocolId::from_static(b"/tari/test-proto"); + let supported_protocols = vec![expected_proto.clone()]; + let peer_manager1 = build_peer_manager(); + let mut listener = PeerListener::new( + Default::default(), + "/memory/0".parse().unwrap(), + MemoryTransport, + noise_config1.clone(), + event_tx.clone(), + peer_manager1.clone(), + node_identity1.clone(), + shutdown.to_signal(), + ); + listener.set_supported_protocols(supported_protocols.clone()); + + // Get the listener address of the peer + let address = listener.listen().await.unwrap(); + + let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); + let noise_config2 = NoiseConfig::new(node_identity2.clone()); + let (request_tx, request_rx) = mpsc::channel(1); + let peer_manager2 = build_peer_manager(); + let connection_manager_config = ConnectionManagerConfig { + excluded_dial_addresses: vec![], + ..Default::default() + }; + let mut dialer = Dialer::new( + connection_manager_config, + node_identity2.clone(), + peer_manager2.clone(), + MemoryTransport, + noise_config2.clone(), + ConstantBackoff::new(Duration::from_millis(100)), + request_rx, + event_tx.clone(), + shutdown.to_signal(), + ); + dialer.set_supported_protocols(supported_protocols.clone()); + + let dialer_fut = tokio::spawn(dialer.run()); + + let mut peer = node_identity1.to_peer(); + peer.addresses = MultiaddressesWithStats::from_addresses_with_source(vec![address], &PeerAddressSource::Config); + peer.set_id_for_test(1); + + let (reply_tx, reply_rx) = oneshot::channel(); + request_tx + .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx))) + .await + .unwrap(); + + // Check that the dial failed. We're checking that the dial attempt was never made. + let res = reply_rx.await.unwrap(); + assert!(res.is_ok()); + + shutdown.trigger(); + timeout(Duration::from_secs(5), dialer_fut).await.unwrap().unwrap(); +} diff --git a/comms/core/src/connectivity/manager.rs b/comms/core/src/connectivity/manager.rs index 8a59fb5ae8..899a64c0cb 100644 --- a/comms/core/src/connectivity/manager.rs +++ b/comms/core/src/connectivity/manager.rs @@ -659,6 +659,15 @@ impl ConnectivityManagerActor { let (node_id, mut new_status, connection) = match event { PeerDisconnected(_, node_id, minimized) => (node_id, ConnectionStatus::Disconnected(*minimized), None), PeerConnected(conn) => (conn.peer_node_id(), ConnectionStatus::Connected, Some(conn.clone())), + PeerConnectFailed(node_id, ConnectionManagerError::AllPeerAddressesAreExcluded(msg)) => { + debug!( + target: LOG_TARGET, + "Peer '{}' contains only excluded addresses ({})", + node_id, + msg + ); + (node_id, ConnectionStatus::Failed, None) + }, PeerConnectFailed(node_id, ConnectionManagerError::NoiseHandshakeError(msg)) => { if let Some(conn) = self.pool.get_connection(node_id) { warn!( diff --git a/comms/core/src/peer_manager/manager.rs b/comms/core/src/peer_manager/manager.rs index c1e5efc2f0..fe3e001c97 100644 --- a/comms/core/src/peer_manager/manager.rs +++ b/comms/core/src/peer_manager/manager.rs @@ -347,6 +347,17 @@ impl PeerManager { Ok(peer.features) } + pub async fn get_peer_multi_addresses( + &self, + node_id: &NodeId, + ) -> Result { + let peer = self + .find_by_node_id(node_id) + .await? + .ok_or(PeerManagerError::PeerNotFoundError)?; + Ok(peer.addresses) + } + /// This will store metadata inside of the metadata field in the peer provided by the nodeID. /// It will return None if the value was empty and the old value if the value was updated pub async fn set_peer_metadata( diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 6f86bf3df3..57790c0d44 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -35,6 +35,7 @@ use log::*; use tari_comms::{ connection_manager::ConnectionManagerError, connectivity::{ConnectivityError, ConnectivityRequester, ConnectivitySelection}, + multiaddr::Multiaddr, peer_manager::{NodeId, NodeIdentity, PeerFeatures, PeerManager, PeerManagerError, PeerQuery, PeerQuerySortBy}, types::CommsPublicKey, PeerConnection, @@ -88,6 +89,8 @@ pub enum DhtActorError { ConnectivityError(#[from] ConnectivityError), #[error("Connectivity event stream closed")] ConnectivityEventStreamClosed, + #[error("All peer addresses are excluded")] + AllPeerAddressesAreExcluded, } impl From> for DhtActorError { @@ -381,6 +384,28 @@ impl DhtActor { } } + // Helper function to check if all peer addresses are excluded + async fn check_if_addresses_excluded( + excluded_dial_addresses: Vec, + peer_manager: &PeerManager, + node_id: NodeId, + ) -> Result<(), DhtActorError> { + if !excluded_dial_addresses.is_empty() { + let addresses = peer_manager.get_peer_multi_addresses(&node_id).await?; + if addresses + .iter() + .all(|addr| excluded_dial_addresses.contains(addr.address())) + { + warn!( + target: LOG_TARGET, + "All peer addresses are excluded. Not broadcasting join message." + ); + return Err(DhtActorError::AllPeerAddressesAreExcluded); + } + } + Ok(()) + } + #[allow(clippy::too_many_lines)] fn request_handler(&mut self, request: DhtRequest) -> BoxFuture<'static, Result<(), DhtActorError>> { #[allow(clippy::enum_glob_use)] @@ -388,8 +413,15 @@ impl DhtActor { match request { SendJoin => { let node_identity = Arc::clone(&self.node_identity); + let peer_manager = Arc::clone(&self.peer_manager); let outbound_requester = self.outbound_requester.clone(); - Box::pin(Self::broadcast_join(node_identity, outbound_requester)) + let excluded_dial_addresses = self.config.excluded_dial_addresses.clone(); + Box::pin(Self::broadcast_join( + node_identity, + peer_manager, + excluded_dial_addresses, + outbound_requester, + )) }, MsgHashCacheInsert { message_hash, @@ -465,7 +497,16 @@ impl DhtActor { let connectivity = self.connectivity.clone(); let discovery = self.discovery.clone(); let peer_manager = self.peer_manager.clone(); + let node_identity = self.node_identity.clone(); + let excluded_dial_addresses = self.config.excluded_dial_addresses.clone(); + Box::pin(async move { + DhtActor::check_if_addresses_excluded( + excluded_dial_addresses, + &peer_manager, + node_identity.node_id().clone(), + ) + .await?; let mut task = DiscoveryDialTask::new(connectivity, peer_manager, discovery); let result = task.run(public_key).await; let _result = reply.send(result); @@ -491,8 +532,16 @@ impl DhtActor { async fn broadcast_join( node_identity: Arc, + peer_manager: Arc, + excluded_dial_addresses: Vec, mut outbound_requester: OutboundMessageRequester, ) -> Result<(), DhtActorError> { + DhtActor::check_if_addresses_excluded( + excluded_dial_addresses, + peer_manager.as_ref(), + node_identity.node_id().clone(), + ) + .await?; let message = JoinMessage::from(&node_identity); debug!(target: LOG_TARGET, "Sending Join message to closest peers"); @@ -524,14 +573,14 @@ impl DhtActor { ) -> Result, DhtActorError> { #[allow(clippy::enum_glob_use)] use BroadcastStrategy::*; - match broadcast_strategy { + let peers = match broadcast_strategy { DirectNodeId(node_id) => { // Send to a particular peer matching the given node ID peer_manager .direct_identity_node_id(&node_id) .await .map(|peer| peer.map(|p| vec![p.node_id]).unwrap_or_default()) - .map_err(Into::into) + .map_err(Into::::into)? }, DirectPublicKey(public_key) => { // Send to a particular peer matching the given node ID @@ -539,16 +588,16 @@ impl DhtActor { .direct_identity_public_key(&public_key) .await .map(|peer| peer.map(|p| vec![p.node_id]).unwrap_or_default()) - .map_err(Into::into) + .map_err(Into::::into)? }, Flood(exclude) => { let peers = connectivity .select_connections(ConnectivitySelection::all_nodes(exclude)) .await?; - Ok(peers.into_iter().map(|p| p.peer_node_id().clone()).collect()) + peers.into_iter().map(|p| p.peer_node_id().clone()).collect() }, ClosestNodes(closest_request) => { - Self::select_closest_node_connected(closest_request, config, connectivity, peer_manager).await + Self::select_closest_node_connected(closest_request, config, connectivity, peer_manager.clone()).await? }, DirectOrClosestNodes(closest_request) => { // First check if a direct connection exists @@ -557,20 +606,22 @@ impl DhtActor { .await? .is_some() { - return Ok(vec![closest_request.node_id.clone()]); + vec![closest_request.node_id.clone()] + } else { + Self::select_closest_node_connected(closest_request, config, connectivity, peer_manager.clone()) + .await? } - Self::select_closest_node_connected(closest_request, config, connectivity, peer_manager).await }, Random(n, excluded) => { // Send to a random set of peers of size n that are Communication Nodes - Ok(peer_manager + peer_manager .random_peers(n, &excluded) .await? .into_iter() .map(|p| p.node_id) - .collect()) + .collect() }, - SelectedPeers(peers) => Ok(peers), + SelectedPeers(peers) => peers, Broadcast(exclude) => { let connections = connectivity .select_connections(ConnectivitySelection::random_nodes( @@ -597,7 +648,7 @@ impl DhtActor { candidates.len() ); - Ok(candidates) + candidates }, Propagate(destination, exclude) => { let dest_node_id = destination.to_derived_node_id(); @@ -687,8 +738,34 @@ impl DhtActor { candidates.iter().map(|n| n.short_str()).collect::>().join(", ") ); - Ok(candidates) + candidates }, + }; + if config.excluded_dial_addresses.is_empty() { + return Ok(peers); + }; + + let mut filtered_peers = Vec::with_capacity(peers.len()); + for id in &peers { + let addresses = peer_manager.get_peer_multi_addresses(id).await?; + if addresses + .iter() + .all(|addr| config.excluded_dial_addresses.contains(addr.address())) + { + trace!(target: LOG_TARGET, "Peer '{}' has only excluded addresses. Skipping.", id); + } else { + filtered_peers.push(id.clone()); + } + } + + if filtered_peers.is_empty() { + warn!( + target: LOG_TARGET, + "All selected peers have only excluded addresses. No peers will be selected." + ); + Err(DhtActorError::AllPeerAddressesAreExcluded) + } else { + Ok(filtered_peers) } } diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index ec502499bc..6f3053539a 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -24,7 +24,7 @@ use std::{path::Path, time::Duration}; use serde::{Deserialize, Serialize}; use tari_common::configuration::serializers; -use tari_comms::peer_validator::PeerValidatorConfig; +use tari_comms::{multiaddr::Multiaddr, peer_validator::PeerValidatorConfig}; use crate::{ actor::OffenceSeverity, @@ -115,6 +115,8 @@ pub struct DhtConfig { /// Configuration for peer validation /// See [PeerValidatorConfig] pub peer_validator_config: PeerValidatorConfig, + /// Addresses that should never be dialed + pub excluded_dial_addresses: Vec, } impl DhtConfig { @@ -193,6 +195,7 @@ impl Default for DhtConfig { max_permitted_peer_claims: 5, offline_peer_cooldown: Duration::from_secs(24 * 60 * 60), peer_validator_config: Default::default(), + excluded_dial_addresses: vec![], } } } diff --git a/comms/dht/src/connectivity/mod.rs b/comms/dht/src/connectivity/mod.rs index a10c674974..b0294e9184 100644 --- a/comms/dht/src/connectivity/mod.rs +++ b/comms/dht/src/connectivity/mod.rs @@ -368,7 +368,7 @@ impl DhtConnectivity { } async fn refresh_neighbour_pool(&mut self, try_revive_connections: bool) -> Result<(), DhtConnectivityError> { - self.remove_allow_list_peers_from_pools().await?; + self.remove_unmanaged_peers_from_pools().await?; let mut new_neighbours = self .fetch_neighbouring_peers(self.config.num_neighbouring_nodes, &[], try_revive_connections) .await?; @@ -457,7 +457,7 @@ impl DhtConnectivity { } async fn refresh_random_pool(&mut self) -> Result<(), DhtConnectivityError> { - self.remove_allow_list_peers_from_pools().await?; + self.remove_unmanaged_peers_from_pools().await?; let mut exclude = self.neighbours.clone(); if self.config.minimize_connections { exclude.extend(self.previous_random.iter().cloned()); @@ -505,7 +505,7 @@ impl DhtConnectivity { } async fn handle_new_peer_connected(&mut self, conn: PeerConnection) -> Result<(), DhtConnectivityError> { - self.remove_allow_list_peers_from_pools().await?; + self.remove_unmanaged_peers_from_pools().await?; if conn.peer_features().is_client() { debug!( target: LOG_TARGET, @@ -640,7 +640,7 @@ impl DhtConnectivity { "Failed to clear metrics for peer `{}`. Metric collector is shut down.", node_id ); }; - self.remove_allow_list_peers_from_pools().await?; + self.remove_unmanaged_peers_from_pools().await?; if !self.is_pool_peer(&node_id) { debug!(target: LOG_TARGET, "{} is not managed by the DHT. Ignoring", node_id); return Ok(()); @@ -662,7 +662,7 @@ impl DhtConnectivity { "Failed to clear metrics for peer `{}`. Metric collector is shut down.", node_id ); }; - self.remove_allow_list_peers_from_pools().await?; + self.remove_unmanaged_peers_from_pools().await?; if !self.is_pool_peer(&node_id) { debug!(target: LOG_TARGET, "{} is not managed by the DHT. Ignoring", node_id); return Ok(()); @@ -732,7 +732,7 @@ impl DhtConnectivity { } async fn replace_pool_peer(&mut self, current_peer: &NodeId) -> Result<(), DhtConnectivityError> { - self.remove_allow_list_peers_from_pools().await?; + self.remove_unmanaged_peers_from_pools().await?; if self.is_allow_list_peer(current_peer).await? { debug!( target: LOG_TARGET, @@ -853,6 +853,11 @@ impl DhtConnectivity { } } + async fn remove_unmanaged_peers_from_pools(&mut self) -> Result<(), DhtConnectivityError> { + self.remove_allow_list_peers_from_pools().await?; + self.remove_exlcuded_peers_from_pools().await + } + async fn remove_allow_list_peers_from_pools(&mut self) -> Result<(), DhtConnectivityError> { let allow_list = self.peer_allow_list().await?; self.neighbours.retain(|n| !allow_list.contains(n)); @@ -860,6 +865,35 @@ impl DhtConnectivity { Ok(()) } + async fn remove_exlcuded_peers_from_pools(&mut self) -> Result<(), DhtConnectivityError> { + if !self.config.excluded_dial_addresses.is_empty() { + let mut neighbours = Vec::with_capacity(self.neighbours.len()); + for peer in &self.neighbours { + let addresses = self.peer_manager.get_peer_multi_addresses(peer).await?; + if !addresses + .iter() + .all(|addr| self.config.excluded_dial_addresses.contains(addr.address())) + { + neighbours.push(peer.clone()); + } + } + self.neighbours = neighbours; + + let mut random_pool = Vec::with_capacity(self.random_pool.len()); + for peer in &self.random_pool { + let addresses = self.peer_manager.get_peer_multi_addresses(peer).await?; + if !addresses + .iter() + .all(|addr| self.config.excluded_dial_addresses.contains(addr.address())) + { + random_pool.push(peer.clone()); + } + } + self.random_pool = random_pool; + } + Ok(()) + } + async fn is_allow_list_peer(&mut self, node_id: &NodeId) -> Result { Ok(self.peer_allow_list().await?.contains(node_id)) }