diff --git a/applications/tari_base_node/src/parser.rs b/applications/tari_base_node/src/parser.rs index 3894a2a3ea..4d5e5f1d78 100644 --- a/applications/tari_base_node/src/parser.rs +++ b/applications/tari_base_node/src/parser.rs @@ -50,7 +50,7 @@ use tari_comms::{ types::CommsPublicKey, NodeIdentity, }; -use tari_comms_dht::{envelope::NodeDestination, DhtDiscoveryRequester}; +use tari_comms_dht::DhtDiscoveryRequester; use tari_core::{ base_node::LocalNodeCommsInterface, blocks::BlockHeader, @@ -486,7 +486,7 @@ impl Parser { self.executor.spawn(async move { let start = Instant::now(); println!("🌎 Peer discovery started."); - match dht.discover_peer(dest_pubkey, None, NodeDestination::Unknown).await { + match dht.discover_peer(dest_pubkey).await { Ok(p) => { let end = Instant::now(); println!("⚡️ Discovery succeeded in {}ms!", (end - start).as_millis()); diff --git a/base_layer/core/src/base_node/chain_metadata_service/service.rs b/base_layer/core/src/base_node/chain_metadata_service/service.rs index a3107f9e50..1b4bab76d0 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/service.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/service.rs @@ -120,7 +120,7 @@ impl ChainMetadataService { /// Send this node's metadata to async fn update_liveness_chain_metadata(&mut self) -> Result<(), ChainMetadataSyncError> { let chain_metadata = self.base_node.get_metadata().await?; - let bytes = proto::ChainMetadata::from(chain_metadata).to_encoded_bytes()?; + let bytes = proto::ChainMetadata::from(chain_metadata).to_encoded_bytes(); self.liveness .set_pong_metadata_entry(MetadataKey::ChainMetadata, bytes) .await?; @@ -300,10 +300,7 @@ mod test { let (liveness_handle, _) = create_p2p_liveness_mock(1); let mut metadata = Metadata::new(); let proto_chain_metadata = create_sample_proto_chain_metadata(); - metadata.insert( - MetadataKey::ChainMetadata, - proto_chain_metadata.to_encoded_bytes().unwrap(), - ); + metadata.insert(MetadataKey::ChainMetadata, proto_chain_metadata.to_encoded_bytes()); let node_id = NodeId::new(); let pong_event = PongEvent { diff --git a/base_layer/core/src/base_node/service/initializer.rs b/base_layer/core/src/base_node/service/initializer.rs index 643ac067de..d2aed71668 100644 --- a/base_layer/core/src/base_node/service/initializer.rs +++ b/base_layer/core/src/base_node/service/initializer.rs @@ -136,6 +136,7 @@ async fn extract_block(msg: Arc) -> Option> { Some(DomainMessage { source_peer: msg.source_peer.clone(), dht_header: msg.dht_header.clone(), + authenticated_origin: msg.authenticated_origin.clone(), inner: block, }) }, diff --git a/base_layer/core/src/mempool/service/initializer.rs b/base_layer/core/src/mempool/service/initializer.rs index e173a026e8..8d62cc2dfc 100644 --- a/base_layer/core/src/mempool/service/initializer.rs +++ b/base_layer/core/src/mempool/service/initializer.rs @@ -121,10 +121,9 @@ async fn extract_transaction(msg: Arc) -> Option { let tx = match Transaction::try_from(tx) { Err(e) => { - let origin = msg.origin_public_key(); warn!( target: LOG_TARGET, - "Inbound transaction message from {} was ill-formed. {}", origin, e + "Inbound transaction message from {} was ill-formed. {}", msg.source_peer.public_key, e ); return None; }, @@ -133,6 +132,7 @@ async fn extract_transaction(msg: Arc) -> Option InboundDomainConnector { let DecryptedDhtMessage { source_peer, dht_header, + authenticated_origin, .. } = inbound_message; let peer_message = PeerMessage { message_header: header, source_peer: Clone::clone(&*source_peer), + authenticated_origin, dht_header, body: msg_bytes, }; @@ -141,21 +143,17 @@ mod test { use crate::test_utils::{make_dht_inbound_message, make_node_identity}; use futures::{channel::mpsc, executor::block_on, StreamExt}; use tari_comms::{message::MessageExt, wrap_in_envelope_body}; - use tari_comms_dht::{domain_message::MessageHeader, envelope::DhtMessageFlags}; + use tari_comms_dht::domain_message::MessageHeader; use tower::ServiceExt; #[tokio_macros::test_basic] async fn handle_message() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); - let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap(); - - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let msg = wrap_in_envelope_body!(header, b"my message".to_vec()); + + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap(); let peer_message = block_on(rx.next()).unwrap(); @@ -167,14 +165,10 @@ mod test { async fn send_on_sink() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); - let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap(); + let msg = wrap_in_envelope_body!(header, b"my message".to_vec()); - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).send(decrypted).await.unwrap(); @@ -187,14 +181,10 @@ mod test { async fn handle_message_fail_deserialize() { let (tx, mut rx) = mpsc::channel(1); let header = b"dodgy header".to_vec(); - let msg = wrap_in_envelope_body!(header, b"message".to_vec()).unwrap(); - - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let msg = wrap_in_envelope_body!(header, b"message".to_vec()); + + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap_err(); assert!(rx.try_next().unwrap().is_none()); @@ -206,13 +196,9 @@ mod test { // from it's call function let (tx, _) = mpsc::channel(1); let header = MessageHeader::new(123); - let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap(); - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let msg = wrap_in_envelope_body!(header, b"my message".to_vec()); + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); let result = InboundDomainConnector::new(tx).oneshot(decrypted).await; assert!(result.is_err()); } diff --git a/base_layer/p2p/src/comms_connector/peer_message.rs b/base_layer/p2p/src/comms_connector/peer_message.rs index f1d53c205f..4a62e393b6 100644 --- a/base_layer/p2p/src/comms_connector/peer_message.rs +++ b/base_layer/p2p/src/comms_connector/peer_message.rs @@ -34,28 +34,13 @@ pub struct PeerMessage { pub source_peer: Peer, /// Domain message header pub message_header: MessageHeader, + /// This messages authenticated origin, otherwise None + pub authenticated_origin: Option, /// Serialized message data pub body: Vec, } impl PeerMessage { - pub fn new(dht_header: DhtMessageHeader, source_peer: Peer, message_header: MessageHeader, body: Vec) -> Self { - Self { - body, - message_header, - dht_header, - source_peer, - } - } - - pub fn origin_public_key(&self) -> &CommsPublicKey { - self.dht_header - .origin - .as_ref() - .map(|o| &o.public_key) - .unwrap_or(&self.source_peer.public_key) - } - pub fn decode_message(&self) -> Result where T: prost::Message + Default { let msg = T::decode(self.body.as_slice())?; diff --git a/base_layer/p2p/src/domain_message.rs b/base_layer/p2p/src/domain_message.rs index 7aea2a50c8..0299a466fb 100644 --- a/base_layer/p2p/src/domain_message.rs +++ b/base_layer/p2p/src/domain_message.rs @@ -32,6 +32,8 @@ pub struct DomainMessage { /// This DHT header of this message. If `DhtMessageHeader::origin_public_key` is different from the /// `source_peer.public_key`, this message was forwarded. pub dht_header: DhtMessageHeader, + /// The authenticated origin public key of this message or None a message origin was not provided. + pub authenticated_origin: Option, /// The domain-level message pub inner: T, } @@ -48,32 +50,15 @@ impl DomainMessage { /// Consumes this object returning the public key of the original sender of this message and the message itself pub fn into_origin_and_inner(self) -> (CommsPublicKey, T) { let inner = self.inner; - let pk = self - .dht_header - .origin - .map(|o| o.public_key) - .unwrap_or(self.source_peer.public_key); + let pk = self.authenticated_origin.unwrap_or(self.source_peer.public_key); (pk, inner) } - /// Returns true of this message was forwarded from another peer, otherwise false - pub fn is_forwarded(&self) -> bool { - self.dht_header - .origin - .as_ref() - // If the source and origin are different, then the message was forwarded - .map(|o| o.public_key != self.source_peer.public_key) - // Otherwise, if no origin is specified, the message was sent directly from the peer - .unwrap_or(false) - } - /// Returns the public key that sent this message. If no origin is specified, then the source peer /// sent this message. pub fn origin_public_key(&self) -> &CommsPublicKey { - self.dht_header - .origin + self.authenticated_origin .as_ref() - .map(|o| &o.public_key) .unwrap_or(&self.source_peer.public_key) } @@ -88,6 +73,7 @@ impl DomainMessage { DomainMessage { source_peer: self.source_peer, dht_header: self.dht_header, + authenticated_origin: self.authenticated_origin, inner, } } @@ -103,6 +89,7 @@ impl DomainMessage { Ok(DomainMessage { source_peer: self.source_peer, dht_header: self.dht_header, + authenticated_origin: self.authenticated_origin, inner, }) } diff --git a/base_layer/p2p/src/services/liveness/service.rs b/base_layer/p2p/src/services/liveness/service.rs index 8d71de1c83..fe9ce3fe9d 100644 --- a/base_layer/p2p/src/services/liveness/service.rs +++ b/base_layer/p2p/src/services/liveness/service.rs @@ -513,13 +513,16 @@ mod test { &[], ); DomainMessage { - dht_header: DhtMessageHeader::new( - Default::default(), - DhtMessageType::None, - None, - Network::LocalTest, - Default::default(), - ), + dht_header: DhtMessageHeader { + version: 0, + destination: Default::default(), + origin_mac: Vec::new(), + ephemeral_public_key: None, + message_type: DhtMessageType::None, + network: Network::LocalTest, + flags: Default::default(), + }, + authenticated_origin: None, source_peer, inner, } diff --git a/base_layer/p2p/src/services/utils.rs b/base_layer/p2p/src/services/utils.rs index 111926ab58..70699d9f10 100644 --- a/base_layer/p2p/src/services/utils.rs +++ b/base_layer/p2p/src/services/utils.rs @@ -43,6 +43,7 @@ where T: prost::Message + Default { Ok(DomainMessage { source_peer: serialized.source_peer.clone(), dht_header: serialized.dht_header.clone(), + authenticated_origin: serialized.authenticated_origin.clone(), inner: serialized.decode_message()?, }) } diff --git a/base_layer/p2p/src/test_utils.rs b/base_layer/p2p/src/test_utils.rs index 1582449e63..0b90f38724 100644 --- a/base_layer/p2p/src/test_utils.rs +++ b/base_layer/p2p/src/test_utils.rs @@ -25,13 +25,11 @@ use std::sync::Arc; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerFlags}, - utils::signature, }; use tari_comms_dht::{ - envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, DhtMessageType, Network, NodeDestination}, + envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageType, Network, NodeDestination}, inbound::DhtInboundMessage, }; -use tari_crypto::tari_utilities::message_format::MessageFormat; macro_rules! unwrap_oms_send_msg { ($var:expr, reply_value=$reply_value:expr) => { @@ -61,31 +59,21 @@ pub fn make_node_identity() -> Arc { ) } -pub fn make_dht_header(node_identity: &NodeIdentity, message: &Vec, flags: DhtMessageFlags) -> DhtMessageHeader { +pub fn make_dht_header() -> DhtMessageHeader { DhtMessageHeader { version: 0, destination: NodeDestination::Unknown, - origin: Some(DhtMessageOrigin { - public_key: node_identity.public_key().clone(), - signature: signature::sign(&mut OsRng, node_identity.secret_key().clone(), message) - .unwrap() - .to_binary() - .unwrap(), - }), + origin_mac: Vec::new(), + ephemeral_public_key: None, message_type: DhtMessageType::None, network: Network::LocalTest, - flags, + flags: DhtMessageFlags::NONE, } } -pub fn make_dht_inbound_message( - node_identity: &NodeIdentity, - message: Vec, - flags: DhtMessageFlags, -) -> DhtInboundMessage -{ +pub fn make_dht_inbound_message(node_identity: &NodeIdentity, message: Vec) -> DhtInboundMessage { DhtInboundMessage::new( - make_dht_header(node_identity, &message, flags), + make_dht_header(), Arc::new(Peer::new( node_identity.public_key().clone(), node_identity.node_id().clone(), diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index a91a1055d6..0660d4665e 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -680,8 +680,8 @@ fn test_startup_utxo_scan() { let output3 = UnblindedOutput::new(MicroTari::from(value3), key3, None); runtime.block_on(oms.add_output(output3.clone())).unwrap(); - let call = outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body .decode_part::(1) .unwrap() @@ -755,8 +755,8 @@ fn test_startup_utxo_scan() { runtime.block_on(oms.sync_with_base_node()).unwrap(); - let call = outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body .decode_part::(1) .unwrap() diff --git a/base_layer/wallet/tests/support/comms_and_services.rs b/base_layer/wallet/tests/support/comms_and_services.rs index a8d50ed86a..13a6afd2da 100644 --- a/base_layer/wallet/tests/support/comms_and_services.rs +++ b/base_layer/wallet/tests/support/comms_and_services.rs @@ -77,13 +77,15 @@ pub fn create_dummy_message(inner: T, public_key: &CommsPublicKey) -> DomainM ); DomainMessage { dht_header: DhtMessageHeader { - origin: None, + ephemeral_public_key: None, + origin_mac: Vec::new(), version: Default::default(), message_type: Default::default(), flags: Default::default(), network: Network::LocalTest, destination: Default::default(), }, + authenticated_origin: None, source_peer: peer_source, inner, } diff --git a/base_layer/wallet/tests/transaction_service/service.rs b/base_layer/wallet/tests/transaction_service/service.rs index 4bdf23d1ff..1ec4b344b2 100644 --- a/base_layer/wallet/tests/transaction_service/service.rs +++ b/base_layer/wallet/tests/transaction_service/service.rs @@ -791,7 +791,7 @@ fn test_accepting_unknown_tx_id_and_malformed_reply(1) .unwrap() @@ -954,7 +954,7 @@ fn finalize_tx_with_incorrect_pubkey(al .wait_call_count(1, Duration::from_secs(10)) .unwrap(); let (_, body) = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(body.as_slice()).unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let recipient_reply: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -1058,7 +1058,7 @@ fn finalize_tx_with_missing_output(alic .wait_call_count(1, Duration::from_secs(10)) .unwrap(); let (_, body) = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(body.as_slice()).unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let recipient_reply: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -1448,8 +1448,8 @@ fn transaction_mempool_broadcast() { )) .unwrap(); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() @@ -1478,8 +1478,8 @@ fn transaction_mempool_broadcast() { timeout = Duration::from_secs(20) ) }); - let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_reply_msg: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -1526,8 +1526,8 @@ fn transaction_mempool_broadcast() { assert_eq!(alice_completed_tx.status, TransactionStatus::Completed); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let msr = envelope_body .decode_part::(1) .unwrap() @@ -1537,9 +1537,9 @@ fn transaction_mempool_broadcast() { let _ = alice_outbound_service.pop_call().unwrap(); // burn a mempool request let _ = alice_outbound_service.pop_call().unwrap(); // burn a mempool request - let call = alice_outbound_service.pop_call().unwrap(); // this should be the sending of the finalized tx to the receiver + let (_, body) = alice_outbound_service.pop_call().unwrap(); // this should be the sending of the finalized tx to the receiver - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_finalized = envelope_body .decode_part::(1) .unwrap() @@ -1772,8 +1772,8 @@ fn transaction_base_node_monitoring() { )) .unwrap(); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() @@ -1802,8 +1802,8 @@ fn transaction_base_node_monitoring() { timeout = Duration::from_secs(20) ) }); - let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_reply_msg: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -1845,8 +1845,8 @@ fn transaction_base_node_monitoring() { )) .unwrap(); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() @@ -1875,8 +1875,8 @@ fn transaction_base_node_monitoring() { timeout = Duration::from_secs(20) ) }); - let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_reply_msg: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -1912,8 +1912,8 @@ fn transaction_base_node_monitoring() { let _ = alice_outbound_service.pop_call().unwrap(); // burn a base node request - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let msr = envelope_body .decode_part::(1) .unwrap() @@ -2362,8 +2362,8 @@ fn transaction_cancellation_when_not_in_mempool() { )) .unwrap(); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() @@ -2392,8 +2392,8 @@ fn transaction_cancellation_when_not_in_mempool() { timeout = Duration::from_secs(30) ) }); - let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_reply_msg: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -2429,8 +2429,8 @@ fn transaction_cancellation_when_not_in_mempool() { let _ = alice_outbound_service.pop_call().unwrap(); // burn a base node request - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let msr = envelope_body .decode_part::(1) .unwrap() diff --git a/comms/dht/examples/memorynet.rs b/comms/dht/examples/memorynet.rs index ef04d93a96..527019d4ab 100644 --- a/comms/dht/examples/memorynet.rs +++ b/comms/dht/examples/memorynet.rs @@ -68,8 +68,7 @@ use tari_comms::{ ConnectionManagerEvent, PeerConnection, }; -use tari_comms_dht::{envelope::NodeDestination, inbound::DecryptedDhtMessage, Dht, DhtBuilder}; -use tari_crypto::tari_utilities::ByteArray; +use tari_comms_dht::{inbound::DecryptedDhtMessage, Dht, DhtBuilder}; use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; use tari_test_utils::{paths::create_temporary_data_path, random}; use tokio::{runtime, time}; @@ -216,17 +215,7 @@ async fn main() { peer_list_summary(&wallets).await; - total_messages += discovery(&wallets, &mut messaging_events_rx, false, true).await; - - take_a_break().await; - total_messages += drain_messaging_events(&mut messaging_events_rx, false).await; - - total_messages += discovery(&wallets, &mut messaging_events_rx, true, false).await; - - take_a_break().await; - total_messages += drain_messaging_events(&mut messaging_events_rx, false).await; - - total_messages += discovery(&wallets, &mut messaging_events_rx, false, false).await; + total_messages += discovery(&wallets, &mut messaging_events_rx).await; take_a_break().await; total_messages += drain_messaging_events(&mut messaging_events_rx, false).await; @@ -247,13 +236,7 @@ async fn shutdown_all(nodes: Vec) { future::join_all(tasks).await; } -async fn discovery( - wallets: &[TestNode], - messaging_events_rx: &mut MessagingEventRx, - use_network_region: bool, - use_destination_node_id: bool, -) -> usize -{ +async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut MessagingEventRx) -> usize { let mut successes = 0; let mut total_messages = 0; let mut total_time = Duration::from_secs(0); @@ -265,30 +248,11 @@ async fn discovery( peer_list_summary(&[wallet1, wallet2]).await; - let mut destination = NodeDestination::Unknown; - if use_network_region { - let mut new_node_id = [0; 13]; - let node_id = wallet2.get_node_id(); - let buf = &mut new_node_id[..10]; - buf.copy_from_slice(&node_id.as_bytes()[..10]); - let regional_node_id = NodeId::from_bytes(&new_node_id).unwrap(); - destination = NodeDestination::NodeId(Box::new(regional_node_id)); - } - - let mut node_id_dest = None; - if use_destination_node_id { - node_id_dest = Some(wallet2.get_node_id()); - } - let start = Instant::now(); let discovery_result = wallet1 .dht .discovery_service_requester() - .discover_peer( - Box::new(wallet2.node_identity().public_key().clone()), - node_id_dest, - destination, - ) + .discover_peer(Box::new(wallet2.node_identity().public_key().clone())) .await; let end = Instant::now(); @@ -387,15 +351,13 @@ async fn do_store_and_forward_discovery( node_identity.public_key(), ); let mut first_wallet_discovery_req = wallets[0].dht.discovery_service_requester(); + + let start = Instant::now(); let discovery_task = runtime::Handle::current().spawn({ let node_identity = node_identity.clone(); async move { first_wallet_discovery_req - .discover_peer( - Box::new(node_identity.public_key().clone()), - Some(node_identity.node_id().clone()), - node_identity.public_key().clone().into(), - ) + .discover_peer(Box::new(node_identity.public_key().clone())) .await } }); @@ -410,7 +372,6 @@ async fn do_store_and_forward_discovery( let (comms, dht) = setup_comms_dht(node_identity, create_peer_storage(all_peers), tx).await; let wallet = TestNode::new(comms, dht, None, ims_rx, messaging_tx); - let start = Instant::now(); wallet.dht.dht_requester().send_request_stored_messages().await.unwrap(); total_messages += match discovery_task.await.unwrap() { @@ -497,8 +458,8 @@ fn connection_manager_logger( PeerConnectWillClose(_, node_id, direction) => { println!( "'{}' will disconnect {} connection to '{}'", - direction, get_name(node_id), + direction, node_name, ); }, diff --git a/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/down.sql b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/down.sql new file mode 100644 index 0000000000..a4b1198712 --- /dev/null +++ b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/down.sql @@ -0,0 +1 @@ +-- No going back diff --git a/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/up.sql b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/up.sql new file mode 100644 index 0000000000..d672ff7ac5 --- /dev/null +++ b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/up.sql @@ -0,0 +1,20 @@ +DROP TABLE stored_messages; + +CREATE TABLE stored_messages( + id INTEGER NOT NULL PRIMARY KEY, + version INT NOT NULL, + origin_pubkey TEXT, + message_type INT NOT NULL, + destination_pubkey TEXT, + destination_node_id TEXT, + header BLOB NOT NULL, + body BLOB NOT NULL, + is_encrypted BOOLEAN NOT NULL CHECK (is_encrypted IN (0,1)), + priority INT NOT NULL, + stored_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX idx_stored_messages_destination_pubkey ON stored_messages (destination_pubkey); +CREATE INDEX idx_stored_messages_destination_node_id ON stored_messages (destination_node_id); +CREATE INDEX idx_stored_messages_stored_at ON stored_messages (stored_at); +CREATE INDEX idx_stored_messages_priority ON stored_messages (priority); diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 2e5d6c2182..4453d5fbbb 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -190,11 +190,8 @@ impl Dht { S::Future: Send, { let builder = ServiceBuilder::new() - .layer(inbound::DeserializeLayer::new()) - .layer(inbound::ValidateLayer::new( - self.config.network, - self.outbound_requester(), - )) + .layer(inbound::DeserializeLayer) + .layer(inbound::ValidateLayer::new(self.config.network)) .layer(inbound::DedupLayer::new(self.dht_requester())); // FIXME: There is an unresolved stack overflow issue on windows. Seems that we've reached the limit on stack @@ -262,7 +259,7 @@ impl Dht { )) .layer(MessageLoggingLayer::new("Outbound message: ")) .layer(outbound::EncryptionLayer::new(Arc::clone(&self.node_identity))) - .layer(outbound::SerializeLayer::new(Arc::clone(&self.node_identity))) + .layer(outbound::SerializeLayer) .into_inner() } @@ -348,14 +345,9 @@ mod test { let mut service = dht.inbound_middleware_layer().layer(SinkService::new(out_tx)); - let msg = wrap_in_envelope_body!(b"secret".to_vec()).unwrap(); - let dht_envelope = make_dht_envelope( - &node_identity, - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let inbound_message = - make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().unwrap().into()); + let msg = wrap_in_envelope_body!(b"secret".to_vec()); + let dht_envelope = make_dht_envelope(&node_identity, msg.to_encoded_bytes(), DhtMessageFlags::empty(), false); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); let msg = { service.call(inbound_message).await.unwrap(); @@ -376,7 +368,7 @@ mod test { let (connection_manager, _) = create_connection_manager_mock(1); // Dummy out channel, we are not testing outbound here. - let (out_tx, _) = mpsc::channel(10); + let (out_tx, _out_rx) = mpsc::channel(10); let shutdown = Shutdown::new(); let dht = DhtBuilder::new( @@ -392,13 +384,10 @@ mod test { let mut service = dht.inbound_middleware_layer().layer(SinkService::new(out_tx)); - let msg = wrap_in_envelope_body!(b"secret".to_vec()).unwrap(); + let msg = wrap_in_envelope_body!(b"secret".to_vec()); // Encrypt for self - let ecdh_key = crypt::generate_ecdh_secret(node_identity.secret_key(), node_identity.public_key()); - let encrypted_bytes = crypt::encrypt(&ecdh_key, &msg.to_encoded_bytes().unwrap()).unwrap(); - let dht_envelope = make_dht_envelope(&node_identity, encrypted_bytes, DhtMessageFlags::ENCRYPTED); - let inbound_message = - make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().unwrap().into()); + let dht_envelope = make_dht_envelope(&node_identity, msg.to_encoded_bytes(), DhtMessageFlags::ENCRYPTED, true); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); let msg = { service.call(inbound_message).await.unwrap(); @@ -437,25 +426,17 @@ mod test { let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); - let msg = wrap_in_envelope_body!(b"unencrypteable".to_vec()).unwrap(); + let msg = wrap_in_envelope_body!(b"unencrypteable".to_vec()); // Encrypt for someone else let node_identity2 = make_node_identity(); let ecdh_key = crypt::generate_ecdh_secret(node_identity2.secret_key(), node_identity2.public_key()); - let encrypted_bytes = crypt::encrypt(&ecdh_key, &msg.to_encoded_bytes().unwrap()).unwrap(); - let dht_envelope = make_dht_envelope(&node_identity, encrypted_bytes, DhtMessageFlags::ENCRYPTED); - - let origin_sig = dht_envelope - .header - .as_ref() - .unwrap() - .origin - .as_ref() - .unwrap() - .signature - .clone(); - let inbound_message = - make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().unwrap().into()); + let encrypted_bytes = crypt::encrypt(&ecdh_key, &msg.to_encoded_bytes()).unwrap(); + let dht_envelope = make_dht_envelope(&node_identity, encrypted_bytes, DhtMessageFlags::ENCRYPTED, true); + + let origin_mac = dht_envelope.header.as_ref().unwrap().origin_mac.clone(); + assert_eq!(origin_mac.is_empty(), false); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); service.call(inbound_message).await.unwrap(); @@ -463,7 +444,7 @@ mod test { let (params, _) = oms_mock_state.pop_call().unwrap(); // Check that OMS got a request to forward with the original Dht Header - assert_eq!(params.dht_header.unwrap().origin.unwrap().signature, origin_sig); + assert_eq!(params.dht_header.unwrap().origin_mac, origin_mac); // Check the next service was not called assert!(next_service_rx.try_next().is_err()); @@ -494,18 +475,14 @@ mod test { let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); - let msg = wrap_in_envelope_body!(b"secret".to_vec()).unwrap(); - let mut dht_envelope = make_dht_envelope( - &node_identity, - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); + let msg = wrap_in_envelope_body!(b"secret".to_vec()); + let mut dht_envelope = + make_dht_envelope(&node_identity, msg.to_encoded_bytes(), DhtMessageFlags::empty(), false); dht_envelope.header.as_mut().and_then(|header| { header.message_type = DhtMessageType::SafStoredMessages as i32; Some(header) }); - let inbound_message = - make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().unwrap().into()); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); service.call(inbound_message).await.unwrap_err(); // This seems like the best way to tell that an open channel is empty without the test blocking indefinitely diff --git a/comms/dht/src/discovery/requester.rs b/comms/dht/src/discovery/requester.rs index 4f97425eda..d43fb1dea3 100644 --- a/comms/dht/src/discovery/requester.rs +++ b/comms/dht/src/discovery/requester.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{discovery::DhtDiscoveryError, envelope::NodeDestination, proto::dht::DiscoveryResponseMessage}; +use crate::{discovery::DhtDiscoveryError, proto::dht::DiscoveryResponseMessage}; use futures::{ channel::{mpsc, oneshot}, SinkExt, @@ -29,51 +29,12 @@ use std::{ fmt::{Display, Error, Formatter}, time::Duration, }; -use tari_comms::{ - peer_manager::{NodeId, Peer}, - types::CommsPublicKey, -}; +use tari_comms::{peer_manager::Peer, types::CommsPublicKey}; use tokio::time; -#[derive(Debug)] -pub struct DiscoverPeerRequest { - /// The public key of the peer to be discovered. The message will be encrypted with a DH shared - /// secret using this public key. - pub dest_public_key: Box, - /// The node id of the peer to be discovered, if it is known. Providing the `NodeId` allows - /// discovery requests to reach their destination more quickly. - pub dest_node_id: Option, - /// The destination to include in the comms header. - /// `Undisclosed` will require nodes to propagate the message across the network, presumably eventually - /// reaching the destination node (the node which can decrypt the message). This will happen without - /// any intermediary nodes knowing who is being searched for. - /// `NodeId` will direct the discovery request closer to the destination or network region. - /// `PublicKey` will be propagated across the network. If any node knows the peer, the request can be - /// forwarded to them immediately. However, more nodes will know that this node is being searched for - /// which may slightly compromise privacy. - pub destination: NodeDestination, -} - -impl Display for DiscoverPeerRequest { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { - f.debug_struct("DiscoverPeerRequest") - .field("dest_public_key", &format!("{}", self.dest_public_key)) - .field( - "dest_node_id", - &self - .dest_node_id - .as_ref() - .map(|node_id| format!("{}", node_id)) - .unwrap_or_else(|| "None".to_string()), - ) - .field("destination", &format!("{}", self.destination)) - .finish() - } -} - #[derive(Debug)] pub enum DhtDiscoveryRequest { - DiscoverPeer(Box<(DiscoverPeerRequest, oneshot::Sender>)>), + DiscoverPeer(Box, oneshot::Sender>), NotifyDiscoveryResponseReceived(Box), } @@ -81,7 +42,7 @@ impl Display for DhtDiscoveryRequest { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { use DhtDiscoveryRequest::*; match self { - DiscoverPeer(boxed) => write!(f, "DiscoverPeer({})", boxed.0), + DiscoverPeer(public_key, _) => write!(f, "DiscoverPeer({})", public_key), NotifyDiscoveryResponseReceived(discovery_resp) => { write!(f, "NotifyDiscoveryResponseReceived({:#?})", discovery_resp) }, @@ -103,21 +64,25 @@ impl DhtDiscoveryRequester { } } - pub async fn discover_peer( - &mut self, - dest_public_key: Box, - dest_node_id: Option, - destination: NodeDestination, - ) -> Result - { + /// Initiate a peer discovery + /// + /// ## Arguments + /// - `dest_public_key` - The public key of he recipient used to create a shared ECDH key which in turn is used to + /// encrypt the discovery message + /// - `destination` - The `NodeDestination` to use in the DhtHeader when sending a discovery message. + /// - `Unknown` destination will maintain complete privacy, the trade off is that discovery needs to propagate + /// the entire network to reach the destination and so may take longer + /// - `NodeId` Instruct propagation nodes to direct the message to peers closer to the given NodeId. The `NodeId` + /// may be directed to a region close to the real destination (somewhat private) or directed at a particular + /// node (not private) + /// - `PublicKey` if any node on the network knows this public key, the message will be directed to that node. + /// This sacrifices privacy for more efficient discovery in terms of network bandwidth and may result in + /// quicker discovery times. + pub async fn discover_peer(&mut self, dest_public_key: Box) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); - let request = DiscoverPeerRequest { - dest_public_key, - dest_node_id, - destination, - }; + self.sender - .send(DhtDiscoveryRequest::DiscoverPeer(Box::new((request, reply_tx)))) + .send(DhtDiscoveryRequest::DiscoverPeer(dest_public_key, reply_tx)) .await?; time::timeout( diff --git a/comms/dht/src/discovery/service.rs b/comms/dht/src/discovery/service.rs index 6aea69ebce..3d0a1661d8 100644 --- a/comms/dht/src/discovery/service.rs +++ b/comms/dht/src/discovery/service.rs @@ -21,11 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - discovery::{ - requester::{DhtDiscoveryRequest, DiscoverPeerRequest}, - DhtDiscoveryError, - }, - envelope::{DhtMessageType, NodeDestination}, + discovery::{requester::DhtDiscoveryRequest, DhtDiscoveryError}, + envelope::DhtMessageType, outbound::{OutboundEncryption, OutboundMessageRequester, SendMessageParams}, proto::dht::{DiscoveryMessage, DiscoveryResponseMessage}, DhtConfig, @@ -146,11 +143,10 @@ impl DhtDiscoveryService { async fn handle_request(&mut self, request: DhtDiscoveryRequest) { use DhtDiscoveryRequest::*; match request { - DiscoverPeer(boxed) => { - let (request, reply_tx) = *boxed; + DiscoverPeer(dest_pubkey, reply_tx) => { log_if_error!( target: LOG_TARGET, - self.initiate_peer_discovery(request, reply_tx).await, + self.initiate_peer_discovery(dest_pubkey, reply_tx).await, "Failed to initiate a discovery request because '{error}'", ); }, @@ -183,16 +179,9 @@ impl DhtDiscoveryService { peer", peer.node_id.short_str() ); - // Attempt to discover them - let request = DiscoverPeerRequest { - dest_public_key: Box::new(peer.public_key), - // TODO: This should be the node region, not the node id - dest_node_id: Some(peer.node_id), - destination: Default::default(), - }; // Don't need to be notified for this discovery let (reply_tx, _) = oneshot::channel(); - if let Err(err) = self.initiate_peer_discovery(request, reply_tx).await { + if let Err(err) = self.initiate_peer_discovery(Box::new(peer.public_key), reply_tx).await { error!(target: LOG_TARGET, "Error sending discovery message: {:?}", err); } } @@ -381,13 +370,12 @@ impl DhtDiscoveryService { async fn initiate_peer_discovery( &mut self, - discovery_request: DiscoverPeerRequest, + dest_pubkey: Box, reply_tx: oneshot::Sender>, ) -> Result<(), DhtDiscoveryError> { let nonce = OsRng.next_u64(); - let public_key = discovery_request.dest_public_key.clone(); - self.send_discover(nonce, discovery_request).await?; + self.send_discover(nonce, dest_pubkey.clone()).await?; let inflight_count = self.inflight_discoveries.len(); @@ -406,7 +394,7 @@ impl DhtDiscoveryService { // Add the new inflight request. self.inflight_discoveries - .insert(nonce, DiscoveryRequestState::new(public_key, reply_tx)); + .insert(nonce, DiscoveryRequestState::new(dest_pubkey, reply_tx)); trace!( target: LOG_TARGET, @@ -420,25 +408,9 @@ impl DhtDiscoveryService { async fn send_discover( &mut self, nonce: u64, - discovery_request: DiscoverPeerRequest, + dest_public_key: Box, ) -> Result<(), DhtDiscoveryError> { - let DiscoverPeerRequest { - dest_node_id, - dest_public_key, - destination, - } = discovery_request; - - // If the destination node is is known, send to the closest peers we know. Otherwise... - let network_location_node_id = dest_node_id - .or_else(|| match &destination { - // ... if the destination is undisclosed or a public key, send discover to our closest peers - NodeDestination::Unknown | NodeDestination::PublicKey(_) => Some(self.node_identity.node_id().clone()), - // otherwise, send it to the closest peers to the given NodeId destination we know - NodeDestination::NodeId(node_id) => Some(*node_id.clone()), - }) - .expect("cannot fail"); - let discover_msg = DiscoveryMessage { node_id: self.node_identity.node_id().to_vec(), addresses: vec![self.node_identity.public_address().to_string()], @@ -447,19 +419,13 @@ impl DhtDiscoveryService { }; debug!( target: LOG_TARGET, - "Sending Discovery message for Node Id: {}", destination + "Sending Discovery message for peer public key '{}'", dest_public_key ); self.outbound_requester .send_message_no_header( SendMessageParams::new() - .closest( - network_location_node_id, - self.config.num_neighbouring_nodes, - Vec::new(), - PeerFeatures::empty(), - ) - .with_destination(destination) + .neighbours_include_clients(Vec::new()) .with_encryption(OutboundEncryption::EncryptFor(dest_public_key)) .with_dht_message_type(DhtMessageType::Discovery) .finish(), @@ -512,7 +478,7 @@ mod test { rt.spawn(service.run()); let dest_public_key = Box::new(CommsPublicKey::default()); - let result = rt.block_on(requester.discover_peer(dest_public_key.clone(), None, NodeDestination::Unknown)); + let result = rt.block_on(requester.discover_peer(dest_public_key.clone())); assert!(result.unwrap_err().is_timeout()); diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index ae1d0d6893..8e58296d59 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{consts::DHT_ENVELOPE_HEADER_VERSION, proto::envelope::DhtOrigin}; use bitflags::bitflags; use derive_error::Error; use serde::{Deserialize, Serialize}; @@ -29,11 +28,12 @@ use std::{ fmt, fmt::Display, }; -use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, utils::signature}; -use tari_crypto::tari_utilities::{hex::Hex, ByteArray, ByteArrayError}; +use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; +use tari_crypto::tari_utilities::{ByteArray, ByteArrayError}; // Re-export applicable protos pub use crate::proto::envelope::{dht_header::Destination, DhtEnvelope, DhtHeader, DhtMessageType, Network}; +use bytes::Bytes; #[derive(Debug, Error)] pub enum DhtMessageError { @@ -47,6 +47,8 @@ pub enum DhtMessageError { InvalidNetwork, /// Invalid or unrecognised DHT message flags InvalidMessageFlags, + /// Invalid ephemeral public key + InvalidEphemeralPublicKey, /// Header was omitted from the message HeaderOmitted, } @@ -103,71 +105,27 @@ impl DhtMessageType { } } -#[derive(Clone, PartialEq, Eq)] -pub struct DhtMessageOrigin { - pub public_key: CommsPublicKey, - pub signature: Vec, -} - -impl fmt::Debug for DhtMessageOrigin { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("DhtMessageOrigin") - .field("public_key", &self.public_key.to_hex()) - .field("signature", &self.signature.to_hex()) - .finish() - } -} - -impl TryFrom for DhtMessageOrigin { - type Error = DhtMessageError; - - fn try_from(value: DhtOrigin) -> Result { - Ok(Self { - public_key: CommsPublicKey::from_bytes(&value.public_key).map_err(|_| DhtMessageError::InvalidOrigin)?, - signature: value.signature, - }) - } -} - -impl From for DhtOrigin { - fn from(value: DhtMessageOrigin) -> Self { - Self { - public_key: value.public_key.to_vec(), - signature: value.signature, - } - } -} - /// This struct mirrors the protobuf version of DhtHeader but is more ergonomic to work with. /// It is preferable to not to expose the generated prost structs publicly. #[derive(Clone, Debug, PartialEq, Eq)] pub struct DhtMessageHeader { pub version: u32, pub destination: NodeDestination, - /// Origin of the message. This can refer to the same peer that sent the message - /// or another peer if the message should be forwarded. - pub origin: Option, + /// Encoded DhtOrigin. This can refer to the same peer that sent the message + /// or another peer if the message is being propagated. + pub origin_mac: Vec, + pub ephemeral_public_key: Option, pub message_type: DhtMessageType, pub network: Network, pub flags: DhtMessageFlags, } impl DhtMessageHeader { - pub fn new( - destination: NodeDestination, - message_type: DhtMessageType, - origin: Option, - network: Network, - flags: DhtMessageFlags, - ) -> Self - { - Self { - version: DHT_ENVELOPE_HEADER_VERSION, - destination, - origin, - message_type, - network, - flags, + pub fn is_valid(&self) -> bool { + if self.flags.contains(DhtMessageFlags::ENCRYPTED) { + !self.origin_mac.is_empty() && self.ephemeral_public_key.is_some() + } else { + true } } } @@ -176,8 +134,8 @@ impl Display for DhtMessageHeader { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!( f, - "DhtMessageHeader (Dest:{}, Origin:{:?}, Type:{:?}, Network:{:?}, Flags:{:?})", - self.destination, self.origin, self.message_type, self.network, self.flags + "DhtMessageHeader (Dest:{}, Type:{:?}, Network:{:?}, Flags:{:?})", + self.destination, self.message_type, self.network, self.flags ) } } @@ -193,15 +151,20 @@ impl TryFrom for DhtMessageHeader { .map(Option::unwrap) .ok_or_else(|| DhtMessageError::InvalidDestination)?; - let origin = match header.origin { - Some(origin) => Some(origin.try_into()?), - None => None, + let ephemeral_public_key = if header.ephemeral_public_key.is_empty() { + None + } else { + Some( + CommsPublicKey::from_bytes(&header.ephemeral_public_key) + .map_err(|_| DhtMessageError::InvalidEphemeralPublicKey)?, + ) }; Ok(Self { version: header.version, destination, - origin, + origin_mac: header.origin_mac, + ephemeral_public_key, message_type: DhtMessageType::from_i32(header.message_type) .ok_or_else(|| DhtMessageError::InvalidMessageType)?, network: Network::from_i32(header.network).ok_or_else(|| DhtMessageError::InvalidNetwork)?, @@ -225,7 +188,12 @@ impl From for DhtHeader { fn from(header: DhtMessageHeader) -> Self { Self { version: header.version, - origin: header.origin.map(Into::into), + ephemeral_public_key: header + .ephemeral_public_key + .as_ref() + .map(ByteArray::to_vec) + .unwrap_or_else(Vec::new), + origin_mac: header.origin_mac, destination: Some(header.destination.into()), message_type: header.message_type as i32, network: header.network as i32, @@ -235,44 +203,12 @@ impl From for DhtHeader { } impl DhtEnvelope { - pub fn new(header: DhtHeader, body: Vec) -> Self { + pub fn new(header: DhtHeader, body: Bytes) -> Self { Self { header: Some(header), - body, + body: body.to_vec(), } } - - /// Returns true if the header and origin are present, otherwise false - pub fn has_origin(&self) -> bool { - self.header.as_ref().map(|h| h.origin.is_some()).unwrap_or(false) - } - - /// Verifies the origin signature and returns true if it is valid. - /// - /// This method panics if called on an envelope without an origin. This should be checked before calling this - /// function by using the `DhtEnvelope::has_origin` method - pub fn is_origin_signature_valid(&self) -> bool { - self.header - .as_ref() - .and_then(|header| { - let origin = header - .origin - .as_ref() - .expect("call is_origin_signature_valid on envelope without origin"); - - CommsPublicKey::from_bytes(&origin.public_key) - .map(|pk| (pk, &origin.signature)) - .ok() - }) - .map(|(origin_public_key, origin_signature)| { - match signature::verify(&origin_public_key, origin_signature, &self.body) { - Ok(is_valid) => is_valid, - // error means that the signature could not deserialize, so is invalid - Err(_) => false, - } - }) - .unwrap_or(false) - } } /// Represents the ways a destination node can be represented. diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index ade9cb1cc7..18797eb380 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -22,18 +22,43 @@ use crate::{ crypt, - envelope::DhtMessageFlags, + envelope::{DhtMessageFlags, DhtMessageHeader}, inbound::message::{DecryptedDhtMessage, DhtInboundMessage}, + proto::envelope::OriginMac, }; +use derive_error::Error; use futures::{task::Context, Future}; use log::*; use prost::Message; use std::{sync::Arc, task::Poll}; -use tari_comms::{message::EnvelopeBody, peer_manager::NodeIdentity, pipeline::PipelineError}; +use tari_comms::{ + message::EnvelopeBody, + peer_manager::NodeIdentity, + pipeline::PipelineError, + types::CommsPublicKey, + utils::signature, +}; +use tari_crypto::tari_utilities::ByteArray; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::middleware::decryption"; +#[derive(Error, Debug)] +enum DecryptionError { + /// Failed to validate origin MAC signature + OriginMacInvalidSignature, + /// Origin MAC contained an invalid public key + OriginMacInvalidPublicKey, + /// Origin MAC not provided for encrypted message + OriginMacNotProvided, + /// Failed to decrypt origin MAC + OriginMacDecryptedFailed, + /// Failed to decode clear-text origin MAC + OriginMacClearTextDecodeFailed, + /// Failed to decrypt message body + MessageBodyDecryptionFailed, +} + /// This layer is responsible for attempting to decrypt inbound messages. pub struct DecryptionLayer { node_identity: Arc, @@ -96,20 +121,42 @@ where S: Service ) -> Result<(), PipelineError> { let dht_header = &message.dht_header; + if !dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) { return Self::success_not_encrypted(next_service, message).await; } - let origin = dht_header - .origin + let e_pk = dht_header + .ephemeral_public_key .as_ref() - // TODO: #banheuristics - this should not have been sent/propagated - .ok_or_else(|| "Message origin field is required for encrypted messages")?; + // TODO: #banheuristic - encrypted message sent without ephemeral public key + .ok_or("Ephemeral public key not provided for encrypted message")?; + + let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), e_pk); - debug!(target: LOG_TARGET, "Attempting to decrypt message"); - let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), &origin.public_key); - match crypt::decrypt(&shared_secret, &message.body) { - Ok(decrypted) => Self::decryption_succeeded(next_service, &node_identity, message, &decrypted).await, + // Decrypt and verify the origin + let authenticated_origin = match Self::attempt_decrypt_origin_mac(&shared_secret, dht_header) { + Ok(origin) => { + // If this fails, discard the message because we decrypted and deserialized the message with our shared + // ECDH secret but the message could not be authenticated + Self::authenticate_origin_mac(&origin, &message.body).map_err(PipelineError::from_debug)? + }, + Err(err) => { + debug!(target: LOG_TARGET, "Unable to decrypt message origin: {}", err); + return Self::decryption_failed(next_service, &node_identity, message).await; + }, + }; + + debug!( + target: LOG_TARGET, + "Attempting to decrypt message body from origin public key '{}'", authenticated_origin + ); + match Self::attempt_decrypt_message_body(&shared_secret, &message.body) { + Ok(message_body) => { + debug!(target: LOG_TARGET, "Message successfully decrypted"); + let msg = DecryptedDhtMessage::succeeded(message_body, Some(authenticated_origin), message); + next_service.oneshot(msg).await + }, Err(err) => { debug!(target: LOG_TARGET, "Unable to decrypt message: {}", err); Self::decryption_failed(next_service, &node_identity, message).await @@ -117,57 +164,81 @@ where S: Service } } - async fn decryption_succeeded( - next_service: S, - node_identity: &NodeIdentity, - message: DhtInboundMessage, - decrypted: &[u8], - ) -> Result<(), PipelineError> + fn attempt_decrypt_origin_mac( + shared_secret: &CommsPublicKey, + dht_header: &DhtMessageHeader, + ) -> Result { + let encrypted_origin_mac = Some(&dht_header.origin_mac) + .filter(|b| !b.is_empty()) + // TODO: #banheuristic - this should not have been sent/propagated + .ok_or_else(|| DecryptionError::OriginMacNotProvided)?; + let decrypted_bytes = crypt::decrypt(shared_secret, encrypted_origin_mac) + .map_err(|_| DecryptionError::OriginMacDecryptedFailed)?; + OriginMac::decode(decrypted_bytes.as_slice()).map_err(|_| DecryptionError::OriginMacDecryptedFailed) + } + + fn authenticate_origin_mac(origin: &OriginMac, body: &[u8]) -> Result { + let public_key = + CommsPublicKey::from_bytes(&origin.public_key).map_err(|_| DecryptionError::OriginMacInvalidPublicKey)?; + if signature::verify(&public_key, &origin.signature, body).unwrap_or(false) { + Ok(public_key) + } else { + Err(DecryptionError::OriginMacInvalidSignature) + } + } + + fn attempt_decrypt_message_body( + shared_secret: &CommsPublicKey, + message_body: &[u8], + ) -> Result + { + let decrypted = + crypt::decrypt(shared_secret, message_body).map_err(|_| DecryptionError::MessageBodyDecryptionFailed)?; // Deserialization into an EnvelopeBody is done here to determine if the // decryption produced valid bytes or not. - let result = EnvelopeBody::decode(decrypted).and_then(|body| { - // Check if we received a body length of zero - // - // In addition to a peer sending a zero-length EnvelopeBody, decoding can erroneously succeed - // if the decrypted bytes happen to be valid protobuf encoding. This is very possible and - // the decrypt_inbound_fail test below _will_ sporadically fail without the following check. - // This is because proto3 will set fields to their default value if they don't exist in a valid encoding. - // - // For the parts of EnvelopeBody to be erroneously populated with bytes, all of these - // conditions would have to be true: - // 1. field type == 2 (length-delimited) - // 2. field number == 1 - // 3. the subsequent byte(s) would have to be varint-encoded length which does not overflow - // 4. the rest of the bytes would have to be valid protobuf encoding - // - // The chance of this happening is extremely negligible. - if body.is_empty() { - return Err(prost::DecodeError::new("EnvelopeBody has no parts")); - } - Ok(body) - }); - match result { - Ok(deserialized) => { - debug!(target: LOG_TARGET, "Message successfully decrypted"); - let msg = DecryptedDhtMessage::succeeded(deserialized, message); - next_service.oneshot(msg).await - }, - Err(err) => { - debug!(target: LOG_TARGET, "Unable to deserialize message: {}", err); - Self::decryption_failed(next_service, &node_identity, message).await - }, - } + EnvelopeBody::decode(decrypted.as_slice()) + .and_then(|body| { + // Check if we received a body length of zero + // + // In addition to a peer sending a zero-length EnvelopeBody, decoding can erroneously succeed + // if the decrypted bytes happen to be valid protobuf encoding. This is very possible and + // the decrypt_inbound_fail test below _will_ sporadically fail without the following check. + // This is because proto3 will set fields to their default value if they don't exist in a valid + // encoding. + // + // For the parts of EnvelopeBody to be erroneously populated with bytes, all of these + // conditions would have to be true: + // 1. field type == 2 (length-delimited) + // 2. field number == 1 + // 3. the subsequent byte(s) would have to be varint-encoded length which does not overflow + // 4. the rest of the bytes would have to be valid protobuf encoding + // + // The chance of this happening is extremely negligible. + if body.is_empty() { + return Err(prost::DecodeError::new("EnvelopeBody has no parts")); + } + Ok(body) + }) + .map_err(|_| DecryptionError::MessageBodyDecryptionFailed) } async fn success_not_encrypted(next_service: S, message: DhtInboundMessage) -> Result<(), PipelineError> { + let authenticated_pk = if message.dht_header.origin_mac.is_empty() { + None + } else { + let origin_mac = OriginMac::decode(message.dht_header.origin_mac.as_slice()) + .map_err(|_| PipelineError::from_debug(DecryptionError::OriginMacClearTextDecodeFailed))?; + Some(Self::authenticate_origin_mac(&origin_mac, &message.body).map_err(PipelineError::from_debug)?) + }; + match EnvelopeBody::decode(message.body.as_slice()) { Ok(deserialized) => { debug!( target: LOG_TARGET, "Message is not encrypted. Passing onto next service" ); - let msg = DecryptedDhtMessage::succeeded(deserialized, message); + let msg = DecryptedDhtMessage::succeeded(deserialized, authenticated_pk, message); next_service.oneshot(msg).await }, Err(err) => { @@ -194,8 +265,9 @@ where S: Service // TODO: #banheuristic - the origin of this message sent this node a message we could not decrypt warn!( target: LOG_TARGET, - "Received message from peer '{}' that is destined for that peer. Discarding message", - message.dht_header.origin.as_ref().expect("already checked").public_key + "Received message from peer '{}' that is destined for this node that could not be decrypted. \ + Discarding message", + message.source_peer.node_id ); return Err( "Message rejected because this node could not decrypt a message that was addressed to it".into(), @@ -241,10 +313,13 @@ mod test { let node_identity = make_node_identity(); let mut service = DecryptionService::new(inner, Arc::clone(&node_identity)); - let plain_text_msg = wrap_in_envelope_body!(Vec::new()).unwrap(); - let secret_key = crypt::generate_ecdh_secret(node_identity.secret_key(), node_identity.public_key()); - let encrypted = crypt::encrypt(&secret_key, &plain_text_msg.to_encoded_bytes().unwrap()).unwrap(); - let inbound_msg = make_dht_inbound_message(&node_identity, encrypted, DhtMessageFlags::ENCRYPTED); + let plain_text_msg = wrap_in_envelope_body!(b"Secret plans".to_vec()); + let inbound_msg = make_dht_inbound_message( + &node_identity, + plain_text_msg.to_encoded_bytes(), + DhtMessageFlags::ENCRYPTED, + true, + ); block_on(service.call(inbound_msg)).unwrap(); let decrypted = result.lock().unwrap().take().unwrap(); @@ -262,14 +337,16 @@ mod test { let node_identity = make_node_identity(); let mut service = DecryptionService::new(inner, Arc::clone(&node_identity)); - let nonsense = "Cannot Decrypt this".as_bytes().to_vec(); - let inbound_msg = make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED); + let some_secret = "Super secret message".as_bytes().to_vec(); + let some_other_node_identity = make_node_identity(); + let inbound_msg = + make_dht_inbound_message(&some_other_node_identity, some_secret, DhtMessageFlags::ENCRYPTED, true); - block_on(service.call(inbound_msg)).unwrap(); + block_on(service.call(inbound_msg.clone())).unwrap(); let decrypted = result.lock().unwrap().take().unwrap(); assert_eq!(decrypted.decryption_succeeded(), false); - assert_eq!(decrypted.decryption_result.unwrap_err(), nonsense); + assert_eq!(decrypted.decryption_result.unwrap_err(), inbound_msg.body); } #[test] @@ -283,7 +360,8 @@ mod test { let mut service = DecryptionService::new(inner, Arc::clone(&node_identity)); let nonsense = "Cannot Decrypt this".as_bytes().to_vec(); - let mut inbound_msg = make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED); + let mut inbound_msg = + make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED, true); inbound_msg.dht_header.destination = node_identity.public_key().clone().into(); let err = block_on(service.call(inbound_msg)).unwrap_err(); diff --git a/comms/dht/src/inbound/dedup.rs b/comms/dht/src/inbound/dedup.rs index 25a6253dc5..15f09646fc 100644 --- a/comms/dht/src/inbound/dedup.rs +++ b/comms/dht/src/inbound/dedup.rs @@ -26,7 +26,6 @@ use futures::{task::Context, Future}; use log::*; use std::task::Poll; use tari_comms::{pipeline::PipelineError, types::Challenge}; -use tari_crypto::tari_utilities::hex::Hex; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::dedup"; @@ -85,13 +84,7 @@ where S: Service { warn!( target: LOG_TARGET, - "Received duplicate message from peer {} (origin={:?}). Message discarded.", - message.source_peer.node_id, - message - .dht_header - .origin - .map(|o| o.public_key.to_hex()) - .unwrap_or_else(|| "".to_string()), + "Received duplicate message from peer {}. Message discarded.", message.source_peer.node_id, ); return Ok(()); } @@ -148,7 +141,7 @@ mod test { assert!(dedup.poll_ready(&mut cx).is_ready()); let node_identity = make_node_identity(); - let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty()); + let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false); rt.block_on(dedup.call(msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); diff --git a/comms/dht/src/inbound/deserialize.rs b/comms/dht/src/inbound/deserialize.rs index cf94c2d471..57ac7e39e2 100644 --- a/comms/dht/src/inbound/deserialize.rs +++ b/comms/dht/src/inbound/deserialize.rs @@ -58,51 +58,38 @@ where S: Service + Clon Poll::Ready(Ok(())) } - fn call(&mut self, msg: InboundMessage) -> Self::Future { - Self::deserialize(self.next_service.clone(), msg) - } -} - -impl DhtDeserializeMiddleware -where S: Service -{ - pub async fn deserialize(next_service: S, message: InboundMessage) -> Result<(), PipelineError> { - trace!(target: LOG_TARGET, "Deserializing InboundMessage"); - - let InboundMessage { - source_peer, mut body, .. - } = message; - - match DhtEnvelope::decode(&mut body) { - Ok(dht_envelope) => { - trace!(target: LOG_TARGET, "Deserialization succeeded. Checking signatures"); - if dht_envelope.has_origin() { - if dht_envelope.is_origin_signature_valid() { - trace!(target: LOG_TARGET, "Origin signature validation passed."); - } else { - // TODO: #banheuristic - // The origin signature is not valid, this message should never have been sent - warn!( - target: LOG_TARGET, - "SECURITY: Origin signature verification failed. Discarding message from NodeId {}", - source_peer.node_id - ); - return Ok(()); - } - } - - let inbound_msg = DhtInboundMessage::new( - dht_envelope.header.try_into().map_err(PipelineError::from_debug)?, - source_peer, - dht_envelope.body, - ); - - next_service.oneshot(inbound_msg).await - }, - Err(err) => { - error!(target: LOG_TARGET, "DHT deserialization failed: {}", err); - Err(PipelineError::from_debug(err)) - }, + fn call(&mut self, message: InboundMessage) -> Self::Future { + let next_service = self.next_service.clone(); + async move { + trace!(target: LOG_TARGET, "Deserializing InboundMessage"); + + let InboundMessage { + source_peer, mut body, .. + } = message; + + if body.is_empty() { + return Err(format!("Received empty message from peer '{}'", source_peer) + .as_str() + .into()); + } + + match DhtEnvelope::decode(&mut body) { + Ok(dht_envelope) => { + trace!(target: LOG_TARGET, "Deserialization succeeded."); + + let inbound_msg = DhtInboundMessage::new( + dht_envelope.header.try_into().map_err(PipelineError::from_debug)?, + source_peer, + dht_envelope.body, + ); + + next_service.oneshot(inbound_msg).await + }, + Err(err) => { + error!(target: LOG_TARGET, "DHT deserialization failed: {}", err); + Err(PipelineError::from_debug(err)) + }, + } } } } @@ -144,10 +131,10 @@ mod test { assert!(deserialize.poll_ready(&mut cx).is_ready()); let node_identity = make_node_identity(); - let dht_envelope = make_dht_envelope(&node_identity, b"A".to_vec(), DhtMessageFlags::empty()); + let dht_envelope = make_dht_envelope(&node_identity, b"A".to_vec(), DhtMessageFlags::empty(), false); block_on(deserialize.call(make_comms_inbound_message( &node_identity, - dht_envelope.to_encoded_bytes().unwrap().into(), + dht_envelope.to_encoded_bytes().into(), ))) .unwrap(); diff --git a/comms/dht/src/inbound/dht_handler/task.rs b/comms/dht/src/inbound/dht_handler/task.rs index 9a3dc372d1..43601edea2 100644 --- a/comms/dht/src/inbound/dht_handler/task.rs +++ b/comms/dht/src/inbound/dht_handler/task.rs @@ -175,20 +175,20 @@ where S: Service decryption_result, dht_header, source_peer, + authenticated_origin, .. } = message; - let origin = dht_header - .origin - .as_ref() - .ok_or_else(|| DhtInboundError::OriginRequired("Origin is required for this message type".to_string()))?; + let authenticated_pk = authenticated_origin.ok_or_else(|| { + DhtInboundError::OriginRequired("Authenticated origin is required for this message type".to_string()) + })?; - if &origin.public_key == self.node_identity.public_key() { + if &authenticated_pk == self.node_identity.public_key() { trace!(target: LOG_TARGET, "Received our own join message. Discarding it."); return Ok(()); } - trace!(target: LOG_TARGET, "Received Join Message from {}", origin.public_key); + trace!(target: LOG_TARGET, "Received Join Message from {}", authenticated_pk); let body = decryption_result.expect("already checked that this message decrypted successfully"); let join_msg = body @@ -205,11 +205,11 @@ where S: Service return Err(DhtInboundError::InvalidAddresses); } - let node_id = self.validate_raw_node_id(&origin.public_key, &join_msg.node_id)?; + let node_id = self.validate_raw_node_id(&authenticated_pk, &join_msg.node_id)?; let origin_peer = self .add_or_update_peer( - &origin.public_key, + &authenticated_pk, node_id, addresses, PeerFeatures::from_bits_truncate(join_msg.peer_features), @@ -262,12 +262,12 @@ where S: Service .closest( origin_peer.node_id, self.config.num_neighbouring_nodes, - vec![origin.public_key.clone(), source_peer.public_key.clone()], + vec![authenticated_pk, source_peer.public_key.clone()], PeerFeatures::MESSAGE_PROPAGATION, ) .with_dht_header(dht_header) .finish(), - body.to_encoded_bytes()?, + body.to_encoded_bytes(), ) .await?; @@ -300,10 +300,9 @@ where S: Service target: LOG_TARGET, "Received Discover Response Message from {}", message - .dht_header - .origin + .authenticated_origin .as_ref() - .map(|o| o.public_key.to_hex()) + .map(|pk| pk.to_hex()) .unwrap_or_else(|| "".to_string()) ); @@ -331,13 +330,13 @@ where S: Service .decode_part::(0)? .ok_or_else(|| DhtInboundError::InvalidMessageBody)?; - let origin = message.dht_header.origin.ok_or_else(|| { + let authenticated_pk = message.authenticated_origin.ok_or_else(|| { DhtInboundError::OriginRequired("Origin header required for Discovery message".to_string()) })?; info!( target: LOG_TARGET, - "Received discovery message from '{}'", origin.public_key, + "Received discovery message from '{}'", authenticated_pk ); let addresses = discover_msg @@ -350,10 +349,10 @@ where S: Service return Err(DhtInboundError::InvalidAddresses); } - let node_id = self.validate_raw_node_id(&origin.public_key, &discover_msg.node_id)?; + let node_id = self.validate_raw_node_id(&authenticated_pk, &discover_msg.node_id)?; let origin_peer = self .add_or_update_peer( - &origin.public_key, + &authenticated_pk, node_id, addresses, PeerFeatures::from_bits_truncate(discover_msg.peer_features), @@ -370,7 +369,7 @@ where S: Service } // Send the origin the current nodes latest contact info - self.send_discovery_response(origin.public_key, discover_msg.nonce) + self.send_discovery_response(origin_peer.public_key, discover_msg.nonce) .await?; Ok(()) diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index 6224bf340c..987c8cea76 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -68,17 +68,24 @@ pub struct DecryptedDhtMessage { /// The _connected_ peer which sent or forwarded this message. This may not be the peer /// which created this message. pub source_peer: Arc, + pub authenticated_origin: Option, pub dht_header: DhtMessageHeader, pub decryption_result: Result>, } impl DecryptedDhtMessage { - pub fn succeeded(decrypted_message: EnvelopeBody, message: DhtInboundMessage) -> Self { + pub fn succeeded( + message_body: EnvelopeBody, + authenticated_origin: Option, + message: DhtInboundMessage, + ) -> Self + { Self { version: message.version, source_peer: message.source_peer, + authenticated_origin, dht_header: message.dht_header, - decryption_result: Ok(decrypted_message), + decryption_result: Ok(message_body), } } @@ -86,6 +93,7 @@ impl DecryptedDhtMessage { Self { version: message.version, source_peer: message.source_peer, + authenticated_origin: None, dht_header: message.dht_header, decryption_result: Err(message.body), } @@ -115,12 +123,8 @@ impl DecryptedDhtMessage { self.decryption_result.is_err() } - pub fn origin_public_key(&self) -> &CommsPublicKey { - self.dht_header - .origin - .as_ref() - .map(|o| &o.public_key) - .unwrap_or(&self.source_peer.public_key) + pub fn authenticated_origin(&self) -> Option<&CommsPublicKey> { + self.authenticated_origin.as_ref() } /// Returns true if the message is or was encrypted by @@ -128,11 +132,11 @@ impl DecryptedDhtMessage { self.dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) } - pub fn has_origin(&self) -> bool { - self.dht_header.origin.is_some() + pub fn has_origin_mac(&self) -> bool { + !self.dht_header.origin_mac.is_empty() } - pub fn body_size(&self) -> usize { + pub fn body_len(&self) -> usize { match self.decryption_result.as_ref() { Ok(b) => b.total_size(), Err(b) => b.len(), diff --git a/comms/dht/src/inbound/validate.rs b/comms/dht/src/inbound/validate.rs index c4fa1fce28..b6132af821 100644 --- a/comms/dht/src/inbound/validate.rs +++ b/comms/dht/src/inbound/validate.rs @@ -20,19 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - inbound::DhtInboundMessage, - outbound::{OutboundMessageRequester, SendMessageParams}, - proto::{ - dht::{RejectMessage, RejectMessageReason}, - envelope::{DhtMessageType, Network}, - }, -}; +use crate::{inbound::DhtInboundMessage, proto::envelope::Network}; use futures::{task::Context, Future}; use log::*; use std::task::Poll; -use tari_comms::{message::MessageExt, pipeline::PipelineError}; -use tari_crypto::tari_utilities::ByteArray; +use tari_comms::pipeline::PipelineError; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::validate"; @@ -45,15 +37,13 @@ const LOG_TARGET: &str = "comms::dht::validate"; pub struct ValidateMiddleware { next_service: S, target_network: Network, - outbound_requester: OutboundMessageRequester, } impl ValidateMiddleware { - pub fn new(service: S, target_network: Network, outbound_requester: OutboundMessageRequester) -> Self { + pub fn new(service: S, target_network: Network) -> Self { Self { next_service: service, target_network, - outbound_requester, } } } @@ -70,76 +60,40 @@ where S: Service + Clon Poll::Ready(Ok(())) } - fn call(&mut self, msg: DhtInboundMessage) -> Self::Future { - Self::process_message( - self.next_service.clone(), - self.target_network, - self.outbound_requester.clone(), - msg, - ) - } -} - -impl ValidateMiddleware -where S: Service -{ - pub async fn process_message( - next_service: S, - target_network: Network, - mut outbound_requester: OutboundMessageRequester, - message: DhtInboundMessage, - ) -> Result<(), PipelineError> - { - trace!( - target: LOG_TARGET, - "Checking the message target network is '{:?}'", - target_network - ); - if message.dht_header.network == target_network { - next_service.oneshot(message).await?; - } else { - debug!( + fn call(&mut self, message: DhtInboundMessage) -> Self::Future { + let next_service = self.next_service.clone(); + let target_network = self.target_network; + async move { + trace!( target: LOG_TARGET, - "Message is for another network (want = {:?} got = {:?}). Explicitly rejecting the message.", - target_network, - message.dht_header.network + "Checking the message target network is '{:?}'", + target_network ); - outbound_requester - .send_raw( - SendMessageParams::new() - .direct_public_key(message.source_peer.public_key.clone()) - .with_dht_message_type(DhtMessageType::RejectMsg) - .finish(), - RejectMessage { - signature: message - .dht_header - .origin - .map(|o| o.public_key.to_vec()) - .unwrap_or_default(), - reason: RejectMessageReason::UnsupportedNetwork as i32, - } - .to_encoded_bytes() - .map_err(PipelineError::from_debug)?, - ) - .await - .map_err(PipelineError::from_debug)?; - } - Ok(()) + if message.dht_header.network == target_network && message.dht_header.is_valid() { + next_service.oneshot(message).await?; + } else { + debug!( + target: LOG_TARGET, + "Message is for another network (want = {:?} got = {:?}) or message header is invalid. Discarding \ + the message.", + target_network, + message.dht_header.network + ); + } + + Ok(()) + } } } pub struct ValidateLayer { target_network: Network, - outbound_requester: OutboundMessageRequester, } impl ValidateLayer { - pub fn new(target_network: Network, outbound_requester: OutboundMessageRequester) -> Self { - Self { - target_network, - outbound_requester, - } + pub fn new(target_network: Network) -> Self { + Self { target_network } } } @@ -147,7 +101,7 @@ impl Layer for ValidateLayer { type Service = ValidateMiddleware; fn layer(&self, service: S) -> Self::Service { - ValidateMiddleware::new(service, self.target_network, self.outbound_requester.clone()) + ValidateMiddleware::new(service, self.target_network) } } @@ -155,8 +109,7 @@ impl Layer for ValidateLayer { mod test { use super::*; use crate::{ - envelope::{DhtMessageFlags, DhtMessageType}, - outbound::mock::create_outbound_service_mock, + envelope::DhtMessageFlags, test_utils::{make_dht_inbound_message, make_node_identity, service_spy}, }; use tari_test_utils::panic_context; @@ -167,18 +120,13 @@ mod test { let mut rt = Runtime::new().unwrap(); let spy = service_spy(); - let (out_requester, mock) = create_outbound_service_mock(1); - let mock_state = mock.get_state(); - rt.spawn(mock.run()); - - let mut validate = - ValidateLayer::new(Network::LocalTest, out_requester).layer(spy.to_service::()); + let mut validate = ValidateLayer::new(Network::LocalTest).layer(spy.to_service::()); panic_context!(cx); assert!(validate.poll_ready(&mut cx).is_ready()); let node_identity = make_node_identity(); - let mut msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty()); + let mut msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false); msg.dht_header.network = Network::MainNet; rt.block_on(validate.call(msg.clone())).unwrap(); @@ -188,17 +136,5 @@ mod test { rt.block_on(validate.call(msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); - - let calls = mock_state.take_calls(); - assert_eq!(calls.len(), 1); - let params = calls[0].0.clone(); - assert_eq!(params.dht_message_type, DhtMessageType::RejectMsg); - assert_eq!( - params.broadcast_strategy.direct_public_key().unwrap(), - node_identity.public_key() - ); - - // Drop validate so that the mock will stop running - drop(validate); } } diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 33bc492f0b..c0b1672a07 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -25,7 +25,7 @@ use crate::{ actor::DhtRequester, broadcast_strategy::BroadcastStrategy, discovery::DhtDiscoveryRequester, - envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, NodeDestination}, + envelope::{DhtMessageFlags, DhtMessageHeader, NodeDestination}, outbound::{ message::{DhtOutboundMessage, OutboundEncryption}, message_params::FinalSendMessageParams, @@ -33,6 +33,7 @@ use crate::{ }, proto::envelope::{DhtMessageType, Network}, }; +use bytes::Bytes; use futures::{ channel::oneshot, future, @@ -43,7 +44,8 @@ use futures::{ use log::*; use std::{sync::Arc, task::Poll}; use tari_comms::{ - peer_manager::{NodeId, NodeIdentity, Peer}, + message::MessageTag, + peer_manager::{NodeIdentity, Peer}, pipeline::PipelineError, types::CommsPublicKey, }; @@ -216,7 +218,7 @@ where S: Service async fn handle_send_message( &mut self, params: FinalSendMessageParams, - body: Vec, + body: Bytes, reply_tx: oneshot::Sender, ) -> Result, DhtOutboundError> { @@ -344,24 +346,8 @@ where S: Service dest_public_key ); - // TODO: This works because we know that all non-DAN node IDs are/should be derived from the public key. - // Once the DAN launches, this may not be the case and we'll need to query the blockchain for the node id - let derived_node_id = NodeId::from_key(&*dest_public_key).ok(); - - // TODO: Target a general region instead of the actual destination node id - let regional_destination = derived_node_id - .as_ref() - .map(Clone::clone) - .map(Box::new) - .map(NodeDestination::NodeId) - .unwrap_or_else(|| NodeDestination::Unknown); - // Peer not found, let's try and discover it - match self - .dht_discovery_requester - .discover_peer(dest_public_key, derived_node_id, regional_destination) - .await - { + match self.dht_discovery_requester.discover_peer(dest_public_key).await { // Peer found! Ok(peer) => { if peer.is_banned() { @@ -397,42 +383,28 @@ where S: Service custom_header: Option, extra_flags: DhtMessageFlags, force_origin: bool, - body: Vec, + body: Bytes, ) -> Result, DhtOutboundError> { let dht_flags = encryption.flags() | extra_flags; - // Create a DHT header - let dht_header = custom_header - .or_else(|| { - // The origin is specified if encryption is turned on, otherwise it is not - let origin = if force_origin || encryption.is_encrypt() { - Some(DhtMessageOrigin { - // Origin public key used to identify the origin and verify the signature - public_key: self.node_identity.public_key().clone(), - // Signing will happen later in the pipeline (SerializeMiddleware), left empty to prevent double - // work - signature: Vec::new(), - }) - } else { - None - }; - - Some(DhtMessageHeader::new( - // Final destination for this message - destination, - dht_message_type, - origin, - self.target_network, - dht_flags, - )) - }) - .expect("always Some"); - - // Construct a MessageEnvelope for each recipient + // Construct a DhtOutboundMessage for each recipient let messages = selected_peers .into_iter() - .map(|peer| DhtOutboundMessage::new(peer, dht_header.clone(), encryption.clone(), body.clone())) + .map(|peer| DhtOutboundMessage { + tag: MessageTag::new(), + destination_peer: peer, + destination: destination.clone(), + dht_message_type, + network: self.target_network, + dht_flags, + custom_header: custom_header.clone(), + include_origin: force_origin || encryption.is_encrypt(), + encryption: encryption.clone(), + body: body.clone(), + ephemeral_public_key: None, + origin_mac: None, + }) .collect::>(); Ok(messages) @@ -518,7 +490,7 @@ mod test { rt.block_on(service.call(DhtOutboundRequest::SendMessage( Box::new(SendMessageParams::new().flood().finish()), - "custom_msg".as_bytes().to_vec(), + "custom_msg".as_bytes().into(), reply_tx, ))) .unwrap(); @@ -568,7 +540,7 @@ mod test { .with_discovery(false) .finish(), ), - "custom_msg".as_bytes().to_vec(), + Bytes::from_static(b"custom_msg"), reply_tx, )), ) @@ -619,7 +591,7 @@ mod test { .direct_public_key(peer_to_discover.public_key.clone()) .finish(), ), - "custom_msg".as_bytes().to_vec(), + "custom_msg".as_bytes().into(), reply_tx, )), ) diff --git a/comms/dht/src/outbound/encryption.rs b/comms/dht/src/outbound/encryption.rs index 659b6ca067..9b3f1be881 100644 --- a/comms/dht/src/outbound/encryption.rs +++ b/comms/dht/src/outbound/encryption.rs @@ -23,11 +23,23 @@ use crate::{ crypt, outbound::message::{DhtOutboundMessage, OutboundEncryption}, + proto::envelope::OriginMac, }; use futures::{task::Context, Future}; use log::*; +use rand::rngs::OsRng; use std::{sync::Arc, task::Poll}; -use tari_comms::{peer_manager::NodeIdentity, pipeline::PipelineError}; +use tari_comms::{ + message::MessageExt, + peer_manager::NodeIdentity, + pipeline::PipelineError, + types::CommsPublicKey, + utils::signature, +}; +use tari_crypto::{ + keys::PublicKey, + tari_utilities::{message_format::MessageFormat, ByteArray}, +}; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::middleware::encryption"; @@ -79,34 +91,55 @@ where S: Service + Clo Poll::Ready(Ok(())) } - fn call(&mut self, msg: DhtOutboundMessage) -> Self::Future { - Self::handle_message(self.inner.clone(), Arc::clone(&self.node_identity), msg) + fn call(&mut self, mut message: DhtOutboundMessage) -> Self::Future { + let next_service = self.inner.clone(); + let node_identity = Arc::clone(&self.node_identity); + async move { + trace!(target: LOG_TARGET, "DHT Message flags: {:?}", message.dht_flags); + match &message.encryption { + OutboundEncryption::EncryptFor(public_key) => { + debug!(target: LOG_TARGET, "Encrypting message for {}", public_key); + // Generate ephemeral public/private key pair and ECDH shared secret + let (e_sk, e_pk) = CommsPublicKey::random_keypair(&mut OsRng); + let shared_ephemeral_secret = crypt::generate_ecdh_secret(&e_sk, &**public_key); + // Encrypt the message with the body + let encrypted_body = + crypt::encrypt(&shared_ephemeral_secret, &message.body).map_err(PipelineError::from_debug)?; + + // Sign the encrypted message + let origin_mac = create_origin_mac(&node_identity, &encrypted_body)?; + // Encrypt and set the origin field + let encrypted_origin_mac = + crypt::encrypt(&shared_ephemeral_secret, &origin_mac).map_err(PipelineError::from_debug)?; + message + .with_origin_mac(encrypted_origin_mac) + .with_ephemeral_public_key(e_pk) + .set_body(encrypted_body.into()); + }, + OutboundEncryption::None => { + debug!(target: LOG_TARGET, "Encryption not requested for message"); + + if message.include_origin && message.custom_header.is_none() { + let origin_mac = create_origin_mac(&node_identity, &message.body)?; + message.with_origin_mac(origin_mac); + } + }, + }; + + next_service.oneshot(message).await + } } } -impl EncryptionService -where S: Service -{ - async fn handle_message( - next_service: S, - node_identity: Arc, - mut message: DhtOutboundMessage, - ) -> Result<(), PipelineError> - { - trace!(target: LOG_TARGET, "DHT Message flags: {:?}", message.dht_header.flags); - match &message.encryption { - OutboundEncryption::EncryptFor(public_key) => { - debug!(target: LOG_TARGET, "Encrypting message for {}", public_key); - let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), &**public_key); - message.body = crypt::encrypt(&shared_secret, &message.body).map_err(PipelineError::from_debug)?; - }, - OutboundEncryption::None => { - debug!(target: LOG_TARGET, "Encryption not requested for message"); - }, - }; - - next_service.oneshot(message).await - } +fn create_origin_mac(node_identity: &NodeIdentity, body: &[u8]) -> Result, PipelineError> { + let signature = + signature::sign(&mut OsRng, node_identity.secret_key().clone(), body).map_err(PipelineError::from_debug)?; + + let mac = OriginMac { + public_key: node_identity.public_key().to_vec(), + signature: signature.to_binary().map_err(PipelineError::from_debug)?, + }; + Ok(mac.to_encoded_bytes()) } #[cfg(test)] @@ -114,14 +147,10 @@ mod test { use super::*; use crate::{ envelope::DhtMessageFlags, - test_utils::{make_dht_header, make_node_identity, service_spy}, + test_utils::{create_outbound_message, make_node_identity, service_spy}, }; use futures::executor::block_on; - use tari_comms::{ - net_address::MultiaddressesWithStats, - peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, - types::CommsPublicKey, - }; + use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; use tari_test_utils::panic_context; #[test] @@ -133,25 +162,14 @@ mod test { panic_context!(cx); assert!(encryption.poll_ready(&mut cx).is_ready()); - let body = b"A".to_vec(); - let msg = DhtOutboundMessage::new( - Peer::new( - CommsPublicKey::default(), - NodeId::default(), - MultiaddressesWithStats::new(vec![]), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_NODE, - &[], - ), - make_dht_header(&node_identity, &body, DhtMessageFlags::empty()), - OutboundEncryption::None, - body.clone(), - ); + let body = b"A"; + let msg = create_outbound_message(body); block_on(encryption.call(msg)).unwrap(); let msg = spy.pop_request().unwrap(); - assert_eq!(msg.body, body); + assert_eq!(msg.body.to_vec(), body); assert_eq!(msg.destination_peer.node_id, NodeId::default()); + assert!(msg.ephemeral_public_key.is_none()) } #[test] @@ -163,24 +181,15 @@ mod test { panic_context!(cx); assert!(encryption.poll_ready(&mut cx).is_ready()); - let body = b"A".to_vec(); - let msg = DhtOutboundMessage::new( - Peer::new( - CommsPublicKey::default(), - NodeId::default(), - MultiaddressesWithStats::new(vec![]), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_NODE, - &[], - ), - make_dht_header(&node_identity, &body, DhtMessageFlags::ENCRYPTED), - OutboundEncryption::EncryptFor(Box::new(CommsPublicKey::default())), - body.clone(), - ); + let body = b"A"; + let mut msg = create_outbound_message(body); + msg.dht_flags = DhtMessageFlags::ENCRYPTED; + msg.encryption = OutboundEncryption::EncryptFor(Box::new(CommsPublicKey::default())); block_on(encryption.call(msg)).unwrap(); let msg = spy.pop_request().unwrap(); - assert_ne!(msg.body, body); + assert_ne!(msg.body.to_vec(), body); assert_eq!(msg.destination_peer.node_id, NodeId::default()); + assert!(msg.ephemeral_public_key.is_some()) } } diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index a3f4229793..b266d1c573 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -21,9 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - envelope::{DhtMessageFlags, DhtMessageHeader}, + envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageType, Network, NodeDestination}, outbound::message_params::FinalSendMessageParams, }; +use bytes::Bytes; use futures::channel::oneshot; use std::{fmt, fmt::Display}; use tari_comms::{message::MessageTag, peer_manager::Peer, types::CommsPublicKey}; @@ -114,11 +115,7 @@ impl SendMessageResponse { #[derive(Debug)] pub enum DhtOutboundRequest { /// Send a message using the given broadcast strategy - SendMessage( - Box, - Vec, - oneshot::Sender, - ), + SendMessage(Box, Bytes, oneshot::Sender), } impl fmt::Display for DhtOutboundRequest { @@ -137,41 +134,54 @@ impl fmt::Display for DhtOutboundRequest { pub struct DhtOutboundMessage { pub tag: MessageTag, pub destination_peer: Peer, - pub dht_header: DhtMessageHeader, + pub custom_header: Option, pub encryption: OutboundEncryption, - pub body: Vec, + pub body: Bytes, + pub ephemeral_public_key: Option, + pub origin_mac: Option>, + pub include_origin: bool, + pub destination: NodeDestination, + pub dht_message_type: DhtMessageType, + pub network: Network, + pub dht_flags: DhtMessageFlags, } impl DhtOutboundMessage { - /// Create a new DhtOutboundMessage - pub fn new( - destination_peer: Peer, - dht_header: DhtMessageHeader, - encryption: OutboundEncryption, - body: Vec, - ) -> Self - { - Self { - tag: MessageTag::new(), - destination_peer, - dht_header, - encryption, - body, - } + pub fn with_ephemeral_public_key(&mut self, ephemeral_public_key: CommsPublicKey) -> &mut Self { + self.ephemeral_public_key = Some(ephemeral_public_key); + self + } + + pub fn with_origin_mac(&mut self, origin_mac: Vec) -> &mut Self { + self.origin_mac = Some(origin_mac); + self + } + + pub fn set_body(&mut self, body: Bytes) -> &mut Self { + self.body = body; + self } } impl fmt::Display for DhtOutboundMessage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + let header_str = self + .custom_header + .as_ref() + .and_then(|h| Some(format!("{} (Propagated)", h))) + .unwrap_or_else(|| { + format!( + "Network: {:?}, Flags: {:?}, Destination: {}", + self.network, self.dht_flags, self.destination + ) + }); write!( f, - "\n---- DhtOutboundMessage ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {} \nFlags: \ - {:?}\nEncryption: {}\n{}\n----", + "\n---- Outgoing message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\nEncryption: {}\n{}\n----", self.body.len(), - self.dht_header.message_type, + self.dht_message_type, self.destination_peer, - self.dht_header, - self.dht_header.flags, + header_str, self.encryption, self.tag ) diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index 89cc7c8e65..1ca43c822c 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -26,6 +26,7 @@ use crate::outbound::{ DhtOutboundRequest, OutboundMessageRequester, }; +use bytes::Bytes; use futures::{channel::mpsc, stream::Fuse, StreamExt}; use std::{ sync::{Arc, Condvar, Mutex, RwLock}, @@ -44,7 +45,7 @@ pub fn create_outbound_service_mock(size: usize) -> (OutboundMessageRequester, O #[derive(Clone, Default)] pub struct OutboundServiceMockState { #[allow(clippy::type_complexity)] - calls: Arc)>>>, + calls: Arc>>, next_response: Arc>>, call_count_cond_var: Arc, } @@ -88,7 +89,7 @@ impl OutboundServiceMockState { /// Wait for a call to be added or timeout. /// /// An error will be returned if the timeout expires. - pub fn wait_pop_call(&self, timeout: Duration) -> Result<(FinalSendMessageParams, Vec), String> { + pub fn wait_pop_call(&self, timeout: Duration) -> Result<(FinalSendMessageParams, Bytes), String> { let call_guard = acquire_lock!(self.calls); let (mut call_guard, timeout) = self .call_count_cond_var @@ -106,16 +107,16 @@ impl OutboundServiceMockState { acquire_write_lock!(self.next_response).take() } - pub fn add_call(&self, req: (FinalSendMessageParams, Vec)) { + pub fn add_call(&self, req: (FinalSendMessageParams, Bytes)) { acquire_lock!(self.calls).push(req); self.call_count_cond_var.notify_all(); } - pub fn take_calls(&self) -> Vec<(FinalSendMessageParams, Vec)> { + pub fn take_calls(&self) -> Vec<(FinalSendMessageParams, Bytes)> { acquire_lock!(self.calls).drain(..).collect() } - pub fn pop_call(&self) -> Option<(FinalSendMessageParams, Vec)> { + pub fn pop_call(&self) -> Option<(FinalSendMessageParams, Bytes)> { acquire_lock!(self.calls).pop() } } diff --git a/comms/dht/src/outbound/mod.rs b/comms/dht/src/outbound/mod.rs index 354910eb81..3a5d58fd01 100644 --- a/comms/dht/src/outbound/mod.rs +++ b/comms/dht/src/outbound/mod.rs @@ -21,22 +21,25 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod broadcast; +pub use broadcast::BroadcastLayer; + mod encryption; +pub use encryption::EncryptionLayer; + mod error; +pub use error::DhtOutboundError; + pub(crate) mod message; +pub use message::{DhtOutboundRequest, OutboundEncryption, SendMessageResponse}; + mod message_params; +pub use message_params::SendMessageParams; + mod requester; +pub use requester::OutboundMessageRequester; + mod serialize; +pub use serialize::SerializeLayer; #[cfg(any(test, feature = "test-mocks"))] pub mod mock; - -pub use self::{ - broadcast::BroadcastLayer, - encryption::EncryptionLayer, - error::DhtOutboundError, - message::{DhtOutboundRequest, OutboundEncryption, SendMessageResponse}, - message_params::SendMessageParams, - requester::OutboundMessageRequester, - serialize::SerializeLayer, -}; diff --git a/comms/dht/src/outbound/requester.rs b/comms/dht/src/outbound/requester.rs index 9eb6eb0025..64e0793a01 100644 --- a/comms/dht/src/outbound/requester.rs +++ b/comms/dht/src/outbound/requester.rs @@ -195,7 +195,7 @@ impl OutboundMessageRequester { message ); } - let body = wrap_in_envelope_body!(message.to_header(), message.into_inner())?.to_encoded_bytes()?; + let body = wrap_in_envelope_body!(message.to_header(), message.into_inner()).to_encoded_bytes(); self.send_raw(params, body).await } @@ -211,7 +211,7 @@ impl OutboundMessageRequester { if cfg!(debug_assertions) { trace!(target: LOG_TARGET, "Send Message: {} {:?}", params, message); } - let body = wrap_in_envelope_body!(message)?.to_encoded_bytes()?; + let body = wrap_in_envelope_body!(message).to_encoded_bytes(); self.send_raw(params, body).await } @@ -224,7 +224,7 @@ impl OutboundMessageRequester { { let (reply_tx, reply_rx) = oneshot::channel(); self.sender - .send(DhtOutboundRequest::SendMessage(Box::new(params), body, reply_tx)) + .send(DhtOutboundRequest::SendMessage(Box::new(params), body.into(), reply_tx)) .await?; reply_rx diff --git a/comms/dht/src/outbound/serialize.rs b/comms/dht/src/outbound/serialize.rs index 1815756046..ce761d9464 100644 --- a/comms/dht/src/outbound/serialize.rs +++ b/comms/dht/src/outbound/serialize.rs @@ -20,19 +20,20 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{outbound::message::DhtOutboundMessage, proto::envelope::DhtEnvelope}; +use crate::{ + consts::DHT_ENVELOPE_HEADER_VERSION, + outbound::message::DhtOutboundMessage, + proto::envelope::{DhtEnvelope, DhtHeader}, +}; use futures::{task::Context, Future}; use log::*; -use rand::rngs::OsRng; -use std::{sync::Arc, task::Poll}; +use std::task::Poll; use tari_comms::{ message::{MessageExt, OutboundMessage}, - peer_manager::NodeIdentity, pipeline::PipelineError, - utils::signature, Bytes, }; -use tari_crypto::tari_utilities::message_format::MessageFormat; +use tari_crypto::tari_utilities::ByteArray; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::serialize"; @@ -40,15 +41,11 @@ const LOG_TARGET: &str = "comms::dht::serialize"; #[derive(Clone)] pub struct SerializeMiddleware { inner: S, - node_identity: Arc, } impl SerializeMiddleware { - pub fn new(service: S, node_identity: Arc) -> Self { - Self { - inner: service, - node_identity, - } + pub fn new(service: S) -> Self { + Self { inner: service } } } @@ -64,68 +61,51 @@ where S: Service + Clone Poll::Ready(Ok(())) } - fn call(&mut self, msg: DhtOutboundMessage) -> Self::Future { - Self::serialize(self.inner.clone(), Arc::clone(&self.node_identity), msg) - } -} - -impl SerializeMiddleware -where S: Service -{ - pub async fn serialize( - next_service: S, - node_identity: Arc, - message: DhtOutboundMessage, - ) -> Result<(), PipelineError> - { - debug!(target: LOG_TARGET, "Serializing outbound message {:?}", message.tag); - - let DhtOutboundMessage { - mut dht_header, - body, - destination_peer, - .. - } = message; - - // The message is being forwarded if the origin public_key is specified and it is not this node - let is_forwarded = dht_header - .origin - .as_ref() - .map(|o| &o.public_key != node_identity.public_key()) - .unwrap_or(false); - - // If forwarding the message, the DhtHeader already has a signature that should not change - if is_forwarded { - debug!( - target: LOG_TARGET, - "Message ({}) is being forwarded so this node will NOT signed it", message.tag - ); - } else { - // Sign the body if the origin public key was previously specified. - if let Some(origin) = dht_header.origin.as_mut() { - let signature = signature::sign(&mut OsRng, node_identity.secret_key().clone(), &body) - .map_err(PipelineError::from_debug)?; - origin.signature = signature.to_binary().map_err(PipelineError::from_debug)?; - } + fn call(&mut self, message: DhtOutboundMessage) -> Self::Future { + let next_service = self.inner.clone(); + async move { + debug!(target: LOG_TARGET, "Serializing outbound message {:?}", message.tag); + + let DhtOutboundMessage { + tag, + destination_peer, + custom_header, + body, + ephemeral_public_key, + destination, + dht_message_type, + network, + dht_flags, + origin_mac, + .. + } = message; + + let dht_header = custom_header.map(DhtHeader::from).unwrap_or_else(|| DhtHeader { + version: DHT_ENVELOPE_HEADER_VERSION, + origin_mac: origin_mac.unwrap_or_else(Vec::new), + ephemeral_public_key: ephemeral_public_key.map(|e| e.to_vec()).unwrap_or_else(Vec::new), + message_type: dht_message_type as i32, + network: network as i32, + flags: dht_flags.bits(), + destination: Some(destination.into()), + }); + + let envelope = DhtEnvelope::new(dht_header, body); + + let body = Bytes::from(envelope.to_encoded_bytes()); + + next_service + .oneshot(OutboundMessage::with_tag(tag, destination_peer.node_id, body)) + .await } - - let envelope = DhtEnvelope::new(dht_header.into(), body); - - let body = Bytes::from(envelope.to_encoded_bytes().map_err(PipelineError::from_debug)?); - - next_service - .oneshot(OutboundMessage::with_tag(message.tag, destination_peer.node_id, body)) - .await } } -pub struct SerializeLayer { - node_identity: Arc, -} +pub struct SerializeLayer; impl SerializeLayer { - pub fn new(node_identity: Arc) -> Self { - Self { node_identity } + pub fn new() -> Self { + Self } } @@ -133,50 +113,29 @@ impl Layer for SerializeLayer { type Service = SerializeMiddleware; fn layer(&self, service: S) -> Self::Service { - SerializeMiddleware::new(service, Arc::clone(&self.node_identity)) + SerializeMiddleware::new(service) } } #[cfg(test)] mod test { use super::*; - use crate::{ - envelope::DhtMessageFlags, - outbound::OutboundEncryption, - test_utils::{make_dht_header, make_node_identity, service_spy}, - }; + use crate::test_utils::{create_outbound_message, service_spy}; use futures::executor::block_on; use prost::Message; - use tari_comms::{ - net_address::MultiaddressesWithStats, - peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, - types::CommsPublicKey, - }; + use tari_comms::peer_manager::NodeId; use tari_test_utils::panic_context; #[test] fn serialize() { let spy = service_spy(); - let node_identity = make_node_identity(); - let mut serialize = SerializeLayer::new(Arc::clone(&node_identity)).layer(spy.to_service::()); + let mut serialize = SerializeLayer.layer(spy.to_service::()); panic_context!(cx); assert!(serialize.poll_ready(&mut cx).is_ready()); - let body = b"A".to_vec(); - let msg = DhtOutboundMessage::new( - Peer::new( - CommsPublicKey::default(), - NodeId::default(), - MultiaddressesWithStats::new(vec![]), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_NODE, - &[], - ), - make_dht_header(&node_identity, &body, DhtMessageFlags::empty()), - OutboundEncryption::None, - body, - ); + let body = b"A"; + let msg = create_outbound_message(body); block_on(serialize.call(msg)).unwrap(); let mut msg = spy.pop_request().unwrap(); diff --git a/comms/dht/src/proto/envelope.proto b/comms/dht/src/proto/envelope.proto index 43f1815cee..5a3071dfa3 100644 --- a/comms/dht/src/proto/envelope.proto +++ b/comms/dht/src/proto/envelope.proto @@ -32,14 +32,17 @@ message DhtHeader { } // Origin public key of the message. This can be the same peer that sent the message - // or another peer if the message should be forwarded. This is optional but must be specified + // or another peer if the message should be forwarded. This is optional but MUST be specified // if the ENCRYPTED flag is set. - DhtOrigin origin = 5; + // If an ephemeral_public_key is specified, this MUST be encrypted using a derived ECDH shared key + bytes origin_mac = 5; + // Ephemeral public key component of the ECDH shared key. MUST be specified if the ENCRYPTED flag is set. + bytes ephemeral_public_key = 6; // The type of message - DhtMessageType message_type = 6; + DhtMessageType message_type = 7; // The network for which this message is intended (e.g. TestNet, MainNet etc.) - Network network = 7; - uint32 flags = 8; + Network network = 8; + uint32 flags = 9; } enum Network { @@ -56,7 +59,8 @@ message DhtEnvelope { bytes body = 2; } -message DhtOrigin { +// The Message Authentication Code (MAC) message format of the decrypted `DhtHeader::origin_mac` field +message OriginMac { bytes public_key = 1; bytes signature = 2; } \ No newline at end of file diff --git a/comms/dht/src/proto/tari.dht.envelope.rs b/comms/dht/src/proto/tari.dht.envelope.rs index f7e24f31b2..734bee953f 100644 --- a/comms/dht/src/proto/tari.dht.envelope.rs +++ b/comms/dht/src/proto/tari.dht.envelope.rs @@ -3,17 +3,21 @@ pub struct DhtHeader { #[prost(uint32, tag = "1")] pub version: u32, /// Origin public key of the message. This can be the same peer that sent the message - /// or another peer if the message should be forwarded. This is optional but must be specified + /// or another peer if the message should be forwarded. This is optional but MUST be specified /// if the ENCRYPTED flag is set. - #[prost(message, optional, tag = "5")] - pub origin: ::std::option::Option, + /// If an ephemeral_public_key is specified, this MUST be encrypted using a derived ECDH shared key + #[prost(bytes, tag = "5")] + pub origin_mac: std::vec::Vec, + /// Ephemeral public key component of the ECDH shared key. MUST be specified if the ENCRYPTED flag is set. + #[prost(bytes, tag = "6")] + pub ephemeral_public_key: std::vec::Vec, /// The type of message - #[prost(enumeration = "DhtMessageType", tag = "6")] + #[prost(enumeration = "DhtMessageType", tag = "7")] pub message_type: i32, /// The network for which this message is intended (e.g. TestNet, MainNet etc.) - #[prost(enumeration = "Network", tag = "7")] + #[prost(enumeration = "Network", tag = "8")] pub network: i32, - #[prost(uint32, tag = "8")] + #[prost(uint32, tag = "9")] pub flags: u32, #[prost(oneof = "dht_header::Destination", tags = "2, 3, 4")] pub destination: ::std::option::Option, @@ -40,8 +44,9 @@ pub struct DhtEnvelope { #[prost(bytes, tag = "2")] pub body: std::vec::Vec, } +/// The Message Authentication Code (MAC) message format of the decrypted `DhtHeader::origin_mac` field #[derive(Clone, PartialEq, ::prost::Message)] -pub struct DhtOrigin { +pub struct OriginMac { #[prost(bytes, tag = "1")] pub public_key: std::vec::Vec, #[prost(bytes, tag = "2")] diff --git a/comms/dht/src/schema.rs b/comms/dht/src/schema.rs index 2ebbfb28b0..9f2f9910b1 100644 --- a/comms/dht/src/schema.rs +++ b/comms/dht/src/schema.rs @@ -10,8 +10,7 @@ table! { stored_messages (id) { id -> Integer, version -> Integer, - origin_pubkey -> Text, - origin_signature -> Text, + origin_pubkey -> Nullable, message_type -> Integer, destination_pubkey -> Nullable, destination_node_id -> Nullable, diff --git a/comms/dht/src/store_forward/database/stored_message.rs b/comms/dht/src/store_forward/database/stored_message.rs index d3d9826674..0aa8ce3f92 100644 --- a/comms/dht/src/store_forward/database/stored_message.rs +++ b/comms/dht/src/store_forward/database/stored_message.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - envelope::DhtMessageHeader, + inbound::DecryptedDhtMessage, proto::envelope::DhtHeader, schema::stored_messages, store_forward::message::StoredMessagePriority, @@ -35,8 +35,7 @@ use tari_crypto::tari_utilities::hex::Hex; #[table_name = "stored_messages"] pub struct NewStoredMessage { pub version: i32, - pub origin_pubkey: String, - pub origin_signature: String, + pub origin_pubkey: Option, pub message_type: i32, pub destination_pubkey: Option, pub destination_node_id: Option, @@ -47,27 +46,33 @@ pub struct NewStoredMessage { } impl NewStoredMessage { - pub fn try_construct( - version: u32, - dht_header: DhtMessageHeader, - priority: StoredMessagePriority, - body: Vec, - ) -> Option - { + pub fn try_construct(message: DecryptedDhtMessage, priority: StoredMessagePriority) -> Option { + let DecryptedDhtMessage { + version, + authenticated_origin, + decryption_result, + dht_header, + .. + } = message; + + let body = match decryption_result { + Ok(envelope_body) => envelope_body.to_encoded_bytes(), + Err(encrypted_body) => encrypted_body, + }; + Some(Self { version: version.try_into().ok()?, - origin_pubkey: dht_header.origin.as_ref().map(|o| o.public_key.to_hex())?, - origin_signature: dht_header.origin.as_ref().map(|o| o.signature.to_hex())?, + origin_pubkey: authenticated_origin.as_ref().map(|pk| pk.to_hex()), message_type: dht_header.message_type as i32, destination_pubkey: dht_header.destination.public_key().map(|pk| pk.to_hex()), destination_node_id: dht_header.destination.node_id().map(|node_id| node_id.to_hex()), - body, is_encrypted: dht_header.flags.is_encrypted(), priority: priority as i32, header: { let dht_header: DhtHeader = dht_header.into(); - dht_header.to_encoded_bytes().ok()? + dht_header.to_encoded_bytes() }, + body, }) } } @@ -76,8 +81,7 @@ impl NewStoredMessage { pub struct StoredMessage { pub id: i32, pub version: i32, - pub origin_pubkey: String, - pub origin_signature: String, + pub origin_pubkey: Option, pub message_type: i32, pub destination_pubkey: Option, pub destination_node_id: Option, diff --git a/comms/dht/src/store_forward/error.rs b/comms/dht/src/store_forward/error.rs index a5573b11f6..e2c19ef830 100644 --- a/comms/dht/src/store_forward/error.rs +++ b/comms/dht/src/store_forward/error.rs @@ -36,9 +36,11 @@ pub enum StoreAndForwardError { /// Received stored message has an invalid destination InvalidDestination, /// Received stored message has an invalid origin signature - InvalidSignature, + InvalidOriginMac, /// Invalid envelope body InvalidEnvelopeBody, + /// DHT header is invalid + InvalidDhtHeader, /// Received stored message which is not encrypted StoredMessageNotEncrypted, /// Unable to decrypt received stored message diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index beeee69754..e6983d7ed5 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -147,6 +147,7 @@ where S: Service source_peer, decryption_result, dht_header, + authenticated_origin, .. } = message; @@ -171,8 +172,8 @@ where S: Service .expect("previous check that decryption failed"); let mut excluded_peers = vec![source_peer.public_key.clone()]; - if let Some(origin) = dht_header.origin.as_ref() { - excluded_peers.push(origin.public_key.clone()); + if let Some(pk) = authenticated_origin.as_ref() { + excluded_peers.push(pk.clone()); } let mut message_params = self.get_send_params(&dht_header, excluded_peers).await?; @@ -271,8 +272,13 @@ mod test { let oms = OutboundMessageRequester::new(oms_tx); let mut service = ForwardLayer::new(peer_manager, oms).layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); - let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()).unwrap(), inbound_msg); + let node_identity = make_node_identity(); + let inbound_msg = make_dht_inbound_message(&node_identity, b"".to_vec(), DhtMessageFlags::empty(), false); + let msg = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(Vec::new()), + Some(node_identity.public_key().clone()), + inbound_msg, + ); block_on(service.call(msg)).unwrap(); assert!(spy.is_called()); assert!(oms_rx.try_next().is_err()); @@ -289,7 +295,8 @@ mod test { let mut service = ForwardLayer::new(peer_manager, oms_requester).layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); + let inbound_msg = + make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty(), false); let msg = DecryptedDhtMessage::failed(inbound_msg); rt.block_on(service.call(msg)).unwrap(); assert!(spy.is_called()); diff --git a/comms/dht/src/store_forward/message.rs b/comms/dht/src/store_forward/message.rs index 5cdcc01c3b..41dc126022 100644 --- a/comms/dht/src/store_forward/message.rs +++ b/comms/dht/src/store_forward/message.rs @@ -68,11 +68,11 @@ impl StoredMessagesRequest { #[cfg(test)] impl StoredMessage { - pub fn new(version: u32, dht_header: crate::envelope::DhtMessageHeader, encrypted_body: Vec) -> Self { + pub fn new(version: u32, dht_header: crate::envelope::DhtMessageHeader, body: Vec) -> Self { Self { version, dht_header: Some(dht_header.into()), - body: encrypted_body, + body, stored_at: Some(datetime_to_timestamp(Utc::now())), } } diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index 9a0274037d..1e944b6050 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -24,11 +24,11 @@ use crate::{ actor::DhtRequester, config::DhtConfig, crypt, - envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, NodeDestination}, + envelope::{DhtMessageFlags, DhtMessageHeader, NodeDestination}, inbound::{DecryptedDhtMessage, DhtInboundMessage}, outbound::{OutboundMessageRequester, SendMessageParams}, proto::{ - envelope::DhtMessageType, + envelope::{DhtMessageType, OriginMac}, store_forward::{ stored_messages_response::SafResponseType, StoredMessage as ProtoStoredMessage, @@ -55,9 +55,10 @@ use tari_comms::{ message::EnvelopeBody, peer_manager::{NodeIdentity, Peer, PeerManager, PeerManagerError}, pipeline::PipelineError, - types::Challenge, + types::{Challenge, CommsPublicKey}, utils::signature, }; +use tari_crypto::tari_utilities::ByteArray; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::store_forward"; @@ -373,6 +374,10 @@ where S: Service .try_into() .map_err(StoreAndForwardError::DhtMessageError)?; + if !dht_header.is_valid() { + return Err(StoreAndForwardError::InvalidDhtHeader); + } + if dht_header.message_type.is_dht_message() { trace!( target: LOG_TARGET, @@ -382,26 +387,22 @@ where S: Service ); } - let dht_flags = dht_header.flags; - - let origin = dht_header - .origin - .as_ref() - .ok_or_else(|| StoreAndForwardError::MessageOriginRequired)?; - - // Check that the destination is either undisclosed + // Check that the destination is either undisclosed, for us or for our network region Self::check_destination(&config, &peer_manager, &node_identity, &dht_header).await?; - // Verify the signature - Self::check_signature(origin, &message.body)?; // Check that the message has not already been received. Self::check_duplicate(&mut dht_requester, &message.body).await?; // Attempt to decrypt the message (if applicable), and deserialize it - let decrypted_body = Self::maybe_decrypt_and_deserialize(&node_identity, origin, dht_flags, &message.body)?; + let (authenticated_pk, decrypted_body) = + Self::authenticate_and_decrypt_if_required(&node_identity, &dht_header, &message.body)?; let inbound_msg = DhtInboundMessage::new(dht_header, Arc::clone(&source_peer), message.body); - Ok(DecryptedDhtMessage::succeeded(decrypted_body, inbound_msg)) + Ok(DecryptedDhtMessage::succeeded( + decrypted_body, + authenticated_pk, + inbound_msg, + )) } } @@ -439,34 +440,54 @@ where S: Service } } - fn check_signature(origin: &DhtMessageOrigin, body: &[u8]) -> Result<(), StoreAndForwardError> { - signature::verify(&origin.public_key, &origin.signature, body) - .map_err(|_| StoreAndForwardError::InvalidSignature) - .and_then(|is_valid| { - if is_valid { - Ok(()) - } else { - Err(StoreAndForwardError::InvalidSignature) - } - }) - } - - fn maybe_decrypt_and_deserialize( + fn authenticate_and_decrypt_if_required( node_identity: &NodeIdentity, - origin: &DhtMessageOrigin, - flags: DhtMessageFlags, + header: &DhtMessageHeader, body: &[u8], - ) -> Result + ) -> Result<(Option, EnvelopeBody), StoreAndForwardError> { - if flags.contains(DhtMessageFlags::ENCRYPTED) { - let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), &origin.public_key); + if header.flags.contains(DhtMessageFlags::ENCRYPTED) { + let ephemeral_public_key = header.ephemeral_public_key.as_ref().expect( + "[store and forward] DHT header is invalid after validity check because it did not contain an \ + ephemeral_public_key", + ); + let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), ephemeral_public_key); + let decrypted = crypt::decrypt(&shared_secret, &header.origin_mac)?; + let authenticated_pk = Self::authenticate_message(&decrypted, body)?; + let decrypted_bytes = crypt::decrypt(&shared_secret, body)?; - EnvelopeBody::decode(decrypted_bytes.as_slice()).map_err(|_| StoreAndForwardError::DecryptionFailed) + let envelope_body = + EnvelopeBody::decode(decrypted_bytes.as_slice()).map_err(|_| StoreAndForwardError::DecryptionFailed)?; + if envelope_body.is_empty() { + return Err(StoreAndForwardError::InvalidEnvelopeBody); + } + Ok((Some(authenticated_pk), envelope_body)) } else { - // Malformed cleartext messages should never have been forwarded by the peer - EnvelopeBody::decode(body).map_err(|_| StoreAndForwardError::MalformedMessage) + let authenticated_pk = if !header.origin_mac.is_empty() { + Some(Self::authenticate_message(&header.origin_mac, body)?) + } else { + None + }; + let envelope_body = EnvelopeBody::decode(body).map_err(|_| StoreAndForwardError::MalformedMessage)?; + Ok((authenticated_pk, envelope_body)) } } + + fn authenticate_message(origin_mac_body: &[u8], body: &[u8]) -> Result { + let origin_mac = OriginMac::decode(origin_mac_body)?; + let public_key = + CommsPublicKey::from_bytes(&origin_mac.public_key).map_err(|_| StoreAndForwardError::InvalidOriginMac)?; + signature::verify(&public_key, &origin_mac.signature, body) + .map_err(|_| StoreAndForwardError::InvalidOriginMac) + .and_then(|is_valid| { + if is_valid { + Ok(()) + } else { + Err(StoreAndForwardError::InvalidOriginMac) + } + })?; + Ok(public_key) + } } #[cfg(test)] @@ -481,6 +502,7 @@ mod test { create_store_and_forward_mock, make_dht_header, make_dht_inbound_message, + make_keypair, make_node_identity, make_peer_manager, service_spy, @@ -500,12 +522,11 @@ mod test { StoredMessage { id: 1, version: 0, - origin_pubkey: node_identity.public_key().to_hex(), - origin_signature: String::new(), + origin_pubkey: Some(node_identity.public_key().to_hex()), message_type: DhtMessageType::None as i32, destination_pubkey: None, destination_node_id: None, - header: DhtHeader::from(dht_header).to_encoded_bytes().unwrap(), + header: DhtHeader::from(dht_header).to_encoded_bytes(), body: b"A".to_vec(), is_encrypted: false, priority: StoredMessagePriority::High as i32, @@ -525,15 +546,22 @@ mod test { let node_identity = make_node_identity(); // Recent message - let dht_header = make_dht_header(&node_identity, &[], DhtMessageFlags::empty()); + let (e_sk, e_pk) = make_keypair(); + let dht_header = make_dht_header(&node_identity, &e_pk, &e_sk, &[], DhtMessageFlags::empty(), false); mock_state .add_message(make_stored_message(&node_identity, dht_header)) .await; let since = Utc::now().checked_sub_signed(chrono::Duration::seconds(60)).unwrap(); let mut message = DecryptedDhtMessage::succeeded( - wrap_in_envelope_body!(StoredMessagesRequest::since(since)).unwrap(), - make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::ENCRYPTED), + wrap_in_envelope_body!(StoredMessagesRequest::since(since)), + None, + make_dht_inbound_message( + &node_identity, + b"Keep this for others please".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ), ); message.dht_header.message_type = DhtMessageType::SafRequestMessages; @@ -554,6 +582,7 @@ mod test { rt_handle.spawn(task.run()); let (_, body) = unwrap_oms_send_msg!(oms_rx.next().await.unwrap()); + let body = body.to_vec(); let body = EnvelopeBody::decode(body.as_slice()).unwrap(); let msg = body.decode_part::(0).unwrap().unwrap(); assert_eq!(msg.messages().len(), 1); @@ -578,56 +607,43 @@ mod test { let node_identity = make_node_identity(); - let shared_key = crypt::generate_ecdh_secret(node_identity.secret_key(), node_identity.public_key()); - let msg_a = crypt::encrypt( - &shared_key, - &wrap_in_envelope_body!(&b"A".to_vec()) - .unwrap() - .to_encoded_bytes() - .unwrap(), - ) - .unwrap(); - - let inbound_msg_a = make_dht_inbound_message(&node_identity, msg_a.clone(), DhtMessageFlags::ENCRYPTED); + let msg_a = wrap_in_envelope_body!(&b"A".to_vec()).to_encoded_bytes(); + + let inbound_msg_a = make_dht_inbound_message(&node_identity, msg_a.clone(), DhtMessageFlags::ENCRYPTED, true); // Need to know the peer to process a stored message peer_manager .add_peer(Clone::clone(&*inbound_msg_a.source_peer)) .await .unwrap(); - let msg_b = crypt::encrypt( - &shared_key, - &wrap_in_envelope_body!(b"B".to_vec()) - .unwrap() - .to_encoded_bytes() - .unwrap(), - ) - .unwrap(); - - let inbound_msg_b = make_dht_inbound_message(&node_identity, msg_b.clone(), DhtMessageFlags::ENCRYPTED); + + let msg_b = &wrap_in_envelope_body!(b"B".to_vec()).to_encoded_bytes(); + let inbound_msg_b = make_dht_inbound_message(&node_identity, msg_b.clone(), DhtMessageFlags::ENCRYPTED, true); // Need to know the peer to process a stored message peer_manager .add_peer(Clone::clone(&*inbound_msg_b.source_peer)) .await .unwrap(); - let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), msg_a); - let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, msg_b); + let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), inbound_msg_a.body); + let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, inbound_msg_b.body); // Cleartext message - let clear_msg = wrap_in_envelope_body!(b"Clear".to_vec()) - .unwrap() - .to_encoded_bytes() - .unwrap(); + let clear_msg = wrap_in_envelope_body!(b"Clear".to_vec()).to_encoded_bytes(); let clear_header = - make_dht_inbound_message(&node_identity, clear_msg.clone(), DhtMessageFlags::empty()).dht_header; + make_dht_inbound_message(&node_identity, clear_msg.clone(), DhtMessageFlags::empty(), false).dht_header; let msg_clear = ProtoStoredMessage::new(0, clear_header, clear_msg); let mut message = DecryptedDhtMessage::succeeded( wrap_in_envelope_body!(StoredMessagesResponse { messages: vec![msg1.clone(), msg2, msg_clear], request_id: 123, response_type: 0 - }) - .unwrap(), - make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::ENCRYPTED), + }), + None, + make_dht_inbound_message( + &node_identity, + b"Stored message".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ), ); message.dht_header.message_type = DhtMessageType::SafStoredMessages; diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index dd99f2bb94..0baa43f6f4 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -37,7 +37,6 @@ use futures::{task::Context, Future}; use log::*; use std::{sync::Arc, task::Poll}; use tari_comms::{ - message::MessageExt, peer_manager::{NodeId, NodeIdentity, PeerManager}, pipeline::PipelineError, }; @@ -205,17 +204,21 @@ where S: Service ); }; - if message.body_size() > self.config.saf_max_message_size { + if message.body_len() > self.config.saf_max_message_size { log_not_eligible(&format!( "the message body exceeded the maximum storage size (body size={}, max={})", - message.body_size(), + message.body_len(), self.config.saf_max_message_size )); return Ok(None); } - if message.origin_public_key() == self.node_identity.public_key() { - log_not_eligible("not storing message from this node"); + if message + .authenticated_origin() + .map(|pk| pk == self.node_identity.public_key()) + .unwrap_or(false) + { + log_not_eligible("this message originates from this node"); return Ok(None); } @@ -223,12 +226,12 @@ where S: Service // The message decryption was successful, or the message was not encrypted Some(_) => { // If the message doesnt have an origin we wont store it - if !message.has_origin() { - log_not_eligible("it does not have an origin"); + if !message.has_origin_mac() { + log_not_eligible("it is encrypted and does not have an origin MAC"); return Ok(None); } - // If this node decrypted the message, no need to store it + // If this node decrypted the message (message.success() above), no need to store it if message.is_encrypted() { log_not_eligible("the message was encrypted for this node"); return Ok(None); @@ -251,12 +254,12 @@ where S: Service }, // This node could not decrypt the message None => { - if !message.has_origin() { - // TODO: #banheuristic - the source should not have propagated this message + if !message.has_origin_mac() { + // TODO: #banheuristic - the source peer should not have propagated this message warn!( target: LOG_TARGET, - "Store task received an encrypted message with no source. This message is invalid and should \ - not be stored or propagated. Dropping message. Sent by node '{}'", + "Store task received an encrypted message with no origin MAC. This message is invalid and \ + should not be stored or propagated. Dropping message. Sent by node '{}'", message.source_peer.node_id.short_str() ); return Ok(None); @@ -387,26 +390,14 @@ where S: Service } async fn store(&mut self, priority: StoredMessagePriority, message: DecryptedDhtMessage) -> SafResult<()> { - let DecryptedDhtMessage { - version, - decryption_result, - dht_header, - .. - } = message; - - let body = match decryption_result { - Ok(body) => body.to_encoded_bytes()?, - Err(encrypted_body) => encrypted_body, - }; - debug!( target: LOG_TARGET, "Storing message from peer '{}' ({} bytes)", message.source_peer.node_id.short_str(), - body.len() + message.body_len(), ); - let stored_message = NewStoredMessage::try_construct(version, dht_header, priority, body) + let stored_message = NewStoredMessage::try_construct(message, priority) .ok_or_else(|| StoreAndForwardError::InvalidStoreMessage)?; self.saf_requester.insert_message(stored_message).await?; @@ -430,7 +421,7 @@ mod test { }; use chrono::Utc; use std::time::Duration; - use tari_comms::wrap_in_envelope_body; + use tari_comms::{message::MessageExt, wrap_in_envelope_body}; use tari_crypto::tari_utilities::hex::Hex; use tari_test_utils::async_assert_eventually; @@ -444,9 +435,9 @@ mod test { let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let mut inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); - inbound_msg.dht_header.origin = None; - let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()).unwrap(), inbound_msg); + let inbound_msg = + make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty(), false); + let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()), None, inbound_msg); service.call(msg).await.unwrap(); assert!(spy.is_called()); let messages = mock_state.get_messages().await; @@ -465,14 +456,18 @@ mod test { addresses: vec![], peer_features: 0, } - .to_encoded_bytes() - .unwrap(); + .to_encoded_bytes(); let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); + let sender_identity = make_node_identity(); + let inbound_msg = make_dht_inbound_message(&sender_identity, b"".to_vec(), DhtMessageFlags::empty(), true); - let mut msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(join_msg_bytes).unwrap(), inbound_msg); + let mut msg = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(join_msg_bytes), + Some(sender_identity.public_key().clone()), + inbound_msg, + ); msg.dht_header.message_type = DhtMessageType::Join; service.call(msg).await.unwrap(); assert!(spy.is_called()); @@ -499,8 +494,18 @@ mod test { let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::ENCRYPTED); - let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(b"secret".to_vec()).unwrap(), inbound_msg); + let msg_node_identity = make_node_identity(); + let inbound_msg = make_dht_inbound_message( + &msg_node_identity, + b"This shouldnt be stored".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ); + let msg = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(b"secret".to_vec()), + Some(msg_node_identity.public_key().clone()), + inbound_msg, + ); service.call(msg).await.unwrap(); assert!(spy.is_called()); @@ -518,7 +523,12 @@ mod test { let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let mut inbound_msg = make_dht_inbound_message(&origin_node_identity, b"".to_vec(), DhtMessageFlags::empty()); + let mut inbound_msg = make_dht_inbound_message( + &origin_node_identity, + b"Will you keep this for me?".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ); inbound_msg.dht_header.destination = NodeDestination::PublicKey(Box::new(origin_node_identity.public_key().clone())); let msg = DecryptedDhtMessage::failed(inbound_msg.clone()); @@ -534,8 +544,8 @@ mod test { let message = mock_state.get_messages().await.remove(0); assert_eq!( - message.origin_signature, - inbound_msg.dht_header.origin.unwrap().signature.to_hex() + message.destination_pubkey.unwrap(), + origin_node_identity.public_key().to_hex() ); let duration = Utc::now().naive_utc().signed_duration_since(message.stored_at); assert!(duration.num_seconds() <= 5); diff --git a/comms/dht/src/test_utils/dht_discovery_mock.rs b/comms/dht/src/test_utils/dht_discovery_mock.rs index ae07a83cd0..89c5b6bd06 100644 --- a/comms/dht/src/test_utils/dht_discovery_mock.rs +++ b/comms/dht/src/test_utils/dht_discovery_mock.rs @@ -102,8 +102,7 @@ impl DhtDiscoveryMock { trace!(target: LOG_TARGET, "DhtDiscoveryMock received request {:?}", req); self.state.inc_call_count(); match req { - DiscoverPeer(boxed) => { - let (_, reply_tx) = *boxed; + DiscoverPeer(_, reply_tx) => { let lock = self.state.discover_peer.read().unwrap(); reply_tx.send(Ok(lock.clone())).unwrap(); }, diff --git a/comms/dht/src/test_utils/makers.rs b/comms/dht/src/test_utils/makers.rs index 862ecb474e..e56e1d76a3 100644 --- a/comms/dht/src/test_utils/makers.rs +++ b/comms/dht/src/test_utils/makers.rs @@ -20,21 +20,27 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, NodeDestination}, + crypt, + envelope::{DhtMessageFlags, DhtMessageHeader, NodeDestination}, inbound::DhtInboundMessage, - proto::envelope::{DhtEnvelope, DhtMessageType, Network}, + outbound::message::DhtOutboundMessage, + proto::envelope::{DhtEnvelope, DhtMessageType, Network, OriginMac}, }; use rand::rngs::OsRng; -use std::sync::Arc; +use std::{convert::TryInto, sync::Arc}; use tari_comms::{ - message::InboundMessage, + message::{InboundMessage, MessageExt, MessageTag}, multiaddr::Multiaddr, - peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerFlags, PeerManager}, - types::CommsDatabase, + net_address::MultiaddressesWithStats, + peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags, PeerManager}, + types::{CommsDatabase, CommsPublicKey, CommsSecretKey}, utils::signature, Bytes, }; -use tari_crypto::tari_utilities::message_format::MessageFormat; +use tari_crypto::{ + keys::PublicKey, + tari_utilities::{message_format::MessageFormat, ByteArray}, +}; use tari_storage::lmdb_store::LMDBBuilder; use tari_test_utils::{paths::create_temporary_data_path, random}; @@ -86,31 +92,63 @@ pub fn make_comms_inbound_message(node_identity: &NodeIdentity, message: Bytes) ) } -pub fn make_dht_header(node_identity: &NodeIdentity, message: &[u8], flags: DhtMessageFlags) -> DhtMessageHeader { +pub fn make_dht_header( + node_identity: &NodeIdentity, + e_pk: &CommsPublicKey, + e_sk: &CommsSecretKey, + message: &[u8], + flags: DhtMessageFlags, + include_origin: bool, +) -> DhtMessageHeader +{ DhtMessageHeader { version: 0, destination: NodeDestination::Unknown, - origin: Some(DhtMessageOrigin { - public_key: node_identity.public_key().clone(), - signature: signature::sign(&mut OsRng, node_identity.secret_key().clone(), message) - .unwrap() - .to_binary() - .unwrap(), - }), + ephemeral_public_key: if flags.is_encrypted() { Some(e_pk.clone()) } else { None }, + origin_mac: if include_origin { + make_valid_origin_mac(node_identity, &e_sk, message, flags) + } else { + Vec::new() + }, message_type: DhtMessageType::None, network: Network::LocalTest, flags, } } +pub fn make_valid_origin_mac( + node_identity: &NodeIdentity, + e_sk: &CommsSecretKey, + body: &[u8], + flags: DhtMessageFlags, +) -> Vec +{ + let mac = OriginMac { + public_key: node_identity.public_key().to_vec(), + signature: signature::sign(&mut OsRng, node_identity.secret_key().clone(), body) + .unwrap() + .to_binary() + .unwrap(), + }; + let body = mac.to_encoded_bytes(); + if flags.is_encrypted() { + let shared_secret = crypt::generate_ecdh_secret(e_sk, node_identity.public_key()); + crypt::encrypt(&shared_secret, &body).unwrap() + } else { + body + } +} + pub fn make_dht_inbound_message( node_identity: &NodeIdentity, body: Vec, flags: DhtMessageFlags, + include_origin: bool, ) -> DhtInboundMessage { + let envelope = make_dht_envelope(node_identity, body, flags, include_origin); DhtInboundMessage::new( - make_dht_header(node_identity, &body, flags), + envelope.header.unwrap().try_into().unwrap(), Arc::new(Peer::new( node_identity.public_key().clone(), node_identity.node_id().clone(), @@ -119,12 +157,28 @@ pub fn make_dht_inbound_message( PeerFeatures::COMMUNICATION_NODE, &[], )), - body, + envelope.body, ) } -pub fn make_dht_envelope(node_identity: &NodeIdentity, message: Vec, flags: DhtMessageFlags) -> DhtEnvelope { - DhtEnvelope::new(make_dht_header(node_identity, &message, flags).into(), message) +pub fn make_keypair() -> (CommsSecretKey, CommsPublicKey) { + CommsPublicKey::random_keypair(&mut OsRng) +} + +pub fn make_dht_envelope( + node_identity: &NodeIdentity, + mut message: Vec, + flags: DhtMessageFlags, + include_origin: bool, +) -> DhtEnvelope +{ + let (e_sk, e_pk) = make_keypair(); + if flags.is_encrypted() { + let shared_secret = crypt::generate_ecdh_secret(&e_sk, node_identity.public_key()); + message = crypt::encrypt(&shared_secret, &message).unwrap(); + } + let header = make_dht_header(node_identity, &e_pk, &e_sk, &message, flags, include_origin).into(); + DhtEnvelope::new(header, message.into()) } pub fn make_peer_manager() -> Arc { @@ -144,3 +198,27 @@ pub fn make_peer_manager() -> Arc { .map(Arc::new) .unwrap() } + +pub fn create_outbound_message(body: &[u8]) -> DhtOutboundMessage { + DhtOutboundMessage { + tag: MessageTag::new(), + destination_peer: Peer::new( + CommsPublicKey::default(), + NodeId::default(), + MultiaddressesWithStats::new(vec![]), + PeerFlags::empty(), + PeerFeatures::COMMUNICATION_NODE, + &[], + ), + destination: Default::default(), + dht_message_type: Default::default(), + network: Network::LocalTest, + dht_flags: Default::default(), + custom_header: None, + include_origin: false, + encryption: Default::default(), + body: body.to_vec().into(), + ephemeral_public_key: None, + origin_mac: None, + } +} diff --git a/comms/dht/src/test_utils/store_and_forward_mock.rs b/comms/dht/src/test_utils/store_and_forward_mock.rs index e338b323c4..66fe181a0b 100644 --- a/comms/dht/src/test_utils/store_and_forward_mock.rs +++ b/comms/dht/src/test_utils/store_and_forward_mock.rs @@ -120,7 +120,6 @@ impl StoreAndForwardMock { id: OsRng.next_u32() as i32, version: msg.version, origin_pubkey: msg.origin_pubkey, - origin_signature: msg.origin_signature, message_type: msg.message_type, destination_pubkey: msg.destination_pubkey, destination_node_id: msg.destination_node_id, diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index 75960b9f9c..521a07f6f8 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -33,7 +33,7 @@ use tari_comms::{ CommsBuilder, CommsNode, }; -use tari_comms_dht::{envelope::NodeDestination, inbound::DecryptedDhtMessage, Dht, DhtBuilder}; +use tari_comms_dht::{inbound::DecryptedDhtMessage, Dht, DhtBuilder}; use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; use tari_test_utils::{async_assert_eventually, paths::create_temporary_data_path, random}; use tower::ServiceBuilder; @@ -213,12 +213,7 @@ async fn dht_discover_propagation() { node_A .dht .discovery_service_requester() - .discover_peer( - Box::new(node_D.node_identity().public_key().clone()), - None, - // Sending to a nonsense NodeId, this should still propagate towards D in a network of 4 - NodeDestination::NodeId(Box::new(Default::default())), - ) + .discover_peer(Box::new(node_D.node_identity().public_key().clone())) .await .unwrap(); diff --git a/comms/src/connection_manager/manager.rs b/comms/src/connection_manager/manager.rs index dc5c54afe0..40b9bb934a 100644 --- a/comms/src/connection_manager/manager.rs +++ b/comms/src/connection_manager/manager.rs @@ -412,7 +412,7 @@ where self.send_dialer_request(DialerRequest::CancelPendingDial(node_id.clone())) .await; - match self.active_connections.remove(&node_id) { + match self.active_connections.get(&node_id) { Some(existing_conn) => { debug!( target: LOG_TARGET, @@ -421,7 +421,7 @@ where existing_conn.peer_node_id() ); - if self.tie_break_existing_connection(&existing_conn, &new_conn) { + if self.tie_break_existing_connection(existing_conn, &new_conn) { debug!( target: LOG_TARGET, "Disconnecting existing {} connection to peer '{}' because of simultaneous dial", @@ -434,8 +434,14 @@ where Box::new(existing_conn.peer_node_id().clone()), existing_conn.direction(), )); + + // Replace existing connection with new one + let existing_conn = self + .active_connections + .insert(node_id, new_conn.clone()) + .expect("Already checked"); + self.delayed_disconnect(existing_conn); - self.active_connections.insert(node_id, new_conn.clone()); self.publish_event(PeerConnected(new_conn)); } else { debug!( @@ -447,7 +453,6 @@ where ); self.delayed_disconnect(new_conn); - self.active_connections.insert(node_id, existing_conn); } }, None => { @@ -586,19 +591,8 @@ where return; } - if let Err(err) = self.dialer_tx.try_send(DialerRequest::Dial(Box::new(peer), reply_tx)) { + if let Err(err) = self.dialer_tx.send(DialerRequest::Dial(Box::new(peer), reply_tx)).await { error!(target: LOG_TARGET, "Failed to send request to dialer because '{}'", err); - // TODO: If the channel is full - we'll fail to dial. This function should block until the dial - // request channel has cleared - - if let DialerRequest::Dial(_, reply_tx) = err.into_inner() { - log_if_error_fmt!( - target: LOG_TARGET, - reply_tx.send(Err(ConnectionManagerError::EstablisherChannelError)), - "Failed to send dial peer result for peer '{}'", - node_id.short_str() - ); - } } }, Err(err) => { diff --git a/comms/src/message/envelope.rs b/comms/src/message/envelope.rs index ec1076eb3f..61a2c0d749 100644 --- a/comms/src/message/envelope.rs +++ b/comms/src/message/envelope.rs @@ -31,22 +31,10 @@ macro_rules! wrap_in_envelope_body { ($($e:expr),+) => {{ use $crate::message::MessageExt; let mut envelope_body = $crate::message::EnvelopeBody::new(); - let mut error = None; $( - match $e.to_encoded_bytes() { - Ok(bytes) => envelope_body.push_part(bytes), - Err(err) => { - if error.is_none() { - error = Some(err); - } - } - } + envelope_body.push_part($e.to_encoded_bytes()); )* - - match error { - Some(err) => Err(err), - None => Ok(envelope_body), - } + envelope_body }} } diff --git a/comms/src/message/error.rs b/comms/src/message/error.rs index 76b23f2761..f56c4cdb4d 100644 --- a/comms/src/message/error.rs +++ b/comms/src/message/error.rs @@ -22,7 +22,7 @@ use crate::peer_manager::node_id::NodeIdError; use derive_error::Error; -use prost::{DecodeError, EncodeError}; +use prost::DecodeError; use tari_crypto::{ signatures::SchnorrSignatureError, tari_utilities::{ciphers::cipher::CipherError, message_format::MessageFormatError}, @@ -53,8 +53,6 @@ pub enum MessageError { InvalidHeaderPublicKey, /// Failed to decode protobuf message DecodeError(DecodeError), - /// Failed to encode protobuf message - EncodeError(EncodeError), /// Failed to decode message part of envelope body EnvelopeBodyDecodeFailed, } diff --git a/comms/src/message/mod.rs b/comms/src/message/mod.rs index edeff8a439..3f05d33a0e 100644 --- a/comms/src/message/mod.rs +++ b/comms/src/message/mod.rs @@ -78,11 +78,14 @@ pub use tag::MessageTag; pub trait MessageExt: prost::Message { /// Encodes a message, allocating the buffer on the heap as necessary - fn to_encoded_bytes(&self) -> Result, MessageError> + fn to_encoded_bytes(&self) -> Vec where Self: Sized { let mut buf = Vec::with_capacity(self.encoded_len()); - self.encode(&mut buf)?; - Ok(buf) + self.encode(&mut buf).expect( + "prost::Message::encode documentation says it is infallible unless the buffer has insufficient capacity. \ + This buffer's capacity was set with encoded_len", + ); + buf } } impl MessageExt for T {} diff --git a/comms/src/peer_manager/peer.rs b/comms/src/peer_manager/peer.rs index d97fadbb4e..b3a89ad077 100644 --- a/comms/src/peer_manager/peer.rs +++ b/comms/src/peer_manager/peer.rs @@ -213,16 +213,22 @@ impl Peer { /// Display Peer as `[peer_id]: ` impl Display for Peer { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let flags_str = if self.flags == PeerFlags::empty() { + "".to_string() + } else { + format!("{:?}", self.flags) + }; f.write_str(&format!( "{}[{}] PK={} {} {:?} {}", - if self.is_banned() { "BANNED " } else { "" }, + flags_str, self.node_id.short_str(), self.public_key, self.addresses - .address_iter() - .next() + .addresses + .iter() .map(ToString::to_string) - .unwrap_or_else(|| "".to_string()), + .collect::>() + .join(","), match self.features { PeerFeatures::COMMUNICATION_NODE => "BASE_NODE".to_string(), PeerFeatures::COMMUNICATION_CLIENT => "WALLET".to_string(), diff --git a/comms/src/pipeline/error.rs b/comms/src/pipeline/error.rs index ccb42e506b..f225b27760 100644 --- a/comms/src/pipeline/error.rs +++ b/comms/src/pipeline/error.rs @@ -22,7 +22,6 @@ use std::{error, fmt}; -#[derive(Debug)] pub struct PipelineError { err_string: String, } @@ -35,6 +34,13 @@ impl PipelineError { } } +impl fmt::Debug for PipelineError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("PipelineError: ")?; + f.write_str(&self.err_string) + } +} + impl From<&str> for PipelineError { fn from(s: &str) -> Self { Self { diff --git a/comms/src/protocol/identity.rs b/comms/src/protocol/identity.rs index 78f734a77c..bd52299620 100644 --- a/comms/src/protocol/identity.rs +++ b/comms/src/protocol/identity.rs @@ -88,8 +88,7 @@ where features: node_identity.features().bits(), supported_protocols, } - .to_encoded_bytes() - .map_err(|_| IdentityProtocolError::ProtobufEncodingError)?; + .to_encoded_bytes(); sink.send(msg_bytes.into()).await?; sink.close().await?;