From 79a74a1ea8854f339ea5cc2e7ec990ce417c641a Mon Sep 17 00:00:00 2001 From: Stanimal Date: Tue, 14 Apr 2020 10:55:53 +0200 Subject: [PATCH] [Store and forward] Private message storage and retrieval Storage and retreival of messages without an origin (PR #1686) All private/anonymous messages are stored as low priority (short TTL) messages. - New response type `Anonymous` for store and forward query responses - Added some more logging - DHT store and forward integration test - SAF messages can never be stored in store and forward --- comms/dht/src/actor.rs | 16 ++- comms/dht/src/config.rs | 8 +- comms/dht/src/inbound/dedup.rs | 4 + comms/dht/src/inbound/message.rs | 2 +- comms/dht/src/outbound/broadcast.rs | 1 + comms/dht/src/outbound/message.rs | 2 +- comms/dht/src/proto/store_forward.proto | 19 ++-- comms/dht/src/proto/tari.dht.store_forward.rs | 20 ++-- comms/dht/src/store_forward/database/mod.rs | 100 +++++++++++++++++- comms/dht/src/store_forward/error.rs | 2 + comms/dht/src/store_forward/message.rs | 2 + .../dht/src/store_forward/saf_handler/task.rs | 44 ++++++-- comms/dht/src/store_forward/service.rs | 33 ++++-- comms/dht/src/store_forward/store.rs | 14 +-- comms/dht/tests/dht.rs | 93 ++++++++++++++-- comms/src/peer_manager/node_id.rs | 15 +++ 16 files changed, 315 insertions(+), 60 deletions(-) diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 3c30ea6215..5b85a48d90 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -299,9 +299,11 @@ impl<'a> DhtActor<'a> { }, SendRequestStoredMessages => { let node_identity = Arc::clone(&self.node_identity); + let peer_manager = Arc::clone(&self.peer_manager); let outbound_requester = self.outbound_requester.clone(); Box::pin(Self::request_stored_messages( node_identity, + peer_manager, outbound_requester, db, self.config.num_neighbouring_nodes, @@ -359,17 +361,29 @@ impl<'a> DhtActor<'a> { async fn request_stored_messages( node_identity: Arc, + peer_manager: Arc, mut outbound_requester: OutboundMessageRequester, db: DhtDatabase, num_neighbouring_nodes: usize, ) -> Result<(), DhtActorError> { - let request = db + let mut request = db .get_value(DhtSettingKey::SafLastRequestTimestamp) .await? .map(StoredMessagesRequest::since) .unwrap_or_else(StoredMessagesRequest::new); + // Calculate the network region threshold for our node id. + // i.e. "Give me all messages that are this close to my node ID" + let threshold = peer_manager + .calc_region_threshold( + node_identity.node_id(), + num_neighbouring_nodes, + PeerFeatures::DHT_STORE_FORWARD, + ) + .await?; + request.dist_threshold = threshold.to_vec(); + outbound_requester .send_message_no_header( SendMessageParams::new() diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 0f920a10e6..a59ace1b65 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -28,7 +28,7 @@ pub const SAF_MSG_CACHE_STORAGE_CAPACITY: usize = 10_000; /// The default time-to-live duration used for storage of low priority messages by the Store-and-forward middleware pub const SAF_LOW_PRIORITY_MSG_STORAGE_TTL: Duration = Duration::from_secs(6 * 60 * 60); // 6 hours /// The default time-to-live duration used for storage of high priority messages by the Store-and-forward middleware -pub const SAF_HIGH_PRIORITY_MSG_STORAGE_TTL: Duration = Duration::from_secs(2 * 24 * 60 * 60); // 2 days +pub const SAF_HIGH_PRIORITY_MSG_STORAGE_TTL: Duration = Duration::from_secs(3 * 24 * 60 * 60); // 3 days /// The default number of peer nodes that a message has to be closer to, to be considered a neighbour pub const DEFAULT_NUM_NEIGHBOURING_NODES: usize = 10; @@ -55,9 +55,9 @@ pub struct DhtConfig { /// Default: 6 hours pub saf_low_priority_msg_storage_ttl: Duration, /// The time-to-live duration used for storage of high priority messages by the Store-and-forward middleware. - /// Default: 2 days + /// Default: 3 days pub saf_high_priority_msg_storage_ttl: Duration, - /// The limit on the message size to store in SAF storage in bytes. Default 500kb + /// The limit on the message size to store in SAF storage in bytes. Default 500 KiB pub saf_max_message_size: usize, /// The max capacity of the message hash cache /// Default: 1000 @@ -112,7 +112,7 @@ impl Default for DhtConfig { saf_msg_cache_storage_capacity: SAF_MSG_CACHE_STORAGE_CAPACITY, saf_low_priority_msg_storage_ttl: SAF_LOW_PRIORITY_MSG_STORAGE_TTL, saf_high_priority_msg_storage_ttl: SAF_HIGH_PRIORITY_MSG_STORAGE_TTL, - saf_max_message_size: 512 * 1024, // 512 kb + saf_max_message_size: 512 * 1024, // 500 KiB msg_hash_cache_capacity: 10_000, msg_hash_cache_ttl: Duration::from_secs(300), broadcast_cooldown_max_attempts: 3, diff --git a/comms/dht/src/inbound/dedup.rs b/comms/dht/src/inbound/dedup.rs index 15f09646fc..23ec6b93f4 100644 --- a/comms/dht/src/inbound/dedup.rs +++ b/comms/dht/src/inbound/dedup.rs @@ -26,6 +26,7 @@ use futures::{task::Context, Future}; use log::*; use std::task::Poll; use tari_comms::{pipeline::PipelineError, types::Challenge}; +use tari_crypto::tari_utilities::hex::Hex; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::dedup"; @@ -77,6 +78,7 @@ where S: Service { trace!(target: LOG_TARGET, "Checking inbound message cache for duplicates"); let hash = Self::hash_message(&message); + trace!(target: LOG_TARGET, "Inserting message hash {}", hash.to_hex()); if dht_requester .insert_message_hash(hash) .await @@ -88,6 +90,8 @@ where S: Service ); return Ok(()); } + + trace!(target: LOG_TARGET, "Passing message onto next service"); next_service.oneshot(message).await } diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index 987c8cea76..3363c325f6 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -52,7 +52,7 @@ impl Display for DhtInboundMessage { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { write!( f, - "\n---- DhtInboundMessage ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\n----", + "\n---- Inbound Message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\n----", self.body.len(), self.dht_header.message_type, self.source_peer, diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index c0b1672a07..c6ea76c4e1 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -222,6 +222,7 @@ where S: Service reply_tx: oneshot::Sender, ) -> Result, DhtOutboundError> { + trace!(target: LOG_TARGET, "Send params: {:?}", params); if params .broadcast_strategy .direct_public_key() diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index b266d1c573..a4d1f3f60b 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -168,7 +168,7 @@ impl fmt::Display for DhtOutboundMessage { let header_str = self .custom_header .as_ref() - .and_then(|h| Some(format!("{} (Propagated)", h))) + .map(|h| format!("{} (Propagated)", h)) .unwrap_or_else(|| { format!( "Network: {:?}, Flags: {:?}, Destination: {}", diff --git a/comms/dht/src/proto/store_forward.proto b/comms/dht/src/proto/store_forward.proto index c7d26c416f..896728fb8c 100644 --- a/comms/dht/src/proto/store_forward.proto +++ b/comms/dht/src/proto/store_forward.proto @@ -12,6 +12,7 @@ package tari.dht.store_forward; message StoredMessagesRequest { google.protobuf.Timestamp since = 1; uint32 request_id = 2; + bytes dist_threshold = 3; } // Storage for a single message envelope, including the date and time when the element was stored @@ -27,14 +28,16 @@ message StoredMessagesResponse { repeated StoredMessage messages = 1; uint32 request_id = 2; enum SafResponseType { - // All applicable messages - General = 0; - // Send messages explicitly addressed to the requesting node or within the requesting node's region - ExplicitlyAddressed = 1; - // Send Discovery messages that could be for the requester - Discovery = 2; - // Send Join messages that the requester could be interested in - Join = 3; + // Messages for the requested public key or node ID + ForMe = 0; + // Discovery messages that could be for the requester + Discovery = 1; + // Join messages that the requester could be interested in + Join = 2; + // Messages without an explicit destination and with an unidentified encrypted source + Anonymous = 3; + // Messages within the requesting node's region + InRegion = 4; } SafResponseType response_type = 3; } diff --git a/comms/dht/src/proto/tari.dht.store_forward.rs b/comms/dht/src/proto/tari.dht.store_forward.rs index e0b58b90b7..7e1e67d202 100644 --- a/comms/dht/src/proto/tari.dht.store_forward.rs +++ b/comms/dht/src/proto/tari.dht.store_forward.rs @@ -7,6 +7,8 @@ pub struct StoredMessagesRequest { pub since: ::std::option::Option<::prost_types::Timestamp>, #[prost(uint32, tag = "2")] pub request_id: u32, + #[prost(bytes, tag = "3")] + pub dist_threshold: std::vec::Vec, } /// Storage for a single message envelope, including the date and time when the element was stored #[derive(Clone, PartialEq, ::prost::Message)] @@ -34,13 +36,15 @@ pub mod stored_messages_response { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SafResponseType { - /// All applicable messages - General = 0, - /// Send messages explicitly addressed to the requesting node or within the requesting node's region - ExplicitlyAddressed = 1, - /// Send Discovery messages that could be for the requester - Discovery = 2, - /// Send Join messages that the requester could be interested in - Join = 3, + /// Messages for the requested public key + ForMe = 0, + /// Discovery messages that could be for the requester + Discovery = 1, + /// Join messages that the requester could be interested in + Join = 2, + /// Messages without an explicit destination and with an unidentified encrypted source + Anonymous = 3, + /// Messages within the requesting node's region + InRegion = 4, } } diff --git a/comms/dht/src/store_forward/database/mod.rs b/comms/dht/src/store_forward/database/mod.rs index 12b772a1d5..1b40f77a27 100644 --- a/comms/dht/src/store_forward/database/mod.rs +++ b/comms/dht/src/store_forward/database/mod.rs @@ -31,7 +31,10 @@ use crate::{ }; use chrono::{DateTime, NaiveDateTime, Utc}; use diesel::{BoolExpressionMethods, ExpressionMethods, QueryDsl, RunQueryDsl}; -use tari_comms::types::CommsPublicKey; +use tari_comms::{ + peer_manager::{node_id::NodeDistance, NodeId}, + types::CommsPublicKey, +}; use tari_crypto::tari_utilities::hex::Hex; pub struct StoreAndForwardDatabase { @@ -54,19 +57,106 @@ impl StoreAndForwardDatabase { .await } - pub async fn find_messages_for_public_key( + pub async fn find_messages_for_peer( &self, public_key: &CommsPublicKey, + node_id: &NodeId, since: Option>, limit: i64, ) -> Result, StorageError> { let pk_hex = public_key.to_hex(); + let node_id_hex = node_id.to_hex(); + self.connection + .with_connection_async::<_, Vec>(move |conn| { + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter( + stored_messages::destination_pubkey + .eq(pk_hex) + .or(stored_messages::destination_node_id.eq(node_id_hex)), + ) + .into_boxed(); + + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.ge(since.naive_utc())); + } + + query + .order_by(stored_messages::stored_at.asc()) + .limit(limit) + .get_results(conn) + .map_err(Into::into) + }) + .await + } + + pub async fn find_regional_messages( + &self, + node_id: &NodeId, + dist_threshold: Option>, + since: Option>, + limit: i64, + ) -> Result, StorageError> + { + let node_id_hex = node_id.to_hex(); + let results = self + .connection + .with_connection_async::<_, Vec>(move |conn| { + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter(stored_messages::destination_node_id.ne(node_id_hex)) + .filter(stored_messages::destination_node_id.is_not_null()) + .filter(stored_messages::message_type.eq(DhtMessageType::None as i32)) + .into_boxed(); + + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.ge(since.naive_utc())); + } + + query + .order_by(stored_messages::stored_at.asc()) + .limit(limit) + .get_results(conn) + .map_err(Into::into) + }) + .await?; + + match dist_threshold { + Some(dist_threshold) => { + // Filter node ids that are within the distance threshold from the source node id + let results = results + .into_iter() + // TODO: Investigate if we could do this in sqlite using XOR (^) + .filter(|message| match message.destination_node_id { + Some(ref dest_node_id) => match NodeId::from_hex(dest_node_id).ok() { + Some(dest_node_id) => { + &dest_node_id == node_id || &dest_node_id.distance(node_id) <= &*dist_threshold + }, + None => false, + }, + None => true, + }) + .collect(); + Ok(results) + }, + None => Ok(results), + } + } + + pub async fn find_anonymous_messages( + &self, + since: Option>, + limit: i64, + ) -> Result, StorageError> + { self.connection .with_connection_async(move |conn| { let mut query = stored_messages::table .select(stored_messages::all_columns) - .filter(stored_messages::destination_pubkey.eq(pk_hex)) + .filter(stored_messages::origin_pubkey.is_null()) + .filter(stored_messages::destination_pubkey.is_null()) + .filter(stored_messages::is_encrypted.eq(true)) .filter(stored_messages::message_type.eq(DhtMessageType::None as i32)) .into_boxed(); @@ -75,7 +165,7 @@ impl StoreAndForwardDatabase { } query - .order_by(stored_messages::stored_at.desc()) + .order_by(stored_messages::stored_at.asc()) .limit(limit) .get_results(conn) .map_err(Into::into) @@ -109,7 +199,7 @@ impl StoreAndForwardDatabase { } query - .order_by(stored_messages::stored_at.desc()) + .order_by(stored_messages::stored_at.asc()) .limit(limit) .get_results(conn) .map_err(Into::into) diff --git a/comms/dht/src/store_forward/error.rs b/comms/dht/src/store_forward/error.rs index e2c19ef830..e6777f60c9 100644 --- a/comms/dht/src/store_forward/error.rs +++ b/comms/dht/src/store_forward/error.rs @@ -71,4 +71,6 @@ pub enum StoreAndForwardError { /// The envelope version is invalid InvalidEnvelopeVersion, MalformedNodeId(ByteArrayError), + /// NodeDistance threshold was invalid + InvalidNodeDistanceThreshold, } diff --git a/comms/dht/src/store_forward/message.rs b/comms/dht/src/store_forward/message.rs index 41dc126022..6e0c130465 100644 --- a/comms/dht/src/store_forward/message.rs +++ b/comms/dht/src/store_forward/message.rs @@ -55,6 +55,7 @@ impl StoredMessagesRequest { Self { since: None, request_id: OsRng.next_u32(), + dist_threshold: Vec::new(), } } @@ -62,6 +63,7 @@ impl StoredMessagesRequest { Self { since: Some(datetime_to_timestamp(since)), request_id: OsRng.next_u32(), + dist_threshold: Vec::new(), } } } diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index 1e944b6050..270acac9f7 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -53,7 +53,7 @@ use prost::Message; use std::{convert::TryInto, sync::Arc}; use tari_comms::{ message::EnvelopeBody, - peer_manager::{NodeIdentity, Peer, PeerManager, PeerManagerError}, + peer_manager::{node_id::NodeDistance, NodeIdentity, Peer, PeerManager, PeerManagerError}, pipeline::PipelineError, types::{Challenge, CommsPublicKey}, utils::signature, @@ -171,17 +171,29 @@ where S: Service } let source_pubkey = Box::new(message.source_peer.public_key.clone()); + let source_node_id = Box::new(message.source_peer.node_id.clone()); // Compile a set of stored messages for the requesting peer - let mut query = FetchStoredMessageQuery::new(source_pubkey); + let mut query = FetchStoredMessageQuery::new(source_pubkey, source_node_id); + if let Some(since) = retrieve_msgs.since.map(timestamp_to_datetime) { query.since(since); } + if !retrieve_msgs.dist_threshold.is_empty() { + let dist_threshold = Box::new( + NodeDistance::from_bytes(&retrieve_msgs.dist_threshold) + .map_err(|_| StoreAndForwardError::InvalidNodeDistanceThreshold)?, + ); + query.with_dist_threshold(dist_threshold); + } + let response_types = vec![ + SafResponseType::ForMe, + SafResponseType::Anonymous, + SafResponseType::InRegion, SafResponseType::Discovery, SafResponseType::Join, - SafResponseType::ExplicitlyAddressed, ]; for resp_type in response_types { @@ -189,7 +201,7 @@ where S: Service let messages = self.saf_requester.fetch_messages(query.clone()).await?; if messages.is_empty() { - debug!( + info!( target: LOG_TARGET, "No {:?} stored messages for peer '{}'", resp_type, @@ -206,8 +218,9 @@ where S: Service info!( target: LOG_TARGET, - "Responding to received message retrieval request with {} message(s)", - stored_messages.messages().len() + "Responding to received message retrieval request with {} {:?} message(s)", + stored_messages.messages().len(), + resp_type ); self.outbound_service .send_message_no_header( @@ -240,8 +253,12 @@ where S: Service debug!( target: LOG_TARGET, - "Received {} stored messages from peer", - response.messages().len() + "Received {} stored messages of type {} from peer", + response.messages().len(), + SafResponseType::from_i32(response.response_type) + .as_ref() + .map(|t| format!("{:?}", t)) + .unwrap_or("".to_string()), ); let last_timestamp = self @@ -451,10 +468,21 @@ where S: Service "[store and forward] DHT header is invalid after validity check because it did not contain an \ ephemeral_public_key", ); + + trace!( + target: LOG_TARGET, + "Attempting to decrypt origin mac ({} byte(s))", + header.origin_mac.len() + ); let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), ephemeral_public_key); let decrypted = crypt::decrypt(&shared_secret, &header.origin_mac)?; let authenticated_pk = Self::authenticate_message(&decrypted, body)?; + trace!( + target: LOG_TARGET, + "Attempting to decrypt message body ({} byte(s))", + body.len() + ); let decrypted_bytes = crypt::decrypt(&shared_secret, body)?; let envelope_body = EnvelopeBody::decode(decrypted_bytes.as_slice()).map_err(|_| StoreAndForwardError::DecryptionFailed)?; diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs index fb41681f1b..e8dbb11a80 100644 --- a/comms/dht/src/store_forward/service.rs +++ b/comms/dht/src/store_forward/service.rs @@ -41,7 +41,10 @@ use futures::{ }; use log::*; use std::{convert::TryFrom, time::Duration}; -use tari_comms::types::CommsPublicKey; +use tari_comms::{ + peer_manager::{node_id::NodeDistance, NodeId}, + types::CommsPublicKey, +}; use tari_shutdown::ShutdownSignal; use tokio::time; @@ -53,16 +56,20 @@ const CLEANUP_INTERVAL: Duration = Duration::from_secs(10 * 60); // 10 mins #[derive(Debug, Clone)] pub struct FetchStoredMessageQuery { public_key: Box, + node_id: Box, since: Option>, + dist_threshold: Option>, response_type: SafResponseType, } impl FetchStoredMessageQuery { - pub fn new(public_key: Box) -> Self { + pub fn new(public_key: Box, node_id: Box) -> Self { Self { public_key, + node_id, since: None, - response_type: SafResponseType::General, + response_type: SafResponseType::Anonymous, + dist_threshold: None, } } @@ -75,6 +82,11 @@ impl FetchStoredMessageQuery { self.response_type = response_type; self } + + pub fn with_dist_threshold(&mut self, dist_threshold: Box) -> &mut Self { + self.dist_threshold = Some(dist_threshold); + self + } } #[derive(Debug)] @@ -218,25 +230,28 @@ impl StoreAndForwardService { query: FetchStoredMessageQuery, ) -> SafResult> { + use SafResponseType::*; let limit = i64::try_from(self.config.saf_max_returned_messages) .ok() .or(Some(std::i64::MAX)) .unwrap(); let messages = match query.response_type { - SafResponseType::General => { - db.find_messages_for_public_key(&query.public_key, query.since, limit) + ForMe => { + db.find_messages_for_peer(&query.public_key, &query.node_id, query.since, limit) .await? }, - SafResponseType::Join => { + Join => { db.find_messages_of_type_for_pubkey(&query.public_key, DhtMessageType::Join, query.since, limit) .await? }, - SafResponseType::Discovery => { + Discovery => { db.find_messages_of_type_for_pubkey(&query.public_key, DhtMessageType::Discovery, query.since, limit) .await? }, - SafResponseType::ExplicitlyAddressed => { - db.find_messages_for_public_key(&query.public_key, query.since, limit) + + Anonymous => db.find_anonymous_messages(query.since, limit).await?, + InRegion => { + db.find_regional_messages(&query.node_id, query.dist_threshold, query.since, limit) .await? }, }; diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index 0baa43f6f4..b95c8ee636 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -213,6 +213,11 @@ where S: Service return Ok(None); } + if message.dht_header.message_type.is_saf_message() { + log_not_eligible("because it is a SAF message"); + return Ok(None); + } + if message .authenticated_origin() .map(|pk| pk == self.node_identity.public_key()) @@ -330,13 +335,8 @@ where S: Service use NodeDestination::*; match &message.dht_header.destination { Unknown => { - // No destination provided, only discovery messages are currently important enough to be stored - if message.dht_header.message_type.is_dht_discovery() { - Ok(Some(StoredMessagePriority::Low)) - } else { - log_not_eligible("destination is unknown, and message is not a Discovery"); - Ok(None) - } + // No destination provided, + Ok(Some(StoredMessagePriority::Low)) }, PublicKey(dest_public_key) => { // If we know the destination peer, keep the message for them diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index 521a07f6f8..e5d2230a13 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -20,28 +20,37 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::channel::mpsc; +use futures::{channel::mpsc, StreamExt}; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_comms::{ backoff::ConstantBackoff, + message::MessageExt, peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerStorage}, pipeline, pipeline::SinkService, transports::MemoryTransport, types::CommsDatabase, + wrap_in_envelope_body, CommsBuilder, CommsNode, }; -use tari_comms_dht::{inbound::DecryptedDhtMessage, Dht, DhtBuilder}; +use tari_comms_dht::{ + envelope::NodeDestination, + inbound::DecryptedDhtMessage, + outbound::{OutboundEncryption, SendMessageParams}, + Dht, + DhtBuilder, +}; use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; use tari_test_utils::{async_assert_eventually, paths::create_temporary_data_path, random}; +use tokio::time; use tower::ServiceBuilder; struct TestNode { comms: CommsNode, dht: Dht, - _ims_rx: mpsc::Receiver, + ims_rx: mpsc::Receiver, } impl TestNode { @@ -52,6 +61,10 @@ impl TestNode { pub fn to_peer(&self) -> Peer { self.comms.node_identity().to_peer() } + + pub async fn next_inbound_message(&mut self, timeout: Duration) -> Option { + time::timeout(timeout, self.ims_rx.next()).await.ok()? + } } fn make_node_identity(features: PeerFeatures) -> Arc { @@ -81,15 +94,14 @@ fn create_peer_storage(peers: Vec) -> CommsDatabase { async fn make_node(features: PeerFeatures, seed_peer: Option) -> TestNode { let node_identity = make_node_identity(features); + make_node_with_node_identity(node_identity, seed_peer).await +} +async fn make_node_with_node_identity(node_identity: Arc, seed_peer: Option) -> TestNode { let (tx, ims_rx) = mpsc::channel(1); let (comms, dht) = setup_comms_dht(node_identity, create_peer_storage(seed_peer.into_iter().collect()), tx).await; - TestNode { - comms, - dht, - _ims_rx: ims_rx, - } + TestNode { comms, dht, ims_rx } } async fn setup_comms_dht( @@ -234,3 +246,68 @@ async fn dht_discover_propagation() { node_C.comms.shutdown().await; node_D.comms.shutdown().await; } + +#[tokio_macros::test] +#[allow(non_snake_case)] +async fn dht_store_forward() { + let node_C_node_identity = make_node_identity(PeerFeatures::COMMUNICATION_NODE); + // Node B knows about Node C + let node_B = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + // Node A knows about Node B + let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + + let params = SendMessageParams::new() + .neighbours(vec![]) + .with_encryption(OutboundEncryption::EncryptFor(Box::new( + node_C_node_identity.public_key().clone(), + ))) + .with_destination(NodeDestination::Unknown) + .finish(); + + let secret_msg1 = b"NCZW VUSX PNYM INHZ XMQX SFWX WLKJ AHSH"; + let secret_msg2 = b"NMCO CCAK UQPM KCSM HKSE INJU SBLK"; + + let mut node_B_msg_events = node_B.comms.subscribe_messaging_events(); + node_A + .dht + .outbound_requester() + .send_raw( + params.clone(), + wrap_in_envelope_body!(secret_msg1.to_vec()).to_encoded_bytes(), + ) + .await + .unwrap(); + node_A + .dht + .outbound_requester() + .send_raw(params, wrap_in_envelope_body!(secret_msg2.to_vec()).to_encoded_bytes()) + .await + .unwrap(); + + // Wait for node B to receive the 2 propagation messages + node_B_msg_events.next().await.unwrap().unwrap(); + node_B_msg_events.next().await.unwrap().unwrap(); + + let mut node_C = make_node_with_node_identity(node_C_node_identity, None).await; + node_C.comms.peer_manager().add_peer(node_B.to_peer()).await.unwrap(); + node_C.dht.dht_requester().send_request_stored_messages().await.unwrap(); + + let msg = node_C.next_inbound_message(Duration::from_secs(5)).await.unwrap(); + assert_eq!( + msg.authenticated_origin.as_ref().unwrap(), + node_A.comms.node_identity().public_key() + ); + let secret = msg.success().unwrap().decode_part::>(0).unwrap().unwrap(); + assert_eq!(secret, secret_msg1.to_vec()); + let msg = node_C.next_inbound_message(Duration::from_secs(5)).await.unwrap(); + assert_eq!( + msg.authenticated_origin.as_ref().unwrap(), + node_A.comms.node_identity().public_key() + ); + let secret = msg.success().unwrap().decode_part::>(0).unwrap().unwrap(); + assert_eq!(secret, secret_msg2.to_vec()); + + node_A.comms.shutdown().await; + node_B.comms.shutdown().await; + node_C.comms.shutdown().await; +} diff --git a/comms/src/peer_manager/node_id.rs b/comms/src/peer_manager/node_id.rs index b1466803b0..f777ea4f4a 100644 --- a/comms/src/peer_manager/node_id.rs +++ b/comms/src/peer_manager/node_id.rs @@ -94,6 +94,21 @@ impl TryFrom<&[u8]> for NodeDistance { } } +impl ByteArray for NodeDistance { + /// Try and convert the given byte array to a NodeDistance. Any failures (incorrect array length, + /// implementation-specific checks, etc) return a [ByteArrayError](enum.ByteArrayError.html). + fn from_bytes(bytes: &[u8]) -> Result { + bytes + .try_into() + .map_err(|err| ByteArrayError::ConversionError(format!("{:?}", err))) + } + + /// Return the NodeId as a byte array + fn as_bytes(&self) -> &[u8] { + self.0.as_ref() + } +} + impl fmt::Display for NodeDistance { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", to_hex(&self.0))