diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index c2c2d4e52a..ce56e8765b 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -101,9 +101,9 @@ impl From for DhtActorError { pub enum DhtRequest { /// Send a Join request to the network SendJoin, - /// Inserts a message signature to the msg hash cache. This operation replies with a boolean - /// which is true if the signature already exists in the cache, otherwise false - MsgHashCacheInsert(Vec, CommsPublicKey, oneshot::Sender), + /// Inserts a message signature to the msg hash cache. This operation replies with the number of times this message + /// has previously been seen (hit count) + MsgHashCacheInsert(Vec, CommsPublicKey, oneshot::Sender), /// Fetch selected peers according to the broadcast strategy SelectPeers(BroadcastStrategy, oneshot::Sender>), GetMetadata(DhtMetadataKey, oneshot::Sender>, DhtActorError>>), @@ -151,7 +151,7 @@ impl DhtRequester { &mut self, message_hash: Vec, public_key: CommsPublicKey, - ) -> Result { + ) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); self.sender .send(DhtRequest::MsgHashCacheInsert(message_hash, public_key, reply_tx)) @@ -268,7 +268,7 @@ impl DhtActor { }, _ = dedup_cache_trim_ticker.select_next_some() => { - if let Err(err) = self.msg_hash_dedup_cache.truncate().await { + if let Err(err) = self.msg_hash_dedup_cache.trim_entries().await { error!(target: LOG_TARGET, "Error when trimming message dedup cache: {:?}", err); } }, @@ -303,16 +303,16 @@ impl DhtActor { MsgHashCacheInsert(hash, public_key, reply_tx) => { let msg_hash_cache = self.msg_hash_dedup_cache.clone(); Box::pin(async move { - match msg_hash_cache.insert_body_hash_if_unique(hash, public_key).await { - Ok(already_exists) => { - let _ = reply_tx.send(already_exists).map_err(|_| DhtActorError::ReplyCanceled); + match msg_hash_cache.add_body_hash(hash, public_key).await { + Ok(hit_count) => { + let _ = reply_tx.send(hit_count); }, Err(err) => { warn!( target: LOG_TARGET, "Unable to update message dedup cache because {:?}", err ); - let _ = reply_tx.send(false).map_err(|_| DhtActorError::ReplyCanceled); + let _ = reply_tx.send(0); }, } Ok(()) @@ -690,11 +690,9 @@ mod test { test_utils::{build_peer_manager, make_client_identity, make_node_identity}, }; use chrono::{DateTime, Utc}; - use std::time::Duration; use tari_comms::test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair}; use tari_shutdown::Shutdown; use tari_test_utils::random; - use tokio::time::delay_for; async fn db_connection() -> DbConnection { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); @@ -756,21 +754,21 @@ mod test { actor.spawn(); let signature = vec![1u8, 2, 3]; - let is_dup = requester + let num_hits = requester .insert_message_hash(signature.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); - let is_dup = requester + assert_eq!(num_hits, 1); + let num_hits = requester .insert_message_hash(signature, CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); - let is_dup = requester + assert_eq!(num_hits, 2); + let num_hits = requester .insert_message_hash(Vec::new(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } #[tokio_macros::test_basic] @@ -783,14 +781,12 @@ mod test { let (actor_tx, actor_rx) = mpsc::channel(1); let mut requester = DhtRequester::new(actor_tx); let outbound_requester = OutboundMessageRequester::new(out_tx); - let mut shutdown = Shutdown::new(); - let trim_interval_ms = 500; + let shutdown = Shutdown::new(); // Note: This must be equal or larger than the minimum dedup cache capacity for DedupCacheDatabase - let capacity = 120; + let capacity = 10; let actor = DhtActor::new( DhtConfig { dedup_cache_capacity: capacity, - dedup_cache_trim_interval: Duration::from_millis(trim_interval_ms), ..Default::default() }, db_connection().await, @@ -803,63 +799,61 @@ mod test { ); // Create signatures for double the dedup cache capacity - let mut signatures: Vec> = Vec::new(); - for i in 0..(capacity * 2) { - signatures.push(vec![1u8, 2, i as u8]) - } + let signatures = (0..(capacity * 2)).map(|i| vec![1u8, 2, i as u8]).collect::>(); - // Pre-populate the dedup cache; everything should be accepted due to cleanup ticker not active yet + // Pre-populate the dedup cache; everything should be accepted because the cleanup ticker has not run yet for key in &signatures { - let is_dup = actor + let num_hits = actor .msg_hash_dedup_cache - .insert_body_hash_if_unique(key.clone(), CommsPublicKey::default()) + .add_body_hash(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - // Try to re-insert all; everything should be marked as duplicates due to cleanup ticker not active yet + // Try to re-insert all; all hashes should have incremented their hit count for key in &signatures { - let is_dup = actor + let num_hits = actor .msg_hash_dedup_cache - .insert_body_hash_if_unique(key.clone(), CommsPublicKey::default()) + .add_body_hash(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); + assert_eq!(num_hits, 2); } - // The cleanup ticker starts when the actor is spawned; the first cleanup event will fire immediately + let dedup_cache_db = actor.msg_hash_dedup_cache.clone(); + // The cleanup ticker starts when the actor is spawned; the first cleanup event will fire fairly soon after the + // task is running on a thread. To remove this race condition, we trim the cache in the test. + dedup_cache_db.trim_entries().await.unwrap(); actor.spawn(); // Verify that the last half of the signatures are still present in the cache for key in signatures.iter().take(capacity * 2).skip(capacity) { - let is_dup = requester + let num_hits = requester .insert_message_hash(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); + assert_eq!(num_hits, 3); } // Verify that the first half of the signatures have been removed and can be re-inserted into cache for key in signatures.iter().take(capacity) { - let is_dup = requester + let num_hits = requester .insert_message_hash(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - // Let the trim period expire; this will trim the dedup cache to capacity - delay_for(Duration::from_millis(trim_interval_ms * 2)).await; + // Trim the database of excess entries + dedup_cache_db.trim_entries().await.unwrap(); // Verify that the last half of the signatures have been removed and can be re-inserted into cache for key in signatures.iter().take(capacity * 2).skip(capacity) { - let is_dup = requester + let num_hits = requester .insert_message_hash(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - - shutdown.trigger().unwrap(); } #[tokio_macros::test_basic] diff --git a/comms/dht/src/builder.rs b/comms/dht/src/builder.rs index 249ed3d369..382d28cf4a 100644 --- a/comms/dht/src/builder.rs +++ b/comms/dht/src/builder.rs @@ -99,6 +99,11 @@ impl DhtBuilder { self } + pub fn with_dedup_discard_hit_count(mut self, max_hit_count: usize) -> Self { + self.config.dedup_discard_hit_count = max_hit_count; + self + } + pub fn with_num_random_nodes(mut self, n: usize) -> Self { self.config.num_random_nodes = n; self diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 0612445dca..09c36dcef0 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -72,6 +72,9 @@ pub struct DhtConfig { /// The periodic trim interval for items in the message hash cache /// Default: 300s (5 mins) pub dedup_cache_trim_interval: Duration, + /// The number of hits a message is allowed before being discarded + /// Default: 3 + pub dedup_discard_hit_count: usize, /// The duration to wait for a peer discovery to complete before giving up. /// Default: 2 minutes pub discovery_request_timeout: Duration, @@ -136,6 +139,7 @@ impl DhtConfig { impl Default for DhtConfig { fn default() -> Self { + // NB: please remember to update field comments to reflect these defaults Self { num_neighbouring_nodes: 8, num_random_nodes: 4, @@ -151,6 +155,7 @@ impl Default for DhtConfig { saf_max_message_size: 512 * 1024, dedup_cache_capacity: 2_500, dedup_cache_trim_interval: Duration::from_secs(5 * 60), + dedup_discard_hit_count: 3, database_url: DbConnectionUrl::Memory, discovery_request_timeout: Duration::from_secs(2 * 60), connectivity_update_interval: Duration::from_secs(2 * 60), diff --git a/comms/dht/src/dedup/dedup_cache.rs b/comms/dht/src/dedup/dedup_cache.rs index f8f5f6fcbf..564206651d 100644 --- a/comms/dht/src/dedup/dedup_cache.rs +++ b/comms/dht/src/dedup/dedup_cache.rs @@ -24,15 +24,23 @@ use crate::{ schema::dedup_cache, storage::{DbConnection, StorageError}, }; -use chrono::Utc; +use chrono::{NaiveDateTime, Utc}; use diesel::{dsl, result::DatabaseErrorKind, ExpressionMethods, QueryDsl, RunQueryDsl}; use log::*; use tari_comms::types::CommsPublicKey; -use tari_crypto::tari_utilities::{hex::Hex, ByteArray}; -use tari_utilities::hex; +use tari_crypto::tari_utilities::hex::Hex; const LOG_TARGET: &str = "comms::dht::dedup_cache"; +#[derive(Queryable, PartialEq, Debug)] +struct DedupCacheEntry { + body_hash: String, + sender_public_ke: String, + number_of_hit: i32, + stored_at: NaiveDateTime, + last_hit_at: NaiveDateTime, +} + #[derive(Clone)] pub struct DedupCacheDatabase { connection: DbConnection, @@ -48,36 +56,24 @@ impl DedupCacheDatabase { Self { connection, capacity } } - /// Inserts and returns Ok(true) if the item already existed and Ok(false) if it didn't, also updating hit stats - pub async fn insert_body_hash_if_unique( - &self, - body_hash: Vec, - public_key: CommsPublicKey, - ) -> Result { - let body_hash = hex::to_hex(&body_hash.as_bytes()); - let public_key = public_key.to_hex(); - match self - .insert_body_hash_or_update_stats(body_hash.clone(), public_key.clone()) - .await - { - Ok(val) => { - if val == 0 { - warn!( - target: LOG_TARGET, - "Unable to insert new entry into message dedup cache" - ); - } - Ok(false) - }, - Err(e) => match e { - StorageError::UniqueViolation(_) => Ok(true), - _ => Err(e), - }, + /// Adds the body hash to the cache, returning the number of hits (inclusive) that have been recorded for this body + /// hash + pub async fn add_body_hash(&self, body_hash: Vec, public_key: CommsPublicKey) -> Result { + let hit_count = self + .insert_body_hash_or_update_stats(body_hash.to_hex(), public_key.to_hex()) + .await?; + + if hit_count == 0 { + warn!( + target: LOG_TARGET, + "Unable to insert new entry into message dedup cache" + ); } + Ok(hit_count) } /// Trims the dedup cache to the configured limit by removing the oldest entries - pub async fn truncate(&self) -> Result { + pub async fn trim_entries(&self) -> Result { let capacity = self.capacity; self.connection .with_connection_async(move |conn| { @@ -109,40 +105,46 @@ impl DedupCacheDatabase { .await } - // Insert new row into the table or update existing row in an atomic fashion; more than one thread can access this - // table at the same time. + /// Insert new row into the table or updates an existing row. Returns the number of hits for this body hash. async fn insert_body_hash_or_update_stats( &self, body_hash: String, public_key: String, - ) -> Result { + ) -> Result { self.connection .with_connection_async(move |conn| { let insert_result = diesel::insert_into(dedup_cache::table) .values(( - dedup_cache::body_hash.eq(body_hash.clone()), - dedup_cache::sender_public_key.eq(public_key.clone()), + dedup_cache::body_hash.eq(&body_hash), + dedup_cache::sender_public_key.eq(&public_key), dedup_cache::number_of_hits.eq(1), dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), )) .execute(conn); match insert_result { - Ok(val) => Ok(val), + Ok(1) => Ok(1), + Ok(n) => Err(StorageError::UnexpectedResult(format!( + "Expected exactly one row to be inserted. Got {}", + n + ))), Err(diesel::result::Error::DatabaseError(kind, e_info)) => match kind { DatabaseErrorKind::UniqueViolation => { // Update hit stats for the message - let result = - diesel::update(dedup_cache::table.filter(dedup_cache::body_hash.eq(&body_hash))) - .set(( - dedup_cache::sender_public_key.eq(public_key), - dedup_cache::number_of_hits.eq(dedup_cache::number_of_hits + 1), - dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), - )) - .execute(conn); - match result { - Ok(_) => Err(StorageError::UniqueViolation(body_hash)), - Err(e) => Err(e.into()), - } + diesel::update(dedup_cache::table.filter(dedup_cache::body_hash.eq(&body_hash))) + .set(( + dedup_cache::sender_public_key.eq(&public_key), + dedup_cache::number_of_hits.eq(dedup_cache::number_of_hits + 1), + dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), + )) + .execute(conn)?; + // TODO: Diesel support for RETURNING statements would remove this query, but is not + // available for Diesel + SQLite yet + let hits = dedup_cache::table + .select(dedup_cache::number_of_hits) + .filter(dedup_cache::body_hash.eq(&body_hash)) + .get_result::(conn)?; + + Ok(hits as u32) }, _ => Err(diesel::result::Error::DatabaseError(kind, e_info).into()), }, diff --git a/comms/dht/src/dedup/mod.rs b/comms/dht/src/dedup/mod.rs index 5428277af0..a93ccfd980 100644 --- a/comms/dht/src/dedup/mod.rs +++ b/comms/dht/src/dedup/mod.rs @@ -47,13 +47,15 @@ fn hash_inbound_message(message: &DhtInboundMessage) -> Vec { pub struct DedupMiddleware { next_service: S, dht_requester: DhtRequester, + discard_on_hit_count: usize, } impl DedupMiddleware { - pub fn new(service: S, dht_requester: DhtRequester) -> Self { + pub fn new(service: S, dht_requester: DhtRequester, discard_on_hit_count: usize) -> Self { Self { next_service: service, dht_requester, + discard_on_hit_count, } } } @@ -71,9 +73,10 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, message: DhtInboundMessage) -> Self::Future { + fn call(&mut self, mut message: DhtInboundMessage) -> Self::Future { let next_service = self.next_service.clone(); let mut dht_requester = self.dht_requester.clone(); + let discard_on_hit_count = self.discard_on_hit_count; Box::pin(async move { let hash = hash_inbound_message(&message); trace!( @@ -83,14 +86,17 @@ where message.tag, message.dht_header.message_tag ); - if dht_requester + + message.dedup_hit_count = dht_requester .insert_message_hash(hash, message.source_peer.public_key.clone()) - .await? - { + .await?; + + if message.dedup_hit_count as usize >= discard_on_hit_count { trace!( target: LOG_TARGET, - "Received duplicate message {} from peer '{}' (Trace: {}). Message discarded.", + "Received duplicate message {} (hit_count = {}) from peer '{}' (Trace: {}). Message discarded.", message.tag, + message.dedup_hit_count, message.source_peer.node_id.short_str(), message.dht_header.message_tag, ); @@ -99,8 +105,9 @@ where trace!( target: LOG_TARGET, - "Passing message {} onto next service (Trace: {})", + "Passing message {} (hit_count = {}) onto next service (Trace: {})", message.tag, + message.dedup_hit_count, message.dht_header.message_tag ); next_service.oneshot(message).await @@ -110,11 +117,15 @@ where pub struct DedupLayer { dht_requester: DhtRequester, + discard_on_hit_count: usize, } impl DedupLayer { - pub fn new(dht_requester: DhtRequester) -> Self { - Self { dht_requester } + pub fn new(dht_requester: DhtRequester, discard_on_hit_count: usize) -> Self { + Self { + dht_requester, + discard_on_hit_count, + } } } @@ -122,7 +133,7 @@ impl Layer for DedupLayer { type Service = DedupMiddleware; fn layer(&self, service: S) -> Self::Service { - DedupMiddleware::new(service, self.dht_requester.clone()) + DedupMiddleware::new(service, self.dht_requester.clone(), self.discard_on_hit_count) } } @@ -143,10 +154,10 @@ mod test { let (dht_requester, mock) = create_dht_actor_mock(1); let mock_state = mock.get_shared_state(); - mock_state.set_signature_cache_insert(false); + mock_state.set_number_of_message_hits(1); rt.spawn(mock.run()); - let mut dedup = DedupLayer::new(dht_requester).layer(spy.to_service::()); + let mut dedup = DedupLayer::new(dht_requester, 3).layer(spy.to_service::()); panic_context!(cx); @@ -157,7 +168,7 @@ mod test { rt.block_on(dedup.call(msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); - mock_state.set_signature_cache_insert(true); + mock_state.set_number_of_message_hits(3); rt.block_on(dedup.call(msg)).unwrap(); assert_eq!(spy.call_count(), 1); // Drop dedup so that the DhtMock will stop running diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index dcdeea5730..32c778f8ca 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -26,6 +26,7 @@ use crate::{ connectivity::{DhtConnectivity, MetricsCollector, MetricsCollectorHandle}, discovery::{DhtDiscoveryRequest, DhtDiscoveryRequester, DhtDiscoveryService}, event::{DhtEventReceiver, DhtEventSender}, + filter, inbound, inbound::{DecryptedDhtMessage, DhtInboundMessage, MetricsLayer}, logging_middleware::MessageLoggingLayer, @@ -37,12 +38,11 @@ use crate::{ storage::{DbConnection, StorageError}, store_forward, store_forward::{StoreAndForwardError, StoreAndForwardRequest, StoreAndForwardRequester, StoreAndForwardService}, - tower_filter, DedupLayer, DhtActorError, DhtConfig, }; -use futures::{channel::mpsc, future, Future}; +use futures::{channel::mpsc, Future}; use log::*; use std::sync::Arc; use tari_comms::{ @@ -285,13 +285,14 @@ impl Dht { S: Service + Clone + Send + Sync + 'static, S::Future: Send, { - // FIXME: There is an unresolved stack overflow issue on windows in debug mode during runtime, but not in - // release mode, related to the amount of layers. (issue #1416) ServiceBuilder::new() .layer(MetricsLayer::new(self.metrics_collector.clone())) .layer(inbound::DeserializeLayer::new(self.peer_manager.clone())) - .layer(DedupLayer::new(self.dht_requester())) - .layer(tower_filter::FilterLayer::new(self.unsupported_saf_messages_filter())) + .layer(DedupLayer::new( + self.dht_requester(), + self.config.dedup_discard_hit_count, + )) + .layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter())) .layer(MessageLoggingLayer::new(format!( "Inbound [{}]", self.node_identity.node_id().short_str() @@ -301,6 +302,7 @@ impl Dht { self.node_identity.clone(), self.connectivity.clone(), )) + .layer(filter::FilterLayer::new(filter_messages_to_rebroadcast)) .layer(store_forward::StoreLayer::new( self.config.clone(), Arc::clone(&self.peer_manager), @@ -363,34 +365,60 @@ impl Dht { /// Produces a filter predicate which disallows store and forward messages if that feature is not /// supported by the node. - fn unsupported_saf_messages_filter( - &self, - ) -> impl tower_filter::Predicate>> + Clone + Send - { + fn unsupported_saf_messages_filter(&self) -> impl filter::Predicate + Clone + Send { let node_identity = Arc::clone(&self.node_identity); move |msg: &DhtInboundMessage| { if node_identity.has_peer_features(PeerFeatures::DHT_STORE_FORWARD) { - return future::ready(Ok(())); + return true; } match msg.dht_header.message_type { DhtMessageType::SafRequestMessages => { // TODO: #banheuristic This is an indication of node misbehaviour - debug!( + warn!( "Received store and forward message from PublicKey={}. Store and forward feature is not \ supported by this node. Discarding message.", msg.source_peer.public_key ); - future::ready(Err(anyhow::anyhow!( - "Message filtered out because store and forward is not supported by this node", - ))) + false }, - _ => future::ready(Ok(())), + _ => true, } } } } +/// Provides the gossip filtering rules for an inbound message +fn filter_messages_to_rebroadcast(msg: &DecryptedDhtMessage) -> bool { + // Let the message through if: + // it isn't a duplicate (normal message), or + let should_continue = !msg.is_duplicate() || + ( + // it is a duplicate domain message (i.e. not DHT or SAF protocol message), and + msg.dht_header.message_type.is_domain_message() && + // it has an unknown destination (e.g complete transactions, blocks, misc. encrypted + // messages) we allow it to proceed, which in turn, re-propagates it for another round. + msg.dht_header.destination.is_unknown() + ); + + if should_continue { + // The message has been forwarded, but downstream middleware may be interested + debug!( + target: LOG_TARGET, + "[filter_messages_to_rebroadcast] Passing message {} to next service (Trace: {})", + msg.tag, + msg.dht_header.message_tag + ); + true + } else { + debug!( + target: LOG_TARGET, + "[filter_messages_to_rebroadcast] Discarding duplicate message {}", msg + ); + false + } +} + #[cfg(test)] mod test { use crate::{ diff --git a/comms/dht/src/domain_message.rs b/comms/dht/src/domain_message.rs index 2fe7af16fe..f565882725 100644 --- a/comms/dht/src/domain_message.rs +++ b/comms/dht/src/domain_message.rs @@ -33,7 +33,7 @@ impl ToProtoEnum for i32 { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct OutboundDomainMessage { inner: T, message_type: i32, diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index 0b93546dbb..fb7e02050f 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -106,8 +106,12 @@ impl DhtMessageFlags { } impl DhtMessageType { + pub fn is_domain_message(self) -> bool { + matches!(self, DhtMessageType::None) + } + pub fn is_dht_message(self) -> bool { - self.is_dht_discovery() || self.is_dht_join() + self.is_dht_discovery() || matches!(self, DhtMessageType::DiscoveryResponse) || self.is_dht_join() } pub fn is_dht_discovery(self) -> bool { diff --git a/comms/dht/src/tower_filter/future.rs b/comms/dht/src/filter/future.rs similarity index 66% rename from comms/dht/src/tower_filter/future.rs rename to comms/dht/src/filter/future.rs index 78b2c613e6..4559aeaadf 100644 --- a/comms/dht/src/tower_filter/future.rs +++ b/comms/dht/src/filter/future.rs @@ -13,16 +13,15 @@ use tower::Service; /// Filtered response future #[pin_project] #[derive(Debug)] -pub struct ResponseFuture +pub struct ResponseFuture where S: Service { #[pin] /// Response future state state: State, - #[pin] - /// Predicate future - check: T, + /// Predicate result + check: bool, /// Inner service service: S, @@ -35,12 +34,10 @@ enum State { WaitResponse(#[pin] U), } -impl ResponseFuture -where - F: Future>, - S: Service, +impl ResponseFuture +where S: Service { - pub(crate) fn new(request: Request, check: F, service: S) -> Self { + pub(crate) fn new(request: Request, check: bool, service: S) -> Self { ResponseFuture { state: State::Check(Some(request)), check, @@ -49,10 +46,8 @@ where } } -impl Future for ResponseFuture -where - F: Future>, - S: Service, +impl Future for ResponseFuture +where S: Service { type Output = Result; @@ -66,15 +61,13 @@ where .take() .expect("we either give it back or leave State::Check once we take"); - // Poll predicate - match this.check.as_mut().poll(cx)? { - Poll::Ready(_) => { + match this.check { + true => { let response = this.service.call(request); this.state.set(State::WaitResponse(response)); }, - Poll::Pending => { - this.state.set(State::Check(Some(request))); - return Poll::Pending; + false => { + return Poll::Ready(Ok(())); }, } }, diff --git a/comms/dht/src/tower_filter/layer.rs b/comms/dht/src/filter/layer.rs similarity index 100% rename from comms/dht/src/tower_filter/layer.rs rename to comms/dht/src/filter/layer.rs diff --git a/comms/dht/src/tower_filter/mod.rs b/comms/dht/src/filter/mod.rs similarity index 92% rename from comms/dht/src/tower_filter/mod.rs rename to comms/dht/src/filter/mod.rs index d1df2f27a7..e7f168161b 100644 --- a/comms/dht/src/tower_filter/mod.rs +++ b/comms/dht/src/filter/mod.rs @@ -33,11 +33,11 @@ impl Filter { impl Service for Filter where - T: Service + Clone, + T: Service + Clone, U: Predicate, { type Error = PipelineError; - type Future = ResponseFuture; + type Future = ResponseFuture; type Response = T::Response; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { diff --git a/comms/dht/src/filter/predicate.rs b/comms/dht/src/filter/predicate.rs new file mode 100644 index 0000000000..024dee826d --- /dev/null +++ b/comms/dht/src/filter/predicate.rs @@ -0,0 +1,13 @@ +/// Checks a request +pub trait Predicate { + /// Check whether the given request should be forwarded. + fn check(&mut self, request: &Request) -> bool; +} + +impl Predicate for F +where F: Fn(&T) -> bool +{ + fn check(&mut self, request: &T) -> bool { + self(request) + } +} diff --git a/comms/dht/src/inbound/dht_handler/task.rs b/comms/dht/src/inbound/dht_handler/task.rs index f45507a905..99ddddef3a 100644 --- a/comms/dht/src/inbound/dht_handler/task.rs +++ b/comms/dht/src/inbound/dht_handler/task.rs @@ -88,6 +88,21 @@ where S: Service return Ok(()); } + if message.is_duplicate() { + debug!( + target: LOG_TARGET, + "Received message ({}) that has already been received {} time(s) from peer '{}', passing on (Trace: \ + {}, forwarded = {:?})", + message.tag, + message.dedup_hit_count, + message.source_peer.node_id.short_str(), + message.dht_header.message_tag, + message.is_already_forwarded + ); + self.next_service.oneshot(message).await?; + return Ok(()); + } + trace!( target: LOG_TARGET, "Received DHT message type `{}` (Source peer: {}, Tag: {}, Trace: {})", diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index a49ae4b073..c9cdd103fd 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -43,6 +43,7 @@ pub struct DhtInboundMessage { pub dht_header: DhtMessageHeader, /// True if forwarded via store and forward, otherwise false pub is_saf_message: bool, + pub dedup_hit_count: u32, pub body: Vec, } impl DhtInboundMessage { @@ -53,6 +54,7 @@ impl DhtInboundMessage { dht_header, source_peer, is_saf_message: false, + dedup_hit_count: 0, body, } } @@ -62,11 +64,12 @@ impl Display for DhtInboundMessage { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> { write!( f, - "\n---- Inbound Message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\n{}\n----", + "\n---- Inbound Message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHit Count: {}\nHeader: {}\n{}\n----", self.body.len(), self.dht_header.message_type, self.source_peer, self.dht_header, + self.dedup_hit_count, self.tag, ) } @@ -86,6 +89,14 @@ pub struct DecryptedDhtMessage { pub is_saf_stored: Option, pub is_already_forwarded: bool, pub decryption_result: Result>, + pub dedup_hit_count: u32, +} + +impl DecryptedDhtMessage { + /// Returns true if this message has been received before, otherwise false if this is the first time + pub fn is_duplicate(&self) -> bool { + self.dedup_hit_count > 1 + } } impl DecryptedDhtMessage { @@ -104,6 +115,7 @@ impl DecryptedDhtMessage { is_saf_stored: None, is_already_forwarded: false, decryption_result: Ok(message_body), + dedup_hit_count: message.dedup_hit_count, } } @@ -118,6 +130,7 @@ impl DecryptedDhtMessage { is_saf_stored: None, is_already_forwarded: false, decryption_result: Err(message.body), + dedup_hit_count: message.dedup_hit_count, } } diff --git a/comms/dht/src/lib.rs b/comms/dht/src/lib.rs index cab2f8ab6f..7ef238d16e 100644 --- a/comms/dht/src/lib.rs +++ b/comms/dht/src/lib.rs @@ -153,11 +153,11 @@ pub use storage::DbConnectionUrl; mod dedup; pub use dedup::DedupLayer; +mod filter; mod logging_middleware; mod proto; mod rpc; mod schema; -mod tower_filter; pub mod broadcast_strategy; pub mod domain_message; diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 0aa9fab611..4921c3cddb 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -448,7 +448,7 @@ where S: Service Ok(messages.unzip()) } - async fn add_to_dedup_cache(&mut self, body: &[u8], public_key: CommsPublicKey) -> Result { + async fn add_to_dedup_cache(&mut self, body: &[u8], public_key: CommsPublicKey) -> Result<(), DhtOutboundError> { let hash = Challenge::new().chain(&body).finalize().to_vec(); trace!( target: LOG_TARGET, @@ -459,7 +459,8 @@ where S: Service self.dht_requester .insert_message_hash(hash, public_key) .await - .map_err(|_| DhtOutboundError::FailedToInsertMessageHash) + .map_err(|_| DhtOutboundError::FailedToInsertMessageHash)?; + Ok(()) } fn process_encryption( diff --git a/comms/dht/src/storage/error.rs b/comms/dht/src/storage/error.rs index ab9f52f78d..f5bf4f0596 100644 --- a/comms/dht/src/storage/error.rs +++ b/comms/dht/src/storage/error.rs @@ -40,4 +40,6 @@ pub enum StorageError { ResultError(#[from] diesel::result::Error), #[error("MessageFormatError: {0}")] MessageFormatError(#[from] MessageFormatError), + #[error("Unexpected result: {0}")] + UnexpectedResult(String), } diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index 95ce5e2500..013917e6e4 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -153,7 +153,7 @@ where S: Service self.forward(&message).await?; } - // The message has been forwarded, but other middleware may be interested (i.e. StoreMiddleware) + // The message has been forwarded, but downstream middleware may be interested trace!( target: LOG_TARGET, "Passing message {} to next service (Trace: {})", @@ -205,8 +205,9 @@ where S: Service } let body = decryption_result - .clone() + .as_ref() .err() + .cloned() .expect("previous check that decryption failed"); let excluded_peers = vec![source_peer.node_id.clone()]; diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index f3ba852118..fc6fb04bce 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -103,6 +103,21 @@ where S: Service .take() .expect("DhtInboundMessageTask initialized without message"); + if message.is_duplicate() { + debug!( + target: LOG_TARGET, + "Received message ({}) that has already been received {} time(s) from peer '{}', passing on (Trace: \ + {}, forwarded = {:?})", + message.tag, + message.dedup_hit_count, + message.source_peer.node_id.short_str(), + message.dht_header.message_tag, + message.is_already_forwarded + ); + self.next_service.oneshot(message).await?; + return Ok(()); + } + if message.dht_header.message_type.is_saf_message() && message.decryption_failed() { debug!( target: LOG_TARGET, @@ -460,7 +475,8 @@ where S: Service public_key: CommsPublicKey, ) -> Result<(), StoreAndForwardError> { let msg_hash = Challenge::new().chain(body).finalize().to_vec(); - if dht_requester.insert_message_hash(msg_hash, public_key).await? { + let hit_count = dht_requester.insert_message_hash(msg_hash, public_key).await?; + if hit_count > 1 { Err(StoreAndForwardError::DuplicateMessage) } else { Ok(()) diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index 4393f36518..c23133f23d 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -122,16 +122,31 @@ where } fn call(&mut self, msg: DecryptedDhtMessage) -> Self::Future { - Box::pin( - StoreTask::new( - self.next_service.clone(), - self.config.clone(), - Arc::clone(&self.peer_manager), - Arc::clone(&self.node_identity), - self.saf_requester.clone(), + if msg.is_duplicate() { + trace!( + target: LOG_TARGET, + "Passing duplicate message {} to next service (Trace: {})", + msg.tag, + msg.dht_header.message_tag + ); + + let service = self.next_service.clone(); + Box::pin(async move { + let service = service.ready_oneshot().await?; + service.oneshot(msg).await + }) + } else { + Box::pin( + StoreTask::new( + self.next_service.clone(), + self.config.clone(), + Arc::clone(&self.peer_manager), + Arc::clone(&self.node_identity), + self.saf_requester.clone(), + ) + .handle(msg), ) - .handle(msg), - ) + } } } diff --git a/comms/dht/src/test_utils/dht_actor_mock.rs b/comms/dht/src/test_utils/dht_actor_mock.rs index ccc53c5a1e..c3e4edc648 100644 --- a/comms/dht/src/test_utils/dht_actor_mock.rs +++ b/comms/dht/src/test_utils/dht_actor_mock.rs @@ -29,7 +29,7 @@ use futures::{channel::mpsc, stream::Fuse, StreamExt}; use std::{ collections::HashMap, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, RwLock, }, @@ -44,7 +44,7 @@ pub fn create_dht_actor_mock(buf_size: usize) -> (DhtRequester, DhtActorMock) { #[derive(Default, Debug, Clone)] pub struct DhtMockState { - signature_cache_insert: Arc, + signature_cache_insert: Arc, call_count: Arc, select_peers: Arc>>, settings: Arc>>>, @@ -52,16 +52,11 @@ pub struct DhtMockState { impl DhtMockState { pub fn new() -> Self { - Self { - signature_cache_insert: Arc::new(AtomicBool::new(false)), - call_count: Arc::new(AtomicUsize::new(0)), - select_peers: Arc::new(RwLock::new(Vec::new())), - settings: Arc::new(RwLock::new(HashMap::new())), - } + Default::default() } - pub fn set_signature_cache_insert(&self, v: bool) -> &Self { - self.signature_cache_insert.store(v, Ordering::SeqCst); + pub fn set_number_of_message_hits(&self, v: u32) -> &Self { + self.signature_cache_insert.store(v as usize, Ordering::SeqCst); self } @@ -113,7 +108,7 @@ impl DhtActorMock { SendJoin => {}, MsgHashCacheInsert(_, _, reply_tx) => { let v = self.state.signature_cache_insert.load(Ordering::SeqCst); - reply_tx.send(v).unwrap(); + reply_tx.send(v as u32).unwrap(); }, SelectPeers(_, reply_tx) => { let lock = self.state.select_peers.read().unwrap(); diff --git a/comms/dht/src/tower_filter/predicate.rs b/comms/dht/src/tower_filter/predicate.rs deleted file mode 100644 index f86b9cc406..0000000000 --- a/comms/dht/src/tower_filter/predicate.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::future::Future; -use tari_comms::pipeline::PipelineError; - -/// Checks a request -pub trait Predicate { - /// The future returned by `check`. - type Future: Future>; - - /// Check whether the given request should be forwarded. - /// - /// If the future resolves with `Ok`, the request is forwarded to the inner service. - fn check(&mut self, request: &Request) -> Self::Future; -} - -impl Predicate for F -where - F: Fn(&T) -> U, - U: Future>, -{ - type Future = U; - - fn check(&mut self, request: &T) -> Self::Future { - self(request) - } -} diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index a5aed09970..d917b8e192 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -113,19 +113,22 @@ fn create_peer_storage() -> CommsDatabase { LMDBWrapper::new(Arc::new(peer_database)) } -async fn make_node(features: PeerFeatures, seed_peer: Option) -> TestNode { +async fn make_node>(features: PeerFeatures, known_peers: I) -> TestNode { let node_identity = make_node_identity(features); - make_node_with_node_identity(node_identity, seed_peer).await + make_node_with_node_identity(node_identity, known_peers).await } -async fn make_node_with_node_identity(node_identity: Arc, seed_peer: Option) -> TestNode { +async fn make_node_with_node_identity>( + node_identity: Arc, + known_peers: I, +) -> TestNode { let (tx, inbound_messages) = mpsc::channel(10); let shutdown = Shutdown::new(); let (comms, dht, messaging_events) = setup_comms_dht( node_identity, create_peer_storage(), tx, - seed_peer.into_iter().collect(), + known_peers.into_iter().collect(), shutdown.to_signal(), ) .await; @@ -173,6 +176,7 @@ async fn setup_comms_dht( .with_database_url(DbConnectionUrl::MemoryShared(random::string(8))) .with_discovery_timeout(Duration::from_secs(60)) .with_num_neighbouring_nodes(8) + .with_dedup_discard_hit_count(3) .build() .await .unwrap(); @@ -482,8 +486,7 @@ async fn dht_propagate_dedup() { .dht .outbound_requester() .propagate( - // Node D is a client node, so an destination is required for domain messages - NodeDestination::Unknown, // NodeId(Box::new(node_D.node_identity().node_id().clone())), + NodeDestination::Unknown, OutboundEncryption::EncryptFor(Box::new(node_D.node_identity().public_key().clone())), vec![], out_msg, @@ -496,6 +499,7 @@ async fn dht_propagate_dedup() { .await .expect("Node D expected an inbound message but it never arrived"); assert!(msg.decryption_succeeded()); + log::info!("Received message {}", msg.tag); let person = msg .decryption_result .unwrap() @@ -536,6 +540,85 @@ async fn dht_propagate_dedup() { assert_eq!(count_messages_received(&received, &[&node_C_id]), 1); } +#[tokio_macros::test] +#[allow(non_snake_case)] +async fn dht_repropagate() { + // Node C knows no one + let mut node_C = make_node(PeerFeatures::COMMUNICATION_NODE, []).await; + // Node B knows about Node C + let mut node_B = make_node(PeerFeatures::COMMUNICATION_NODE, [node_C.to_peer()]).await; + // Node A knows about Node B and C + let mut node_A = make_node(PeerFeatures::COMMUNICATION_NODE, [node_B.to_peer(), node_C.to_peer()]).await; + node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); + log::info!( + "NodeA = {}, NodeB = {}, Node C = {}", + node_A.node_identity().node_id().short_str(), + node_B.node_identity().node_id().short_str(), + node_C.node_identity().node_id().short_str(), + ); + + // Connect the peers that should be connected + async fn connect_nodes(node1: &mut TestNode, node2: &mut TestNode) { + node1 + .comms + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + } + // Pre-connect nodes, this helps message passing be more deterministic + connect_nodes(&mut node_A, &mut node_B).await; + connect_nodes(&mut node_B, &mut node_C).await; + connect_nodes(&mut node_C, &mut node_A).await; + + #[derive(Clone, PartialEq, ::prost::Message)] + struct Person { + #[prost(string, tag = "1")] + name: String, + #[prost(uint32, tag = "2")] + age: u32, + } + + let out_msg = OutboundDomainMessage::new(123, Person { + name: "John Conway".into(), + age: 82, + }); + node_A + .dht + .outbound_requester() + .propagate( + NodeDestination::Unknown, + OutboundEncryption::ClearText, + vec![], + out_msg.clone(), + ) + .await + .unwrap(); + + let msg = node_B + .next_inbound_message(Duration::from_secs(10)) + .await + .expect("Node expected an inbound message but it never arrived"); + log::info!("Received message {}", msg.tag); + + node_C + .dht + .outbound_requester() + .propagate(NodeDestination::Unknown, OutboundEncryption::ClearText, vec![], out_msg) + .await + .unwrap(); + + let msg = node_A + .next_inbound_message(Duration::from_secs(10)) + .await + .expect("Node expected an inbound message but it never arrived"); + log::info!("Received duplicate message {}", msg.tag); + + node_A.shutdown().await; + node_B.shutdown().await; + node_C.shutdown().await; +} + #[tokio_macros::test] #[allow(non_snake_case)] async fn dht_propagate_message_contents_not_malleable_ban() {