diff --git a/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs index 7f85ed90c4..d19195b380 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{cmp::Ordering, time::Instant}; +use std::{cmp::Ordering, mem, time::Instant}; use log::*; use tari_common_types::chain_metadata::ChainMetadata; @@ -81,7 +81,7 @@ impl HeaderSyncState { shared.db.clone(), shared.consensus_rules.clone(), shared.connectivity.clone(), - &mut self.sync_peers, + mem::take(&mut self.sync_peers), shared.randomx_factory.clone(), &self.local_metadata, ); diff --git a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs index 6cdd482697..87cb1a4813 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs @@ -91,7 +91,7 @@ impl HorizonStateSync { db, connectivity, rules, - &sync_peers, + sync_peers, horizon_sync_height, prover, validator, diff --git a/base_layer/core/src/base_node/sync/ban.rs b/base_layer/core/src/base_node/sync/ban.rs new file mode 100644 index 0000000000..9f3da126fd --- /dev/null +++ b/base_layer/core/src/base_node/sync/ban.rs @@ -0,0 +1,65 @@ +// Copyright 2023, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use log::*; +use tari_comms::{connectivity::ConnectivityRequester, peer_manager::NodeId}; + +use crate::{base_node::BlockchainSyncConfig, common::BanReason}; + +const LOG_TARGET: &str = "c::bn::sync"; + +// Sync peers are banned if there exists a ban reason for the error and the peer is not on the allow list for sync. + +pub struct PeerBanManager { + config: BlockchainSyncConfig, + connectivity: ConnectivityRequester, +} + +impl PeerBanManager { + pub fn new(config: BlockchainSyncConfig, connectivity: ConnectivityRequester) -> Self { + Self { config, connectivity } + } + + pub async fn ban_peer_if_required(&mut self, node_id: &NodeId, ban_reason: &Option) { + if let Some(ban) = ban_reason { + if self.config.forced_sync_peers.contains(node_id) { + debug!( + target: LOG_TARGET, + "Not banning peer that is on the allow list for sync. Ban reason = {}", ban.reason() + ); + return; + } + debug!(target: LOG_TARGET, "Sync peer {} removed from the sync peer list because {}", node_id, ban.reason()); + + match self + .connectivity + .ban_peer_until(node_id.clone(), ban.ban_duration, ban.reason().to_string()) + .await + { + Ok(_) => { + warn!(target: LOG_TARGET, "Banned sync peer {} for {:?} because {}", node_id, ban.ban_duration, ban.reason()) + }, + Err(err) => error!(target: LOG_TARGET, "Failed to ban sync peer {}: {}", node_id, err), + } + } + } +} diff --git a/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs index 99cc87d2b1..1ebf52cc24 100644 --- a/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs @@ -37,12 +37,12 @@ use tracing; use super::error::BlockSyncError; use crate::{ base_node::{ - sync::{hooks::Hooks, rpc, SyncPeer}, + sync::{ban::PeerBanManager, hooks::Hooks, rpc, SyncPeer}, BlockchainSyncConfig, }, blocks::{Block, ChainBlock}, chain_storage::{async_db::AsyncBlockchainDb, BlockchainBackend}, - common::{rolling_avg::RollingAverageTime, BanReason}, + common::rolling_avg::RollingAverageTime, proto::base_node::SyncBlocksRequest, transactions::aggregated_body::AggregateBody, validation::{BlockBodyValidator, ValidationError}, @@ -57,6 +57,7 @@ pub struct BlockSynchronizer { sync_peers: Vec, block_validator: Arc>, hooks: Hooks, + peer_ban_manager: PeerBanManager, } impl BlockSynchronizer { @@ -67,6 +68,7 @@ impl BlockSynchronizer { sync_peers: Vec, block_validator: Arc>, ) -> Self { + let peer_ban_manager = PeerBanManager::new(config.clone(), connectivity.clone()); Self { config, db, @@ -74,6 +76,7 @@ impl BlockSynchronizer { sync_peers, block_validator, hooks: Default::default(), + peer_ban_manager, } } @@ -175,27 +178,24 @@ impl BlockSynchronizer { ); match self.synchronize_blocks(sync_peer, client, max_latency).await { Ok(_) => return Ok(()), - - Err(err @ BlockSyncError::MaxLatencyExceeded { .. }) => { - warn!(target: LOG_TARGET, "{}", err); - latency_counter += 1; - self.ban_peer_if_required( - node_id, - &BlockSyncError::get_ban_reason(&err, self.config.short_ban_period, self.config.ban_period), - ) - .await; - continue; - }, - Err(err) => { warn!(target: LOG_TARGET, "{}", err); - self.ban_peer_if_required( - node_id, - &BlockSyncError::get_ban_reason(&err, self.config.short_ban_period, self.config.ban_period), - ) - .await; - self.remove_sync_peer(node_id); - continue; + let ban_reason = + BlockSyncError::get_ban_reason(&err, self.config.short_ban_period, self.config.ban_period); + if let Some(reason) = ban_reason { + warn!(target: LOG_TARGET, "{}", err); + self.peer_ban_manager + .ban_peer_if_required(node_id, &Some(reason.clone())) + .await; + + if reason.ban_duration > self.config.short_ban_period { + self.remove_sync_peer(node_id); + } + } + + if let BlockSyncError::MaxLatencyExceeded { .. } = err { + latency_counter += 1; + } }, } } @@ -413,32 +413,7 @@ impl BlockSynchronizer { Ok(()) } - // Sync peers are banned if there exists a ban reason for the error and the peer is not on the allow list for sync. // Sync peers are also removed from the list of sync peers if the ban duration is longer than the short ban period. - async fn ban_peer_if_required(&mut self, node_id: &NodeId, ban_reason: &Option) { - if let Some(ban) = ban_reason { - if self.config.forced_sync_peers.contains(node_id) { - debug!( - target: LOG_TARGET, - "Not banning peer that is on the allow list for sync. Ban reason = {}", ban.reason() - ); - return; - } - debug!(target: LOG_TARGET, "Sync peer {} removed from the sync peer list because {}", node_id, ban.reason()); - - match self - .connectivity - .ban_peer_until(node_id.clone(), ban.ban_duration, ban.reason().to_string()) - .await - { - Ok(_) => { - warn!(target: LOG_TARGET, "Banned sync peer {} for {:?} because {}", node_id, ban.ban_duration, ban.reason()) - }, - Err(err) => error!(target: LOG_TARGET, "Failed to ban sync peer {}: {}", node_id, err), - } - } - } - fn remove_sync_peer(&mut self, node_id: &NodeId) { if let Some(pos) = self.sync_peers.iter().position(|p| p.node_id() == node_id) { self.sync_peers.remove(pos); diff --git a/base_layer/core/src/base_node/sync/header_sync/error.rs b/base_layer/core/src/base_node/sync/header_sync/error.rs index eaa0d4c63c..c06b9a6b14 100644 --- a/base_layer/core/src/base_node/sync/header_sync/error.rs +++ b/base_layer/core/src/base_node/sync/header_sync/error.rs @@ -28,10 +28,12 @@ use tari_comms::{ protocol::rpc::{RpcError, RpcStatus}, }; -use crate::{blocks::BlockError, chain_storage::ChainStorageError, validation::ValidationError}; +use crate::{blocks::BlockError, chain_storage::ChainStorageError, common::BanReason, validation::ValidationError}; #[derive(Debug, thiserror::Error)] pub enum BlockHeaderSyncError { + #[error("No more sync peers available: {0}")] + NoMoreSyncPeers(String), #[error("RPC error: {0}")] RpcError(#[from] RpcError), #[error("RPC request failed: {0}")] @@ -77,6 +79,8 @@ pub enum BlockHeaderSyncError { actual: Option, local: u128, }, + #[error("This peer sent too many headers ({0}) in response to a chain split request")] + PeerSentTooManyHeaders(usize), #[error("Peer {peer} exceeded maximum permitted sync latency. latency: {latency:.2?}s, max: {max_latency:.2?}s")] MaxLatencyExceeded { peer: NodeId, @@ -86,3 +90,42 @@ pub enum BlockHeaderSyncError { #[error("All sync peers exceeded max allowed latency")] AllSyncPeersExceedLatency, } + +impl BlockHeaderSyncError { + pub fn get_ban_reason(&self, short_ban: Duration, long_ban: Duration) -> Option { + match self { + // no ban + BlockHeaderSyncError::NoMoreSyncPeers(_) | + BlockHeaderSyncError::RpcError(_) | + BlockHeaderSyncError::RpcRequestError(_) | + BlockHeaderSyncError::SyncFailedAllPeers | + BlockHeaderSyncError::FailedToBan(_) | + BlockHeaderSyncError::AllSyncPeersExceedLatency | + BlockHeaderSyncError::ConnectivityError(_) | + BlockHeaderSyncError::NotInSync | + BlockHeaderSyncError::ChainStorageError(_) => None, + + // short ban + err @ BlockHeaderSyncError::MaxLatencyExceeded { .. } => Some(BanReason { + reason: format!("{}", err), + ban_duration: short_ban, + }), + + // long ban + err @ BlockHeaderSyncError::ReceivedInvalidHeader(_) | + err @ BlockHeaderSyncError::ValidationFailed(_) | + err @ BlockHeaderSyncError::FoundHashIndexOutOfRange(_, _) | + err @ BlockHeaderSyncError::StartHashNotFound(_) | + err @ BlockHeaderSyncError::InvalidBlockHeight { .. } | + err @ BlockHeaderSyncError::ChainSplitNotFound(_) | + err @ BlockHeaderSyncError::InvalidProtocolResponse(_) | + err @ BlockHeaderSyncError::ChainLinkBroken { .. } | + err @ BlockHeaderSyncError::BlockError(_) | + err @ BlockHeaderSyncError::PeerSentInaccurateChainMetadata { .. } | + err @ BlockHeaderSyncError::PeerSentTooManyHeaders(_) => Some(BanReason { + reason: format!("{}", err), + ban_duration: long_ban, + }), + } + } +} diff --git a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs index 9374138197..0eef90cb2d 100644 --- a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs @@ -31,7 +31,7 @@ use tari_common_types::{chain_metadata::ChainMetadata, types::HashOutput}; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::NodeId, - protocol::rpc::{RpcClient, RpcError, RpcHandshakeError}, + protocol::rpc::{RpcClient, RpcError}, PeerConnection, }; use tari_utilities::hex::Hex; @@ -39,7 +39,7 @@ use tracing; use super::{validator::BlockHeaderSyncValidator, BlockHeaderSyncError}; use crate::{ - base_node::sync::{hooks::Hooks, rpc, BlockchainSyncConfig, SyncPeer}, + base_node::sync::{ban::PeerBanManager, hooks::Hooks, rpc, BlockchainSyncConfig, SyncPeer}, blocks::{BlockHeader, ChainBlock, ChainHeader}, chain_storage::{async_db::AsyncBlockchainDb, BlockchainBackend}, common::rolling_avg::RollingAverageTime, @@ -49,7 +49,6 @@ use crate::{ base_node as proto, base_node::{FindChainSplitRequest, SyncHeadersRequest}, }, - validation::ValidationError, }; const LOG_TARGET: &str = "c::bn::header_sync"; @@ -61,9 +60,10 @@ pub struct HeaderSynchronizer<'a, B> { db: AsyncBlockchainDb, header_validator: BlockHeaderSyncValidator, connectivity: ConnectivityRequester, - sync_peers: &'a mut [SyncPeer], + sync_peers: Vec, hooks: Hooks, local_metadata: &'a ChainMetadata, + peer_ban_manager: PeerBanManager, } impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { @@ -72,10 +72,11 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { db: AsyncBlockchainDb, consensus_rules: ConsensusManager, connectivity: ConnectivityRequester, - sync_peers: &'a mut [SyncPeer], + sync_peers: Vec, randomx_factory: RandomXFactory, local_metadata: &'a ChainMetadata, ) -> Self { + let peer_ban_manager = PeerBanManager::new(config.clone(), connectivity.clone()); Self { config, header_validator: BlockHeaderSyncValidator::new(db.clone(), consensus_rules, randomx_factory), @@ -84,6 +85,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { sync_peers, hooks: Default::default(), local_metadata, + peer_ban_manager, } } @@ -134,67 +136,41 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { "Attempting to sync headers ({} sync peers)", sync_peer_node_ids.len() ); + let mut latency_counter = 0usize; for (i, node_id) in sync_peer_node_ids.iter().enumerate() { match self.connect_and_attempt_sync(i, node_id, max_latency).await { Ok(peer) => return Ok(peer), - // Try another peer - Err(err @ BlockHeaderSyncError::NotInSync) => { - warn!(target: LOG_TARGET, "{}", err); - }, - - Err(err @ BlockHeaderSyncError::RpcError(RpcError::HandshakeError(RpcHandshakeError::TimedOut))) => { - warn!(target: LOG_TARGET, "{}", err); - self.ban_peer_short(node_id, BanReason::RpcNegotiationTimedOut).await?; - }, - Err(BlockHeaderSyncError::ValidationFailed(err)) => { - warn!(target: LOG_TARGET, "Block header validation failed: {}", err); - self.ban_peer_long(node_id, err.into()).await?; - }, - Err(BlockHeaderSyncError::ChainSplitNotFound(peer)) => { - warn!(target: LOG_TARGET, "Chain split not found for peer {}.", peer); - self.ban_peer_long(&peer, BanReason::ChainSplitNotFound).await?; - }, - Err(ref err @ BlockHeaderSyncError::PeerSentInaccurateChainMetadata { claimed, actual, local }) => { - warn!(target: LOG_TARGET, "{}", err); - self.ban_peer_long(node_id, BanReason::PeerSentInaccurateChainMetadata { - claimed, - actual: actual.unwrap_or(0), - local, - }) - .await?; - }, - Err(BlockHeaderSyncError::ChainLinkBroken { - height, - actual, - expected, - }) => { - let reason = BanReason::ChainLinkBroken { - height, - actual: actual.to_string(), - expected: expected.to_string(), - }; - warn!(target: LOG_TARGET, "Chain link broken: {}", reason); - self.ban_peer_long(node_id, reason).await?; - }, - Err(err @ BlockHeaderSyncError::RpcError(RpcError::ReplyTimeout)) | - Err(err @ BlockHeaderSyncError::MaxLatencyExceeded { .. }) => { - warn!(target: LOG_TARGET, "{}", err); - if i == self.sync_peers.len() - 1 { - return Err(BlockHeaderSyncError::AllSyncPeersExceedLatency); - } - continue; - }, - Err(err) => { - error!( - target: LOG_TARGET, - "Failed to synchronize headers from peer `{}`: {}", node_id, err + let ban_reason = BlockHeaderSyncError::get_ban_reason( + &err, + self.config.short_ban_period, + self.config.ban_period, ); + if let Some(reason) = ban_reason { + warn!(target: LOG_TARGET, "{}", err); + self.peer_ban_manager + .ban_peer_if_required(node_id, &Some(reason.clone())) + .await; + + if reason.ban_duration > self.config.short_ban_period { + self.remove_sync_peer(node_id); + } + } + + if let BlockHeaderSyncError::MaxLatencyExceeded { .. } = err { + latency_counter += 1; + } }, } } - Err(BlockHeaderSyncError::SyncFailedAllPeers) + if self.sync_peers.is_empty() { + Err(BlockHeaderSyncError::NoMoreSyncPeers("Header sync failed".to_string())) + } else if latency_counter >= self.sync_peers.len() { + Err(BlockHeaderSyncError::AllSyncPeersExceedLatency) + } else { + Err(BlockHeaderSyncError::SyncFailedAllPeers) + } } async fn connect_and_attempt_sync( @@ -251,35 +227,6 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { Ok(conn) } - async fn ban_peer_long(&mut self, node_id: &NodeId, reason: BanReason) -> Result<(), BlockHeaderSyncError> { - self.ban_peer_for(node_id, reason, self.config.ban_period).await - } - - async fn ban_peer_short(&mut self, node_id: &NodeId, reason: BanReason) -> Result<(), BlockHeaderSyncError> { - self.ban_peer_for(node_id, reason, self.config.short_ban_period).await - } - - async fn ban_peer_for( - &mut self, - node_id: &NodeId, - reason: BanReason, - duration: Duration, - ) -> Result<(), BlockHeaderSyncError> { - if self.config.forced_sync_peers.contains(node_id) { - debug!( - target: LOG_TARGET, - "Not banning peer that is allowlisted for sync. Ban reason = {}", reason - ); - return Ok(()); - } - warn!(target: LOG_TARGET, "Banned sync peer because {}", reason); - self.connectivity - .ban_peer_until(node_id.clone(), duration, reason.to_string()) - .await - .map_err(BlockHeaderSyncError::FailedToBan)?; - Ok(()) - } - #[tracing::instrument(skip(self, client), err)] async fn attempt_sync( &mut self, @@ -421,17 +368,9 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { }, }; if resp.headers.len() > NUM_INITIAL_HEADERS_TO_REQUEST { - self.ban_peer_long(peer, BanReason::PeerSentTooManyHeaders(resp.headers.len())) - .await?; - return Err(BlockHeaderSyncError::NotInSync); + return Err(BlockHeaderSyncError::PeerSentTooManyHeaders(resp.headers.len())); } if resp.fork_hash_index >= block_hashes.len() as u64 { - let _result = self - .ban_peer_long(peer, BanReason::SplitHashGreaterThanHashes { - fork_hash_index: resp.fork_hash_index, - num_block_hashes: block_hashes.len(), - }) - .await; return Err(BlockHeaderSyncError::FoundHashIndexOutOfRange( block_hashes.len() as u64, resp.fork_hash_index, @@ -522,11 +461,6 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { // Basic sanity check that the peer sent tip height greater than the split. let split_height = local_tip_header.height().saturating_sub(steps_back); if remote_tip_height < split_height { - self.ban_peer_short(sync_peer.node_id(), BanReason::PeerSentInvalidTipHeight { - actual: remote_tip_height, - expected: split_height, - }) - .await?; return Err(BlockHeaderSyncError::InvalidProtocolResponse(format!( "Peer {} sent invalid remote tip height", sync_peer @@ -830,39 +764,13 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { Ok(()) } -} -#[derive(Debug, thiserror::Error)] -enum BanReason { - #[error("This peer sent too many headers ({0}) in response to a chain split request")] - PeerSentTooManyHeaders(usize), - #[error("This peer sent an invalid tip height {actual} expected a height greater than or equal to {expected}")] - PeerSentInvalidTipHeight { actual: u64, expected: u64 }, - #[error( - "This peer sent a split hash index ({fork_hash_index}) greater than the number of block hashes sent \ - ({num_block_hashes})" - )] - SplitHashGreaterThanHashes { - fork_hash_index: u64, - num_block_hashes: usize, - }, - #[error("Peer sent invalid header: {0}")] - ValidationFailed(#[from] ValidationError), - #[error("Peer could not find the location of a chain split")] - ChainSplitNotFound, - #[error("Peer did not respond timeously during RPC negotiation")] - RpcNegotiationTimedOut, - #[error("Header at height {height} did not form a chain. Expected {actual} to equal the previous hash {expected}")] - ChainLinkBroken { - height: u64, - actual: String, - expected: String, - }, - #[error( - "Peer sent inaccurate chain metadata. Claimed {claimed} but validated difficulty was {actual}, while local \ - was {local}" - )] - PeerSentInaccurateChainMetadata { claimed: u128, actual: u128, local: u128 }, + // Sync peers are also removed from the list of sync peers if the ban duration is longer than the short ban period. + fn remove_sync_peer(&mut self, node_id: &NodeId) { + if let Some(pos) = self.sync_peers.iter().position(|p| p.node_id() == node_id) { + self.sync_peers.remove(pos); + } + } } struct ChainSplitInfo { diff --git a/base_layer/core/src/base_node/sync/horizon_state_sync/error.rs b/base_layer/core/src/base_node/sync/horizon_state_sync/error.rs index 9025bec4d5..9b3563acfd 100644 --- a/base_layer/core/src/base_node/sync/horizon_state_sync/error.rs +++ b/base_layer/core/src/base_node/sync/horizon_state_sync/error.rs @@ -35,6 +35,7 @@ use tokio::task; use crate::{ chain_storage::{ChainStorageError, MmrTree}, + common::BanReason, transactions::transaction_components::TransactionError, validation::ValidationError, }; @@ -88,6 +89,8 @@ pub enum HorizonSyncError { AllSyncPeersExceedLatency, #[error("FixedHash size error: {0}")] FixedHashSizeError(#[from] FixedHashSizeError), + #[error("No more sync peers available: {0}")] + NoMoreSyncPeers(String), } impl From for HorizonSyncError { @@ -101,3 +104,41 @@ impl From for HorizonSyncError { HorizonSyncError::RangeProofError(e.to_string()) } } + +impl HorizonSyncError { + pub fn get_ban_reason(&self, short_ban: Duration, long_ban: Duration) -> Option { + match self { + // no ban + HorizonSyncError::ChainStorageError(_) | + HorizonSyncError::NoSyncPeers | + HorizonSyncError::FailedSyncAllPeers | + HorizonSyncError::AllSyncPeersExceedLatency | + HorizonSyncError::ConnectivityError(_) | + HorizonSyncError::RpcError(_) | + HorizonSyncError::RpcStatus(_) | + HorizonSyncError::NoMoreSyncPeers(_) | + HorizonSyncError::JoinError(_) => None, + + // short ban + err @ HorizonSyncError::MaxLatencyExceeded { .. } => Some(BanReason { + reason: format!("{}", err), + ban_duration: short_ban, + }), + + // long ban + err @ HorizonSyncError::IncorrectResponse(_) | + err @ HorizonSyncError::FinalStateValidationFailed(_) | + err @ HorizonSyncError::RangeProofError(_) | + err @ HorizonSyncError::InvalidMmrRoot { .. } | + err @ HorizonSyncError::InvalidMmrPosition { .. } | + err @ HorizonSyncError::ConversionError(_) | + err @ HorizonSyncError::MerkleMountainRangeError(_) | + err @ HorizonSyncError::ValidationError(_) | + err @ HorizonSyncError::FixedHashSizeError(_) | + err @ HorizonSyncError::TransactionError(_) => Some(BanReason { + reason: format!("{}", err), + ban_duration: long_ban, + }), + } + } +} diff --git a/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs index 9f48570f85..6b4e6bef28 100644 --- a/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs @@ -32,17 +32,14 @@ use croaring::Bitmap; use futures::{stream::FuturesUnordered, StreamExt}; use log::*; use tari_common_types::types::{Commitment, RangeProofService}; -use tari_comms::{ - connectivity::ConnectivityRequester, - peer_manager::NodeId, - protocol::rpc::{RpcClient, RpcError}, -}; +use tari_comms::{connectivity::ConnectivityRequester, peer_manager::NodeId, protocol::rpc::RpcClient, PeerConnection}; use tari_crypto::{commitment::HomomorphicCommitment, tari_utilities::hex::Hex}; use tokio::task; use super::error::HorizonSyncError; use crate::{ base_node::sync::{ + ban::PeerBanManager, hooks::Hooks, horizon_state_sync::{HorizonSyncInfo, HorizonSyncStatus}, rpc, @@ -81,11 +78,11 @@ use crate::{ const LOG_TARGET: &str = "c::bn::state_machine_service::states::horizon_state_sync"; -pub struct HorizonStateSynchronization<'a, B> { +pub struct HorizonStateSynchronization { config: BlockchainSyncConfig, db: AsyncBlockchainDb, rules: ConsensusManager, - sync_peers: &'a [SyncPeer], + sync_peers: Vec, horizon_sync_height: u64, prover: Arc, num_kernels: u64, @@ -95,20 +92,22 @@ pub struct HorizonStateSynchronization<'a, B> { connectivity: ConnectivityRequester, final_state_validator: Arc>, max_latency: Duration, + peer_ban_manager: PeerBanManager, } -impl<'a, B: BlockchainBackend + 'static> HorizonStateSynchronization<'a, B> { +impl HorizonStateSynchronization { #[allow(clippy::too_many_arguments)] pub fn new( config: BlockchainSyncConfig, db: AsyncBlockchainDb, connectivity: ConnectivityRequester, rules: ConsensusManager, - sync_peers: &'a [SyncPeer], + sync_peers: Vec, horizon_sync_height: u64, prover: Arc, final_state_validator: Arc>, ) -> Self { + let peer_ban_manager = PeerBanManager::new(config.clone(), connectivity.clone()); Self { max_latency: config.initial_max_sync_latency, config, @@ -123,6 +122,7 @@ impl<'a, B: BlockchainBackend + 'static> HorizonStateSynchronization<'a, B> { full_bitmap: None, hooks: Hooks::default(), final_state_validator, + peer_ban_manager, } } @@ -179,97 +179,107 @@ impl<'a, B: BlockchainBackend + 'static> HorizonStateSynchronization<'a, B> { } async fn sync(&mut self, header: &BlockHeader) -> Result<(), HorizonSyncError> { + let sync_peer_node_ids = self.sync_peers.iter().map(|p| p.node_id()).cloned().collect::>(); info!( target: LOG_TARGET, - "Attempting to sync blocks({} sync peers)", - self.sync_peers.len() + "Attempting to sync horizon state ({} sync peers)", + sync_peer_node_ids.len() ); - for (i, sync_peer) in self.sync_peers.iter().enumerate() { - self.hooks.call_on_starting_hook(sync_peer); - let mut connection = match self.connectivity.dial_peer(sync_peer.node_id().clone()).await { - Ok(conn) => conn, + let mut latency_counter = 0usize; + for (i, node_id) in sync_peer_node_ids.iter().enumerate() { + match self.connect_and_attempt_sync(i, node_id, header).await { + Ok(_) => return Ok(()), + // Try another peer Err(err) => { - warn!(target: LOG_TARGET, "Failed to connect to sync peer `{}`: {}", sync_peer.node_id(), err); - continue; - }, - }; - let config = RpcClient::builder() - .with_deadline(self.config.rpc_deadline) - .with_deadline_grace_period(Duration::from_secs(3)); - let mut client = match connection.connect_rpc_using_builder(config).await { - Ok(rpc_client) => rpc_client, - Err(err) => { - warn!(target: LOG_TARGET, "Failed to establish RPC coonection with sync peer `{}`: {}", sync_peer.node_id(), err); - continue; - }, - }; - - match self.begin_sync(sync_peer.clone(), &mut client, header).await { - Ok(_) => match self.finalize_horizon_sync(sync_peer).await { - Ok(_) => return Ok(()), - Err(err) => { - self.ban_peer_on_bannable_error(sync_peer, &err).await?; - warn!(target: LOG_TARGET, "Error during sync:{}", err); - }, - }, - Err(err @ HorizonSyncError::RpcError(RpcError::ReplyTimeout)) | - Err(err @ HorizonSyncError::MaxLatencyExceeded { .. }) => { - self.ban_peer_on_bannable_error(sync_peer, &err).await?; - warn!(target: LOG_TARGET, "{}", err); - if i == self.sync_peers.len() - 1 { - return Err(HorizonSyncError::AllSyncPeersExceedLatency); + let ban_reason = + HorizonSyncError::get_ban_reason(&err, self.config.short_ban_period, self.config.ban_period); + + if let Some(reason) = ban_reason { + warn!(target: LOG_TARGET, "{}", err); + self.peer_ban_manager + .ban_peer_if_required(node_id, &Some(reason.clone())) + .await; + + if reason.ban_duration > self.config.short_ban_period { + self.remove_sync_peer(node_id); + } + } + + if let HorizonSyncError::MaxLatencyExceeded { .. } = err { + latency_counter += 1; } - }, - Err(err) => { - self.ban_peer_on_bannable_error(sync_peer, &err).await?; - warn!(target: LOG_TARGET, "Error during sync:{}", err); }, } } - Err(HorizonSyncError::FailedSyncAllPeers) + if self.sync_peers.is_empty() { + Err(HorizonSyncError::NoMoreSyncPeers("Header sync failed".to_string())) + } else if latency_counter >= self.sync_peers.len() { + Err(HorizonSyncError::AllSyncPeersExceedLatency) + } else { + Err(HorizonSyncError::FailedSyncAllPeers) + } } - async fn ban_peer_on_bannable_error( + async fn connect_and_attempt_sync( &mut self, - peer: &SyncPeer, - error: &HorizonSyncError, + peer_index: usize, + node_id: &NodeId, + header: &BlockHeader, ) -> Result<(), HorizonSyncError> { - match error { - HorizonSyncError::ChainStorageError(_) | - HorizonSyncError::JoinError(_) | - HorizonSyncError::RpcError(_) | - HorizonSyncError::RpcStatus(_) | - HorizonSyncError::ConnectivityError(_) | - HorizonSyncError::NoSyncPeers | - HorizonSyncError::FailedSyncAllPeers | - HorizonSyncError::AllSyncPeersExceedLatency => { - // these are local errors so we dont ban die per - }, - HorizonSyncError::MaxLatencyExceeded { .. } => { - warn!(target: LOG_TARGET, "Banned sync peer for short while because peer exceeded max latency: {}",error.to_string()); - if let Err(err) = self - .connectivity - .ban_peer_until(peer.node_id().clone(), self.config.short_ban_period, error.to_string()) - .await - { - error!(target: LOG_TARGET, "Failed to ban peer: {}", err); - } - }, - _ => { - warn!(target: LOG_TARGET, "Banned sync peer for because: {}",error.to_string()); - if let Err(err) = self - .connectivity - .ban_peer_until(peer.node_id().clone(), self.config.ban_period, error.to_string()) - .await - { - error!(target: LOG_TARGET, "Failed to ban peer: {}", err); - } - }, + { + let sync_peer = &self.sync_peers[peer_index]; + self.hooks.call_on_starting_hook(sync_peer); + } + + let mut conn = self.dial_sync_peer(node_id).await?; + debug!( + target: LOG_TARGET, + "Attempting to synchronize horizon state with `{}`", node_id + ); + + let config = RpcClient::builder() + .with_deadline(self.config.rpc_deadline) + .with_deadline_grace_period(Duration::from_secs(3)); + + let mut client = conn + .connect_rpc_using_builder::(config) + .await?; + + let latency = client + .get_last_request_latency() + .expect("unreachable panic: last request latency must be set after connect"); + self.sync_peers[peer_index].set_latency(latency); + if latency > self.max_latency { + return Err(HorizonSyncError::MaxLatencyExceeded { + peer: conn.peer_node_id().clone(), + latency, + max_latency: self.max_latency, + }); } + + debug!(target: LOG_TARGET, "Sync peer latency is {:.2?}", latency); + let sync_peer = self.sync_peers[peer_index].clone(); + + self.begin_sync(sync_peer.clone(), &mut client, header).await?; + self.finalize_horizon_sync(&sync_peer).await?; + Ok(()) } + async fn dial_sync_peer(&self, node_id: &NodeId) -> Result { + let timer = Instant::now(); + debug!(target: LOG_TARGET, "Dialing {} sync peer", node_id); + let conn = self.connectivity.dial_peer(node_id.clone()).await?; + info!( + target: LOG_TARGET, + "Successfully dialed sync peer {} in {:.2?}", + node_id, + timer.elapsed() + ); + Ok(conn) + } + async fn begin_sync( &mut self, sync_peer: SyncPeer, @@ -371,9 +381,9 @@ impl<'a, B: BlockchainBackend + 'static> HorizonStateSynchronization<'a, B> { kernel_hashes.push(kernel.hash()); if mmr_position > end { - return Err(HorizonSyncError::IncorrectResponse(format!( - "Peer sent too many kernels", - ))); + return Err(HorizonSyncError::IncorrectResponse( + "Peer sent too many kernels".to_string(), + )); } let mmr_position_u32 = u32::try_from(mmr_position).map_err(|_| HorizonSyncError::InvalidMmrPosition { @@ -566,9 +576,9 @@ impl<'a, B: BlockchainBackend + 'static> HorizonStateSynchronization<'a, B> { let res: SyncUtxosResponse = response?; if mmr_position > end { - return Err(HorizonSyncError::IncorrectResponse(format!( - "Peer sent too many outputs", - ))); + return Err(HorizonSyncError::IncorrectResponse( + "Peer sent too many outputs".to_string(), + )); } if res.mmr_index != 0 && res.mmr_index != mmr_position { @@ -961,6 +971,13 @@ impl<'a, B: BlockchainBackend + 'static> HorizonStateSynchronization<'a, B> { .await? } + // Sync peers are also removed from the list of sync peers if the ban duration is longer than the short ban period. + fn remove_sync_peer(&mut self, node_id: &NodeId) { + if let Some(pos) = self.sync_peers.iter().position(|p| p.node_id() == node_id) { + self.sync_peers.remove(pos); + } + } + #[inline] fn db(&self) -> &AsyncBlockchainDb { &self.db diff --git a/base_layer/core/src/base_node/sync/mod.rs b/base_layer/core/src/base_node/sync/mod.rs index 869ed48611..2193bb6482 100644 --- a/base_layer/core/src/base_node/sync/mod.rs +++ b/base_layer/core/src/base_node/sync/mod.rs @@ -20,6 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#[cfg(feature = "base_node")] +pub mod ban; + #[cfg(feature = "base_node")] mod config; #[cfg(feature = "base_node")] diff --git a/base_layer/core/src/common/mod.rs b/base_layer/core/src/common/mod.rs index 638c4e3040..b9a2d30dc2 100644 --- a/base_layer/core/src/common/mod.rs +++ b/base_layer/core/src/common/mod.rs @@ -42,6 +42,7 @@ pub type ConfidentialOutputHasher = DomainSeparatedConsensusHasher>, ) -> Result<(), DhtDiscoveryError> { let nonce = OsRng.next_u64(); + if *dest_pubkey == *self.node_identity.public_key() { + let _result = reply_tx.send(Err(DhtDiscoveryError::CannotDiscoverThisNode)); + return Ok(()); + } + if let Err(err) = self.send_discover(nonce, destination, dest_pubkey.clone()).await { let _result = reply_tx.send(Err(err)); return Ok(()); diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index e507cfc00a..28c87e5be1 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -67,6 +67,7 @@ pub(crate) fn epochtime_to_datetime(datetime: EpochTime) -> DateTime { DateTime::from_utc(dt, Utc) } +/// Message errors that should be verified by every node #[derive(Debug, Error)] pub enum DhtMessageError { #[error("Invalid node destination")] @@ -83,8 +84,10 @@ pub enum DhtMessageError { InvalidMessageFlags, #[error("Invalid ephemeral public key")] InvalidEphemeralPublicKey, - #[error("Header was omitted from the message")] + #[error("Header is omitted from the message")] HeaderOmitted, + #[error("Message Body is empty")] + BodyEmpty, } impl fmt::Display for DhtMessageType { @@ -157,12 +160,31 @@ pub struct DhtMessageHeader { } impl DhtMessageHeader { - pub fn is_valid(&self) -> bool { + /// Checks if the DHT header is semantically valid. For example, if the message is flagged as encrypted, but sets a + /// empty signature or provides no ephemeral public key, this returns false. + pub fn is_semantically_valid(&self) -> bool { + // If the message is encrypted: + // - it needs a destination + // - it needs an ephemeral public key + // - it needs a signature if self.flags.is_encrypted() { - !self.message_signature.is_empty() && self.ephemeral_public_key.is_some() - } else { - true + // Must have a destination + if self.destination.is_unknown() { + return false; + } + + // Must have an ephemeral public key + if self.ephemeral_public_key.is_none() { + return false; + } + + // Must have a signature + if self.message_signature.is_empty() { + return false; + } } + + true } } diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index d08185dc50..f54843b634 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -56,6 +56,8 @@ enum DecryptionError { MessageRejectDecryptionFailed, #[error("Failed to decode envelope body")] EnvelopeBodyDecodeFailed, + #[error("Bad clear-text message semantics")] + BadClearTextMessageSemantics, } /// This layer is responsible for attempting to decrypt inbound messages. @@ -294,36 +296,17 @@ where S: Service /// /// These failure modes are detectable by any node, so it is generally safe to ban an offending peer. fn initial_validation(message: DhtInboundMessage) -> Result { - // If an unencrypted message has no signature, it passes this validation automatically - if !message.dht_header.flags.is_encrypted() && message.dht_header.message_signature.is_empty() { - return Ok(ValidatedDhtInboundMessage::new(message, None)); - } - - // If the message is encrypted: - // - it must be nonempty - // - it needs a destination - // - it needs an ephemeral public key - // - it needs a signature - if message.dht_header.flags.is_encrypted() { - // Must be nonempty - if message.body.is_empty() { - return Err(DecryptionError::BadEncryptedMessageSemantics); - } - - // Must have a destination - if message.dht_header.destination.is_unknown() { - return Err(DecryptionError::BadEncryptedMessageSemantics); - } - - // Must have an ephemeral public key - if message.dht_header.ephemeral_public_key.is_none() { + if !message.is_semantically_valid() { + if message.dht_header.flags.is_encrypted() { return Err(DecryptionError::BadEncryptedMessageSemantics); + } else { + return Err(DecryptionError::BadClearTextMessageSemantics); } + } - // Must have a signature - if message.dht_header.message_signature.is_empty() { - return Err(DecryptionError::BadEncryptedMessageSemantics); - } + // If a signature is not present, the message is valid at this point + if message.dht_header.message_signature.is_empty() { + return Ok(ValidatedDhtInboundMessage::new(message, None)); } // If a signature is present, it must be valid diff --git a/comms/dht/src/inbound/dht_handler/task.rs b/comms/dht/src/inbound/dht_handler/task.rs index 1f1cc91594..819e14071e 100644 --- a/comms/dht/src/inbound/dht_handler/task.rs +++ b/comms/dht/src/inbound/dht_handler/task.rs @@ -161,7 +161,13 @@ where S: Service "Received JoinMessage that did not have an authenticated origin from source peer {}. Banning source", source_peer ); - self .dht .ban_peer(source_peer.public_key.clone(), OffenceSeverity::Low, "Received JoinMessage that did not have an authenticated origin", ).await; + self.dht + .ban_peer( + source_peer.public_key.clone(), + OffenceSeverity::Low, + "Received JoinMessage that did not have an authenticated origin", + ) + .await; return Ok(()); }; @@ -290,11 +296,13 @@ where S: Service target: LOG_TARGET, "Received DiscoveryResponseMessage that did not have an authenticated origin: {}. Banning source", message ); - self.dht .ban_peer( - message.source_peer.public_key.clone(), - OffenceSeverity::Low, - "Received DiscoveryResponseMessage that did not have an authenticated origin", - ).await; + self.dht + .ban_peer( + message.source_peer.public_key.clone(), + OffenceSeverity::Low, + "Received DiscoveryResponseMessage that did not have an authenticated origin", + ) + .await; return Ok(()); }; @@ -351,11 +359,13 @@ where S: Service target: LOG_TARGET, "Received Discover that did not have an authenticated origin from source peer {}. Banning source", message.source_peer ); - self.dht.ban_peer( - message.source_peer.public_key.clone(), - OffenceSeverity::Low, - "Received JoinMessage that did not have an authenticated origin", - ).await; + self.dht + .ban_peer( + message.source_peer.public_key.clone(), + OffenceSeverity::Low, + "Received JoinMessage that did not have an authenticated origin", + ) + .await; return Ok(()); }; diff --git a/comms/dht/src/inbound/forward.rs b/comms/dht/src/inbound/forward.rs index 37e04129d2..0b34f76785 100644 --- a/comms/dht/src/inbound/forward.rs +++ b/comms/dht/src/inbound/forward.rs @@ -186,26 +186,12 @@ where S: Service dht_header, is_saf_stored, is_already_forwarded, - authenticated_origin, .. } = message; if self.destination_matches_source(&dht_header.destination, source_peer) { - // #banheuristic - the origin of this message was the destination. Two things are wrong here: - // 1. The origin/destination should not have forwarded this (the destination node didnt do this - // destination_matches_source check) - // 1. The origin sent a message that the destination could not decrypt - // The authenticated source should be banned (malicious), and origin should be temporarily banned - // (bug?) - if let Some(authenticated_origin) = authenticated_origin { - self.dht - .ban_peer( - authenticated_origin.clone(), - OffenceSeverity::High, - "Received message from peer that is destined for that peer. This peer originally sent it.", - ) - .await; - } + // The origin/destination should not have forwarded this (the source node didnt do this + // destination_matches_source check) self.dht .ban_peer( source_peer.public_key.clone(), diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index 5d89f98e98..0d811b2557 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -91,6 +91,22 @@ impl DhtInboundMessage { body, } } + + pub fn is_semantically_valid(&self) -> bool { + if !self.dht_header.is_semantically_valid() { + return false; + } + + // If the message is encrypted: + // - it must be nonempty + if self.dht_header.flags.is_encrypted() { + // Body must be nonempty + if self.body.is_empty() { + return false; + } + } + true + } } impl Display for DhtInboundMessage { diff --git a/comms/dht/src/message_signature.rs b/comms/dht/src/message_signature.rs index 975b5d9208..fa71033e93 100644 --- a/comms/dht/src/message_signature.rs +++ b/comms/dht/src/message_signature.rs @@ -123,7 +123,7 @@ pub struct ProtoMessageSignature { #[derive(Debug, thiserror::Error, PartialEq)] pub enum MessageSignatureError { - #[error("Failed to validate message signature")] + #[error("Message signature does not contain valid scalar bytes")] InvalidSignatureBytes, #[error("Message signature contained an invalid public nonce")] InvalidPublicNonceBytes, diff --git a/comms/dht/src/proto/store_forward.proto b/comms/dht/src/proto/store_forward.proto index be5d0dbe76..830004e5ea 100644 --- a/comms/dht/src/proto/store_forward.proto +++ b/comms/dht/src/proto/store_forward.proto @@ -15,6 +15,7 @@ package tari.dht.store_forward; message StoredMessagesRequest { google.protobuf.Timestamp since = 1; uint32 request_id = 2; + uint32 limit = 3; } // Storage for a single message envelope, including the date and time when the element was stored diff --git a/comms/dht/src/store_forward/error.rs b/comms/dht/src/store_forward/error.rs index 6d202cd6e5..d7f47ff225 100644 --- a/comms/dht/src/store_forward/error.rs +++ b/comms/dht/src/store_forward/error.rs @@ -27,14 +27,12 @@ use tari_comms::{ message::MessageError, peer_manager::{NodeId, PeerManagerError}, }; -use tari_utilities::{byte_array::ByteArrayError, epoch_time::EpochTime}; use thiserror::Error; use crate::{ actor::DhtActorError, envelope::DhtMessageError, error::DhtEncryptError, - inbound::DhtInboundError, message_signature::MessageSignatureError, outbound::DhtOutboundError, storage::StorageError, @@ -55,14 +53,12 @@ pub enum StoreAndForwardError { DhtEncryptError(#[from] DhtEncryptError), #[error("Received stored message has an invalid destination")] InvalidDestination, - #[error("DhtInboundError: {0}")] - DhtInboundError(#[from] DhtInboundError), #[error("Received stored message has an invalid origin signature: {0}")] InvalidMessageSignature(#[from] MessageSignatureError), - #[error("Invalid envelope body")] - InvalidEnvelopeBody, - #[error("DHT header is invalid")] - InvalidDhtHeader, + #[error("Envelope body is missing a required message part")] + EnvelopeBodyMissingMessagePart, + #[error("DHT header did not pass semantic validation rules")] + BadDhtHeaderSemanticallyInvalid, #[error("Unable to decrypt received stored message")] DecryptionFailed, #[error("DhtActorError: {0}")] @@ -71,10 +67,8 @@ pub enum StoreAndForwardError { DuplicateMessage, #[error("Unable to decode message: {0}")] DecodeError(#[from] DecodeError), - #[error("Dht header was not provided")] - DhtHeaderNotProvided, - #[error("The message was malformed")] - MalformedMessage, + #[error("The message envelope was malformed: {0}")] + MalformedEnvelopeBody(DecodeError), #[error("StorageError: {0}")] StorageError(#[from] StorageError), #[error("The store and forward service requester channel closed")] @@ -83,24 +77,16 @@ pub enum StoreAndForwardError { RequestCancelled, #[error("The {field} field was not valid, discarding SAF response: {details}")] InvalidSafResponseMessage { field: &'static str, details: String }, - #[error("The message has expired, not storing message in SAF db (expiry: {expired}, now: {now})")] - NotStoringExpiredMessage { expired: EpochTime, now: EpochTime }, - #[error("MalformedNodeId: {0}")] - MalformedNodeId(String), - #[error("DHT message type should not have been forwarded")] - InvalidDhtMessageType, - #[error("Failed to send request for store and forward messages: {0}")] - RequestMessagesFailed(DhtOutboundError), + #[error("DHT message type should not have been stored/forwarded")] + PeerSentDhtMessageViaSaf, + #[error("SAF message type should not have been stored/forwarded")] + PeerSentSafMessageViaSaf, #[error("Received SAF messages that were not requested")] ReceivedUnrequestedSafMessages, #[error("SAF messages received from peer {peer} after deadline. Received after {message_age:.2?}")] SafMessagesReceivedAfterDeadline { peer: NodeId, message_age: Duration }, #[error("Invalid SAF request: `stored_at` cannot be in the future")] StoredAtWasInFuture, -} - -impl From for StoreAndForwardError { - fn from(e: ByteArrayError) -> Self { - StoreAndForwardError::MalformedNodeId(e.to_string()) - } + #[error("Invariant error (POSSIBLE BUG): {0}")] + InvariantError(String), } diff --git a/comms/dht/src/store_forward/message.rs b/comms/dht/src/store_forward/message.rs index f74af32c61..bdce545204 100644 --- a/comms/dht/src/store_forward/message.rs +++ b/comms/dht/src/store_forward/message.rs @@ -40,14 +40,15 @@ impl StoredMessagesRequest { Self { since: None, request_id: OsRng.next_u32(), + limit: 0, } } - #[allow(unused)] pub fn since(since: DateTime) -> Self { Self { since: Some(datetime_to_timestamp(since)), request_id: OsRng.next_u32(), + limit: 0, } } } diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index 1247cc8236..4c987c5111 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::{ + cmp, convert::{TryFrom, TryInto}, sync::Arc, }; @@ -41,10 +42,10 @@ use tokio::sync::mpsc; use tower::{Service, ServiceExt}; use crate::{ - actor::DhtRequester, + actor::{DhtRequester, OffenceSeverity}, crypt, dedup, - envelope::{timestamp_to_datetime, DhtMessageHeader, NodeDestination}, + envelope::{timestamp_to_datetime, DhtMessageError, DhtMessageHeader, NodeDestination}, inbound::{DecryptedDhtMessage, DhtInboundMessage}, message_signature::{MessageSignature, MessageSignatureError, ProtoMessageSignature}, outbound::{OutboundMessageRequester, SendMessageParams}, @@ -185,7 +186,7 @@ where S: Service let retrieve_msgs = msg .decode_part::(0)? - .ok_or(StoreAndForwardError::InvalidEnvelopeBody)?; + .ok_or(StoreAndForwardError::EnvelopeBodyMissingMessagePart)?; let source_pubkey = Box::new(message.source_peer.public_key.clone()); let source_node_id = Box::new(message.source_peer.node_id.clone()); @@ -193,6 +194,19 @@ where S: Service // Compile a set of stored messages for the requesting peer let mut query = FetchStoredMessageQuery::new(source_pubkey, source_node_id.clone()); + let max = u32::try_from(self.config.max_returned_messages).unwrap_or_else(|_| { + warn!(target: LOG_TARGET,"Your node is configured with an extremely high number for max_returned_messages. This will likely be disregarded by peers."); + u32::MAX + }); + // limit of 0 means no hard limit, though we still limit to our configured limit + if retrieve_msgs.limit == 0 { + query.with_limit(max); + } else { + // Return up to the limit. The limit cannot exceed our locally configured max_returned_messages setting. + // Returning less than requested is completely expected. + query.with_limit(cmp::min(retrieve_msgs.limit, max)); + } + let since = match retrieve_msgs.since.and_then(timestamp_to_datetime) { Some(since) => { debug!( @@ -273,7 +287,8 @@ where S: Service let message_tag = message.dht_header.message_tag; if let Err(err) = self.check_saf_messages_were_requested(&source_node_id).await { - // Peer sent SAF messages we didn't request?? #banheuristics + // Peer sent SAF messages we didn't request, it was cancelled locally or sent it more than 4 to 10 minutes + // late?? #banheuristics warn!(target: LOG_TARGET, "SAF response check failed: {}", err); return Ok(()); } @@ -283,7 +298,31 @@ where S: Service .expect("already checked that this message decrypted successfully"); let response = msg .decode_part::(0)? - .ok_or(StoreAndForwardError::InvalidEnvelopeBody)?; + .ok_or(StoreAndForwardError::EnvelopeBodyMissingMessagePart)?; + + if response.messages.len() > self.config.max_returned_messages { + warn!( + target: LOG_TARGET, + "Peer '{}' sent {} stored messages which is more than the maximum allowed of {}. Discarding \ + messages.", + source_node_id.short_str(), + response.messages.len(), + self.config.max_returned_messages + ); + self.dht_requester + .ban_peer( + message.source_peer.public_key.clone(), + OffenceSeverity::High, + format!( + "Peer sent too many stored messages ({} of {})", + response.messages.len(), + self.config.max_returned_messages + ), + ) + .await; + return Ok(()); + } + let source_peer = message.source_peer.clone(); debug!( @@ -298,65 +337,10 @@ where S: Service message_tag ); - let results = self + let successful_messages = self .process_incoming_stored_messages(source_peer.clone(), response.messages) .await?; - let successful_msgs_iter = results - .into_iter() - .map(|result| { - match &result { - Ok(msg) => { - trace!(target: LOG_TARGET, "Recv SAF message: {}", msg); - }, - // Failed decryption is acceptable, the message wasn't for this node so we - // simply discard the message. - Err(err @ StoreAndForwardError::DecryptionFailed) => { - debug!( - target: LOG_TARGET, - "Unable to decrypt stored message sent by {}: {}", - source_peer.node_id.short_str(), - err - ); - }, - // The peer that originally sent this message is not known to us. - Err(StoreAndForwardError::PeerManagerError(PeerManagerError::PeerNotFoundError)) => { - debug!(target: LOG_TARGET, "Origin peer not found. Discarding stored message."); - }, - - // Failed to send request to Dht Actor, something has gone very wrong - Err(StoreAndForwardError::DhtActorError(err)) => { - error!( - target: LOG_TARGET, - "DhtActor returned an error. {}. This could indicate a system malfunction.", err - ); - }, - // Duplicate message detected, no problem it happens. - Err(StoreAndForwardError::DuplicateMessage) => { - debug!( - target: LOG_TARGET, - "Store and forward received a duplicate message. Message discarded." - ); - }, - - // Every other error shouldn't happen if the sending node is behaving - Err(err) => { - // #banheuristics - warn!( - target: LOG_TARGET, - "SECURITY: invalid store and forward message was discarded from NodeId={}. Reason: {}. \ - These messages should never have been forwarded. This is a sign of a badly behaving node.", - source_peer.node_id.short_str(), - err - ); - }, - } - - result - }) - .filter(Result::is_ok) - .map(Result::unwrap); - // Let the SAF Service know we got a SAF response. let _ = self .saf_response_signal_sender @@ -365,7 +349,7 @@ where S: Service .map_err(|e| warn!(target: LOG_TARGET, "Error sending SAF response signal; {:?}", e)); self.next_service - .call_all(stream::iter(successful_msgs_iter)) + .call_all(stream::iter(successful_messages)) .unordered() .for_each(|service_result| { if let Err(err) = service_result { @@ -382,25 +366,33 @@ where S: Service &mut self, source_peer: Arc, messages: Vec, - ) -> Result>, StoreAndForwardError> { + ) -> Result, StoreAndForwardError> { let mut last_saf_received = self .dht_requester .get_metadata::>(DhtMetadataKey::LastSafMessageReceived) .await?; + // Allocations: the number of messages has already been bounds checked to be <= + // self.config.max_returned_messages let mut results = Vec::with_capacity(messages.len()); for msg in messages { let result = self .validate_and_decrypt_incoming_stored_message(Arc::clone(&source_peer), msg) .await; - if let Ok((_, stored_at)) = result.as_ref() { - if last_saf_received.as_ref().map(|dt| stored_at > dt).unwrap_or(true) { - last_saf_received = Some(*stored_at); - } + let Some(result) = self.process_saf_message_validation_result(&source_peer.public_key, result).await else { + // Logging of problems and banning are done inside process_saf_message. We can simply continue + continue; + }; + + // If the messages should no longer be processed because we banned the peer, we exit here on Err + let (msg, stored_at) = result?; + + if last_saf_received.as_ref().map(|dt| stored_at > *dt).unwrap_or(true) { + last_saf_received = Some(stored_at); } - results.push(result.map(|(msg, _)| msg)); + results.push(msg); } if let Some(last_saf_received) = last_saf_received { @@ -418,8 +410,12 @@ where S: Service message: ProtoStoredMessage, ) -> Result<(DecryptedDhtMessage, DateTime), StoreAndForwardError> { let node_identity = &self.node_identity; - if message.dht_header.is_none() { - return Err(StoreAndForwardError::DhtHeaderNotProvided); + let Some(dht_header) = message.dht_header else { + return Err(StoreAndForwardError::DhtMessageError(DhtMessageError::HeaderOmitted)); + }; + + if message.body.is_empty() { + return Err(StoreAndForwardError::DhtMessageError(DhtMessageError::BodyEmpty)); } let stored_at = message @@ -429,7 +425,7 @@ where S: Service NaiveDateTime::from_timestamp_opt(t.seconds, 0).ok_or_else(|| { StoreAndForwardError::InvalidSafResponseMessage { field: "stored_at", - details: "number of seconds provided represents more days than can fit in a u32" + details: "number of seconds provided represents more days than can fit in a NaiveDateTime" .to_string(), } })?, @@ -443,48 +439,37 @@ where S: Service return Err(StoreAndForwardError::StoredAtWasInFuture); } - let msg_hash = dedup::create_message_hash( - message - .dht_header - .as_ref() - .map(|h| h.message_signature.as_slice()) - .unwrap_or(&[]), - &message.body, - ); + let msg_hash = dedup::create_message_hash(&dht_header.message_signature, &message.body); - let dht_header: DhtMessageHeader = message - .dht_header - .expect("previously checked") - .try_into() - .map_err(StoreAndForwardError::DhtMessageError)?; + let dht_header: DhtMessageHeader = dht_header.try_into().map_err(StoreAndForwardError::DhtMessageError)?; - if !dht_header.is_valid() { - return Err(StoreAndForwardError::InvalidDhtHeader); + if !dht_header.is_semantically_valid() { + return Err(StoreAndForwardError::BadDhtHeaderSemanticallyInvalid); } let message_type = dht_header.message_type; if message_type.is_dht_message() { - if !message_type.is_dht_discovery() { - debug!( - target: LOG_TARGET, - "Discarding {} message from peer '{}'", - message_type, - source_peer.node_id.short_str() - ); - return Err(StoreAndForwardError::InvalidDhtMessageType); - } - if dht_header.destination.is_unknown() { - debug!( - target: LOG_TARGET, - "Discarding anonymous discovery message from peer '{}'", - source_peer.node_id.short_str() - ); - return Err(StoreAndForwardError::InvalidDhtMessageType); - } + debug!( + target: LOG_TARGET, + "Discarding {} message from peer '{}'", + message_type, + source_peer.node_id.short_str() + ); + return Err(StoreAndForwardError::PeerSentDhtMessageViaSaf); + } + + if message_type.is_saf_message() { + debug!( + target: LOG_TARGET, + "Discarding {} message from peer '{}'", + message_type, + source_peer.node_id.short_str() + ); + return Err(StoreAndForwardError::PeerSentSafMessageViaSaf); } // Check that the destination is either undisclosed, for us or for our network region - Self::check_destination(node_identity, &dht_header).await?; + Self::check_destination_for(node_identity.public_key(), &dht_header).await?; // Attempt to decrypt the message (if applicable), and deserialize it let (authenticated_pk, decrypted_body) = @@ -521,13 +506,13 @@ where S: Service } } - async fn check_destination( - node_identity: &NodeIdentity, + async fn check_destination_for( + public_key: &CommsPublicKey, dht_header: &DhtMessageHeader, ) -> Result<(), StoreAndForwardError> { let is_valid_destination = match &dht_header.destination { NodeDestination::Unknown => true, - NodeDestination::PublicKey(pk) => node_identity.public_key() == &**pk, + NodeDestination::PublicKey(pk) => *public_key == **pk, }; if is_valid_destination { @@ -568,11 +553,12 @@ where S: Service let envelope_body = EnvelopeBody::decode(decrypted_bytes.freeze()).map_err(|_| StoreAndForwardError::DecryptionFailed)?; if envelope_body.is_empty() { - return Err(StoreAndForwardError::InvalidEnvelopeBody); + return Err(StoreAndForwardError::EnvelopeBodyMissingMessagePart); } // Unmask the sender public key - let mask = crypt::generate_key_mask(&shared_ephemeral_secret)?; + let mask = crypt::generate_key_mask(&shared_ephemeral_secret) + .map_err(|e| StoreAndForwardError::InvariantError(e.to_string()))?; let mask_inverse = mask.invert().ok_or(StoreAndForwardError::DecryptionFailed)?; Ok((Some(mask_inverse * masked_sender_public_key), envelope_body)) } else { @@ -581,7 +567,7 @@ where S: Service } else { Some(Self::authenticate_message(&header.message_signature, header, body)?) }; - let envelope_body = EnvelopeBody::decode(body).map_err(|_| StoreAndForwardError::MalformedMessage)?; + let envelope_body = EnvelopeBody::decode(body).map_err(StoreAndForwardError::MalformedEnvelopeBody)?; Ok((authenticated_pk, envelope_body)) } } @@ -615,6 +601,163 @@ where S: Service None => Err(StoreAndForwardError::ReceivedUnrequestedSafMessages), } } + + #[allow(clippy::too_many_lines)] + pub async fn process_saf_message_validation_result( + &mut self, + source_peer: &CommsPublicKey, + result: Result, + ) -> Option> { + match result { + Ok(t) => Some(Ok(t)), + // Failed decryption is acceptable, the message wasn't for this node so we + // simply discard the message. + Err(err @ StoreAndForwardError::DhtEncryptError(_)) | Err(err @ StoreAndForwardError::DecryptionFailed) => { + debug!( + target: LOG_TARGET, + "Unable to decrypt stored message sent by {}: {}", + source_peer, + err + ); + None + }, + // The peer that originally sent this message is not known to us. + Err(StoreAndForwardError::PeerManagerError(PeerManagerError::PeerNotFoundError)) => { + debug!(target: LOG_TARGET, "Origin peer not found. Discarding stored message."); + None + }, + Err(StoreAndForwardError::PeerManagerError(PeerManagerError::BannedPeer)) => { + debug!(target: LOG_TARGET, "Origin peer was banned. Discarding stored message."); + None + }, + + // These aren't be possible in this function if the code is correct. + Err(err @ StoreAndForwardError::InvariantError(_)) | + Err(err @ StoreAndForwardError::SafMessagesReceivedAfterDeadline { .. }) | + Err(err @ StoreAndForwardError::ReceivedUnrequestedSafMessages) => { + error!(target: LOG_TARGET, "BUG: unreachable error reached! {}", err); + None + }, + + // Internal errors + Err(err @ StoreAndForwardError::RequestCancelled) | + Err(err @ StoreAndForwardError::RequesterChannelClosed) | + Err(err @ StoreAndForwardError::DhtOutboundError(_)) | + Err(err @ StoreAndForwardError::StorageError(_)) | + Err(err @ StoreAndForwardError::PeerManagerError(_)) => { + error!(target: LOG_TARGET, "Internal error: {}", err); + None + }, + + // Failed to send request to Dht Actor, something has gone very wrong + Err(StoreAndForwardError::DhtActorError(err)) => { + error!( + target: LOG_TARGET, + "DhtActor returned an error. {}. This could indicate a system malfunction.", err + ); + None + }, + // Duplicate message detected, no problem it happens. + Err(StoreAndForwardError::DuplicateMessage) => { + debug!( + target: LOG_TARGET, + "Store and forward received a duplicate message. Message discarded." + ); + None + }, + + // The decrypted message did not contain a required message part. The sender has no way to know this + // so we can just ignore the message + Err(StoreAndForwardError::EnvelopeBodyMissingMessagePart) => { + debug!( + target: LOG_TARGET, + "Received stored message from peer `{}` that is missing a required message part. Message \ + discarded.", + source_peer + ); + None + }, + + // Peer sent an invalid SAF reply + Err(err @ StoreAndForwardError::StoredAtWasInFuture) | + Err(err @ StoreAndForwardError::InvalidSafResponseMessage { .. }) => { + warn!( + target: LOG_TARGET, + "SECURITY: invalid store and forward message was discarded from NodeId={}. Reason: {}. \ + This is a sign of a badly behaving node.", + source_peer, + err + ); + self.dht_requester + .ban_peer(source_peer.clone(), OffenceSeverity::High, &err) + .await; + Some(Err(err)) + }, + + // Ban - peer sent us a message containing an invalid DhtHeader or encoded signature. They should + // have discarded this message. + Err(err @ StoreAndForwardError::DecodeError(_)) | + Err(err @ StoreAndForwardError::MessageError(_)) | + Err(err @ StoreAndForwardError::MalformedEnvelopeBody(_)) | + Err(err @ StoreAndForwardError::DhtMessageError(_)) => { + warn!( + target: LOG_TARGET, + "SECURITY: invalid store and forward message was discarded from NodeId={}. Reason: {}. \ + These messages should never have been forwarded. This is a sign of a badly behaving node.", + source_peer, + err + ); + self.dht_requester + .ban_peer(source_peer.clone(), OffenceSeverity::Medium, &err) + .await; + Some(Err(err)) + }, + + Err(err @ StoreAndForwardError::BadDhtHeaderSemanticallyInvalid) | + Err(err @ StoreAndForwardError::InvalidMessageSignature(_)) => { + warn!( + target: LOG_TARGET, + "SECURITY: invalid store and forward message was discarded from NodeId={}. Reason: {}. \ + These messages should never have been forwarded. This is a sign of a badly behaving node.", + source_peer, + err + ); + self.dht_requester + .ban_peer(source_peer.clone(), OffenceSeverity::High, &err) + .await; + Some(Err(err)) + }, + + // The destination for this message is not this node, so the sender should not have sent it + Err(err @ StoreAndForwardError::InvalidDestination) => { + warn!( + target: LOG_TARGET, + "SECURITY: invalid store and forward message was discarded from NodeId={}. Reason: {}. \ + These messages should never have been forwarded. This is a sign of a badly behaving node.", + source_peer, + err + ); + self.dht_requester + .ban_peer(source_peer.clone(), OffenceSeverity::High, &err) + .await; + Some(Err(err)) + }, + Err(err @ StoreAndForwardError::PeerSentDhtMessageViaSaf) | + Err(err @ StoreAndForwardError::PeerSentSafMessageViaSaf) => { + warn!( + target: LOG_TARGET, + "SECURITY: invalid store and forward message was discarded from NodeId={}. Reason: {}. \ + These messages should never have been forwarded. This is a sign of a badly behaving node.", + source_peer, + err + ); + self.dht_requester + .ban_peer(source_peer.clone(), OffenceSeverity::High, &err) + .await; + Some(Err(err)) + }, + } + } } #[cfg(test)] @@ -832,7 +975,7 @@ mod test { let msg_a = wrap_in_envelope_body!(&b"A".to_vec()); let inbound_msg_a = - make_dht_inbound_message(&node_identity, &msg_a, DhtMessageFlags::ENCRYPTED, true, false).unwrap(); + make_dht_inbound_message(&node_identity, &msg_a, DhtMessageFlags::ENCRYPTED, true, true).unwrap(); // Need to know the peer to process a stored message peer_manager .add_peer(Clone::clone(&*inbound_msg_a.source_peer)) @@ -841,7 +984,7 @@ mod test { let msg_b = wrap_in_envelope_body!(b"B".to_vec()); let inbound_msg_b = - make_dht_inbound_message(&node_identity, &msg_b, DhtMessageFlags::ENCRYPTED, true, false).unwrap(); + make_dht_inbound_message(&node_identity, &msg_b, DhtMessageFlags::ENCRYPTED, true, true).unwrap(); // Need to know the peer to process a stored message peer_manager .add_peer(Clone::clone(&*inbound_msg_b.source_peer)) @@ -936,6 +1079,98 @@ mod test { assert_eq!(last_saf_received.second(), msg2_time.second()); } + #[tokio::test] + #[allow(clippy::similar_names, clippy::too_many_lines)] + async fn rejected_with_bad_message_semantics() { + let spy = service_spy(); + let (saf_requester, saf_mock_state) = create_store_and_forward_mock(); + + let peer_manager = build_peer_manager(); + let (oms_tx, _) = mpsc::channel(1); + + let node_identity = make_node_identity(); + + let msg_a = wrap_in_envelope_body!(&b"A".to_vec()); + + let inbound_msg_a = + make_dht_inbound_message(&node_identity, &msg_a, DhtMessageFlags::ENCRYPTED, true, false).unwrap(); + // Need to know the peer to process a stored message + peer_manager + .add_peer(Clone::clone(&*inbound_msg_a.source_peer)) + .await + .unwrap(); + + let msg_b = wrap_in_envelope_body!(b"B".to_vec()); + let inbound_msg_b = + make_dht_inbound_message(&node_identity, &msg_b, DhtMessageFlags::ENCRYPTED, false, true).unwrap(); + // Need to know the peer to process a stored message + peer_manager + .add_peer(Clone::clone(&*inbound_msg_b.source_peer)) + .await + .unwrap(); + + let msg1_time = Utc::now() + .checked_sub_signed(chrono::Duration::from_std(Duration::from_secs(60)).unwrap()) + .unwrap(); + let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), inbound_msg_a.body, msg1_time); + let msg2_time = Utc::now() + .checked_sub_signed(chrono::Duration::from_std(Duration::from_secs(30)).unwrap()) + .unwrap(); + let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, inbound_msg_b.body, msg2_time); + + let mut message = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(StoredMessagesResponse { + messages: vec![msg1.clone(), msg2], + request_id: 123, + response_type: 0 + }), + None, + make_dht_inbound_message( + &node_identity, + &b"Stored message".to_vec(), + DhtMessageFlags::ENCRYPTED, + false, + false, + ) + .unwrap(), + ); + message.dht_header.message_type = DhtMessageType::SafStoredMessages; + + let (mut dht_requester, mock) = create_dht_actor_mock(1); + task::spawn(mock.run()); + let (saf_response_signal_sender, _) = mpsc::channel(20); + + assert!(dht_requester + .get_metadata::>(DhtMetadataKey::LastSafMessageReceived) + .await + .unwrap() + .is_none()); + + // Allow request inflight check to pass + saf_mock_state.set_request_inflight(Some(Duration::from_secs(10))).await; + + let task = MessageHandlerTask::new( + Default::default(), + spy.to_service::(), + saf_requester, + dht_requester.clone(), + OutboundMessageRequester::new(oms_tx), + node_identity, + message, + saf_response_signal_sender, + ); + + let err = task.run().await.unwrap_err(); + matches!( + err.downcast_ref::().unwrap(), + StoreAndForwardError::BadDhtHeaderSemanticallyInvalid + ); + + assert_eq!(spy.call_count(), 0); + let requests = spy.take_requests(); + assert_eq!(requests.len(), 0); + } + #[tokio::test] async fn stored_at_in_future() { let spy = service_spy(); @@ -1021,7 +1256,7 @@ mod test { let msg_a = wrap_in_envelope_body!(&b"A".to_vec()); let inbound_msg_a = - make_dht_inbound_message(&node_identity, &msg_a, DhtMessageFlags::ENCRYPTED, true, false).unwrap(); + make_dht_inbound_message(&node_identity, &msg_a, DhtMessageFlags::ENCRYPTED, true, true).unwrap(); peer_manager .add_peer(Clone::clone(&*inbound_msg_a.source_peer)) .await diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs index 3285795c87..d39132ff76 100644 --- a/comms/dht/src/store_forward/service.rs +++ b/comms/dht/src/store_forward/service.rs @@ -20,7 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{convert::TryFrom, sync::Arc, time::Duration}; +use std::{ + convert::{TryFrom, TryInto}, + sync::Arc, + time::Duration, +}; use chrono::{DateTime, NaiveDateTime, Utc}; use log::*; @@ -67,6 +71,7 @@ pub struct FetchStoredMessageQuery { node_id: Box, since: Option>, response_type: SafResponseType, + limit: Option, } impl FetchStoredMessageQuery { @@ -77,9 +82,16 @@ impl FetchStoredMessageQuery { node_id, since: None, response_type: SafResponseType::Anonymous, + limit: None, } } + /// Limit the number of messages returned + pub fn with_limit(&mut self, limit: u32) -> &mut Self { + self.limit = Some(limit); + self + } + /// Modify query to only include messages since the given date. pub fn with_messages_since(&mut self, since: DateTime) -> &mut Self { self.since = Some(since); @@ -401,8 +413,7 @@ impl StoreAndForwardService { .finish(), request, ) - .await - .map_err(StoreAndForwardError::RequestMessagesFailed)?; + .await?; Ok(()) } @@ -428,20 +439,21 @@ impl StoreAndForwardService { .finish(), request, ) - .await - .map_err(StoreAndForwardError::RequestMessagesFailed)?; + .await?; Ok(()) } async fn get_saf_request(&mut self) -> SafResult { - let request = self + let mut request = self .dht_requester .get_metadata(DhtMetadataKey::LastSafMessageReceived) .await? .map(StoredMessagesRequest::since) .unwrap_or_else(StoredMessagesRequest::new); + request.limit = self.config.max_returned_messages.try_into().unwrap_or(u32::MAX); + Ok(request) } @@ -473,9 +485,10 @@ impl StoreAndForwardService { fn handle_fetch_message_query(&self, query: &FetchStoredMessageQuery) -> SafResult> { use SafResponseType::{Anonymous, Discovery, ForMe, Join}; - let limit = i64::try_from(self.config.max_returned_messages) - .ok() - .unwrap_or(std::i64::MAX); + let limit = query + .limit + .and_then(|v| i64::try_from(v).ok()) + .unwrap_or(self.config.max_returned_messages as i64); let db = &self.database; let messages = match query.response_type { ForMe => db.find_messages_for_peer(&query.public_key, &query.node_id, query.since, limit)?, diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index 6c177cecfb..16283bede2 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -34,13 +34,7 @@ use tower::{layer::Layer, Service, ServiceExt}; use super::StoreAndForwardRequester; use crate::{ inbound::DecryptedDhtMessage, - store_forward::{ - database::NewStoredMessage, - error::StoreAndForwardError, - message::StoredMessagePriority, - SafConfig, - SafResult, - }, + store_forward::{database::NewStoredMessage, message::StoredMessagePriority, SafConfig, SafResult}, }; const LOG_TARGET: &str = "comms::dht::storeforward::store"; @@ -205,10 +199,12 @@ where S: Service + Se } message.set_saf_stored(false); - if let Some(priority) = self.get_storage_priority(&message).await? { - message.set_saf_stored(true); - let existing = self.store(priority, message.clone()).await?; - message.set_already_forwarded(existing); + if self.is_valid_for_storage(&message) { + if let Some(priority) = self.get_storage_priority(&message).await? { + message.set_saf_stored(true); + let existing = self.store(priority, message.clone()).await?; + message.set_already_forwarded(existing); + } } trace!( @@ -222,6 +218,35 @@ where S: Service + Se service.oneshot(message).await } + fn is_valid_for_storage(&self, message: &DecryptedDhtMessage) -> bool { + if message.body_len() == 0 { + debug!( + target: LOG_TARGET, + "Message {} from peer '{}' not eligible for SAF storage because it has no body (Trace: {})", + message.tag, + message.source_peer.node_id.short_str(), + message.dht_header.message_tag + ); + return false; + } + + if let Some(expires) = message.dht_header.expires { + let now = EpochTime::now(); + if expires < now { + debug!( + target: LOG_TARGET, + "Message {} from peer '{}' not eligible for SAF storage because it has expired (Trace: {})", + message.tag, + message.source_peer.node_id.short_str(), + message.dht_header.message_tag + ); + return false; + } + } + + true + } + async fn get_storage_priority(&self, message: &DecryptedDhtMessage) -> SafResult> { let log_not_eligible = |reason: &str| { debug!( @@ -248,13 +273,8 @@ where S: Service + Se return Ok(None); } - if message.dht_header.message_type.is_dht_join() { - log_not_eligible("it is a join message"); - return Ok(None); - } - - if message.dht_header.message_type.is_dht_discovery() { - log_not_eligible("it is a discovery message"); + if message.dht_header.message_type.is_dht_message() { + log_not_eligible(&format!("it is a DHT {} message", message.dht_header.message_type)); return Ok(None); } @@ -389,13 +409,6 @@ where S: Service + Se message.dht_header.message_tag, ); - if let Some(expires) = message.dht_header.expires { - let now = EpochTime::now(); - if expires < now { - return Err(StoreAndForwardError::NotStoringExpiredMessage { expired: expires, now }); - } - } - let stored_message = NewStoredMessage::new(message, priority); self.saf_requester.insert_message(stored_message).await } diff --git a/comms/dht/tests/attacks.rs b/comms/dht/tests/attacks.rs index c7c460fc87..319e17168c 100644 --- a/comms/dht/tests/attacks.rs +++ b/comms/dht/tests/attacks.rs @@ -99,7 +99,7 @@ async fn large_join_messages_with_many_addresses() { .await .unwrap(), expect = true, - max_attempts = 10, + max_attempts = 20, interval = Duration::from_secs(1) ); // Node B did not propagate