Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ephemeral key for private messages (e.g Discovery) #1686

Merged
merged 1 commit into from
Apr 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions applications/tari_base_node/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use tari_comms::{
types::CommsPublicKey,
NodeIdentity,
};
use tari_comms_dht::{envelope::NodeDestination, DhtDiscoveryRequester};
use tari_comms_dht::DhtDiscoveryRequester;
use tari_core::{
base_node::LocalNodeCommsInterface,
blocks::BlockHeader,
Expand Down Expand Up @@ -486,7 +486,7 @@ impl Parser {
self.executor.spawn(async move {
let start = Instant::now();
println!("🌎 Peer discovery started.");
match dht.discover_peer(dest_pubkey, None, NodeDestination::Unknown).await {
match dht.discover_peer(dest_pubkey).await {
Ok(p) => {
let end = Instant::now();
println!("⚡️ Discovery succeeded in {}ms!", (end - start).as_millis());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl ChainMetadataService {
/// Send this node's metadata to
async fn update_liveness_chain_metadata(&mut self) -> Result<(), ChainMetadataSyncError> {
let chain_metadata = self.base_node.get_metadata().await?;
let bytes = proto::ChainMetadata::from(chain_metadata).to_encoded_bytes()?;
let bytes = proto::ChainMetadata::from(chain_metadata).to_encoded_bytes();
self.liveness
.set_pong_metadata_entry(MetadataKey::ChainMetadata, bytes)
.await?;
Expand Down Expand Up @@ -300,10 +300,7 @@ mod test {
let (liveness_handle, _) = create_p2p_liveness_mock(1);
let mut metadata = Metadata::new();
let proto_chain_metadata = create_sample_proto_chain_metadata();
metadata.insert(
MetadataKey::ChainMetadata,
proto_chain_metadata.to_encoded_bytes().unwrap(),
);
metadata.insert(MetadataKey::ChainMetadata, proto_chain_metadata.to_encoded_bytes());

let node_id = NodeId::new();
let pong_event = PongEvent {
Expand Down
1 change: 1 addition & 0 deletions base_layer/core/src/base_node/service/initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ async fn extract_block(msg: Arc<PeerMessage>) -> Option<DomainMessage<Block>> {
Some(DomainMessage {
source_peer: msg.source_peer.clone(),
dht_header: msg.dht_header.clone(),
authenticated_origin: msg.authenticated_origin.clone(),
inner: block,
})
},
Expand Down
4 changes: 2 additions & 2 deletions base_layer/core/src/mempool/service/initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,9 @@ async fn extract_transaction(msg: Arc<PeerMessage>) -> Option<DomainMessage<Tran
Ok(tx) => {
let tx = match Transaction::try_from(tx) {
Err(e) => {
let origin = msg.origin_public_key();
warn!(
target: LOG_TARGET,
"Inbound transaction message from {} was ill-formed. {}", origin, e
"Inbound transaction message from {} was ill-formed. {}", msg.source_peer.public_key, e
);
return None;
},
Expand All @@ -133,6 +132,7 @@ async fn extract_transaction(msg: Arc<PeerMessage>) -> Option<DomainMessage<Tran
Some(DomainMessage {
source_peer: msg.source_peer.clone(),
dht_header: msg.dht_header.clone(),
authenticated_origin: msg.authenticated_origin.clone(),
inner: tx,
})
},
Expand Down
48 changes: 17 additions & 31 deletions base_layer/p2p/src/comms_connector/inbound_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,14 @@ impl<TSink> InboundDomainConnector<TSink> {
let DecryptedDhtMessage {
source_peer,
dht_header,
authenticated_origin,
..
} = inbound_message;

let peer_message = PeerMessage {
message_header: header,
source_peer: Clone::clone(&*source_peer),
authenticated_origin,
dht_header,
body: msg_bytes,
};
Expand Down Expand Up @@ -141,21 +143,17 @@ mod test {
use crate::test_utils::{make_dht_inbound_message, make_node_identity};
use futures::{channel::mpsc, executor::block_on, StreamExt};
use tari_comms::{message::MessageExt, wrap_in_envelope_body};
use tari_comms_dht::{domain_message::MessageHeader, envelope::DhtMessageFlags};
use tari_comms_dht::domain_message::MessageHeader;
use tower::ServiceExt;

#[tokio_macros::test_basic]
async fn handle_message() {
let (tx, mut rx) = mpsc::channel(1);
let header = MessageHeader::new(123);
let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap();

let inbound_message = make_dht_inbound_message(
&make_node_identity(),
msg.to_encoded_bytes().unwrap(),
DhtMessageFlags::empty(),
);
let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message);
let msg = wrap_in_envelope_body!(header, b"my message".to_vec());

let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes());
let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message);
InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap();

let peer_message = block_on(rx.next()).unwrap();
Expand All @@ -167,14 +165,10 @@ mod test {
async fn send_on_sink() {
let (tx, mut rx) = mpsc::channel(1);
let header = MessageHeader::new(123);
let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap();
let msg = wrap_in_envelope_body!(header, b"my message".to_vec());

let inbound_message = make_dht_inbound_message(
&make_node_identity(),
msg.to_encoded_bytes().unwrap(),
DhtMessageFlags::empty(),
);
let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message);
let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes());
let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message);

InboundDomainConnector::new(tx).send(decrypted).await.unwrap();

Expand All @@ -187,14 +181,10 @@ mod test {
async fn handle_message_fail_deserialize() {
let (tx, mut rx) = mpsc::channel(1);
let header = b"dodgy header".to_vec();
let msg = wrap_in_envelope_body!(header, b"message".to_vec()).unwrap();

let inbound_message = make_dht_inbound_message(
&make_node_identity(),
msg.to_encoded_bytes().unwrap(),
DhtMessageFlags::empty(),
);
let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message);
let msg = wrap_in_envelope_body!(header, b"message".to_vec());

let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes());
let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message);
InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap_err();

assert!(rx.try_next().unwrap().is_none());
Expand All @@ -206,13 +196,9 @@ mod test {
// from it's call function
let (tx, _) = mpsc::channel(1);
let header = MessageHeader::new(123);
let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap();
let inbound_message = make_dht_inbound_message(
&make_node_identity(),
msg.to_encoded_bytes().unwrap(),
DhtMessageFlags::empty(),
);
let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message);
let msg = wrap_in_envelope_body!(header, b"my message".to_vec());
let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes());
let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message);
let result = InboundDomainConnector::new(tx).oneshot(decrypted).await;
assert!(result.is_err());
}
Expand Down
19 changes: 2 additions & 17 deletions base_layer/p2p/src/comms_connector/peer_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,13 @@ pub struct PeerMessage {
pub source_peer: Peer,
/// Domain message header
pub message_header: MessageHeader,
/// This messages authenticated origin, otherwise None
pub authenticated_origin: Option<CommsPublicKey>,
/// Serialized message data
pub body: Vec<u8>,
}

impl PeerMessage {
pub fn new(dht_header: DhtMessageHeader, source_peer: Peer, message_header: MessageHeader, body: Vec<u8>) -> Self {
Self {
body,
message_header,
dht_header,
source_peer,
}
}

pub fn origin_public_key(&self) -> &CommsPublicKey {
self.dht_header
.origin
.as_ref()
.map(|o| &o.public_key)
.unwrap_or(&self.source_peer.public_key)
}

pub fn decode_message<T>(&self) -> Result<T, prost::DecodeError>
where T: prost::Message + Default {
let msg = T::decode(self.body.as_slice())?;
Expand Down
25 changes: 6 additions & 19 deletions base_layer/p2p/src/domain_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pub struct DomainMessage<T> {
/// This DHT header of this message. If `DhtMessageHeader::origin_public_key` is different from the
/// `source_peer.public_key`, this message was forwarded.
pub dht_header: DhtMessageHeader,
/// The authenticated origin public key of this message or None a message origin was not provided.
pub authenticated_origin: Option<CommsPublicKey>,
/// The domain-level message
pub inner: T,
}
Expand All @@ -48,32 +50,15 @@ impl<T> DomainMessage<T> {
/// Consumes this object returning the public key of the original sender of this message and the message itself
pub fn into_origin_and_inner(self) -> (CommsPublicKey, T) {
let inner = self.inner;
let pk = self
.dht_header
.origin
.map(|o| o.public_key)
.unwrap_or(self.source_peer.public_key);
let pk = self.authenticated_origin.unwrap_or(self.source_peer.public_key);
(pk, inner)
}

/// Returns true of this message was forwarded from another peer, otherwise false
pub fn is_forwarded(&self) -> bool {
self.dht_header
.origin
.as_ref()
// If the source and origin are different, then the message was forwarded
.map(|o| o.public_key != self.source_peer.public_key)
// Otherwise, if no origin is specified, the message was sent directly from the peer
.unwrap_or(false)
}

/// Returns the public key that sent this message. If no origin is specified, then the source peer
/// sent this message.
pub fn origin_public_key(&self) -> &CommsPublicKey {
self.dht_header
.origin
self.authenticated_origin
.as_ref()
.map(|o| &o.public_key)
.unwrap_or(&self.source_peer.public_key)
}

Expand All @@ -88,6 +73,7 @@ impl<T> DomainMessage<T> {
DomainMessage {
source_peer: self.source_peer,
dht_header: self.dht_header,
authenticated_origin: self.authenticated_origin,
inner,
}
}
Expand All @@ -103,6 +89,7 @@ impl<T> DomainMessage<T> {
Ok(DomainMessage {
source_peer: self.source_peer,
dht_header: self.dht_header,
authenticated_origin: self.authenticated_origin,
inner,
})
}
Expand Down
17 changes: 10 additions & 7 deletions base_layer/p2p/src/services/liveness/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,16 @@ mod test {
&[],
);
DomainMessage {
dht_header: DhtMessageHeader::new(
Default::default(),
DhtMessageType::None,
None,
Network::LocalTest,
Default::default(),
),
dht_header: DhtMessageHeader {
version: 0,
destination: Default::default(),
origin_mac: Vec::new(),
ephemeral_public_key: None,
message_type: DhtMessageType::None,
network: Network::LocalTest,
flags: Default::default(),
},
authenticated_origin: None,
source_peer,
inner,
}
Expand Down
1 change: 1 addition & 0 deletions base_layer/p2p/src/services/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ where T: prost::Message + Default {
Ok(DomainMessage {
source_peer: serialized.source_peer.clone(),
dht_header: serialized.dht_header.clone(),
authenticated_origin: serialized.authenticated_origin.clone(),
inner: serialized.decode_message()?,
})
}
Expand Down
26 changes: 7 additions & 19 deletions base_layer/p2p/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,11 @@ use std::sync::Arc;
use tari_comms::{
multiaddr::Multiaddr,
peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerFlags},
utils::signature,
};
use tari_comms_dht::{
envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, DhtMessageType, Network, NodeDestination},
envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageType, Network, NodeDestination},
inbound::DhtInboundMessage,
};
use tari_crypto::tari_utilities::message_format::MessageFormat;

macro_rules! unwrap_oms_send_msg {
($var:expr, reply_value=$reply_value:expr) => {
Expand Down Expand Up @@ -61,31 +59,21 @@ pub fn make_node_identity() -> Arc<NodeIdentity> {
)
}

pub fn make_dht_header(node_identity: &NodeIdentity, message: &Vec<u8>, flags: DhtMessageFlags) -> DhtMessageHeader {
pub fn make_dht_header() -> DhtMessageHeader {
DhtMessageHeader {
version: 0,
destination: NodeDestination::Unknown,
origin: Some(DhtMessageOrigin {
public_key: node_identity.public_key().clone(),
signature: signature::sign(&mut OsRng, node_identity.secret_key().clone(), message)
.unwrap()
.to_binary()
.unwrap(),
}),
origin_mac: Vec::new(),
ephemeral_public_key: None,
message_type: DhtMessageType::None,
network: Network::LocalTest,
flags,
flags: DhtMessageFlags::NONE,
}
}

pub fn make_dht_inbound_message(
node_identity: &NodeIdentity,
message: Vec<u8>,
flags: DhtMessageFlags,
) -> DhtInboundMessage
{
pub fn make_dht_inbound_message(node_identity: &NodeIdentity, message: Vec<u8>) -> DhtInboundMessage {
DhtInboundMessage::new(
make_dht_header(node_identity, &message, flags),
make_dht_header(),
Arc::new(Peer::new(
node_identity.public_key().clone(),
node_identity.node_id().clone(),
Expand Down
8 changes: 4 additions & 4 deletions base_layer/wallet/tests/output_manager_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,8 @@ fn test_startup_utxo_scan() {
let output3 = UnblindedOutput::new(MicroTari::from(value3), key3, None);
runtime.block_on(oms.add_output(output3.clone())).unwrap();

let call = outbound_service.pop_call().unwrap();
let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap();
let (_, body) = outbound_service.pop_call().unwrap();
let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap();
let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body
.decode_part::<BaseNodeProto::BaseNodeServiceRequest>(1)
.unwrap()
Expand Down Expand Up @@ -755,8 +755,8 @@ fn test_startup_utxo_scan() {

runtime.block_on(oms.sync_with_base_node()).unwrap();

let call = outbound_service.pop_call().unwrap();
let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap();
let (_, body) = outbound_service.pop_call().unwrap();
let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap();
let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body
.decode_part::<BaseNodeProto::BaseNodeServiceRequest>(1)
.unwrap()
Expand Down
4 changes: 3 additions & 1 deletion base_layer/wallet/tests/support/comms_and_services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@ pub fn create_dummy_message<T>(inner: T, public_key: &CommsPublicKey) -> DomainM
);
DomainMessage {
dht_header: DhtMessageHeader {
origin: None,
ephemeral_public_key: None,
origin_mac: Vec::new(),
version: Default::default(),
message_type: Default::default(),
flags: Default::default(),
network: Network::LocalTest,
destination: Default::default(),
},
authenticated_origin: None,
source_peer: peer_source,
inner,
}
Expand Down
Loading