From 3602573adb81bc6f2be980010abe6c39217b04bb Mon Sep 17 00:00:00 2001 From: Stanimal Date: Thu, 12 Aug 2021 10:56:48 +0400 Subject: [PATCH] feat(dht): allow messages to be repropagated for a number of rounds (gossip) Use the dedup cache hit count to allow certain duplicate messages through a configurable number of times. This improves mempool synchronization. --- comms/dht/examples/memorynet.rs | 4 +- comms/dht/src/actor.rs | 153 ++++++----- comms/dht/src/builder.rs | 5 + comms/dht/src/config.rs | 5 + comms/dht/src/dedup/dedup_cache.rs | 114 ++++---- comms/dht/src/dedup/mod.rs | 39 ++- comms/dht/src/dht.rs | 60 +++-- comms/dht/src/domain_message.rs | 2 +- comms/dht/src/envelope.rs | 6 +- .../src/{tower_filter => filter}/future.rs | 31 +-- .../dht/src/{tower_filter => filter}/layer.rs | 0 comms/dht/src/{tower_filter => filter}/mod.rs | 4 +- comms/dht/src/filter/predicate.rs | 13 + comms/dht/src/inbound/dht_handler/task.rs | 14 + comms/dht/src/inbound/message.rs | 15 +- comms/dht/src/lib.rs | 2 +- comms/dht/src/outbound/broadcast.rs | 22 +- comms/dht/src/outbound/error.rs | 4 +- comms/dht/src/outbound/message_params.rs | 10 +- comms/dht/src/storage/error.rs | 2 + comms/dht/src/store_forward/forward.rs | 5 +- .../dht/src/store_forward/saf_handler/task.rs | 17 +- comms/dht/src/store_forward/store.rs | 33 ++- comms/dht/src/test_utils/dht_actor_mock.rs | 23 +- comms/dht/src/tower_filter/predicate.rs | 25 -- comms/dht/tests/dht.rs | 248 ++++++++++++++++-- 26 files changed, 608 insertions(+), 248 deletions(-) rename comms/dht/src/{tower_filter => filter}/future.rs (66%) rename comms/dht/src/{tower_filter => filter}/layer.rs (100%) rename comms/dht/src/{tower_filter => filter}/mod.rs (92%) create mode 100644 comms/dht/src/filter/predicate.rs delete mode 100644 comms/dht/src/tower_filter/predicate.rs diff --git a/comms/dht/examples/memorynet.rs b/comms/dht/examples/memorynet.rs index 9cc28551bc..c09da9d9a7 100644 --- a/comms/dht/examples/memorynet.rs +++ b/comms/dht/examples/memorynet.rs @@ -55,9 +55,9 @@ use std::{iter::repeat_with, time::Duration}; use tari_comms::peer_manager::PeerFeatures; // Size of network -const NUM_NODES: usize = 6; +const NUM_NODES: usize = 30; // Must be at least 2 -const NUM_WALLETS: usize = 50; +const NUM_WALLETS: usize = 5; const QUIET_MODE: bool = true; /// Number of neighbouring nodes each node should include in the connection pool const NUM_NEIGHBOURING_NODES: usize = 8; diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index c2c2d4e52a..326d465458 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -51,6 +51,7 @@ use tari_comms::{ peer_manager::{NodeId, NodeIdentity, PeerFeatures, PeerManager, PeerManagerError, PeerQuery, PeerQuerySortBy}, types::CommsPublicKey, }; +use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::ShutdownSignal; use tari_utilities::message_format::{MessageFormat, MessageFormatError}; use thiserror::Error; @@ -101,9 +102,14 @@ impl From for DhtActorError { pub enum DhtRequest { /// Send a Join request to the network SendJoin, - /// Inserts a message signature to the msg hash cache. This operation replies with a boolean - /// which is true if the signature already exists in the cache, otherwise false - MsgHashCacheInsert(Vec, CommsPublicKey, oneshot::Sender), + /// Inserts a message signature to the msg hash cache. This operation replies with the number of times this message + /// has previously been seen (hit count) + MsgHashCacheInsert { + message_hash: Vec, + received_from: CommsPublicKey, + reply_tx: oneshot::Sender, + }, + GetMsgHashHitCount(Vec, oneshot::Sender), /// Fetch selected peers according to the broadcast strategy SelectPeers(BroadcastStrategy, oneshot::Sender>), GetMetadata(DhtMetadataKey, oneshot::Sender>, DhtActorError>>), @@ -114,12 +120,22 @@ impl Display for DhtRequest { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use DhtRequest::*; match self { - SendJoin => f.write_str("SendJoin"), - MsgHashCacheInsert(_, _, _) => f.write_str("MsgHashCacheInsert"), - SelectPeers(s, _) => f.write_str(&format!("SelectPeers (Strategy={})", s)), - GetMetadata(key, _) => f.write_str(&format!("GetMetadata (key={})", key)), + SendJoin => write!(f, "SendJoin"), + MsgHashCacheInsert { + message_hash, + received_from, + .. + } => write!( + f, + "MsgHashCacheInsert(message hash: {}, received from: {})", + message_hash.to_hex(), + received_from.to_hex(), + ), + GetMsgHashHitCount(hash, _) => write!(f, "GetMsgHashHitCount({})", hash.to_hex()), + SelectPeers(s, _) => write!(f, "SelectPeers (Strategy={})", s), + GetMetadata(key, _) => write!(f, "GetMetadata (key={})", key), SetMetadata(key, value, _) => { - f.write_str(&format!("SetMetadata (key={}, value={} bytes)", key, value.len())) + write!(f, "SetMetadata (key={}, value={} bytes)", key, value.len()) }, } } @@ -147,14 +163,27 @@ impl DhtRequester { reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled) } - pub async fn insert_message_hash( + pub async fn add_message_to_dedup_cache( &mut self, message_hash: Vec, - public_key: CommsPublicKey, - ) -> Result { + received_from: CommsPublicKey, + ) -> Result { + let (reply_tx, reply_rx) = oneshot::channel(); + self.sender + .send(DhtRequest::MsgHashCacheInsert { + message_hash, + received_from, + reply_tx, + }) + .await?; + + reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled) + } + + pub async fn get_message_cache_hit_count(&mut self, message_hash: Vec) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); self.sender - .send(DhtRequest::MsgHashCacheInsert(message_hash, public_key, reply_tx)) + .send(DhtRequest::GetMsgHashHitCount(message_hash, reply_tx)) .await?; reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled) @@ -268,7 +297,7 @@ impl DhtActor { }, _ = dedup_cache_trim_ticker.select_next_some() => { - if let Err(err) = self.msg_hash_dedup_cache.truncate().await { + if let Err(err) = self.msg_hash_dedup_cache.trim_entries().await { error!(target: LOG_TARGET, "Error when trimming message dedup cache: {:?}", err); } }, @@ -300,24 +329,36 @@ impl DhtActor { let outbound_requester = self.outbound_requester.clone(); Box::pin(Self::broadcast_join(node_identity, outbound_requester)) }, - MsgHashCacheInsert(hash, public_key, reply_tx) => { + MsgHashCacheInsert { + message_hash, + received_from, + reply_tx, + } => { let msg_hash_cache = self.msg_hash_dedup_cache.clone(); Box::pin(async move { - match msg_hash_cache.insert_body_hash_if_unique(hash, public_key).await { - Ok(already_exists) => { - let _ = reply_tx.send(already_exists).map_err(|_| DhtActorError::ReplyCanceled); + match msg_hash_cache.add_body_hash(message_hash, received_from).await { + Ok(hit_count) => { + let _ = reply_tx.send(hit_count); }, Err(err) => { warn!( target: LOG_TARGET, "Unable to update message dedup cache because {:?}", err ); - let _ = reply_tx.send(false).map_err(|_| DhtActorError::ReplyCanceled); + let _ = reply_tx.send(0); }, } Ok(()) }) }, + GetMsgHashHitCount(hash, reply_tx) => { + let msg_hash_cache = self.msg_hash_dedup_cache.clone(); + Box::pin(async move { + let hit_count = msg_hash_cache.get_hit_count(hash).await?; + let _ = reply_tx.send(hit_count); + Ok(()) + }) + }, SelectPeers(broadcast_strategy, reply_tx) => { let peer_manager = Arc::clone(&self.peer_manager); let node_identity = Arc::clone(&self.node_identity); @@ -690,11 +731,9 @@ mod test { test_utils::{build_peer_manager, make_client_identity, make_node_identity}, }; use chrono::{DateTime, Utc}; - use std::time::Duration; use tari_comms::test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair}; use tari_shutdown::Shutdown; use tari_test_utils::random; - use tokio::time::delay_for; async fn db_connection() -> DbConnection { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); @@ -756,21 +795,21 @@ mod test { actor.spawn(); let signature = vec![1u8, 2, 3]; - let is_dup = requester - .insert_message_hash(signature.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(signature.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); - let is_dup = requester - .insert_message_hash(signature, CommsPublicKey::default()) + assert_eq!(num_hits, 1); + let num_hits = requester + .add_message_to_dedup_cache(signature, CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); - let is_dup = requester - .insert_message_hash(Vec::new(), CommsPublicKey::default()) + assert_eq!(num_hits, 2); + let num_hits = requester + .add_message_to_dedup_cache(Vec::new(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } #[tokio_macros::test_basic] @@ -783,14 +822,12 @@ mod test { let (actor_tx, actor_rx) = mpsc::channel(1); let mut requester = DhtRequester::new(actor_tx); let outbound_requester = OutboundMessageRequester::new(out_tx); - let mut shutdown = Shutdown::new(); - let trim_interval_ms = 500; + let shutdown = Shutdown::new(); // Note: This must be equal or larger than the minimum dedup cache capacity for DedupCacheDatabase - let capacity = 120; + let capacity = 10; let actor = DhtActor::new( DhtConfig { dedup_cache_capacity: capacity, - dedup_cache_trim_interval: Duration::from_millis(trim_interval_ms), ..Default::default() }, db_connection().await, @@ -803,63 +840,61 @@ mod test { ); // Create signatures for double the dedup cache capacity - let mut signatures: Vec> = Vec::new(); - for i in 0..(capacity * 2) { - signatures.push(vec![1u8, 2, i as u8]) - } + let signatures = (0..(capacity * 2)).map(|i| vec![1u8, 2, i as u8]).collect::>(); - // Pre-populate the dedup cache; everything should be accepted due to cleanup ticker not active yet + // Pre-populate the dedup cache; everything should be accepted because the cleanup ticker has not run yet for key in &signatures { - let is_dup = actor + let num_hits = actor .msg_hash_dedup_cache - .insert_body_hash_if_unique(key.clone(), CommsPublicKey::default()) + .add_body_hash(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - // Try to re-insert all; everything should be marked as duplicates due to cleanup ticker not active yet + // Try to re-insert all; all hashes should have incremented their hit count for key in &signatures { - let is_dup = actor + let num_hits = actor .msg_hash_dedup_cache - .insert_body_hash_if_unique(key.clone(), CommsPublicKey::default()) + .add_body_hash(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); + assert_eq!(num_hits, 2); } - // The cleanup ticker starts when the actor is spawned; the first cleanup event will fire immediately + let dedup_cache_db = actor.msg_hash_dedup_cache.clone(); + // The cleanup ticker starts when the actor is spawned; the first cleanup event will fire fairly soon after the + // task is running on a thread. To remove this race condition, we trim the cache in the test. + dedup_cache_db.trim_entries().await.unwrap(); actor.spawn(); // Verify that the last half of the signatures are still present in the cache for key in signatures.iter().take(capacity * 2).skip(capacity) { - let is_dup = requester - .insert_message_hash(key.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); + assert_eq!(num_hits, 3); } // Verify that the first half of the signatures have been removed and can be re-inserted into cache for key in signatures.iter().take(capacity) { - let is_dup = requester - .insert_message_hash(key.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - // Let the trim period expire; this will trim the dedup cache to capacity - delay_for(Duration::from_millis(trim_interval_ms * 2)).await; + // Trim the database of excess entries + dedup_cache_db.trim_entries().await.unwrap(); // Verify that the last half of the signatures have been removed and can be re-inserted into cache for key in signatures.iter().take(capacity * 2).skip(capacity) { - let is_dup = requester - .insert_message_hash(key.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - - shutdown.trigger().unwrap(); } #[tokio_macros::test_basic] diff --git a/comms/dht/src/builder.rs b/comms/dht/src/builder.rs index 249ed3d369..382d28cf4a 100644 --- a/comms/dht/src/builder.rs +++ b/comms/dht/src/builder.rs @@ -99,6 +99,11 @@ impl DhtBuilder { self } + pub fn with_dedup_discard_hit_count(mut self, max_hit_count: usize) -> Self { + self.config.dedup_discard_hit_count = max_hit_count; + self + } + pub fn with_num_random_nodes(mut self, n: usize) -> Self { self.config.num_random_nodes = n; self diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 0612445dca..09c36dcef0 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -72,6 +72,9 @@ pub struct DhtConfig { /// The periodic trim interval for items in the message hash cache /// Default: 300s (5 mins) pub dedup_cache_trim_interval: Duration, + /// The number of hits a message is allowed before being discarded + /// Default: 3 + pub dedup_discard_hit_count: usize, /// The duration to wait for a peer discovery to complete before giving up. /// Default: 2 minutes pub discovery_request_timeout: Duration, @@ -136,6 +139,7 @@ impl DhtConfig { impl Default for DhtConfig { fn default() -> Self { + // NB: please remember to update field comments to reflect these defaults Self { num_neighbouring_nodes: 8, num_random_nodes: 4, @@ -151,6 +155,7 @@ impl Default for DhtConfig { saf_max_message_size: 512 * 1024, dedup_cache_capacity: 2_500, dedup_cache_trim_interval: Duration::from_secs(5 * 60), + dedup_discard_hit_count: 3, database_url: DbConnectionUrl::Memory, discovery_request_timeout: Duration::from_secs(2 * 60), connectivity_update_interval: Duration::from_secs(2 * 60), diff --git a/comms/dht/src/dedup/dedup_cache.rs b/comms/dht/src/dedup/dedup_cache.rs index f8f5f6fcbf..8364f020a0 100644 --- a/comms/dht/src/dedup/dedup_cache.rs +++ b/comms/dht/src/dedup/dedup_cache.rs @@ -24,15 +24,23 @@ use crate::{ schema::dedup_cache, storage::{DbConnection, StorageError}, }; -use chrono::Utc; -use diesel::{dsl, result::DatabaseErrorKind, ExpressionMethods, QueryDsl, RunQueryDsl}; +use chrono::{NaiveDateTime, Utc}; +use diesel::{dsl, result::DatabaseErrorKind, ExpressionMethods, OptionalExtension, QueryDsl, RunQueryDsl}; use log::*; use tari_comms::types::CommsPublicKey; -use tari_crypto::tari_utilities::{hex::Hex, ByteArray}; -use tari_utilities::hex; +use tari_crypto::tari_utilities::hex::Hex; const LOG_TARGET: &str = "comms::dht::dedup_cache"; +#[derive(Queryable, PartialEq, Debug)] +struct DedupCacheEntry { + body_hash: String, + sender_public_ke: String, + number_of_hit: i32, + stored_at: NaiveDateTime, + last_hit_at: NaiveDateTime, +} + #[derive(Clone)] pub struct DedupCacheDatabase { connection: DbConnection, @@ -48,36 +56,40 @@ impl DedupCacheDatabase { Self { connection, capacity } } - /// Inserts and returns Ok(true) if the item already existed and Ok(false) if it didn't, also updating hit stats - pub async fn insert_body_hash_if_unique( - &self, - body_hash: Vec, - public_key: CommsPublicKey, - ) -> Result { - let body_hash = hex::to_hex(&body_hash.as_bytes()); - let public_key = public_key.to_hex(); - match self - .insert_body_hash_or_update_stats(body_hash.clone(), public_key.clone()) - .await - { - Ok(val) => { - if val == 0 { - warn!( - target: LOG_TARGET, - "Unable to insert new entry into message dedup cache" - ); - } - Ok(false) - }, - Err(e) => match e { - StorageError::UniqueViolation(_) => Ok(true), - _ => Err(e), - }, + /// Adds the body hash to the cache, returning the number of hits (inclusive) that have been recorded for this body + /// hash + pub async fn add_body_hash(&self, body_hash: Vec, public_key: CommsPublicKey) -> Result { + let hit_count = self + .insert_body_hash_or_update_stats(body_hash.to_hex(), public_key.to_hex()) + .await?; + + if hit_count == 0 { + warn!( + target: LOG_TARGET, + "Unable to insert new entry into message dedup cache" + ); } + Ok(hit_count) + } + + pub async fn get_hit_count(&self, body_hash: Vec) -> Result { + let hit_count = self + .connection + .with_connection_async(move |conn| { + dedup_cache::table + .select(dedup_cache::number_of_hits) + .filter(dedup_cache::body_hash.eq(&body_hash.to_hex())) + .get_result::(conn) + .optional() + .map_err(Into::into) + }) + .await?; + + Ok(hit_count.unwrap_or(0) as u32) } /// Trims the dedup cache to the configured limit by removing the oldest entries - pub async fn truncate(&self) -> Result { + pub async fn trim_entries(&self) -> Result { let capacity = self.capacity; self.connection .with_connection_async(move |conn| { @@ -109,40 +121,46 @@ impl DedupCacheDatabase { .await } - // Insert new row into the table or update existing row in an atomic fashion; more than one thread can access this - // table at the same time. + /// Insert new row into the table or updates an existing row. Returns the number of hits for this body hash. async fn insert_body_hash_or_update_stats( &self, body_hash: String, public_key: String, - ) -> Result { + ) -> Result { self.connection .with_connection_async(move |conn| { let insert_result = diesel::insert_into(dedup_cache::table) .values(( - dedup_cache::body_hash.eq(body_hash.clone()), - dedup_cache::sender_public_key.eq(public_key.clone()), + dedup_cache::body_hash.eq(&body_hash), + dedup_cache::sender_public_key.eq(&public_key), dedup_cache::number_of_hits.eq(1), dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), )) .execute(conn); match insert_result { - Ok(val) => Ok(val), + Ok(1) => Ok(1), + Ok(n) => Err(StorageError::UnexpectedResult(format!( + "Expected exactly one row to be inserted. Got {}", + n + ))), Err(diesel::result::Error::DatabaseError(kind, e_info)) => match kind { DatabaseErrorKind::UniqueViolation => { // Update hit stats for the message - let result = - diesel::update(dedup_cache::table.filter(dedup_cache::body_hash.eq(&body_hash))) - .set(( - dedup_cache::sender_public_key.eq(public_key), - dedup_cache::number_of_hits.eq(dedup_cache::number_of_hits + 1), - dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), - )) - .execute(conn); - match result { - Ok(_) => Err(StorageError::UniqueViolation(body_hash)), - Err(e) => Err(e.into()), - } + diesel::update(dedup_cache::table.filter(dedup_cache::body_hash.eq(&body_hash))) + .set(( + dedup_cache::sender_public_key.eq(&public_key), + dedup_cache::number_of_hits.eq(dedup_cache::number_of_hits + 1), + dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), + )) + .execute(conn)?; + // TODO: Diesel support for RETURNING statements would remove this query, but is not + // available for Diesel + SQLite yet + let hits = dedup_cache::table + .select(dedup_cache::number_of_hits) + .filter(dedup_cache::body_hash.eq(&body_hash)) + .get_result::(conn)?; + + Ok(hits as u32) }, _ => Err(diesel::result::Error::DatabaseError(kind, e_info).into()), }, diff --git a/comms/dht/src/dedup/mod.rs b/comms/dht/src/dedup/mod.rs index 5428277af0..c5a078010f 100644 --- a/comms/dht/src/dedup/mod.rs +++ b/comms/dht/src/dedup/mod.rs @@ -47,13 +47,15 @@ fn hash_inbound_message(message: &DhtInboundMessage) -> Vec { pub struct DedupMiddleware { next_service: S, dht_requester: DhtRequester, + discard_on_hit_count: usize, } impl DedupMiddleware { - pub fn new(service: S, dht_requester: DhtRequester) -> Self { + pub fn new(service: S, dht_requester: DhtRequester, discard_on_hit_count: usize) -> Self { Self { next_service: service, dht_requester, + discard_on_hit_count, } } } @@ -71,9 +73,10 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, message: DhtInboundMessage) -> Self::Future { + fn call(&mut self, mut message: DhtInboundMessage) -> Self::Future { let next_service = self.next_service.clone(); let mut dht_requester = self.dht_requester.clone(); + let discard_on_hit_count = self.discard_on_hit_count; Box::pin(async move { let hash = hash_inbound_message(&message); trace!( @@ -83,14 +86,17 @@ where message.tag, message.dht_header.message_tag ); - if dht_requester - .insert_message_hash(hash, message.source_peer.public_key.clone()) - .await? - { + + message.dedup_hit_count = dht_requester + .add_message_to_dedup_cache(hash, message.source_peer.public_key.clone()) + .await?; + + if message.dedup_hit_count as usize >= discard_on_hit_count { trace!( target: LOG_TARGET, - "Received duplicate message {} from peer '{}' (Trace: {}). Message discarded.", + "Received duplicate message {} (hit_count = {}) from peer '{}' (Trace: {}). Message discarded.", message.tag, + message.dedup_hit_count, message.source_peer.node_id.short_str(), message.dht_header.message_tag, ); @@ -99,8 +105,9 @@ where trace!( target: LOG_TARGET, - "Passing message {} onto next service (Trace: {})", + "Passing message {} (hit_count = {}) onto next service (Trace: {})", message.tag, + message.dedup_hit_count, message.dht_header.message_tag ); next_service.oneshot(message).await @@ -110,11 +117,15 @@ where pub struct DedupLayer { dht_requester: DhtRequester, + discard_on_hit_count: usize, } impl DedupLayer { - pub fn new(dht_requester: DhtRequester) -> Self { - Self { dht_requester } + pub fn new(dht_requester: DhtRequester, discard_on_hit_count: usize) -> Self { + Self { + dht_requester, + discard_on_hit_count, + } } } @@ -122,7 +133,7 @@ impl Layer for DedupLayer { type Service = DedupMiddleware; fn layer(&self, service: S) -> Self::Service { - DedupMiddleware::new(service, self.dht_requester.clone()) + DedupMiddleware::new(service, self.dht_requester.clone(), self.discard_on_hit_count) } } @@ -143,10 +154,10 @@ mod test { let (dht_requester, mock) = create_dht_actor_mock(1); let mock_state = mock.get_shared_state(); - mock_state.set_signature_cache_insert(false); + mock_state.set_number_of_message_hits(1); rt.spawn(mock.run()); - let mut dedup = DedupLayer::new(dht_requester).layer(spy.to_service::()); + let mut dedup = DedupLayer::new(dht_requester, 3).layer(spy.to_service::()); panic_context!(cx); @@ -157,7 +168,7 @@ mod test { rt.block_on(dedup.call(msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); - mock_state.set_signature_cache_insert(true); + mock_state.set_number_of_message_hits(3); rt.block_on(dedup.call(msg)).unwrap(); assert_eq!(spy.call_count(), 1); // Drop dedup so that the DhtMock will stop running diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index dcdeea5730..32c778f8ca 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -26,6 +26,7 @@ use crate::{ connectivity::{DhtConnectivity, MetricsCollector, MetricsCollectorHandle}, discovery::{DhtDiscoveryRequest, DhtDiscoveryRequester, DhtDiscoveryService}, event::{DhtEventReceiver, DhtEventSender}, + filter, inbound, inbound::{DecryptedDhtMessage, DhtInboundMessage, MetricsLayer}, logging_middleware::MessageLoggingLayer, @@ -37,12 +38,11 @@ use crate::{ storage::{DbConnection, StorageError}, store_forward, store_forward::{StoreAndForwardError, StoreAndForwardRequest, StoreAndForwardRequester, StoreAndForwardService}, - tower_filter, DedupLayer, DhtActorError, DhtConfig, }; -use futures::{channel::mpsc, future, Future}; +use futures::{channel::mpsc, Future}; use log::*; use std::sync::Arc; use tari_comms::{ @@ -285,13 +285,14 @@ impl Dht { S: Service + Clone + Send + Sync + 'static, S::Future: Send, { - // FIXME: There is an unresolved stack overflow issue on windows in debug mode during runtime, but not in - // release mode, related to the amount of layers. (issue #1416) ServiceBuilder::new() .layer(MetricsLayer::new(self.metrics_collector.clone())) .layer(inbound::DeserializeLayer::new(self.peer_manager.clone())) - .layer(DedupLayer::new(self.dht_requester())) - .layer(tower_filter::FilterLayer::new(self.unsupported_saf_messages_filter())) + .layer(DedupLayer::new( + self.dht_requester(), + self.config.dedup_discard_hit_count, + )) + .layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter())) .layer(MessageLoggingLayer::new(format!( "Inbound [{}]", self.node_identity.node_id().short_str() @@ -301,6 +302,7 @@ impl Dht { self.node_identity.clone(), self.connectivity.clone(), )) + .layer(filter::FilterLayer::new(filter_messages_to_rebroadcast)) .layer(store_forward::StoreLayer::new( self.config.clone(), Arc::clone(&self.peer_manager), @@ -363,34 +365,60 @@ impl Dht { /// Produces a filter predicate which disallows store and forward messages if that feature is not /// supported by the node. - fn unsupported_saf_messages_filter( - &self, - ) -> impl tower_filter::Predicate>> + Clone + Send - { + fn unsupported_saf_messages_filter(&self) -> impl filter::Predicate + Clone + Send { let node_identity = Arc::clone(&self.node_identity); move |msg: &DhtInboundMessage| { if node_identity.has_peer_features(PeerFeatures::DHT_STORE_FORWARD) { - return future::ready(Ok(())); + return true; } match msg.dht_header.message_type { DhtMessageType::SafRequestMessages => { // TODO: #banheuristic This is an indication of node misbehaviour - debug!( + warn!( "Received store and forward message from PublicKey={}. Store and forward feature is not \ supported by this node. Discarding message.", msg.source_peer.public_key ); - future::ready(Err(anyhow::anyhow!( - "Message filtered out because store and forward is not supported by this node", - ))) + false }, - _ => future::ready(Ok(())), + _ => true, } } } } +/// Provides the gossip filtering rules for an inbound message +fn filter_messages_to_rebroadcast(msg: &DecryptedDhtMessage) -> bool { + // Let the message through if: + // it isn't a duplicate (normal message), or + let should_continue = !msg.is_duplicate() || + ( + // it is a duplicate domain message (i.e. not DHT or SAF protocol message), and + msg.dht_header.message_type.is_domain_message() && + // it has an unknown destination (e.g complete transactions, blocks, misc. encrypted + // messages) we allow it to proceed, which in turn, re-propagates it for another round. + msg.dht_header.destination.is_unknown() + ); + + if should_continue { + // The message has been forwarded, but downstream middleware may be interested + debug!( + target: LOG_TARGET, + "[filter_messages_to_rebroadcast] Passing message {} to next service (Trace: {})", + msg.tag, + msg.dht_header.message_tag + ); + true + } else { + debug!( + target: LOG_TARGET, + "[filter_messages_to_rebroadcast] Discarding duplicate message {}", msg + ); + false + } +} + #[cfg(test)] mod test { use crate::{ diff --git a/comms/dht/src/domain_message.rs b/comms/dht/src/domain_message.rs index 2fe7af16fe..f565882725 100644 --- a/comms/dht/src/domain_message.rs +++ b/comms/dht/src/domain_message.rs @@ -33,7 +33,7 @@ impl ToProtoEnum for i32 { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct OutboundDomainMessage { inner: T, message_type: i32, diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index 0b93546dbb..fb7e02050f 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -106,8 +106,12 @@ impl DhtMessageFlags { } impl DhtMessageType { + pub fn is_domain_message(self) -> bool { + matches!(self, DhtMessageType::None) + } + pub fn is_dht_message(self) -> bool { - self.is_dht_discovery() || self.is_dht_join() + self.is_dht_discovery() || matches!(self, DhtMessageType::DiscoveryResponse) || self.is_dht_join() } pub fn is_dht_discovery(self) -> bool { diff --git a/comms/dht/src/tower_filter/future.rs b/comms/dht/src/filter/future.rs similarity index 66% rename from comms/dht/src/tower_filter/future.rs rename to comms/dht/src/filter/future.rs index 78b2c613e6..4559aeaadf 100644 --- a/comms/dht/src/tower_filter/future.rs +++ b/comms/dht/src/filter/future.rs @@ -13,16 +13,15 @@ use tower::Service; /// Filtered response future #[pin_project] #[derive(Debug)] -pub struct ResponseFuture +pub struct ResponseFuture where S: Service { #[pin] /// Response future state state: State, - #[pin] - /// Predicate future - check: T, + /// Predicate result + check: bool, /// Inner service service: S, @@ -35,12 +34,10 @@ enum State { WaitResponse(#[pin] U), } -impl ResponseFuture -where - F: Future>, - S: Service, +impl ResponseFuture +where S: Service { - pub(crate) fn new(request: Request, check: F, service: S) -> Self { + pub(crate) fn new(request: Request, check: bool, service: S) -> Self { ResponseFuture { state: State::Check(Some(request)), check, @@ -49,10 +46,8 @@ where } } -impl Future for ResponseFuture -where - F: Future>, - S: Service, +impl Future for ResponseFuture +where S: Service { type Output = Result; @@ -66,15 +61,13 @@ where .take() .expect("we either give it back or leave State::Check once we take"); - // Poll predicate - match this.check.as_mut().poll(cx)? { - Poll::Ready(_) => { + match this.check { + true => { let response = this.service.call(request); this.state.set(State::WaitResponse(response)); }, - Poll::Pending => { - this.state.set(State::Check(Some(request))); - return Poll::Pending; + false => { + return Poll::Ready(Ok(())); }, } }, diff --git a/comms/dht/src/tower_filter/layer.rs b/comms/dht/src/filter/layer.rs similarity index 100% rename from comms/dht/src/tower_filter/layer.rs rename to comms/dht/src/filter/layer.rs diff --git a/comms/dht/src/tower_filter/mod.rs b/comms/dht/src/filter/mod.rs similarity index 92% rename from comms/dht/src/tower_filter/mod.rs rename to comms/dht/src/filter/mod.rs index d1df2f27a7..e7f168161b 100644 --- a/comms/dht/src/tower_filter/mod.rs +++ b/comms/dht/src/filter/mod.rs @@ -33,11 +33,11 @@ impl Filter { impl Service for Filter where - T: Service + Clone, + T: Service + Clone, U: Predicate, { type Error = PipelineError; - type Future = ResponseFuture; + type Future = ResponseFuture; type Response = T::Response; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { diff --git a/comms/dht/src/filter/predicate.rs b/comms/dht/src/filter/predicate.rs new file mode 100644 index 0000000000..024dee826d --- /dev/null +++ b/comms/dht/src/filter/predicate.rs @@ -0,0 +1,13 @@ +/// Checks a request +pub trait Predicate { + /// Check whether the given request should be forwarded. + fn check(&mut self, request: &Request) -> bool; +} + +impl Predicate for F +where F: Fn(&T) -> bool +{ + fn check(&mut self, request: &T) -> bool { + self(request) + } +} diff --git a/comms/dht/src/inbound/dht_handler/task.rs b/comms/dht/src/inbound/dht_handler/task.rs index f45507a905..ec42bbd4fe 100644 --- a/comms/dht/src/inbound/dht_handler/task.rs +++ b/comms/dht/src/inbound/dht_handler/task.rs @@ -88,6 +88,20 @@ where S: Service return Ok(()); } + if message.is_duplicate() { + debug!( + target: LOG_TARGET, + "Received message ({}) that has already been received {} time(s). Last sent by peer '{}', passing on \ + to next service (Trace: {})", + message.tag, + message.dedup_hit_count, + message.source_peer.node_id.short_str(), + message.dht_header.message_tag, + ); + self.next_service.oneshot(message).await?; + return Ok(()); + } + trace!( target: LOG_TARGET, "Received DHT message type `{}` (Source peer: {}, Tag: {}, Trace: {})", diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index a49ae4b073..c9cdd103fd 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -43,6 +43,7 @@ pub struct DhtInboundMessage { pub dht_header: DhtMessageHeader, /// True if forwarded via store and forward, otherwise false pub is_saf_message: bool, + pub dedup_hit_count: u32, pub body: Vec, } impl DhtInboundMessage { @@ -53,6 +54,7 @@ impl DhtInboundMessage { dht_header, source_peer, is_saf_message: false, + dedup_hit_count: 0, body, } } @@ -62,11 +64,12 @@ impl Display for DhtInboundMessage { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> { write!( f, - "\n---- Inbound Message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\n{}\n----", + "\n---- Inbound Message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHit Count: {}\nHeader: {}\n{}\n----", self.body.len(), self.dht_header.message_type, self.source_peer, self.dht_header, + self.dedup_hit_count, self.tag, ) } @@ -86,6 +89,14 @@ pub struct DecryptedDhtMessage { pub is_saf_stored: Option, pub is_already_forwarded: bool, pub decryption_result: Result>, + pub dedup_hit_count: u32, +} + +impl DecryptedDhtMessage { + /// Returns true if this message has been received before, otherwise false if this is the first time + pub fn is_duplicate(&self) -> bool { + self.dedup_hit_count > 1 + } } impl DecryptedDhtMessage { @@ -104,6 +115,7 @@ impl DecryptedDhtMessage { is_saf_stored: None, is_already_forwarded: false, decryption_result: Ok(message_body), + dedup_hit_count: message.dedup_hit_count, } } @@ -118,6 +130,7 @@ impl DecryptedDhtMessage { is_saf_stored: None, is_already_forwarded: false, decryption_result: Err(message.body), + dedup_hit_count: message.dedup_hit_count, } } diff --git a/comms/dht/src/lib.rs b/comms/dht/src/lib.rs index cab2f8ab6f..7ef238d16e 100644 --- a/comms/dht/src/lib.rs +++ b/comms/dht/src/lib.rs @@ -153,11 +153,11 @@ pub use storage::DbConnectionUrl; mod dedup; pub use dedup::DedupLayer; +mod filter; mod logging_middleware; mod proto; mod rpc; mod schema; -mod tower_filter; pub mod broadcast_strategy; pub mod domain_message; diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 0aa9fab611..3107b41fc5 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -251,6 +251,7 @@ where S: Service is_discovery_enabled, force_origin, dht_header, + tag, } = params; match self.select_peers(broadcast_strategy.clone()).await { @@ -320,6 +321,7 @@ where S: Service is_broadcast, body, Some(expires), + tag, ) .await { @@ -411,6 +413,7 @@ where S: Service is_broadcast: bool, body: Bytes, expires: Option>, + tag: Option, ) -> Result<(Vec, Vec), DhtOutboundError> { let dht_flags = encryption.flags() | extra_flags; @@ -424,7 +427,7 @@ where S: Service // Construct a DhtOutboundMessage for each recipient let messages = selected_peers.into_iter().map(|node_id| { let (reply_tx, reply_rx) = oneshot::channel(); - let tag = MessageTag::new(); + let tag = tag.unwrap_or_else(MessageTag::new); let send_state = MessageSendState::new(tag, reply_rx); ( DhtOutboundMessage { @@ -448,7 +451,7 @@ where S: Service Ok(messages.unzip()) } - async fn add_to_dedup_cache(&mut self, body: &[u8], public_key: CommsPublicKey) -> Result { + async fn add_to_dedup_cache(&mut self, body: &[u8], public_key: CommsPublicKey) -> Result<(), DhtOutboundError> { let hash = Challenge::new().chain(&body).finalize().to_vec(); trace!( target: LOG_TARGET, @@ -456,10 +459,19 @@ where S: Service hash.to_hex(), ); - self.dht_requester - .insert_message_hash(hash, public_key) + // Do not count messages we've broadcast towards the total hit count + let hit_count = self + .dht_requester + .get_message_cache_hit_count(hash.clone()) .await - .map_err(|_| DhtOutboundError::FailedToInsertMessageHash) + .map_err(|err| DhtOutboundError::FailedToInsertMessageHash(err.to_string()))?; + if hit_count == 0 { + self.dht_requester + .add_message_to_dedup_cache(hash, public_key) + .await + .map_err(|err| DhtOutboundError::FailedToInsertMessageHash(err.to_string()))?; + } + Ok(()) } fn process_encryption( diff --git a/comms/dht/src/outbound/error.rs b/comms/dht/src/outbound/error.rs index 3f93dab043..6e1a8156c4 100644 --- a/comms/dht/src/outbound/error.rs +++ b/comms/dht/src/outbound/error.rs @@ -48,8 +48,8 @@ pub enum DhtOutboundError { SendToOurselves, #[error("Discovery process failed")] DiscoveryFailed, - #[error("Failed to insert message hash")] - FailedToInsertMessageHash, + #[error("Failed to insert message hash: {0}")] + FailedToInsertMessageHash(String), #[error("Failed to send message: {0}")] SendMessageFailed(SendFailure), #[error("No messages were queued for sending")] diff --git a/comms/dht/src/outbound/message_params.rs b/comms/dht/src/outbound/message_params.rs index 0ad00bbc4e..3b38272c38 100644 --- a/comms/dht/src/outbound/message_params.rs +++ b/comms/dht/src/outbound/message_params.rs @@ -27,7 +27,7 @@ use crate::{ proto::envelope::DhtMessageType, }; use std::{fmt, fmt::Display}; -use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; +use tari_comms::{message::MessageTag, peer_manager::NodeId, types::CommsPublicKey}; /// Configuration for outbound messages. /// @@ -66,6 +66,7 @@ pub struct FinalSendMessageParams { pub dht_message_type: DhtMessageType, pub dht_message_flags: DhtMessageFlags, pub dht_header: Option, + pub tag: Option, } impl Default for FinalSendMessageParams { @@ -79,6 +80,7 @@ impl Default for FinalSendMessageParams { force_origin: false, is_discovery_enabled: false, dht_header: None, + tag: None, } } } @@ -171,6 +173,12 @@ impl SendMessageParams { self } + /// Set the message trace tag + pub fn with_tag(&mut self, tag: MessageTag) -> &mut Self { + self.params_mut().tag = Some(tag); + self + } + /// Set destination field in message header. pub fn with_destination(&mut self, destination: NodeDestination) -> &mut Self { self.params_mut().destination = destination; diff --git a/comms/dht/src/storage/error.rs b/comms/dht/src/storage/error.rs index ab9f52f78d..f5bf4f0596 100644 --- a/comms/dht/src/storage/error.rs +++ b/comms/dht/src/storage/error.rs @@ -40,4 +40,6 @@ pub enum StorageError { ResultError(#[from] diesel::result::Error), #[error("MessageFormatError: {0}")] MessageFormatError(#[from] MessageFormatError), + #[error("Unexpected result: {0}")] + UnexpectedResult(String), } diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index 95ce5e2500..013917e6e4 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -153,7 +153,7 @@ where S: Service self.forward(&message).await?; } - // The message has been forwarded, but other middleware may be interested (i.e. StoreMiddleware) + // The message has been forwarded, but downstream middleware may be interested trace!( target: LOG_TARGET, "Passing message {} to next service (Trace: {})", @@ -205,8 +205,9 @@ where S: Service } let body = decryption_result - .clone() + .as_ref() .err() + .cloned() .expect("previous check that decryption failed"); let excluded_peers = vec![source_peer.node_id.clone()]; diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index f3ba852118..bcba88493c 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -103,6 +103,20 @@ where S: Service .take() .expect("DhtInboundMessageTask initialized without message"); + if message.is_duplicate() { + debug!( + target: LOG_TARGET, + "Received message ({}) that has already been received {} time(s). Last sent by peer '{}', passing on \ + (Trace: {})", + message.tag, + message.dedup_hit_count, + message.source_peer.node_id.short_str(), + message.dht_header.message_tag, + ); + self.next_service.oneshot(message).await?; + return Ok(()); + } + if message.dht_header.message_type.is_saf_message() && message.decryption_failed() { debug!( target: LOG_TARGET, @@ -460,7 +474,8 @@ where S: Service public_key: CommsPublicKey, ) -> Result<(), StoreAndForwardError> { let msg_hash = Challenge::new().chain(body).finalize().to_vec(); - if dht_requester.insert_message_hash(msg_hash, public_key).await? { + let hit_count = dht_requester.add_message_to_dedup_cache(msg_hash, public_key).await?; + if hit_count > 1 { Err(StoreAndForwardError::DuplicateMessage) } else { Ok(()) diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index 4393f36518..c23133f23d 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -122,16 +122,31 @@ where } fn call(&mut self, msg: DecryptedDhtMessage) -> Self::Future { - Box::pin( - StoreTask::new( - self.next_service.clone(), - self.config.clone(), - Arc::clone(&self.peer_manager), - Arc::clone(&self.node_identity), - self.saf_requester.clone(), + if msg.is_duplicate() { + trace!( + target: LOG_TARGET, + "Passing duplicate message {} to next service (Trace: {})", + msg.tag, + msg.dht_header.message_tag + ); + + let service = self.next_service.clone(); + Box::pin(async move { + let service = service.ready_oneshot().await?; + service.oneshot(msg).await + }) + } else { + Box::pin( + StoreTask::new( + self.next_service.clone(), + self.config.clone(), + Arc::clone(&self.peer_manager), + Arc::clone(&self.node_identity), + self.saf_requester.clone(), + ) + .handle(msg), ) - .handle(msg), - ) + } } } diff --git a/comms/dht/src/test_utils/dht_actor_mock.rs b/comms/dht/src/test_utils/dht_actor_mock.rs index ccc53c5a1e..a64292b9ed 100644 --- a/comms/dht/src/test_utils/dht_actor_mock.rs +++ b/comms/dht/src/test_utils/dht_actor_mock.rs @@ -29,7 +29,7 @@ use futures::{channel::mpsc, stream::Fuse, StreamExt}; use std::{ collections::HashMap, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, RwLock, }, @@ -44,7 +44,7 @@ pub fn create_dht_actor_mock(buf_size: usize) -> (DhtRequester, DhtActorMock) { #[derive(Default, Debug, Clone)] pub struct DhtMockState { - signature_cache_insert: Arc, + signature_cache_insert: Arc, call_count: Arc, select_peers: Arc>>, settings: Arc>>>, @@ -52,16 +52,11 @@ pub struct DhtMockState { impl DhtMockState { pub fn new() -> Self { - Self { - signature_cache_insert: Arc::new(AtomicBool::new(false)), - call_count: Arc::new(AtomicUsize::new(0)), - select_peers: Arc::new(RwLock::new(Vec::new())), - settings: Arc::new(RwLock::new(HashMap::new())), - } + Default::default() } - pub fn set_signature_cache_insert(&self, v: bool) -> &Self { - self.signature_cache_insert.store(v, Ordering::SeqCst); + pub fn set_number_of_message_hits(&self, v: u32) -> &Self { + self.signature_cache_insert.store(v as usize, Ordering::SeqCst); self } @@ -111,9 +106,13 @@ impl DhtActorMock { self.state.inc_call_count(); match req { SendJoin => {}, - MsgHashCacheInsert(_, _, reply_tx) => { + MsgHashCacheInsert { reply_tx, .. } => { + let v = self.state.signature_cache_insert.load(Ordering::SeqCst); + reply_tx.send(v as u32).unwrap(); + }, + GetMsgHashHitCount(_, reply_tx) => { let v = self.state.signature_cache_insert.load(Ordering::SeqCst); - reply_tx.send(v).unwrap(); + reply_tx.send(v as u32).unwrap(); }, SelectPeers(_, reply_tx) => { let lock = self.state.select_peers.read().unwrap(); diff --git a/comms/dht/src/tower_filter/predicate.rs b/comms/dht/src/tower_filter/predicate.rs deleted file mode 100644 index f86b9cc406..0000000000 --- a/comms/dht/src/tower_filter/predicate.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::future::Future; -use tari_comms::pipeline::PipelineError; - -/// Checks a request -pub trait Predicate { - /// The future returned by `check`. - type Future: Future>; - - /// Check whether the given request should be forwarded. - /// - /// If the future resolves with `Ok`, the request is forwarded to the inner service. - fn check(&mut self, request: &Request) -> Self::Future; -} - -impl Predicate for F -where - F: Fn(&T) -> U, - U: Future>, -{ - type Future = U; - - fn check(&mut self, request: &T) -> Self::Future { - self(request) - } -} diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index a5aed09970..e963dd550a 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -46,6 +46,7 @@ use tari_comms_dht::{ DbConnectionUrl, Dht, DhtBuilder, + DhtConfig, }; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_storage::{ @@ -64,6 +65,7 @@ use tokio::{sync::broadcast, time}; use tower::ServiceBuilder; struct TestNode { + name: String, comms: CommsNode, dht: Dht, inbound_messages: mpsc::Receiver, @@ -80,6 +82,10 @@ impl TestNode { self.comms.node_identity().to_peer() } + pub fn name(&self) -> &str { + &self.name + } + pub async fn next_inbound_message(&mut self, timeout: Duration) -> Option { time::timeout(timeout, self.inbound_messages.next()).await.ok()? } @@ -113,24 +119,36 @@ fn create_peer_storage() -> CommsDatabase { LMDBWrapper::new(Arc::new(peer_database)) } -async fn make_node(features: PeerFeatures, seed_peer: Option) -> TestNode { +async fn make_node>( + name: &str, + features: PeerFeatures, + dht_config: DhtConfig, + known_peers: I, +) -> TestNode { let node_identity = make_node_identity(features); - make_node_with_node_identity(node_identity, seed_peer).await + make_node_with_node_identity(name, node_identity, dht_config, known_peers).await } -async fn make_node_with_node_identity(node_identity: Arc, seed_peer: Option) -> TestNode { +async fn make_node_with_node_identity>( + name: &str, + node_identity: Arc, + dht_config: DhtConfig, + known_peers: I, +) -> TestNode { let (tx, inbound_messages) = mpsc::channel(10); let shutdown = Shutdown::new(); let (comms, dht, messaging_events) = setup_comms_dht( node_identity, create_peer_storage(), tx, - seed_peer.into_iter().collect(), + known_peers.into_iter().collect(), + dht_config, shutdown.to_signal(), ) .await; TestNode { + name: name.to_string(), comms, dht, inbound_messages, @@ -145,6 +163,7 @@ async fn setup_comms_dht( storage: CommsDatabase, inbound_tx: mpsc::Sender, peers: Vec, + dht_config: DhtConfig, shutdown_signal: ShutdownSignal, ) -> (CommsNode, Dht, MessagingEventSender) { // Create inbound and outbound channels @@ -168,11 +187,8 @@ async fn setup_comms_dht( comms.connectivity(), comms.shutdown_signal(), ) - .local_test() - .set_auto_store_and_forward_requests(false) + .with_config(dht_config) .with_database_url(DbConnectionUrl::MemoryShared(random::string(8))) - .with_discovery_timeout(Duration::from_secs(60)) - .with_num_neighbouring_nodes(8) .build() .await .unwrap(); @@ -205,17 +221,38 @@ async fn setup_comms_dht( (comms, dht, event_tx) } +fn dht_config() -> DhtConfig { + let mut config = DhtConfig::default_local_test(); + config.allow_test_addresses = true; + config.saf_auto_request = false; + config.discovery_request_timeout = Duration::from_secs(60); + config.num_neighbouring_nodes = 8; + config +} + #[tokio_macros::test] #[allow(non_snake_case)] async fn dht_join_propagation() { // Create 3 nodes where only Node B knows A and C, but A and C want to talk to each other // Node C knows no one - let node_C = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let node_C = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; // Node B knows about Node C - let node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; node_A .comms @@ -268,13 +305,31 @@ async fn dht_discover_propagation() { // Create 4 nodes where A knows B, B knows A and C, C knows B and D, and D knows C // Node D knows no one - let node_D = make_node(PeerFeatures::COMMUNICATION_CLIENT, None).await; + let node_D = make_node("node_D", PeerFeatures::COMMUNICATION_CLIENT, dht_config(), None).await; // Node C knows about Node D - let node_C = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_D.to_peer())).await; + let node_C = make_node( + "node_C", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_D.to_peer()), + ) + .await; // Node B knows about Node C - let node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; log::info!( "NodeA = {}, NodeB = {}, Node C = {}, Node D = {}", node_A.node_identity().node_id().short_str(), @@ -323,9 +378,15 @@ async fn dht_discover_propagation() { async fn dht_store_forward() { let node_C_node_identity = make_node_identity(PeerFeatures::COMMUNICATION_NODE); // Node B knows about Node C - let node_B = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let node_B = make_node("node_B", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; log::info!( "NodeA = {}, NodeB = {}, Node C = {}", node_A.node_identity().node_id().short_str(), @@ -372,7 +433,8 @@ async fn dht_store_forward() { // Wait for node B to receive 2 propagation messages collect_stream!(node_B_msg_events, take = 2, timeout = Duration::from_secs(20)); - let mut node_C = make_node_with_node_identity(node_C_node_identity, Some(node_B.to_peer())).await; + let mut node_C = + make_node_with_node_identity("node_C", node_C_node_identity, dht_config(), Some(node_B.to_peer())).await; let mut node_C_dht_events = node_C.dht.subscribe_dht_events(); let mut node_C_msg_events = node_C.messaging_events.subscribe(); // Ask node B for messages @@ -429,14 +491,36 @@ async fn dht_store_forward() { #[tokio_macros::test] #[allow(non_snake_case)] async fn dht_propagate_dedup() { + let mut config = dht_config(); + // For this test we want to exactly measure the path of a message, so we disable repropagation of messages (i.e. + // discard on 2nd occurrence) + config.dedup_discard_hit_count = 2; // Node D knows no one - let mut node_D = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let mut node_D = make_node("node_D", PeerFeatures::COMMUNICATION_NODE, config.clone(), None).await; // Node C knows about Node D - let mut node_C = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_D.to_peer())).await; + let mut node_C = make_node( + "node_C", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_D.to_peer()), + ) + .await; // Node B knows about Node C - let mut node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let mut node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B and C - let mut node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let mut node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_B.to_peer()), + ) + .await; node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); log::info!( "NodeA = {}, NodeB = {}, Node C = {}, Node D = {}", @@ -482,8 +566,7 @@ async fn dht_propagate_dedup() { .dht .outbound_requester() .propagate( - // Node D is a client node, so an destination is required for domain messages - NodeDestination::Unknown, // NodeId(Box::new(node_D.node_identity().node_id().clone())), + NodeDestination::Unknown, OutboundEncryption::EncryptFor(Box::new(node_D.node_identity().public_key().clone())), vec![], out_msg, @@ -496,6 +579,7 @@ async fn dht_propagate_dedup() { .await .expect("Node D expected an inbound message but it never arrived"); assert!(msg.decryption_succeeded()); + log::info!("Received message {}", msg.tag); let person = msg .decryption_result .unwrap() @@ -536,14 +620,124 @@ async fn dht_propagate_dedup() { assert_eq!(count_messages_received(&received, &[&node_C_id]), 1); } +#[tokio_macros::test] +#[allow(non_snake_case)] +async fn dht_repropagate() { + let mut node_C = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, dht_config(), []).await; + let mut node_B = make_node("node_B", PeerFeatures::COMMUNICATION_NODE, dht_config(), [ + node_C.to_peer() + ]) + .await; + let mut node_A = make_node("node_A", PeerFeatures::COMMUNICATION_NODE, dht_config(), [ + node_B.to_peer(), + node_C.to_peer(), + ]) + .await; + node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); + node_B.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); + node_C.comms.peer_manager().add_peer(node_A.to_peer()).await.unwrap(); + node_C.comms.peer_manager().add_peer(node_B.to_peer()).await.unwrap(); + log::info!( + "NodeA = {}, NodeB = {}, Node C = {}", + node_A.node_identity().node_id().short_str(), + node_B.node_identity().node_id().short_str(), + node_C.node_identity().node_id().short_str(), + ); + + // Connect the peers that should be connected + async fn connect_nodes(node1: &mut TestNode, node2: &mut TestNode) { + node1 + .comms + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + } + // Pre-connect nodes, this helps message passing be more deterministic + connect_nodes(&mut node_A, &mut node_B).await; + connect_nodes(&mut node_A, &mut node_C).await; + connect_nodes(&mut node_B, &mut node_C).await; + + #[derive(Clone, PartialEq, ::prost::Message)] + struct Person { + #[prost(string, tag = "1")] + name: String, + #[prost(uint32, tag = "2")] + age: u32, + } + + let out_msg = OutboundDomainMessage::new(123, Person { + name: "Alan Turing".into(), + age: 41, + }); + node_A + .dht + .outbound_requester() + .propagate( + NodeDestination::Unknown, + OutboundEncryption::ClearText, + vec![], + out_msg.clone(), + ) + .await + .unwrap(); + + async fn receive_and_repropagate(node: &mut TestNode, out_msg: &OutboundDomainMessage) { + let msg = node + .next_inbound_message(Duration::from_secs(10)) + .await + .unwrap_or_else(|| panic!("{} expected an inbound message but it never arrived", node.name())); + log::info!("Received message {}", msg.tag); + + node.dht + .outbound_requester() + .send_message( + SendMessageParams::new() + .propagate(NodeDestination::Unknown, vec![]) + .with_destination(NodeDestination::Unknown) + .with_tag(msg.tag) + .finish(), + out_msg.clone(), + ) + .await + .unwrap() + .resolve() + .await + .unwrap(); + } + + // This relies on the DHT being set with .with_dedup_discard_hit_count(3) + receive_and_repropagate(&mut node_B, &out_msg).await; + receive_and_repropagate(&mut node_C, &out_msg).await; + receive_and_repropagate(&mut node_A, &out_msg).await; + receive_and_repropagate(&mut node_B, &out_msg).await; + receive_and_repropagate(&mut node_C, &out_msg).await; + + node_A.shutdown().await; + node_B.shutdown().await; + node_C.shutdown().await; +} + #[tokio_macros::test] #[allow(non_snake_case)] async fn dht_propagate_message_contents_not_malleable_ban() { - let node_C = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let node_C = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; // Node B knows about Node C - let mut node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let mut node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); log::info!( "NodeA = {}, NodeB = {}",