diff --git a/applications/minotari_app_grpc/proto/network.proto b/applications/minotari_app_grpc/proto/network.proto index 6e15d127b5..5a0d6c603d 100644 --- a/applications/minotari_app_grpc/proto/network.proto +++ b/applications/minotari_app_grpc/proto/network.proto @@ -31,7 +31,7 @@ message NodeIdentity { bytes node_id = 3; } -message Peer{ +message Peer { /// Public key of the peer bytes public_key =1; /// NodeId of the peer @@ -46,7 +46,7 @@ message Peer{ string banned_reason= 7; google.protobuf.Timestamp offline_at = 8; /// Features supported by the peer - uint64 features = 9; + uint32 features = 9; /// used as information for more efficient protocol negotiation. repeated bytes supported_protocols = 11; /// User agent advertised by the peer diff --git a/applications/minotari_app_grpc/src/conversions/peer.rs b/applications/minotari_app_grpc/src/conversions/peer.rs index 25e9309de1..8948fe6dde 100644 --- a/applications/minotari_app_grpc/src/conversions/peer.rs +++ b/applications/minotari_app_grpc/src/conversions/peer.rs @@ -63,12 +63,12 @@ impl From for grpc::Peer { impl From for grpc::Address { fn from(address_with_stats: MultiaddrWithStats) -> Self { let address = address_with_stats.address().to_vec(); - let last_seen = match address_with_stats.last_seen { + let last_seen = match address_with_stats.last_seen() { Some(v) => v.to_string(), None => String::new(), }; - let connection_attempts = address_with_stats.connection_attempts; - let avg_latency = address_with_stats.avg_latency.as_secs(); + let connection_attempts = address_with_stats.connection_attempts(); + let avg_latency = address_with_stats.avg_latency().as_secs(); Self { address, last_seen, diff --git a/applications/minotari_node/src/commands/command/get_peer.rs b/applications/minotari_node/src/commands/command/get_peer.rs index ca2c8a2121..f9242244f9 100644 --- a/applications/minotari_node/src/commands/command/get_peer.rs +++ b/applications/minotari_node/src/commands/command/get_peer.rs @@ -90,14 +90,14 @@ impl CommandContext { println!( "- {} Score: {} - Source: {} Latency: {:?} - Last Seen: {} - Last Failure:{}", a.address(), - a.quality_score, - a.source, - a.avg_latency, - a.last_seen + a.quality_score(), + a.source(), + a.avg_latency(), + a.last_seen() .as_ref() .map(|t| t.to_string()) .unwrap_or_else(|| "Never".to_string()), - a.last_failed_reason.as_ref().unwrap_or(&"None".to_string()) + a.last_failed_reason().unwrap_or("None") ); }); println!("User agent: {}", peer.user_agent); diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index bf144c16be..5f22b7b422 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -555,11 +555,11 @@ impl ServiceInitializer for P2pInitializer { }) .set_liveness_check(config.listener_liveness_check_interval); - if config.allow_test_addresses || config.dht.allow_test_addresses { + if config.allow_test_addresses || config.dht.peer_validator_config.allow_test_addresses { // The default is false, so ensure that both settings are true in this case config.allow_test_addresses = true; - config.dht.allow_test_addresses = true; builder = builder.allow_test_addresses(); + config.dht.peer_validator_config = builder.peer_validator_config().clone(); } let (comms, dht) = configure_comms_and_dht(builder, &config, connector).await?; diff --git a/base_layer/wallet/src/storage/sqlite_db/wallet.rs b/base_layer/wallet/src/storage/sqlite_db/wallet.rs index 386a563efa..659d5eebe6 100644 --- a/base_layer/wallet/src/storage/sqlite_db/wallet.rs +++ b/base_layer/wallet/src/storage/sqlite_db/wallet.rs @@ -303,7 +303,7 @@ impl WalletSqliteDatabase { fn get_comms_features(&self, conn: &mut SqliteConnection) -> Result, WalletStorageError> { if let Some(key_str) = WalletSettingSql::get(&DbKey::CommsFeatures, conn)? { - let features = u64::from_str(&key_str).map_err(|e| WalletStorageError::ConversionError(e.to_string()))?; + let features = u32::from_str(&key_str).map_err(|e| WalletStorageError::ConversionError(e.to_string()))?; let peer_features = PeerFeatures::from_bits(features); Ok(peer_features) } else { diff --git a/comms/core/src/bans.rs b/comms/core/src/bans.rs new file mode 100644 index 0000000000..fa4c3e037e --- /dev/null +++ b/comms/core/src/bans.rs @@ -0,0 +1,27 @@ +// // Copyright 2023. The Tari Project +// // +// // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// // following conditions are met: +// // +// // 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the +// following // disclaimer. +// // +// // 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// // following disclaimer in the documentation and/or other materials provided with the distribution. +// // +// // 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// // products derived from this software without specific prior written permission. +// // +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// // INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use std::time::Duration; + +// TODO: consolidate ban durations +pub const BAN_DURATION_LONG: Duration = Duration::from_secs(2 * 60 * 60); +pub const BAN_DURATION_SHORT: Duration = Duration::from_secs(2 * 60); diff --git a/comms/core/src/builder/mod.rs b/comms/core/src/builder/mod.rs index 1975665809..4048905377 100644 --- a/comms/core/src/builder/mod.rs +++ b/comms/core/src/builder/mod.rs @@ -46,6 +46,7 @@ use crate::{ connectivity::{ConnectivityConfig, ConnectivityRequester}, multiaddr::Multiaddr, peer_manager::{NodeIdentity, PeerManager}, + peer_validator::PeerValidatorConfig, protocol::{NodeNetworkInfo, ProtocolExtensions}, tor, types::CommsDatabase, @@ -195,10 +196,31 @@ impl CommsBuilder { target: "comms::builder", "Test addresses are enabled! This is invalid and potentially insecure when running a production node." ); - self.connection_manager_config.allow_test_addresses = true; + self.connection_manager_config + .peer_validation_config + .allow_test_addresses = true; self } + /// Sets the PeerValidatorConfig - this will override previous calls to allow_test_addresses() with the value in + /// peer_validator_config.allow_test_addresses + pub fn with_peer_validator_config(mut self, config: PeerValidatorConfig) -> Self { + #[cfg(not(debug_assertions))] + if config.allow_test_addresses { + log::warn!( + target: "comms::builder", + "Test addresses are enabled! This is invalid and potentially insecure when running a production node." + ); + } + self.connection_manager_config.peer_validation_config = config; + self + } + + /// Returns the PeerValidatorConfig set in this builder + pub fn peer_validator_config(&self) -> &PeerValidatorConfig { + &self.connection_manager_config.peer_validation_config + } + /// Sets the address that the transport will listen on. The address must be compatible with the transport. pub fn with_listener_address(mut self, listener_address: Multiaddr) -> Self { self.connection_manager_config.listener_address = listener_address; diff --git a/comms/core/src/connection_manager/common.rs b/comms/core/src/connection_manager/common.rs index 7aeedca221..22ff811d99 100644 --- a/comms/core/src/connection_manager/common.rs +++ b/comms/core/src/connection_manager/common.rs @@ -20,17 +20,21 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{convert::TryInto, net::Ipv6Addr}; +use std::{ + convert::{TryFrom, TryInto}, + time::Duration, +}; -use digest::Digest; use log::*; use tokio::io::{AsyncRead, AsyncWrite}; use crate::{ connection_manager::error::ConnectionManagerError, - multiaddr::{Multiaddr, Protocol}, - net_address::{MultiaddrWithStats, MultiaddressesWithStats, PeerAddressSource}, - peer_manager::{NodeId, NodeIdentity, Peer, PeerFlags, PeerIdentityClaim}, + multiaddr::Multiaddr, + net_address::{MultiaddressesWithStats, PeerAddressSource}, + peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags, PeerIdentityClaim, PeerManagerError}, + peer_validator::{validate_peer_identity_claim, PeerValidatorConfig, PeerValidatorError}, + proto::identity::PeerIdentityMsg, protocol, protocol::{NodeNetworkInfo, ProtocolId}, types::CommsPublicKey, @@ -39,6 +43,33 @@ use crate::{ const LOG_TARGET: &str = "comms::connection_manager::common"; +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ValidatedPeerIdentityExchange { + pub claim: PeerIdentityClaim, + pub metadata: PeerIdentityMetadata, +} + +impl ValidatedPeerIdentityExchange { + // getters + pub fn peer_features(&self) -> PeerFeatures { + self.claim.features + } + + pub fn supported_protocols(&self) -> &[ProtocolId] { + &self.metadata.supported_protocols + } + + pub fn user_agent(&self) -> &str { + self.metadata.user_agent.as_str() + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct PeerIdentityMetadata { + pub user_agent: String, + pub supported_protocols: Vec, +} + /// Performs the identity exchange protocol on the given socket. pub(super) async fn perform_identity_exchange< 'p, @@ -49,11 +80,11 @@ pub(super) async fn perform_identity_exchange< node_identity: &NodeIdentity, our_supported_protocols: P, network_info: NodeNetworkInfo, -) -> Result { +) -> Result { let peer_identity = protocol::identity_exchange(node_identity, our_supported_protocols, network_info, socket).await?; - Ok(peer_identity.try_into()?) + Ok(peer_identity) } /// Validate the peer identity info. @@ -65,25 +96,84 @@ pub(super) async fn perform_identity_exchange< /// /// If the `allow_test_addrs` parameter is true, loopback, local link and other addresses normally not considered valid /// for p2p comms will be accepted. -pub(super) async fn validate_peer_identity( +pub(super) fn validate_peer_identity_message( + config: &PeerValidatorConfig, authenticated_public_key: &CommsPublicKey, - peer_identity: &PeerIdentityClaim, - allow_test_addrs: bool, -) -> Result<(), ConnectionManagerError> { - validate_addresses(&peer_identity.addresses, allow_test_addrs)?; - if peer_identity.addresses.is_empty() { - return Err(ConnectionManagerError::PeerIdentityNoAddresses); + peer_identity_msg: PeerIdentityMsg, +) -> Result { + let PeerIdentityMsg { + addresses, + features, + supported_protocols, + user_agent, + identity_signature, + } = peer_identity_msg; + + // Perform basic length checks before parsing + if supported_protocols.len() > config.max_supported_protocols { + return Err(PeerValidatorError::PeerIdentityTooManyProtocols { + length: supported_protocols.len(), + max: config.max_supported_protocols, + } + .into()); + } + + if let Some(proto) = supported_protocols + .iter() + .find(|p| p.len() > config.max_protocol_id_length) + { + return Err(PeerValidatorError::PeerIdentityProtocolIdTooLong { + length: proto.len(), + max: config.max_protocol_id_length, + } + .into()); + } + + if addresses.is_empty() { + return Err(PeerValidatorError::PeerIdentityNoAddresses.into()); } - if !peer_identity.signature.is_valid( - authenticated_public_key, - peer_identity.features, - &peer_identity.addresses, - ) { - return Err(ConnectionManagerError::PeerIdentityInvalidSignature); + if addresses.len() > config.max_permitted_peer_addresses_per_claim { + return Err(PeerValidatorError::PeerIdentityTooManyAddresses { + length: addresses.len(), + max: config.max_permitted_peer_addresses_per_claim, + } + .into()); + } + + if user_agent.as_bytes().len() > config.max_user_agent_byte_length { + return Err(PeerValidatorError::PeerIdentityUserAgentTooLong { + length: user_agent.as_bytes().len(), + max: config.max_user_agent_byte_length, + } + .into()); } - Ok(()) + let supported_protocols = supported_protocols.into_iter().map(ProtocolId::from).collect(); + + let addresses = addresses + .into_iter() + .map(Multiaddr::try_from) + .collect::, _>>() + .map_err(|e| PeerManagerError::MultiaddrError(e.to_string()))?; + + let peer_identity_claim = PeerIdentityClaim { + addresses, + features: PeerFeatures::from_bits(features).ok_or(PeerManagerError::InvalidPeerFeatures { bits: features })?, + signature: identity_signature + .ok_or(PeerManagerError::MissingIdentitySignature)? + .try_into()?, + }; + + validate_peer_identity_claim(config, authenticated_public_key, &peer_identity_claim)?; + + Ok(ValidatedPeerIdentityExchange { + claim: peer_identity_claim, + metadata: PeerIdentityMetadata { + user_agent, + supported_protocols, + }, + }) } /// Validate the peer identity info. @@ -96,25 +186,15 @@ pub(super) async fn validate_peer_identity( /// /// If the `allow_test_addrs` parameter is true, loopback, local link and other addresses normally not considered valid /// for p2p comms will be accepted. -pub(super) async fn validate_and_add_peer_from_peer_identity( - peer_manager: &PeerManager, +pub(super) fn create_or_update_peer_from_validated_peer_identity( known_peer: Option, authenticated_public_key: CommsPublicKey, - peer_identity: &PeerIdentityClaim, - allow_test_addrs: bool, -) -> Result { + peer_identity: &ValidatedPeerIdentityExchange, +) -> Peer { let peer_node_id = NodeId::from_public_key(&authenticated_public_key); - let addresses = MultiaddressesWithStats::from_addresses_with_source( - peer_identity.addresses.clone(), - &PeerAddressSource::FromPeerConnection { - peer_identity_claim: peer_identity.clone(), - }, - ); - validate_addresses_and_source(&addresses, &authenticated_public_key, allow_test_addrs)?; - // Note: the peer will be merged in the db if it already exists - let peer = match known_peer { + match known_peer { Some(mut peer) => { debug!( target: LOG_TARGET, @@ -122,13 +202,13 @@ pub(super) async fn validate_and_add_peer_from_peer_identity( peer.node_id.short_str() ); peer.addresses - .update_addresses(&peer_identity.addresses, &PeerAddressSource::FromPeerConnection { - peer_identity_claim: peer_identity.clone(), + .update_addresses(&peer_identity.claim.addresses, &PeerAddressSource::FromPeerConnection { + peer_identity_claim: peer_identity.claim.clone(), }); - peer.features = peer_identity.features; - peer.supported_protocols = peer_identity.supported_protocols(); - peer.user_agent = peer_identity.user_agent().unwrap_or_default(); + peer.features = peer_identity.claim.features; + peer.supported_protocols = peer_identity.metadata.supported_protocols.clone(); + peer.user_agent = peer_identity.metadata.user_agent.clone(); peer }, @@ -139,25 +219,21 @@ pub(super) async fn validate_and_add_peer_from_peer_identity( peer_node_id.short_str() ); Peer::new( - authenticated_public_key.clone(), - peer_node_id.clone(), + authenticated_public_key, + peer_node_id, MultiaddressesWithStats::from_addresses_with_source( - peer_identity.addresses.clone(), + peer_identity.claim.addresses.clone(), &PeerAddressSource::FromPeerConnection { - peer_identity_claim: peer_identity.clone(), + peer_identity_claim: peer_identity.claim.clone(), }, ), PeerFlags::empty(), - peer_identity.features, - peer_identity.supported_protocols(), - peer_identity.user_agent().unwrap_or_default(), + peer_identity.peer_features(), + peer_identity.supported_protocols().to_vec(), + peer_identity.user_agent().to_string(), ) }, - }; - - peer_manager.add_peer(peer).await?; - - Ok(peer_node_id) + } } pub(super) async fn find_unbanned_peer( @@ -171,259 +247,37 @@ pub(super) async fn find_unbanned_peer( } } -/// Checks that the given peer addresses are well-formed and valid. If allow_test_addrs is false, all localhost and -/// memory addresses will be rejected. Also checks that the source (signature of the address) is correct -pub fn validate_addresses_and_source( - addresses: &MultiaddressesWithStats, - public_key: &CommsPublicKey, - allow_test_addrs: bool, -) -> Result<(), ConnectionManagerError> { - for addr in addresses.addresses() { - validate_address_and_source(public_key, addr, allow_test_addrs)?; - } - - Ok(()) -} - -/// Checks that the given peer addresses are well-formed and valid. If allow_test_addrs is false, all localhost and -/// memory addresses will be rejected. -pub fn validate_addresses(addresses: &[Multiaddr], allow_test_addrs: bool) -> Result<(), ConnectionManagerError> { - for addr in addresses { - validate_address(addr, allow_test_addrs)?; - } - - Ok(()) -} - -pub fn validate_address_and_source( - public_key: &CommsPublicKey, - addr: &MultiaddrWithStats, - allow_test_addrs: bool, -) -> Result<(), ConnectionManagerError> { - match addr.source { - PeerAddressSource::Config => (), - _ => { - let claim = addr - .source - .peer_identity_claim() - .ok_or(ConnectionManagerError::PeerIdentityInvalidSignature)?; - if !claim.signature.is_valid(public_key, claim.features, &claim.addresses) { - return Err(ConnectionManagerError::PeerIdentityInvalidSignature); - } - if !claim.addresses.contains(addr.address()) { - return Err(ConnectionManagerError::PeerIdentityInvalidSignature); - } - }, - } - validate_address(addr.address(), allow_test_addrs)?; - Ok(()) -} - -fn validate_address(addr: &Multiaddr, allow_test_addrs: bool) -> Result<(), ConnectionManagerError> { - let mut addr_iter = addr.iter(); - let proto = addr_iter - .next() - .ok_or_else(|| ConnectionManagerError::InvalidMultiaddr("Multiaddr was empty".to_string()))?; - - /// Returns [true] if the address is a unicast link-local address (fe80::/10). - #[inline] - const fn is_unicast_link_local(addr: &Ipv6Addr) -> bool { - (addr.segments()[0] & 0xffc0) == 0xfe80 - } - - match proto { - Protocol::Dns4(_) | Protocol::Dns6(_) | Protocol::Dnsaddr(_) => { - let tcp = addr_iter.next().ok_or_else(|| { - ConnectionManagerError::InvalidMultiaddr("Address does not include a TCP port".to_string()) - })?; - - validate_tcp_port(tcp)?; - expect_end_of_address(addr_iter) - }, - - Protocol::Ip4(addr) - if !allow_test_addrs && (addr.is_loopback() || addr.is_link_local() || addr.is_unspecified()) => - { - Err(ConnectionManagerError::InvalidMultiaddr( - "Non-global IP addresses are invalid".to_string(), - )) - }, - Protocol::Ip6(addr) - if !allow_test_addrs && (addr.is_loopback() || is_unicast_link_local(&addr) || addr.is_unspecified()) => - { - Err(ConnectionManagerError::InvalidMultiaddr( - "Non-global IP addresses are invalid".to_string(), - )) - }, - Protocol::Ip4(_) | Protocol::Ip6(_) => { - let tcp = addr_iter.next().ok_or_else(|| { - ConnectionManagerError::InvalidMultiaddr("Address does not include a TCP port".to_string()) - })?; - - validate_tcp_port(tcp)?; - expect_end_of_address(addr_iter) +pub(super) async fn ban_on_offence( + peer_manager: &PeerManager, + authenticated_public_key: &CommsPublicKey, + result: Result, +) -> Result { + match result { + Ok(t) => Ok(t), + Err(ConnectionManagerError::PeerValidationError(e)) => { + maybe_ban(peer_manager, authenticated_public_key, e.as_ban_duration(), e).await }, - Protocol::Memory(0) => Err(ConnectionManagerError::InvalidMultiaddr( - "Cannot connect to a zero memory port".to_string(), - )), - Protocol::Memory(_) if allow_test_addrs => expect_end_of_address(addr_iter), - Protocol::Memory(_) => Err(ConnectionManagerError::InvalidMultiaddr( - "Memory addresses are invalid".to_string(), - )), - // Zero-port onions should have already failed when parsing. Keep these checks here just in case. - Protocol::Onion(_, 0) => Err(ConnectionManagerError::InvalidMultiaddr( - "A zero onion port is not valid in the onion spec".to_string(), - )), - Protocol::Onion3(addr) if addr.port() == 0 => Err(ConnectionManagerError::InvalidMultiaddr( - "A zero onion port is not valid in the onion spec".to_string(), - )), - Protocol::Onion(_, _) => Err(ConnectionManagerError::OnionV2NotSupported), - Protocol::Onion3(addr) => { - expect_end_of_address(addr_iter)?; - validate_onion3_address(&addr) + Err(ConnectionManagerError::IdentityProtocolError(e)) => { + maybe_ban(peer_manager, authenticated_public_key, e.as_ban_duration(), e).await }, - p => Err(ConnectionManagerError::InvalidMultiaddr(format!( - "Unsupported address type '{}'", - p - ))), - } -} - -fn expect_end_of_address(mut iter: multiaddr::Iter<'_>) -> Result<(), ConnectionManagerError> { - match iter.next() { - Some(p) => Err(ConnectionManagerError::InvalidMultiaddr(format!( - "Unexpected multiaddress component '{}'", - p - ))), - None => Ok(()), - } -} - -fn validate_tcp_port(expected_tcp: Protocol) -> Result<(), ConnectionManagerError> { - match expected_tcp { - Protocol::Tcp(0) => Err(ConnectionManagerError::InvalidMultiaddr( - "Cannot connect to a zero TCP port".to_string(), - )), - Protocol::Tcp(_) => Ok(()), - p => Err(ConnectionManagerError::InvalidMultiaddr(format!( - "Expected TCP address component but got '{}'", - p - ))), - } -} - -/// Validates the onion3 version and checksum as per https://github.com/torproject/torspec/blob/main/rend-spec-v3.txt#LL2258C6-L2258C6 -fn validate_onion3_address(addr: &multiaddr::Onion3Addr<'_>) -> Result<(), ConnectionManagerError> { - const ONION3_PUBKEY_SIZE: usize = 32; - const ONION3_CHECKSUM_SIZE: usize = 2; - - let (pub_key, checksum_version) = addr.hash().split_at(ONION3_PUBKEY_SIZE); - let (checksum, version) = checksum_version.split_at(ONION3_CHECKSUM_SIZE); - - if version != b"\x03" { - return Err(ConnectionManagerError::InvalidMultiaddr( - "Invalid version in onion address".to_string(), - )); - } - - let calculated_checksum = sha3::Sha3_256::new() - .chain_update(".onion checksum") - .chain_update(pub_key) - .chain_update(version) - .finalize(); - - if calculated_checksum[..2] != *checksum { - return Err(ConnectionManagerError::InvalidMultiaddr( - "Invalid checksum in onion address".to_string(), - )); + Err(err) => Err(err), } - - Ok(()) } -#[cfg(test)] -mod test { - use multiaddr::multiaddr; - - use super::*; - - #[test] - fn validate_address_strict() { - let valid = [ - multiaddr!(Ip4([172, 0, 0, 1]), Tcp(1u16)), - multiaddr!(Ip6([172, 0, 0, 1, 1, 1, 1, 1]), Tcp(1u16)), - "/onion3/vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd:1234" - .parse() - .unwrap(), - multiaddr!(Dnsaddr("mike-magic-nodes.com"), Tcp(1u16)), - ]; - - let invalid = &[ - "/onion/aaimaq4ygg2iegci:1234".parse().unwrap(), - multiaddr!(Ip4([127, 0, 0, 1]), Tcp(1u16)), - multiaddr!(Ip4([169, 254, 0, 1]), Tcp(1u16)), - multiaddr!(Ip4([172, 0, 0, 1])), - "/onion/aaimaq4ygg2iegci:1234/http".parse().unwrap(), - multiaddr!(Dnsaddr("mike-magic-nodes.com")), - multiaddr!(Memory(1234u64)), - multiaddr!(Memory(0u64)), - ]; - - validate_addresses(&valid, false).unwrap(); - for addr in invalid { - validate_address(addr, false).unwrap_err(); - } - } - - #[test] - fn validate_address_allow_test_addrs() { - let valid = [ - multiaddr!(Ip4([127, 0, 0, 1]), Tcp(1u16)), - multiaddr!(Ip4([169, 254, 0, 1]), Tcp(1u16)), - multiaddr!(Ip4([172, 0, 0, 1]), Tcp(1u16)), - multiaddr!(Ip6([172, 0, 0, 1, 1, 1, 1, 1]), Tcp(1u16)), - "/onion3/vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd:1234" - .parse() - .unwrap(), - multiaddr!(Dnsaddr("mike-magic-nodes.com"), Tcp(1u16)), - multiaddr!(Memory(1234u64)), - ]; - - let invalid = &[ - "/onion/aaimaq4ygg2iegci:1234".parse().unwrap(), - multiaddr!(Ip4([172, 0, 0, 1])), - "/onion/aaimaq4ygg2iegci:1234/http".parse().unwrap(), - multiaddr!(Dnsaddr("mike-magic-nodes.com")), - multiaddr!(Memory(0u64)), - ]; - - validate_addresses(&valid, true).unwrap(); - for addr in invalid { - validate_address(addr, true).unwrap_err(); +async fn maybe_ban>( + peer_manager: &PeerManager, + authenticated_public_key: &CommsPublicKey, + ban_duration: Option, + err: E, +) -> Result { + if let Some(ban_duration) = ban_duration { + if let Err(err) = peer_manager + .ban_peer(authenticated_public_key, ban_duration, err.to_string()) + .await + { + error!(target: LOG_TARGET, "Failed to ban peer due to internal error: {}", err); } } - #[test] - fn validate_onion3_checksum() { - let valid: Multiaddr = "/onion3/vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd:1234" - .parse() - .unwrap(); - - validate_address(&valid, false).unwrap(); - - // Change one byte - let invalid: Multiaddr = "/onion3/www6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd:1234" - .parse() - .unwrap(); - - validate_address(&invalid, false).unwrap_err(); - - // Randomly generated - let invalid: Multiaddr = "/onion3/pd6sf3mqkkkfrn4rk5odgcr2j5sn7m523a4tm7pzpuotk2b7rpuhaeym:1234" - .parse() - .unwrap(); - - let err = validate_address(&invalid, false).unwrap_err(); - assert!(matches!(err, ConnectionManagerError::InvalidMultiaddr(_))); - } + Err(err.into()) } diff --git a/comms/core/src/connection_manager/dialer.rs b/comms/core/src/connection_manager/dialer.rs index e55b303205..028a1a8f3c 100644 --- a/comms/core/src/connection_manager/dialer.rs +++ b/comms/core/src/connection_manager/dialer.rs @@ -50,6 +50,7 @@ use crate::{ backoff::Backoff, connection_manager::{ common, + common::ValidatedPeerIdentityExchange, dial_state::DialState, manager::{ConnectionManagerConfig, ConnectionManagerEvent}, metrics, @@ -68,8 +69,15 @@ use crate::{ const LOG_TARGET: &str = "comms::connection_manager::dialer"; type DialResult = Result<(NoiseSocket, Multiaddr), ConnectionManagerError>; -type DialFuturesUnordered = - FuturesUnordered)>>; +type DialFuturesUnordered = FuturesUnordered< + BoxFuture< + 'static, + ( + DialState, + Result<(PeerConnection, ValidatedPeerIdentityExchange), ConnectionManagerError>, + ), + >, +>; #[derive(Debug)] pub(crate) enum DialerRequest { @@ -94,7 +102,7 @@ pub struct Dialer { conn_man_notifier: mpsc::Sender, shutdown: Option, pending_dial_requests: HashMap>>>, - our_supported_protocols: Vec, + our_supported_protocols: Arc>, } impl Dialer @@ -126,13 +134,13 @@ where conn_man_notifier, shutdown: Some(shutdown), pending_dial_requests: Default::default(), - our_supported_protocols: Vec::new(), + our_supported_protocols: Arc::new(Vec::new()), } } /// Set the supported protocols of this node to send to peers during the peer identity exchange pub fn set_supported_protocols(&mut self, our_supported_protocols: Vec) -> &mut Self { - self.our_supported_protocols = our_supported_protocols; + self.our_supported_protocols = Arc::new(our_supported_protocols); self } @@ -193,7 +201,7 @@ where fn resolve_pending_dials(&mut self, conn: PeerConnection) { let peer = conn.peer_node_id().clone(); - self.reply_to_pending_requests(&peer, &Ok(conn)); + self.reply_to_pending_requests(&peer, Ok(conn)); self.cancel_dial(&peer); } @@ -215,53 +223,43 @@ where async fn handle_dial_result( &mut self, mut dial_state: DialState, - dial_result: Result, + dial_result: Result<(PeerConnection, ValidatedPeerIdentityExchange), ConnectionManagerError>, ) { let node_id = dial_state.peer().node_id.clone(); metrics::pending_connections(Some(&node_id), ConnectionDirection::Outbound).inc(); - // try save the peer back to the peer manager - let peer = dial_state.peer_mut(); - if let Ok(peer_connection) = &dial_result { - if let Some(peer_identity) = peer_connection.peer_identity_claim() { - peer.update_addresses(&peer_identity.addresses, &PeerAddressSource::FromPeerConnection { - peer_identity_claim: peer_identity.clone(), + match dial_result { + Ok((conn, peer_identity)) => { + // try save the peer back to the peer manager + let peer = dial_state.peer_mut(); + peer.update_addresses(&peer_identity.claim.addresses, &PeerAddressSource::FromPeerConnection { + peer_identity_claim: peer_identity.claim.clone(), }); - if let Some(unverified_data) = &peer_identity.unverified_data { - for protocol in &unverified_data.supported_protocols { - if !peer.supported_protocols.contains(protocol) { - peer.supported_protocols.push(protocol.clone()); - } - } - if peer.user_agent != unverified_data.user_agent && !unverified_data.user_agent.is_empty() { - peer.user_agent = unverified_data.user_agent.clone(); - } - } - } else { - error!(target: LOG_TARGET, "No identity claim provided"); - let _ = dial_state - .send_reply(Err(ConnectionManagerError::PeerConnectionError( - "No identity claim provided".to_string(), - ))) - .map_err(|e| error!(target: LOG_TARGET, "Could not send reply to dial request: {:?}", e)); - } - } + peer.supported_protocols = peer_identity.metadata.supported_protocols; + peer.user_agent = peer_identity.metadata.user_agent; - let _ = self - .peer_manager - .add_peer(dial_state.peer().clone()) - .await - .map_err(|e| { - error!("Could not update peer data:{}", e); - let _ = dial_state - .send_reply(Err(ConnectionManagerError::PeerManagerError(e))) - .map_err(|e| error!(target: LOG_TARGET, "Could not send reply to dial request: {:?}", e)); - }); - match &dial_result { - Ok(conn) => { + let _ = self + .peer_manager + .add_peer(dial_state.peer().clone()) + .await + .map_err(|e| { + error!("Could not update peer data:{}", e); + let _ = dial_state + .send_reply(Err(ConnectionManagerError::PeerManagerError(e))) + .map_err(|e| error!(target: LOG_TARGET, "Could not send reply to dial request: {:?}", e)); + }); debug!(target: LOG_TARGET, "Successfully dialed peer '{}'", node_id); self.notify_connection_manager(ConnectionManagerEvent::PeerConnected(conn.clone().into())) - .await + .await; + + if dial_state.send_reply(Ok(conn.clone())).is_err() { + warn!( + target: LOG_TARGET, + "Reply oneshot was closed before dial response for peer '{}' was sent", node_id + ); + } + + self.reply_to_pending_requests(&node_id, Ok(conn)); }, Err(err) => { debug!( @@ -269,20 +267,20 @@ where "Failed to dial peer '{}' because '{:?}'", node_id, err ); self.notify_connection_manager(ConnectionManagerEvent::PeerConnectFailed(node_id.clone(), err.clone())) - .await + .await; + + if dial_state.send_reply(Err(err.clone())).is_err() { + warn!( + target: LOG_TARGET, + "Reply oneshot was closed before dial response for peer '{}' was sent", node_id + ); + } + self.reply_to_pending_requests(&node_id, Err(err)); }, } metrics::pending_connections(Some(&node_id), ConnectionDirection::Outbound).dec(); - if dial_state.send_reply(dial_result.clone()).is_err() { - warn!( - target: LOG_TARGET, - "Reply oneshot was closed before dial response for peer '{}' was sent", node_id - ); - } - - self.reply_to_pending_requests(&node_id, &dial_result); self.cancel_dial(&node_id); } @@ -297,7 +295,7 @@ where fn reply_to_pending_requests( &mut self, peer_node_id: &NodeId, - result: &Result, + result: Result, ) { self.pending_dial_requests .remove(peer_node_id) @@ -345,6 +343,7 @@ where let supported_protocols = self.our_supported_protocols.clone(); let noise_config = self.noise_config.clone(); let config = self.config.clone(); + let peer_manager = self.peer_manager.clone(); let span = span!(Level::TRACE, "handle_dial_peer_request_inner1"); let dial_fut = async move { @@ -369,7 +368,8 @@ where }; let result = Self::perform_socket_upgrade_procedure( - node_identity, + &peer_manager, + &node_identity, socket, addr.clone(), authenticated_public_key, @@ -418,34 +418,43 @@ where } async fn perform_socket_upgrade_procedure( - node_identity: Arc, + peer_manager: &PeerManager, + node_identity: &NodeIdentity, mut socket: NoiseSocket, dialed_addr: Multiaddr, authenticated_public_key: CommsPublicKey, conn_man_notifier: mpsc::Sender, - our_supported_protocols: Vec, + our_supported_protocols: Arc>, config: &ConnectionManagerConfig, cancel_signal: ShutdownSignal, - ) -> Result { + ) -> Result<(PeerConnection, ValidatedPeerIdentityExchange), ConnectionManagerError> { static CONNECTION_DIRECTION: ConnectionDirection = ConnectionDirection::Outbound; debug!( target: LOG_TARGET, "Starting peer identity exchange for peer with public key '{}'", authenticated_public_key ); - let peer_identity = common::perform_identity_exchange( + let peer_identity_result = common::perform_identity_exchange( &mut socket, - &node_identity, - &our_supported_protocols, + node_identity, + &*our_supported_protocols, config.network_info.clone(), ) - .await?; + .await; + let peer_identity = + common::ban_on_offence(peer_manager, &authenticated_public_key, peer_identity_result).await?; if cancel_signal.is_terminated() { return Err(ConnectionManagerError::DialCancelled); } - common::validate_peer_identity(&authenticated_public_key, &peer_identity, config.allow_test_addresses).await?; + let peer_identity_result = common::validate_peer_identity_message( + &config.peer_validation_config, + &authenticated_public_key, + peer_identity, + ); + let peer_identity = + common::ban_on_offence(peer_manager, &authenticated_public_key, peer_identity_result).await?; if cancel_signal.is_terminated() { return Err(ConnectionManagerError::DialCancelled); @@ -459,17 +468,18 @@ where return Err(ConnectionManagerError::DialCancelled); } - peer_connection::try_create( + let peer_connection = peer_connection::create( muxer, dialed_addr, NodeId::from_public_key(&authenticated_public_key), - peer_identity.features, + peer_identity.claim.features, CONNECTION_DIRECTION, conn_man_notifier, our_supported_protocols, - peer_identity.supported_protocols(), - peer_identity, - ) + peer_identity.metadata.supported_protocols.clone(), + ); + + Ok((peer_connection, peer_identity)) } async fn dial_peer_with_retry( diff --git a/comms/core/src/connection_manager/error.rs b/comms/core/src/connection_manager/error.rs index 0bac1ece21..ff0585520c 100644 --- a/comms/core/src/connection_manager/error.rs +++ b/comms/core/src/connection_manager/error.rs @@ -27,6 +27,7 @@ use crate::{ connection_manager::PeerConnectionRequest, noise, peer_manager::PeerManagerError, + peer_validator::PeerValidatorError, protocol::{IdentityProtocolError, ProtocolError}, }; @@ -81,14 +82,8 @@ pub enum ConnectionManagerError { NoiseProtocolTimeout, #[error("Listener oneshot cancelled")] ListenerOneshotCancelled, - #[error("Peer sent invalid identity signature")] - PeerIdentityInvalidSignature, - #[error("Peer did not provide an identity signature")] - PeerIdentityNoSignature, - #[error("Peer did not provide any public addresses")] - PeerIdentityNoAddresses, - #[error("Onion v2 is no longer supported")] - OnionV2NotSupported, + #[error("Peer validation error: {0}")] + PeerValidationError(#[from] PeerValidatorError), } impl From for ConnectionManagerError { diff --git a/comms/core/src/connection_manager/listener.rs b/comms/core/src/connection_manager/listener.rs index dc8db5143a..8699a11fc7 100644 --- a/comms/core/src/connection_manager/listener.rs +++ b/comms/core/src/connection_manager/listener.rs @@ -80,7 +80,7 @@ pub struct PeerListener { noise_config: NoiseConfig, peer_manager: Arc, node_identity: Arc, - our_supported_protocols: Vec, + our_supported_protocols: Arc>, liveness_session_count: Arc, on_listening: OneshotTrigger>, } @@ -108,7 +108,7 @@ where peer_manager, node_identity, shutdown_signal, - our_supported_protocols: Vec::new(), + our_supported_protocols: Arc::new(Vec::new()), bounded_executor: BoundedExecutor::new(config.max_simultaneous_inbound_connects), liveness_session_count: Arc::new(AtomicUsize::new(config.liveness_max_sessions)), config, @@ -127,7 +127,7 @@ where /// Set the supported protocols of this node to send to peers during the peer identity exchange pub fn set_supported_protocols(&mut self, our_supported_protocols: Vec) -> &mut Self { - self.our_supported_protocols = our_supported_protocols; + self.our_supported_protocols = Arc::new(our_supported_protocols); self } @@ -245,8 +245,8 @@ where Ok(WireMode::Comms(byte)) if byte == config.network_info.network_byte => { let this_node_id_str = node_identity.node_id().short_str(); let result = Self::perform_socket_upgrade_procedure( - node_identity, - peer_manager, + &node_identity, + &peer_manager, noise_config.clone(), conn_man_notifier.clone(), socket, @@ -335,16 +335,16 @@ where } async fn perform_socket_upgrade_procedure( - node_identity: Arc, - peer_manager: Arc, + node_identity: &NodeIdentity, + peer_manager: &PeerManager, noise_config: NoiseConfig, conn_man_notifier: mpsc::Sender, socket: TTransport::Output, peer_addr: Multiaddr, - our_supported_protocols: Vec, + our_supported_protocols: Arc>, config: &ConnectionManagerConfig, ) -> Result { - static CONNECTION_DIRECTION: ConnectionDirection = ConnectionDirection::Inbound; + const CONNECTION_DIRECTION: ConnectionDirection = ConnectionDirection::Inbound; debug!( target: LOG_TARGET, "Starting noise protocol upgrade for peer at address '{}'", peer_addr @@ -370,44 +370,56 @@ where ); // Check if we know the peer and if it is banned - let known_peer = common::find_unbanned_peer(&peer_manager, &authenticated_public_key).await?; + let known_peer = common::find_unbanned_peer(peer_manager, &authenticated_public_key).await?; debug!( target: LOG_TARGET, "Starting peer identity exchange for peer with public key '{}'", authenticated_public_key ); - let peer_identity = common::perform_identity_exchange( + let peer_identity_result = common::perform_identity_exchange( &mut noise_socket, - &node_identity, - &our_supported_protocols, + node_identity, + &*our_supported_protocols, config.network_info.clone(), ) - .await?; + .await; + + let peer_identity = + common::ban_on_offence(peer_manager, &authenticated_public_key, peer_identity_result).await?; + + let valid_peer_identity_result = common::validate_peer_identity_message( + &config.peer_validation_config, + &authenticated_public_key, + peer_identity, + ); + + let valid_peer_identity = + common::ban_on_offence(peer_manager, &authenticated_public_key, valid_peer_identity_result).await?; - let peer_node_id = common::validate_and_add_peer_from_peer_identity( - &peer_manager, + let peer = common::create_or_update_peer_from_validated_peer_identity( known_peer, authenticated_public_key, - &peer_identity, - config.allow_test_addresses, - ) - .await?; + &valid_peer_identity, + ); let muxer = Yamux::upgrade_connection(noise_socket, CONNECTION_DIRECTION) .map_err(|err| ConnectionManagerError::YamuxUpgradeFailure(err.to_string()))?; - peer_connection::try_create( + let conn = peer_connection::create( muxer, peer_addr, - peer_node_id, - peer_identity.features, + peer.node_id.clone(), + peer.features, CONNECTION_DIRECTION, conn_man_notifier, our_supported_protocols, - peer_identity.supported_protocols(), - peer_identity, - ) + valid_peer_identity.metadata.supported_protocols, + ); + + peer_manager.add_peer(peer).await?; + + Ok(conn) } async fn bind(&mut self) -> Result<(TTransport::Listener, Multiaddr), ConnectionManagerError> { diff --git a/comms/core/src/connection_manager/manager.rs b/comms/core/src/connection_manager/manager.rs index 4da6116a30..3ca4454804 100644 --- a/comms/core/src/connection_manager/manager.rs +++ b/comms/core/src/connection_manager/manager.rs @@ -47,6 +47,7 @@ use crate::{ multiplexing::Substream, noise::NoiseConfig, peer_manager::{NodeId, NodeIdentity, PeerManagerError}, + peer_validator::PeerValidatorConfig, protocol::{NodeNetworkInfo, ProtocolEvent, ProtocolId, Protocols}, transports::{TcpTransport, Transport}, PeerManager, @@ -106,9 +107,6 @@ pub struct ConnectionManagerConfig { /// The maximum number of connection tasks that will be spawned at the same time. Once this limit is reached, peers /// attempting to connect will have to wait for another connection attempt to complete. Default: 100 pub max_simultaneous_inbound_connects: usize, - /// Set to true to allow peers to send loopback, local-link and other addresses normally not considered valid for - /// peer-to-peer comms. Default: false - pub allow_test_addresses: bool, /// Version information for this node pub network_info: NodeNetworkInfo, /// The maximum time to wait for the first byte before closing the connection. Default: 45s @@ -122,6 +120,7 @@ pub struct ConnectionManagerConfig { /// If set, an additional TCP-only p2p listener will be started. This is useful for local wallet connections. /// Default: None (disabled) pub auxiliary_tcp_listener_address: Option, + pub peer_validation_config: PeerValidatorConfig, } impl Default for ConnectionManagerConfig { @@ -136,16 +135,12 @@ impl Default for ConnectionManagerConfig { max_dial_attempts: 1, max_simultaneous_inbound_connects: 100, network_info: Default::default(), - #[cfg(not(test))] - allow_test_addresses: false, - // This must always be true for internal crate tests - #[cfg(test)] - allow_test_addresses: true, liveness_max_sessions: 1, time_to_first_byte: Duration::from_secs(45), liveness_cidr_allowlist: vec![cidr::AnyIpCidr::V4("127.0.0.1/32".parse().unwrap())], liveness_self_check_interval: None, auxiliary_tcp_listener_address: None, + peer_validation_config: PeerValidatorConfig::default(), } } } diff --git a/comms/core/src/connection_manager/mod.rs b/comms/core/src/connection_manager/mod.rs index fd77998d65..fe3c1fe0ed 100644 --- a/comms/core/src/connection_manager/mod.rs +++ b/comms/core/src/connection_manager/mod.rs @@ -33,7 +33,6 @@ mod listener; mod metrics; mod common; -pub use common::{validate_address_and_source, validate_addresses, validate_addresses_and_source}; mod direction; pub use direction::ConnectionDirection; diff --git a/comms/core/src/connection_manager/peer_connection.rs b/comms/core/src/connection_manager/peer_connection.rs index be154db29e..f8b1fb26d4 100644 --- a/comms/core/src/connection_manager/peer_connection.rs +++ b/comms/core/src/connection_manager/peer_connection.rs @@ -40,11 +40,7 @@ use tokio::{ use tokio_stream::StreamExt; use tracing::{self, span, Instrument, Level}; -use super::{ - direction::ConnectionDirection, - error::{ConnectionManagerError, PeerConnectionError}, - manager::ConnectionManagerEvent, -}; +use super::{direction::ConnectionDirection, error::PeerConnectionError, manager::ConnectionManagerEvent}; #[cfg(feature = "rpc")] use crate::protocol::rpc::{ pool::RpcClientPool, @@ -59,7 +55,7 @@ use crate::{ framing, framing::CanonicalFraming, multiplexing::{Control, IncomingSubstreams, Substream, Yamux}, - peer_manager::{NodeId, PeerFeatures, PeerIdentityClaim}, + peer_manager::{NodeId, PeerFeatures}, protocol::{ProtocolId, ProtocolNegotiation}, utils::atomic_ref_counter::AtomicRefCounter, }; @@ -70,17 +66,16 @@ const PROTOCOL_NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(5); static ID_COUNTER: AtomicUsize = AtomicUsize::new(0); -pub fn try_create( +pub fn create( connection: Yamux, peer_addr: Multiaddr, peer_node_id: NodeId, peer_features: PeerFeatures, direction: ConnectionDirection, event_notifier: mpsc::Sender, - our_supported_protocols: Vec, + our_supported_protocols: Arc>, their_supported_protocols: Vec, - peer_identity_claim: PeerIdentityClaim, -) -> Result { +) -> PeerConnection { trace!( target: LOG_TARGET, "(Peer={}) Socket successfully upgraded to multiplexed socket", @@ -98,7 +93,6 @@ pub fn try_create( peer_addr, direction, substream_counter, - peer_identity_claim, ); let peer_actor = PeerConnectionActor::new( id, @@ -112,7 +106,7 @@ pub fn try_create( ); tokio::spawn(peer_actor.run()); - Ok(peer_conn) + peer_conn } /// Request types for the PeerConnection actor. @@ -142,7 +136,6 @@ pub struct PeerConnection { started_at: Instant, substream_counter: AtomicRefCounter, handle_counter: Arc<()>, - peer_identity_claim: Option, } impl PeerConnection { @@ -154,7 +147,6 @@ impl PeerConnection { address: Multiaddr, direction: ConnectionDirection, substream_counter: AtomicRefCounter, - peer_identity_claim: PeerIdentityClaim, ) -> Self { Self { id, @@ -166,31 +158,6 @@ impl PeerConnection { started_at: Instant::now(), substream_counter, handle_counter: Arc::new(()), - peer_identity_claim: Some(peer_identity_claim), - } - } - - /// Should only be used in tests - pub(crate) fn unverified( - id: ConnectionId, - request_tx: mpsc::Sender, - peer_node_id: NodeId, - peer_features: PeerFeatures, - address: Multiaddr, - direction: ConnectionDirection, - substream_counter: AtomicRefCounter, - ) -> Self { - Self { - id, - request_tx, - peer_node_id, - peer_features, - address: Arc::new(address), - direction, - started_at: Instant::now(), - substream_counter, - handle_counter: Arc::new(()), - peer_identity_claim: None, } } @@ -206,6 +173,14 @@ impl PeerConnection { self.direction } + pub fn known_address(&self) -> Option<&Multiaddr> { + if self.direction.is_outbound() { + Some(self.address()) + } else { + None + } + } + pub fn address(&self) -> &Multiaddr { &self.address } @@ -236,10 +211,6 @@ impl PeerConnection { Arc::strong_count(&self.handle_counter) } - pub fn peer_identity_claim(&self) -> Option<&PeerIdentityClaim> { - self.peer_identity_claim.as_ref() - } - #[tracing::instrument(level = "trace", "peer_connection::open_substream", skip(self))] pub async fn open_substream( &mut self, @@ -375,7 +346,7 @@ impl PeerConnectionActor { connection: Yamux, request_rx: mpsc::Receiver, event_notifier: mpsc::Sender, - our_supported_protocols: Vec, + our_supported_protocols: Arc>, their_supported_protocols: Vec, ) -> Self { Self { @@ -386,9 +357,7 @@ impl PeerConnectionActor { incoming_substreams: connection.into_incoming(), request_rx, event_notifier, - // our_supported_protocols never changes so we make it cheap to clone (used in inbound_protocol_negotiations - // futures) - our_supported_protocols: Arc::new(our_supported_protocols), + our_supported_protocols, inbound_protocol_negotiations: FuturesUnordered::new(), their_supported_protocols, } diff --git a/comms/core/src/connection_manager/tests/listener_dialer.rs b/comms/core/src/connection_manager/tests/listener_dialer.rs index 3d1139e97a..a9ff5d4b12 100644 --- a/comms/core/src/connection_manager/tests/listener_dialer.rs +++ b/comms/core/src/connection_manager/tests/listener_dialer.rs @@ -246,8 +246,10 @@ async fn banned() { let err = reply_rx.await.unwrap().unwrap_err(); unpack_enum!(ConnectionManagerError::IdentityProtocolError(_err) = err); - unpack_enum!(ConnectionManagerEvent::PeerInboundConnectFailed(err) = event_rx.recv().await.unwrap()); - unpack_enum!(ConnectionManagerError::PeerBanned = err); + unpack_enum!( + ConnectionManagerEvent::PeerInboundConnectFailed(ConnectionManagerError::PeerBanned) = + event_rx.recv().await.unwrap() + ); shutdown.trigger(); diff --git a/comms/core/src/connectivity/manager.rs b/comms/core/src/connectivity/manager.rs index c4654ad50d..dc47466bd4 100644 --- a/comms/core/src/connectivity/manager.rs +++ b/comms/core/src/connectivity/manager.rs @@ -265,7 +265,7 @@ impl ConnectivityManagerActor { } else if let Err(err) = self.ban_peer(&node_id, duration, reason).await { error!(target: LOG_TARGET, "Error when banning peer: {:?}", err); } else { - // we ban the peer + // we banned the peer } }, AddPeerToAllowList(node_id) => { diff --git a/comms/core/src/connectivity/requester.rs b/comms/core/src/connectivity/requester.rs index d555e1c716..fc5cac676a 100644 --- a/comms/core/src/connectivity/requester.rs +++ b/comms/core/src/connectivity/requester.rs @@ -248,14 +248,14 @@ impl ConnectivityRequester { } /// Ban peer for the given Duration. The ban `reason` is persisted in the peer database for reference. - pub async fn ban_peer_until( + pub async fn ban_peer_until>( &mut self, node_id: NodeId, duration: Duration, - reason: String, + reason: T, ) -> Result<(), ConnectivityError> { self.sender - .send(ConnectivityRequest::BanPeer(node_id, duration, reason)) + .send(ConnectivityRequest::BanPeer(node_id, duration, reason.into())) .await .map_err(|_| ConnectivityError::ActorDisconnected)?; Ok(()) diff --git a/comms/core/src/lib.rs b/comms/core/src/lib.rs index 9bbd8a8fb2..20144b1fbc 100644 --- a/comms/core/src/lib.rs +++ b/comms/core/src/lib.rs @@ -18,7 +18,7 @@ mod builder; pub use builder::{CommsBuilder, CommsBuilderError, CommsNode, UnspawnedCommsNode}; pub mod connection_manager; -pub use connection_manager::{validate_addresses, PeerConnection, PeerConnectionError}; +pub use connection_manager::{PeerConnection, PeerConnectionError}; pub mod connectivity; @@ -51,6 +51,9 @@ pub mod types; #[macro_use] pub mod utils; +pub mod peer_validator; + +mod bans; pub mod test_utils; pub mod traits; diff --git a/comms/core/src/net_address/multiaddr_with_stats.rs b/comms/core/src/net_address/multiaddr_with_stats.rs index e5599fdc2c..b357168fd6 100644 --- a/comms/core/src/net_address/multiaddr_with_stats.rs +++ b/comms/core/src/net_address/multiaddr_with_stats.rs @@ -19,106 +19,27 @@ use crate::{peer_manager::PeerIdentityClaim, types::CommsPublicKey}; const MAX_LATENCY_SAMPLE_COUNT: u32 = 100; const MAX_INITIAL_DIAL_TIME_SAMPLE_COUNT: u32 = 100; +const HIGH_QUALITY_SCORE: i32 = 1000; #[derive(Debug, Eq, Clone, Deserialize, Serialize)] pub struct MultiaddrWithStats { address: Multiaddr, - pub last_seen: Option, - pub connection_attempts: u32, - pub avg_initial_dial_time: Duration, + last_seen: Option, + connection_attempts: u32, + avg_initial_dial_time: Duration, initial_dial_time_sample_count: u32, - pub avg_latency: Duration, + avg_latency: Duration, latency_sample_count: u32, - pub last_attempted: Option, - pub last_failed_reason: Option, - pub quality_score: i32, - pub source: PeerAddressSource, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq)] -pub enum PeerAddressSource { - Config, - FromNodeIdentity { - peer_identity_claim: PeerIdentityClaim, - }, - FromPeerConnection { - peer_identity_claim: PeerIdentityClaim, - }, - FromDiscovery { - peer_identity_claim: PeerIdentityClaim, - }, - FromAnotherPeer { - peer_identity_claim: PeerIdentityClaim, - source_peer: CommsPublicKey, - }, - FromJoinMessage { - peer_identity_claim: PeerIdentityClaim, - }, -} - -impl PeerAddressSource { - pub fn is_config(&self) -> bool { - matches!(self, PeerAddressSource::Config) - } - - pub fn peer_identity_claim(&self) -> Option<&PeerIdentityClaim> { - match self { - PeerAddressSource::Config => None, - PeerAddressSource::FromNodeIdentity { peer_identity_claim } => Some(peer_identity_claim), - PeerAddressSource::FromPeerConnection { peer_identity_claim } => Some(peer_identity_claim), - PeerAddressSource::FromDiscovery { peer_identity_claim } => Some(peer_identity_claim), - PeerAddressSource::FromAnotherPeer { - peer_identity_claim, .. - } => Some(peer_identity_claim), - PeerAddressSource::FromJoinMessage { peer_identity_claim } => Some(peer_identity_claim), - } - } -} - -impl Display for PeerAddressSource { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - PeerAddressSource::Config => write!(f, "Config"), - PeerAddressSource::FromNodeIdentity { .. } => { - write!(f, "FromNodeIdentity") - }, - PeerAddressSource::FromPeerConnection { .. } => write!(f, "FromPeerConnection"), - PeerAddressSource::FromDiscovery { .. } => write!(f, "FromDiscovery"), - PeerAddressSource::FromAnotherPeer { .. } => write!(f, "FromAnotherPeer"), - PeerAddressSource::FromJoinMessage { .. } => write!(f, "FromJoinMessage"), - } - } -} - -impl PartialEq for PeerAddressSource { - fn eq(&self, other: &Self) -> bool { - match self { - PeerAddressSource::Config => { - matches!(other, PeerAddressSource::Config) - }, - PeerAddressSource::FromNodeIdentity { .. } => { - matches!(other, PeerAddressSource::FromNodeIdentity { .. }) - }, - PeerAddressSource::FromPeerConnection { .. } => { - matches!(other, PeerAddressSource::FromPeerConnection { .. }) - }, - PeerAddressSource::FromAnotherPeer { .. } => { - matches!(other, PeerAddressSource::FromAnotherPeer { .. }) - }, - PeerAddressSource::FromDiscovery { .. } => { - matches!(other, PeerAddressSource::FromDiscovery { .. }) - }, - PeerAddressSource::FromJoinMessage { .. } => { - matches!(other, PeerAddressSource::FromJoinMessage { .. }) - }, - } - } + last_attempted: Option, + last_failed_reason: Option, + quality_score: i32, + source: PeerAddressSource, } impl MultiaddrWithStats { /// Constructs a new net address with zero stats pub fn new(address: Multiaddr, source: PeerAddressSource) -> Self { - Self { + let mut addr = Self { address, last_seen: None, connection_attempts: 0, @@ -130,7 +51,9 @@ impl MultiaddrWithStats { last_failed_reason: None, quality_score: 0, source, - } + }; + addr.update_quality_score(); + addr } pub fn merge(&mut self, other: &Self) { @@ -173,7 +96,7 @@ impl MultiaddrWithStats { } }, } - self.calculate_quality_score(); + self.update_quality_score(); } pub fn address(&self) -> &Multiaddr { @@ -199,13 +122,14 @@ impl MultiaddrWithStats { pub fn update_latency(&mut self, latency_measurement: Duration) { self.last_seen = Some(Utc::now().naive_utc()); - self.avg_latency = - ((self.avg_latency * self.latency_sample_count) + latency_measurement) / (self.latency_sample_count + 1); + self.avg_latency = ((self.avg_latency.saturating_mul(self.latency_sample_count)) + .saturating_add(latency_measurement)) / + (self.latency_sample_count + 1); if self.latency_sample_count < MAX_LATENCY_SAMPLE_COUNT { self.latency_sample_count += 1; } - self.calculate_quality_score(); + self.update_quality_score(); } pub fn update_initial_dial_time(&mut self, initial_dial_time: Duration) { @@ -217,15 +141,16 @@ impl MultiaddrWithStats { if self.initial_dial_time_sample_count < MAX_INITIAL_DIAL_TIME_SAMPLE_COUNT { self.initial_dial_time_sample_count += 1; } - self.calculate_quality_score(); + self.update_quality_score(); } /// Mark that a successful interaction occurred with this address - pub fn mark_last_seen_now(&mut self) { + pub fn mark_last_seen_now(&mut self) -> &mut Self { self.last_seen = Some(Utc::now().naive_utc()); self.last_failed_reason = None; self.reset_connection_attempts(); - self.calculate_quality_score(); + self.update_quality_score(); + self } /// Reset the connection attempts on this net address for a later session of retries @@ -234,10 +159,24 @@ impl MultiaddrWithStats { } /// Mark that a connection could not be established with this net address - pub fn mark_failed_connection_attempt(&mut self, error_string: String) { + pub fn mark_failed_connection_attempt(&mut self, error_string: String) -> &mut Self { self.connection_attempts += 1; self.last_failed_reason = Some(error_string); - self.calculate_quality_score(); + self.update_quality_score(); + self + } + + #[cfg(test)] + pub fn mark_last_attempted(&mut self, timestamp: NaiveDateTime) -> &mut Self { + self.last_attempted = Some(timestamp); + self.update_quality_score(); + self + } + + pub fn mark_last_attempted_now(&mut self) -> &mut Self { + self.last_attempted = Some(Utc::now().naive_utc()); + self.update_quality_score(); + self } /// Get as a Multiaddr @@ -245,18 +184,21 @@ impl MultiaddrWithStats { self.clone().address } - fn calculate_quality_score(&mut self) { - // Try these first + fn calculate_quality_score(&self) -> i32 { + // If we have never seen or attempted the peer, we start with a high score to ensure that if self.last_seen.is_none() && self.last_attempted.is_none() { - self.quality_score = 1000; - return; + return HIGH_QUALITY_SCORE; } let mut score_self = 0; - // explicitly truncate the latency to avoid casting problems - let avg_latency_millis = i32::try_from(self.avg_latency.as_millis()).unwrap_or(i32::MAX); - score_self += cmp::max(0, 100 - (avg_latency_millis / 100)); + if self.avg_latency.as_millis() == 0 { + score_self += 100; + } else { + // explicitly truncate the latency to avoid casting problems + let avg_latency_millis = i32::try_from(self.avg_latency.as_millis()).unwrap_or(i32::MAX); + score_self += cmp::max(0, 100i32.saturating_sub(avg_latency_millis / 100)); + } let last_seen_seconds: i32 = self .last_seen @@ -265,18 +207,58 @@ impl MultiaddrWithStats { .unwrap_or(0) .try_into() .unwrap_or(i32::MAX); - score_self += cmp::max(0, 100 - last_seen_seconds); + score_self += cmp::max(0, 100i32.saturating_sub(last_seen_seconds)); if self.last_failed_reason.is_some() { score_self -= 100; } - self.quality_score = score_self; + score_self + } + + fn update_quality_score(&mut self) { + self.quality_score = self.calculate_quality_score(); } pub fn source(&self) -> &PeerAddressSource { &self.source } + + pub fn last_seen(&self) -> Option { + self.last_seen + } + + pub fn connection_attempts(&self) -> u32 { + self.connection_attempts + } + + pub fn avg_initial_dial_time(&self) -> Duration { + self.avg_initial_dial_time + } + + pub fn initial_dial_time_sample_count(&self) -> u32 { + self.initial_dial_time_sample_count + } + + pub fn avg_latency(&self) -> Duration { + self.avg_latency + } + + pub fn latency_sample_count(&self) -> u32 { + self.latency_sample_count + } + + pub fn last_attempted(&self) -> Option { + self.last_attempted + } + + pub fn last_failed_reason(&self) -> Option<&str> { + self.last_failed_reason.as_deref() + } + + pub fn quality_score(&self) -> i32 { + self.quality_score + } } // Reliability ordering of net addresses: prioritize net addresses according to previous successful connections, @@ -284,7 +266,7 @@ impl MultiaddrWithStats { // priority, this ordering switch allows searching for, and updating of net addresses to be performed more efficiently impl Ord for MultiaddrWithStats { fn cmp(&self, other: &MultiaddrWithStats) -> Ordering { - self.quality_score.cmp(&other.quality_score).reverse() + self.quality_score.cmp(&other.quality_score) } } @@ -306,12 +288,91 @@ impl Hash for MultiaddrWithStats { } } -impl fmt::Display for MultiaddrWithStats { +impl Display for MultiaddrWithStats { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.address) } } +#[derive(Debug, Clone, Serialize, Deserialize, Eq)] +pub enum PeerAddressSource { + Config, + FromNodeIdentity { + peer_identity_claim: PeerIdentityClaim, + }, + FromPeerConnection { + peer_identity_claim: PeerIdentityClaim, + }, + FromDiscovery { + peer_identity_claim: PeerIdentityClaim, + }, + FromAnotherPeer { + peer_identity_claim: PeerIdentityClaim, + source_peer: CommsPublicKey, + }, + FromJoinMessage { + peer_identity_claim: PeerIdentityClaim, + }, +} + +impl PeerAddressSource { + pub fn is_config(&self) -> bool { + matches!(self, PeerAddressSource::Config) + } + + pub fn peer_identity_claim(&self) -> Option<&PeerIdentityClaim> { + match self { + PeerAddressSource::Config => None, + PeerAddressSource::FromNodeIdentity { peer_identity_claim } => Some(peer_identity_claim), + PeerAddressSource::FromPeerConnection { peer_identity_claim } => Some(peer_identity_claim), + PeerAddressSource::FromDiscovery { peer_identity_claim } => Some(peer_identity_claim), + PeerAddressSource::FromAnotherPeer { + peer_identity_claim, .. + } => Some(peer_identity_claim), + PeerAddressSource::FromJoinMessage { peer_identity_claim } => Some(peer_identity_claim), + } + } +} + +impl Display for PeerAddressSource { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + PeerAddressSource::Config => write!(f, "Config"), + PeerAddressSource::FromNodeIdentity { .. } => { + write!(f, "FromNodeIdentity") + }, + PeerAddressSource::FromPeerConnection { .. } => write!(f, "FromPeerConnection"), + PeerAddressSource::FromDiscovery { .. } => write!(f, "FromDiscovery"), + PeerAddressSource::FromAnotherPeer { .. } => write!(f, "FromAnotherPeer"), + PeerAddressSource::FromJoinMessage { .. } => write!(f, "FromJoinMessage"), + } + } +} + +impl PartialEq for PeerAddressSource { + fn eq(&self, other: &Self) -> bool { + match self { + PeerAddressSource::Config => { + matches!(other, PeerAddressSource::Config) + }, + PeerAddressSource::FromNodeIdentity { .. } => { + matches!(other, PeerAddressSource::FromNodeIdentity { .. }) + }, + PeerAddressSource::FromPeerConnection { .. } => { + matches!(other, PeerAddressSource::FromPeerConnection { .. }) + }, + PeerAddressSource::FromAnotherPeer { .. } => { + matches!(other, PeerAddressSource::FromAnotherPeer { .. }) + }, + PeerAddressSource::FromDiscovery { .. } => { + matches!(other, PeerAddressSource::FromDiscovery { .. }) + }, + PeerAddressSource::FromJoinMessage { .. } => { + matches!(other, PeerAddressSource::FromJoinMessage { .. }) + }, + } + } +} #[cfg(test)] mod test { use std::time::Duration; @@ -359,4 +420,21 @@ mod test { net_address_with_stats.reset_connection_attempts(); assert_eq!(net_address_with_stats.connection_attempts, 0); } + + #[test] + fn test_calculate_quality_score() { + let address = "/ip4/123.0.0.123/tcp/8000".parse().unwrap(); + let mut address = MultiaddrWithStats::new(address, PeerAddressSource::Config); + assert_eq!(address.quality_score, 1000); + address.mark_last_seen_now(); + assert!(address.quality_score > 100); + address.mark_failed_connection_attempt("Testing".to_string()); + assert!(address.quality_score <= 100); + + let another_addr = "/ip4/1.0.0.1/tcp/8000".parse().unwrap(); + let another_addr = MultiaddrWithStats::new(another_addr, PeerAddressSource::Config); + assert_eq!(another_addr.quality_score, 1000); + + assert_eq!(another_addr.cmp(&address), Ordering::Greater); + } } diff --git a/comms/core/src/net_address/mutliaddresses_with_stats.rs b/comms/core/src/net_address/mutliaddresses_with_stats.rs index ca751c8d5c..ea08165ba3 100644 --- a/comms/core/src/net_address/mutliaddresses_with_stats.rs +++ b/comms/core/src/net_address/mutliaddresses_with_stats.rs @@ -2,12 +2,13 @@ // SPDX-License-Identifier: BSD-3-Clause use std::{ + cmp, fmt::{Display, Formatter}, ops::Index, time::Duration, }; -use chrono::{NaiveDateTime, Utc}; +use chrono::NaiveDateTime; use multiaddr::Multiaddr; use serde::{Deserialize, Serialize}; @@ -48,76 +49,39 @@ impl MultiaddressesWithStats { /// Provides the date and time of the last successful communication with this peer pub fn last_seen(&self) -> Option { - let mut latest_valid_datetime: Option = None; - for curr_address in &self.addresses { - if curr_address.last_seen.is_none() { - continue; - } - match latest_valid_datetime { - Some(latest_datetime) => { - if latest_datetime < curr_address.last_seen.unwrap() { - latest_valid_datetime = curr_address.last_seen; - } - }, - None => latest_valid_datetime = curr_address.last_seen, - } - } - latest_valid_datetime + self.addresses + .iter() + .max_by_key(|a| a.last_seen()) + .and_then(|a| a.last_seen()) } pub fn offline_at(&self) -> Option { - let mut earliest_offline_at: Option = None; - for curr_address in &self.addresses { - // At least one address is online - #[allow(clippy::question_mark)] - if curr_address.offline_at().is_none() { - return None; - } - match earliest_offline_at { - Some(earliest_datetime) => { - if earliest_datetime > curr_address.offline_at().unwrap() { - earliest_offline_at = curr_address.offline_at(); - } - }, - None => earliest_offline_at = curr_address.offline_at(), - } - } - earliest_offline_at + self.addresses + .iter() + .min_by_key(|a| a.offline_at()) + .and_then(|a| a.offline_at()) } /// Return the time of last attempted connection to this collection of addresses pub fn last_attempted(&self) -> Option { - let mut latest_valid_datetime: Option = None; - for curr_address in &self.addresses { - if curr_address.last_attempted.is_none() { - continue; - } - match latest_valid_datetime { - Some(latest_datetime) => { - if latest_datetime < curr_address.last_attempted.unwrap() { - latest_valid_datetime = curr_address.last_attempted; - } - }, - None => latest_valid_datetime = curr_address.last_attempted, - } - } - latest_valid_datetime + self.addresses + .iter() + .max_by_key(|a| a.last_attempted()) + .and_then(|a| a.last_attempted()) } /// Adds a new net address to the peer. This function will not add a duplicate if the address /// already exists. pub fn add_address(&mut self, net_address: &Multiaddr, source: &PeerAddressSource) { - if self.addresses.iter().any(|x| x.address() == net_address) { - self.addresses - .iter_mut() - .find(|x| x.address() == net_address) - .unwrap() - .update_source_if_better(source); + if let Some(addr_mut) = self.addresses.iter_mut().find(|x| x.address() == net_address) { + addr_mut.update_source_if_better(source); } else { self.addresses .push(MultiaddrWithStats::new(net_address.clone(), source.clone())); - self.addresses.sort(); } + + // Ensure that the addresses are sorted by quality + self.sort_addresses(); } pub fn contains(&self, net_address: &Multiaddr) -> bool { @@ -143,25 +107,21 @@ impl MultiaddressesWithStats { .push(MultiaddrWithStats::new(address.clone(), source.clone())); } - self.addresses.sort(); + self.sort_addresses(); + } + + /// Returns an iterator of addresses with states ordered from 'best' to 'worst' according to heuristics such as + /// failed connections and latency. + pub fn iter(&self) -> impl Iterator { + self.addresses.iter() } /// Returns an iterator of addresses ordered from 'best' to 'worst' according to heuristics such as failed /// connections and latency. - pub fn iter(&self) -> impl Iterator { + pub fn address_iter(&self) -> impl Iterator { self.addresses.iter().map(|addr| addr.address()) } - pub fn to_lexicographically_sorted(&self) -> Vec { - let mut addresses = self.iter().cloned().collect::>(); - addresses.sort_by(|a, b| { - let bytes_a = a.as_ref(); - let bytes_b = b.as_ref(); - bytes_a.cmp(bytes_b) - }); - addresses - } - pub fn merge(&mut self, other: &MultiaddressesWithStats) { for addr in &other.addresses { if let Some(existing) = self.find_address_mut(addr.address()) { @@ -185,7 +145,7 @@ impl MultiaddressesWithStats { match self.find_address_mut(address) { Some(addr) => { addr.update_latency(latency_measurement); - self.addresses.sort(); + self.sort_addresses(); true }, None => false, @@ -196,7 +156,7 @@ impl MultiaddressesWithStats { where F: FnOnce(&mut MultiaddrWithStats) { if let Some(addr) = self.find_address_mut(address) { f(addr); - self.addresses.sort(); + self.sort_addresses(); } } @@ -206,9 +166,8 @@ impl MultiaddressesWithStats { pub fn mark_last_seen_now(&mut self, address: &Multiaddr) -> bool { match self.find_address_mut(address) { Some(addr) => { - addr.mark_last_seen_now(); - addr.last_attempted = Some(Utc::now().naive_utc()); - self.addresses.sort(); + addr.mark_last_seen_now().mark_last_attempted_now(); + self.sort_addresses(); true }, None => false, @@ -222,8 +181,8 @@ impl MultiaddressesWithStats { match self.find_address_mut(address) { Some(addr) => { addr.mark_failed_connection_attempt(failed_reason); - addr.last_attempted = Some(Utc::now().naive_utc()); - self.addresses.sort(); + addr.mark_last_attempted_now(); + self.sort_addresses(); true }, None => false, @@ -237,7 +196,7 @@ impl MultiaddressesWithStats { for a in &mut self.addresses { a.reset_connection_attempts(); } - self.addresses.sort(); + self.sort_addresses(); } /// Returns the number of addresses @@ -257,6 +216,11 @@ impl MultiaddressesWithStats { pub fn addresses(&self) -> &[MultiaddrWithStats] { &self.addresses } + + /// Sort the addresses with the greatest quality score first + fn sort_addresses(&mut self) { + self.addresses.sort_by_key(|addr| cmp::Reverse(addr.quality_score())) + } } impl PartialEq for MultiaddressesWithStats { @@ -342,8 +306,8 @@ mod test { let desired_last_seen = net_addresses .addresses .iter() - .max_by_key(|a| a.last_seen) - .map(|a| a.last_seen.unwrap()); + .max_by_key(|a| a.last_seen()) + .map(|a| a.last_seen().unwrap()); let last_seen = net_addresses.last_seen(); assert_eq!(desired_last_seen.unwrap(), last_seen.unwrap()); } @@ -357,7 +321,7 @@ mod test { MultiaddressesWithStats::from_addresses_with_source(vec![net_address1.clone()], &PeerAddressSource::Config); net_addresses.add_address(&net_address2, &PeerAddressSource::Config); net_addresses.add_address(&net_address3, &PeerAddressSource::Config); - // Add duplicate address, test add_net_address is idempotent + // Add duplicate address, this resets the quality score net_addresses.add_address(&net_address2, &PeerAddressSource::Config); assert_eq!(net_addresses.addresses.len(), 3); assert_eq!(net_addresses.addresses[0].address(), &net_address1); @@ -375,7 +339,7 @@ mod test { net_addresses.add_address(&net_address2, &PeerAddressSource::Config); net_addresses.add_address(&net_address3, &PeerAddressSource::Config); - let priority_address = net_addresses.iter().next().unwrap(); + let priority_address = net_addresses.address_iter().next().unwrap(); assert_eq!(priority_address, &net_address1); net_addresses.mark_last_seen_now(&net_address1); @@ -384,11 +348,11 @@ mod test { assert!(net_addresses.update_latency(&net_address1, Duration::from_millis(250))); assert!(net_addresses.update_latency(&net_address2, Duration::from_millis(50))); assert!(net_addresses.update_latency(&net_address3, Duration::from_millis(100))); - let priority_address = net_addresses.iter().next().unwrap(); + let priority_address = net_addresses.address_iter().next().unwrap(); assert_eq!(priority_address, &net_address2); assert!(net_addresses.mark_failed_connection_attempt(&net_address2, "error".to_string())); - let priority_address = net_addresses.iter().next().unwrap(); + let priority_address = net_addresses.address_iter().next().unwrap(); assert_eq!(priority_address, &net_address3); } @@ -408,12 +372,12 @@ mod test { assert!(net_addresses.mark_failed_connection_attempt(&net_address3, "error".to_string())); assert!(net_addresses.mark_failed_connection_attempt(&net_address1, "error".to_string())); - assert_eq!(net_addresses.addresses[0].connection_attempts, 1); - assert_eq!(net_addresses.addresses[1].connection_attempts, 1); - assert_eq!(net_addresses.addresses[2].connection_attempts, 2); + assert_eq!(net_addresses.addresses[0].connection_attempts(), 1); + assert_eq!(net_addresses.addresses[1].connection_attempts(), 1); + assert_eq!(net_addresses.addresses[2].connection_attempts(), 2); net_addresses.reset_connection_attempts(); - assert_eq!(net_addresses.addresses[0].connection_attempts, 0); - assert_eq!(net_addresses.addresses[1].connection_attempts, 0); - assert_eq!(net_addresses.addresses[2].connection_attempts, 0); + assert_eq!(net_addresses.addresses[0].connection_attempts(), 0); + assert_eq!(net_addresses.addresses[1].connection_attempts(), 0); + assert_eq!(net_addresses.addresses[2].connection_attempts(), 0); } } diff --git a/comms/core/src/peer_manager/error.rs b/comms/core/src/peer_manager/error.rs index 33443745f8..ca8c72fe1e 100644 --- a/comms/core/src/peer_manager/error.rs +++ b/comms/core/src/peer_manager/error.rs @@ -22,9 +22,12 @@ use std::sync::PoisonError; +use multiaddr::Multiaddr; use tari_storage::KeyValStoreError; use thiserror::Error; +use crate::peer_manager::NodeId; + /// Error type for [PeerManager](super::PeerManager). #[derive(Debug, Error, Clone)] pub enum PeerManagerError { @@ -46,6 +49,10 @@ pub enum PeerManagerError { MultiaddrError(String), #[error("Unable to parse any of the network addresses offered by the connecting peer")] PeerIdentityNoValidAddresses, + #[error("Invalid peer feature bits '{bits:#x}'")] + InvalidPeerFeatures { bits: u32 }, + #[error("Address {address} not found for peer {node_id}")] + AddressNotFoundError { address: Multiaddr, node_id: NodeId }, } impl PeerManagerError { diff --git a/comms/core/src/peer_manager/manager.rs b/comms/core/src/peer_manager/manager.rs index 205ce3a123..f455d456e2 100644 --- a/comms/core/src/peer_manager/manager.rs +++ b/comms/core/src/peer_manager/manager.rs @@ -236,24 +236,25 @@ impl PeerManager { .closest_peers(node_id, n, excluded_peers, features) } - pub async fn mark_last_seen( - &self, - node_id: &NodeId, - addr: &Multiaddr, - source: &PeerAddressSource, - ) -> Result<(), PeerManagerError> { + pub async fn mark_last_seen(&self, node_id: &NodeId, addr: &Multiaddr) -> Result<(), PeerManagerError> { let mut lock = self.peer_storage.write().await; - let peer = lock.find_by_node_id(node_id)?; - match peer { - Some(mut peer) => { - // if we have an address, update it - peer.addresses.add_address(addr, source); - peer.addresses.mark_last_seen_now(addr); - lock.add_peer(peer)?; - Ok(()) - }, - None => Err(PeerManagerError::PeerNotFoundError), - } + let mut peer = lock + .find_by_node_id(node_id)? + .ok_or(PeerManagerError::PeerNotFoundError)?; + let source = peer + .addresses + .iter() + .find(|a| a.address() == addr) + .map(|a| a.source().clone()) + .ok_or_else(|| PeerManagerError::AddressNotFoundError { + address: addr.clone(), + node_id: node_id.clone(), + })?; + // if we have an address, update it + peer.addresses.add_address(addr, &source); + peer.addresses.mark_last_seen_now(addr); + lock.add_peer(peer)?; + Ok(()) } /// Fetch n random peers diff --git a/comms/core/src/peer_manager/mod.rs b/comms/core/src/peer_manager/mod.rs index 1eccafa8d2..dba01e9d69 100644 --- a/comms/core/src/peer_manager/mod.rs +++ b/comms/core/src/peer_manager/mod.rs @@ -70,9 +70,6 @@ //! let returned_peer = peer_manager.find_by_node_id(&node_id).unwrap(); //! ``` -/// The maximum size of the peer's user agent string. If the peer sends a longer string it is truncated. -const MAX_USER_AGENT_LEN: usize = 100; - mod error; pub use error::PeerManagerError; diff --git a/comms/core/src/peer_manager/node_identity.rs b/comms/core/src/peer_manager/node_identity.rs index 69c82818fe..b0a85f9f06 100644 --- a/comms/core/src/peer_manager/node_identity.rs +++ b/comms/core/src/peer_manager/node_identity.rs @@ -235,7 +235,6 @@ impl NodeIdentity { &self.public_addresses(), Utc::now(), ), - unverified_data: None, }; let peer = Peer::new( self.public_key().clone(), diff --git a/comms/core/src/peer_manager/peer.rs b/comms/core/src/peer_manager/peer.rs index b21fb14e23..3c4d9268c3 100644 --- a/comms/core/src/peer_manager/peer.rs +++ b/comms/core/src/peer_manager/peer.rs @@ -159,20 +159,11 @@ impl Peer { } pub fn last_connect_attempt(&self) -> Option { - let mut last_connected_attempt = None; - for address in self.addresses.addresses() { - if let Some(address_time) = address.last_attempted { - match last_connected_attempt { - Some(last_time) => { - if last_time < address_time { - last_connected_attempt = address.last_attempted - } - }, - None => last_connected_attempt = address.last_attempted, - } - } - } - last_connected_attempt + self.addresses + .addresses() + .iter() + .max_by_key(|a| a.last_attempted()) + .and_then(|a| a.last_attempted()) } /// Returns true if the peer is marked as offline @@ -212,11 +203,6 @@ impl Peer { .and_then(|dt| Utc::now().naive_utc().signed_duration_since(dt).to_std().ok()) } - /// Returns true if this peer has the given feature, otherwise false - pub fn has_features(&self, features: PeerFeatures) -> bool { - self.features.contains(features) - } - /// Returns the ban status of the peer pub fn is_banned(&self) -> bool { self.banned_until().is_some() diff --git a/comms/core/src/peer_manager/peer_features.rs b/comms/core/src/peer_manager/peer_features.rs index e41d30177c..70bff86df4 100644 --- a/comms/core/src/peer_manager/peer_features.rs +++ b/comms/core/src/peer_manager/peer_features.rs @@ -28,7 +28,7 @@ use serde::{Deserialize, Serialize}; bitflags! { /// Peer feature flags. These advertised the capabilities of peer nodes. #[derive(Serialize, Deserialize)] - pub struct PeerFeatures: u64 { + pub struct PeerFeatures: u32 { /// No capabilities const NONE = 0b0000_0000; /// Node is able to propagate messages diff --git a/comms/core/src/peer_manager/peer_identity_claim.rs b/comms/core/src/peer_manager/peer_identity_claim.rs index 78e5f06c47..7112f7bd82 100644 --- a/comms/core/src/peer_manager/peer_identity_claim.rs +++ b/comms/core/src/peer_manager/peer_identity_claim.rs @@ -26,9 +26,9 @@ use multiaddr::Multiaddr; use serde_derive::{Deserialize, Serialize}; use crate::{ - peer_manager::{IdentitySignature, PeerFeatures, PeerManagerError, MAX_USER_AGENT_LEN}, + peer_manager::{IdentitySignature, PeerFeatures, PeerManagerError}, proto::identity::PeerIdentityMsg, - protocol::ProtocolId, + types::CommsPublicKey, }; #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -36,51 +36,31 @@ pub struct PeerIdentityClaim { pub addresses: Vec, pub features: PeerFeatures, pub signature: IdentitySignature, - pub unverified_data: Option, } impl PeerIdentityClaim { - pub fn new( - addresses: Vec, - features: PeerFeatures, - signature: IdentitySignature, - unverified_data: Option, - ) -> Self { + pub fn new(addresses: Vec, features: PeerFeatures, signature: IdentitySignature) -> Self { Self { addresses, features, signature, - unverified_data, } } - pub fn supported_protocols(&self) -> Vec { - self.unverified_data - .as_ref() - .map(|d| d.supported_protocols.clone()) - .unwrap_or_default() - } - - pub fn user_agent(&self) -> Option { - self.unverified_data.as_ref().map(|d| d.user_agent.clone()) + pub fn is_valid(&self, public_key: &CommsPublicKey) -> bool { + self.signature.is_valid(public_key, self.features, &self.addresses) } } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub struct PeerIdentityClaimUnverifiedData { - pub user_agent: String, - pub supported_protocols: Vec, -} - impl TryFrom for PeerIdentityClaim { type Error = PeerManagerError; fn try_from(value: PeerIdentityMsg) -> Result { - let addresses: Vec = value + let addresses = value .addresses - .iter() - .map(|addr_bytes| Multiaddr::try_from(addr_bytes.clone())) - .collect::>() + .into_iter() + .map(Multiaddr::try_from) + .collect::, _>>() .map_err(|e| PeerManagerError::MultiaddrError(e.to_string()))?; if addresses.is_empty() { @@ -88,24 +68,11 @@ impl TryFrom for PeerIdentityClaim { } let features = PeerFeatures::from_bits_truncate(value.features); - let supported_protocols = value - .supported_protocols - .iter() - .map(|p| bytes::Bytes::from(p.clone())) - .collect::>(); - - let mut user_agent = value.user_agent; - user_agent.truncate(MAX_USER_AGENT_LEN); - if let Some(signature) = value.identity_signature { Ok(Self { addresses, features, signature: signature.try_into()?, - unverified_data: Some(PeerIdentityClaimUnverifiedData { - user_agent, - supported_protocols, - }), }) } else { Err(PeerManagerError::MissingIdentitySignature) diff --git a/comms/core/src/peer_manager/peer_storage.rs b/comms/core/src/peer_manager/peer_storage.rs index dc611cdf7f..9a13381337 100644 --- a/comms/core/src/peer_manager/peer_storage.rs +++ b/comms/core/src/peer_manager/peer_storage.rs @@ -853,7 +853,7 @@ mod test { let mut not_active_peer = create_test_peer(PeerFeatures::COMMUNICATION_NODE, false); let address = not_active_peer.addresses.best().unwrap(); let mut address = MultiaddrWithStats::new(address.address().clone(), PeerAddressSource::Config); - address.last_seen = Some(NaiveDateTime::from_timestamp_opt(a_week_ago, 0).unwrap()); + address.mark_last_attempted(NaiveDateTime::from_timestamp_opt(a_week_ago, 0).unwrap()); not_active_peer .addresses .merge(&MultiaddressesWithStats::from(vec![address])); diff --git a/comms/core/src/peer_validator/config.rs b/comms/core/src/peer_validator/config.rs new file mode 100644 index 0000000000..c5e04a7806 --- /dev/null +++ b/comms/core/src/peer_validator/config.rs @@ -0,0 +1,54 @@ +// Copyright 2023 The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use serde_derive::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct PeerValidatorConfig { + /// The maximum size of the peer's user agent string. Some unicode characters use more than a single byte + /// and this specifies the maximum in bytes, as opposed to unicode characters. + pub max_user_agent_byte_length: usize, + pub max_permitted_peer_addresses_per_claim: usize, + pub max_supported_protocols: usize, + pub max_protocol_id_length: usize, + + /// Set to true to allow peers to send loopback, local-link and other addresses normally not considered valid for + /// peer-to-peer comms. Default: false + pub allow_test_addresses: bool, +} + +impl Default for PeerValidatorConfig { + fn default() -> Self { + Self { + max_user_agent_byte_length: 50, + max_permitted_peer_addresses_per_claim: 5, + max_supported_protocols: 20, + max_protocol_id_length: 50, + #[cfg(not(test))] + allow_test_addresses: false, + // This must always be true for internal crate tests + #[cfg(test)] + allow_test_addresses: true, + } + } +} diff --git a/comms/core/src/peer_validator/error.rs b/comms/core/src/peer_validator/error.rs new file mode 100644 index 0000000000..8afe01df4a --- /dev/null +++ b/comms/core/src/peer_validator/error.rs @@ -0,0 +1,56 @@ +// Copyright 2023, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use std::time::Duration; + +use crate::{bans::BAN_DURATION_LONG, peer_manager::NodeId}; + +/// Validation errors for peers shared on the network +#[derive(Debug, Clone, thiserror::Error)] +pub enum PeerValidatorError { + #[error("Peer signature was invalid for peer '{peer}'")] + InvalidPeerSignature { peer: NodeId }, + #[error("One or more peer addresses were invalid for '{peer}'")] + InvalidPeerAddresses { peer: NodeId }, + #[error("Peer '{peer}' was banned")] + PeerHasNoAddresses { peer: NodeId }, + #[error("Invalid multiaddr: {0}")] + InvalidMultiaddr(String), + #[error("No public addresses provided")] + PeerIdentityNoAddresses, + #[error("Onion v2 is deprecated and not supported")] + OnionV2NotSupported, + #[error("Peer provided too many supported protocols: expected max {max} but got {length}")] + PeerIdentityTooManyProtocols { length: usize, max: usize }, + #[error("Peer provided too many addresses: expected max {max} but got {length}")] + PeerIdentityTooManyAddresses { length: usize, max: usize }, + #[error("Peer provided a protocol id that exceeds the maximum length: expected max {max} but got {length}")] + PeerIdentityProtocolIdTooLong { length: usize, max: usize }, + #[error("Peer provided a user agent that exceeds the maximum length: expected max {max} but got {length}")] + PeerIdentityUserAgentTooLong { length: usize, max: usize }, +} + +impl PeerValidatorError { + pub fn as_ban_duration(&self) -> Option { + Some(BAN_DURATION_LONG) + } +} diff --git a/comms/core/src/peer_validator/helpers.rs b/comms/core/src/peer_validator/helpers.rs new file mode 100644 index 0000000000..99e1eaa386 --- /dev/null +++ b/comms/core/src/peer_validator/helpers.rs @@ -0,0 +1,288 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use std::net::Ipv6Addr; + +use digest::Digest; + +use crate::{ + multiaddr::{Multiaddr, Protocol}, + peer_manager::{NodeId, PeerIdentityClaim}, + peer_validator::{error::PeerValidatorError, PeerValidatorConfig}, + types::CommsPublicKey, +}; + +/// Checks that the given peer addresses are well-formed and valid. If allow_test_addrs is false, all localhost and +/// memory addresses will be rejected. +pub fn validate_addresses(config: &PeerValidatorConfig, addresses: &[Multiaddr]) -> Result<(), PeerValidatorError> { + if addresses.is_empty() { + return Err(PeerValidatorError::PeerIdentityNoAddresses); + } + + if addresses.len() > config.max_permitted_peer_addresses_per_claim { + return Err(PeerValidatorError::PeerIdentityTooManyAddresses { + length: addresses.len(), + max: config.max_permitted_peer_addresses_per_claim, + }); + } + for addr in addresses { + validate_address(addr, config.allow_test_addresses)?; + } + + Ok(()) +} + +pub fn find_most_recent_claim<'a, I: IntoIterator>( + claims: I, +) -> Option<&'a PeerIdentityClaim> { + claims.into_iter().max_by_key(|c| c.signature.updated_at()) +} + +pub fn validate_peer_identity_claim( + config: &PeerValidatorConfig, + public_key: &CommsPublicKey, + claim: &PeerIdentityClaim, +) -> Result<(), PeerValidatorError> { + validate_addresses(config, &claim.addresses)?; + + if !claim.is_valid(public_key) { + return Err(PeerValidatorError::InvalidPeerSignature { + peer: NodeId::from_public_key(public_key), + }); + } + + Ok(()) +} +fn validate_address(addr: &Multiaddr, allow_test_addrs: bool) -> Result<(), PeerValidatorError> { + let mut addr_iter = addr.iter(); + let proto = addr_iter + .next() + .ok_or_else(|| PeerValidatorError::InvalidMultiaddr("Multiaddr was empty".to_string()))?; + + /// Returns [true] if the address is a unicast link-local address (fe80::/10). + /// Taken from stdlib + #[inline] + const fn is_unicast_link_local(addr: &Ipv6Addr) -> bool { + (addr.segments()[0] & 0xffc0) == 0xfe80 + } + + match proto { + Protocol::Dns4(_) | Protocol::Dns6(_) | Protocol::Dnsaddr(_) => { + let tcp = addr_iter.next().ok_or_else(|| { + PeerValidatorError::InvalidMultiaddr("Address does not include a TCP port".to_string()) + })?; + + validate_tcp_port(tcp)?; + expect_end_of_address(addr_iter) + }, + + Protocol::Ip4(addr) + if !allow_test_addrs && (addr.is_loopback() || addr.is_link_local() || addr.is_unspecified()) => + { + Err(PeerValidatorError::InvalidMultiaddr( + "Non-global IP addresses are invalid".to_string(), + )) + }, + Protocol::Ip6(addr) + if !allow_test_addrs && (addr.is_loopback() || is_unicast_link_local(&addr) || addr.is_unspecified()) => + { + Err(PeerValidatorError::InvalidMultiaddr( + "Non-global IP addresses are invalid".to_string(), + )) + }, + Protocol::Ip4(_) | Protocol::Ip6(_) => { + let tcp = addr_iter.next().ok_or_else(|| { + PeerValidatorError::InvalidMultiaddr("Address does not include a TCP port".to_string()) + })?; + + validate_tcp_port(tcp)?; + expect_end_of_address(addr_iter) + }, + Protocol::Memory(0) => Err(PeerValidatorError::InvalidMultiaddr( + "Cannot connect to a zero memory port".to_string(), + )), + Protocol::Memory(_) if allow_test_addrs => expect_end_of_address(addr_iter), + Protocol::Memory(_) => Err(PeerValidatorError::InvalidMultiaddr( + "Memory addresses are invalid".to_string(), + )), + // Zero-port onions should have already failed when parsing. Keep these checks here just in case. + Protocol::Onion(_, 0) => Err(PeerValidatorError::InvalidMultiaddr( + "A zero onion port is not valid in the onion spec".to_string(), + )), + Protocol::Onion3(addr) if addr.port() == 0 => Err(PeerValidatorError::InvalidMultiaddr( + "A zero onion port is not valid in the onion spec".to_string(), + )), + Protocol::Onion(_, _) => Err(PeerValidatorError::OnionV2NotSupported), + Protocol::Onion3(addr) => { + expect_end_of_address(addr_iter)?; + validate_onion3_address(&addr) + }, + p => Err(PeerValidatorError::InvalidMultiaddr(format!( + "Unsupported address type '{}'", + p + ))), + } +} + +fn expect_end_of_address(mut iter: multiaddr::Iter<'_>) -> Result<(), PeerValidatorError> { + match iter.next() { + Some(p) => Err(PeerValidatorError::InvalidMultiaddr(format!( + "Unexpected multiaddress component '{}'", + p + ))), + None => Ok(()), + } +} + +fn validate_tcp_port(expected_tcp: Protocol) -> Result<(), PeerValidatorError> { + match expected_tcp { + Protocol::Tcp(0) => Err(PeerValidatorError::InvalidMultiaddr( + "Cannot connect to a zero TCP port".to_string(), + )), + Protocol::Tcp(_) => Ok(()), + p => Err(PeerValidatorError::InvalidMultiaddr(format!( + "Expected TCP address component but got '{}'", + p + ))), + } +} + +/// Validates the onion3 version and checksum as per https://github.com/torproject/torspec/blob/main/rend-spec-v3.txt#LL2258C6-L2258C6 +fn validate_onion3_address(addr: &multiaddr::Onion3Addr<'_>) -> Result<(), PeerValidatorError> { + const ONION3_PUBKEY_SIZE: usize = 32; + const ONION3_CHECKSUM_SIZE: usize = 2; + + let (pub_key, checksum_version) = addr.hash().split_at(ONION3_PUBKEY_SIZE); + let (checksum, version) = checksum_version.split_at(ONION3_CHECKSUM_SIZE); + + if version != b"\x03" { + return Err(PeerValidatorError::InvalidMultiaddr( + "Invalid version in onion address".to_string(), + )); + } + + let calculated_checksum = sha3::Sha3_256::new() + .chain_update(".onion checksum") + .chain_update(pub_key) + .chain_update(version) + .finalize(); + + if calculated_checksum[..2] != *checksum { + return Err(PeerValidatorError::InvalidMultiaddr( + "Invalid checksum in onion address".to_string(), + )); + } + + Ok(()) +} + +#[cfg(test)] +mod test { + use multiaddr::multiaddr; + + use super::*; + use crate::peer_validator::error::PeerValidatorError; + + #[test] + fn validate_address_strict() { + let valid = [ + multiaddr!(Ip4([172, 0, 0, 1]), Tcp(1u16)), + multiaddr!(Ip6([172, 0, 0, 1, 1, 1, 1, 1]), Tcp(1u16)), + "/onion3/vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd:1234" + .parse() + .unwrap(), + multiaddr!(Dnsaddr("mike-magic-nodes.com"), Tcp(1u16)), + ]; + + let invalid = &[ + "/onion/aaimaq4ygg2iegci:1234".parse().unwrap(), + multiaddr!(Ip4([127, 0, 0, 1]), Tcp(1u16)), + multiaddr!(Ip4([169, 254, 0, 1]), Tcp(1u16)), + multiaddr!(Ip4([172, 0, 0, 1])), + "/onion/aaimaq4ygg2iegci:1234/http".parse().unwrap(), + multiaddr!(Dnsaddr("mike-magic-nodes.com")), + multiaddr!(Memory(1234u64)), + multiaddr!(Memory(0u64)), + ]; + + for addr in valid { + validate_address(&addr, false).unwrap(); + } + for addr in invalid { + validate_address(addr, false).unwrap_err(); + } + } + + #[test] + fn validate_address_allow_test_addrs() { + let valid = [ + multiaddr!(Ip4([127, 0, 0, 1]), Tcp(1u16)), + multiaddr!(Ip4([169, 254, 0, 1]), Tcp(1u16)), + multiaddr!(Ip4([172, 0, 0, 1]), Tcp(1u16)), + multiaddr!(Ip6([172, 0, 0, 1, 1, 1, 1, 1]), Tcp(1u16)), + "/onion3/vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd:1234" + .parse() + .unwrap(), + multiaddr!(Dnsaddr("mike-magic-nodes.com"), Tcp(1u16)), + multiaddr!(Memory(1234u64)), + ]; + + let invalid = &[ + "/onion/aaimaq4ygg2iegci:1234".parse().unwrap(), + multiaddr!(Ip4([172, 0, 0, 1])), + "/onion/aaimaq4ygg2iegci:1234/http".parse().unwrap(), + multiaddr!(Dnsaddr("mike-magic-nodes.com")), + multiaddr!(Memory(0u64)), + ]; + + for addr in valid { + validate_address(&addr, true).unwrap(); + } + for addr in invalid { + validate_address(addr, true).unwrap_err(); + } + } + + #[test] + fn validate_onion3_checksum() { + let valid: Multiaddr = "/onion3/vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd:1234" + .parse() + .unwrap(); + + validate_address(&valid, false).unwrap(); + + // Change one byte + let invalid: Multiaddr = "/onion3/www6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd:1234" + .parse() + .unwrap(); + + validate_address(&invalid, false).unwrap_err(); + + // Randomly generated + let invalid: Multiaddr = "/onion3/pd6sf3mqkkkfrn4rk5odgcr2j5sn7m523a4tm7pzpuotk2b7rpuhaeym:1234" + .parse() + .unwrap(); + + let err = validate_address(&invalid, false).unwrap_err(); + assert!(matches!(err, PeerValidatorError::InvalidMultiaddr(_))); + } +} diff --git a/comms/core/src/peer_validator/mod.rs b/comms/core/src/peer_validator/mod.rs new file mode 100644 index 0000000000..7179fb3766 --- /dev/null +++ b/comms/core/src/peer_validator/mod.rs @@ -0,0 +1,31 @@ +// Copyright 2023, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +mod error; +pub use error::*; + +mod config; +pub use config::*; + +mod helpers; + +pub use helpers::*; diff --git a/comms/core/src/proto/identity.proto b/comms/core/src/proto/identity.proto index c1ba966691..6cf8f16557 100644 --- a/comms/core/src/proto/identity.proto +++ b/comms/core/src/proto/identity.proto @@ -9,7 +9,7 @@ package tari.comms.identity; message PeerIdentityMsg { repeated bytes addresses = 1; - uint64 features = 2; + uint32 features = 2; // Note: not part of the signature repeated bytes supported_protocols = 3; // Note: not part of the signature diff --git a/comms/core/src/protocol/identity.rs b/comms/core/src/protocol/identity.rs index add125a624..278d181810 100644 --- a/comms/core/src/protocol/identity.rs +++ b/comms/core/src/protocol/identity.rs @@ -31,10 +31,11 @@ use tokio::{ }; use crate::{ + bans::{BAN_DURATION_LONG, BAN_DURATION_SHORT}, message::MessageExt, peer_manager::NodeIdentity, proto::identity::PeerIdentityMsg, - protocol::{NodeNetworkInfo, ProtocolError, ProtocolId}, + protocol::{NodeNetworkInfo, ProtocolId}, }; const LOG_TARGET: &str = "comms::protocol::identity"; @@ -77,26 +78,35 @@ where socket.flush().await?; // Receive the connecting node's identity - let (version, msg_bytes) = time::timeout(Duration::from_secs(10), read_protocol_frame(socket)).await??; - let identity_msg = PeerIdentityMsg::decode(Bytes::from(msg_bytes))?; - - if version > network_info.major_version { - warn!( - target: LOG_TARGET, - "Peer sent mismatching major protocol version '{}'. This node has version '{}'", - version, - network_info.major_version - ); - return Err(IdentityProtocolError::ProtocolVersionMismatch); - } + let (_, msg_bytes) = time::timeout( + Duration::from_secs(10), + read_protocol_frame(socket, network_info.major_version), + ) + .await??; + debug!( + target: LOG_TARGET, + "Identity message received {} bytes", + msg_bytes.len() + ); + let identity_msg = PeerIdentityMsg::decode(Bytes::from(msg_bytes))?; Ok(identity_msg) } -async fn read_protocol_frame(socket: &mut S) -> Result<(u8, Vec), IdentityProtocolError> { +async fn read_protocol_frame( + socket: &mut S, + max_supported_version: u8, +) -> Result<(u8, Vec), IdentityProtocolError> { let mut buf = [0u8; 3]; socket.read_exact(&mut buf).await?; let version = buf[0]; + if version > max_supported_version { + return Err(IdentityProtocolError::UnsupportedProtocolVersion { + max_supported_version, + provided_version: version, + }); + } + let buf = [buf[1], buf[2]]; let len = u16::from_le_bytes(buf); if len > MAX_IDENTITY_PROTOCOL_MSG_SIZE { @@ -105,6 +115,7 @@ async fn read_protocol_frame(socket: &mut S) -> Result<(u8 got: len, }); } + let len = len as usize; let mut msg = vec![0u8; len]; socket.read_exact(&mut msg).await?; @@ -116,17 +127,17 @@ async fn write_protocol_frame( version: u8, msg_bytes: &[u8], ) -> Result<(), IdentityProtocolError> { - debug_assert!( - msg_bytes.len() <= MAX_IDENTITY_PROTOCOL_MSG_SIZE as usize, - "Sending identity protocol message of size {}, greater than {} bytes. This is a protocol violation", - msg_bytes.len(), - MAX_IDENTITY_PROTOCOL_MSG_SIZE - ); + if msg_bytes.len() > MAX_IDENTITY_PROTOCOL_MSG_SIZE as usize { + return Err(IdentityProtocolError::InvariantError(format!( + "Sending identity protocol message of size {}, greater than {} bytes. This is a protocol violation", + msg_bytes.len(), + MAX_IDENTITY_PROTOCOL_MSG_SIZE + ))); + } let len = u16::try_from(msg_bytes.len()).map_err(|_| { - IdentityProtocolError::ProtocolError(format!( - "Identity protocol attempted to send a message larger than u16::MAX bytes. len = {}", - msg_bytes.len() + IdentityProtocolError::InvariantError(format!( + "This node attempted to send a message of size greater than u16::MAX" )) })?; let version_bytes = [version]; @@ -148,31 +159,46 @@ async fn write_protocol_frame( pub enum IdentityProtocolError { #[error("IoError: {0}")] IoError(String), - #[error("ProtocolError: {0}")] - ProtocolError(String), + #[error("Possible bug: InvariantError {0}")] + InvariantError(String), #[error("ProtobufDecodeError: {0}")] ProtobufDecodeError(String), - #[error("Failed to encode protobuf message")] - ProtobufEncodingError, #[error("Peer unexpectedly closed the connection")] PeerUnexpectedCloseConnection, #[error("Timeout waiting for peer to send identity information")] Timeout, - #[error("Protocol version mismatch")] - ProtocolVersionMismatch, + #[error( + "Unsupported protocol version. Max supported version: {max_supported_version}, provided version: \ + {provided_version}" + )] + UnsupportedProtocolVersion { + max_supported_version: u8, + provided_version: u8, + }, #[error("Max identity protocol message size exceeded. Expected <= {expected} got {got}")] MaxMsgSizeExceeded { expected: u16, got: u16 }, } -impl From for IdentityProtocolError { - fn from(_: time::error::Elapsed) -> Self { - IdentityProtocolError::Timeout +impl IdentityProtocolError { + pub fn as_ban_duration(&self) -> Option { + match self { + // Don't ban + IdentityProtocolError::InvariantError(_) | IdentityProtocolError::IoError(_) => None, + // Long bans + IdentityProtocolError::ProtobufDecodeError(_) | IdentityProtocolError::MaxMsgSizeExceeded { .. } => { + Some(BAN_DURATION_LONG) + }, + // Short bans + IdentityProtocolError::PeerUnexpectedCloseConnection | + IdentityProtocolError::UnsupportedProtocolVersion { .. } | + IdentityProtocolError::Timeout => Some(BAN_DURATION_SHORT), + } } } -impl From for IdentityProtocolError { - fn from(err: ProtocolError) -> Self { - IdentityProtocolError::ProtocolError(err.to_string()) +impl From for IdentityProtocolError { + fn from(_: time::error::Elapsed) -> Self { + IdentityProtocolError::Timeout } } @@ -297,7 +323,7 @@ mod test { .await; let err = result1.unwrap_err(); - assert!(matches!(err, IdentityProtocolError::ProtocolVersionMismatch)); + assert!(matches!(err, IdentityProtocolError::UnsupportedProtocolVersion { .. })); // Passes because older versions are supported result2.unwrap(); diff --git a/comms/core/src/test_utils/mocks/peer_connection.rs b/comms/core/src/test_utils/mocks/peer_connection.rs index 3bc66066ce..b288339ef8 100644 --- a/comms/core/src/test_utils/mocks/peer_connection.rs +++ b/comms/core/src/test_utils/mocks/peer_connection.rs @@ -56,7 +56,7 @@ pub fn create_dummy_peer_connection(node_id: NodeId) -> (PeerConnection, mpsc::R let (tx, rx) = mpsc::channel(1); let addr = Multiaddr::from_str("/ip4/23.23.23.23/tcp/80").unwrap(); ( - PeerConnection::unverified( + PeerConnection::new( 1, tx, node_id, @@ -92,7 +92,7 @@ pub async fn create_peer_connection_mock_pair( rt_handle.spawn(mock.run()); ( - PeerConnection::unverified( + PeerConnection::new( // ID must be unique since it is used for connection equivalency, so we re-implement this in the mock ID_COUNTER.fetch_add(1, Ordering::SeqCst), tx1, @@ -103,7 +103,7 @@ pub async fn create_peer_connection_mock_pair( mock_state_in.substream_counter(), ), mock_state_in, - PeerConnection::unverified( + PeerConnection::new( ID_COUNTER.fetch_add(1, Ordering::SeqCst), tx2, peer1.node_id, diff --git a/comms/core/src/test_utils/test_node.rs b/comms/core/src/test_utils/test_node.rs index 38b614c41f..d1b4f5dcf3 100644 --- a/comms/core/src/test_utils/test_node.rs +++ b/comms/core/src/test_utils/test_node.rs @@ -35,6 +35,7 @@ use crate::{ multiplexing::Substream, noise::NoiseConfig, peer_manager::{NodeIdentity, PeerFeatures, PeerManager}, + peer_validator::PeerValidatorConfig, protocol::Protocols, transports::Transport, }; @@ -56,7 +57,10 @@ impl Default for TestNodeConfig { Self { connection_manager_config: ConnectionManagerConfig { - allow_test_addresses: true, + peer_validation_config: PeerValidatorConfig { + allow_test_addresses: true, + ..Default::default() + }, listener_address: "/memory/0".parse().unwrap(), ..Default::default() }, diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 1562c48460..194acd6993 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -118,6 +118,18 @@ pub enum DhtRequest { public_key: CommsPublicKey, reply: oneshot::Sender>, }, + BanPeer { + public_key: CommsPublicKey, + severity: OffenceSeverity, + reason: String, + }, +} + +#[derive(Debug, Clone, Copy)] +pub enum OffenceSeverity { + Low, + Medium, + High, } impl Display for DhtRequest { @@ -143,6 +155,15 @@ impl Display for DhtRequest { write!(f, "SetMetadata (key={}, value={} bytes)", key, value.len()) }, DialDiscoverPeer { public_key, .. } => write!(f, "DialDiscoverPeer(public_key={})", public_key), + BanPeer { + public_key, + severity, + reason, + } => write!( + f, + "BanPeer (peer={:#.5}, severity={:?}, reason={})", + public_key, severity, reason + ), } } } @@ -232,6 +253,21 @@ impl DhtRequester { .await?; reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled)? } + + pub async fn ban_peer(&mut self, public_key: CommsPublicKey, severity: OffenceSeverity, reason: T) { + if self + .sender + .send(DhtRequest::BanPeer { + public_key, + severity, + reason: reason.to_string(), + }) + .await + .is_err() + { + debug!(target: LOG_TARGET, "DhtActor is shut down and no longer responding to requests. This is expected during shutdown."); + } + } } /// DHT actor. Responsible for executing DHT-related tasks. @@ -345,6 +381,7 @@ impl DhtActor { } } + #[allow(clippy::too_many_lines)] fn request_handler(&mut self, request: DhtRequest) -> BoxFuture<'static, Result<(), DhtActorError>> { #[allow(clippy::enum_glob_use)] use DhtRequest::*; @@ -435,6 +472,20 @@ impl DhtActor { Ok(()) }) }, + BanPeer { + public_key, + severity, + reason, + } => { + let mut connectivity = self.connectivity.clone(); + let ban_duration = self.config.ban_duration_from_severity(severity); + Box::pin(async move { + connectivity + .ban_peer_until(NodeId::from_public_key(&public_key), ban_duration, reason) + .await?; + Ok(()) + }) + }, } } diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index e1cb70f211..45557090d4 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -24,8 +24,10 @@ use std::{path::Path, time::Duration}; use serde::{Deserialize, Serialize}; use tari_common::configuration::serializers; +use tari_comms::peer_validator::PeerValidatorConfig; use crate::{ + actor::OffenceSeverity, network_discovery::NetworkDiscoveryConfig, storage::DbConnectionUrl, store_forward::SafConfig, @@ -89,9 +91,7 @@ pub struct DhtConfig { /// Default: 30 mins #[serde(with = "serializers::seconds")] pub ban_duration_short: Duration, - /// This allows the use of test addresses in the network. - /// Default: false - pub allow_test_addresses: bool, + /// The maximum number of messages over `flood_ban_timespan` to allow before banning the peer (for /// `ban_duration_short`) Default: 100_000 messages pub flood_ban_max_msg_count: usize, @@ -106,6 +106,12 @@ pub struct DhtConfig { /// Default: 24 hours #[serde(with = "serializers::seconds")] pub offline_peer_cooldown: Duration, + /// The maximum number of peer claims accepted by this node. Only peer sync sends more than one claim. + /// Default: 5 + pub max_permitted_peer_claims: usize, + /// Configuration for peer validation + /// See [PeerValidatorConfig] + pub peer_validator_config: PeerValidatorConfig, } impl DhtConfig { @@ -133,7 +139,10 @@ impl DhtConfig { enabled: false, ..Default::default() }, - allow_test_addresses: true, + peer_validator_config: PeerValidatorConfig { + allow_test_addresses: true, + ..Default::default() + }, ..Default::default() } } @@ -142,6 +151,14 @@ impl DhtConfig { pub fn set_base_path>(&mut self, base_path: P) { self.database_url.set_base_path(base_path); } + + /// Returns a ban duration from the given severity + pub fn ban_duration_from_severity(&self, severity: OffenceSeverity) -> Duration { + match severity { + OffenceSeverity::Low | OffenceSeverity::Medium => self.ban_duration_short, + OffenceSeverity::High => self.ban_duration, + } + } } impl Default for DhtConfig { @@ -166,10 +183,11 @@ impl Default for DhtConfig { network_discovery: Default::default(), ban_duration: Duration::from_secs(6 * 60 * 60), ban_duration_short: Duration::from_secs(60 * 60), - allow_test_addresses: false, flood_ban_max_msg_count: 100_000, flood_ban_timespan: Duration::from_secs(100), + max_permitted_peer_claims: 5, offline_peer_cooldown: Duration::from_secs(24 * 60 * 60), + peer_validator_config: Default::default(), } } } diff --git a/comms/dht/src/connectivity/mod.rs b/comms/dht/src/connectivity/mod.rs index 37ab6f5bf8..2002691c9e 100644 --- a/comms/dht/src/connectivity/mod.rs +++ b/comms/dht/src/connectivity/mod.rs @@ -39,7 +39,6 @@ use std::{sync::Arc, time::Instant}; use log::*; pub use metrics::{MetricsCollector, MetricsCollectorHandle}; use tari_comms::{ - connection_manager::ConnectionDirection, connectivity::{ ConnectivityError, ConnectivityEvent, @@ -48,7 +47,6 @@ use tari_comms::{ ConnectivitySelection, }, multiaddr, - net_address::PeerAddressSource, peer_manager::{NodeDistance, NodeId, PeerManagerError, PeerQuery, PeerQuerySortBy}, NodeIdentity, PeerConnection, @@ -65,8 +63,6 @@ const LOG_TARGET: &str = "comms::dht::connectivity"; /// Error type for the DHT connectivity actor. #[derive(Debug, Error)] pub enum DhtConnectivityError { - #[error("Peer connection did not have a peer identity claim")] - PeerConnectionMissingPeerIdentityClaim, #[error("ConnectivityError: {0}")] ConnectivityError(#[from] ConnectivityError), #[error("PeerManagerError: {0}")] @@ -493,20 +489,9 @@ impl DhtConnectivity { } async fn handle_new_peer_connected(&mut self, conn: PeerConnection) -> Result<(), DhtConnectivityError> { - if conn.direction() == ConnectionDirection::Outbound { - if let Some(peer_identity_claim) = conn.peer_identity_claim() { - self.peer_manager - .mark_last_seen( - conn.peer_node_id(), - conn.address(), - &PeerAddressSource::FromPeerConnection { - peer_identity_claim: peer_identity_claim.clone(), - }, - ) - .await?; - } else { - return Err(DhtConnectivityError::PeerConnectionMissingPeerIdentityClaim); - } + // We can only mark the peer as seen if we know which address we are were about to connect to (Outbound). + if let Some(addr) = conn.known_address() { + self.peer_manager.mark_last_seen(conn.peer_node_id(), addr).await?; } if conn.peer_features().is_client() { debug!( diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 4280bcd21c..62f2d5a0f7 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -95,7 +95,7 @@ pub struct Dht { dht_sender: mpsc::Sender, /// Sender for SAF requests saf_sender: mpsc::Sender, - /// Sender for SAF repsonse signals + /// Sender for SAF response signals saf_response_signal_sender: mpsc::Sender<()>, /// Sender for DHT discovery requests discovery_sender: mpsc::Sender, @@ -197,6 +197,7 @@ impl Dht { self.config.clone(), Arc::clone(&self.node_identity), Arc::clone(&self.peer_manager), + self.dht_requester(), self.outbound_requester(), request_receiver, shutdown_signal, @@ -321,6 +322,7 @@ impl Dht { self.store_and_forward_requester(), )) .layer(ForwardLayer::new( + self.dht_requester(), self.outbound_requester(), self.node_identity.features().contains(PeerFeatures::DHT_STORE_FORWARD), )) @@ -336,6 +338,7 @@ impl Dht { self.config.clone(), self.node_identity.clone(), self.peer_manager.clone(), + self.dht_requester(), self.discovery_service_requester(), self.outbound_requester(), )) diff --git a/comms/dht/src/discovery/error.rs b/comms/dht/src/discovery/error.rs index 8862c6b266..f5951873e7 100644 --- a/comms/dht/src/discovery/error.rs +++ b/comms/dht/src/discovery/error.rs @@ -24,7 +24,10 @@ use tari_comms::peer_manager::PeerManagerError; use thiserror::Error; use tokio::sync::mpsc::error::SendError; -use crate::outbound::{message::SendFailure, DhtOutboundError}; +use crate::{ + outbound::{message::SendFailure, DhtOutboundError}, + peer_validator::DhtPeerValidatorError, +}; #[derive(Debug, Error)] pub enum DhtDiscoveryError { @@ -48,6 +51,10 @@ pub enum DhtDiscoveryError { NoSignatureProvided, #[error("Invalid signature: {0}")] InvalidSignature(String), + #[error("Invalid discovery response: {details}")] + InvalidDiscoveryResponse { details: anyhow::Error }, + #[error("DHT peer validator error: {0}")] + PeerValidatorError(#[from] DhtPeerValidatorError), } impl DhtDiscoveryError { diff --git a/comms/dht/src/discovery/service.rs b/comms/dht/src/discovery/service.rs index 728525b9e9..fe1319dc82 100644 --- a/comms/dht/src/discovery/service.rs +++ b/comms/dht/src/discovery/service.rs @@ -20,22 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{ - collections::HashMap, - convert::{TryFrom, TryInto}, - sync::Arc, - time::Instant, -}; +use std::{collections::HashMap, convert::TryFrom, sync::Arc, time::Instant}; use log::*; use rand::{rngs::OsRng, RngCore}; use tari_comms::{ log_if_error, - multiaddr::Multiaddr, - net_address::PeerAddressSource, - peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerIdentityClaim, PeerManager}, + peer_manager::{NodeIdentity, Peer, PeerManager}, types::CommsPublicKey, - validate_addresses, }; use tari_shutdown::ShutdownSignal; use tari_utilities::{hex::Hex, ByteArray}; @@ -45,11 +37,15 @@ use tokio::{ }; use crate::{ + actor::OffenceSeverity, discovery::{requester::DhtDiscoveryRequest, DhtDiscoveryError}, envelope::{DhtMessageType, NodeDestination}, outbound::{OutboundEncryption, OutboundMessageRequester, SendMessageParams}, + peer_validator::{DhtPeerValidatorError, PeerValidator}, proto::dht::{DiscoveryMessage, DiscoveryResponseMessage}, + rpc::UnvalidatedPeerInfo, DhtConfig, + DhtRequester, }; const LOG_TARGET: &str = "comms::dht::discovery_service"; @@ -75,6 +71,7 @@ pub struct DhtDiscoveryService { node_identity: Arc, outbound_requester: OutboundMessageRequester, peer_manager: Arc, + dht: DhtRequester, request_rx: mpsc::Receiver, shutdown_signal: ShutdownSignal, inflight_discoveries: HashMap, @@ -85,6 +82,7 @@ impl DhtDiscoveryService { config: Arc, node_identity: Arc, peer_manager: Arc, + dht: DhtRequester, outbound_requester: OutboundMessageRequester, request_rx: mpsc::Receiver, shutdown_signal: ShutdownSignal, @@ -93,6 +91,7 @@ impl DhtDiscoveryService { config, outbound_requester, node_identity, + dht, peer_manager, shutdown_signal, request_rx, @@ -140,7 +139,14 @@ impl DhtDiscoveryService { ); }, - NotifyDiscoveryResponseReceived(discovery_msg) => self.handle_discovery_response(discovery_msg).await, + NotifyDiscoveryResponseReceived(discovery_msg) => { + if let Err(err) = self.handle_discovery_response(discovery_msg).await { + error!( + target: LOG_TARGET, + "Failed to handle discovery response message because '{:?}'", err + ); + } + }, } } @@ -167,13 +173,10 @@ impl DhtDiscoveryService { requests } - async fn handle_discovery_response(&mut self, discovery_msg: Box) { - trace!( - target: LOG_TARGET, - "Received discovery response message from {}", - discovery_msg.public_key.to_hex() - ); - + async fn handle_discovery_response( + &mut self, + discovery_msg: Box, + ) -> Result<(), DhtDiscoveryError> { match self.inflight_discoveries.remove(&discovery_msg.nonce) { Some(request) => { let DiscoveryRequestState { @@ -182,7 +185,30 @@ impl DhtDiscoveryService { start_ts, } = request; - let result = self.validate_then_add_peer(&public_key, discovery_msg).await; + // Make sure that the response is for the expected public key + if discovery_msg.public_key.as_bytes() != public_key.as_bytes() { + warn!( + target: LOG_TARGET, + "Received a discovery response does not match the expected public key '{:#.5}'", + public_key + ); + self.dht + .ban_peer( + *public_key, + OffenceSeverity::Medium, + "Received a discovery response does not match the public key we requested", + ) + .await; + + return Ok(()); + } + trace!( + target: LOG_TARGET, + "Received discovery response message from {}", + public_key + ); + + let result = self.validate_then_add_peer(discovery_msg).await; // Resolve any other pending discover requests if the peer was found match &result { @@ -228,48 +254,46 @@ impl DhtDiscoveryService { ); }, } + + Ok(()) } async fn validate_then_add_peer( &mut self, - public_key: &CommsPublicKey, discovery_msg: Box, ) -> Result { - let node_id = NodeId::from_public_key(public_key); - - let addresses: Vec = discovery_msg - .addresses - .into_iter() - .map(Multiaddr::try_from) - .collect::>() - .map_err(|e| DhtDiscoveryError::InvalidPeerMultiaddr(e.to_string()))?; - - validate_addresses(&addresses, self.config.allow_test_addresses) - .map_err(|err| DhtDiscoveryError::InvalidPeerMultiaddr(err.to_string()))?; - - let peer_identity_claim = PeerIdentityClaim::new( - addresses.clone(), - PeerFeatures::from_bits_truncate(discovery_msg.peer_features), - discovery_msg - .identity_signature - .ok_or(DhtDiscoveryError::NoSignatureProvided)? - .try_into() - .map_err(|e: anyhow::Error| DhtDiscoveryError::InvalidSignature(e.to_string()))?, - None, - ); - - let peer = self - .peer_manager - .add_or_update_online_peer( - public_key, - node_id, - addresses, - PeerFeatures::from_bits_truncate(discovery_msg.peer_features), - &PeerAddressSource::FromDiscovery { peer_identity_claim }, - ) + let validator = PeerValidator::new(&self.config); + let info = UnvalidatedPeerInfo::try_from(*discovery_msg) + .map_err(|e| DhtDiscoveryError::InvalidDiscoveryResponse { details: e })?; + let public_key = info.public_key.clone(); + let existing_peer = self.peer_manager.find_by_public_key(&public_key).await?; + let valid_peer = self + .ban_offence(&public_key, validator.validate_peer(info, existing_peer)) .await?; + self.peer_manager.add_peer(valid_peer.clone()).await?; - Ok(peer) + Ok(valid_peer) + } + + async fn ban_offence( + &mut self, + public_key: &CommsPublicKey, + result: Result, + ) -> Result { + match result { + Ok(peer) => Ok(peer), + // Banned is an interesting case - if the peer is banned and we reban them, it will modify the original + // ban either longer or shorter. It is possible for a banned peer to send a secret message to us + // through another peer. + // TODO: perhaps connectivity manager should only make longer bans when this is called, or do nothing if + // shorter. + Err(err @ DhtPeerValidatorError::Banned { .. }) => Err(err), + Err(err @ DhtPeerValidatorError::IdentityTooManyClaims { .. }) | + Err(err @ DhtPeerValidatorError::ValidatorError(_)) => { + self.dht.ban_peer(public_key.clone(), OffenceSeverity::High, &err).await; + Err(err) + }, + } } async fn initiate_peer_discovery( @@ -365,7 +389,7 @@ mod test { use crate::{ discovery::DhtDiscoveryRequester, outbound::mock::create_outbound_service_mock, - test_utils::{build_peer_manager, make_node_identity}, + test_utils::{build_peer_manager, create_dht_actor_mock, make_node_identity}, }; #[tokio::test] @@ -380,11 +404,13 @@ mod test { // Requester which timeout instantly let mut requester = DhtDiscoveryRequester::new(sender, Duration::from_millis(1)); let shutdown = Shutdown::new(); + let (dht, _mock) = create_dht_actor_mock(1); DhtDiscoveryService::new( Default::default(), node_identity, peer_manager, + dht, outbound_requester, receiver, shutdown.to_signal(), diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index 01a7a06c1c..e507cfc00a 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -117,13 +117,21 @@ impl DhtMessageType { } pub fn is_dht_message(self) -> bool { - self.is_dht_discovery() || matches!(self, DhtMessageType::DiscoveryResponse) || self.is_dht_join() + self.is_dht_discovery() || self.is_dht_discovery_response() || self.is_dht_join() + } + + pub fn is_forwardable(self) -> bool { + self.is_domain_message() || self.is_dht_discovery() || self.is_dht_join() } pub fn is_dht_discovery(self) -> bool { matches!(self, DhtMessageType::Discovery) } + pub fn is_dht_discovery_response(self) -> bool { + matches!(self, DhtMessageType::DiscoveryResponse) + } + pub fn is_dht_join(self) -> bool { matches!(self, DhtMessageType::Join) } diff --git a/comms/dht/src/inbound/dht_handler/layer.rs b/comms/dht/src/inbound/dht_handler/layer.rs index eebe5ffa1e..a7252905e2 100644 --- a/comms/dht/src/inbound/dht_handler/layer.rs +++ b/comms/dht/src/inbound/dht_handler/layer.rs @@ -26,12 +26,13 @@ use tari_comms::peer_manager::{NodeIdentity, PeerManager}; use tower::layer::Layer; use super::middleware::DhtHandlerMiddleware; -use crate::{discovery::DhtDiscoveryRequester, outbound::OutboundMessageRequester, DhtConfig}; +use crate::{discovery::DhtDiscoveryRequester, outbound::OutboundMessageRequester, DhtConfig, DhtRequester}; pub struct DhtHandlerLayer { config: Arc, peer_manager: Arc, node_identity: Arc, + dht: DhtRequester, outbound_service: OutboundMessageRequester, discovery_requester: DhtDiscoveryRequester, } @@ -41,6 +42,7 @@ impl DhtHandlerLayer { config: Arc, node_identity: Arc, peer_manager: Arc, + dht: DhtRequester, discovery_requester: DhtDiscoveryRequester, outbound_service: OutboundMessageRequester, ) -> Self { @@ -48,6 +50,7 @@ impl DhtHandlerLayer { config, peer_manager, node_identity, + dht, outbound_service, discovery_requester, } @@ -63,6 +66,7 @@ impl Layer for DhtHandlerLayer { Arc::clone(&self.node_identity), Arc::clone(&self.peer_manager), self.outbound_service.clone(), + self.dht.clone(), self.discovery_requester.clone(), self.config.clone(), ) diff --git a/comms/dht/src/inbound/dht_handler/middleware.rs b/comms/dht/src/inbound/dht_handler/middleware.rs index 8fe66f62e6..f829d45125 100644 --- a/comms/dht/src/inbound/dht_handler/middleware.rs +++ b/comms/dht/src/inbound/dht_handler/middleware.rs @@ -35,6 +35,7 @@ use crate::{ inbound::DecryptedDhtMessage, outbound::OutboundMessageRequester, DhtConfig, + DhtRequester, }; #[derive(Clone)] @@ -42,6 +43,7 @@ pub struct DhtHandlerMiddleware { next_service: S, peer_manager: Arc, node_identity: Arc, + dht: DhtRequester, outbound_service: OutboundMessageRequester, discovery_requester: DhtDiscoveryRequester, config: Arc, @@ -53,6 +55,7 @@ impl DhtHandlerMiddleware { node_identity: Arc, peer_manager: Arc, outbound_service: OutboundMessageRequester, + dht: DhtRequester, discovery_requester: DhtDiscoveryRequester, config: Arc, ) -> Self { @@ -60,6 +63,7 @@ impl DhtHandlerMiddleware { next_service, peer_manager, node_identity, + dht, outbound_service, discovery_requester, config, @@ -87,6 +91,7 @@ where Arc::clone(&self.peer_manager), self.outbound_service.clone(), Arc::clone(&self.node_identity), + self.dht.clone(), self.discovery_requester.clone(), message, self.config.clone(), diff --git a/comms/dht/src/inbound/dht_handler/task.rs b/comms/dht/src/inbound/dht_handler/task.rs index 250070a1f6..1f1cc91594 100644 --- a/comms/dht/src/inbound/dht_handler/task.rs +++ b/comms/dht/src/inbound/dht_handler/task.rs @@ -20,26 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{ - convert::{TryFrom, TryInto}, - sync::Arc, -}; +use std::{convert::TryInto, sync::Arc}; use log::*; use tari_comms::{ message::MessageExt, - multiaddr::Multiaddr, - net_address::{MultiaddressesWithStats, PeerAddressSource}, - peer_manager::{ - IdentitySignature, - NodeId, - NodeIdentity, - Peer, - PeerFeatures, - PeerFlags, - PeerIdentityClaim, - PeerManager, - }, + peer_manager::{NodeId, NodeIdentity, PeerManager}, pipeline::PipelineError, types::CommsPublicKey, OrNotFound, @@ -48,17 +34,19 @@ use tari_utilities::{hex::Hex, ByteArray}; use tower::{Service, ServiceExt}; use crate::{ + actor::OffenceSeverity, discovery::DhtDiscoveryRequester, envelope::NodeDestination, inbound::{error::DhtInboundError, message::DecryptedDhtMessage}, outbound::{OutboundMessageRequester, SendMessageParams}, - peer_validator::PeerValidator, + peer_validator::{DhtPeerValidatorError, PeerValidator}, proto::{ dht::{DiscoveryMessage, DiscoveryResponseMessage, JoinMessage}, envelope::DhtMessageType, }, - rpc::PeerInfo, + rpc::UnvalidatedPeerInfo, DhtConfig, + DhtRequester, }; const LOG_TARGET: &str = "comms::dht::dht_handler"; @@ -68,6 +56,7 @@ pub struct ProcessDhtMessage { peer_manager: Arc, outbound_service: OutboundMessageRequester, node_identity: Arc, + dht: DhtRequester, message: Option, discovery_requester: DhtDiscoveryRequester, config: Arc, @@ -81,6 +70,7 @@ where S: Service peer_manager: Arc, outbound_service: OutboundMessageRequester, node_identity: Arc, + dht: DhtRequester, discovery_requester: DhtDiscoveryRequester, message: DecryptedDhtMessage, config: Arc, @@ -90,6 +80,7 @@ where S: Service peer_manager, outbound_service, node_identity, + dht, discovery_requester, message: Some(message), config, @@ -152,6 +143,7 @@ where S: Service Ok(()) } + #[allow(clippy::too_many_lines)] async fn handle_join(&mut self, message: DecryptedDhtMessage) -> Result<(), DhtInboundError> { let DecryptedDhtMessage { decryption_result, @@ -162,64 +154,80 @@ where S: Service .. } = message; - let authenticated_pk = authenticated_origin.ok_or_else(|| { - DhtInboundError::OriginRequired("Authenticated origin is required for this message type".to_string()) - })?; + // Ban the source peer. They should not have propagated a DHT discover response. + let Some(authenticated_pk) = authenticated_origin else { + warn!( + target: LOG_TARGET, + "Received JoinMessage that did not have an authenticated origin from source peer {}. Banning source", source_peer + ); - if &authenticated_pk == self.node_identity.public_key() { + self .dht .ban_peer(source_peer.public_key.clone(), OffenceSeverity::Low, "Received JoinMessage that did not have an authenticated origin", ).await; + return Ok(()); + }; + + if authenticated_pk == *self.node_identity.public_key() { debug!(target: LOG_TARGET, "Received our own join message. Discarding it."); return Ok(()); } let body = decryption_result.expect("already checked that this message decrypted successfully"); - let join_msg = body - .decode_part::(0)? - .ok_or(DhtInboundError::InvalidMessageBody)?; + let join_msg = self + .ban_on_offence( + &authenticated_pk, + body.decode_part::(0) + .map_err(Into::into) + .and_then(|o| o.ok_or(DhtInboundError::InvalidMessageBody)), + ) + .await?; + + if join_msg.public_key.as_slice() != authenticated_pk.as_bytes() { + warn!( + target: LOG_TARGET, + "Received JoinMessage from peer that mismatches the authenticated origin. \ + This message was signed by another party which may be attempting to get other nodes banned. \ + Banning the message signer." + ); + + warn!( + target: LOG_TARGET, + "Authenticated origin: {:#.6}, Source: {:#.6}, join message: {}", + authenticated_pk, source_peer.public_key, join_msg.public_key.to_hex() + ); + self.dht + .ban_peer( + authenticated_pk, + OffenceSeverity::High, + "Received JoinMessage from peer with a public key that does not match the source peer", + ) + .await; + + return Ok(()); + } debug!( target: LOG_TARGET, "Received join Message from '{}' {}", authenticated_pk, join_msg ); - let addresses = join_msg - .addresses - .iter() - .filter_map(|addr| Multiaddr::try_from(addr.clone()).ok()) - .collect::>(); - - if addresses.is_empty() { - return Err(DhtInboundError::InvalidAddresses); - } - let node_id = NodeId::from_public_key(&authenticated_pk); - - let features = PeerFeatures::from_bits_truncate(join_msg.peer_features); - - let identity_signature: IdentitySignature = join_msg - .identity_signature - .map(IdentitySignature::try_from) - .transpose() - .map_err(|err| DhtInboundError::InvalidPeerIdentitySignature(err.to_string()))? - .ok_or(DhtInboundError::NoPeerIdentitySignature)?; - - let peer_identity_claim = PeerIdentityClaim::new(addresses.clone(), features, identity_signature, None); - - let new_peer = Peer::new( - authenticated_pk, - node_id.clone(), - MultiaddressesWithStats::from_addresses_with_source(addresses, &PeerAddressSource::FromJoinMessage { - peer_identity_claim, - }), - PeerFlags::empty(), - features, - vec![], - String::new(), - ); + let validator = PeerValidator::new(&self.config); + let maybe_existing = self.peer_manager.find_by_public_key(&authenticated_pk).await?; + let valid_peer = self + .ban_on_offence( + &authenticated_pk, + validator + .validate_peer(join_msg.try_into()?, maybe_existing) + .map_err(Into::into), + ) + .await?; - self.peer_manager.add_peer(new_peer.clone()).await?; - let origin_peer = self.peer_manager.find_by_node_id(&node_id).await.or_not_found()?; + let is_banned = valid_peer.is_banned(); + let valid_peer_node_id = valid_peer.node_id.clone(); + let valid_peer_public_key = valid_peer.public_key.clone(); + // Update peer details. If the peer is banned we preserve the ban but still allow them to update their claims. + self.peer_manager.add_peer(valid_peer).await?; // DO NOT propagate this peer if this node has banned them - if origin_peer.is_banned() { + if is_banned { debug!( target: LOG_TARGET, "Received Join request for banned peer. This join request will not be propagated." @@ -235,21 +243,19 @@ where S: Service return Ok(()); } - let origin_public_key = origin_peer.public_key; - // Only propagate a join that was not directly sent to this node if dht_header.destination != self.node_identity.public_key() { debug!( target: LOG_TARGET, "Propagating Join message from peer '{}'", - origin_peer.node_id.short_str() + valid_peer_node_id.short_str() ); // Propagate message to closer peers self.outbound_service .send_raw_no_wait( SendMessageParams::new() - .propagate(origin_public_key.clone().into(), vec![ - origin_peer.node_id, + .propagate(valid_peer_public_key.into(), vec![ + valid_peer_node_id, source_peer.node_id.clone(), ]) .with_debug_info("Propagating join message".to_string()) @@ -278,9 +284,55 @@ where S: Service .success() .expect("already checked that this message decrypted successfully"); - let discover_msg = msg - .decode_part::(0)? - .ok_or(DhtInboundError::InvalidMessageBody)?; + // Ban the source peer. They should not have propagated a DHT discover response. + let Some(authenticated_origin) = message.authenticated_origin.as_ref() else { + warn!( + target: LOG_TARGET, + "Received DiscoveryResponseMessage that did not have an authenticated origin: {}. Banning source", message + ); + self.dht .ban_peer( + message.source_peer.public_key.clone(), + OffenceSeverity::Low, + "Received DiscoveryResponseMessage that did not have an authenticated origin", + ).await; + + return Ok(()); + }; + + let discover_msg = self + .ban_on_offence( + authenticated_origin, + msg.decode_part::(0) + .map_err(Into::into) + .and_then(|o| o.ok_or(DhtInboundError::InvalidMessageBody)), + ) + .await?; + + if *authenticated_origin != message.source_peer.public_key || + authenticated_origin.as_bytes() != discover_msg.public_key.as_slice() + { + warn!( + target: LOG_TARGET, + "Received DiscoveryResponseMessage from peer that mismatches the discovery response. \ + This message was signed by another party which may be attempting to get other nodes banned. \ + Banning the message signer." + ); + + warn!( + target: LOG_TARGET, + "Authenticated origin: {:#.6}, Source: {:#.6}, discovery message: {}", + authenticated_origin, message.source_peer.public_key, discover_msg.public_key.to_hex() + ); + self.dht + .ban_peer( + authenticated_origin.clone(), + OffenceSeverity::High, + "Received DiscoveryResponseMessage from peer with a public key that does not match the source peer", + ) + .await; + + return Ok(()); + } self.discovery_requester .notify_discovery_response_received(discover_msg) @@ -294,34 +346,57 @@ where S: Service .success() .expect("already checked that this message decrypted successfully"); - let discover_msg = msg - .decode_part::(0)? - .ok_or(DhtInboundError::InvalidMessageBody)?; + let Some(authenticated_pk) = message.authenticated_origin.as_ref() else { + warn!( + target: LOG_TARGET, + "Received Discover that did not have an authenticated origin from source peer {}. Banning source", message.source_peer + ); + self.dht.ban_peer( + message.source_peer.public_key.clone(), + OffenceSeverity::Low, + "Received JoinMessage that did not have an authenticated origin", + ).await; + + return Ok(()); + }; + + let discover_msg = self + .ban_on_offence( + authenticated_pk, + msg.decode_part::(0) + .map_err(Into::into) + .and_then(|o| o.ok_or(DhtInboundError::InvalidMessageBody)), + ) + .await?; let nonce = discover_msg.nonce; - let authenticated_pk = message.authenticated_origin.ok_or_else(|| { - DhtInboundError::OriginRequired("Origin header required for Discovery message".to_string()) - })?; debug!( target: LOG_TARGET, "Received discovery message from '{}', forwarded by {}", authenticated_pk, message.source_peer ); - let new_peer: PeerInfo = discover_msg - .try_into() - .map_err(DhtInboundError::InvalidDiscoveryMessage)?; + let new_peer: UnvalidatedPeerInfo = self + .ban_on_offence( + authenticated_pk, + discover_msg + .try_into() + .map_err(DhtInboundError::InvalidDiscoveryMessage), + ) + .await?; let node_id = NodeId::from_public_key(&new_peer.public_key); - let peer_validator = PeerValidator::new(&self.peer_manager, &self.config); - peer_validator.validate_and_add_peer(new_peer).await?; + let peer_validator = PeerValidator::new(&self.config); + let maybe_existing_peer = self.peer_manager.find_by_public_key(&new_peer.public_key).await?; + let peer = peer_validator.validate_peer(new_peer, maybe_existing_peer)?; + self.peer_manager.add_peer(peer).await?; let origin_peer = self.peer_manager.find_by_node_id(&node_id).await.or_not_found()?; // Don't send a join request to the origin peer if they are banned if origin_peer.is_banned() { warn!( target: LOG_TARGET, - "Received Discovery request for banned peer '{}'. This request will be ignored.", node_id + "Received Discovery request for banned peer '{}'. Not propagating further.", node_id ); return Ok(()); } @@ -360,6 +435,7 @@ where S: Service .with_debug_info("Sending discovery response".to_string()) .with_destination(NodeDestination::Unknown) .with_dht_message_type(DhtMessageType::DiscoveryResponse) + .force_origin() .finish(), response, ) @@ -367,4 +443,43 @@ where S: Service Ok(()) } + + async fn ban_on_offence( + &mut self, + authenticated_pk: &CommsPublicKey, + result: Result, + ) -> Result { + match result { + Ok(r) => Ok(r), + Err(err) => { + match &err { + DhtInboundError::PeerValidatorError(err) => match err { + DhtPeerValidatorError::Banned { .. } => {}, + err @ DhtPeerValidatorError::ValidatorError(_) | + err @ DhtPeerValidatorError::IdentityTooManyClaims { .. } => { + self.dht + .ban_peer(authenticated_pk.clone(), OffenceSeverity::Medium, err) + .await; + }, + }, + err @ DhtInboundError::MessageError(_) | err @ DhtInboundError::InvalidMessageBody => { + self.dht + .ban_peer(authenticated_pk.clone(), OffenceSeverity::High, err) + .await; + }, + DhtInboundError::PeerManagerError(_) => {}, + DhtInboundError::DhtOutboundError(_) => {}, + DhtInboundError::DhtDiscoveryError(_) => {}, + DhtInboundError::OriginRequired(_) => {}, + err @ DhtInboundError::InvalidDiscoveryMessage(_) => { + self.dht + .ban_peer(authenticated_pk.clone(), OffenceSeverity::High, err) + .await; + }, + DhtInboundError::ConnectivityError(_) => {}, + } + Err(err) + }, + } + } } diff --git a/comms/dht/src/inbound/error.rs b/comms/dht/src/inbound/error.rs index 6e920bb99d..96a13179cf 100644 --- a/comms/dht/src/inbound/error.rs +++ b/comms/dht/src/inbound/error.rs @@ -20,18 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use tari_comms::{ - message::MessageError, - peer_manager::{NodeId, PeerManagerError}, -}; +use tari_comms::{connectivity::ConnectivityError, message::MessageError, peer_manager::PeerManagerError}; use thiserror::Error; -use crate::{ - discovery::DhtDiscoveryError, - error::DhtEncryptError, - outbound::DhtOutboundError, - peer_validator::PeerValidatorError, -}; +use crate::{discovery::DhtDiscoveryError, outbound::DhtOutboundError, peer_validator::DhtPeerValidatorError}; #[derive(Debug, Error)] pub enum DhtInboundError { @@ -41,24 +33,16 @@ pub enum DhtInboundError { PeerManagerError(#[from] PeerManagerError), #[error("DhtOutboundError: {0}")] DhtOutboundError(#[from] DhtOutboundError), - #[error("DhtEncryptError: {0}")] - DhtEncryptError(#[from] DhtEncryptError), #[error("Message body invalid")] InvalidMessageBody, - #[error("All given addresses were invalid")] - InvalidAddresses, - #[error("One or more peer addresses were invalid for '{peer}'")] - InvalidPeerAddresses { peer: NodeId }, #[error("DhtDiscoveryError: {0}")] DhtDiscoveryError(#[from] DhtDiscoveryError), #[error("OriginRequired: {0}")] OriginRequired(String), - #[error("Invalid peer identity signature: {0}")] - InvalidPeerIdentitySignature(String), - #[error("No peer identity signature")] - NoPeerIdentitySignature, - #[error("Invalid peer: {0}")] - PeerValidatorError(#[from] PeerValidatorError), + #[error("Peer validation failed: {0}")] + PeerValidatorError(#[from] DhtPeerValidatorError), #[error("Invalid discovery message {0}")] InvalidDiscoveryMessage(#[from] anyhow::Error), + #[error("ConnectivityError: {0}")] + ConnectivityError(#[from] ConnectivityError), } diff --git a/comms/dht/src/inbound/forward.rs b/comms/dht/src/inbound/forward.rs index f197c8440a..37e04129d2 100644 --- a/comms/dht/src/inbound/forward.rs +++ b/comms/dht/src/inbound/forward.rs @@ -30,22 +30,26 @@ use tari_utilities::epoch_time::EpochTime; use tower::{layer::Layer, Service, ServiceExt}; use crate::{ + actor::OffenceSeverity, envelope::NodeDestination, inbound::{error::DhtInboundError, DecryptedDhtMessage}, outbound::{OutboundMessageRequester, SendMessageParams}, + DhtRequester, }; const LOG_TARGET: &str = "comms::dht::storeforward::forward"; /// This layer is responsible for forwarding messages which have failed to decrypt pub struct ForwardLayer { + dht: DhtRequester, outbound_service: OutboundMessageRequester, is_enabled: bool, } impl ForwardLayer { - pub fn new(outbound_service: OutboundMessageRequester, is_enabled: bool) -> Self { + pub fn new(dht: DhtRequester, outbound_service: OutboundMessageRequester, is_enabled: bool) -> Self { Self { + dht, outbound_service, is_enabled, } @@ -58,7 +62,7 @@ impl Layer for ForwardLayer { fn layer(&self, service: S) -> Self::Service { ForwardMiddleware::new( service, - // Pass in just the config item needed by the middleware for almost free copies + self.dht.clone(), self.outbound_service.clone(), self.is_enabled, ) @@ -71,14 +75,16 @@ impl Layer for ForwardLayer { #[derive(Clone)] pub struct ForwardMiddleware { next_service: S, + dht: DhtRequester, outbound_service: OutboundMessageRequester, is_enabled: bool, } impl ForwardMiddleware { - pub fn new(service: S, outbound_service: OutboundMessageRequester, is_enabled: bool) -> Self { + pub fn new(service: S, dht: DhtRequester, outbound_service: OutboundMessageRequester, is_enabled: bool) -> Self { Self { next_service: service, + dht, outbound_service, is_enabled, } @@ -101,6 +107,7 @@ where fn call(&mut self, message: DecryptedDhtMessage) -> Self::Future { let next_service = self.next_service.clone(); let outbound_service = self.outbound_service.clone(); + let dht = self.dht.clone(); let is_enabled = self.is_enabled; Box::pin(async move { if !is_enabled { @@ -119,7 +126,7 @@ where message.tag, message.dht_header.message_tag ); - let forwarder = Forwarder::new(next_service, outbound_service); + let forwarder = Forwarder::new(next_service, dht, outbound_service); forwarder.handle(message).await }) } @@ -129,13 +136,15 @@ where /// to the next service. struct Forwarder { next_service: S, + dht: DhtRequester, outbound_service: OutboundMessageRequester, } impl Forwarder { - pub fn new(service: S, outbound_service: OutboundMessageRequester) -> Self { + pub fn new(service: S, dht: DhtRequester, outbound_service: OutboundMessageRequester) -> Self { Self { next_service: service, + dht, outbound_service, } } @@ -152,7 +161,11 @@ where S: Service message.tag, message.dht_header.message_tag ); - self.forward(&message).await?; + + // Only forward DHT discovery, Join and any encrypted Domain messages + if message.dht_header.message_type.is_forwardable() { + self.forward(&message).await?; + } } // The message has been forwarded, but downstream middleware may be interested @@ -173,6 +186,7 @@ where S: Service dht_header, is_saf_stored, is_already_forwarded, + authenticated_origin, .. } = message; @@ -180,9 +194,26 @@ where S: Service // #banheuristic - the origin of this message was the destination. Two things are wrong here: // 1. The origin/destination should not have forwarded this (the destination node didnt do this // destination_matches_source check) - // 1. The source sent a message that the destination could not decrypt + // 1. The origin sent a message that the destination could not decrypt // The authenticated source should be banned (malicious), and origin should be temporarily banned // (bug?) + if let Some(authenticated_origin) = authenticated_origin { + self.dht + .ban_peer( + authenticated_origin.clone(), + OffenceSeverity::High, + "Received message from peer that is destined for that peer. This peer originally sent it.", + ) + .await; + } + self.dht + .ban_peer( + source_peer.public_key.clone(), + OffenceSeverity::Medium, + "Received message from peer that is destined for that peer. The source peer should not have sent \ + this message.", + ) + .await; debug!( target: LOG_TARGET, "Received message {} from peer '{}' that is destined for that peer. Discarding message (Trace: {})", @@ -222,8 +253,9 @@ where S: Service node_id, dht_header.message_tag ); debug!(target: LOG_TARGET, "{}", &debug_info); - send_params.with_debug_info(debug_info); - send_params.direct_or_closest_connected(node_id, excluded_peers); + send_params + .with_debug_info(debug_info) + .direct_or_closest_connected(node_id, excluded_peers); }, _ => { let debug_info = format!( @@ -231,8 +263,9 @@ where S: Service dht_header.destination, dht_header.message_tag ); debug!(target: LOG_TARGET, "{}", debug_info); - send_params.with_debug_info(debug_info); - send_params.propagate(dht_header.destination.clone(), excluded_peers); + send_params + .with_debug_info(debug_info) + .propagate(dht_header.destination.clone(), excluded_peers); }, }; @@ -266,15 +299,16 @@ mod test { use crate::{ envelope::DhtMessageFlags, outbound::mock::create_outbound_service_mock, - test_utils::{make_dht_inbound_message, make_node_identity, service_spy}, + test_utils::{create_dht_actor_mock, make_dht_inbound_message, make_node_identity, service_spy}, }; #[tokio::test] async fn decryption_succeeded() { let spy = service_spy(); let (oms_tx, _) = mpsc::channel(1); + let (dht, _mock) = create_dht_actor_mock(1); let oms = OutboundMessageRequester::new(oms_tx); - let mut service = ForwardLayer::new(oms, true).layer(spy.to_service::()); + let mut service = ForwardLayer::new(dht, oms, true).layer(spy.to_service::()); let node_identity = make_node_identity(); let inbound_msg = @@ -293,9 +327,10 @@ mod test { let spy = service_spy(); let (oms_requester, oms_mock) = create_outbound_service_mock(1); let oms_mock_state = oms_mock.get_state(); + let (dht, _mock) = create_dht_actor_mock(1); task::spawn(oms_mock.run()); - let mut service = ForwardLayer::new(oms_requester, true).layer(spy.to_service::()); + let mut service = ForwardLayer::new(dht, oms_requester, true).layer(spy.to_service::()); let sample_body = b"Lorem ipsum"; let inbound_msg = make_dht_inbound_message( diff --git a/comms/dht/src/logging_middleware.rs b/comms/dht/src/logging_middleware.rs index 9cfe176692..d3b2ffbcb6 100644 --- a/comms/dht/src/logging_middleware.rs +++ b/comms/dht/src/logging_middleware.rs @@ -20,12 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{ - borrow::Cow, - fmt::{Debug, Display}, - marker::PhantomData, - task::Poll, -}; +use std::{borrow::Cow, fmt::Display, marker::PhantomData, task::Poll}; use futures::task::Context; use log::*; @@ -80,7 +75,7 @@ impl<'a, S> MessageLoggingService<'a, S> { impl Service for MessageLoggingService<'_, S> where S: Service, - R: Display + Debug, + R: Display, { type Error = S::Error; type Future = S::Future; @@ -91,7 +86,7 @@ where } fn call(&mut self, msg: R) -> Self::Future { - debug!(target: LOG_TARGET, "{}{:?}", self.prefix_msg, msg); + debug!(target: LOG_TARGET, "{}{:#}", self.prefix_msg, msg); self.inner.call(msg) } } diff --git a/comms/dht/src/network_discovery/discovering.rs b/comms/dht/src/network_discovery/discovering.rs index d482ba5eac..d12a098123 100644 --- a/comms/dht/src/network_discovery/discovering.rs +++ b/comms/dht/src/network_discovery/discovering.rs @@ -35,13 +35,7 @@ use super::{ state_machine::{DhtNetworkDiscoveryRoundInfo, DiscoveryParams, NetworkDiscoveryContext, StateEvent}, NetworkDiscoveryError, }; -use crate::{ - peer_validator::{PeerValidator, PeerValidatorError}, - proto::rpc::GetPeersRequest, - rpc, - rpc::PeerInfo, - DhtConfig, -}; +use crate::{peer_validator::PeerValidator, proto::rpc::GetPeersRequest, rpc, rpc::UnvalidatedPeerInfo, DhtConfig}; const LOG_TARGET: &str = "comms::dht::network_discovery"; @@ -160,6 +154,19 @@ impl Discovering { .map(|v| u32::try_from(v).unwrap()) .unwrap_or_default(), include_clients: true, + max_claims: self.config().max_permitted_peer_claims.try_into().unwrap_or_else(|_| { + error!(target: LOG_TARGET, "Node configured to accept more than u32::MAX claims per peer"); + u32::MAX + }), + max_addresses_per_claim: self + .config() + .peer_validator_config + .max_permitted_peer_addresses_per_claim + .try_into() + .unwrap_or_else(|_| { + error!(target: LOG_TARGET, "Node configured to accept more than u32::MAX addresses per claim"); + u32::MAX + }), }) .await { @@ -194,7 +201,7 @@ impl Discovering { async fn validate_and_add_peer( &mut self, sync_peer: &NodeId, - new_peer: PeerInfo, + new_peer: UnvalidatedPeerInfo, ) -> Result<(), NetworkDiscoveryError> { let node_id = NodeId::from_public_key(&new_peer.public_key); if self.context.node_identity.node_id() == &node_id { @@ -202,18 +209,20 @@ impl Discovering { return Ok(()); } - let peer_validator = PeerValidator::new(self.peer_manager(), self.config()); + let maybe_existing_peer = self.peer_manager().find_by_public_key(&new_peer.public_key).await?; + let peer_exists = maybe_existing_peer.is_some(); - match peer_validator.validate_and_add_peer(new_peer).await { - Ok(true) => { - self.stats.num_new_peers += 1; - Ok(()) - }, - Ok(false) => { - self.stats.num_duplicate_peers += 1; + let peer_validator = PeerValidator::new(self.config()); + match peer_validator.validate_peer(new_peer, maybe_existing_peer) { + Ok(valid_peer) => { + if peer_exists { + self.stats.num_duplicate_peers += 1; + } else { + self.stats.num_new_peers += 1; + } + self.peer_manager().add_peer(valid_peer).await?; Ok(()) }, - Err(err @ PeerValidatorError::PeerManagerError(_)) => Err(err.into()), Err(err) => { warn!( target: LOG_TARGET, diff --git a/comms/dht/src/network_discovery/error.rs b/comms/dht/src/network_discovery/error.rs index 0b9a227908..7175be312f 100644 --- a/comms/dht/src/network_discovery/error.rs +++ b/comms/dht/src/network_discovery/error.rs @@ -22,7 +22,7 @@ use tari_comms::{connectivity::ConnectivityError, peer_manager::PeerManagerError, protocol::rpc::RpcError}; -use crate::peer_validator::PeerValidatorError; +use crate::peer_validator::DhtPeerValidatorError; #[derive(thiserror::Error, Debug)] pub enum NetworkDiscoveryError { @@ -35,5 +35,5 @@ pub enum NetworkDiscoveryError { #[error("No sync peers available")] NoSyncPeers, #[error("Sync peer sent invalid peer: {0}")] - PeerValidationError(#[from] PeerValidatorError), + PeerValidationError(#[from] DhtPeerValidatorError), } diff --git a/comms/dht/src/network_discovery/on_connect.rs b/comms/dht/src/network_discovery/on_connect.rs index 1728ac4b57..d0ac25cb66 100644 --- a/comms/dht/src/network_discovery/on_connect.rs +++ b/comms/dht/src/network_discovery/on_connect.rs @@ -37,7 +37,7 @@ use crate::{ peer_validator::PeerValidator, proto::rpc::GetPeersRequest, rpc, - rpc::PeerInfo, + rpc::UnvalidatedPeerInfo, DhtConfig, }; const LOG_TARGET: &str = "comms::dht::network_discovery:onconnect"; @@ -127,6 +127,19 @@ impl OnConnect { .get_peers(GetPeersRequest { n: NUM_FETCH_PEERS, include_clients: false, + max_claims: self.config().max_permitted_peer_claims.try_into().unwrap_or_else(|_| { + error!(target: LOG_TARGET, "Node configured to accept more than u32::MAX claims per peer"); + u32::MAX + }), + max_addresses_per_claim: self + .config() + .peer_validator_config + .max_permitted_peer_addresses_per_claim + .try_into() + .unwrap_or_else(|_| { + error!(target: LOG_TARGET, "Node configured to accept more than u32::MAX addresses per claim"); + u32::MAX + }), }) .await?; @@ -170,11 +183,14 @@ impl OnConnect { Ok(()) } - // Returns true if the peer was added - async fn validate_and_add_peer(&self, peer: PeerInfo) -> Result { - let peer_validator = PeerValidator::new(&self.context.peer_manager, self.config()); - - Ok(peer_validator.validate_and_add_peer(peer).await?) + /// Returns true if the peer is a new peer + async fn validate_and_add_peer(&self, peer: UnvalidatedPeerInfo) -> Result { + let peer_validator = PeerValidator::new(self.config()); + let maybe_existing_peer = self.context.peer_manager.find_by_public_key(&peer.public_key).await?; + let is_new_peer = maybe_existing_peer.is_none(); + let valid_peer = peer_validator.validate_peer(peer, maybe_existing_peer)?; + self.context.peer_manager.add_peer(valid_peer).await?; + Ok(is_new_peer) } #[inline] diff --git a/comms/dht/src/network_discovery/test.rs b/comms/dht/src/network_discovery/test.rs index f947f69e6d..b757a1a8a7 100644 --- a/comms/dht/src/network_discovery/test.rs +++ b/comms/dht/src/network_discovery/test.rs @@ -48,7 +48,7 @@ use crate::{ mod state_machine { use super::*; - use crate::rpc::PeerInfo; + use crate::rpc::UnvalidatedPeerInfo; async fn setup( mut config: DhtConfig, @@ -113,7 +113,7 @@ mod state_machine { }; let peers = iter::repeat_with(|| make_node_identity().to_peer()) .map(|p| GetPeersResponse { - peer: Some(PeerInfo::from(p).into()), + peer: Some(UnvalidatedPeerInfo::from_peer_limited_claims(p, 5, 5).into()), }) .take(NUM_PEERS) .collect(); diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index fc4547f3c3..226942f1b6 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -189,15 +189,21 @@ impl fmt::Display for DhtOutboundMessage { self.dht_flags, self.destination, self.tag, ) }); - write!( + + writeln!( f, - "\n---- Outgoing message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\n{}\n----\n{:?}\n", + "\n---- Outgoing message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\n{}\n----", self.body.len(), self.dht_message_type, self.destination, header_str, self.tag, - self.body - ) + )?; + + if f.alternate() { + write!(f, "(body omitted)") + } else { + write!(f, "{:?}", self.body) + } } } diff --git a/comms/dht/src/outbound/message_params.rs b/comms/dht/src/outbound/message_params.rs index 2fb1aabf0e..269d75ddc2 100644 --- a/comms/dht/src/outbound/message_params.rs +++ b/comms/dht/src/outbound/message_params.rs @@ -104,6 +104,7 @@ impl SendMessageParams { Default::default() } + /// Not currently used pub fn with_debug_info(&mut self, debug_info: String) -> &mut Self { self.params_mut().debug_info = Some(debug_info); self diff --git a/comms/dht/src/peer_validator.rs b/comms/dht/src/peer_validator.rs index f270559adf..9c4b37a958 100644 --- a/comms/dht/src/peer_validator.rs +++ b/comms/dht/src/peer_validator.rs @@ -22,103 +22,96 @@ use log::*; use tari_comms::{ - connection_manager::validate_address_and_source, - net_address::{MultiaddrWithStats, MultiaddressesWithStats, PeerAddressSource}, - peer_manager::{NodeId, Peer, PeerFlags, PeerManagerError}, - types::CommsPublicKey, - PeerManager, + net_address::{MultiaddressesWithStats, PeerAddressSource}, + peer_manager::{NodeId, Peer, PeerFlags}, + peer_validator, + peer_validator::{find_most_recent_claim, PeerValidatorError}, }; -use crate::{rpc::PeerInfo, DhtConfig}; +use crate::{rpc::UnvalidatedPeerInfo, DhtConfig}; -const LOG_TARGET: &str = "dht::network_discovery::peer_validator"; +const _LOG_TARGET: &str = "dht::network_discovery::peer_validator"; /// Validation errors for peers shared on the network #[derive(Debug, thiserror::Error)] -pub enum PeerValidatorError { - #[error("Node ID was invalid for peer '{peer}'")] - InvalidNodeId { peer: NodeId }, - #[error("Peer signature was invalid for peer '{peer}'")] - InvalidPeerSignature { peer: NodeId }, - #[error("One or more peer addresses were invalid for '{peer}'")] - InvalidPeerAddresses { peer: NodeId }, - #[error("Peer '{peer}' was banned")] - PeerHasNoAddresses { peer: NodeId }, - #[error("Peer manager error: {0}")] - PeerManagerError(#[from] PeerManagerError), +pub enum DhtPeerValidatorError { + #[error("Peer '{peer}' is banned: {reason}")] + Banned { peer: NodeId, reason: String }, + #[error(transparent)] + ValidatorError(#[from] PeerValidatorError), + #[error("Peer provided too many claims: expected max {max} but got {length}")] + IdentityTooManyClaims { length: usize, max: usize }, } /// Validator for Peers pub struct PeerValidator<'a> { - peer_manager: &'a PeerManager, config: &'a DhtConfig, } impl<'a> PeerValidator<'a> { /// Creates a new peer validator - pub fn new(peer_manager: &'a PeerManager, config: &'a DhtConfig) -> Self { - Self { peer_manager, config } + pub fn new(config: &'a DhtConfig) -> Self { + Self { config } } /// Validates the new peer against the current peer database. Returning true if a new peer was added and false if /// the peer already exists. - pub async fn validate_and_add_peer(&self, new_peer: PeerInfo) -> Result { - let node_id = NodeId::from_public_key(&new_peer.public_key); - - if new_peer.addresses.is_empty() { - return Err(PeerValidatorError::PeerHasNoAddresses { peer: node_id }); + pub fn validate_peer( + &self, + new_peer: UnvalidatedPeerInfo, + existing_peer: Option, + ) -> Result { + if new_peer.claims.is_empty() { + return Err(PeerValidatorError::PeerHasNoAddresses { + peer: NodeId::from_public_key(&new_peer.public_key), + } + .into()); } - let mut peer = Peer::new( - new_peer.public_key.clone(), - node_id.clone(), - MultiaddressesWithStats::new(vec![]), - PeerFlags::default(), - new_peer.peer_features, - new_peer.supported_protocols, - new_peer.user_agent, - ); - - for addr in new_peer.addresses { - let multiaddr_and_stats = MultiaddrWithStats::new(addr.address.clone(), PeerAddressSource::FromDiscovery { - peer_identity_claim: addr.peer_identity_claim, + + if new_peer.claims.len() > self.config.max_permitted_peer_claims { + return Err(DhtPeerValidatorError::IdentityTooManyClaims { + length: new_peer.claims.len(), + max: self.config.max_permitted_peer_claims, }); - match validate_address_and_source( - &new_peer.public_key, - &multiaddr_and_stats, - self.config.allow_test_addresses, - ) { - Ok(()) => { - peer.addresses - .add_address(multiaddr_and_stats.address(), multiaddr_and_stats.source()); - }, - Err(e) => { - warn!( - target: LOG_TARGET, - "Peer provided info on another peer that had a bad address or signature (new peer: {} \ - address: {}): error:{}. Ignoring.", - new_peer.public_key, - addr.address, - e - ); - }, + } + + if let Some(ref peer) = existing_peer { + if peer.is_banned() { + return Err(DhtPeerValidatorError::Banned { + peer: peer.node_id.clone(), + reason: peer.banned_reason.clone(), + }); } } - validate_node_id(&peer.public_key, &peer.node_id)?; - let exists = self.peer_manager.exists(&peer.public_key).await; + let most_recent_claim = find_most_recent_claim(&new_peer.claims).expect("new_peer.claims is not empty"); - self.peer_manager.add_peer(peer).await?; + let node_id = NodeId::from_public_key(&new_peer.public_key); - Ok(!exists) - } -} + let mut peer = existing_peer.unwrap_or_else(|| { + Peer::new( + new_peer.public_key.clone(), + node_id, + MultiaddressesWithStats::default(), + PeerFlags::default(), + most_recent_claim.features, + vec![], + String::new(), + ) + }); + + for claim in new_peer.claims { + peer_validator::validate_peer_identity_claim( + &self.config.peer_validator_config, + &new_peer.public_key, + &claim, + )?; + peer.update_addresses(&claim.addresses, &PeerAddressSource::FromDiscovery { + peer_identity_claim: claim.clone(), + }); + } -fn validate_node_id(public_key: &CommsPublicKey, node_id: &NodeId) -> Result { - let expected_node_id = NodeId::from_key(public_key); - if expected_node_id == *node_id { - Ok(expected_node_id) - } else { - Err(PeerValidatorError::InvalidNodeId { peer: node_id.clone() }) + Ok(peer) } } @@ -137,11 +130,10 @@ mod tests { use tari_utilities::ByteArray; use super::*; - use crate::test_utils::{build_peer_manager, make_node_identity}; + use crate::test_utils::make_node_identity; #[tokio::test] - async fn it_adds_a_valid_unsigned_peer() { - let peer_manager = build_peer_manager(); + async fn it_errors_with_invalid_signature() { let config = DhtConfig::default_local_test(); let node_identity = make_node_identity(); let mut peer = node_identity.to_peer(); @@ -159,26 +151,26 @@ mod tests { ), Default::default(), ), - unverified_data: None, }, }); - let validator = PeerValidator::new(&peer_manager, &config); - let is_new = validator.validate_and_add_peer(peer.clone().into()).await.unwrap(); - assert!(is_new); - assert!(peer_manager.exists(&peer.public_key).await); + let validator = PeerValidator::new(&config); + let err = validator + .validate_peer(UnvalidatedPeerInfo::from_peer_limited_claims(peer.clone(), 5, 5), None) + .unwrap_err(); + unpack_enum!(DhtPeerValidatorError::ValidatorError(PeerValidatorError::InvalidPeerSignature { .. }) = err); } #[tokio::test] async fn it_does_not_add_an_invalid_peer() { - let peer_manager = build_peer_manager(); let config = DhtConfig::default_local_test(); let node_identity = make_node_identity(); let mut peer = node_identity.to_peer(); // Peer MUST provide at least one address peer.addresses = MultiaddressesWithStats::new(vec![]); - let validator = PeerValidator::new(&peer_manager, &config); - let err = validator.validate_and_add_peer(peer.clone().into()).await.unwrap_err(); - unpack_enum!(PeerValidatorError::PeerHasNoAddresses { .. } = err); - assert!(!peer_manager.exists(&peer.public_key).await); + let validator = PeerValidator::new(&config); + let err = validator + .validate_peer(UnvalidatedPeerInfo::from_peer_limited_claims(peer, 5, 5), None) + .unwrap_err(); + unpack_enum!(DhtPeerValidatorError::ValidatorError(PeerValidatorError::PeerHasNoAddresses { .. }) = err); } } diff --git a/comms/dht/src/proto/dht.proto b/comms/dht/src/proto/dht.proto index 3b28ed44d5..217e1cf592 100644 --- a/comms/dht/src/proto/dht.proto +++ b/comms/dht/src/proto/dht.proto @@ -16,7 +16,7 @@ import "common.proto"; message JoinMessage { bytes public_key =1; repeated bytes addresses = 2; - uint64 peer_features = 3; + uint32 peer_features = 3; uint64 nonce = 4; tari.dht.common.IdentitySignature identity_signature = 5; } @@ -29,7 +29,7 @@ message JoinMessage { message DiscoveryMessage { bytes public_key =1; repeated bytes addresses = 2; - uint64 peer_features = 3; + uint32 peer_features = 3; uint64 nonce = 4; tari.dht.common.IdentitySignature identity_signature = 5; } @@ -37,7 +37,7 @@ message DiscoveryMessage { message DiscoveryResponseMessage { bytes public_key = 1; repeated bytes addresses = 2; - uint64 peer_features = 3; + uint32 peer_features = 3; uint64 nonce = 4; tari.dht.common.IdentitySignature identity_signature = 5; } diff --git a/comms/dht/src/proto/mod.rs b/comms/dht/src/proto/mod.rs index e74c500b02..e8eaf674a2 100644 --- a/comms/dht/src/proto/mod.rs +++ b/comms/dht/src/proto/mod.rs @@ -34,14 +34,9 @@ use tari_comms::{ types::{CommsPublicKey, CommsSecretKey, Signature}, NodeIdentity, }; -use tari_crypto::ristretto::RistrettoPublicKey; -use tari_utilities::{hex::Hex, ByteArray, ByteArrayError}; -use thiserror::Error; +use tari_utilities::{hex::Hex, ByteArray}; -use crate::{ - proto::dht::{DiscoveryMessage, JoinMessage}, - rpc::{PeerInfo, PeerInfoAddress}, -}; +use crate::{proto::dht::JoinMessage, rpc::UnvalidatedPeerInfo}; pub mod common { tari_comms::outdir_include!("tari.dht.common.rs"); @@ -87,9 +82,9 @@ impl fmt::Display for dht::JoinMessage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "JoinMessage(PK = {}, Addresses = {:?}, Features = {:?})", + "JoinMessage(PK = {}, {} Addresses, Features = {:?})", self.public_key.to_hex(), - self.addresses, + self.addresses.len(), PeerFeatures::from_bits_truncate(self.peer_features), ) } @@ -97,79 +92,11 @@ impl fmt::Display for dht::JoinMessage { //---------------------------------- Rpc Message Conversions --------------------------------------------// -#[derive(Debug, Error, PartialEq)] -enum PeerInfoConvertError { - #[error("Could not convert into byte array: `{0}`")] - ByteArrayError(String), -} - -impl From for PeerInfoConvertError { - fn from(e: ByteArrayError) -> Self { - PeerInfoConvertError::ByteArrayError(e.to_string()) - } -} - -impl TryFrom for PeerInfo { - type Error = anyhow::Error; - - fn try_from(value: DiscoveryMessage) -> Result { - let identity_signature = value - .identity_signature - .ok_or_else(|| anyhow!("DiscoveryMessage missing peer_identity_claim"))? - .try_into()?; - - let identity_claim = PeerIdentityClaim { - addresses: value - .addresses - .iter() - .map(|a| Multiaddr::try_from(a.clone())) - .collect::>()?, - features: PeerFeatures::from_bits_truncate(value.peer_features), - signature: identity_signature, - unverified_data: None, - }; - - Ok(Self { - public_key: RistrettoPublicKey::from_bytes(&value.public_key) - .map_err(|e| PeerInfoConvertError::ByteArrayError(format!("{}", e)))?, - addresses: value - .addresses - .iter() - .map(|a| { - Ok(PeerInfoAddress { - address: Multiaddr::try_from(a.clone())?, - peer_identity_claim: identity_claim.clone(), - }) - }) - .collect::>()?, - peer_features: PeerFeatures::from_bits_truncate(value.peer_features), - supported_protocols: vec![], - user_agent: "".to_string(), - }) - } -} - -impl From for rpc::PeerInfo { - fn from(value: PeerInfo) -> Self { +impl From for rpc::PeerInfo { + fn from(value: UnvalidatedPeerInfo) -> Self { Self { public_key: value.public_key.to_vec(), - addresses: value.addresses.into_iter().map(Into::into).collect(), - peer_features: value.peer_features.bits(), - supported_protocols: value - .supported_protocols - .into_iter() - .map(|b| b.as_ref().to_vec()) - .collect(), - user_agent: value.user_agent, - } - } -} - -impl From for rpc::PeerInfoAddress { - fn from(value: PeerInfoAddress) -> Self { - Self { - address: value.address.to_vec(), - peer_identity_claim: Some(value.peer_identity_claim.into()), + claims: value.claims.into_iter().map(Into::into).collect(), } } } @@ -184,62 +111,34 @@ impl From for rpc::PeerIdentityClaim { } } -impl TryInto for rpc::PeerInfo { +impl TryFrom for UnvalidatedPeerInfo { type Error = anyhow::Error; - fn try_into(self) -> Result { - let public_key = CommsPublicKey::from_bytes(&self.public_key) - .map_err(|e| PeerInfoConvertError::ByteArrayError(format!("{}", e)))?; - let addresses = self - .addresses + fn try_from(value: rpc::PeerInfo) -> Result { + let public_key = + CommsPublicKey::from_bytes(&value.public_key).map_err(|e| anyhow!("PeerInfo invalid public key: {}", e))?; + let claims = value + .claims .into_iter() .map(TryInto::try_into) - .collect::, _>>()?; - let peer_features = PeerFeatures::from_bits_truncate(self.peer_features); - let supported_protocols = self - .supported_protocols - .into_iter() - .map(|b| b.try_into()) - .collect::, _>>()?; - Ok(PeerInfo { - public_key, - addresses, - peer_features, - user_agent: self.user_agent, - supported_protocols, - }) - } -} + .collect::>()?; -impl TryInto for rpc::PeerInfoAddress { - type Error = anyhow::Error; - - fn try_into(self) -> Result { - let address = Multiaddr::try_from(self.address)?; - let peer_identity_claim = self - .peer_identity_claim - .ok_or_else(|| anyhow::anyhow!("Missing peer identity claim"))? - .try_into()?; - - Ok(PeerInfoAddress { - address, - peer_identity_claim, - }) + Ok(Self { public_key, claims }) } } -impl TryInto for rpc::PeerIdentityClaim { +impl TryFrom for PeerIdentityClaim { type Error = anyhow::Error; - fn try_into(self) -> Result { - let addresses = self + fn try_from(value: rpc::PeerIdentityClaim) -> Result { + let addresses = value .addresses .into_iter() .filter_map(|addr| Multiaddr::try_from(addr).ok()) .collect::>(); - let features = PeerFeatures::from_bits_truncate(self.peer_features); - let signature = self + let features = PeerFeatures::from_bits(value.peer_features).ok_or_else(|| anyhow!("Invalid peer features"))?; + let signature = value .identity_signature .map(TryInto::try_into) .ok_or_else(|| anyhow::anyhow!("No signature"))??; @@ -247,7 +146,6 @@ impl TryInto for rpc::PeerIdentityClaim { addresses, features, signature, - unverified_data: None, }) } } @@ -258,10 +156,10 @@ impl TryFrom for IdentitySignature { fn try_from(value: common::IdentitySignature) -> Result { let version = u8::try_from(value.version) .map_err(|_| anyhow::anyhow!("Invalid peer identity signature version {}", value.version))?; - let public_nonce = CommsPublicKey::from_bytes(&value.public_nonce) - .map_err(|e| PeerInfoConvertError::ByteArrayError(format!("{}", e)))?; - let signature = CommsSecretKey::from_bytes(&value.signature) - .map_err(|e| PeerInfoConvertError::ByteArrayError(format!("{}", e)))?; + let public_nonce = + CommsPublicKey::from_bytes(&value.public_nonce).map_err(|e| anyhow!("Invalid public nonce: {}", e))?; + let signature = + CommsSecretKey::from_bytes(&value.signature).map_err(|e| anyhow!("Invalid signature: {}", e))?; let updated_at = NaiveDateTime::from_timestamp_opt(value.updated_at, 0) .ok_or_else(|| anyhow::anyhow!("updated_at overflowed"))?; let updated_at = DateTime::::from_utc(updated_at, Utc); diff --git a/comms/dht/src/proto/rpc.proto b/comms/dht/src/proto/rpc.proto index 3728c0012c..ab97260434 100644 --- a/comms/dht/src/proto/rpc.proto +++ b/comms/dht/src/proto/rpc.proto @@ -14,6 +14,8 @@ message GetCloserPeersRequest { repeated bytes excluded = 2; bytes closer_to = 3; bool include_clients = 4; + uint32 max_claims = 5; + uint32 max_addresses_per_claim = 6; } // `get_peers` request @@ -21,6 +23,8 @@ message GetPeersRequest { // The number of peers to return, 0 for all peers uint32 n = 1; bool include_clients = 2; + uint32 max_claims = 3; + uint32 max_addresses_per_claim = 4; } // GET peers response @@ -31,21 +35,11 @@ message GetPeersResponse { // Minimal peer information message PeerInfo { bytes public_key = 1; - repeated PeerInfoAddress addresses = 2; - uint64 peer_features = 3; - repeated bytes supported_protocols = 4; - // Note: not part of the signature - string user_agent = 5; - -} - -message PeerInfoAddress { - bytes address = 1; - PeerIdentityClaim peer_identity_claim = 2; + repeated PeerIdentityClaim claims = 2; } message PeerIdentityClaim { repeated bytes addresses = 1; - uint64 peer_features = 2; + uint32 peer_features = 2; tari.dht.common.IdentitySignature identity_signature = 3; } diff --git a/comms/dht/src/rpc/mod.rs b/comms/dht/src/rpc/mod.rs index 2b58fe62d5..2dd15f1440 100644 --- a/comms/dht/src/rpc/mod.rs +++ b/comms/dht/src/rpc/mod.rs @@ -37,7 +37,7 @@ use tari_comms_rpc_macros::tari_rpc; use crate::proto::rpc::{GetCloserPeersRequest, GetPeersRequest, GetPeersResponse}; mod peer_info; -pub use peer_info::{PeerInfo, PeerInfoAddress}; +pub use peer_info::UnvalidatedPeerInfo; #[tari_rpc(protocol_name = b"t/dht/1", server_struct = DhtService, client_struct = DhtClient)] pub trait DhtRpcService: Send + Sync + 'static { diff --git a/comms/dht/src/rpc/peer_info.rs b/comms/dht/src/rpc/peer_info.rs index 27a4044080..668596f75e 100644 --- a/comms/dht/src/rpc/peer_info.rs +++ b/comms/dht/src/rpc/peer_info.rs @@ -20,47 +20,145 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::convert::{TryFrom, TryInto}; + +use anyhow::anyhow; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{Peer, PeerFeatures, PeerIdentityClaim}, - protocol::ProtocolId, types::CommsPublicKey, }; +use tari_crypto::ristretto::RistrettoPublicKey; +use tari_utilities::ByteArray; + +use crate::proto::dht::{DiscoveryMessage, DiscoveryResponseMessage, JoinMessage}; -pub struct PeerInfo { +pub struct UnvalidatedPeerInfo { pub public_key: CommsPublicKey, - pub addresses: Vec, - pub peer_features: PeerFeatures, - pub user_agent: String, - pub supported_protocols: Vec, + pub claims: Vec, } -pub struct PeerInfoAddress { - pub address: Multiaddr, - pub peer_identity_claim: PeerIdentityClaim, -} +impl UnvalidatedPeerInfo { + pub fn from_peer_limited_claims(peer: Peer, max_claims: usize, max_addresse_per_claim: usize) -> Self { + let claims = peer + .addresses + .addresses() + .iter() + .filter_map(|addr| { + if addr.address().is_empty() { + return None; + } + + let claim = addr.source().peer_identity_claim()?; + + if claim.addresses.len() > max_addresse_per_claim { + return None; + } -impl From for PeerInfo { - fn from(peer: Peer) -> Self { - PeerInfo { + Some(claim) + }) + .take(max_claims) + .cloned() + .collect::>(); + + Self { public_key: peer.public_key, - addresses: peer - .addresses - .addresses() - .iter() - .filter_map(|addr| { - if addr.address().is_empty() { - return None; - } - addr.source.peer_identity_claim().map(|claim| PeerInfoAddress { - address: addr.address().clone(), - peer_identity_claim: claim.clone(), - }) - }) - .collect(), - peer_features: peer.features, - user_agent: peer.user_agent, - supported_protocols: peer.supported_protocols, + claims, } } } + +impl TryFrom for UnvalidatedPeerInfo { + type Error = anyhow::Error; + + fn try_from(value: DiscoveryMessage) -> Result { + let public_key = RistrettoPublicKey::from_bytes(&value.public_key) + .map_err(|e| anyhow!("DiscoveryMessage invalid public key: {}", e))?; + + let features = PeerFeatures::from_bits(value.peer_features) + .ok_or_else(|| anyhow!("Invalid peer features. Bits: {:#04x}", value.peer_features))?; + + let identity_signature = value + .identity_signature + .ok_or_else(|| anyhow!("DiscoveryMessage missing peer_identity_claim"))? + .try_into()?; + let identity_claim = PeerIdentityClaim { + addresses: value + .addresses + .into_iter() + .map(Multiaddr::try_from) + .collect::>()?, + features, + signature: identity_signature, + }; + + Ok(Self { + public_key, + claims: vec![identity_claim], + }) + } +} + +impl TryFrom for UnvalidatedPeerInfo { + type Error = anyhow::Error; + + fn try_from(value: DiscoveryResponseMessage) -> Result { + let public_key = RistrettoPublicKey::from_bytes(&value.public_key) + .map_err(|e| anyhow!("DiscoveryMessage invalid public key: {}", e))?; + + let features = PeerFeatures::from_bits(value.peer_features) + .ok_or_else(|| anyhow!("Invalid peer features. Bits: {:#04x}", value.peer_features))?; + + let identity_signature = value + .identity_signature + .ok_or_else(|| anyhow!("DiscoveryMessage missing peer_identity_claim"))? + .try_into()?; + + let identity_claim = PeerIdentityClaim { + addresses: value + .addresses + .into_iter() + .map(Multiaddr::try_from) + .collect::>()?, + features, + signature: identity_signature, + }; + + Ok(Self { + public_key, + claims: vec![identity_claim], + }) + } +} + +impl TryFrom for UnvalidatedPeerInfo { + type Error = anyhow::Error; + + fn try_from(value: JoinMessage) -> Result { + let public_key = RistrettoPublicKey::from_bytes(&value.public_key) + .map_err(|e| anyhow!("JoinMessage invalid public key: {}", e))?; + + let features = PeerFeatures::from_bits(value.peer_features) + .ok_or_else(|| anyhow!("Invalid peer features. Bits: {:#04x}", value.peer_features))?; + + let identity_signature = value + .identity_signature + .ok_or_else(|| anyhow!("JoinMessage missing peer_identity_claim"))? + .try_into()?; + + let identity_claim = PeerIdentityClaim { + addresses: value + .addresses + .into_iter() + .map(Multiaddr::try_from) + .collect::>()?, + features, + signature: identity_signature, + }; + + Ok(Self { + public_key, + claims: vec![identity_claim], + }) + } +} diff --git a/comms/dht/src/rpc/service.rs b/comms/dht/src/rpc/service.rs index d46df524b0..55c2c653f3 100644 --- a/comms/dht/src/rpc/service.rs +++ b/comms/dht/src/rpc/service.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{cmp, sync::Arc}; +use std::{cmp, convert::TryInto, sync::Arc}; use log::*; use tari_comms::{ @@ -34,7 +34,7 @@ use tokio::{sync::mpsc, task}; use crate::{ proto::rpc::{GetCloserPeersRequest, GetPeersRequest, GetPeersResponse}, - rpc::{DhtRpcService, PeerInfo}, + rpc::{DhtRpcService, UnvalidatedPeerInfo}, }; const LOG_TARGET: &str = "comms::dht::rpc"; @@ -51,7 +51,12 @@ impl DhtRpcServiceImpl { Self { peer_manager } } - pub fn stream_peers(&self, peers: Vec) -> Streaming { + pub fn stream_peers( + &self, + peers: Vec, + max_claims: usize, + max_addresses_per_claim: usize, + ) -> Streaming { if peers.is_empty() { return Streaming::empty(); } @@ -63,9 +68,10 @@ impl DhtRpcServiceImpl { let iter = peers .into_iter() .filter_map(|peer| { - let peer_info: PeerInfo = peer.into(); + let peer_info = + UnvalidatedPeerInfo::from_peer_limited_claims(peer, max_claims, max_addresses_per_claim); - if peer_info.addresses.is_empty() { + if peer_info.claims.is_empty() { None } else { Some(GetPeersResponse { @@ -100,6 +106,24 @@ impl DhtRpcService for DhtRpcServiceImpl { ))); } + let max_claims = message.max_claims.try_into().map_err(|_| + // This can't happen on a >= 32-bit arch + RpcStatus::bad_request("max_claims is too large"))?; + + if max_claims == 0 { + return Err(RpcStatus::bad_request("max_claims must be greater than zero")); + } + + let max_addresses_per_claim = message.max_addresses_per_claim.try_into().map_err(|_| + // This can't happen on a >= 32-bit arch + RpcStatus::bad_request("max_addresses_per_claim is too large"))?; + + if max_addresses_per_claim == 0 { + return Err(RpcStatus::bad_request( + "max_addresses_per_claim must be greater than zero", + )); + } + let node_id = if message.closer_to.is_empty() { request.context().peer_node_id().clone() } else { @@ -146,7 +170,7 @@ impl DhtRpcService for DhtRpcServiceImpl { node_id.short_str() ); - Ok(self.stream_peers(peers)) + Ok(self.stream_peers(peers, max_claims, max_addresses_per_claim)) } async fn get_peers(&self, request: Request) -> Result, RpcStatus> { @@ -157,6 +181,22 @@ impl DhtRpcService for DhtRpcServiceImpl { features = None; } + let max_claims = message.max_claims.try_into().map_err(|_| + // This can't happen on a >= 32-bit arch + RpcStatus::bad_request("max_claims is too large"))?; + if max_claims == 0 { + return Err(RpcStatus::bad_request("max_claims must be greater than zero")); + } + let max_addresses_per_claim = message.max_addresses_per_claim.try_into().map_err(|_| + // This can't happen on a >= 32-bit arch + RpcStatus::bad_request("max_addresses_per_claim is too large"))?; + + if max_addresses_per_claim == 0 { + return Err(RpcStatus::bad_request( + "max_addresses_per_claim must be greater than zero", + )); + } + let peers = self .peer_manager .discovery_syncing(message.n as usize, &excluded_peers, features) @@ -176,6 +216,6 @@ impl DhtRpcService for DhtRpcServiceImpl { node_id.short_str() ); - Ok(self.stream_peers(peers)) + Ok(self.stream_peers(peers, max_claims, max_addresses_per_claim)) } } diff --git a/comms/dht/src/rpc/test.rs b/comms/dht/src/rpc/test.rs index 5f74629bb3..5d245586cf 100644 --- a/comms/dht/src/rpc/test.rs +++ b/comms/dht/src/rpc/test.rs @@ -51,7 +51,7 @@ mod get_closer_peers { use std::borrow::BorrowMut; use super::*; - use crate::rpc::PeerInfo; + use crate::rpc::UnvalidatedPeerInfo; #[tokio::test] async fn it_returns_empty_peer_stream() { @@ -62,6 +62,8 @@ mod get_closer_peers { excluded: vec![], closer_to: node_identity.node_id().to_vec(), include_clients: false, + max_claims: 5, + max_addresses_per_claim: 5, }; let req = mock.request_with_context(node_identity.node_id().clone(), req); @@ -89,6 +91,8 @@ mod get_closer_peers { excluded: vec![], closer_to: node_identity.node_id().to_vec(), include_clients: false, + max_claims: 5, + max_addresses_per_claim: 5, }; let req = mock.request_with_context(node_identity.node_id().clone(), req); @@ -101,7 +105,7 @@ mod get_closer_peers { .map(Result::unwrap) .map(|r| r.peer.unwrap()) .map(|p| p.try_into().unwrap()) - .collect::>(); + .collect::>(); let mut dist = NodeDistance::zero(); for p in &peers { @@ -130,6 +134,8 @@ mod get_closer_peers { excluded: vec![], closer_to: node_identity.node_id().to_vec(), include_clients: false, + max_claims: 5, + max_addresses_per_claim: 5, }; let req = mock.request_with_context(node_identity.node_id().clone(), req); @@ -158,6 +164,8 @@ mod get_closer_peers { excluded: vec![excluded_peer.node_id().to_vec()], closer_to: node_identity.node_id().to_vec(), include_clients: true, + max_claims: 5, + max_addresses_per_claim: 5, }; let req = mock.request_with_context(node_identity.node_id().clone(), req); @@ -188,7 +196,7 @@ mod get_peers { use tari_comms::test_utils::node_identity::build_many_node_identities; use super::*; - use crate::{proto::rpc::GetPeersRequest, rpc::PeerInfo}; + use crate::{proto::rpc::GetPeersRequest, rpc::UnvalidatedPeerInfo}; #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn it_returns_empty_peer_stream() { @@ -197,6 +205,8 @@ mod get_peers { let req = GetPeersRequest { n: 10, include_clients: false, + max_claims: 5, + max_addresses_per_claim: 5, }; let req = mock.request_with_context(node_identity.node_id().clone(), req); @@ -222,6 +232,8 @@ mod get_peers { let req = GetPeersRequest { n: 5, include_clients: true, + max_claims: 5, + max_addresses_per_claim: 5, }; let peers_stream = service @@ -236,10 +248,10 @@ mod get_peers { .map(Result::unwrap) .map(|r| r.peer.unwrap()) .map(|p| p.try_into().unwrap()) - .collect::>(); + .collect::>(); - assert_eq!(peers.iter().filter(|p| p.peer_features.is_client()).count(), 2); - assert_eq!(peers.iter().filter(|p| p.peer_features.is_node()).count(), 3); + assert_eq!(peers.iter().filter(|p| p.claims[0].features.is_client()).count(), 2); + assert_eq!(peers.iter().filter(|p| p.claims[0].features.is_node()).count(), 3); } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] @@ -258,6 +270,8 @@ mod get_peers { let req = GetPeersRequest { n: 3, include_clients: false, + max_claims: 5, + max_addresses_per_claim: 5, }; let peers_stream = service @@ -272,9 +286,9 @@ mod get_peers { .map(Result::unwrap) .map(|r| r.peer.unwrap()) .map(|p| p.try_into().unwrap()) - .collect::>(); + .collect::>(); - assert!(peers.iter().all(|p| p.peer_features.is_node())); + assert!(peers.iter().all(|p| p.claims[0].features.is_node())); } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] @@ -294,6 +308,8 @@ mod get_peers { let req = GetPeersRequest { n: 2, include_clients: false, + max_claims: 5, + max_addresses_per_claim: 5, }; let req = mock.request_with_context(node_identity.node_id().clone(), req); diff --git a/comms/dht/src/test_utils/dht_actor_mock.rs b/comms/dht/src/test_utils/dht_actor_mock.rs index 08c1d56ed9..b8714ffa36 100644 --- a/comms/dht/src/test_utils/dht_actor_mock.rs +++ b/comms/dht/src/test_utils/dht_actor_mock.rs @@ -137,6 +137,7 @@ impl DhtActorMock { reply_tx.send(Ok(())).unwrap(); }, DialDiscoverPeer { .. } => unimplemented!(), + BanPeer { .. } => unimplemented!(), } } } diff --git a/comms/dht/tests/attacks.rs b/comms/dht/tests/attacks.rs new file mode 100644 index 0000000000..c7c460fc87 --- /dev/null +++ b/comms/dht/tests/attacks.rs @@ -0,0 +1,179 @@ +// // Copyright 2023. The Tari Project +// // +// // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// // following conditions are met: +// // +// // 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the +// following // disclaimer. +// // +// // 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// // following disclaimer in the documentation and/or other materials provided with the distribution. +// // +// // 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// // products derived from this software without specific prior written permission. +// // +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// // INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +mod harness; +use std::{iter, time::Duration}; + +use harness::*; +use rand::{rngs::OsRng, Rng, RngCore}; +use tari_comms::{ + peer_manager::{IdentitySignature, PeerFeatures}, + NodeIdentity, +}; +use tari_comms_dht::{envelope::DhtMessageType, outbound::SendMessageParams}; +use tari_test_utils::async_assert_eventually; +use tari_utilities::ByteArray; + +#[tokio::test(flavor = "multi_thread")] +async fn large_join_messages_with_many_addresses() { + // Create 3 nodes where only Node B knows A and C, but A and C want to talk to each other + + // Node C knows no one + let node_c = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; + // Node B knows about Node C + let node_b = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_c.to_peer()), + ) + .await; + // Node A knows about Node B + let node_a = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_b.to_peer()), + ) + .await; + + node_a + .comms + .connectivity() + .wait_for_connectivity(Duration::from_secs(10)) + .await + .unwrap(); + node_b + .comms + .connectivity() + .wait_for_connectivity(Duration::from_secs(10)) + .await + .unwrap(); + + let addresses = iter::repeat_with(random_multiaddr_bytes) + .take(900 * 1024 / 32) + .collect::>(); + let node_identity = (*node_a.node_identity()).clone(); + let message = JoinMessage::from_node_identity(&node_identity, addresses); + + node_a + .dht + .outbound_requester() + .send_message_no_header( + SendMessageParams::new() + .direct_node_id(node_b.node_identity().node_id().clone()) + .with_destination(node_c.comms.node_identity().public_key().clone().into()) + .with_dht_message_type(DhtMessageType::Join) + .force_origin() + .finish(), + message, + ) + .await + .unwrap(); + + let node_b_peer_manager = node_b.comms.peer_manager(); + let node_c_peer_manager = node_c.comms.peer_manager(); + + // Check that Node B bans node A + async_assert_eventually!( + node_b_peer_manager + .is_peer_banned(node_a.node_identity().node_id()) + .await + .unwrap(), + expect = true, + max_attempts = 10, + interval = Duration::from_secs(1) + ); + // Node B did not propagate + assert!(!node_c_peer_manager.exists(node_a.node_identity().public_key()).await); + + node_a.shutdown().await; + node_b.shutdown().await; + node_c.shutdown().await; +} + +// Copies of non-public the JoinMessage and IdentitySignature structs too allow this test to manipulate them +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JoinMessage { + #[prost(bytes = "vec", tag = "1")] + pub public_key: Vec, + #[prost(bytes = "vec", repeated, tag = "2")] + pub addresses: Vec>, + #[prost(uint32, tag = "3")] + pub peer_features: u32, + #[prost(uint64, tag = "4")] + pub nonce: u64, + #[prost(message, optional, tag = "5")] + pub identity_signature: Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct IdentitySignatureProto { + #[prost(uint32, tag = "1")] + pub version: u32, + #[prost(bytes = "vec", tag = "2")] + pub signature: Vec, + #[prost(bytes = "vec", tag = "3")] + pub public_nonce: Vec, + /// The EPOCH timestamp used in the identity signature challenge + #[prost(int64, tag = "4")] + pub updated_at: i64, +} + +impl JoinMessage { + fn from_node_identity(node_identity: &NodeIdentity, raw_addresses: Vec>) -> Self { + Self { + public_key: node_identity.public_key().to_vec(), + addresses: raw_addresses, + peer_features: node_identity.features().bits(), + nonce: OsRng.next_u64(), + identity_signature: node_identity.identity_signature_read().as_ref().map(Into::into), + } + } +} + +impl From<&IdentitySignature> for IdentitySignatureProto { + fn from(identity_sig: &IdentitySignature) -> Self { + Self { + version: u32::from(identity_sig.version()), + signature: identity_sig.signature().get_signature().to_vec(), + public_nonce: identity_sig.signature().get_public_nonce().to_vec(), + updated_at: identity_sig.updated_at().timestamp(), + } + } +} + +fn random_port() -> u16 { + let mut rng = rand::thread_rng(); + rng.gen_range(1024..=65535) +} + +fn random_multiaddr_bytes() -> Vec { + let port = random_port(); + let mut rng = rand::thread_rng(); + + let mut bytes = Vec::with_capacity(7); + bytes.push(4); // IP4 code + bytes.extend([rng.gen::(), rng.gen(), rng.gen(), rng.gen()]); + bytes.push(6); // TCP code + bytes.extend(&port.to_be_bytes()); + + bytes +} diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index 2a4192fa83..fe0dbf2fc7 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -20,215 +20,23 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +mod harness; use std::{sync::Arc, time::Duration}; -use rand::rngs::OsRng; +use harness::*; use tari_comms::{ - backoff::ConstantBackoff, connectivity::ConnectivityEvent, message::MessageExt, - peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures}, - pipeline, - pipeline::SinkService, - protocol::messaging::{MessagingEvent, MessagingEventSender, MessagingProtocolExtension}, - transports::MemoryTransport, - types::CommsDatabase, - CommsBuilder, - CommsNode, + peer_manager::{NodeId, PeerFeatures}, + protocol::messaging::MessagingEvent, }; use tari_comms_dht::{ domain_message::OutboundDomainMessage, envelope::{DhtMessageType, NodeDestination}, event::DhtEvent, - inbound::DecryptedDhtMessage, outbound::{OutboundEncryption, SendMessageParams}, - DbConnectionUrl, - Dht, - DhtConfig, }; -use tari_shutdown::{Shutdown, ShutdownSignal}; -use tari_storage::{ - lmdb_store::{LMDBBuilder, LMDBConfig}, - LMDBWrapper, -}; -use tari_test_utils::{ - async_assert_eventually, - collect_try_recv, - paths::create_temporary_data_path, - random, - streams, - unpack_enum, -}; -use tokio::{ - sync::{broadcast, mpsc}, - time, -}; -use tower::ServiceBuilder; - -struct TestNode { - name: String, - comms: CommsNode, - dht: Dht, - inbound_messages: mpsc::Receiver, - messaging_events: broadcast::Sender>, - shutdown: Shutdown, -} - -impl TestNode { - pub fn node_identity(&self) -> Arc { - self.comms.node_identity() - } - - pub fn to_peer(&self) -> Peer { - self.comms.node_identity().to_peer() - } - - pub fn name(&self) -> &str { - &self.name - } - - pub async fn next_inbound_message(&mut self, timeout: Duration) -> Option { - time::timeout(timeout, self.inbound_messages.recv()).await.ok()? - } - - pub async fn shutdown(mut self) { - self.shutdown.trigger(); - self.comms.wait_until_shutdown().await; - } -} - -fn make_node_identity(features: PeerFeatures) -> Arc { - let port = MemoryTransport::acquire_next_memsocket_port(); - Arc::new(NodeIdentity::random( - &mut OsRng, - format!("/memory/{}", port).parse().unwrap(), - features, - )) -} - -fn create_peer_storage() -> CommsDatabase { - let database_name = random::string(8); - let datastore = LMDBBuilder::new() - .set_path(create_temporary_data_path()) - .set_env_config(LMDBConfig::default()) - .set_max_number_of_databases(1) - .add_database(&database_name, lmdb_zero::db::CREATE) - .build() - .unwrap(); - - let peer_database = datastore.get_handle(&database_name).unwrap(); - LMDBWrapper::new(Arc::new(peer_database)) -} - -async fn make_node>( - name: &str, - features: PeerFeatures, - dht_config: DhtConfig, - known_peers: I, -) -> TestNode { - let node_identity = make_node_identity(features); - make_node_with_node_identity(name, node_identity, dht_config, known_peers).await -} - -async fn make_node_with_node_identity>( - name: &str, - node_identity: Arc, - dht_config: DhtConfig, - known_peers: I, -) -> TestNode { - let (tx, inbound_messages) = mpsc::channel(10); - let shutdown = Shutdown::new(); - let (comms, dht, messaging_events) = setup_comms_dht( - node_identity, - create_peer_storage(), - tx, - known_peers.into_iter().collect(), - dht_config, - shutdown.to_signal(), - ) - .await; - - TestNode { - name: name.to_string(), - comms, - dht, - inbound_messages, - messaging_events, - shutdown, - } -} - -async fn setup_comms_dht( - node_identity: Arc, - storage: CommsDatabase, - inbound_tx: mpsc::Sender, - peers: Vec, - dht_config: DhtConfig, - shutdown_signal: ShutdownSignal, -) -> (CommsNode, Dht, MessagingEventSender) { - // Create inbound and outbound channels - let (outbound_tx, outbound_rx) = mpsc::channel(10); - - let comms = CommsBuilder::new() - .allow_test_addresses() - // In this case the listener address and the public address are the same (/memory/...) - .with_listener_address(node_identity.first_public_address().unwrap()) - .with_shutdown_signal(shutdown_signal) - .with_node_identity(node_identity) - .with_peer_storage(storage,None) - .with_min_connectivity(1) - .with_dial_backoff(ConstantBackoff::new(Duration::from_millis(100))) - .build() - .unwrap(); - - let dht = Dht::builder() - .with_config(dht_config) - .with_database_url(DbConnectionUrl::MemoryShared(random::string(8))) - .with_outbound_sender(outbound_tx) - .build( - comms.node_identity(), - comms.peer_manager(), - comms.connectivity(), - comms.shutdown_signal(), - ) - .await - .unwrap(); - - for peer in peers { - comms.peer_manager().add_peer(peer).await.unwrap(); - } - - let dht_outbound_layer = dht.outbound_middleware_layer(); - let pipeline = pipeline::Builder::new() - .with_outbound_pipeline(outbound_rx, |sink| { - ServiceBuilder::new().layer(dht_outbound_layer).service(sink) - }) - .max_concurrent_inbound_tasks(10) - .with_inbound_pipeline( - ServiceBuilder::new() - .layer(dht.inbound_middleware_layer()) - .service(SinkService::new(inbound_tx)), - ) - .build(); - - let (event_tx, _) = broadcast::channel(100); - let comms = comms - .add_protocol_extension(MessagingProtocolExtension::new(event_tx.clone(), pipeline)) - .spawn_with_transport(MemoryTransport) - .await - .unwrap(); - - (comms, dht, event_tx) -} - -fn dht_config() -> DhtConfig { - let mut config = DhtConfig::default_local_test(); - config.allow_test_addresses = true; - config.saf.auto_request = false; - config.discovery_request_timeout = Duration::from_secs(60); - config.num_neighbouring_nodes = 8; - config -} +use tari_test_utils::{async_assert_eventually, collect_try_recv, streams, unpack_enum}; #[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[allow(non_snake_case)] diff --git a/comms/dht/tests/harness.rs b/comms/dht/tests/harness.rs new file mode 100644 index 0000000000..bc9abbc8d5 --- /dev/null +++ b/comms/dht/tests/harness.rs @@ -0,0 +1,215 @@ +// // Copyright 2023. The Tari Project +// // +// // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// // following conditions are met: +// // +// // 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the +// following // disclaimer. +// // +// // 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// // following disclaimer in the documentation and/or other materials provided with the distribution. +// // +// // 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// // products derived from this software without specific prior written permission. +// // +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// // INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use std::{sync::Arc, time::Duration}; + +use rand::rngs::OsRng; +use tari_comms::{ + backoff::ConstantBackoff, + peer_manager::{NodeIdentity, Peer, PeerFeatures}, + pipeline, + pipeline::SinkService, + protocol::messaging::{MessagingEvent, MessagingEventSender, MessagingProtocolExtension}, + transports::MemoryTransport, + types::CommsDatabase, + CommsBuilder, + CommsNode, +}; +use tari_comms_dht::{inbound::DecryptedDhtMessage, DbConnectionUrl, Dht, DhtConfig}; +use tari_shutdown::{Shutdown, ShutdownSignal}; +use tari_storage::{ + lmdb_store::{LMDBBuilder, LMDBConfig}, + LMDBWrapper, +}; +use tari_test_utils::{paths::create_temporary_data_path, random}; +use tokio::{ + sync::{broadcast, mpsc}, + time, +}; +use tower::ServiceBuilder; + +pub struct TestNode { + pub name: String, + pub comms: CommsNode, + pub dht: Dht, + pub inbound_messages: mpsc::Receiver, + pub messaging_events: broadcast::Sender>, + pub shutdown: Shutdown, +} + +impl TestNode { + pub fn node_identity(&self) -> Arc { + self.comms.node_identity() + } + + pub fn to_peer(&self) -> Peer { + self.comms.node_identity().to_peer() + } + + #[allow(dead_code)] + pub fn name(&self) -> &str { + &self.name + } + + #[allow(dead_code)] + pub async fn next_inbound_message(&mut self, timeout: Duration) -> Option { + time::timeout(timeout, self.inbound_messages.recv()).await.ok()? + } + + pub async fn shutdown(mut self) { + self.shutdown.trigger(); + self.comms.wait_until_shutdown().await; + } +} + +pub fn make_node_identity(features: PeerFeatures) -> Arc { + let port = MemoryTransport::acquire_next_memsocket_port(); + Arc::new(NodeIdentity::random( + &mut OsRng, + format!("/memory/{}", port).parse().unwrap(), + features, + )) +} + +pub fn create_peer_storage() -> CommsDatabase { + let database_name = random::string(8); + let datastore = LMDBBuilder::new() + .set_path(create_temporary_data_path()) + .set_env_config(LMDBConfig::default()) + .set_max_number_of_databases(1) + .add_database(&database_name, lmdb_zero::db::CREATE) + .build() + .unwrap(); + + let peer_database = datastore.get_handle(&database_name).unwrap(); + LMDBWrapper::new(Arc::new(peer_database)) +} + +pub async fn make_node>( + name: &str, + features: PeerFeatures, + dht_config: DhtConfig, + known_peers: I, +) -> TestNode { + let node_identity = make_node_identity(features); + make_node_with_node_identity(name, node_identity, dht_config, known_peers).await +} + +pub async fn make_node_with_node_identity>( + name: &str, + node_identity: Arc, + dht_config: DhtConfig, + known_peers: I, +) -> TestNode { + let (tx, inbound_messages) = mpsc::channel(10); + let shutdown = Shutdown::new(); + let (comms, dht, messaging_events) = setup_comms_dht( + node_identity, + create_peer_storage(), + tx, + known_peers.into_iter().collect(), + dht_config, + shutdown.to_signal(), + ) + .await; + + TestNode { + name: name.to_string(), + comms, + dht, + inbound_messages, + messaging_events, + shutdown, + } +} + +pub async fn setup_comms_dht( + node_identity: Arc, + storage: CommsDatabase, + inbound_tx: mpsc::Sender, + peers: Vec, + dht_config: DhtConfig, + shutdown_signal: ShutdownSignal, +) -> (CommsNode, Dht, MessagingEventSender) { + // Create inbound and outbound channels + let (outbound_tx, outbound_rx) = mpsc::channel(10); + + let comms = CommsBuilder::new() + .allow_test_addresses() + // In this case the listener address and the public address are the same (/memory/...) + .with_listener_address(node_identity.first_public_address().unwrap()) + .with_shutdown_signal(shutdown_signal) + .with_node_identity(node_identity) + .with_peer_storage(storage,None) + .with_min_connectivity(1) + .with_dial_backoff(ConstantBackoff::new(Duration::from_millis(100))) + .build() + .unwrap(); + + let dht = Dht::builder() + .with_config(dht_config) + .with_database_url(DbConnectionUrl::MemoryShared(random::string(8))) + .with_outbound_sender(outbound_tx) + .build( + comms.node_identity(), + comms.peer_manager(), + comms.connectivity(), + comms.shutdown_signal(), + ) + .await + .unwrap(); + + for peer in peers { + comms.peer_manager().add_peer(peer).await.unwrap(); + } + + let dht_outbound_layer = dht.outbound_middleware_layer(); + let pipeline = pipeline::Builder::new() + .with_outbound_pipeline(outbound_rx, |sink| { + ServiceBuilder::new().layer(dht_outbound_layer).service(sink) + }) + .max_concurrent_inbound_tasks(10) + .with_inbound_pipeline( + ServiceBuilder::new() + .layer(dht.inbound_middleware_layer()) + .service(SinkService::new(inbound_tx)), + ) + .build(); + + let (event_tx, _) = broadcast::channel(100); + let comms = comms + .add_protocol_extension(MessagingProtocolExtension::new(event_tx.clone(), pipeline)) + .spawn_with_transport(MemoryTransport) + .await + .unwrap(); + + (comms, dht, event_tx) +} + +pub fn dht_config() -> DhtConfig { + let mut config = DhtConfig::default_local_test(); + config.peer_validator_config.allow_test_addresses = true; + config.saf.auto_request = false; + config.discovery_request_timeout = Duration::from_secs(60); + config.num_neighbouring_nodes = 8; + config +}