diff --git a/applications/tari_base_node/src/bootstrap.rs b/applications/tari_base_node/src/bootstrap.rs index daee21c21b..29b6404638 100644 --- a/applications/tari_base_node/src/bootstrap.rs +++ b/applications/tari_base_node/src/bootstrap.rs @@ -55,7 +55,7 @@ use tari_p2p::{ auto_update::{AutoUpdateConfig, SoftwareUpdaterService}, comms_connector::pubsub_connector, initialization, - initialization::{CommsConfig, P2pInitializer}, + initialization::{P2pConfig, P2pInitializer}, peer_seeds::SeedPeer, services::liveness::{LivenessConfig, LivenessInitializer}, }; @@ -236,8 +236,8 @@ where B: BlockchainBackend + 'static comms.add_protocol_extension(rpc_server) } - fn create_comms_config(&self) -> CommsConfig { - CommsConfig { + fn create_comms_config(&self) -> P2pConfig { + P2pConfig { network: self.config.network, node_identity: self.node_identity.clone(), transport_type: create_transport_type(self.config), diff --git a/applications/tari_console_wallet/src/init/mod.rs b/applications/tari_console_wallet/src/init/mod.rs index df26c2e4d0..f720aee65d 100644 --- a/applications/tari_console_wallet/src/init/mod.rs +++ b/applications/tari_console_wallet/src/init/mod.rs @@ -38,7 +38,7 @@ use tari_comms_dht::{DbConnectionUrl, DhtConfig}; use tari_core::transactions::CryptoFactories; use tari_p2p::{ auto_update::AutoUpdateConfig, - initialization::CommsConfig, + initialization::P2pConfig, peer_seeds::SeedPeer, transport::TransportType::Tor, DEFAULT_DNS_NAME_SERVER, @@ -326,7 +326,7 @@ pub async fn init_wallet( _ => transport_type, }; - let comms_config = CommsConfig { + let comms_config = P2pConfig { network: config.network, node_identity, user_agent: format!("tari/wallet/{}", env!("CARGO_PKG_VERSION")), diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 762a4bce33..68cc39c710 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -62,7 +62,7 @@ use tari_comms::{ PeerManager, UnspawnedCommsNode, }; -use tari_comms_dht::{Dht, DhtBuilder, DhtConfig, DhtInitializationError}; +use tari_comms_dht::{Dht, DhtConfig, DhtInitializationError, DhtProtocolVersion}; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; use tari_shutdown::ShutdownSignal; use tari_storage::{ @@ -112,7 +112,7 @@ impl CommsInitializationError { /// Configuration for a comms node #[derive(Clone)] -pub struct CommsConfig { +pub struct P2pConfig { /// Path to the LMDB data files. pub datastore_path: PathBuf, /// Name to use for the peer database @@ -202,17 +202,17 @@ pub async fn initialize_local_test_comms( // Create outbound channel let (outbound_tx, outbound_rx) = mpsc::channel(10); - let dht = DhtBuilder::new( - comms.node_identity(), - comms.peer_manager(), - outbound_tx, - comms.connectivity(), - comms.shutdown_signal(), - ) - .local_test() - .with_discovery_timeout(discovery_request_timeout) - .build() - .await?; + let dht = Dht::builder() + .local_test() + .with_outbound_sender(outbound_tx) + .with_discovery_timeout(discovery_request_timeout) + .build( + comms.node_identity(), + comms.peer_manager(), + comms.connectivity(), + comms.shutdown_signal(), + ) + .await?; let dht_outbound_layer = dht.outbound_middleware_layer(); let (event_sender, _) = broadcast::channel(100); @@ -316,7 +316,7 @@ async fn initialize_hidden_service( async fn configure_comms_and_dht( builder: CommsBuilder, - config: &CommsConfig, + config: &P2pConfig, connector: InboundDomainConnector, ) -> Result<(UnspawnedCommsNode, Dht), CommsInitializationError> { let file_lock = acquire_exclusive_file_lock(&config.datastore_path)?; @@ -352,16 +352,15 @@ async fn configure_comms_and_dht( // Create outbound channel let (outbound_tx, outbound_rx) = mpsc::channel(config.outbound_buffer_size); - let dht = DhtBuilder::new( - node_identity.clone(), - peer_manager, - outbound_tx, - connectivity, - shutdown_signal, - ) - .with_config(config.dht.clone()) - .build() - .await?; + let mut dht = Dht::builder(); + dht.with_config(config.dht.clone()).with_outbound_sender(outbound_tx); + // TODO: remove this once enough weatherwax nodes have upgraded + if config.network == Network::Weatherwax { + dht.with_protocol_version(DhtProtocolVersion::v1()); + } + let dht = dht + .build(node_identity.clone(), peer_manager, connectivity, shutdown_signal) + .await?; let dht_outbound_layer = dht.outbound_middleware_layer(); @@ -449,12 +448,12 @@ async fn add_all_peers( } pub struct P2pInitializer { - config: CommsConfig, + config: P2pConfig, connector: Option, } impl P2pInitializer { - pub fn new(config: CommsConfig, connector: PubsubDomainConnector) -> Self { + pub fn new(config: P2pConfig, connector: PubsubDomainConnector) -> Self { Self { config, connector: Some(connector), diff --git a/base_layer/p2p/src/services/liveness/service.rs b/base_layer/p2p/src/services/liveness/service.rs index f3b7d163f4..bc1cd3bbd5 100644 --- a/base_layer/p2p/src/services/liveness/service.rs +++ b/base_layer/p2p/src/services/liveness/service.rs @@ -322,6 +322,7 @@ mod test { use tari_comms_dht::{ envelope::{DhtMessageHeader, DhtMessageType}, outbound::{DhtOutboundRequest, MessageSendState, SendMessageResponse}, + DhtProtocolVersion, }; use tari_crypto::keys::PublicKey; use tari_service_framework::reply_channel; @@ -435,8 +436,7 @@ mod test { ); DomainMessage { dht_header: DhtMessageHeader { - major: 0, - minor: 0, + version: DhtProtocolVersion::latest(), destination: Default::default(), origin_mac: Vec::new(), ephemeral_public_key: None, diff --git a/base_layer/p2p/src/test_utils.rs b/base_layer/p2p/src/test_utils.rs index c44f4fa5af..fb1ced2945 100644 --- a/base_layer/p2p/src/test_utils.rs +++ b/base_layer/p2p/src/test_utils.rs @@ -30,6 +30,7 @@ use tari_comms::{ use tari_comms_dht::{ envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageType, NodeDestination}, inbound::DhtInboundMessage, + DhtProtocolVersion, }; macro_rules! unwrap_oms_send_msg { @@ -59,8 +60,7 @@ pub fn make_node_identity() -> Arc { pub fn make_dht_header(trace: MessageTag) -> DhtMessageHeader { DhtMessageHeader { - major: 0, - minor: 0, + version: DhtProtocolVersion::latest(), destination: NodeDestination::Unknown, origin_mac: Vec::new(), ephemeral_public_key: None, diff --git a/base_layer/wallet/src/config.rs b/base_layer/wallet/src/config.rs index 908e7bbc56..e176836453 100644 --- a/base_layer/wallet/src/config.rs +++ b/base_layer/wallet/src/config.rs @@ -23,7 +23,7 @@ use std::time::Duration; use tari_core::{consensus::NetworkConsensus, transactions::CryptoFactories}; -use tari_p2p::{auto_update::AutoUpdateConfig, initialization::CommsConfig}; +use tari_p2p::{auto_update::AutoUpdateConfig, initialization::P2pConfig}; use crate::{ base_node_service::config::BaseNodeServiceConfig, @@ -35,7 +35,7 @@ pub const KEY_MANAGER_COMMS_SECRET_KEY_BRANCH_KEY: &str = "comms"; #[derive(Clone)] pub struct WalletConfig { - pub comms_config: CommsConfig, + pub comms_config: P2pConfig, pub factories: CryptoFactories, pub transaction_service_config: Option, pub output_manager_service_config: Option, @@ -51,7 +51,7 @@ pub struct WalletConfig { impl WalletConfig { #[allow(clippy::too_many_arguments)] pub fn new( - comms_config: CommsConfig, + comms_config: P2pConfig, factories: CryptoFactories, transaction_service_config: Option, output_manager_service_config: Option, diff --git a/base_layer/wallet/tests/support/comms_and_services.rs b/base_layer/wallet/tests/support/comms_and_services.rs index 1b1243d72b..a3d7609bc8 100644 --- a/base_layer/wallet/tests/support/comms_and_services.rs +++ b/base_layer/wallet/tests/support/comms_and_services.rs @@ -29,7 +29,7 @@ use tari_comms::{ types::CommsPublicKey, CommsNode, }; -use tari_comms_dht::{envelope::DhtMessageHeader, Dht}; +use tari_comms_dht::{envelope::DhtMessageHeader, Dht, DhtProtocolVersion}; use tari_p2p::{ comms_connector::InboundDomainConnector, domain_message::DomainMessage, @@ -77,8 +77,7 @@ pub fn create_dummy_message(inner: T, public_key: &CommsPublicKey) -> DomainM ); DomainMessage { dht_header: DhtMessageHeader { - major: Default::default(), - minor: Default::default(), + version: DhtProtocolVersion::latest(), ephemeral_public_key: None, origin_mac: Vec::new(), message_type: Default::default(), diff --git a/base_layer/wallet/tests/wallet/mod.rs b/base_layer/wallet/tests/wallet/mod.rs index 3fc3e06086..ca846366ae 100644 --- a/base_layer/wallet/tests/wallet/mod.rs +++ b/base_layer/wallet/tests/wallet/mod.rs @@ -53,7 +53,7 @@ use tari_core::transactions::{ transaction::OutputFeatures, CryptoFactories, }; -use tari_p2p::{initialization::CommsConfig, transport::TransportType, Network, DEFAULT_DNS_NAME_SERVER}; +use tari_p2p::{initialization::P2pConfig, transport::TransportType, Network, DEFAULT_DNS_NAME_SERVER}; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_test_utils::random; use tari_wallet::{ @@ -105,7 +105,7 @@ async fn create_wallet( ) -> Result { const NETWORK: Network = Network::Weatherwax; let node_identity = NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE); - let comms_config = CommsConfig { + let comms_config = P2pConfig { network: NETWORK, node_identity: Arc::new(node_identity.clone()), transport_type: TransportType::Memory { @@ -685,7 +685,7 @@ async fn test_import_utxo() { ); let temp_dir = tempdir().unwrap(); let (connection, _temp_dir) = make_wallet_database_connection(None); - let comms_config = CommsConfig { + let comms_config = P2pConfig { network: Network::Weatherwax, node_identity: Arc::new(alice_identity.clone()), transport_type: TransportType::Tcp { diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index a727412d3d..178b9a98be 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -210,7 +210,7 @@ const LOG_TARGET: &str = "wallet_ffi"; pub type TariTransportType = tari_p2p::transport::TransportType; pub type TariPublicKey = tari_comms::types::CommsPublicKey; pub type TariPrivateKey = tari_comms::types::CommsSecretKey; -pub type TariCommsConfig = tari_p2p::initialization::CommsConfig; +pub type TariCommsConfig = tari_p2p::initialization::P2pConfig; pub type TariTransactionKernel = tari_core::transactions::transaction::TransactionKernel; pub struct TariContacts(Vec); diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index d44db988ae..bb6cf8f55a 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -54,7 +54,6 @@ use tari_comms_dht::{ inbound::DecryptedDhtMessage, outbound::OutboundEncryption, Dht, - DhtBuilder, DhtConfig, }; use tari_shutdown::{Shutdown, ShutdownSignal}; @@ -911,26 +910,26 @@ async fn setup_comms_dht( comms.peer_manager().add_peer(peer).await.unwrap(); } - let dht = DhtBuilder::new( - comms.node_identity(), - comms.peer_manager(), - outbound_tx, - comms.connectivity(), - comms.shutdown_signal(), - ) - .with_config(DhtConfig { - saf_auto_request, - auto_join: false, - discovery_request_timeout: Duration::from_secs(15), - num_neighbouring_nodes, - num_random_nodes, - propagation_factor, - network_discovery: Default::default(), - ..DhtConfig::default_local_test() - }) - .build() - .await - .unwrap(); + let dht = Dht::builder() + .with_config(DhtConfig { + saf_auto_request, + auto_join: false, + discovery_request_timeout: Duration::from_secs(15), + num_neighbouring_nodes, + num_random_nodes, + propagation_factor, + network_discovery: Default::default(), + ..DhtConfig::default_local_test() + }) + .with_outbound_sender(outbound_tx) + .build( + comms.node_identity(), + comms.peer_manager(), + comms.connectivity(), + comms.shutdown_signal(), + ) + .await + .unwrap(); let dht_outbound_layer = dht.outbound_middleware_layer(); let pipeline = pipeline::Builder::new() diff --git a/comms/dht/src/builder.rs b/comms/dht/src/builder.rs index bd7fb2521c..a35da81ee0 100644 --- a/comms/dht/src/builder.rs +++ b/comms/dht/src/builder.rs @@ -20,131 +20,143 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{dht::DhtInitializationError, outbound::DhtOutboundRequest, DbConnectionUrl, Dht, DhtConfig}; -use std::{sync::Arc, time::Duration}; -use tari_comms::{ - connectivity::ConnectivityRequester, - peer_manager::{NodeIdentity, PeerManager}, +use crate::{ + dht::DhtInitializationError, + outbound::DhtOutboundRequest, + version::DhtProtocolVersion, + DbConnectionUrl, + Dht, + DhtConfig, }; +use std::{sync::Arc, time::Duration}; +use tari_comms::{connectivity::ConnectivityRequester, NodeIdentity, PeerManager}; use tari_shutdown::ShutdownSignal; use tokio::sync::mpsc; +#[derive(Debug, Clone, Default)] pub struct DhtBuilder { - node_identity: Arc, - peer_manager: Arc, config: DhtConfig, - outbound_tx: mpsc::Sender, - connectivity: ConnectivityRequester, - shutdown_signal: ShutdownSignal, + outbound_tx: Option>, } impl DhtBuilder { - pub fn new( - node_identity: Arc, - peer_manager: Arc, - outbound_tx: mpsc::Sender, - connectivity: ConnectivityRequester, - shutdown_signal: ShutdownSignal, - ) -> Self { + pub fn new() -> Self { Self { #[cfg(test)] config: DhtConfig::default_local_test(), #[cfg(not(test))] config: Default::default(), - node_identity, - peer_manager, - outbound_tx, - connectivity, - shutdown_signal, + outbound_tx: None, } } - pub fn with_config(mut self, config: DhtConfig) -> Self { + pub fn with_config(&mut self, config: DhtConfig) -> &mut Self { self.config = config; self } - pub fn local_test(mut self) -> Self { + pub fn local_test(&mut self) -> &mut Self { self.config = DhtConfig::default_local_test(); self } - pub fn set_auto_store_and_forward_requests(mut self, enabled: bool) -> Self { + pub fn with_protocol_version(&mut self, protocol_version: DhtProtocolVersion) -> &mut Self { + self.config.protocol_version = protocol_version; + self + } + + pub fn set_auto_store_and_forward_requests(&mut self, enabled: bool) -> &mut Self { self.config.saf_auto_request = enabled; self } - pub fn testnet(mut self) -> Self { + pub fn with_outbound_sender(&mut self, outbound_tx: mpsc::Sender) -> &mut Self { + self.outbound_tx = Some(outbound_tx); + self + } + + pub fn testnet(&mut self) -> &mut Self { self.config = DhtConfig::default_testnet(); self } - pub fn mainnet(mut self) -> Self { + pub fn mainnet(&mut self) -> &mut Self { self.config = DhtConfig::default_mainnet(); self } - pub fn with_database_url(mut self, database_url: DbConnectionUrl) -> Self { + pub fn with_database_url(&mut self, database_url: DbConnectionUrl) -> &mut Self { self.config.database_url = database_url; self } - pub fn with_dedup_cache_trim_interval(mut self, trim_interval: Duration) -> Self { + pub fn with_dedup_cache_trim_interval(&mut self, trim_interval: Duration) -> &mut Self { self.config.dedup_cache_trim_interval = trim_interval; self } - pub fn with_dedup_cache_capacity(mut self, capacity: usize) -> Self { + pub fn with_dedup_cache_capacity(&mut self, capacity: usize) -> &mut Self { self.config.dedup_cache_capacity = capacity; self } - pub fn with_dedup_discard_hit_count(mut self, max_hit_count: usize) -> Self { + pub fn with_dedup_discard_hit_count(&mut self, max_hit_count: usize) -> &mut Self { self.config.dedup_allowed_message_occurrences = max_hit_count; self } - pub fn with_num_random_nodes(mut self, n: usize) -> Self { + pub fn with_num_random_nodes(&mut self, n: usize) -> &mut Self { self.config.num_random_nodes = n; self } - pub fn with_num_neighbouring_nodes(mut self, n: usize) -> Self { + pub fn with_num_neighbouring_nodes(&mut self, n: usize) -> &mut Self { self.config.num_neighbouring_nodes = n; self } - pub fn with_propagation_factor(mut self, propagation_factor: usize) -> Self { + pub fn with_propagation_factor(&mut self, propagation_factor: usize) -> &mut Self { self.config.propagation_factor = propagation_factor; self } - pub fn with_broadcast_factor(mut self, broadcast_factor: usize) -> Self { + pub fn with_broadcast_factor(&mut self, broadcast_factor: usize) -> &mut Self { self.config.broadcast_factor = broadcast_factor; self } - pub fn with_discovery_timeout(mut self, timeout: Duration) -> Self { + pub fn with_discovery_timeout(&mut self, timeout: Duration) -> &mut Self { self.config.discovery_request_timeout = timeout; self } - pub fn enable_auto_join(mut self) -> Self { + pub fn enable_auto_join(&mut self) -> &mut Self { self.config.auto_join = true; self } /// Build and initialize a Dht object. /// - /// Will panic not in a tokio runtime context - pub async fn build(self) -> Result { + /// Will panic if not in a tokio runtime context + pub async fn build( + &mut self, + node_identity: Arc, + peer_manager: Arc, + connectivity: ConnectivityRequester, + shutdown_signal: ShutdownSignal, + ) -> Result { + let outbound_tx = self + .outbound_tx + .take() + .ok_or(DhtInitializationError::BuilderNoOutboundMessageSender)?; + Dht::initialize( - self.config, - self.node_identity, - self.peer_manager, - self.outbound_tx, - self.connectivity, - self.shutdown_signal, + self.config.clone(), + node_identity, + peer_manager, + outbound_tx, + connectivity, + shutdown_signal, ) .await } diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 7e6fe21aa6..c38d0b5cb0 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -20,11 +20,13 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{network_discovery::NetworkDiscoveryConfig, storage::DbConnectionUrl}; +use crate::{network_discovery::NetworkDiscoveryConfig, storage::DbConnectionUrl, version::DhtProtocolVersion}; use std::time::Duration; #[derive(Debug, Clone)] pub struct DhtConfig { + /// The major protocol version to use. Default: DhtProtocolVersion::latest() + pub protocol_version: DhtProtocolVersion, /// The `DbConnectionUrl` for the Dht database. Default: In-memory database pub database_url: DbConnectionUrl, /// The size of the buffer (channel) which holds pending outbound message requests. @@ -142,6 +144,7 @@ impl Default for DhtConfig { fn default() -> Self { // NB: please remember to update field comments to reflect these defaults Self { + protocol_version: DhtProtocolVersion::latest(), num_neighbouring_nodes: 8, num_random_nodes: 4, propagation_factor: 4, diff --git a/comms/dht/src/consts.rs b/comms/dht/src/consts.rs deleted file mode 100644 index d34232cf17..0000000000 --- a/comms/dht/src/consts.rs +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2019, The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -/// Version for DHT envelope -pub const DHT_MAJOR_VERSION: u32 = 1; -pub const DHT_MINOR_VERSION: u32 = 0; diff --git a/comms/dht/src/crypt.rs b/comms/dht/src/crypt.rs index 18e81ba974..f0987907c9 100644 --- a/comms/dht/src/crypt.rs +++ b/comms/dht/src/crypt.rs @@ -20,19 +20,24 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::outbound::DhtOutboundError; +use crate::{ + envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageType, NodeDestination}, + outbound::DhtOutboundError, + version::DhtProtocolVersion, +}; use chacha20::{ cipher::{NewCipher, StreamCipher}, ChaCha20, Key, Nonce, }; +use digest::Digest; use rand::{rngs::OsRng, RngCore}; use std::mem::size_of; -use tari_comms::types::CommsPublicKey; +use tari_comms::types::{Challenge, CommsPublicKey}; use tari_crypto::{ keys::{DiffieHellmanSharedSecret, PublicKey}, - tari_utilities::ByteArray, + tari_utilities::{epoch_time::EpochTime, ByteArray}, }; pub fn generate_ecdh_secret(secret_key: &PK::K, public_key: &PK) -> PK @@ -78,6 +83,45 @@ pub fn encrypt(cipher_key: &CommsPublicKey, plain_text: &[u8]) -> Result Ok(ciphertext_integral_nonce) } +pub fn create_origin_mac_challenge(header: &DhtMessageHeader, body: &[u8]) -> Challenge { + create_origin_mac_challenge_parts( + header.version, + &header.destination, + &header.message_type, + header.flags, + header.expires, + header.ephemeral_public_key.as_ref(), + body, + ) +} + +pub fn create_origin_mac_challenge_parts( + protocol_version: DhtProtocolVersion, + destination: &NodeDestination, + message_type: &DhtMessageType, + flags: DhtMessageFlags, + expires: Option, + ephemeral_public_key: Option<&CommsPublicKey>, + body: &[u8], +) -> Challenge { + let mut mac_challenge = Challenge::new(); + // TODO: #testnetreset remove conditional + if protocol_version.as_major() > 1 { + mac_challenge.update(&protocol_version.to_bytes()); + mac_challenge.update(destination.to_inner_bytes().as_slice()); + mac_challenge.update(&(*message_type as i32).to_le_bytes()); + mac_challenge.update(&flags.bits().to_le_bytes()); + if let Some(t) = expires { + mac_challenge.update(&t.as_u64().to_le_bytes()); + } + if let Some(e_pk) = ephemeral_public_key.as_ref() { + mac_challenge.update(e_pk.as_bytes()); + } + } + mac_challenge.update(&body); + mac_challenge +} + #[cfg(test)] mod test { use super::*; diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index b78379dbba..1b762e02bc 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -40,6 +40,7 @@ use crate::{ store_forward::{StoreAndForwardError, StoreAndForwardRequest, StoreAndForwardRequester, StoreAndForwardService}, DedupLayer, DhtActorError, + DhtBuilder, DhtConfig, }; use futures::Future; @@ -71,6 +72,8 @@ pub enum DhtInitializationError { StoreAndForwardInitializationError(#[from] StoreAndForwardError), #[error("DhtActorInitializationError: {0}")] DhtActorInitializationError(#[from] DhtActorError), + #[error("Builder error: no outbound message sender set")] + BuilderNoOutboundMessageSender, } /// Responsible for starting the DHT actor, building the DHT middleware stack and as a factory @@ -102,7 +105,7 @@ pub struct Dht { } impl Dht { - pub async fn initialize( + pub(crate) async fn initialize( config: DhtConfig, node_identity: Arc, peer_manager: Arc, @@ -153,6 +156,10 @@ impl Dht { Ok(dht) } + pub fn builder() -> DhtBuilder { + DhtBuilder::new() + } + /// Create a DHT RPC service pub fn rpc_service(&self) -> rpc::DhtService { rpc::DhtService::new(rpc::DhtRpcServiceImpl::new(self.peer_manager.clone())) @@ -353,7 +360,7 @@ impl Dht { Arc::clone(&self.node_identity), self.dht_requester(), self.discovery_service_requester(), - self.config.saf_msg_validity, + &self.config, )) .layer(MessageLoggingLayer::new(format!( "Outbound [{}]", @@ -421,6 +428,7 @@ fn filter_messages_to_rebroadcast(msg: &DecryptedDhtMessage) -> bool { #[cfg(test)] mod test { + use super::*; use crate::{ crypt, envelope::DhtMessageFlags, @@ -434,7 +442,6 @@ mod test { make_node_identity, service_spy, }, - DhtBuilder, }; use std::{sync::Arc, time::Duration}; use tari_comms::{ @@ -460,17 +467,17 @@ mod test { let (out_tx, _) = mpsc::channel(10); let shutdown = Shutdown::new(); - let dht = DhtBuilder::new( - Arc::clone(&node_identity), - peer_manager, - out_tx, - connectivity, - shutdown.to_signal(), - ) - .local_test() - .build() - .await - .unwrap(); + let dht = Dht::builder() + .local_test() + .with_outbound_sender(out_tx) + .build( + Arc::clone(&node_identity), + peer_manager, + connectivity, + shutdown.to_signal(), + ) + .await + .unwrap(); let (out_tx, mut out_rx) = mpsc::channel(10); @@ -511,16 +518,16 @@ mod test { let (out_tx, _out_rx) = mpsc::channel(10); let shutdown = Shutdown::new(); - let dht = DhtBuilder::new( - Arc::clone(&node_identity), - peer_manager, - out_tx, - connectivity, - shutdown.to_signal(), - ) - .build() - .await - .unwrap(); + let dht = Dht::builder() + .with_outbound_sender(out_tx) + .build( + Arc::clone(&node_identity), + peer_manager, + connectivity, + shutdown.to_signal(), + ) + .await + .unwrap(); let (out_tx, mut out_rx) = mpsc::channel(10); @@ -563,16 +570,16 @@ mod test { let (oms_requester, oms_mock) = create_outbound_service_mock(1); // Send all outbound requests to the mock - let dht = DhtBuilder::new( - Arc::clone(&node_identity), - peer_manager, - oms_requester.get_mpsc_sender(), - connectivity, - shutdown.to_signal(), - ) - .build() - .await - .unwrap(); + let dht = Dht::builder() + .with_outbound_sender(oms_requester.get_mpsc_sender()) + .build( + Arc::clone(&node_identity), + peer_manager, + connectivity, + shutdown.to_signal(), + ) + .await + .unwrap(); let oms_mock_state = oms_mock.get_state(); task::spawn(oms_mock.run()); @@ -622,16 +629,16 @@ mod test { let (out_tx, _) = mpsc::channel(10); let shutdown = Shutdown::new(); - let dht = DhtBuilder::new( - Arc::clone(&node_identity), - peer_manager, - out_tx, - connectivity, - shutdown.to_signal(), - ) - .build() - .await - .unwrap(); + let dht = Dht::builder() + .with_outbound_sender(out_tx) + .build( + Arc::clone(&node_identity), + peer_manager, + connectivity, + shutdown.to_signal(), + ) + .await + .unwrap(); let spy = service_spy(); let mut service = dht.inbound_middleware_layer().layer(spy.to_service()); diff --git a/comms/dht/src/discovery/service.rs b/comms/dht/src/discovery/service.rs index 2cebf571b4..a07d96002f 100644 --- a/comms/dht/src/discovery/service.rs +++ b/comms/dht/src/discovery/service.rs @@ -29,11 +29,7 @@ use crate::{ }; use log::*; use rand::{rngs::OsRng, RngCore}; -use std::{ - collections::HashMap, - sync::Arc, - time::{Duration, Instant}, -}; +use std::{collections::HashMap, sync::Arc, time::Instant}; use tari_comms::{ log_if_error, peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerManager}, @@ -45,7 +41,6 @@ use tari_utilities::{hex::Hex, ByteArray}; use tokio::{ sync::{mpsc, oneshot}, task, - time, }; const LOG_TARGET: &str = "comms::dht::discovery_service"; @@ -327,8 +322,7 @@ impl DhtDiscoveryService { "Sending Discovery message for peer public key '{}' with destination {}", dest_public_key, destination ); - let send_states = self - .outbound_requester + self.outbound_requester .send_message_no_header( SendMessageParams::new() .broadcast(Vec::new()) @@ -343,32 +337,6 @@ impl DhtDiscoveryService { .await .map_err(DhtDiscoveryError::DiscoverySendFailed)?; - // Spawn a task to log how the sending of discovery went - task::spawn(async move { - debug!( - target: LOG_TARGET, - "Discovery sent to {} peer(s). Waiting to see how many got through.", - send_states.len() - ); - let result = time::timeout(Duration::from_secs(10), send_states.wait_percentage_success(0.51)).await; - match result { - Ok((succeeded, failed)) => { - let num_succeeded = succeeded.len(); - let num_failed = failed.len(); - - debug!( - target: LOG_TARGET, - "Discovery sent to a majority of neighbouring peers ({} succeeded, {} failed)", - num_succeeded, - num_failed - ); - }, - Err(_) => { - warn!(target: LOG_TARGET, "Failed to send discovery to a majority of peers"); - }, - } - }); - Ok(()) } } diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index c8b18d9e9b..4dc31f999e 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -37,6 +37,7 @@ use thiserror::Error; // Re-export applicable protos pub use crate::proto::envelope::{dht_header::Destination, DhtEnvelope, DhtHeader, DhtMessageType}; +use crate::version::DhtProtocolVersion; /// Utility function that converts a `chrono::DateTime` to a `prost::Timestamp` pub(crate) fn datetime_to_timestamp(datetime: DateTime) -> Timestamp { @@ -70,6 +71,8 @@ pub enum DhtMessageError { InvalidOrigin, #[error("Invalid or unrecognised DHT message type")] InvalidMessageType, + #[error("Invalid or unsupported DHT protocol version {0}")] + InvalidProtocolVersion(u32), #[error("Invalid or unrecognised network type")] InvalidNetwork, #[error("Invalid or unrecognised DHT message flags")] @@ -131,8 +134,7 @@ impl DhtMessageType { /// It is preferable to not to expose the generated prost structs publicly. #[derive(Clone, Debug, PartialEq, Eq)] pub struct DhtMessageHeader { - pub major: u32, - pub minor: u32, + pub version: DhtProtocolVersion, pub destination: NodeDestination, /// Encoded DhtOrigin. This can refer to the same peer that sent the message /// or another peer if the message is being propagated. @@ -158,8 +160,8 @@ impl Display for DhtMessageHeader { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!( f, - "DhtMessageHeader (Dest:{}, Type:{:?}, Flags:{:?}, Trace:{})", - self.destination, self.message_type, self.flags, self.message_tag + "DhtMessageHeader ({}, Dest:{}, Type:{:?}, Flags:{:?}, Trace:{})", + self.version, self.destination, self.message_type, self.flags, self.message_tag ) } } @@ -184,10 +186,10 @@ impl TryFrom for DhtMessageHeader { }; let expires: Option> = header.expires.map(timestamp_to_datetime); + let version = DhtProtocolVersion::try_from((header.major, header.minor))?; Ok(Self { - major: header.major, - minor: header.minor, + version, destination, origin_mac: header.origin_mac, ephemeral_public_key, @@ -214,8 +216,8 @@ impl From for DhtHeader { fn from(header: DhtMessageHeader) -> Self { let expires = header.expires.map(epochtime_to_datetime); Self { - major: header.major, - minor: header.minor, + major: header.version.as_major(), + minor: header.version.as_minor(), ephemeral_public_key: header .ephemeral_public_key .as_ref() diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index 65fd618f7f..08a2e61758 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -22,7 +22,7 @@ use crate::{ crypt, - envelope::{DhtMessageFlags, DhtMessageHeader}, + envelope::DhtMessageHeader, inbound::message::{DecryptedDhtMessage, DhtInboundMessage}, proto::envelope::OriginMac, DhtConfig, @@ -36,7 +36,7 @@ use tari_comms::{ message::EnvelopeBody, peer_manager::NodeIdentity, pipeline::PipelineError, - types::CommsPublicKey, + types::{Challenge, CommsPublicKey}, utils::signature, }; use tari_utilities::ByteArray; @@ -161,7 +161,10 @@ where S: Service let trace_id = message.dht_header.message_tag; let tag = message.tag; match Self::validate_and_decrypt_message(node_identity, message).await { - Ok(msg) => next_service.oneshot(msg).await, + Ok(msg) => { + trace!(target: LOG_TARGET, "Passing onto next service (Trace: {})", msg.tag); + next_service.oneshot(msg).await + }, Err(err @ OriginMacNotProvided) | Err(err @ EphemeralKeyNotProvided) | @@ -193,18 +196,10 @@ where S: Service ) -> Result { let dht_header = &message.dht_header; - let mut header_mac_bytes = Vec::with_capacity(256); - header_mac_bytes.extend_from_slice(&dht_header.major.to_le_bytes()); - header_mac_bytes.extend_from_slice(&dht_header.minor.to_le_bytes()); - header_mac_bytes.extend_from_slice(dht_header.destination.to_inner_bytes().as_slice()); - header_mac_bytes.extend_from_slice(&(dht_header.message_type as i32).to_le_bytes()); - header_mac_bytes.extend_from_slice(&dht_header.flags.bits().to_le_bytes()); - if let Some(t) = dht_header.expires { - header_mac_bytes.extend_from_slice(&t.as_u64().to_le_bytes()); - } + let mac_challenge = crypt::create_origin_mac_challenge(dht_header, &message.body); - if !dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) { - return Self::success_not_encrypted(message, header_mac_bytes).await; + if !dht_header.flags.is_encrypted() { + return Self::success_not_encrypted(message, mac_challenge).await; } trace!( target: LOG_TARGET, @@ -224,11 +219,9 @@ where S: Service // Decrypt and verify the origin let authenticated_origin = match Self::attempt_decrypt_origin_mac(&shared_secret, dht_header) { Ok((public_key, signature)) => { - header_mac_bytes.extend_from_slice(e_pk.as_bytes()); - // If this fails, discard the message because we decrypted and deserialized the message with our shared // ECDH secret but the message could not be authenticated - Self::authenticate_origin_mac(&public_key, &signature, header_mac_bytes.as_slice(), &message.body)?; + Self::authenticate_origin_mac(&public_key, &signature, mac_challenge)?; public_key }, Err(err) => { @@ -319,11 +312,9 @@ where S: Service fn authenticate_origin_mac( public_key: &CommsPublicKey, signature: &[u8], - mac_header: &[u8], - body: &[u8], + challenge: Challenge, ) -> Result<(), DecryptionError> { - let mac_body = [mac_header, body].concat(); - if signature::verify(public_key, signature, mac_body) { + if signature::verify_challenge(public_key, signature, challenge) { Ok(()) } else { Err(DecryptionError::OriginMacInvalidSignature) @@ -366,7 +357,7 @@ where S: Service async fn success_not_encrypted( message: DhtInboundMessage, - header_mac_bytes: Vec, + mac_challenge: Challenge, ) -> Result { let authenticated_pk = if message.dht_header.origin_mac.is_empty() { None @@ -376,12 +367,7 @@ where S: Service let public_key = CommsPublicKey::from_bytes(&origin_mac.public_key) .map_err(|_| DecryptionError::OriginMacInvalidPublicKey)?; - Self::authenticate_origin_mac( - &public_key, - &origin_mac.signature, - header_mac_bytes.as_slice(), - &message.body, - )?; + Self::authenticate_origin_mac(&public_key, &origin_mac.signature, mac_challenge)?; Some(public_key) }; diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index c9cdd103fd..a2f75b755f 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -20,10 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - consts::DHT_MAJOR_VERSION, - envelope::{DhtMessageFlags, DhtMessageHeader}, -}; +use crate::envelope::{DhtMessageFlags, DhtMessageHeader}; use std::{ fmt, fmt::{Display, Formatter}, @@ -38,7 +35,6 @@ use tari_comms::{ #[derive(Debug, Clone)] pub struct DhtInboundMessage { pub tag: MessageTag, - pub version: u32, pub source_peer: Arc, pub dht_header: DhtMessageHeader, /// True if forwarded via store and forward, otherwise false @@ -50,7 +46,6 @@ impl DhtInboundMessage { pub fn new(tag: MessageTag, dht_header: DhtMessageHeader, source_peer: Arc, body: Vec) -> Self { Self { tag, - version: DHT_MAJOR_VERSION, dht_header, source_peer, is_saf_message: false, @@ -68,8 +63,8 @@ impl Display for DhtInboundMessage { self.body.len(), self.dht_header.message_type, self.source_peer, - self.dht_header, self.dedup_hit_count, + self.dht_header, self.tag, ) } @@ -79,7 +74,6 @@ impl Display for DhtInboundMessage { #[derive(Debug, Clone)] pub struct DecryptedDhtMessage { pub tag: MessageTag, - pub version: u32, /// The _connected_ peer which sent or forwarded this message. This may not be the peer /// which created this message. pub source_peer: Arc, @@ -97,6 +91,10 @@ impl DecryptedDhtMessage { pub fn is_duplicate(&self) -> bool { self.dedup_hit_count > 1 } + + pub fn major_version(&self) -> u32 { + self.dht_header.version.as_major() + } } impl DecryptedDhtMessage { @@ -107,7 +105,6 @@ impl DecryptedDhtMessage { ) -> Self { Self { tag: message.tag, - version: message.version, source_peer: message.source_peer, authenticated_origin, dht_header: message.dht_header, @@ -122,7 +119,6 @@ impl DecryptedDhtMessage { pub fn failed(message: DhtInboundMessage) -> Self { Self { tag: message.tag, - version: message.version, source_peer: message.source_peer, authenticated_origin: None, dht_header: message.dht_header, @@ -193,7 +189,7 @@ impl Display for DecryptedDhtMessage { f, "version = {}, origin = {}, decryption_result = {}, header = ({}), is_saf_message = {}, is_saf_stored = \ {:?}, source_peer = {}, tag = {}", - self.version, + self.major_version(), self.authenticated_origin .as_ref() .map(ToString::to_string) diff --git a/comms/dht/src/lib.rs b/comms/dht/src/lib.rs index 710b354f7b..e4a078da50 100644 --- a/comms/dht/src/lib.rs +++ b/comms/dht/src/lib.rs @@ -88,7 +88,7 @@ //! .build() //! .unwrap(); //! let peer_manager = comms.start().unwrap().peer_manager(); -//! let dht = DhtBuilder::new(node_identity, peer_manager).finish(); +//! let dht = Dht::builder().build(node_identity, peer_manager)?; //! //! let inbound_pipeline = ServicePipeline::new( //! comms_in_rx, @@ -135,7 +135,6 @@ pub use connectivity::MetricsCollectorHandle; mod config; pub use config::DhtConfig; -mod consts; mod crypt; mod dht; @@ -159,6 +158,9 @@ mod proto; mod rpc; mod schema; +mod version; +pub use version::DhtProtocolVersion; + pub mod broadcast_strategy; pub mod domain_message; pub mod envelope; diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 75c271755a..46b398f21e 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -24,7 +24,6 @@ use super::{error::DhtOutboundError, message::DhtOutboundRequest}; use crate::{ actor::DhtRequester, broadcast_strategy::BroadcastStrategy, - consts::{DHT_MAJOR_VERSION, DHT_MINOR_VERSION}, crypt, discovery::DhtDiscoveryRequester, envelope::{datetime_to_epochtime, datetime_to_timestamp, DhtMessageFlags, DhtMessageHeader, NodeDestination}, @@ -35,6 +34,8 @@ use crate::{ SendMessageResponse, }, proto::envelope::{DhtMessageType, OriginMac}, + version::DhtProtocolVersion, + DhtConfig, }; use bytes::Bytes; use chrono::{DateTime, Utc}; @@ -47,7 +48,7 @@ use futures::{ }; use log::*; use rand::rngs::OsRng; -use std::{sync::Arc, task::Poll, time::Duration}; +use std::{sync::Arc, task::Poll}; use tari_comms::{ message::{MessageExt, MessageTag}, peer_manager::{NodeId, NodeIdentity, Peer}, @@ -70,6 +71,7 @@ pub struct BroadcastLayer { dht_discovery_requester: DhtDiscoveryRequester, node_identity: Arc, message_validity_window: chrono::Duration, + protocol_version: DhtProtocolVersion, } impl BroadcastLayer { @@ -77,14 +79,15 @@ impl BroadcastLayer { node_identity: Arc, dht_requester: DhtRequester, dht_discovery_requester: DhtDiscoveryRequester, - message_validity_window: Duration, + config: &DhtConfig, ) -> Self { BroadcastLayer { dht_requester, dht_discovery_requester, node_identity, - message_validity_window: chrono::Duration::from_std(message_validity_window) + message_validity_window: chrono::Duration::from_std(config.saf_msg_validity) .expect("message_validity_window is too large"), + protocol_version: config.protocol_version, } } } @@ -99,6 +102,7 @@ impl Layer for BroadcastLayer { self.dht_requester.clone(), self.dht_discovery_requester.clone(), self.message_validity_window, + self.protocol_version, ) } } @@ -112,6 +116,7 @@ pub struct BroadcastMiddleware { dht_discovery_requester: DhtDiscoveryRequester, node_identity: Arc, message_validity_window: chrono::Duration, + protocol_version: DhtProtocolVersion, } impl BroadcastMiddleware { @@ -121,6 +126,7 @@ impl BroadcastMiddleware { dht_requester: DhtRequester, dht_discovery_requester: DhtDiscoveryRequester, message_validity_window: chrono::Duration, + protocol_version: DhtProtocolVersion, ) -> Self { Self { next_service: service, @@ -128,6 +134,7 @@ impl BroadcastMiddleware { dht_discovery_requester, node_identity, message_validity_window, + protocol_version, } } } @@ -154,6 +161,7 @@ where self.dht_discovery_requester.clone(), msg, self.message_validity_window, + self.protocol_version, ) .handle(), ) @@ -167,6 +175,7 @@ struct BroadcastTask { dht_discovery_requester: DhtDiscoveryRequester, request: Option, message_validity_window: chrono::Duration, + protocol_version: DhtProtocolVersion, } type FinalMessageParts = (Option>, Option, Bytes); @@ -180,6 +189,7 @@ where S: Service dht_discovery_requester: DhtDiscoveryRequester, request: DhtOutboundRequest, message_validity_window: chrono::Duration, + protocol_version: DhtProtocolVersion, ) -> Self { Self { service, @@ -188,6 +198,7 @@ where S: Service dht_discovery_requester, request: Some(request), message_validity_window, + protocol_version, } } @@ -424,8 +435,8 @@ where S: Service force_origin, &destination, &dht_message_type, - &dht_flags, - expires_epochtime.as_ref(), + dht_flags, + expires_epochtime, body, )?; @@ -441,6 +452,7 @@ where S: Service let send_state = MessageSendState::new(tag, reply_rx); ( DhtOutboundMessage { + protocol_version: self.protocol_version, tag, destination_node_id: node_id, destination: destination.clone(), @@ -491,8 +503,8 @@ where S: Service include_origin: bool, destination: &NodeDestination, message_type: &DhtMessageType, - flags: &DhtMessageFlags, - expires: Option<&EpochTime>, + flags: DhtMessageFlags, + expires: Option, body: Bytes, ) -> Result { match encryption { @@ -504,18 +516,17 @@ where S: Service // Encrypt the message with the body let encrypted_body = crypt::encrypt(&shared_ephemeral_secret, &body)?; - let mut header_mac_bytes = Vec::with_capacity(256); - header_mac_bytes.extend_from_slice(&DHT_MAJOR_VERSION.to_le_bytes()); - header_mac_bytes.extend_from_slice(&DHT_MINOR_VERSION.to_le_bytes()); - header_mac_bytes.extend_from_slice(destination.to_inner_bytes().as_slice()); - header_mac_bytes.extend_from_slice(&(*message_type as i32).to_le_bytes()); - header_mac_bytes.extend_from_slice(&flags.bits().to_le_bytes()); - if let Some(t) = expires { - header_mac_bytes.extend_from_slice(&t.as_u64().to_le_bytes()); - } - header_mac_bytes.extend_from_slice(e_pk.as_bytes()); + let mac_challenge = crypt::create_origin_mac_challenge_parts( + self.protocol_version, + &destination, + message_type, + flags, + expires, + Some(&e_pk), + &encrypted_body, + ); // Sign the encrypted message - let origin_mac = create_origin_mac(&self.node_identity, header_mac_bytes.as_slice(), &encrypted_body)?; + let origin_mac = create_origin_mac(&self.node_identity, mac_challenge)?; // Encrypt and set the origin field let encrypted_origin_mac = crypt::encrypt(&shared_ephemeral_secret, &origin_mac)?; Ok(( @@ -528,16 +539,16 @@ where S: Service trace!(target: LOG_TARGET, "Encryption not requested for message"); if include_origin { - let mut header_mac_bytes = Vec::with_capacity(256); - header_mac_bytes.extend_from_slice(&DHT_MAJOR_VERSION.to_le_bytes()); - header_mac_bytes.extend_from_slice(&DHT_MINOR_VERSION.to_le_bytes()); - header_mac_bytes.extend_from_slice(destination.to_inner_bytes().as_slice()); - header_mac_bytes.extend_from_slice(&(*message_type as i32).to_le_bytes()); - header_mac_bytes.extend_from_slice(&flags.bits().to_le_bytes()); - if let Some(t) = expires { - header_mac_bytes.extend_from_slice(&t.as_u64().to_le_bytes()); - } - let origin_mac = create_origin_mac(&self.node_identity, &header_mac_bytes, &body)?; + let mac_challenge = crypt::create_origin_mac_challenge_parts( + self.protocol_version, + destination, + message_type, + flags, + expires, + None, + &body, + ); + let origin_mac = create_origin_mac(&self.node_identity, mac_challenge)?; Ok((None, Some(origin_mac.into()), body)) } else { Ok((None, None, body)) @@ -547,15 +558,8 @@ where S: Service } } -fn create_origin_mac( - node_identity: &NodeIdentity, - mac_header: &[u8], - body: &[u8], -) -> Result, DhtOutboundError> { - let mac_body = [mac_header, body].concat(); - - let signature = signature::sign(&mut OsRng, node_identity.secret_key().clone(), mac_body)?; - +fn create_origin_mac(node_identity: &NodeIdentity, mac_challenge: Challenge) -> Result, DhtOutboundError> { + let signature = signature::sign_challenge(&mut OsRng, node_identity.secret_key().clone(), mac_challenge)?; let mac = OriginMac { public_key: node_identity.public_key().to_vec(), signature: signature.to_binary()?, @@ -632,6 +636,7 @@ mod test { dht_requester, dht_discover_requester, chrono::Duration::seconds(10800), + DhtProtocolVersion::latest(), ); assert_send_static_service(&service); let (reply_tx, _reply_rx) = oneshot::channel(); @@ -674,6 +679,7 @@ mod test { dht_requester, dht_discover_requester, chrono::Duration::seconds(10800), + DhtProtocolVersion::latest(), ); let (reply_tx, reply_rx) = oneshot::channel(); @@ -722,6 +728,7 @@ mod test { dht_requester, dht_discover_requester, chrono::Duration::seconds(10800), + DhtProtocolVersion::latest(), ); let (reply_tx, reply_rx) = oneshot::channel(); diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index bb782dc2e5..74e8b8c720 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -23,6 +23,7 @@ use crate::{ envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageType, NodeDestination}, outbound::{message_params::FinalSendMessageParams, message_send_state::MessageSendStates}, + version::DhtProtocolVersion, }; use bytes::Bytes; use std::{fmt, fmt::Display, sync::Arc}; @@ -155,6 +156,7 @@ impl fmt::Display for DhtOutboundRequest { /// send a message #[derive(Debug)] pub struct DhtOutboundMessage { + pub protocol_version: DhtProtocolVersion, pub tag: MessageTag, pub destination_node_id: NodeId, pub custom_header: Option, diff --git a/comms/dht/src/outbound/serialize.rs b/comms/dht/src/outbound/serialize.rs index 86224dd926..9e8b324e96 100644 --- a/comms/dht/src/outbound/serialize.rs +++ b/comms/dht/src/outbound/serialize.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - consts::{DHT_MAJOR_VERSION, DHT_MINOR_VERSION}, outbound::message::DhtOutboundMessage, proto::envelope::{DhtEnvelope, DhtHeader}, }; @@ -66,6 +65,7 @@ where let next_service = self.inner.clone(); let DhtOutboundMessage { + protocol_version, tag, destination_node_id, custom_header, @@ -86,8 +86,8 @@ where destination_node_id.short_str() ); let dht_header = custom_header.map(DhtHeader::from).unwrap_or_else(|| DhtHeader { - major: DHT_MAJOR_VERSION, - minor: DHT_MINOR_VERSION, + major: protocol_version.as_major(), + minor: protocol_version.as_minor(), origin_mac: origin_mac.map(|b| b.to_vec()).unwrap_or_else(Vec::new), ephemeral_public_key: ephemeral_public_key.map(|e| e.to_vec()).unwrap_or_else(Vec::new), message_type: dht_message_type as i32, diff --git a/comms/dht/src/store_forward/database/stored_message.rs b/comms/dht/src/store_forward/database/stored_message.rs index 92272ac0cd..047ee2f605 100644 --- a/comms/dht/src/store_forward/database/stored_message.rs +++ b/comms/dht/src/store_forward/database/stored_message.rs @@ -50,7 +50,6 @@ pub struct NewStoredMessage { impl NewStoredMessage { pub fn try_construct(message: DecryptedDhtMessage, priority: StoredMessagePriority) -> Option { let DecryptedDhtMessage { - version, authenticated_origin, decryption_result, dht_header, @@ -63,7 +62,7 @@ impl NewStoredMessage { }; Some(Self { - version: version.try_into().ok()?, + version: dht_header.version.as_major().try_into().ok()?, origin_pubkey: authenticated_origin.as_ref().map(|pk| pk.to_hex()), message_type: dht_header.message_type as i32, destination_pubkey: dht_header.destination.public_key().map(|pk| pk.to_hex()), diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index 0332922bab..d19ec13b44 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -218,14 +218,14 @@ where S: Service (Some(node_id), Some(true)) => { debug!( target: LOG_TARGET, - "Forwarding SAF message directly to node: {}, Tag#{}", node_id, dht_header.message_tag + "Forwarding SAF message directly to node: {}, {}", node_id, dht_header.message_tag ); send_params.direct_or_closest_connected(node_id.clone(), excluded_peers); }, _ => { debug!( target: LOG_TARGET, - "Not storing this SAF message for {}, propagating it. Tag#{}", + "Propagating SAF message for {}, propagating it. {}", dht_header.destination, dht_header.message_tag ); diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index 3a1435f427..1733c02ab5 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -512,23 +512,11 @@ where S: Service header: &DhtMessageHeader, body: &[u8], ) -> Result<(Option, EnvelopeBody), StoreAndForwardError> { - let mut header_mac_bytes = Vec::with_capacity(256); - header_mac_bytes.extend_from_slice(&header.major.to_le_bytes()); - header_mac_bytes.extend_from_slice(&header.minor.to_le_bytes()); - header_mac_bytes.extend_from_slice(header.destination.to_inner_bytes().as_slice()); - header_mac_bytes.extend_from_slice(&(header.message_type as i32).to_le_bytes()); - header_mac_bytes.extend_from_slice(&header.flags.bits().to_le_bytes()); - - if let Some(t) = header.expires { - header_mac_bytes.extend_from_slice(&t.as_u64().to_le_bytes()); - } - if header.flags.contains(DhtMessageFlags::ENCRYPTED) { let ephemeral_public_key = header.ephemeral_public_key.as_ref().expect( "[store and forward] DHT header is invalid after validity check because it did not contain an \ ephemeral_public_key", ); - header_mac_bytes.extend_from_slice(&ephemeral_public_key.as_bytes()); trace!( target: LOG_TARGET, @@ -537,7 +525,8 @@ where S: Service ); let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), ephemeral_public_key); let decrypted = crypt::decrypt(&shared_secret, &header.origin_mac)?; - let authenticated_pk = Self::authenticate_message(&decrypted, header_mac_bytes.as_bytes(), body)?; + let mac_challenge = crypt::create_origin_mac_challenge(header, body); + let authenticated_pk = Self::authenticate_message(&decrypted, mac_challenge)?; trace!( target: LOG_TARGET, @@ -553,11 +542,8 @@ where S: Service Ok((Some(authenticated_pk), envelope_body)) } else { let authenticated_pk = if !header.origin_mac.is_empty() { - Some(Self::authenticate_message( - &header.origin_mac, - header_mac_bytes.as_bytes(), - body, - )?) + let mac_challenge = crypt::create_origin_mac_challenge(header, body); + Some(Self::authenticate_message(&header.origin_mac, mac_challenge)?) } else { None }; @@ -568,15 +554,13 @@ where S: Service fn authenticate_message( origin_mac_body: &[u8], - mac_header: &[u8], - body: &[u8], + challenge: Challenge, ) -> Result { let origin_mac = OriginMac::decode(origin_mac_body)?; let public_key = CommsPublicKey::from_bytes(&origin_mac.public_key).map_err(|_| StoreAndForwardError::InvalidOriginMac)?; - let full_mac_body = [mac_header, body].concat(); - if signature::verify(&public_key, &origin_mac.signature, full_mac_body) { + if signature::verify_challenge(&public_key, &origin_mac.signature, challenge) { Ok(public_key) } else { Err(StoreAndForwardError::InvalidOriginMac) @@ -621,7 +605,7 @@ mod test { stored_at: NaiveDateTime, ) -> StoredMessage { let body = message.into_bytes(); - let body_hash = hex::to_hex(&Challenge::new().chain(body.clone()).finalize()); + let body_hash = hex::to_hex(&Challenge::new().chain(&body).finalize()); StoredMessage { id: 1, version: 0, diff --git a/comms/dht/src/test_utils/makers.rs b/comms/dht/src/test_utils/makers.rs index 46bf0edc41..97c0d6fffb 100644 --- a/comms/dht/src/test_utils/makers.rs +++ b/comms/dht/src/test_utils/makers.rs @@ -25,6 +25,7 @@ use crate::{ inbound::DhtInboundMessage, outbound::message::DhtOutboundMessage, proto::envelope::{DhtEnvelope, DhtMessageType, OriginMac}, + version::DhtProtocolVersion, }; use rand::rngs::OsRng; use std::{convert::TryInto, sync::Arc}; @@ -33,7 +34,7 @@ use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags, PeerManager}, transports::MemoryTransport, - types::{CommsDatabase, CommsPublicKey, CommsSecretKey}, + types::{Challenge, CommsDatabase, CommsPublicKey, CommsSecretKey}, utils::signature, Bytes, }; @@ -83,25 +84,29 @@ pub fn make_dht_header( } else { NodeDestination::Unknown }; + let mut origin_mac = Vec::new(); + + if include_origin { + let challenge = crypt::create_origin_mac_challenge_parts( + DhtProtocolVersion::latest(), + &destination, + &DhtMessageType::None, + flags, + None, + Some(&e_pk), + &message, + ); + origin_mac = make_valid_origin_mac(node_identity, challenge); + if flags.is_encrypted() { + let shared_secret = crypt::generate_ecdh_secret(e_sk, node_identity.public_key()); + origin_mac = crypt::encrypt(&shared_secret, &origin_mac).unwrap() + } + } DhtMessageHeader { - major: 0, - minor: 0, - destination: destination.clone(), + version: DhtProtocolVersion::latest(), + destination, ephemeral_public_key: if flags.is_encrypted() { Some(e_pk.clone()) } else { None }, - origin_mac: if include_origin { - let mut header_mac_bytes = Vec::with_capacity(256); - header_mac_bytes.extend_from_slice(&0u32.to_le_bytes()); - header_mac_bytes.extend_from_slice(&0u32.to_le_bytes()); - header_mac_bytes.extend_from_slice((destination).to_inner_bytes().as_bytes()); - header_mac_bytes.extend_from_slice(&(DhtMessageType::None as i32).to_le_bytes()); - header_mac_bytes.extend_from_slice(&flags.bits().to_le_bytes()); - if flags.is_encrypted() { - header_mac_bytes.extend_from_slice(&e_pk.as_bytes()); - } - make_valid_origin_mac(node_identity, &e_sk, header_mac_bytes.as_bytes(), message, flags) - } else { - Vec::new() - }, + origin_mac, message_type: DhtMessageType::None, flags, message_tag: trace, @@ -109,28 +114,15 @@ pub fn make_dht_header( } } -pub fn make_valid_origin_mac( - node_identity: &NodeIdentity, - e_sk: &CommsSecretKey, - mac_header: &[u8], - body: &[u8], - flags: DhtMessageFlags, -) -> Vec { - let mac_body = [mac_header, body].concat(); +pub fn make_valid_origin_mac(node_identity: &NodeIdentity, challenge: Challenge) -> Vec { let mac = OriginMac { public_key: node_identity.public_key().to_vec(), - signature: signature::sign(&mut OsRng, node_identity.secret_key().clone(), mac_body) + signature: signature::sign_challenge(&mut OsRng, node_identity.secret_key().clone(), challenge) .unwrap() .to_binary() .unwrap(), }; - let body = mac.to_encoded_bytes(); - if flags.is_encrypted() { - let shared_secret = crypt::generate_ecdh_secret(e_sk, node_identity.public_key()); - crypt::encrypt(&shared_secret, &body).unwrap() - } else { - body - } + mac.to_encoded_bytes() } pub fn make_dht_inbound_message( @@ -210,6 +202,7 @@ pub fn build_peer_manager() -> Arc { pub fn create_outbound_message(body: &[u8]) -> DhtOutboundMessage { let msg_tag = MessageTag::new(); DhtOutboundMessage { + protocol_version: DhtProtocolVersion::latest(), tag: msg_tag, destination_node_id: NodeId::default(), destination: Default::default(), diff --git a/comms/dht/src/version.rs b/comms/dht/src/version.rs new file mode 100644 index 0000000000..7c2b85b740 --- /dev/null +++ b/comms/dht/src/version.rs @@ -0,0 +1,99 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::envelope::DhtMessageError; +use std::{ + convert::{TryFrom, TryInto}, + fmt, + fmt::{Display, Formatter}, + io::Write, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DhtProtocolVersion { + V1 { minor: u32 }, + V2 { minor: u32 }, +} + +impl DhtProtocolVersion { + pub fn latest() -> Self { + DhtProtocolVersion::v2() + } + + pub fn v1() -> Self { + DhtProtocolVersion::V1 { minor: 0 } + } + + pub fn v2() -> Self { + DhtProtocolVersion::V2 { minor: 0 } + } + + pub fn to_bytes(self) -> Vec { + let mut buf = Vec::with_capacity(4 * 2); + buf.write_all(&self.as_major().to_le_bytes()).unwrap(); + buf.write_all(&self.as_minor().to_le_bytes()).unwrap(); + buf + } + + pub fn as_major(&self) -> u32 { + use DhtProtocolVersion::*; + match self { + V1 { .. } => 1, + V2 { .. } => 2, + } + } + + pub fn as_minor(&self) -> u32 { + use DhtProtocolVersion::*; + match self { + V1 { minor } => *minor, + V2 { minor } => *minor, + } + } +} + +impl Display for DhtProtocolVersion { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "v{}.{}", self.as_major(), self.as_minor()) + } +} + +impl TryFrom for DhtProtocolVersion { + type Error = DhtMessageError; + + fn try_from(value: u32) -> Result { + (value, 0).try_into() + } +} + +impl TryFrom<(u32, u32)> for DhtProtocolVersion { + type Error = DhtMessageError; + + fn try_from((major, minor): (u32, u32)) -> Result { + use DhtProtocolVersion::*; + match major { + 0..=1 => Ok(V1 { minor }), + 2 => Ok(V2 { minor }), + n => Err(DhtMessageError::InvalidProtocolVersion(n)), + } + } +} diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index b924e759ec..cebedac101 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -44,7 +44,6 @@ use tari_comms_dht::{ outbound::{OutboundEncryption, SendMessageParams}, DbConnectionUrl, Dht, - DhtBuilder, DhtConfig, }; use tari_shutdown::{Shutdown, ShutdownSignal}; @@ -183,18 +182,18 @@ async fn setup_comms_dht( .build() .unwrap(); - let dht = DhtBuilder::new( - comms.node_identity(), - comms.peer_manager(), - outbound_tx, - comms.connectivity(), - comms.shutdown_signal(), - ) - .with_config(dht_config) - .with_database_url(DbConnectionUrl::MemoryShared(random::string(8))) - .build() - .await - .unwrap(); + let dht = Dht::builder() + .with_config(dht_config) + .with_database_url(DbConnectionUrl::MemoryShared(random::string(8))) + .with_outbound_sender(outbound_tx) + .build( + comms.node_identity(), + comms.peer_manager(), + comms.connectivity(), + comms.shutdown_signal(), + ) + .await + .unwrap(); for peer in peers { comms.peer_manager().add_peer(peer).await.unwrap(); @@ -816,10 +815,10 @@ async fn dht_propagate_message_contents_not_malleable_ban() { let node_B_node_id = node_B.node_identity().node_id().clone(); // Node C should ban node B - let banned_node_id = streams::assert_in_stream( + let banned_node_id = streams::assert_in_broadcast( &mut connectivity_events, - |r| match &*r.unwrap() { - ConnectivityEvent::PeerBanned(node_id) => Some(node_id.clone()), + |r| match r { + ConnectivityEvent::PeerBanned(node_id) => Some(node_id), _ => None, }, Duration::from_secs(10), @@ -832,14 +831,27 @@ async fn dht_propagate_message_contents_not_malleable_ban() { node_C.shutdown().await; } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_header_not_malleable() { - let node_C = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + env_logger::init(); + let node_C = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; // Node B knows about Node C - let mut node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let mut node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); log::info!( "NodeA = {}, NodeB = {}", diff --git a/comms/src/protocol/messaging/outbound.rs b/comms/src/protocol/messaging/outbound.rs index 438f287b40..a2c4288b4e 100644 --- a/comms/src/protocol/messaging/outbound.rs +++ b/comms/src/protocol/messaging/outbound.rs @@ -83,7 +83,8 @@ impl OutboundMessaging { Ok(_) => { event!( Level::DEBUG, - "Outbound messaging for peer has stopped because the stream was closed" + "Outbound messaging for peer '{}' has stopped because the stream was closed", + peer_node_id.short_str() ); debug!( @@ -94,8 +95,9 @@ impl OutboundMessaging { }, Err(MessagingProtocolError::Inactivity) => { event!( - Level::ERROR, - "Outbound messaging for peer has stopped because it was inactive" + Level::DEBUG, + "Outbound messaging for peer '{}' has stopped because it was inactive", + peer_node_id.short_str() ); debug!( target: LOG_TARGET, diff --git a/comms/src/utils/signature.rs b/comms/src/utils/signature.rs index 2df4b678db..0eda6e8770 100644 --- a/comms/src/utils/signature.rs +++ b/comms/src/utils/signature.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::types::{Challenge, CommsPublicKey}; -use blake2::digest::FixedOutput; use digest::Digest; use rand::{CryptoRng, Rng}; use tari_crypto::{ @@ -30,28 +29,22 @@ use tari_crypto::{ tari_utilities::message_format::MessageFormat, }; -pub fn sign( +pub fn sign_challenge( rng: &mut R, secret_key: ::K, - body: B, + challenge: Challenge, ) -> Result::K>, SchnorrSignatureError> where R: CryptoRng + Rng, - B: AsRef<[u8]>, { - let challenge = Challenge::new().chain(body).finalize_fixed(); let nonce = ::K::random(rng); - SchnorrSignature::sign(secret_key, nonce, challenge.as_slice()) + SchnorrSignature::sign(secret_key, nonce, &challenge.finalize()) } -/// Verify that the signature is valid for the message body -pub fn verify(public_key: &CommsPublicKey, signature: &[u8], body: B) -> bool -where B: AsRef<[u8]> { +/// Verify that the signature is valid for the challenge +pub fn verify_challenge(public_key: &CommsPublicKey, signature: &[u8], challenge: Challenge) -> bool { match SchnorrSignature::::K>::from_binary(signature) { - Ok(signature) => { - let challenge = Challenge::new().chain(body).finalize_fixed(); - signature.verify_challenge(public_key, challenge.as_slice()) - }, + Ok(signature) => signature.verify_challenge(public_key, &challenge.finalize()), Err(_) => false, } }