From af585c03cec9710548818413aea8f4dc2463ef90 Mon Sep 17 00:00:00 2001 From: Stanimal Date: Tue, 7 Apr 2020 11:12:55 +0200 Subject: [PATCH] Use ephemeral key for private messages (e.g Discovery) Added ephemeral key ECDH encryption for private messages. An "origin Message Authentication Code (MAC)" has been introduced to the message envelope header. Given, `k_e` - Ephemeral secret key `G` - Ristretto generator point `k_r` - Receiver secret key `K_s` - Sender public key `SK` - symmetric shared encryption key Sender: 1. Generate an ephemeral secret key `k_e` and ephemeral public key `k_eG` 1. Create a Diffie-hellman encryption key using the recipient public key `SK = k_e * k_rG` 1. Use `SK` to encrypt the message body (Chacha20) 1. Generate an origin MAC containing the sender public key and signature (Schnorr) that signs the encrypted message body 1. Encode (protobuf) the `OriginMac` structure and encrypt it using `SK` 1. Set the `origin_mac` field in the envelope header to the resulting cipher text from step 5. 1. Set the `ephemeral_public_key` in the envelope header Receiver: 1. Generate `SK = k_r * k_eG` 1. Attempt to decrypt and decode the `origin_mac` field using `SK` 1. If successful, validate the message body using the signature component 1. The receiver has verified the sender authenticity and message integrity - Removed non-private options for discovery --- applications/tari_base_node/src/parser.rs | 4 +- .../chain_metadata_service/service.rs | 7 +- .../core/src/base_node/service/initializer.rs | 1 + .../core/src/mempool/service/initializer.rs | 4 +- .../src/comms_connector/inbound_connector.rs | 48 ++--- .../p2p/src/comms_connector/peer_message.rs | 19 +- base_layer/p2p/src/domain_message.rs | 25 +-- .../p2p/src/services/liveness/service.rs | 17 +- base_layer/p2p/src/services/utils.rs | 1 + base_layer/p2p/src/test_utils.rs | 26 +-- .../tests/output_manager_service/service.rs | 8 +- .../tests/support/comms_and_services.rs | 4 +- .../tests/transaction_service/service.rs | 54 ++--- comms/dht/examples/memorynet.rs | 55 +---- .../down.sql | 1 + .../up.sql | 20 ++ comms/dht/src/dht.rs | 67 ++---- comms/dht/src/discovery/requester.rs | 77 ++----- comms/dht/src/discovery/service.rs | 58 ++--- comms/dht/src/envelope.rs | 130 +++--------- comms/dht/src/inbound/decryption.rs | 198 ++++++++++++------ comms/dht/src/inbound/dedup.rs | 11 +- comms/dht/src/inbound/deserialize.rs | 81 +++---- comms/dht/src/inbound/dht_handler/task.rs | 35 ++-- comms/dht/src/inbound/message.rs | 26 ++- comms/dht/src/inbound/validate.rs | 124 +++-------- comms/dht/src/outbound/broadcast.rs | 78 +++---- comms/dht/src/outbound/encryption.rs | 133 ++++++------ comms/dht/src/outbound/message.rs | 66 +++--- comms/dht/src/outbound/mock.rs | 11 +- comms/dht/src/outbound/mod.rs | 23 +- comms/dht/src/outbound/requester.rs | 6 +- comms/dht/src/outbound/serialize.rs | 149 +++++-------- comms/dht/src/proto/envelope.proto | 16 +- comms/dht/src/proto/tari.dht.envelope.rs | 19 +- comms/dht/src/schema.rs | 3 +- .../store_forward/database/stored_message.rs | 36 ++-- comms/dht/src/store_forward/error.rs | 4 +- comms/dht/src/store_forward/forward.rs | 17 +- comms/dht/src/store_forward/message.rs | 4 +- .../dht/src/store_forward/saf_handler/task.rs | 162 +++++++------- comms/dht/src/store_forward/store.rs | 88 ++++---- .../dht/src/test_utils/dht_discovery_mock.rs | 3 +- comms/dht/src/test_utils/makers.rs | 116 ++++++++-- .../src/test_utils/store_and_forward_mock.rs | 1 - comms/dht/tests/dht.rs | 9 +- comms/src/connection_manager/manager.rs | 26 +-- comms/src/message/envelope.rs | 16 +- comms/src/message/error.rs | 4 +- comms/src/message/mod.rs | 9 +- comms/src/peer_manager/peer.rs | 14 +- comms/src/pipeline/error.rs | 8 +- comms/src/protocol/identity.rs | 3 +- 53 files changed, 982 insertions(+), 1143 deletions(-) create mode 100644 comms/dht/migrations/2020-04-07-161148_remove_origin_signature/down.sql create mode 100644 comms/dht/migrations/2020-04-07-161148_remove_origin_signature/up.sql diff --git a/applications/tari_base_node/src/parser.rs b/applications/tari_base_node/src/parser.rs index 3894a2a3ea..4d5e5f1d78 100644 --- a/applications/tari_base_node/src/parser.rs +++ b/applications/tari_base_node/src/parser.rs @@ -50,7 +50,7 @@ use tari_comms::{ types::CommsPublicKey, NodeIdentity, }; -use tari_comms_dht::{envelope::NodeDestination, DhtDiscoveryRequester}; +use tari_comms_dht::DhtDiscoveryRequester; use tari_core::{ base_node::LocalNodeCommsInterface, blocks::BlockHeader, @@ -486,7 +486,7 @@ impl Parser { self.executor.spawn(async move { let start = Instant::now(); println!("🌎 Peer discovery started."); - match dht.discover_peer(dest_pubkey, None, NodeDestination::Unknown).await { + match dht.discover_peer(dest_pubkey).await { Ok(p) => { let end = Instant::now(); println!("⚡️ Discovery succeeded in {}ms!", (end - start).as_millis()); diff --git a/base_layer/core/src/base_node/chain_metadata_service/service.rs b/base_layer/core/src/base_node/chain_metadata_service/service.rs index a3107f9e50..1b4bab76d0 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/service.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/service.rs @@ -120,7 +120,7 @@ impl ChainMetadataService { /// Send this node's metadata to async fn update_liveness_chain_metadata(&mut self) -> Result<(), ChainMetadataSyncError> { let chain_metadata = self.base_node.get_metadata().await?; - let bytes = proto::ChainMetadata::from(chain_metadata).to_encoded_bytes()?; + let bytes = proto::ChainMetadata::from(chain_metadata).to_encoded_bytes(); self.liveness .set_pong_metadata_entry(MetadataKey::ChainMetadata, bytes) .await?; @@ -300,10 +300,7 @@ mod test { let (liveness_handle, _) = create_p2p_liveness_mock(1); let mut metadata = Metadata::new(); let proto_chain_metadata = create_sample_proto_chain_metadata(); - metadata.insert( - MetadataKey::ChainMetadata, - proto_chain_metadata.to_encoded_bytes().unwrap(), - ); + metadata.insert(MetadataKey::ChainMetadata, proto_chain_metadata.to_encoded_bytes()); let node_id = NodeId::new(); let pong_event = PongEvent { diff --git a/base_layer/core/src/base_node/service/initializer.rs b/base_layer/core/src/base_node/service/initializer.rs index 643ac067de..d2aed71668 100644 --- a/base_layer/core/src/base_node/service/initializer.rs +++ b/base_layer/core/src/base_node/service/initializer.rs @@ -136,6 +136,7 @@ async fn extract_block(msg: Arc) -> Option> { Some(DomainMessage { source_peer: msg.source_peer.clone(), dht_header: msg.dht_header.clone(), + authenticated_origin: msg.authenticated_origin.clone(), inner: block, }) }, diff --git a/base_layer/core/src/mempool/service/initializer.rs b/base_layer/core/src/mempool/service/initializer.rs index e173a026e8..8d62cc2dfc 100644 --- a/base_layer/core/src/mempool/service/initializer.rs +++ b/base_layer/core/src/mempool/service/initializer.rs @@ -121,10 +121,9 @@ async fn extract_transaction(msg: Arc) -> Option { let tx = match Transaction::try_from(tx) { Err(e) => { - let origin = msg.origin_public_key(); warn!( target: LOG_TARGET, - "Inbound transaction message from {} was ill-formed. {}", origin, e + "Inbound transaction message from {} was ill-formed. {}", msg.source_peer.public_key, e ); return None; }, @@ -133,6 +132,7 @@ async fn extract_transaction(msg: Arc) -> Option InboundDomainConnector { let DecryptedDhtMessage { source_peer, dht_header, + authenticated_origin, .. } = inbound_message; let peer_message = PeerMessage { message_header: header, source_peer: Clone::clone(&*source_peer), + authenticated_origin, dht_header, body: msg_bytes, }; @@ -141,21 +143,17 @@ mod test { use crate::test_utils::{make_dht_inbound_message, make_node_identity}; use futures::{channel::mpsc, executor::block_on, StreamExt}; use tari_comms::{message::MessageExt, wrap_in_envelope_body}; - use tari_comms_dht::{domain_message::MessageHeader, envelope::DhtMessageFlags}; + use tari_comms_dht::domain_message::MessageHeader; use tower::ServiceExt; #[tokio_macros::test_basic] async fn handle_message() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); - let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap(); - - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let msg = wrap_in_envelope_body!(header, b"my message".to_vec()); + + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap(); let peer_message = block_on(rx.next()).unwrap(); @@ -167,14 +165,10 @@ mod test { async fn send_on_sink() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); - let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap(); + let msg = wrap_in_envelope_body!(header, b"my message".to_vec()); - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).send(decrypted).await.unwrap(); @@ -187,14 +181,10 @@ mod test { async fn handle_message_fail_deserialize() { let (tx, mut rx) = mpsc::channel(1); let header = b"dodgy header".to_vec(); - let msg = wrap_in_envelope_body!(header, b"message".to_vec()).unwrap(); - - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let msg = wrap_in_envelope_body!(header, b"message".to_vec()); + + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap_err(); assert!(rx.try_next().unwrap().is_none()); @@ -206,13 +196,9 @@ mod test { // from it's call function let (tx, _) = mpsc::channel(1); let header = MessageHeader::new(123); - let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap(); - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let msg = wrap_in_envelope_body!(header, b"my message".to_vec()); + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); let result = InboundDomainConnector::new(tx).oneshot(decrypted).await; assert!(result.is_err()); } diff --git a/base_layer/p2p/src/comms_connector/peer_message.rs b/base_layer/p2p/src/comms_connector/peer_message.rs index f1d53c205f..4a62e393b6 100644 --- a/base_layer/p2p/src/comms_connector/peer_message.rs +++ b/base_layer/p2p/src/comms_connector/peer_message.rs @@ -34,28 +34,13 @@ pub struct PeerMessage { pub source_peer: Peer, /// Domain message header pub message_header: MessageHeader, + /// This messages authenticated origin, otherwise None + pub authenticated_origin: Option, /// Serialized message data pub body: Vec, } impl PeerMessage { - pub fn new(dht_header: DhtMessageHeader, source_peer: Peer, message_header: MessageHeader, body: Vec) -> Self { - Self { - body, - message_header, - dht_header, - source_peer, - } - } - - pub fn origin_public_key(&self) -> &CommsPublicKey { - self.dht_header - .origin - .as_ref() - .map(|o| &o.public_key) - .unwrap_or(&self.source_peer.public_key) - } - pub fn decode_message(&self) -> Result where T: prost::Message + Default { let msg = T::decode(self.body.as_slice())?; diff --git a/base_layer/p2p/src/domain_message.rs b/base_layer/p2p/src/domain_message.rs index 7aea2a50c8..0299a466fb 100644 --- a/base_layer/p2p/src/domain_message.rs +++ b/base_layer/p2p/src/domain_message.rs @@ -32,6 +32,8 @@ pub struct DomainMessage { /// This DHT header of this message. If `DhtMessageHeader::origin_public_key` is different from the /// `source_peer.public_key`, this message was forwarded. pub dht_header: DhtMessageHeader, + /// The authenticated origin public key of this message or None a message origin was not provided. + pub authenticated_origin: Option, /// The domain-level message pub inner: T, } @@ -48,32 +50,15 @@ impl DomainMessage { /// Consumes this object returning the public key of the original sender of this message and the message itself pub fn into_origin_and_inner(self) -> (CommsPublicKey, T) { let inner = self.inner; - let pk = self - .dht_header - .origin - .map(|o| o.public_key) - .unwrap_or(self.source_peer.public_key); + let pk = self.authenticated_origin.unwrap_or(self.source_peer.public_key); (pk, inner) } - /// Returns true of this message was forwarded from another peer, otherwise false - pub fn is_forwarded(&self) -> bool { - self.dht_header - .origin - .as_ref() - // If the source and origin are different, then the message was forwarded - .map(|o| o.public_key != self.source_peer.public_key) - // Otherwise, if no origin is specified, the message was sent directly from the peer - .unwrap_or(false) - } - /// Returns the public key that sent this message. If no origin is specified, then the source peer /// sent this message. pub fn origin_public_key(&self) -> &CommsPublicKey { - self.dht_header - .origin + self.authenticated_origin .as_ref() - .map(|o| &o.public_key) .unwrap_or(&self.source_peer.public_key) } @@ -88,6 +73,7 @@ impl DomainMessage { DomainMessage { source_peer: self.source_peer, dht_header: self.dht_header, + authenticated_origin: self.authenticated_origin, inner, } } @@ -103,6 +89,7 @@ impl DomainMessage { Ok(DomainMessage { source_peer: self.source_peer, dht_header: self.dht_header, + authenticated_origin: self.authenticated_origin, inner, }) } diff --git a/base_layer/p2p/src/services/liveness/service.rs b/base_layer/p2p/src/services/liveness/service.rs index 8d71de1c83..fe9ce3fe9d 100644 --- a/base_layer/p2p/src/services/liveness/service.rs +++ b/base_layer/p2p/src/services/liveness/service.rs @@ -513,13 +513,16 @@ mod test { &[], ); DomainMessage { - dht_header: DhtMessageHeader::new( - Default::default(), - DhtMessageType::None, - None, - Network::LocalTest, - Default::default(), - ), + dht_header: DhtMessageHeader { + version: 0, + destination: Default::default(), + origin_mac: Vec::new(), + ephemeral_public_key: None, + message_type: DhtMessageType::None, + network: Network::LocalTest, + flags: Default::default(), + }, + authenticated_origin: None, source_peer, inner, } diff --git a/base_layer/p2p/src/services/utils.rs b/base_layer/p2p/src/services/utils.rs index 111926ab58..70699d9f10 100644 --- a/base_layer/p2p/src/services/utils.rs +++ b/base_layer/p2p/src/services/utils.rs @@ -43,6 +43,7 @@ where T: prost::Message + Default { Ok(DomainMessage { source_peer: serialized.source_peer.clone(), dht_header: serialized.dht_header.clone(), + authenticated_origin: serialized.authenticated_origin.clone(), inner: serialized.decode_message()?, }) } diff --git a/base_layer/p2p/src/test_utils.rs b/base_layer/p2p/src/test_utils.rs index 1582449e63..0b90f38724 100644 --- a/base_layer/p2p/src/test_utils.rs +++ b/base_layer/p2p/src/test_utils.rs @@ -25,13 +25,11 @@ use std::sync::Arc; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerFlags}, - utils::signature, }; use tari_comms_dht::{ - envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, DhtMessageType, Network, NodeDestination}, + envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageType, Network, NodeDestination}, inbound::DhtInboundMessage, }; -use tari_crypto::tari_utilities::message_format::MessageFormat; macro_rules! unwrap_oms_send_msg { ($var:expr, reply_value=$reply_value:expr) => { @@ -61,31 +59,21 @@ pub fn make_node_identity() -> Arc { ) } -pub fn make_dht_header(node_identity: &NodeIdentity, message: &Vec, flags: DhtMessageFlags) -> DhtMessageHeader { +pub fn make_dht_header() -> DhtMessageHeader { DhtMessageHeader { version: 0, destination: NodeDestination::Unknown, - origin: Some(DhtMessageOrigin { - public_key: node_identity.public_key().clone(), - signature: signature::sign(&mut OsRng, node_identity.secret_key().clone(), message) - .unwrap() - .to_binary() - .unwrap(), - }), + origin_mac: Vec::new(), + ephemeral_public_key: None, message_type: DhtMessageType::None, network: Network::LocalTest, - flags, + flags: DhtMessageFlags::NONE, } } -pub fn make_dht_inbound_message( - node_identity: &NodeIdentity, - message: Vec, - flags: DhtMessageFlags, -) -> DhtInboundMessage -{ +pub fn make_dht_inbound_message(node_identity: &NodeIdentity, message: Vec) -> DhtInboundMessage { DhtInboundMessage::new( - make_dht_header(node_identity, &message, flags), + make_dht_header(), Arc::new(Peer::new( node_identity.public_key().clone(), node_identity.node_id().clone(), diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index a91a1055d6..0660d4665e 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -680,8 +680,8 @@ fn test_startup_utxo_scan() { let output3 = UnblindedOutput::new(MicroTari::from(value3), key3, None); runtime.block_on(oms.add_output(output3.clone())).unwrap(); - let call = outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body .decode_part::(1) .unwrap() @@ -755,8 +755,8 @@ fn test_startup_utxo_scan() { runtime.block_on(oms.sync_with_base_node()).unwrap(); - let call = outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body .decode_part::(1) .unwrap() diff --git a/base_layer/wallet/tests/support/comms_and_services.rs b/base_layer/wallet/tests/support/comms_and_services.rs index a8d50ed86a..13a6afd2da 100644 --- a/base_layer/wallet/tests/support/comms_and_services.rs +++ b/base_layer/wallet/tests/support/comms_and_services.rs @@ -77,13 +77,15 @@ pub fn create_dummy_message(inner: T, public_key: &CommsPublicKey) -> DomainM ); DomainMessage { dht_header: DhtMessageHeader { - origin: None, + ephemeral_public_key: None, + origin_mac: Vec::new(), version: Default::default(), message_type: Default::default(), flags: Default::default(), network: Network::LocalTest, destination: Default::default(), }, + authenticated_origin: None, source_peer: peer_source, inner, } diff --git a/base_layer/wallet/tests/transaction_service/service.rs b/base_layer/wallet/tests/transaction_service/service.rs index 4bdf23d1ff..1ec4b344b2 100644 --- a/base_layer/wallet/tests/transaction_service/service.rs +++ b/base_layer/wallet/tests/transaction_service/service.rs @@ -791,7 +791,7 @@ fn test_accepting_unknown_tx_id_and_malformed_reply(1) .unwrap() @@ -954,7 +954,7 @@ fn finalize_tx_with_incorrect_pubkey(al .wait_call_count(1, Duration::from_secs(10)) .unwrap(); let (_, body) = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(body.as_slice()).unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let recipient_reply: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -1058,7 +1058,7 @@ fn finalize_tx_with_missing_output(alic .wait_call_count(1, Duration::from_secs(10)) .unwrap(); let (_, body) = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(body.as_slice()).unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let recipient_reply: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -1448,8 +1448,8 @@ fn transaction_mempool_broadcast() { )) .unwrap(); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() @@ -1478,8 +1478,8 @@ fn transaction_mempool_broadcast() { timeout = Duration::from_secs(20) ) }); - let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_reply_msg: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -1526,8 +1526,8 @@ fn transaction_mempool_broadcast() { assert_eq!(alice_completed_tx.status, TransactionStatus::Completed); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let msr = envelope_body .decode_part::(1) .unwrap() @@ -1537,9 +1537,9 @@ fn transaction_mempool_broadcast() { let _ = alice_outbound_service.pop_call().unwrap(); // burn a mempool request let _ = alice_outbound_service.pop_call().unwrap(); // burn a mempool request - let call = alice_outbound_service.pop_call().unwrap(); // this should be the sending of the finalized tx to the receiver + let (_, body) = alice_outbound_service.pop_call().unwrap(); // this should be the sending of the finalized tx to the receiver - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_finalized = envelope_body .decode_part::(1) .unwrap() @@ -1772,8 +1772,8 @@ fn transaction_base_node_monitoring() { )) .unwrap(); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() @@ -1802,8 +1802,8 @@ fn transaction_base_node_monitoring() { timeout = Duration::from_secs(20) ) }); - let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_reply_msg: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -1845,8 +1845,8 @@ fn transaction_base_node_monitoring() { )) .unwrap(); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() @@ -1875,8 +1875,8 @@ fn transaction_base_node_monitoring() { timeout = Duration::from_secs(20) ) }); - let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_reply_msg: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -1912,8 +1912,8 @@ fn transaction_base_node_monitoring() { let _ = alice_outbound_service.pop_call().unwrap(); // burn a base node request - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let msr = envelope_body .decode_part::(1) .unwrap() @@ -2362,8 +2362,8 @@ fn transaction_cancellation_when_not_in_mempool() { )) .unwrap(); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() @@ -2392,8 +2392,8 @@ fn transaction_cancellation_when_not_in_mempool() { timeout = Duration::from_secs(30) ) }); - let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_reply_msg: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -2429,8 +2429,8 @@ fn transaction_cancellation_when_not_in_mempool() { let _ = alice_outbound_service.pop_call().unwrap(); // burn a base node request - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let msr = envelope_body .decode_part::(1) .unwrap() diff --git a/comms/dht/examples/memorynet.rs b/comms/dht/examples/memorynet.rs index ef04d93a96..527019d4ab 100644 --- a/comms/dht/examples/memorynet.rs +++ b/comms/dht/examples/memorynet.rs @@ -68,8 +68,7 @@ use tari_comms::{ ConnectionManagerEvent, PeerConnection, }; -use tari_comms_dht::{envelope::NodeDestination, inbound::DecryptedDhtMessage, Dht, DhtBuilder}; -use tari_crypto::tari_utilities::ByteArray; +use tari_comms_dht::{inbound::DecryptedDhtMessage, Dht, DhtBuilder}; use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; use tari_test_utils::{paths::create_temporary_data_path, random}; use tokio::{runtime, time}; @@ -216,17 +215,7 @@ async fn main() { peer_list_summary(&wallets).await; - total_messages += discovery(&wallets, &mut messaging_events_rx, false, true).await; - - take_a_break().await; - total_messages += drain_messaging_events(&mut messaging_events_rx, false).await; - - total_messages += discovery(&wallets, &mut messaging_events_rx, true, false).await; - - take_a_break().await; - total_messages += drain_messaging_events(&mut messaging_events_rx, false).await; - - total_messages += discovery(&wallets, &mut messaging_events_rx, false, false).await; + total_messages += discovery(&wallets, &mut messaging_events_rx).await; take_a_break().await; total_messages += drain_messaging_events(&mut messaging_events_rx, false).await; @@ -247,13 +236,7 @@ async fn shutdown_all(nodes: Vec) { future::join_all(tasks).await; } -async fn discovery( - wallets: &[TestNode], - messaging_events_rx: &mut MessagingEventRx, - use_network_region: bool, - use_destination_node_id: bool, -) -> usize -{ +async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut MessagingEventRx) -> usize { let mut successes = 0; let mut total_messages = 0; let mut total_time = Duration::from_secs(0); @@ -265,30 +248,11 @@ async fn discovery( peer_list_summary(&[wallet1, wallet2]).await; - let mut destination = NodeDestination::Unknown; - if use_network_region { - let mut new_node_id = [0; 13]; - let node_id = wallet2.get_node_id(); - let buf = &mut new_node_id[..10]; - buf.copy_from_slice(&node_id.as_bytes()[..10]); - let regional_node_id = NodeId::from_bytes(&new_node_id).unwrap(); - destination = NodeDestination::NodeId(Box::new(regional_node_id)); - } - - let mut node_id_dest = None; - if use_destination_node_id { - node_id_dest = Some(wallet2.get_node_id()); - } - let start = Instant::now(); let discovery_result = wallet1 .dht .discovery_service_requester() - .discover_peer( - Box::new(wallet2.node_identity().public_key().clone()), - node_id_dest, - destination, - ) + .discover_peer(Box::new(wallet2.node_identity().public_key().clone())) .await; let end = Instant::now(); @@ -387,15 +351,13 @@ async fn do_store_and_forward_discovery( node_identity.public_key(), ); let mut first_wallet_discovery_req = wallets[0].dht.discovery_service_requester(); + + let start = Instant::now(); let discovery_task = runtime::Handle::current().spawn({ let node_identity = node_identity.clone(); async move { first_wallet_discovery_req - .discover_peer( - Box::new(node_identity.public_key().clone()), - Some(node_identity.node_id().clone()), - node_identity.public_key().clone().into(), - ) + .discover_peer(Box::new(node_identity.public_key().clone())) .await } }); @@ -410,7 +372,6 @@ async fn do_store_and_forward_discovery( let (comms, dht) = setup_comms_dht(node_identity, create_peer_storage(all_peers), tx).await; let wallet = TestNode::new(comms, dht, None, ims_rx, messaging_tx); - let start = Instant::now(); wallet.dht.dht_requester().send_request_stored_messages().await.unwrap(); total_messages += match discovery_task.await.unwrap() { @@ -497,8 +458,8 @@ fn connection_manager_logger( PeerConnectWillClose(_, node_id, direction) => { println!( "'{}' will disconnect {} connection to '{}'", - direction, get_name(node_id), + direction, node_name, ); }, diff --git a/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/down.sql b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/down.sql new file mode 100644 index 0000000000..a4b1198712 --- /dev/null +++ b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/down.sql @@ -0,0 +1 @@ +-- No going back diff --git a/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/up.sql b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/up.sql new file mode 100644 index 0000000000..d672ff7ac5 --- /dev/null +++ b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/up.sql @@ -0,0 +1,20 @@ +DROP TABLE stored_messages; + +CREATE TABLE stored_messages( + id INTEGER NOT NULL PRIMARY KEY, + version INT NOT NULL, + origin_pubkey TEXT, + message_type INT NOT NULL, + destination_pubkey TEXT, + destination_node_id TEXT, + header BLOB NOT NULL, + body BLOB NOT NULL, + is_encrypted BOOLEAN NOT NULL CHECK (is_encrypted IN (0,1)), + priority INT NOT NULL, + stored_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX idx_stored_messages_destination_pubkey ON stored_messages (destination_pubkey); +CREATE INDEX idx_stored_messages_destination_node_id ON stored_messages (destination_node_id); +CREATE INDEX idx_stored_messages_stored_at ON stored_messages (stored_at); +CREATE INDEX idx_stored_messages_priority ON stored_messages (priority); diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 2e5d6c2182..4453d5fbbb 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -190,11 +190,8 @@ impl Dht { S::Future: Send, { let builder = ServiceBuilder::new() - .layer(inbound::DeserializeLayer::new()) - .layer(inbound::ValidateLayer::new( - self.config.network, - self.outbound_requester(), - )) + .layer(inbound::DeserializeLayer) + .layer(inbound::ValidateLayer::new(self.config.network)) .layer(inbound::DedupLayer::new(self.dht_requester())); // FIXME: There is an unresolved stack overflow issue on windows. Seems that we've reached the limit on stack @@ -262,7 +259,7 @@ impl Dht { )) .layer(MessageLoggingLayer::new("Outbound message: ")) .layer(outbound::EncryptionLayer::new(Arc::clone(&self.node_identity))) - .layer(outbound::SerializeLayer::new(Arc::clone(&self.node_identity))) + .layer(outbound::SerializeLayer) .into_inner() } @@ -348,14 +345,9 @@ mod test { let mut service = dht.inbound_middleware_layer().layer(SinkService::new(out_tx)); - let msg = wrap_in_envelope_body!(b"secret".to_vec()).unwrap(); - let dht_envelope = make_dht_envelope( - &node_identity, - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let inbound_message = - make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().unwrap().into()); + let msg = wrap_in_envelope_body!(b"secret".to_vec()); + let dht_envelope = make_dht_envelope(&node_identity, msg.to_encoded_bytes(), DhtMessageFlags::empty(), false); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); let msg = { service.call(inbound_message).await.unwrap(); @@ -376,7 +368,7 @@ mod test { let (connection_manager, _) = create_connection_manager_mock(1); // Dummy out channel, we are not testing outbound here. - let (out_tx, _) = mpsc::channel(10); + let (out_tx, _out_rx) = mpsc::channel(10); let shutdown = Shutdown::new(); let dht = DhtBuilder::new( @@ -392,13 +384,10 @@ mod test { let mut service = dht.inbound_middleware_layer().layer(SinkService::new(out_tx)); - let msg = wrap_in_envelope_body!(b"secret".to_vec()).unwrap(); + let msg = wrap_in_envelope_body!(b"secret".to_vec()); // Encrypt for self - let ecdh_key = crypt::generate_ecdh_secret(node_identity.secret_key(), node_identity.public_key()); - let encrypted_bytes = crypt::encrypt(&ecdh_key, &msg.to_encoded_bytes().unwrap()).unwrap(); - let dht_envelope = make_dht_envelope(&node_identity, encrypted_bytes, DhtMessageFlags::ENCRYPTED); - let inbound_message = - make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().unwrap().into()); + let dht_envelope = make_dht_envelope(&node_identity, msg.to_encoded_bytes(), DhtMessageFlags::ENCRYPTED, true); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); let msg = { service.call(inbound_message).await.unwrap(); @@ -437,25 +426,17 @@ mod test { let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); - let msg = wrap_in_envelope_body!(b"unencrypteable".to_vec()).unwrap(); + let msg = wrap_in_envelope_body!(b"unencrypteable".to_vec()); // Encrypt for someone else let node_identity2 = make_node_identity(); let ecdh_key = crypt::generate_ecdh_secret(node_identity2.secret_key(), node_identity2.public_key()); - let encrypted_bytes = crypt::encrypt(&ecdh_key, &msg.to_encoded_bytes().unwrap()).unwrap(); - let dht_envelope = make_dht_envelope(&node_identity, encrypted_bytes, DhtMessageFlags::ENCRYPTED); - - let origin_sig = dht_envelope - .header - .as_ref() - .unwrap() - .origin - .as_ref() - .unwrap() - .signature - .clone(); - let inbound_message = - make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().unwrap().into()); + let encrypted_bytes = crypt::encrypt(&ecdh_key, &msg.to_encoded_bytes()).unwrap(); + let dht_envelope = make_dht_envelope(&node_identity, encrypted_bytes, DhtMessageFlags::ENCRYPTED, true); + + let origin_mac = dht_envelope.header.as_ref().unwrap().origin_mac.clone(); + assert_eq!(origin_mac.is_empty(), false); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); service.call(inbound_message).await.unwrap(); @@ -463,7 +444,7 @@ mod test { let (params, _) = oms_mock_state.pop_call().unwrap(); // Check that OMS got a request to forward with the original Dht Header - assert_eq!(params.dht_header.unwrap().origin.unwrap().signature, origin_sig); + assert_eq!(params.dht_header.unwrap().origin_mac, origin_mac); // Check the next service was not called assert!(next_service_rx.try_next().is_err()); @@ -494,18 +475,14 @@ mod test { let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); - let msg = wrap_in_envelope_body!(b"secret".to_vec()).unwrap(); - let mut dht_envelope = make_dht_envelope( - &node_identity, - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); + let msg = wrap_in_envelope_body!(b"secret".to_vec()); + let mut dht_envelope = + make_dht_envelope(&node_identity, msg.to_encoded_bytes(), DhtMessageFlags::empty(), false); dht_envelope.header.as_mut().and_then(|header| { header.message_type = DhtMessageType::SafStoredMessages as i32; Some(header) }); - let inbound_message = - make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().unwrap().into()); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); service.call(inbound_message).await.unwrap_err(); // This seems like the best way to tell that an open channel is empty without the test blocking indefinitely diff --git a/comms/dht/src/discovery/requester.rs b/comms/dht/src/discovery/requester.rs index 4f97425eda..d43fb1dea3 100644 --- a/comms/dht/src/discovery/requester.rs +++ b/comms/dht/src/discovery/requester.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{discovery::DhtDiscoveryError, envelope::NodeDestination, proto::dht::DiscoveryResponseMessage}; +use crate::{discovery::DhtDiscoveryError, proto::dht::DiscoveryResponseMessage}; use futures::{ channel::{mpsc, oneshot}, SinkExt, @@ -29,51 +29,12 @@ use std::{ fmt::{Display, Error, Formatter}, time::Duration, }; -use tari_comms::{ - peer_manager::{NodeId, Peer}, - types::CommsPublicKey, -}; +use tari_comms::{peer_manager::Peer, types::CommsPublicKey}; use tokio::time; -#[derive(Debug)] -pub struct DiscoverPeerRequest { - /// The public key of the peer to be discovered. The message will be encrypted with a DH shared - /// secret using this public key. - pub dest_public_key: Box, - /// The node id of the peer to be discovered, if it is known. Providing the `NodeId` allows - /// discovery requests to reach their destination more quickly. - pub dest_node_id: Option, - /// The destination to include in the comms header. - /// `Undisclosed` will require nodes to propagate the message across the network, presumably eventually - /// reaching the destination node (the node which can decrypt the message). This will happen without - /// any intermediary nodes knowing who is being searched for. - /// `NodeId` will direct the discovery request closer to the destination or network region. - /// `PublicKey` will be propagated across the network. If any node knows the peer, the request can be - /// forwarded to them immediately. However, more nodes will know that this node is being searched for - /// which may slightly compromise privacy. - pub destination: NodeDestination, -} - -impl Display for DiscoverPeerRequest { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { - f.debug_struct("DiscoverPeerRequest") - .field("dest_public_key", &format!("{}", self.dest_public_key)) - .field( - "dest_node_id", - &self - .dest_node_id - .as_ref() - .map(|node_id| format!("{}", node_id)) - .unwrap_or_else(|| "None".to_string()), - ) - .field("destination", &format!("{}", self.destination)) - .finish() - } -} - #[derive(Debug)] pub enum DhtDiscoveryRequest { - DiscoverPeer(Box<(DiscoverPeerRequest, oneshot::Sender>)>), + DiscoverPeer(Box, oneshot::Sender>), NotifyDiscoveryResponseReceived(Box), } @@ -81,7 +42,7 @@ impl Display for DhtDiscoveryRequest { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { use DhtDiscoveryRequest::*; match self { - DiscoverPeer(boxed) => write!(f, "DiscoverPeer({})", boxed.0), + DiscoverPeer(public_key, _) => write!(f, "DiscoverPeer({})", public_key), NotifyDiscoveryResponseReceived(discovery_resp) => { write!(f, "NotifyDiscoveryResponseReceived({:#?})", discovery_resp) }, @@ -103,21 +64,25 @@ impl DhtDiscoveryRequester { } } - pub async fn discover_peer( - &mut self, - dest_public_key: Box, - dest_node_id: Option, - destination: NodeDestination, - ) -> Result - { + /// Initiate a peer discovery + /// + /// ## Arguments + /// - `dest_public_key` - The public key of he recipient used to create a shared ECDH key which in turn is used to + /// encrypt the discovery message + /// - `destination` - The `NodeDestination` to use in the DhtHeader when sending a discovery message. + /// - `Unknown` destination will maintain complete privacy, the trade off is that discovery needs to propagate + /// the entire network to reach the destination and so may take longer + /// - `NodeId` Instruct propagation nodes to direct the message to peers closer to the given NodeId. The `NodeId` + /// may be directed to a region close to the real destination (somewhat private) or directed at a particular + /// node (not private) + /// - `PublicKey` if any node on the network knows this public key, the message will be directed to that node. + /// This sacrifices privacy for more efficient discovery in terms of network bandwidth and may result in + /// quicker discovery times. + pub async fn discover_peer(&mut self, dest_public_key: Box) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); - let request = DiscoverPeerRequest { - dest_public_key, - dest_node_id, - destination, - }; + self.sender - .send(DhtDiscoveryRequest::DiscoverPeer(Box::new((request, reply_tx)))) + .send(DhtDiscoveryRequest::DiscoverPeer(dest_public_key, reply_tx)) .await?; time::timeout( diff --git a/comms/dht/src/discovery/service.rs b/comms/dht/src/discovery/service.rs index 6aea69ebce..3d0a1661d8 100644 --- a/comms/dht/src/discovery/service.rs +++ b/comms/dht/src/discovery/service.rs @@ -21,11 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - discovery::{ - requester::{DhtDiscoveryRequest, DiscoverPeerRequest}, - DhtDiscoveryError, - }, - envelope::{DhtMessageType, NodeDestination}, + discovery::{requester::DhtDiscoveryRequest, DhtDiscoveryError}, + envelope::DhtMessageType, outbound::{OutboundEncryption, OutboundMessageRequester, SendMessageParams}, proto::dht::{DiscoveryMessage, DiscoveryResponseMessage}, DhtConfig, @@ -146,11 +143,10 @@ impl DhtDiscoveryService { async fn handle_request(&mut self, request: DhtDiscoveryRequest) { use DhtDiscoveryRequest::*; match request { - DiscoverPeer(boxed) => { - let (request, reply_tx) = *boxed; + DiscoverPeer(dest_pubkey, reply_tx) => { log_if_error!( target: LOG_TARGET, - self.initiate_peer_discovery(request, reply_tx).await, + self.initiate_peer_discovery(dest_pubkey, reply_tx).await, "Failed to initiate a discovery request because '{error}'", ); }, @@ -183,16 +179,9 @@ impl DhtDiscoveryService { peer", peer.node_id.short_str() ); - // Attempt to discover them - let request = DiscoverPeerRequest { - dest_public_key: Box::new(peer.public_key), - // TODO: This should be the node region, not the node id - dest_node_id: Some(peer.node_id), - destination: Default::default(), - }; // Don't need to be notified for this discovery let (reply_tx, _) = oneshot::channel(); - if let Err(err) = self.initiate_peer_discovery(request, reply_tx).await { + if let Err(err) = self.initiate_peer_discovery(Box::new(peer.public_key), reply_tx).await { error!(target: LOG_TARGET, "Error sending discovery message: {:?}", err); } } @@ -381,13 +370,12 @@ impl DhtDiscoveryService { async fn initiate_peer_discovery( &mut self, - discovery_request: DiscoverPeerRequest, + dest_pubkey: Box, reply_tx: oneshot::Sender>, ) -> Result<(), DhtDiscoveryError> { let nonce = OsRng.next_u64(); - let public_key = discovery_request.dest_public_key.clone(); - self.send_discover(nonce, discovery_request).await?; + self.send_discover(nonce, dest_pubkey.clone()).await?; let inflight_count = self.inflight_discoveries.len(); @@ -406,7 +394,7 @@ impl DhtDiscoveryService { // Add the new inflight request. self.inflight_discoveries - .insert(nonce, DiscoveryRequestState::new(public_key, reply_tx)); + .insert(nonce, DiscoveryRequestState::new(dest_pubkey, reply_tx)); trace!( target: LOG_TARGET, @@ -420,25 +408,9 @@ impl DhtDiscoveryService { async fn send_discover( &mut self, nonce: u64, - discovery_request: DiscoverPeerRequest, + dest_public_key: Box, ) -> Result<(), DhtDiscoveryError> { - let DiscoverPeerRequest { - dest_node_id, - dest_public_key, - destination, - } = discovery_request; - - // If the destination node is is known, send to the closest peers we know. Otherwise... - let network_location_node_id = dest_node_id - .or_else(|| match &destination { - // ... if the destination is undisclosed or a public key, send discover to our closest peers - NodeDestination::Unknown | NodeDestination::PublicKey(_) => Some(self.node_identity.node_id().clone()), - // otherwise, send it to the closest peers to the given NodeId destination we know - NodeDestination::NodeId(node_id) => Some(*node_id.clone()), - }) - .expect("cannot fail"); - let discover_msg = DiscoveryMessage { node_id: self.node_identity.node_id().to_vec(), addresses: vec![self.node_identity.public_address().to_string()], @@ -447,19 +419,13 @@ impl DhtDiscoveryService { }; debug!( target: LOG_TARGET, - "Sending Discovery message for Node Id: {}", destination + "Sending Discovery message for peer public key '{}'", dest_public_key ); self.outbound_requester .send_message_no_header( SendMessageParams::new() - .closest( - network_location_node_id, - self.config.num_neighbouring_nodes, - Vec::new(), - PeerFeatures::empty(), - ) - .with_destination(destination) + .neighbours_include_clients(Vec::new()) .with_encryption(OutboundEncryption::EncryptFor(dest_public_key)) .with_dht_message_type(DhtMessageType::Discovery) .finish(), @@ -512,7 +478,7 @@ mod test { rt.spawn(service.run()); let dest_public_key = Box::new(CommsPublicKey::default()); - let result = rt.block_on(requester.discover_peer(dest_public_key.clone(), None, NodeDestination::Unknown)); + let result = rt.block_on(requester.discover_peer(dest_public_key.clone())); assert!(result.unwrap_err().is_timeout()); diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index ae1d0d6893..8e58296d59 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{consts::DHT_ENVELOPE_HEADER_VERSION, proto::envelope::DhtOrigin}; use bitflags::bitflags; use derive_error::Error; use serde::{Deserialize, Serialize}; @@ -29,11 +28,12 @@ use std::{ fmt, fmt::Display, }; -use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, utils::signature}; -use tari_crypto::tari_utilities::{hex::Hex, ByteArray, ByteArrayError}; +use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; +use tari_crypto::tari_utilities::{ByteArray, ByteArrayError}; // Re-export applicable protos pub use crate::proto::envelope::{dht_header::Destination, DhtEnvelope, DhtHeader, DhtMessageType, Network}; +use bytes::Bytes; #[derive(Debug, Error)] pub enum DhtMessageError { @@ -47,6 +47,8 @@ pub enum DhtMessageError { InvalidNetwork, /// Invalid or unrecognised DHT message flags InvalidMessageFlags, + /// Invalid ephemeral public key + InvalidEphemeralPublicKey, /// Header was omitted from the message HeaderOmitted, } @@ -103,71 +105,27 @@ impl DhtMessageType { } } -#[derive(Clone, PartialEq, Eq)] -pub struct DhtMessageOrigin { - pub public_key: CommsPublicKey, - pub signature: Vec, -} - -impl fmt::Debug for DhtMessageOrigin { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("DhtMessageOrigin") - .field("public_key", &self.public_key.to_hex()) - .field("signature", &self.signature.to_hex()) - .finish() - } -} - -impl TryFrom for DhtMessageOrigin { - type Error = DhtMessageError; - - fn try_from(value: DhtOrigin) -> Result { - Ok(Self { - public_key: CommsPublicKey::from_bytes(&value.public_key).map_err(|_| DhtMessageError::InvalidOrigin)?, - signature: value.signature, - }) - } -} - -impl From for DhtOrigin { - fn from(value: DhtMessageOrigin) -> Self { - Self { - public_key: value.public_key.to_vec(), - signature: value.signature, - } - } -} - /// This struct mirrors the protobuf version of DhtHeader but is more ergonomic to work with. /// It is preferable to not to expose the generated prost structs publicly. #[derive(Clone, Debug, PartialEq, Eq)] pub struct DhtMessageHeader { pub version: u32, pub destination: NodeDestination, - /// Origin of the message. This can refer to the same peer that sent the message - /// or another peer if the message should be forwarded. - pub origin: Option, + /// Encoded DhtOrigin. This can refer to the same peer that sent the message + /// or another peer if the message is being propagated. + pub origin_mac: Vec, + pub ephemeral_public_key: Option, pub message_type: DhtMessageType, pub network: Network, pub flags: DhtMessageFlags, } impl DhtMessageHeader { - pub fn new( - destination: NodeDestination, - message_type: DhtMessageType, - origin: Option, - network: Network, - flags: DhtMessageFlags, - ) -> Self - { - Self { - version: DHT_ENVELOPE_HEADER_VERSION, - destination, - origin, - message_type, - network, - flags, + pub fn is_valid(&self) -> bool { + if self.flags.contains(DhtMessageFlags::ENCRYPTED) { + !self.origin_mac.is_empty() && self.ephemeral_public_key.is_some() + } else { + true } } } @@ -176,8 +134,8 @@ impl Display for DhtMessageHeader { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!( f, - "DhtMessageHeader (Dest:{}, Origin:{:?}, Type:{:?}, Network:{:?}, Flags:{:?})", - self.destination, self.origin, self.message_type, self.network, self.flags + "DhtMessageHeader (Dest:{}, Type:{:?}, Network:{:?}, Flags:{:?})", + self.destination, self.message_type, self.network, self.flags ) } } @@ -193,15 +151,20 @@ impl TryFrom for DhtMessageHeader { .map(Option::unwrap) .ok_or_else(|| DhtMessageError::InvalidDestination)?; - let origin = match header.origin { - Some(origin) => Some(origin.try_into()?), - None => None, + let ephemeral_public_key = if header.ephemeral_public_key.is_empty() { + None + } else { + Some( + CommsPublicKey::from_bytes(&header.ephemeral_public_key) + .map_err(|_| DhtMessageError::InvalidEphemeralPublicKey)?, + ) }; Ok(Self { version: header.version, destination, - origin, + origin_mac: header.origin_mac, + ephemeral_public_key, message_type: DhtMessageType::from_i32(header.message_type) .ok_or_else(|| DhtMessageError::InvalidMessageType)?, network: Network::from_i32(header.network).ok_or_else(|| DhtMessageError::InvalidNetwork)?, @@ -225,7 +188,12 @@ impl From for DhtHeader { fn from(header: DhtMessageHeader) -> Self { Self { version: header.version, - origin: header.origin.map(Into::into), + ephemeral_public_key: header + .ephemeral_public_key + .as_ref() + .map(ByteArray::to_vec) + .unwrap_or_else(Vec::new), + origin_mac: header.origin_mac, destination: Some(header.destination.into()), message_type: header.message_type as i32, network: header.network as i32, @@ -235,44 +203,12 @@ impl From for DhtHeader { } impl DhtEnvelope { - pub fn new(header: DhtHeader, body: Vec) -> Self { + pub fn new(header: DhtHeader, body: Bytes) -> Self { Self { header: Some(header), - body, + body: body.to_vec(), } } - - /// Returns true if the header and origin are present, otherwise false - pub fn has_origin(&self) -> bool { - self.header.as_ref().map(|h| h.origin.is_some()).unwrap_or(false) - } - - /// Verifies the origin signature and returns true if it is valid. - /// - /// This method panics if called on an envelope without an origin. This should be checked before calling this - /// function by using the `DhtEnvelope::has_origin` method - pub fn is_origin_signature_valid(&self) -> bool { - self.header - .as_ref() - .and_then(|header| { - let origin = header - .origin - .as_ref() - .expect("call is_origin_signature_valid on envelope without origin"); - - CommsPublicKey::from_bytes(&origin.public_key) - .map(|pk| (pk, &origin.signature)) - .ok() - }) - .map(|(origin_public_key, origin_signature)| { - match signature::verify(&origin_public_key, origin_signature, &self.body) { - Ok(is_valid) => is_valid, - // error means that the signature could not deserialize, so is invalid - Err(_) => false, - } - }) - .unwrap_or(false) - } } /// Represents the ways a destination node can be represented. diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index ade9cb1cc7..18797eb380 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -22,18 +22,43 @@ use crate::{ crypt, - envelope::DhtMessageFlags, + envelope::{DhtMessageFlags, DhtMessageHeader}, inbound::message::{DecryptedDhtMessage, DhtInboundMessage}, + proto::envelope::OriginMac, }; +use derive_error::Error; use futures::{task::Context, Future}; use log::*; use prost::Message; use std::{sync::Arc, task::Poll}; -use tari_comms::{message::EnvelopeBody, peer_manager::NodeIdentity, pipeline::PipelineError}; +use tari_comms::{ + message::EnvelopeBody, + peer_manager::NodeIdentity, + pipeline::PipelineError, + types::CommsPublicKey, + utils::signature, +}; +use tari_crypto::tari_utilities::ByteArray; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::middleware::decryption"; +#[derive(Error, Debug)] +enum DecryptionError { + /// Failed to validate origin MAC signature + OriginMacInvalidSignature, + /// Origin MAC contained an invalid public key + OriginMacInvalidPublicKey, + /// Origin MAC not provided for encrypted message + OriginMacNotProvided, + /// Failed to decrypt origin MAC + OriginMacDecryptedFailed, + /// Failed to decode clear-text origin MAC + OriginMacClearTextDecodeFailed, + /// Failed to decrypt message body + MessageBodyDecryptionFailed, +} + /// This layer is responsible for attempting to decrypt inbound messages. pub struct DecryptionLayer { node_identity: Arc, @@ -96,20 +121,42 @@ where S: Service ) -> Result<(), PipelineError> { let dht_header = &message.dht_header; + if !dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) { return Self::success_not_encrypted(next_service, message).await; } - let origin = dht_header - .origin + let e_pk = dht_header + .ephemeral_public_key .as_ref() - // TODO: #banheuristics - this should not have been sent/propagated - .ok_or_else(|| "Message origin field is required for encrypted messages")?; + // TODO: #banheuristic - encrypted message sent without ephemeral public key + .ok_or("Ephemeral public key not provided for encrypted message")?; + + let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), e_pk); - debug!(target: LOG_TARGET, "Attempting to decrypt message"); - let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), &origin.public_key); - match crypt::decrypt(&shared_secret, &message.body) { - Ok(decrypted) => Self::decryption_succeeded(next_service, &node_identity, message, &decrypted).await, + // Decrypt and verify the origin + let authenticated_origin = match Self::attempt_decrypt_origin_mac(&shared_secret, dht_header) { + Ok(origin) => { + // If this fails, discard the message because we decrypted and deserialized the message with our shared + // ECDH secret but the message could not be authenticated + Self::authenticate_origin_mac(&origin, &message.body).map_err(PipelineError::from_debug)? + }, + Err(err) => { + debug!(target: LOG_TARGET, "Unable to decrypt message origin: {}", err); + return Self::decryption_failed(next_service, &node_identity, message).await; + }, + }; + + debug!( + target: LOG_TARGET, + "Attempting to decrypt message body from origin public key '{}'", authenticated_origin + ); + match Self::attempt_decrypt_message_body(&shared_secret, &message.body) { + Ok(message_body) => { + debug!(target: LOG_TARGET, "Message successfully decrypted"); + let msg = DecryptedDhtMessage::succeeded(message_body, Some(authenticated_origin), message); + next_service.oneshot(msg).await + }, Err(err) => { debug!(target: LOG_TARGET, "Unable to decrypt message: {}", err); Self::decryption_failed(next_service, &node_identity, message).await @@ -117,57 +164,81 @@ where S: Service } } - async fn decryption_succeeded( - next_service: S, - node_identity: &NodeIdentity, - message: DhtInboundMessage, - decrypted: &[u8], - ) -> Result<(), PipelineError> + fn attempt_decrypt_origin_mac( + shared_secret: &CommsPublicKey, + dht_header: &DhtMessageHeader, + ) -> Result { + let encrypted_origin_mac = Some(&dht_header.origin_mac) + .filter(|b| !b.is_empty()) + // TODO: #banheuristic - this should not have been sent/propagated + .ok_or_else(|| DecryptionError::OriginMacNotProvided)?; + let decrypted_bytes = crypt::decrypt(shared_secret, encrypted_origin_mac) + .map_err(|_| DecryptionError::OriginMacDecryptedFailed)?; + OriginMac::decode(decrypted_bytes.as_slice()).map_err(|_| DecryptionError::OriginMacDecryptedFailed) + } + + fn authenticate_origin_mac(origin: &OriginMac, body: &[u8]) -> Result { + let public_key = + CommsPublicKey::from_bytes(&origin.public_key).map_err(|_| DecryptionError::OriginMacInvalidPublicKey)?; + if signature::verify(&public_key, &origin.signature, body).unwrap_or(false) { + Ok(public_key) + } else { + Err(DecryptionError::OriginMacInvalidSignature) + } + } + + fn attempt_decrypt_message_body( + shared_secret: &CommsPublicKey, + message_body: &[u8], + ) -> Result + { + let decrypted = + crypt::decrypt(shared_secret, message_body).map_err(|_| DecryptionError::MessageBodyDecryptionFailed)?; // Deserialization into an EnvelopeBody is done here to determine if the // decryption produced valid bytes or not. - let result = EnvelopeBody::decode(decrypted).and_then(|body| { - // Check if we received a body length of zero - // - // In addition to a peer sending a zero-length EnvelopeBody, decoding can erroneously succeed - // if the decrypted bytes happen to be valid protobuf encoding. This is very possible and - // the decrypt_inbound_fail test below _will_ sporadically fail without the following check. - // This is because proto3 will set fields to their default value if they don't exist in a valid encoding. - // - // For the parts of EnvelopeBody to be erroneously populated with bytes, all of these - // conditions would have to be true: - // 1. field type == 2 (length-delimited) - // 2. field number == 1 - // 3. the subsequent byte(s) would have to be varint-encoded length which does not overflow - // 4. the rest of the bytes would have to be valid protobuf encoding - // - // The chance of this happening is extremely negligible. - if body.is_empty() { - return Err(prost::DecodeError::new("EnvelopeBody has no parts")); - } - Ok(body) - }); - match result { - Ok(deserialized) => { - debug!(target: LOG_TARGET, "Message successfully decrypted"); - let msg = DecryptedDhtMessage::succeeded(deserialized, message); - next_service.oneshot(msg).await - }, - Err(err) => { - debug!(target: LOG_TARGET, "Unable to deserialize message: {}", err); - Self::decryption_failed(next_service, &node_identity, message).await - }, - } + EnvelopeBody::decode(decrypted.as_slice()) + .and_then(|body| { + // Check if we received a body length of zero + // + // In addition to a peer sending a zero-length EnvelopeBody, decoding can erroneously succeed + // if the decrypted bytes happen to be valid protobuf encoding. This is very possible and + // the decrypt_inbound_fail test below _will_ sporadically fail without the following check. + // This is because proto3 will set fields to their default value if they don't exist in a valid + // encoding. + // + // For the parts of EnvelopeBody to be erroneously populated with bytes, all of these + // conditions would have to be true: + // 1. field type == 2 (length-delimited) + // 2. field number == 1 + // 3. the subsequent byte(s) would have to be varint-encoded length which does not overflow + // 4. the rest of the bytes would have to be valid protobuf encoding + // + // The chance of this happening is extremely negligible. + if body.is_empty() { + return Err(prost::DecodeError::new("EnvelopeBody has no parts")); + } + Ok(body) + }) + .map_err(|_| DecryptionError::MessageBodyDecryptionFailed) } async fn success_not_encrypted(next_service: S, message: DhtInboundMessage) -> Result<(), PipelineError> { + let authenticated_pk = if message.dht_header.origin_mac.is_empty() { + None + } else { + let origin_mac = OriginMac::decode(message.dht_header.origin_mac.as_slice()) + .map_err(|_| PipelineError::from_debug(DecryptionError::OriginMacClearTextDecodeFailed))?; + Some(Self::authenticate_origin_mac(&origin_mac, &message.body).map_err(PipelineError::from_debug)?) + }; + match EnvelopeBody::decode(message.body.as_slice()) { Ok(deserialized) => { debug!( target: LOG_TARGET, "Message is not encrypted. Passing onto next service" ); - let msg = DecryptedDhtMessage::succeeded(deserialized, message); + let msg = DecryptedDhtMessage::succeeded(deserialized, authenticated_pk, message); next_service.oneshot(msg).await }, Err(err) => { @@ -194,8 +265,9 @@ where S: Service // TODO: #banheuristic - the origin of this message sent this node a message we could not decrypt warn!( target: LOG_TARGET, - "Received message from peer '{}' that is destined for that peer. Discarding message", - message.dht_header.origin.as_ref().expect("already checked").public_key + "Received message from peer '{}' that is destined for this node that could not be decrypted. \ + Discarding message", + message.source_peer.node_id ); return Err( "Message rejected because this node could not decrypt a message that was addressed to it".into(), @@ -241,10 +313,13 @@ mod test { let node_identity = make_node_identity(); let mut service = DecryptionService::new(inner, Arc::clone(&node_identity)); - let plain_text_msg = wrap_in_envelope_body!(Vec::new()).unwrap(); - let secret_key = crypt::generate_ecdh_secret(node_identity.secret_key(), node_identity.public_key()); - let encrypted = crypt::encrypt(&secret_key, &plain_text_msg.to_encoded_bytes().unwrap()).unwrap(); - let inbound_msg = make_dht_inbound_message(&node_identity, encrypted, DhtMessageFlags::ENCRYPTED); + let plain_text_msg = wrap_in_envelope_body!(b"Secret plans".to_vec()); + let inbound_msg = make_dht_inbound_message( + &node_identity, + plain_text_msg.to_encoded_bytes(), + DhtMessageFlags::ENCRYPTED, + true, + ); block_on(service.call(inbound_msg)).unwrap(); let decrypted = result.lock().unwrap().take().unwrap(); @@ -262,14 +337,16 @@ mod test { let node_identity = make_node_identity(); let mut service = DecryptionService::new(inner, Arc::clone(&node_identity)); - let nonsense = "Cannot Decrypt this".as_bytes().to_vec(); - let inbound_msg = make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED); + let some_secret = "Super secret message".as_bytes().to_vec(); + let some_other_node_identity = make_node_identity(); + let inbound_msg = + make_dht_inbound_message(&some_other_node_identity, some_secret, DhtMessageFlags::ENCRYPTED, true); - block_on(service.call(inbound_msg)).unwrap(); + block_on(service.call(inbound_msg.clone())).unwrap(); let decrypted = result.lock().unwrap().take().unwrap(); assert_eq!(decrypted.decryption_succeeded(), false); - assert_eq!(decrypted.decryption_result.unwrap_err(), nonsense); + assert_eq!(decrypted.decryption_result.unwrap_err(), inbound_msg.body); } #[test] @@ -283,7 +360,8 @@ mod test { let mut service = DecryptionService::new(inner, Arc::clone(&node_identity)); let nonsense = "Cannot Decrypt this".as_bytes().to_vec(); - let mut inbound_msg = make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED); + let mut inbound_msg = + make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED, true); inbound_msg.dht_header.destination = node_identity.public_key().clone().into(); let err = block_on(service.call(inbound_msg)).unwrap_err(); diff --git a/comms/dht/src/inbound/dedup.rs b/comms/dht/src/inbound/dedup.rs index 25a6253dc5..15f09646fc 100644 --- a/comms/dht/src/inbound/dedup.rs +++ b/comms/dht/src/inbound/dedup.rs @@ -26,7 +26,6 @@ use futures::{task::Context, Future}; use log::*; use std::task::Poll; use tari_comms::{pipeline::PipelineError, types::Challenge}; -use tari_crypto::tari_utilities::hex::Hex; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::dedup"; @@ -85,13 +84,7 @@ where S: Service { warn!( target: LOG_TARGET, - "Received duplicate message from peer {} (origin={:?}). Message discarded.", - message.source_peer.node_id, - message - .dht_header - .origin - .map(|o| o.public_key.to_hex()) - .unwrap_or_else(|| "".to_string()), + "Received duplicate message from peer {}. Message discarded.", message.source_peer.node_id, ); return Ok(()); } @@ -148,7 +141,7 @@ mod test { assert!(dedup.poll_ready(&mut cx).is_ready()); let node_identity = make_node_identity(); - let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty()); + let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false); rt.block_on(dedup.call(msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); diff --git a/comms/dht/src/inbound/deserialize.rs b/comms/dht/src/inbound/deserialize.rs index cf94c2d471..57ac7e39e2 100644 --- a/comms/dht/src/inbound/deserialize.rs +++ b/comms/dht/src/inbound/deserialize.rs @@ -58,51 +58,38 @@ where S: Service + Clon Poll::Ready(Ok(())) } - fn call(&mut self, msg: InboundMessage) -> Self::Future { - Self::deserialize(self.next_service.clone(), msg) - } -} - -impl DhtDeserializeMiddleware -where S: Service -{ - pub async fn deserialize(next_service: S, message: InboundMessage) -> Result<(), PipelineError> { - trace!(target: LOG_TARGET, "Deserializing InboundMessage"); - - let InboundMessage { - source_peer, mut body, .. - } = message; - - match DhtEnvelope::decode(&mut body) { - Ok(dht_envelope) => { - trace!(target: LOG_TARGET, "Deserialization succeeded. Checking signatures"); - if dht_envelope.has_origin() { - if dht_envelope.is_origin_signature_valid() { - trace!(target: LOG_TARGET, "Origin signature validation passed."); - } else { - // TODO: #banheuristic - // The origin signature is not valid, this message should never have been sent - warn!( - target: LOG_TARGET, - "SECURITY: Origin signature verification failed. Discarding message from NodeId {}", - source_peer.node_id - ); - return Ok(()); - } - } - - let inbound_msg = DhtInboundMessage::new( - dht_envelope.header.try_into().map_err(PipelineError::from_debug)?, - source_peer, - dht_envelope.body, - ); - - next_service.oneshot(inbound_msg).await - }, - Err(err) => { - error!(target: LOG_TARGET, "DHT deserialization failed: {}", err); - Err(PipelineError::from_debug(err)) - }, + fn call(&mut self, message: InboundMessage) -> Self::Future { + let next_service = self.next_service.clone(); + async move { + trace!(target: LOG_TARGET, "Deserializing InboundMessage"); + + let InboundMessage { + source_peer, mut body, .. + } = message; + + if body.is_empty() { + return Err(format!("Received empty message from peer '{}'", source_peer) + .as_str() + .into()); + } + + match DhtEnvelope::decode(&mut body) { + Ok(dht_envelope) => { + trace!(target: LOG_TARGET, "Deserialization succeeded."); + + let inbound_msg = DhtInboundMessage::new( + dht_envelope.header.try_into().map_err(PipelineError::from_debug)?, + source_peer, + dht_envelope.body, + ); + + next_service.oneshot(inbound_msg).await + }, + Err(err) => { + error!(target: LOG_TARGET, "DHT deserialization failed: {}", err); + Err(PipelineError::from_debug(err)) + }, + } } } } @@ -144,10 +131,10 @@ mod test { assert!(deserialize.poll_ready(&mut cx).is_ready()); let node_identity = make_node_identity(); - let dht_envelope = make_dht_envelope(&node_identity, b"A".to_vec(), DhtMessageFlags::empty()); + let dht_envelope = make_dht_envelope(&node_identity, b"A".to_vec(), DhtMessageFlags::empty(), false); block_on(deserialize.call(make_comms_inbound_message( &node_identity, - dht_envelope.to_encoded_bytes().unwrap().into(), + dht_envelope.to_encoded_bytes().into(), ))) .unwrap(); diff --git a/comms/dht/src/inbound/dht_handler/task.rs b/comms/dht/src/inbound/dht_handler/task.rs index 9a3dc372d1..43601edea2 100644 --- a/comms/dht/src/inbound/dht_handler/task.rs +++ b/comms/dht/src/inbound/dht_handler/task.rs @@ -175,20 +175,20 @@ where S: Service decryption_result, dht_header, source_peer, + authenticated_origin, .. } = message; - let origin = dht_header - .origin - .as_ref() - .ok_or_else(|| DhtInboundError::OriginRequired("Origin is required for this message type".to_string()))?; + let authenticated_pk = authenticated_origin.ok_or_else(|| { + DhtInboundError::OriginRequired("Authenticated origin is required for this message type".to_string()) + })?; - if &origin.public_key == self.node_identity.public_key() { + if &authenticated_pk == self.node_identity.public_key() { trace!(target: LOG_TARGET, "Received our own join message. Discarding it."); return Ok(()); } - trace!(target: LOG_TARGET, "Received Join Message from {}", origin.public_key); + trace!(target: LOG_TARGET, "Received Join Message from {}", authenticated_pk); let body = decryption_result.expect("already checked that this message decrypted successfully"); let join_msg = body @@ -205,11 +205,11 @@ where S: Service return Err(DhtInboundError::InvalidAddresses); } - let node_id = self.validate_raw_node_id(&origin.public_key, &join_msg.node_id)?; + let node_id = self.validate_raw_node_id(&authenticated_pk, &join_msg.node_id)?; let origin_peer = self .add_or_update_peer( - &origin.public_key, + &authenticated_pk, node_id, addresses, PeerFeatures::from_bits_truncate(join_msg.peer_features), @@ -262,12 +262,12 @@ where S: Service .closest( origin_peer.node_id, self.config.num_neighbouring_nodes, - vec![origin.public_key.clone(), source_peer.public_key.clone()], + vec![authenticated_pk, source_peer.public_key.clone()], PeerFeatures::MESSAGE_PROPAGATION, ) .with_dht_header(dht_header) .finish(), - body.to_encoded_bytes()?, + body.to_encoded_bytes(), ) .await?; @@ -300,10 +300,9 @@ where S: Service target: LOG_TARGET, "Received Discover Response Message from {}", message - .dht_header - .origin + .authenticated_origin .as_ref() - .map(|o| o.public_key.to_hex()) + .map(|pk| pk.to_hex()) .unwrap_or_else(|| "".to_string()) ); @@ -331,13 +330,13 @@ where S: Service .decode_part::(0)? .ok_or_else(|| DhtInboundError::InvalidMessageBody)?; - let origin = message.dht_header.origin.ok_or_else(|| { + let authenticated_pk = message.authenticated_origin.ok_or_else(|| { DhtInboundError::OriginRequired("Origin header required for Discovery message".to_string()) })?; info!( target: LOG_TARGET, - "Received discovery message from '{}'", origin.public_key, + "Received discovery message from '{}'", authenticated_pk ); let addresses = discover_msg @@ -350,10 +349,10 @@ where S: Service return Err(DhtInboundError::InvalidAddresses); } - let node_id = self.validate_raw_node_id(&origin.public_key, &discover_msg.node_id)?; + let node_id = self.validate_raw_node_id(&authenticated_pk, &discover_msg.node_id)?; let origin_peer = self .add_or_update_peer( - &origin.public_key, + &authenticated_pk, node_id, addresses, PeerFeatures::from_bits_truncate(discover_msg.peer_features), @@ -370,7 +369,7 @@ where S: Service } // Send the origin the current nodes latest contact info - self.send_discovery_response(origin.public_key, discover_msg.nonce) + self.send_discovery_response(origin_peer.public_key, discover_msg.nonce) .await?; Ok(()) diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index 6224bf340c..987c8cea76 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -68,17 +68,24 @@ pub struct DecryptedDhtMessage { /// The _connected_ peer which sent or forwarded this message. This may not be the peer /// which created this message. pub source_peer: Arc, + pub authenticated_origin: Option, pub dht_header: DhtMessageHeader, pub decryption_result: Result>, } impl DecryptedDhtMessage { - pub fn succeeded(decrypted_message: EnvelopeBody, message: DhtInboundMessage) -> Self { + pub fn succeeded( + message_body: EnvelopeBody, + authenticated_origin: Option, + message: DhtInboundMessage, + ) -> Self + { Self { version: message.version, source_peer: message.source_peer, + authenticated_origin, dht_header: message.dht_header, - decryption_result: Ok(decrypted_message), + decryption_result: Ok(message_body), } } @@ -86,6 +93,7 @@ impl DecryptedDhtMessage { Self { version: message.version, source_peer: message.source_peer, + authenticated_origin: None, dht_header: message.dht_header, decryption_result: Err(message.body), } @@ -115,12 +123,8 @@ impl DecryptedDhtMessage { self.decryption_result.is_err() } - pub fn origin_public_key(&self) -> &CommsPublicKey { - self.dht_header - .origin - .as_ref() - .map(|o| &o.public_key) - .unwrap_or(&self.source_peer.public_key) + pub fn authenticated_origin(&self) -> Option<&CommsPublicKey> { + self.authenticated_origin.as_ref() } /// Returns true if the message is or was encrypted by @@ -128,11 +132,11 @@ impl DecryptedDhtMessage { self.dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) } - pub fn has_origin(&self) -> bool { - self.dht_header.origin.is_some() + pub fn has_origin_mac(&self) -> bool { + !self.dht_header.origin_mac.is_empty() } - pub fn body_size(&self) -> usize { + pub fn body_len(&self) -> usize { match self.decryption_result.as_ref() { Ok(b) => b.total_size(), Err(b) => b.len(), diff --git a/comms/dht/src/inbound/validate.rs b/comms/dht/src/inbound/validate.rs index c4fa1fce28..b6132af821 100644 --- a/comms/dht/src/inbound/validate.rs +++ b/comms/dht/src/inbound/validate.rs @@ -20,19 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - inbound::DhtInboundMessage, - outbound::{OutboundMessageRequester, SendMessageParams}, - proto::{ - dht::{RejectMessage, RejectMessageReason}, - envelope::{DhtMessageType, Network}, - }, -}; +use crate::{inbound::DhtInboundMessage, proto::envelope::Network}; use futures::{task::Context, Future}; use log::*; use std::task::Poll; -use tari_comms::{message::MessageExt, pipeline::PipelineError}; -use tari_crypto::tari_utilities::ByteArray; +use tari_comms::pipeline::PipelineError; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::validate"; @@ -45,15 +37,13 @@ const LOG_TARGET: &str = "comms::dht::validate"; pub struct ValidateMiddleware { next_service: S, target_network: Network, - outbound_requester: OutboundMessageRequester, } impl ValidateMiddleware { - pub fn new(service: S, target_network: Network, outbound_requester: OutboundMessageRequester) -> Self { + pub fn new(service: S, target_network: Network) -> Self { Self { next_service: service, target_network, - outbound_requester, } } } @@ -70,76 +60,40 @@ where S: Service + Clon Poll::Ready(Ok(())) } - fn call(&mut self, msg: DhtInboundMessage) -> Self::Future { - Self::process_message( - self.next_service.clone(), - self.target_network, - self.outbound_requester.clone(), - msg, - ) - } -} - -impl ValidateMiddleware -where S: Service -{ - pub async fn process_message( - next_service: S, - target_network: Network, - mut outbound_requester: OutboundMessageRequester, - message: DhtInboundMessage, - ) -> Result<(), PipelineError> - { - trace!( - target: LOG_TARGET, - "Checking the message target network is '{:?}'", - target_network - ); - if message.dht_header.network == target_network { - next_service.oneshot(message).await?; - } else { - debug!( + fn call(&mut self, message: DhtInboundMessage) -> Self::Future { + let next_service = self.next_service.clone(); + let target_network = self.target_network; + async move { + trace!( target: LOG_TARGET, - "Message is for another network (want = {:?} got = {:?}). Explicitly rejecting the message.", - target_network, - message.dht_header.network + "Checking the message target network is '{:?}'", + target_network ); - outbound_requester - .send_raw( - SendMessageParams::new() - .direct_public_key(message.source_peer.public_key.clone()) - .with_dht_message_type(DhtMessageType::RejectMsg) - .finish(), - RejectMessage { - signature: message - .dht_header - .origin - .map(|o| o.public_key.to_vec()) - .unwrap_or_default(), - reason: RejectMessageReason::UnsupportedNetwork as i32, - } - .to_encoded_bytes() - .map_err(PipelineError::from_debug)?, - ) - .await - .map_err(PipelineError::from_debug)?; - } - Ok(()) + if message.dht_header.network == target_network && message.dht_header.is_valid() { + next_service.oneshot(message).await?; + } else { + debug!( + target: LOG_TARGET, + "Message is for another network (want = {:?} got = {:?}) or message header is invalid. Discarding \ + the message.", + target_network, + message.dht_header.network + ); + } + + Ok(()) + } } } pub struct ValidateLayer { target_network: Network, - outbound_requester: OutboundMessageRequester, } impl ValidateLayer { - pub fn new(target_network: Network, outbound_requester: OutboundMessageRequester) -> Self { - Self { - target_network, - outbound_requester, - } + pub fn new(target_network: Network) -> Self { + Self { target_network } } } @@ -147,7 +101,7 @@ impl Layer for ValidateLayer { type Service = ValidateMiddleware; fn layer(&self, service: S) -> Self::Service { - ValidateMiddleware::new(service, self.target_network, self.outbound_requester.clone()) + ValidateMiddleware::new(service, self.target_network) } } @@ -155,8 +109,7 @@ impl Layer for ValidateLayer { mod test { use super::*; use crate::{ - envelope::{DhtMessageFlags, DhtMessageType}, - outbound::mock::create_outbound_service_mock, + envelope::DhtMessageFlags, test_utils::{make_dht_inbound_message, make_node_identity, service_spy}, }; use tari_test_utils::panic_context; @@ -167,18 +120,13 @@ mod test { let mut rt = Runtime::new().unwrap(); let spy = service_spy(); - let (out_requester, mock) = create_outbound_service_mock(1); - let mock_state = mock.get_state(); - rt.spawn(mock.run()); - - let mut validate = - ValidateLayer::new(Network::LocalTest, out_requester).layer(spy.to_service::()); + let mut validate = ValidateLayer::new(Network::LocalTest).layer(spy.to_service::()); panic_context!(cx); assert!(validate.poll_ready(&mut cx).is_ready()); let node_identity = make_node_identity(); - let mut msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty()); + let mut msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false); msg.dht_header.network = Network::MainNet; rt.block_on(validate.call(msg.clone())).unwrap(); @@ -188,17 +136,5 @@ mod test { rt.block_on(validate.call(msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); - - let calls = mock_state.take_calls(); - assert_eq!(calls.len(), 1); - let params = calls[0].0.clone(); - assert_eq!(params.dht_message_type, DhtMessageType::RejectMsg); - assert_eq!( - params.broadcast_strategy.direct_public_key().unwrap(), - node_identity.public_key() - ); - - // Drop validate so that the mock will stop running - drop(validate); } } diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 33bc492f0b..c0b1672a07 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -25,7 +25,7 @@ use crate::{ actor::DhtRequester, broadcast_strategy::BroadcastStrategy, discovery::DhtDiscoveryRequester, - envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, NodeDestination}, + envelope::{DhtMessageFlags, DhtMessageHeader, NodeDestination}, outbound::{ message::{DhtOutboundMessage, OutboundEncryption}, message_params::FinalSendMessageParams, @@ -33,6 +33,7 @@ use crate::{ }, proto::envelope::{DhtMessageType, Network}, }; +use bytes::Bytes; use futures::{ channel::oneshot, future, @@ -43,7 +44,8 @@ use futures::{ use log::*; use std::{sync::Arc, task::Poll}; use tari_comms::{ - peer_manager::{NodeId, NodeIdentity, Peer}, + message::MessageTag, + peer_manager::{NodeIdentity, Peer}, pipeline::PipelineError, types::CommsPublicKey, }; @@ -216,7 +218,7 @@ where S: Service async fn handle_send_message( &mut self, params: FinalSendMessageParams, - body: Vec, + body: Bytes, reply_tx: oneshot::Sender, ) -> Result, DhtOutboundError> { @@ -344,24 +346,8 @@ where S: Service dest_public_key ); - // TODO: This works because we know that all non-DAN node IDs are/should be derived from the public key. - // Once the DAN launches, this may not be the case and we'll need to query the blockchain for the node id - let derived_node_id = NodeId::from_key(&*dest_public_key).ok(); - - // TODO: Target a general region instead of the actual destination node id - let regional_destination = derived_node_id - .as_ref() - .map(Clone::clone) - .map(Box::new) - .map(NodeDestination::NodeId) - .unwrap_or_else(|| NodeDestination::Unknown); - // Peer not found, let's try and discover it - match self - .dht_discovery_requester - .discover_peer(dest_public_key, derived_node_id, regional_destination) - .await - { + match self.dht_discovery_requester.discover_peer(dest_public_key).await { // Peer found! Ok(peer) => { if peer.is_banned() { @@ -397,42 +383,28 @@ where S: Service custom_header: Option, extra_flags: DhtMessageFlags, force_origin: bool, - body: Vec, + body: Bytes, ) -> Result, DhtOutboundError> { let dht_flags = encryption.flags() | extra_flags; - // Create a DHT header - let dht_header = custom_header - .or_else(|| { - // The origin is specified if encryption is turned on, otherwise it is not - let origin = if force_origin || encryption.is_encrypt() { - Some(DhtMessageOrigin { - // Origin public key used to identify the origin and verify the signature - public_key: self.node_identity.public_key().clone(), - // Signing will happen later in the pipeline (SerializeMiddleware), left empty to prevent double - // work - signature: Vec::new(), - }) - } else { - None - }; - - Some(DhtMessageHeader::new( - // Final destination for this message - destination, - dht_message_type, - origin, - self.target_network, - dht_flags, - )) - }) - .expect("always Some"); - - // Construct a MessageEnvelope for each recipient + // Construct a DhtOutboundMessage for each recipient let messages = selected_peers .into_iter() - .map(|peer| DhtOutboundMessage::new(peer, dht_header.clone(), encryption.clone(), body.clone())) + .map(|peer| DhtOutboundMessage { + tag: MessageTag::new(), + destination_peer: peer, + destination: destination.clone(), + dht_message_type, + network: self.target_network, + dht_flags, + custom_header: custom_header.clone(), + include_origin: force_origin || encryption.is_encrypt(), + encryption: encryption.clone(), + body: body.clone(), + ephemeral_public_key: None, + origin_mac: None, + }) .collect::>(); Ok(messages) @@ -518,7 +490,7 @@ mod test { rt.block_on(service.call(DhtOutboundRequest::SendMessage( Box::new(SendMessageParams::new().flood().finish()), - "custom_msg".as_bytes().to_vec(), + "custom_msg".as_bytes().into(), reply_tx, ))) .unwrap(); @@ -568,7 +540,7 @@ mod test { .with_discovery(false) .finish(), ), - "custom_msg".as_bytes().to_vec(), + Bytes::from_static(b"custom_msg"), reply_tx, )), ) @@ -619,7 +591,7 @@ mod test { .direct_public_key(peer_to_discover.public_key.clone()) .finish(), ), - "custom_msg".as_bytes().to_vec(), + "custom_msg".as_bytes().into(), reply_tx, )), ) diff --git a/comms/dht/src/outbound/encryption.rs b/comms/dht/src/outbound/encryption.rs index 659b6ca067..9b3f1be881 100644 --- a/comms/dht/src/outbound/encryption.rs +++ b/comms/dht/src/outbound/encryption.rs @@ -23,11 +23,23 @@ use crate::{ crypt, outbound::message::{DhtOutboundMessage, OutboundEncryption}, + proto::envelope::OriginMac, }; use futures::{task::Context, Future}; use log::*; +use rand::rngs::OsRng; use std::{sync::Arc, task::Poll}; -use tari_comms::{peer_manager::NodeIdentity, pipeline::PipelineError}; +use tari_comms::{ + message::MessageExt, + peer_manager::NodeIdentity, + pipeline::PipelineError, + types::CommsPublicKey, + utils::signature, +}; +use tari_crypto::{ + keys::PublicKey, + tari_utilities::{message_format::MessageFormat, ByteArray}, +}; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::middleware::encryption"; @@ -79,34 +91,55 @@ where S: Service + Clo Poll::Ready(Ok(())) } - fn call(&mut self, msg: DhtOutboundMessage) -> Self::Future { - Self::handle_message(self.inner.clone(), Arc::clone(&self.node_identity), msg) + fn call(&mut self, mut message: DhtOutboundMessage) -> Self::Future { + let next_service = self.inner.clone(); + let node_identity = Arc::clone(&self.node_identity); + async move { + trace!(target: LOG_TARGET, "DHT Message flags: {:?}", message.dht_flags); + match &message.encryption { + OutboundEncryption::EncryptFor(public_key) => { + debug!(target: LOG_TARGET, "Encrypting message for {}", public_key); + // Generate ephemeral public/private key pair and ECDH shared secret + let (e_sk, e_pk) = CommsPublicKey::random_keypair(&mut OsRng); + let shared_ephemeral_secret = crypt::generate_ecdh_secret(&e_sk, &**public_key); + // Encrypt the message with the body + let encrypted_body = + crypt::encrypt(&shared_ephemeral_secret, &message.body).map_err(PipelineError::from_debug)?; + + // Sign the encrypted message + let origin_mac = create_origin_mac(&node_identity, &encrypted_body)?; + // Encrypt and set the origin field + let encrypted_origin_mac = + crypt::encrypt(&shared_ephemeral_secret, &origin_mac).map_err(PipelineError::from_debug)?; + message + .with_origin_mac(encrypted_origin_mac) + .with_ephemeral_public_key(e_pk) + .set_body(encrypted_body.into()); + }, + OutboundEncryption::None => { + debug!(target: LOG_TARGET, "Encryption not requested for message"); + + if message.include_origin && message.custom_header.is_none() { + let origin_mac = create_origin_mac(&node_identity, &message.body)?; + message.with_origin_mac(origin_mac); + } + }, + }; + + next_service.oneshot(message).await + } } } -impl EncryptionService -where S: Service -{ - async fn handle_message( - next_service: S, - node_identity: Arc, - mut message: DhtOutboundMessage, - ) -> Result<(), PipelineError> - { - trace!(target: LOG_TARGET, "DHT Message flags: {:?}", message.dht_header.flags); - match &message.encryption { - OutboundEncryption::EncryptFor(public_key) => { - debug!(target: LOG_TARGET, "Encrypting message for {}", public_key); - let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), &**public_key); - message.body = crypt::encrypt(&shared_secret, &message.body).map_err(PipelineError::from_debug)?; - }, - OutboundEncryption::None => { - debug!(target: LOG_TARGET, "Encryption not requested for message"); - }, - }; - - next_service.oneshot(message).await - } +fn create_origin_mac(node_identity: &NodeIdentity, body: &[u8]) -> Result, PipelineError> { + let signature = + signature::sign(&mut OsRng, node_identity.secret_key().clone(), body).map_err(PipelineError::from_debug)?; + + let mac = OriginMac { + public_key: node_identity.public_key().to_vec(), + signature: signature.to_binary().map_err(PipelineError::from_debug)?, + }; + Ok(mac.to_encoded_bytes()) } #[cfg(test)] @@ -114,14 +147,10 @@ mod test { use super::*; use crate::{ envelope::DhtMessageFlags, - test_utils::{make_dht_header, make_node_identity, service_spy}, + test_utils::{create_outbound_message, make_node_identity, service_spy}, }; use futures::executor::block_on; - use tari_comms::{ - net_address::MultiaddressesWithStats, - peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, - types::CommsPublicKey, - }; + use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; use tari_test_utils::panic_context; #[test] @@ -133,25 +162,14 @@ mod test { panic_context!(cx); assert!(encryption.poll_ready(&mut cx).is_ready()); - let body = b"A".to_vec(); - let msg = DhtOutboundMessage::new( - Peer::new( - CommsPublicKey::default(), - NodeId::default(), - MultiaddressesWithStats::new(vec![]), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_NODE, - &[], - ), - make_dht_header(&node_identity, &body, DhtMessageFlags::empty()), - OutboundEncryption::None, - body.clone(), - ); + let body = b"A"; + let msg = create_outbound_message(body); block_on(encryption.call(msg)).unwrap(); let msg = spy.pop_request().unwrap(); - assert_eq!(msg.body, body); + assert_eq!(msg.body.to_vec(), body); assert_eq!(msg.destination_peer.node_id, NodeId::default()); + assert!(msg.ephemeral_public_key.is_none()) } #[test] @@ -163,24 +181,15 @@ mod test { panic_context!(cx); assert!(encryption.poll_ready(&mut cx).is_ready()); - let body = b"A".to_vec(); - let msg = DhtOutboundMessage::new( - Peer::new( - CommsPublicKey::default(), - NodeId::default(), - MultiaddressesWithStats::new(vec![]), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_NODE, - &[], - ), - make_dht_header(&node_identity, &body, DhtMessageFlags::ENCRYPTED), - OutboundEncryption::EncryptFor(Box::new(CommsPublicKey::default())), - body.clone(), - ); + let body = b"A"; + let mut msg = create_outbound_message(body); + msg.dht_flags = DhtMessageFlags::ENCRYPTED; + msg.encryption = OutboundEncryption::EncryptFor(Box::new(CommsPublicKey::default())); block_on(encryption.call(msg)).unwrap(); let msg = spy.pop_request().unwrap(); - assert_ne!(msg.body, body); + assert_ne!(msg.body.to_vec(), body); assert_eq!(msg.destination_peer.node_id, NodeId::default()); + assert!(msg.ephemeral_public_key.is_some()) } } diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index a3f4229793..b266d1c573 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -21,9 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - envelope::{DhtMessageFlags, DhtMessageHeader}, + envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageType, Network, NodeDestination}, outbound::message_params::FinalSendMessageParams, }; +use bytes::Bytes; use futures::channel::oneshot; use std::{fmt, fmt::Display}; use tari_comms::{message::MessageTag, peer_manager::Peer, types::CommsPublicKey}; @@ -114,11 +115,7 @@ impl SendMessageResponse { #[derive(Debug)] pub enum DhtOutboundRequest { /// Send a message using the given broadcast strategy - SendMessage( - Box, - Vec, - oneshot::Sender, - ), + SendMessage(Box, Bytes, oneshot::Sender), } impl fmt::Display for DhtOutboundRequest { @@ -137,41 +134,54 @@ impl fmt::Display for DhtOutboundRequest { pub struct DhtOutboundMessage { pub tag: MessageTag, pub destination_peer: Peer, - pub dht_header: DhtMessageHeader, + pub custom_header: Option, pub encryption: OutboundEncryption, - pub body: Vec, + pub body: Bytes, + pub ephemeral_public_key: Option, + pub origin_mac: Option>, + pub include_origin: bool, + pub destination: NodeDestination, + pub dht_message_type: DhtMessageType, + pub network: Network, + pub dht_flags: DhtMessageFlags, } impl DhtOutboundMessage { - /// Create a new DhtOutboundMessage - pub fn new( - destination_peer: Peer, - dht_header: DhtMessageHeader, - encryption: OutboundEncryption, - body: Vec, - ) -> Self - { - Self { - tag: MessageTag::new(), - destination_peer, - dht_header, - encryption, - body, - } + pub fn with_ephemeral_public_key(&mut self, ephemeral_public_key: CommsPublicKey) -> &mut Self { + self.ephemeral_public_key = Some(ephemeral_public_key); + self + } + + pub fn with_origin_mac(&mut self, origin_mac: Vec) -> &mut Self { + self.origin_mac = Some(origin_mac); + self + } + + pub fn set_body(&mut self, body: Bytes) -> &mut Self { + self.body = body; + self } } impl fmt::Display for DhtOutboundMessage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + let header_str = self + .custom_header + .as_ref() + .and_then(|h| Some(format!("{} (Propagated)", h))) + .unwrap_or_else(|| { + format!( + "Network: {:?}, Flags: {:?}, Destination: {}", + self.network, self.dht_flags, self.destination + ) + }); write!( f, - "\n---- DhtOutboundMessage ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {} \nFlags: \ - {:?}\nEncryption: {}\n{}\n----", + "\n---- Outgoing message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\nEncryption: {}\n{}\n----", self.body.len(), - self.dht_header.message_type, + self.dht_message_type, self.destination_peer, - self.dht_header, - self.dht_header.flags, + header_str, self.encryption, self.tag ) diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index 89cc7c8e65..1ca43c822c 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -26,6 +26,7 @@ use crate::outbound::{ DhtOutboundRequest, OutboundMessageRequester, }; +use bytes::Bytes; use futures::{channel::mpsc, stream::Fuse, StreamExt}; use std::{ sync::{Arc, Condvar, Mutex, RwLock}, @@ -44,7 +45,7 @@ pub fn create_outbound_service_mock(size: usize) -> (OutboundMessageRequester, O #[derive(Clone, Default)] pub struct OutboundServiceMockState { #[allow(clippy::type_complexity)] - calls: Arc)>>>, + calls: Arc>>, next_response: Arc>>, call_count_cond_var: Arc, } @@ -88,7 +89,7 @@ impl OutboundServiceMockState { /// Wait for a call to be added or timeout. /// /// An error will be returned if the timeout expires. - pub fn wait_pop_call(&self, timeout: Duration) -> Result<(FinalSendMessageParams, Vec), String> { + pub fn wait_pop_call(&self, timeout: Duration) -> Result<(FinalSendMessageParams, Bytes), String> { let call_guard = acquire_lock!(self.calls); let (mut call_guard, timeout) = self .call_count_cond_var @@ -106,16 +107,16 @@ impl OutboundServiceMockState { acquire_write_lock!(self.next_response).take() } - pub fn add_call(&self, req: (FinalSendMessageParams, Vec)) { + pub fn add_call(&self, req: (FinalSendMessageParams, Bytes)) { acquire_lock!(self.calls).push(req); self.call_count_cond_var.notify_all(); } - pub fn take_calls(&self) -> Vec<(FinalSendMessageParams, Vec)> { + pub fn take_calls(&self) -> Vec<(FinalSendMessageParams, Bytes)> { acquire_lock!(self.calls).drain(..).collect() } - pub fn pop_call(&self) -> Option<(FinalSendMessageParams, Vec)> { + pub fn pop_call(&self) -> Option<(FinalSendMessageParams, Bytes)> { acquire_lock!(self.calls).pop() } } diff --git a/comms/dht/src/outbound/mod.rs b/comms/dht/src/outbound/mod.rs index 354910eb81..3a5d58fd01 100644 --- a/comms/dht/src/outbound/mod.rs +++ b/comms/dht/src/outbound/mod.rs @@ -21,22 +21,25 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod broadcast; +pub use broadcast::BroadcastLayer; + mod encryption; +pub use encryption::EncryptionLayer; + mod error; +pub use error::DhtOutboundError; + pub(crate) mod message; +pub use message::{DhtOutboundRequest, OutboundEncryption, SendMessageResponse}; + mod message_params; +pub use message_params::SendMessageParams; + mod requester; +pub use requester::OutboundMessageRequester; + mod serialize; +pub use serialize::SerializeLayer; #[cfg(any(test, feature = "test-mocks"))] pub mod mock; - -pub use self::{ - broadcast::BroadcastLayer, - encryption::EncryptionLayer, - error::DhtOutboundError, - message::{DhtOutboundRequest, OutboundEncryption, SendMessageResponse}, - message_params::SendMessageParams, - requester::OutboundMessageRequester, - serialize::SerializeLayer, -}; diff --git a/comms/dht/src/outbound/requester.rs b/comms/dht/src/outbound/requester.rs index 9eb6eb0025..64e0793a01 100644 --- a/comms/dht/src/outbound/requester.rs +++ b/comms/dht/src/outbound/requester.rs @@ -195,7 +195,7 @@ impl OutboundMessageRequester { message ); } - let body = wrap_in_envelope_body!(message.to_header(), message.into_inner())?.to_encoded_bytes()?; + let body = wrap_in_envelope_body!(message.to_header(), message.into_inner()).to_encoded_bytes(); self.send_raw(params, body).await } @@ -211,7 +211,7 @@ impl OutboundMessageRequester { if cfg!(debug_assertions) { trace!(target: LOG_TARGET, "Send Message: {} {:?}", params, message); } - let body = wrap_in_envelope_body!(message)?.to_encoded_bytes()?; + let body = wrap_in_envelope_body!(message).to_encoded_bytes(); self.send_raw(params, body).await } @@ -224,7 +224,7 @@ impl OutboundMessageRequester { { let (reply_tx, reply_rx) = oneshot::channel(); self.sender - .send(DhtOutboundRequest::SendMessage(Box::new(params), body, reply_tx)) + .send(DhtOutboundRequest::SendMessage(Box::new(params), body.into(), reply_tx)) .await?; reply_rx diff --git a/comms/dht/src/outbound/serialize.rs b/comms/dht/src/outbound/serialize.rs index 1815756046..ce761d9464 100644 --- a/comms/dht/src/outbound/serialize.rs +++ b/comms/dht/src/outbound/serialize.rs @@ -20,19 +20,20 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{outbound::message::DhtOutboundMessage, proto::envelope::DhtEnvelope}; +use crate::{ + consts::DHT_ENVELOPE_HEADER_VERSION, + outbound::message::DhtOutboundMessage, + proto::envelope::{DhtEnvelope, DhtHeader}, +}; use futures::{task::Context, Future}; use log::*; -use rand::rngs::OsRng; -use std::{sync::Arc, task::Poll}; +use std::task::Poll; use tari_comms::{ message::{MessageExt, OutboundMessage}, - peer_manager::NodeIdentity, pipeline::PipelineError, - utils::signature, Bytes, }; -use tari_crypto::tari_utilities::message_format::MessageFormat; +use tari_crypto::tari_utilities::ByteArray; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::serialize"; @@ -40,15 +41,11 @@ const LOG_TARGET: &str = "comms::dht::serialize"; #[derive(Clone)] pub struct SerializeMiddleware { inner: S, - node_identity: Arc, } impl SerializeMiddleware { - pub fn new(service: S, node_identity: Arc) -> Self { - Self { - inner: service, - node_identity, - } + pub fn new(service: S) -> Self { + Self { inner: service } } } @@ -64,68 +61,51 @@ where S: Service + Clone Poll::Ready(Ok(())) } - fn call(&mut self, msg: DhtOutboundMessage) -> Self::Future { - Self::serialize(self.inner.clone(), Arc::clone(&self.node_identity), msg) - } -} - -impl SerializeMiddleware -where S: Service -{ - pub async fn serialize( - next_service: S, - node_identity: Arc, - message: DhtOutboundMessage, - ) -> Result<(), PipelineError> - { - debug!(target: LOG_TARGET, "Serializing outbound message {:?}", message.tag); - - let DhtOutboundMessage { - mut dht_header, - body, - destination_peer, - .. - } = message; - - // The message is being forwarded if the origin public_key is specified and it is not this node - let is_forwarded = dht_header - .origin - .as_ref() - .map(|o| &o.public_key != node_identity.public_key()) - .unwrap_or(false); - - // If forwarding the message, the DhtHeader already has a signature that should not change - if is_forwarded { - debug!( - target: LOG_TARGET, - "Message ({}) is being forwarded so this node will NOT signed it", message.tag - ); - } else { - // Sign the body if the origin public key was previously specified. - if let Some(origin) = dht_header.origin.as_mut() { - let signature = signature::sign(&mut OsRng, node_identity.secret_key().clone(), &body) - .map_err(PipelineError::from_debug)?; - origin.signature = signature.to_binary().map_err(PipelineError::from_debug)?; - } + fn call(&mut self, message: DhtOutboundMessage) -> Self::Future { + let next_service = self.inner.clone(); + async move { + debug!(target: LOG_TARGET, "Serializing outbound message {:?}", message.tag); + + let DhtOutboundMessage { + tag, + destination_peer, + custom_header, + body, + ephemeral_public_key, + destination, + dht_message_type, + network, + dht_flags, + origin_mac, + .. + } = message; + + let dht_header = custom_header.map(DhtHeader::from).unwrap_or_else(|| DhtHeader { + version: DHT_ENVELOPE_HEADER_VERSION, + origin_mac: origin_mac.unwrap_or_else(Vec::new), + ephemeral_public_key: ephemeral_public_key.map(|e| e.to_vec()).unwrap_or_else(Vec::new), + message_type: dht_message_type as i32, + network: network as i32, + flags: dht_flags.bits(), + destination: Some(destination.into()), + }); + + let envelope = DhtEnvelope::new(dht_header, body); + + let body = Bytes::from(envelope.to_encoded_bytes()); + + next_service + .oneshot(OutboundMessage::with_tag(tag, destination_peer.node_id, body)) + .await } - - let envelope = DhtEnvelope::new(dht_header.into(), body); - - let body = Bytes::from(envelope.to_encoded_bytes().map_err(PipelineError::from_debug)?); - - next_service - .oneshot(OutboundMessage::with_tag(message.tag, destination_peer.node_id, body)) - .await } } -pub struct SerializeLayer { - node_identity: Arc, -} +pub struct SerializeLayer; impl SerializeLayer { - pub fn new(node_identity: Arc) -> Self { - Self { node_identity } + pub fn new() -> Self { + Self } } @@ -133,50 +113,29 @@ impl Layer for SerializeLayer { type Service = SerializeMiddleware; fn layer(&self, service: S) -> Self::Service { - SerializeMiddleware::new(service, Arc::clone(&self.node_identity)) + SerializeMiddleware::new(service) } } #[cfg(test)] mod test { use super::*; - use crate::{ - envelope::DhtMessageFlags, - outbound::OutboundEncryption, - test_utils::{make_dht_header, make_node_identity, service_spy}, - }; + use crate::test_utils::{create_outbound_message, service_spy}; use futures::executor::block_on; use prost::Message; - use tari_comms::{ - net_address::MultiaddressesWithStats, - peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, - types::CommsPublicKey, - }; + use tari_comms::peer_manager::NodeId; use tari_test_utils::panic_context; #[test] fn serialize() { let spy = service_spy(); - let node_identity = make_node_identity(); - let mut serialize = SerializeLayer::new(Arc::clone(&node_identity)).layer(spy.to_service::()); + let mut serialize = SerializeLayer.layer(spy.to_service::()); panic_context!(cx); assert!(serialize.poll_ready(&mut cx).is_ready()); - let body = b"A".to_vec(); - let msg = DhtOutboundMessage::new( - Peer::new( - CommsPublicKey::default(), - NodeId::default(), - MultiaddressesWithStats::new(vec![]), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_NODE, - &[], - ), - make_dht_header(&node_identity, &body, DhtMessageFlags::empty()), - OutboundEncryption::None, - body, - ); + let body = b"A"; + let msg = create_outbound_message(body); block_on(serialize.call(msg)).unwrap(); let mut msg = spy.pop_request().unwrap(); diff --git a/comms/dht/src/proto/envelope.proto b/comms/dht/src/proto/envelope.proto index 43f1815cee..5a3071dfa3 100644 --- a/comms/dht/src/proto/envelope.proto +++ b/comms/dht/src/proto/envelope.proto @@ -32,14 +32,17 @@ message DhtHeader { } // Origin public key of the message. This can be the same peer that sent the message - // or another peer if the message should be forwarded. This is optional but must be specified + // or another peer if the message should be forwarded. This is optional but MUST be specified // if the ENCRYPTED flag is set. - DhtOrigin origin = 5; + // If an ephemeral_public_key is specified, this MUST be encrypted using a derived ECDH shared key + bytes origin_mac = 5; + // Ephemeral public key component of the ECDH shared key. MUST be specified if the ENCRYPTED flag is set. + bytes ephemeral_public_key = 6; // The type of message - DhtMessageType message_type = 6; + DhtMessageType message_type = 7; // The network for which this message is intended (e.g. TestNet, MainNet etc.) - Network network = 7; - uint32 flags = 8; + Network network = 8; + uint32 flags = 9; } enum Network { @@ -56,7 +59,8 @@ message DhtEnvelope { bytes body = 2; } -message DhtOrigin { +// The Message Authentication Code (MAC) message format of the decrypted `DhtHeader::origin_mac` field +message OriginMac { bytes public_key = 1; bytes signature = 2; } \ No newline at end of file diff --git a/comms/dht/src/proto/tari.dht.envelope.rs b/comms/dht/src/proto/tari.dht.envelope.rs index f7e24f31b2..734bee953f 100644 --- a/comms/dht/src/proto/tari.dht.envelope.rs +++ b/comms/dht/src/proto/tari.dht.envelope.rs @@ -3,17 +3,21 @@ pub struct DhtHeader { #[prost(uint32, tag = "1")] pub version: u32, /// Origin public key of the message. This can be the same peer that sent the message - /// or another peer if the message should be forwarded. This is optional but must be specified + /// or another peer if the message should be forwarded. This is optional but MUST be specified /// if the ENCRYPTED flag is set. - #[prost(message, optional, tag = "5")] - pub origin: ::std::option::Option, + /// If an ephemeral_public_key is specified, this MUST be encrypted using a derived ECDH shared key + #[prost(bytes, tag = "5")] + pub origin_mac: std::vec::Vec, + /// Ephemeral public key component of the ECDH shared key. MUST be specified if the ENCRYPTED flag is set. + #[prost(bytes, tag = "6")] + pub ephemeral_public_key: std::vec::Vec, /// The type of message - #[prost(enumeration = "DhtMessageType", tag = "6")] + #[prost(enumeration = "DhtMessageType", tag = "7")] pub message_type: i32, /// The network for which this message is intended (e.g. TestNet, MainNet etc.) - #[prost(enumeration = "Network", tag = "7")] + #[prost(enumeration = "Network", tag = "8")] pub network: i32, - #[prost(uint32, tag = "8")] + #[prost(uint32, tag = "9")] pub flags: u32, #[prost(oneof = "dht_header::Destination", tags = "2, 3, 4")] pub destination: ::std::option::Option, @@ -40,8 +44,9 @@ pub struct DhtEnvelope { #[prost(bytes, tag = "2")] pub body: std::vec::Vec, } +/// The Message Authentication Code (MAC) message format of the decrypted `DhtHeader::origin_mac` field #[derive(Clone, PartialEq, ::prost::Message)] -pub struct DhtOrigin { +pub struct OriginMac { #[prost(bytes, tag = "1")] pub public_key: std::vec::Vec, #[prost(bytes, tag = "2")] diff --git a/comms/dht/src/schema.rs b/comms/dht/src/schema.rs index 2ebbfb28b0..9f2f9910b1 100644 --- a/comms/dht/src/schema.rs +++ b/comms/dht/src/schema.rs @@ -10,8 +10,7 @@ table! { stored_messages (id) { id -> Integer, version -> Integer, - origin_pubkey -> Text, - origin_signature -> Text, + origin_pubkey -> Nullable, message_type -> Integer, destination_pubkey -> Nullable, destination_node_id -> Nullable, diff --git a/comms/dht/src/store_forward/database/stored_message.rs b/comms/dht/src/store_forward/database/stored_message.rs index d3d9826674..0aa8ce3f92 100644 --- a/comms/dht/src/store_forward/database/stored_message.rs +++ b/comms/dht/src/store_forward/database/stored_message.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - envelope::DhtMessageHeader, + inbound::DecryptedDhtMessage, proto::envelope::DhtHeader, schema::stored_messages, store_forward::message::StoredMessagePriority, @@ -35,8 +35,7 @@ use tari_crypto::tari_utilities::hex::Hex; #[table_name = "stored_messages"] pub struct NewStoredMessage { pub version: i32, - pub origin_pubkey: String, - pub origin_signature: String, + pub origin_pubkey: Option, pub message_type: i32, pub destination_pubkey: Option, pub destination_node_id: Option, @@ -47,27 +46,33 @@ pub struct NewStoredMessage { } impl NewStoredMessage { - pub fn try_construct( - version: u32, - dht_header: DhtMessageHeader, - priority: StoredMessagePriority, - body: Vec, - ) -> Option - { + pub fn try_construct(message: DecryptedDhtMessage, priority: StoredMessagePriority) -> Option { + let DecryptedDhtMessage { + version, + authenticated_origin, + decryption_result, + dht_header, + .. + } = message; + + let body = match decryption_result { + Ok(envelope_body) => envelope_body.to_encoded_bytes(), + Err(encrypted_body) => encrypted_body, + }; + Some(Self { version: version.try_into().ok()?, - origin_pubkey: dht_header.origin.as_ref().map(|o| o.public_key.to_hex())?, - origin_signature: dht_header.origin.as_ref().map(|o| o.signature.to_hex())?, + origin_pubkey: authenticated_origin.as_ref().map(|pk| pk.to_hex()), message_type: dht_header.message_type as i32, destination_pubkey: dht_header.destination.public_key().map(|pk| pk.to_hex()), destination_node_id: dht_header.destination.node_id().map(|node_id| node_id.to_hex()), - body, is_encrypted: dht_header.flags.is_encrypted(), priority: priority as i32, header: { let dht_header: DhtHeader = dht_header.into(); - dht_header.to_encoded_bytes().ok()? + dht_header.to_encoded_bytes() }, + body, }) } } @@ -76,8 +81,7 @@ impl NewStoredMessage { pub struct StoredMessage { pub id: i32, pub version: i32, - pub origin_pubkey: String, - pub origin_signature: String, + pub origin_pubkey: Option, pub message_type: i32, pub destination_pubkey: Option, pub destination_node_id: Option, diff --git a/comms/dht/src/store_forward/error.rs b/comms/dht/src/store_forward/error.rs index a5573b11f6..e2c19ef830 100644 --- a/comms/dht/src/store_forward/error.rs +++ b/comms/dht/src/store_forward/error.rs @@ -36,9 +36,11 @@ pub enum StoreAndForwardError { /// Received stored message has an invalid destination InvalidDestination, /// Received stored message has an invalid origin signature - InvalidSignature, + InvalidOriginMac, /// Invalid envelope body InvalidEnvelopeBody, + /// DHT header is invalid + InvalidDhtHeader, /// Received stored message which is not encrypted StoredMessageNotEncrypted, /// Unable to decrypt received stored message diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index beeee69754..e6983d7ed5 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -147,6 +147,7 @@ where S: Service source_peer, decryption_result, dht_header, + authenticated_origin, .. } = message; @@ -171,8 +172,8 @@ where S: Service .expect("previous check that decryption failed"); let mut excluded_peers = vec![source_peer.public_key.clone()]; - if let Some(origin) = dht_header.origin.as_ref() { - excluded_peers.push(origin.public_key.clone()); + if let Some(pk) = authenticated_origin.as_ref() { + excluded_peers.push(pk.clone()); } let mut message_params = self.get_send_params(&dht_header, excluded_peers).await?; @@ -271,8 +272,13 @@ mod test { let oms = OutboundMessageRequester::new(oms_tx); let mut service = ForwardLayer::new(peer_manager, oms).layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); - let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()).unwrap(), inbound_msg); + let node_identity = make_node_identity(); + let inbound_msg = make_dht_inbound_message(&node_identity, b"".to_vec(), DhtMessageFlags::empty(), false); + let msg = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(Vec::new()), + Some(node_identity.public_key().clone()), + inbound_msg, + ); block_on(service.call(msg)).unwrap(); assert!(spy.is_called()); assert!(oms_rx.try_next().is_err()); @@ -289,7 +295,8 @@ mod test { let mut service = ForwardLayer::new(peer_manager, oms_requester).layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); + let inbound_msg = + make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty(), false); let msg = DecryptedDhtMessage::failed(inbound_msg); rt.block_on(service.call(msg)).unwrap(); assert!(spy.is_called()); diff --git a/comms/dht/src/store_forward/message.rs b/comms/dht/src/store_forward/message.rs index 5cdcc01c3b..41dc126022 100644 --- a/comms/dht/src/store_forward/message.rs +++ b/comms/dht/src/store_forward/message.rs @@ -68,11 +68,11 @@ impl StoredMessagesRequest { #[cfg(test)] impl StoredMessage { - pub fn new(version: u32, dht_header: crate::envelope::DhtMessageHeader, encrypted_body: Vec) -> Self { + pub fn new(version: u32, dht_header: crate::envelope::DhtMessageHeader, body: Vec) -> Self { Self { version, dht_header: Some(dht_header.into()), - body: encrypted_body, + body, stored_at: Some(datetime_to_timestamp(Utc::now())), } } diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index 9a0274037d..1e944b6050 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -24,11 +24,11 @@ use crate::{ actor::DhtRequester, config::DhtConfig, crypt, - envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, NodeDestination}, + envelope::{DhtMessageFlags, DhtMessageHeader, NodeDestination}, inbound::{DecryptedDhtMessage, DhtInboundMessage}, outbound::{OutboundMessageRequester, SendMessageParams}, proto::{ - envelope::DhtMessageType, + envelope::{DhtMessageType, OriginMac}, store_forward::{ stored_messages_response::SafResponseType, StoredMessage as ProtoStoredMessage, @@ -55,9 +55,10 @@ use tari_comms::{ message::EnvelopeBody, peer_manager::{NodeIdentity, Peer, PeerManager, PeerManagerError}, pipeline::PipelineError, - types::Challenge, + types::{Challenge, CommsPublicKey}, utils::signature, }; +use tari_crypto::tari_utilities::ByteArray; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::store_forward"; @@ -373,6 +374,10 @@ where S: Service .try_into() .map_err(StoreAndForwardError::DhtMessageError)?; + if !dht_header.is_valid() { + return Err(StoreAndForwardError::InvalidDhtHeader); + } + if dht_header.message_type.is_dht_message() { trace!( target: LOG_TARGET, @@ -382,26 +387,22 @@ where S: Service ); } - let dht_flags = dht_header.flags; - - let origin = dht_header - .origin - .as_ref() - .ok_or_else(|| StoreAndForwardError::MessageOriginRequired)?; - - // Check that the destination is either undisclosed + // Check that the destination is either undisclosed, for us or for our network region Self::check_destination(&config, &peer_manager, &node_identity, &dht_header).await?; - // Verify the signature - Self::check_signature(origin, &message.body)?; // Check that the message has not already been received. Self::check_duplicate(&mut dht_requester, &message.body).await?; // Attempt to decrypt the message (if applicable), and deserialize it - let decrypted_body = Self::maybe_decrypt_and_deserialize(&node_identity, origin, dht_flags, &message.body)?; + let (authenticated_pk, decrypted_body) = + Self::authenticate_and_decrypt_if_required(&node_identity, &dht_header, &message.body)?; let inbound_msg = DhtInboundMessage::new(dht_header, Arc::clone(&source_peer), message.body); - Ok(DecryptedDhtMessage::succeeded(decrypted_body, inbound_msg)) + Ok(DecryptedDhtMessage::succeeded( + decrypted_body, + authenticated_pk, + inbound_msg, + )) } } @@ -439,34 +440,54 @@ where S: Service } } - fn check_signature(origin: &DhtMessageOrigin, body: &[u8]) -> Result<(), StoreAndForwardError> { - signature::verify(&origin.public_key, &origin.signature, body) - .map_err(|_| StoreAndForwardError::InvalidSignature) - .and_then(|is_valid| { - if is_valid { - Ok(()) - } else { - Err(StoreAndForwardError::InvalidSignature) - } - }) - } - - fn maybe_decrypt_and_deserialize( + fn authenticate_and_decrypt_if_required( node_identity: &NodeIdentity, - origin: &DhtMessageOrigin, - flags: DhtMessageFlags, + header: &DhtMessageHeader, body: &[u8], - ) -> Result + ) -> Result<(Option, EnvelopeBody), StoreAndForwardError> { - if flags.contains(DhtMessageFlags::ENCRYPTED) { - let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), &origin.public_key); + if header.flags.contains(DhtMessageFlags::ENCRYPTED) { + let ephemeral_public_key = header.ephemeral_public_key.as_ref().expect( + "[store and forward] DHT header is invalid after validity check because it did not contain an \ + ephemeral_public_key", + ); + let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), ephemeral_public_key); + let decrypted = crypt::decrypt(&shared_secret, &header.origin_mac)?; + let authenticated_pk = Self::authenticate_message(&decrypted, body)?; + let decrypted_bytes = crypt::decrypt(&shared_secret, body)?; - EnvelopeBody::decode(decrypted_bytes.as_slice()).map_err(|_| StoreAndForwardError::DecryptionFailed) + let envelope_body = + EnvelopeBody::decode(decrypted_bytes.as_slice()).map_err(|_| StoreAndForwardError::DecryptionFailed)?; + if envelope_body.is_empty() { + return Err(StoreAndForwardError::InvalidEnvelopeBody); + } + Ok((Some(authenticated_pk), envelope_body)) } else { - // Malformed cleartext messages should never have been forwarded by the peer - EnvelopeBody::decode(body).map_err(|_| StoreAndForwardError::MalformedMessage) + let authenticated_pk = if !header.origin_mac.is_empty() { + Some(Self::authenticate_message(&header.origin_mac, body)?) + } else { + None + }; + let envelope_body = EnvelopeBody::decode(body).map_err(|_| StoreAndForwardError::MalformedMessage)?; + Ok((authenticated_pk, envelope_body)) } } + + fn authenticate_message(origin_mac_body: &[u8], body: &[u8]) -> Result { + let origin_mac = OriginMac::decode(origin_mac_body)?; + let public_key = + CommsPublicKey::from_bytes(&origin_mac.public_key).map_err(|_| StoreAndForwardError::InvalidOriginMac)?; + signature::verify(&public_key, &origin_mac.signature, body) + .map_err(|_| StoreAndForwardError::InvalidOriginMac) + .and_then(|is_valid| { + if is_valid { + Ok(()) + } else { + Err(StoreAndForwardError::InvalidOriginMac) + } + })?; + Ok(public_key) + } } #[cfg(test)] @@ -481,6 +502,7 @@ mod test { create_store_and_forward_mock, make_dht_header, make_dht_inbound_message, + make_keypair, make_node_identity, make_peer_manager, service_spy, @@ -500,12 +522,11 @@ mod test { StoredMessage { id: 1, version: 0, - origin_pubkey: node_identity.public_key().to_hex(), - origin_signature: String::new(), + origin_pubkey: Some(node_identity.public_key().to_hex()), message_type: DhtMessageType::None as i32, destination_pubkey: None, destination_node_id: None, - header: DhtHeader::from(dht_header).to_encoded_bytes().unwrap(), + header: DhtHeader::from(dht_header).to_encoded_bytes(), body: b"A".to_vec(), is_encrypted: false, priority: StoredMessagePriority::High as i32, @@ -525,15 +546,22 @@ mod test { let node_identity = make_node_identity(); // Recent message - let dht_header = make_dht_header(&node_identity, &[], DhtMessageFlags::empty()); + let (e_sk, e_pk) = make_keypair(); + let dht_header = make_dht_header(&node_identity, &e_pk, &e_sk, &[], DhtMessageFlags::empty(), false); mock_state .add_message(make_stored_message(&node_identity, dht_header)) .await; let since = Utc::now().checked_sub_signed(chrono::Duration::seconds(60)).unwrap(); let mut message = DecryptedDhtMessage::succeeded( - wrap_in_envelope_body!(StoredMessagesRequest::since(since)).unwrap(), - make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::ENCRYPTED), + wrap_in_envelope_body!(StoredMessagesRequest::since(since)), + None, + make_dht_inbound_message( + &node_identity, + b"Keep this for others please".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ), ); message.dht_header.message_type = DhtMessageType::SafRequestMessages; @@ -554,6 +582,7 @@ mod test { rt_handle.spawn(task.run()); let (_, body) = unwrap_oms_send_msg!(oms_rx.next().await.unwrap()); + let body = body.to_vec(); let body = EnvelopeBody::decode(body.as_slice()).unwrap(); let msg = body.decode_part::(0).unwrap().unwrap(); assert_eq!(msg.messages().len(), 1); @@ -578,56 +607,43 @@ mod test { let node_identity = make_node_identity(); - let shared_key = crypt::generate_ecdh_secret(node_identity.secret_key(), node_identity.public_key()); - let msg_a = crypt::encrypt( - &shared_key, - &wrap_in_envelope_body!(&b"A".to_vec()) - .unwrap() - .to_encoded_bytes() - .unwrap(), - ) - .unwrap(); - - let inbound_msg_a = make_dht_inbound_message(&node_identity, msg_a.clone(), DhtMessageFlags::ENCRYPTED); + let msg_a = wrap_in_envelope_body!(&b"A".to_vec()).to_encoded_bytes(); + + let inbound_msg_a = make_dht_inbound_message(&node_identity, msg_a.clone(), DhtMessageFlags::ENCRYPTED, true); // Need to know the peer to process a stored message peer_manager .add_peer(Clone::clone(&*inbound_msg_a.source_peer)) .await .unwrap(); - let msg_b = crypt::encrypt( - &shared_key, - &wrap_in_envelope_body!(b"B".to_vec()) - .unwrap() - .to_encoded_bytes() - .unwrap(), - ) - .unwrap(); - - let inbound_msg_b = make_dht_inbound_message(&node_identity, msg_b.clone(), DhtMessageFlags::ENCRYPTED); + + let msg_b = &wrap_in_envelope_body!(b"B".to_vec()).to_encoded_bytes(); + let inbound_msg_b = make_dht_inbound_message(&node_identity, msg_b.clone(), DhtMessageFlags::ENCRYPTED, true); // Need to know the peer to process a stored message peer_manager .add_peer(Clone::clone(&*inbound_msg_b.source_peer)) .await .unwrap(); - let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), msg_a); - let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, msg_b); + let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), inbound_msg_a.body); + let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, inbound_msg_b.body); // Cleartext message - let clear_msg = wrap_in_envelope_body!(b"Clear".to_vec()) - .unwrap() - .to_encoded_bytes() - .unwrap(); + let clear_msg = wrap_in_envelope_body!(b"Clear".to_vec()).to_encoded_bytes(); let clear_header = - make_dht_inbound_message(&node_identity, clear_msg.clone(), DhtMessageFlags::empty()).dht_header; + make_dht_inbound_message(&node_identity, clear_msg.clone(), DhtMessageFlags::empty(), false).dht_header; let msg_clear = ProtoStoredMessage::new(0, clear_header, clear_msg); let mut message = DecryptedDhtMessage::succeeded( wrap_in_envelope_body!(StoredMessagesResponse { messages: vec![msg1.clone(), msg2, msg_clear], request_id: 123, response_type: 0 - }) - .unwrap(), - make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::ENCRYPTED), + }), + None, + make_dht_inbound_message( + &node_identity, + b"Stored message".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ), ); message.dht_header.message_type = DhtMessageType::SafStoredMessages; diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index dd99f2bb94..0baa43f6f4 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -37,7 +37,6 @@ use futures::{task::Context, Future}; use log::*; use std::{sync::Arc, task::Poll}; use tari_comms::{ - message::MessageExt, peer_manager::{NodeId, NodeIdentity, PeerManager}, pipeline::PipelineError, }; @@ -205,17 +204,21 @@ where S: Service ); }; - if message.body_size() > self.config.saf_max_message_size { + if message.body_len() > self.config.saf_max_message_size { log_not_eligible(&format!( "the message body exceeded the maximum storage size (body size={}, max={})", - message.body_size(), + message.body_len(), self.config.saf_max_message_size )); return Ok(None); } - if message.origin_public_key() == self.node_identity.public_key() { - log_not_eligible("not storing message from this node"); + if message + .authenticated_origin() + .map(|pk| pk == self.node_identity.public_key()) + .unwrap_or(false) + { + log_not_eligible("this message originates from this node"); return Ok(None); } @@ -223,12 +226,12 @@ where S: Service // The message decryption was successful, or the message was not encrypted Some(_) => { // If the message doesnt have an origin we wont store it - if !message.has_origin() { - log_not_eligible("it does not have an origin"); + if !message.has_origin_mac() { + log_not_eligible("it is encrypted and does not have an origin MAC"); return Ok(None); } - // If this node decrypted the message, no need to store it + // If this node decrypted the message (message.success() above), no need to store it if message.is_encrypted() { log_not_eligible("the message was encrypted for this node"); return Ok(None); @@ -251,12 +254,12 @@ where S: Service }, // This node could not decrypt the message None => { - if !message.has_origin() { - // TODO: #banheuristic - the source should not have propagated this message + if !message.has_origin_mac() { + // TODO: #banheuristic - the source peer should not have propagated this message warn!( target: LOG_TARGET, - "Store task received an encrypted message with no source. This message is invalid and should \ - not be stored or propagated. Dropping message. Sent by node '{}'", + "Store task received an encrypted message with no origin MAC. This message is invalid and \ + should not be stored or propagated. Dropping message. Sent by node '{}'", message.source_peer.node_id.short_str() ); return Ok(None); @@ -387,26 +390,14 @@ where S: Service } async fn store(&mut self, priority: StoredMessagePriority, message: DecryptedDhtMessage) -> SafResult<()> { - let DecryptedDhtMessage { - version, - decryption_result, - dht_header, - .. - } = message; - - let body = match decryption_result { - Ok(body) => body.to_encoded_bytes()?, - Err(encrypted_body) => encrypted_body, - }; - debug!( target: LOG_TARGET, "Storing message from peer '{}' ({} bytes)", message.source_peer.node_id.short_str(), - body.len() + message.body_len(), ); - let stored_message = NewStoredMessage::try_construct(version, dht_header, priority, body) + let stored_message = NewStoredMessage::try_construct(message, priority) .ok_or_else(|| StoreAndForwardError::InvalidStoreMessage)?; self.saf_requester.insert_message(stored_message).await?; @@ -430,7 +421,7 @@ mod test { }; use chrono::Utc; use std::time::Duration; - use tari_comms::wrap_in_envelope_body; + use tari_comms::{message::MessageExt, wrap_in_envelope_body}; use tari_crypto::tari_utilities::hex::Hex; use tari_test_utils::async_assert_eventually; @@ -444,9 +435,9 @@ mod test { let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let mut inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); - inbound_msg.dht_header.origin = None; - let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()).unwrap(), inbound_msg); + let inbound_msg = + make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty(), false); + let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()), None, inbound_msg); service.call(msg).await.unwrap(); assert!(spy.is_called()); let messages = mock_state.get_messages().await; @@ -465,14 +456,18 @@ mod test { addresses: vec![], peer_features: 0, } - .to_encoded_bytes() - .unwrap(); + .to_encoded_bytes(); let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); + let sender_identity = make_node_identity(); + let inbound_msg = make_dht_inbound_message(&sender_identity, b"".to_vec(), DhtMessageFlags::empty(), true); - let mut msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(join_msg_bytes).unwrap(), inbound_msg); + let mut msg = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(join_msg_bytes), + Some(sender_identity.public_key().clone()), + inbound_msg, + ); msg.dht_header.message_type = DhtMessageType::Join; service.call(msg).await.unwrap(); assert!(spy.is_called()); @@ -499,8 +494,18 @@ mod test { let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::ENCRYPTED); - let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(b"secret".to_vec()).unwrap(), inbound_msg); + let msg_node_identity = make_node_identity(); + let inbound_msg = make_dht_inbound_message( + &msg_node_identity, + b"This shouldnt be stored".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ); + let msg = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(b"secret".to_vec()), + Some(msg_node_identity.public_key().clone()), + inbound_msg, + ); service.call(msg).await.unwrap(); assert!(spy.is_called()); @@ -518,7 +523,12 @@ mod test { let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let mut inbound_msg = make_dht_inbound_message(&origin_node_identity, b"".to_vec(), DhtMessageFlags::empty()); + let mut inbound_msg = make_dht_inbound_message( + &origin_node_identity, + b"Will you keep this for me?".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ); inbound_msg.dht_header.destination = NodeDestination::PublicKey(Box::new(origin_node_identity.public_key().clone())); let msg = DecryptedDhtMessage::failed(inbound_msg.clone()); @@ -534,8 +544,8 @@ mod test { let message = mock_state.get_messages().await.remove(0); assert_eq!( - message.origin_signature, - inbound_msg.dht_header.origin.unwrap().signature.to_hex() + message.destination_pubkey.unwrap(), + origin_node_identity.public_key().to_hex() ); let duration = Utc::now().naive_utc().signed_duration_since(message.stored_at); assert!(duration.num_seconds() <= 5); diff --git a/comms/dht/src/test_utils/dht_discovery_mock.rs b/comms/dht/src/test_utils/dht_discovery_mock.rs index ae07a83cd0..89c5b6bd06 100644 --- a/comms/dht/src/test_utils/dht_discovery_mock.rs +++ b/comms/dht/src/test_utils/dht_discovery_mock.rs @@ -102,8 +102,7 @@ impl DhtDiscoveryMock { trace!(target: LOG_TARGET, "DhtDiscoveryMock received request {:?}", req); self.state.inc_call_count(); match req { - DiscoverPeer(boxed) => { - let (_, reply_tx) = *boxed; + DiscoverPeer(_, reply_tx) => { let lock = self.state.discover_peer.read().unwrap(); reply_tx.send(Ok(lock.clone())).unwrap(); }, diff --git a/comms/dht/src/test_utils/makers.rs b/comms/dht/src/test_utils/makers.rs index 862ecb474e..e56e1d76a3 100644 --- a/comms/dht/src/test_utils/makers.rs +++ b/comms/dht/src/test_utils/makers.rs @@ -20,21 +20,27 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, NodeDestination}, + crypt, + envelope::{DhtMessageFlags, DhtMessageHeader, NodeDestination}, inbound::DhtInboundMessage, - proto::envelope::{DhtEnvelope, DhtMessageType, Network}, + outbound::message::DhtOutboundMessage, + proto::envelope::{DhtEnvelope, DhtMessageType, Network, OriginMac}, }; use rand::rngs::OsRng; -use std::sync::Arc; +use std::{convert::TryInto, sync::Arc}; use tari_comms::{ - message::InboundMessage, + message::{InboundMessage, MessageExt, MessageTag}, multiaddr::Multiaddr, - peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerFlags, PeerManager}, - types::CommsDatabase, + net_address::MultiaddressesWithStats, + peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags, PeerManager}, + types::{CommsDatabase, CommsPublicKey, CommsSecretKey}, utils::signature, Bytes, }; -use tari_crypto::tari_utilities::message_format::MessageFormat; +use tari_crypto::{ + keys::PublicKey, + tari_utilities::{message_format::MessageFormat, ByteArray}, +}; use tari_storage::lmdb_store::LMDBBuilder; use tari_test_utils::{paths::create_temporary_data_path, random}; @@ -86,31 +92,63 @@ pub fn make_comms_inbound_message(node_identity: &NodeIdentity, message: Bytes) ) } -pub fn make_dht_header(node_identity: &NodeIdentity, message: &[u8], flags: DhtMessageFlags) -> DhtMessageHeader { +pub fn make_dht_header( + node_identity: &NodeIdentity, + e_pk: &CommsPublicKey, + e_sk: &CommsSecretKey, + message: &[u8], + flags: DhtMessageFlags, + include_origin: bool, +) -> DhtMessageHeader +{ DhtMessageHeader { version: 0, destination: NodeDestination::Unknown, - origin: Some(DhtMessageOrigin { - public_key: node_identity.public_key().clone(), - signature: signature::sign(&mut OsRng, node_identity.secret_key().clone(), message) - .unwrap() - .to_binary() - .unwrap(), - }), + ephemeral_public_key: if flags.is_encrypted() { Some(e_pk.clone()) } else { None }, + origin_mac: if include_origin { + make_valid_origin_mac(node_identity, &e_sk, message, flags) + } else { + Vec::new() + }, message_type: DhtMessageType::None, network: Network::LocalTest, flags, } } +pub fn make_valid_origin_mac( + node_identity: &NodeIdentity, + e_sk: &CommsSecretKey, + body: &[u8], + flags: DhtMessageFlags, +) -> Vec +{ + let mac = OriginMac { + public_key: node_identity.public_key().to_vec(), + signature: signature::sign(&mut OsRng, node_identity.secret_key().clone(), body) + .unwrap() + .to_binary() + .unwrap(), + }; + let body = mac.to_encoded_bytes(); + if flags.is_encrypted() { + let shared_secret = crypt::generate_ecdh_secret(e_sk, node_identity.public_key()); + crypt::encrypt(&shared_secret, &body).unwrap() + } else { + body + } +} + pub fn make_dht_inbound_message( node_identity: &NodeIdentity, body: Vec, flags: DhtMessageFlags, + include_origin: bool, ) -> DhtInboundMessage { + let envelope = make_dht_envelope(node_identity, body, flags, include_origin); DhtInboundMessage::new( - make_dht_header(node_identity, &body, flags), + envelope.header.unwrap().try_into().unwrap(), Arc::new(Peer::new( node_identity.public_key().clone(), node_identity.node_id().clone(), @@ -119,12 +157,28 @@ pub fn make_dht_inbound_message( PeerFeatures::COMMUNICATION_NODE, &[], )), - body, + envelope.body, ) } -pub fn make_dht_envelope(node_identity: &NodeIdentity, message: Vec, flags: DhtMessageFlags) -> DhtEnvelope { - DhtEnvelope::new(make_dht_header(node_identity, &message, flags).into(), message) +pub fn make_keypair() -> (CommsSecretKey, CommsPublicKey) { + CommsPublicKey::random_keypair(&mut OsRng) +} + +pub fn make_dht_envelope( + node_identity: &NodeIdentity, + mut message: Vec, + flags: DhtMessageFlags, + include_origin: bool, +) -> DhtEnvelope +{ + let (e_sk, e_pk) = make_keypair(); + if flags.is_encrypted() { + let shared_secret = crypt::generate_ecdh_secret(&e_sk, node_identity.public_key()); + message = crypt::encrypt(&shared_secret, &message).unwrap(); + } + let header = make_dht_header(node_identity, &e_pk, &e_sk, &message, flags, include_origin).into(); + DhtEnvelope::new(header, message.into()) } pub fn make_peer_manager() -> Arc { @@ -144,3 +198,27 @@ pub fn make_peer_manager() -> Arc { .map(Arc::new) .unwrap() } + +pub fn create_outbound_message(body: &[u8]) -> DhtOutboundMessage { + DhtOutboundMessage { + tag: MessageTag::new(), + destination_peer: Peer::new( + CommsPublicKey::default(), + NodeId::default(), + MultiaddressesWithStats::new(vec![]), + PeerFlags::empty(), + PeerFeatures::COMMUNICATION_NODE, + &[], + ), + destination: Default::default(), + dht_message_type: Default::default(), + network: Network::LocalTest, + dht_flags: Default::default(), + custom_header: None, + include_origin: false, + encryption: Default::default(), + body: body.to_vec().into(), + ephemeral_public_key: None, + origin_mac: None, + } +} diff --git a/comms/dht/src/test_utils/store_and_forward_mock.rs b/comms/dht/src/test_utils/store_and_forward_mock.rs index e338b323c4..66fe181a0b 100644 --- a/comms/dht/src/test_utils/store_and_forward_mock.rs +++ b/comms/dht/src/test_utils/store_and_forward_mock.rs @@ -120,7 +120,6 @@ impl StoreAndForwardMock { id: OsRng.next_u32() as i32, version: msg.version, origin_pubkey: msg.origin_pubkey, - origin_signature: msg.origin_signature, message_type: msg.message_type, destination_pubkey: msg.destination_pubkey, destination_node_id: msg.destination_node_id, diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index 75960b9f9c..521a07f6f8 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -33,7 +33,7 @@ use tari_comms::{ CommsBuilder, CommsNode, }; -use tari_comms_dht::{envelope::NodeDestination, inbound::DecryptedDhtMessage, Dht, DhtBuilder}; +use tari_comms_dht::{inbound::DecryptedDhtMessage, Dht, DhtBuilder}; use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; use tari_test_utils::{async_assert_eventually, paths::create_temporary_data_path, random}; use tower::ServiceBuilder; @@ -213,12 +213,7 @@ async fn dht_discover_propagation() { node_A .dht .discovery_service_requester() - .discover_peer( - Box::new(node_D.node_identity().public_key().clone()), - None, - // Sending to a nonsense NodeId, this should still propagate towards D in a network of 4 - NodeDestination::NodeId(Box::new(Default::default())), - ) + .discover_peer(Box::new(node_D.node_identity().public_key().clone())) .await .unwrap(); diff --git a/comms/src/connection_manager/manager.rs b/comms/src/connection_manager/manager.rs index dc5c54afe0..40b9bb934a 100644 --- a/comms/src/connection_manager/manager.rs +++ b/comms/src/connection_manager/manager.rs @@ -412,7 +412,7 @@ where self.send_dialer_request(DialerRequest::CancelPendingDial(node_id.clone())) .await; - match self.active_connections.remove(&node_id) { + match self.active_connections.get(&node_id) { Some(existing_conn) => { debug!( target: LOG_TARGET, @@ -421,7 +421,7 @@ where existing_conn.peer_node_id() ); - if self.tie_break_existing_connection(&existing_conn, &new_conn) { + if self.tie_break_existing_connection(existing_conn, &new_conn) { debug!( target: LOG_TARGET, "Disconnecting existing {} connection to peer '{}' because of simultaneous dial", @@ -434,8 +434,14 @@ where Box::new(existing_conn.peer_node_id().clone()), existing_conn.direction(), )); + + // Replace existing connection with new one + let existing_conn = self + .active_connections + .insert(node_id, new_conn.clone()) + .expect("Already checked"); + self.delayed_disconnect(existing_conn); - self.active_connections.insert(node_id, new_conn.clone()); self.publish_event(PeerConnected(new_conn)); } else { debug!( @@ -447,7 +453,6 @@ where ); self.delayed_disconnect(new_conn); - self.active_connections.insert(node_id, existing_conn); } }, None => { @@ -586,19 +591,8 @@ where return; } - if let Err(err) = self.dialer_tx.try_send(DialerRequest::Dial(Box::new(peer), reply_tx)) { + if let Err(err) = self.dialer_tx.send(DialerRequest::Dial(Box::new(peer), reply_tx)).await { error!(target: LOG_TARGET, "Failed to send request to dialer because '{}'", err); - // TODO: If the channel is full - we'll fail to dial. This function should block until the dial - // request channel has cleared - - if let DialerRequest::Dial(_, reply_tx) = err.into_inner() { - log_if_error_fmt!( - target: LOG_TARGET, - reply_tx.send(Err(ConnectionManagerError::EstablisherChannelError)), - "Failed to send dial peer result for peer '{}'", - node_id.short_str() - ); - } } }, Err(err) => { diff --git a/comms/src/message/envelope.rs b/comms/src/message/envelope.rs index ec1076eb3f..61a2c0d749 100644 --- a/comms/src/message/envelope.rs +++ b/comms/src/message/envelope.rs @@ -31,22 +31,10 @@ macro_rules! wrap_in_envelope_body { ($($e:expr),+) => {{ use $crate::message::MessageExt; let mut envelope_body = $crate::message::EnvelopeBody::new(); - let mut error = None; $( - match $e.to_encoded_bytes() { - Ok(bytes) => envelope_body.push_part(bytes), - Err(err) => { - if error.is_none() { - error = Some(err); - } - } - } + envelope_body.push_part($e.to_encoded_bytes()); )* - - match error { - Some(err) => Err(err), - None => Ok(envelope_body), - } + envelope_body }} } diff --git a/comms/src/message/error.rs b/comms/src/message/error.rs index 76b23f2761..f56c4cdb4d 100644 --- a/comms/src/message/error.rs +++ b/comms/src/message/error.rs @@ -22,7 +22,7 @@ use crate::peer_manager::node_id::NodeIdError; use derive_error::Error; -use prost::{DecodeError, EncodeError}; +use prost::DecodeError; use tari_crypto::{ signatures::SchnorrSignatureError, tari_utilities::{ciphers::cipher::CipherError, message_format::MessageFormatError}, @@ -53,8 +53,6 @@ pub enum MessageError { InvalidHeaderPublicKey, /// Failed to decode protobuf message DecodeError(DecodeError), - /// Failed to encode protobuf message - EncodeError(EncodeError), /// Failed to decode message part of envelope body EnvelopeBodyDecodeFailed, } diff --git a/comms/src/message/mod.rs b/comms/src/message/mod.rs index edeff8a439..3f05d33a0e 100644 --- a/comms/src/message/mod.rs +++ b/comms/src/message/mod.rs @@ -78,11 +78,14 @@ pub use tag::MessageTag; pub trait MessageExt: prost::Message { /// Encodes a message, allocating the buffer on the heap as necessary - fn to_encoded_bytes(&self) -> Result, MessageError> + fn to_encoded_bytes(&self) -> Vec where Self: Sized { let mut buf = Vec::with_capacity(self.encoded_len()); - self.encode(&mut buf)?; - Ok(buf) + self.encode(&mut buf).expect( + "prost::Message::encode documentation says it is infallible unless the buffer has insufficient capacity. \ + This buffer's capacity was set with encoded_len", + ); + buf } } impl MessageExt for T {} diff --git a/comms/src/peer_manager/peer.rs b/comms/src/peer_manager/peer.rs index d97fadbb4e..b3a89ad077 100644 --- a/comms/src/peer_manager/peer.rs +++ b/comms/src/peer_manager/peer.rs @@ -213,16 +213,22 @@ impl Peer { /// Display Peer as `[peer_id]: ` impl Display for Peer { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let flags_str = if self.flags == PeerFlags::empty() { + "".to_string() + } else { + format!("{:?}", self.flags) + }; f.write_str(&format!( "{}[{}] PK={} {} {:?} {}", - if self.is_banned() { "BANNED " } else { "" }, + flags_str, self.node_id.short_str(), self.public_key, self.addresses - .address_iter() - .next() + .addresses + .iter() .map(ToString::to_string) - .unwrap_or_else(|| "".to_string()), + .collect::>() + .join(","), match self.features { PeerFeatures::COMMUNICATION_NODE => "BASE_NODE".to_string(), PeerFeatures::COMMUNICATION_CLIENT => "WALLET".to_string(), diff --git a/comms/src/pipeline/error.rs b/comms/src/pipeline/error.rs index ccb42e506b..f225b27760 100644 --- a/comms/src/pipeline/error.rs +++ b/comms/src/pipeline/error.rs @@ -22,7 +22,6 @@ use std::{error, fmt}; -#[derive(Debug)] pub struct PipelineError { err_string: String, } @@ -35,6 +34,13 @@ impl PipelineError { } } +impl fmt::Debug for PipelineError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("PipelineError: ")?; + f.write_str(&self.err_string) + } +} + impl From<&str> for PipelineError { fn from(s: &str) -> Self { Self { diff --git a/comms/src/protocol/identity.rs b/comms/src/protocol/identity.rs index 78f734a77c..bd52299620 100644 --- a/comms/src/protocol/identity.rs +++ b/comms/src/protocol/identity.rs @@ -88,8 +88,7 @@ where features: node_identity.features().bits(), supported_protocols, } - .to_encoded_bytes() - .map_err(|_| IdentityProtocolError::ProtobufEncodingError)?; + .to_encoded_bytes(); sink.send(msg_bytes.into()).await?; sink.close().await?;