Skip to content

Commit

Permalink
fix(dht/encryption): greatly reduce heap allocations for encrypted me…
Browse files Browse the repository at this point in the history
…ssaging (#4753)

Description
---
- Encrypt, decrypt and message padding mutate a single buffer for encrypted messages

Motivation and Context
---
Encrypted message handling should be as efficient as possible. The previous implementation performed allocations of the full padded message size twice for encryption and twice for decryption. Increasing memory usage, and negating the performance benefits of using an encryption keystream.

This PR allocates a single buffer for the message to be de/encrypted and de/encrypts the contents in-place using the BytesMut type from the `bytes` crate.

How Has This Been Tested?
---
This change is backwards compatible, tested on current esme network and updated existing tests as required.
Discovery: OK
Memorynet: OK
PingPong: OK
InteractiveTransactions: OK
SafTransactions: OK
  • Loading branch information
sdbondi authored Oct 3, 2022
1 parent 60c3df4 commit 195df85
Show file tree
Hide file tree
Showing 23 changed files with 383 additions and 272 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion comms/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ pub mod multiaddr {
}

pub use async_trait::async_trait;
pub use bytes::{Bytes, BytesMut};
pub use bytes::{Buf, BufMut, Bytes, BytesMut};
#[cfg(feature = "rpc")]
pub use tower::make::MakeService;
13 changes: 13 additions & 0 deletions comms/core/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#[macro_use]
mod envelope;

use bytes::BytesMut;
pub use envelope::EnvelopeBody;

mod error;
Expand All @@ -52,5 +54,16 @@ pub trait MessageExt: prost::Message {
);
buf
}

/// Encodes a message into a BytesMut, allocating the buffer on the heap as necessary.
fn encode_into_bytes_mut(&self) -> BytesMut
where Self: Sized {
let mut buf = BytesMut::with_capacity(self.encoded_len());
self.encode(&mut buf).expect(
"prost::Message::encode documentation says it is infallible unless the buffer has insufficient capacity. \
This buffer's capacity was set with encoded_len",
);
buf
}
}
impl<T: prost::Message> MessageExt for T {}
2 changes: 1 addition & 1 deletion comms/core/src/protocol/rpc/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ impl BodyBytes {
}

pub fn into_vec(self) -> Vec<u8> {
self.0.map(|bytes| bytes.to_vec()).unwrap_or_else(Vec::new)
self.0.map(|bytes| bytes.into()).unwrap_or_else(Vec::new)
}

pub fn into_bytes(self) -> Option<Bytes> {
Expand Down
1 change: 0 additions & 1 deletion comms/dht/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ tari_common_sqlite = { path = "../../common_sqlite" }

anyhow = "1.0.53"
bitflags = "1.2.0"
bytes = "0.5"
chacha20 = "0.7.1"
chacha20poly1305 = "0.9.1"
chrono = { version = "0.4.19", default-features = false }
Expand Down
299 changes: 169 additions & 130 deletions comms/dht/src/crypt.rs

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions comms/dht/src/dedup/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ mod test {
assert!(dedup.poll_ready(&mut cx).is_ready());
let node_identity = make_node_identity();
let inbound_message =
make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty(), false, false).unwrap();
make_dht_inbound_message(&node_identity, &vec![], DhtMessageFlags::empty(), false, false).unwrap();
let decrypted_msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, inbound_message);

rt.block_on(dedup.call(decrypted_msg.clone())).unwrap();
Expand All @@ -213,12 +213,12 @@ mod test {
#[test]
fn deterministic_hash() {
const TEST_MSG: &[u8] = b"test123";
const EXPECTED_HASH: &str = "d6333668f259f677703fbe4e89152ee41c7c01f6dec502befc63120246523ffe";
const EXPECTED_HASH: &str = "1c2bb1bcff443af4441b789bd1d6984bb8d7bed2c9f85e8cf4f45615fdd9e47d";

let node_identity = make_node_identity();
let dht_message = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
&TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
Expand All @@ -229,7 +229,7 @@ mod test {
let node_identity = make_node_identity();
let dht_message = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
&TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
Expand Down
11 changes: 6 additions & 5 deletions comms/dht/src/dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ mod test {
let msg = wrap_in_envelope_body!(b"secret".to_vec());
let dht_envelope = make_dht_envelope(
&node_identity,
msg.to_encoded_bytes(),
&msg,
DhtMessageFlags::empty(),
false,
MessageTag::new(),
Expand Down Expand Up @@ -546,7 +546,7 @@ mod test {
// Encrypt for self
let dht_envelope = make_dht_envelope(
&node_identity,
msg.to_encoded_bytes(),
&msg,
DhtMessageFlags::ENCRYPTED,
true,
MessageTag::new(),
Expand Down Expand Up @@ -602,10 +602,11 @@ mod test {
let node_identity2 = make_node_identity();
let ecdh_key = crypt::generate_ecdh_secret(node_identity2.secret_key(), node_identity2.public_key());
let key_message = crypt::generate_key_message(&ecdh_key);
let encrypted_bytes = crypt::encrypt(&key_message, &msg.to_encoded_bytes()).unwrap();
let mut encrypted_bytes = msg.encode_into_bytes_mut();
crypt::encrypt(&key_message, &mut encrypted_bytes).unwrap();
let dht_envelope = make_dht_envelope(
&node_identity2,
encrypted_bytes,
&encrypted_bytes.to_vec(),
DhtMessageFlags::ENCRYPTED,
true,
MessageTag::new(),
Expand Down Expand Up @@ -667,7 +668,7 @@ mod test {
let msg = wrap_in_envelope_body!(b"secret".to_vec());
let mut dht_envelope = make_dht_envelope(
&node_identity,
msg.to_encoded_bytes(),
&msg,
DhtMessageFlags::empty(),
false,
MessageTag::new(),
Expand Down
5 changes: 2 additions & 3 deletions comms/dht/src/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use std::{
};

use bitflags::bitflags;
use bytes::Bytes;
use chrono::{DateTime, NaiveDateTime, Utc};
use prost_types::Timestamp;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -249,10 +248,10 @@ impl From<DhtMessageHeader> for DhtHeader {
}

impl DhtEnvelope {
pub fn new(header: DhtHeader, body: &Bytes) -> Self {
pub fn new(header: DhtHeader, body: Vec<u8>) -> Self {
Self {
header: Some(header),
body: body.to_vec(),
body,
}
}
}
Expand Down
51 changes: 23 additions & 28 deletions comms/dht/src/inbound/decryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use tari_comms::{
message::EnvelopeBody,
peer_manager::NodeIdentity,
pipeline::PipelineError,
BytesMut,
};
use thiserror::Error;
use tower::{layer::Layer, Service, ServiceExt};
Expand Down Expand Up @@ -406,11 +407,11 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
message_body: &[u8],
) -> Result<EnvelopeBody, DecryptionError> {
let key_message = crypt::generate_key_message(shared_secret);
let decrypted =
crypt::decrypt(&key_message, message_body).map_err(DecryptionError::DecryptionFailedMalformedCipher)?;
let mut decrypted = BytesMut::from(message_body);
crypt::decrypt(&key_message, &mut decrypted).map_err(DecryptionError::DecryptionFailedMalformedCipher)?;
// Deserialization into an EnvelopeBody is done here to determine if the
// decryption produced valid bytes or not.
EnvelopeBody::decode(decrypted.as_slice())
EnvelopeBody::decode(decrypted.freeze())
.and_then(|body| {
// Check if we received a body length of zero
//
Expand Down Expand Up @@ -477,10 +478,11 @@ mod test {

use futures::{executor::block_on, future};
use tari_comms::{
message::{MessageExt, MessageTag},
message::MessageTag,
runtime,
test_utils::mocks::create_connectivity_mock,
wrap_in_envelope_body,
BytesMut,
};
use tari_test_utils::{counter_context, unpack_enum};
use tokio::time::sleep;
Expand All @@ -492,6 +494,7 @@ mod test {
test_utils::{
make_dht_header,
make_dht_inbound_message,
make_dht_inbound_message_raw,
make_keypair,
make_node_identity,
make_valid_message_signature,
Expand Down Expand Up @@ -527,14 +530,8 @@ mod test {
let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service);

let plain_text_msg = wrap_in_envelope_body!(b"Secret plans".to_vec());
let inbound_msg = make_dht_inbound_message(
&node_identity,
plain_text_msg.to_encoded_bytes(),
DhtMessageFlags::ENCRYPTED,
true,
true,
)
.unwrap();
let inbound_msg =
make_dht_inbound_message(&node_identity, &plain_text_msg, DhtMessageFlags::ENCRYPTED, true, true).unwrap();

block_on(service.call(inbound_msg)).unwrap();
let decrypted = result.lock().unwrap().take().unwrap();
Expand All @@ -560,7 +557,7 @@ mod test {
let some_other_node_identity = make_node_identity();
let inbound_msg = make_dht_inbound_message(
&some_other_node_identity,
some_secret,
&some_secret,
DhtMessageFlags::ENCRYPTED,
true,
true,
Expand Down Expand Up @@ -591,7 +588,7 @@ mod test {

let nonsense = b"Cannot Decrypt this".to_vec();
let inbound_msg =
make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED, true, true).unwrap();
make_dht_inbound_message_raw(&node_identity, nonsense, DhtMessageFlags::ENCRYPTED, true, true).unwrap();

let err = service.call(inbound_msg).await.unwrap_err();
let err = err.downcast::<DecryptionError>().unwrap();
Expand All @@ -615,14 +612,8 @@ mod test {
let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service);

let plain_text_msg = b"Secret message to nowhere".to_vec();
let inbound_msg = make_dht_inbound_message(
&node_identity,
plain_text_msg.to_encoded_bytes(),
DhtMessageFlags::ENCRYPTED,
true,
false,
)
.unwrap();
let inbound_msg =
make_dht_inbound_message(&node_identity, &plain_text_msg, DhtMessageFlags::ENCRYPTED, true, false).unwrap();

let err = service.call(inbound_msg).await.unwrap_err();
let err = err.downcast::<DecryptionError>().unwrap();
Expand All @@ -645,13 +636,15 @@ mod test {
let node_identity = make_node_identity();
let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service);

let plain_text_msg = b"Secret message".to_vec();
let plain_text_msg = BytesMut::from(b"Secret message".as_slice());
let (e_secret_key, e_public_key) = make_keypair();
let shared_secret = crypt::generate_ecdh_secret(&e_secret_key, node_identity.public_key());
let key_message = crypt::generate_key_message(&shared_secret);
let msg_tag = MessageTag::new();

let message = crypt::encrypt(&key_message, &plain_text_msg).unwrap();
let mut message = plain_text_msg.clone();
crypt::encrypt(&key_message, &mut message).unwrap();
let message = message.freeze();
let header = make_dht_header(
&node_identity,
&e_public_key,
Expand All @@ -663,7 +656,7 @@ mod test {
true,
)
.unwrap();
let envelope = DhtEnvelope::new(header.into(), &message.into());
let envelope = DhtEnvelope::new(header.into(), message.into());
let msg_tag = MessageTag::new();
let mut inbound_msg = DhtInboundMessage::new(
msg_tag,
Expand Down Expand Up @@ -706,13 +699,15 @@ mod test {
let node_identity = make_node_identity();
let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service);

let plain_text_msg = b"Public message".to_vec();
let plain_text_msg = BytesMut::from(b"Public message".as_slice());
let (e_secret_key, e_public_key) = make_keypair();
let shared_secret = crypt::generate_ecdh_secret(&e_secret_key, node_identity.public_key());
let key_message = crypt::generate_key_message(&shared_secret);
let msg_tag = MessageTag::new();

let message = crypt::encrypt(&key_message, &plain_text_msg).unwrap();
let mut message = plain_text_msg.clone();
crypt::encrypt(&key_message, &mut message).unwrap();
let message = message.freeze();
let header = make_dht_header(
&node_identity,
&e_public_key,
Expand All @@ -724,7 +719,7 @@ mod test {
true,
)
.unwrap();
let envelope = DhtEnvelope::new(header.into(), &message.into());
let envelope = DhtEnvelope::new(header.into(), message.into());
let msg_tag = MessageTag::new();
let mut inbound_msg = DhtInboundMessage::new(
msg_tag,
Expand Down
4 changes: 2 additions & 2 deletions comms/dht/src/inbound/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ mod test {

let dht_envelope = make_dht_envelope(
&node_identity,
b"A".to_vec(),
&b"A".to_vec(),
DhtMessageFlags::empty(),
false,
MessageTag::new(),
Expand All @@ -181,7 +181,7 @@ mod test {
.unwrap();

let msg = spy.pop_request().unwrap();
assert_eq!(msg.body, b"A".to_vec());
assert_eq!(msg.body, b"A".to_vec().to_encoded_bytes());
assert_eq!(msg.dht_header, dht_envelope.header.unwrap().try_into().unwrap());
}
}
2 changes: 1 addition & 1 deletion comms/dht/src/inbound/dht_handler/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
.with_debug_info("Propagating join message".to_string())
.with_dht_header(dht_header)
.finish(),
body.to_encoded_bytes(),
body.encode_into_bytes_mut(),
)
.await?;
}
Expand Down
20 changes: 10 additions & 10 deletions comms/dht/src/inbound/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use std::task::Poll;

use futures::{future::BoxFuture, task::Context};
use log::*;
use tari_comms::{peer_manager::Peer, pipeline::PipelineError};
use prost::bytes::BufMut;
use tari_comms::{peer_manager::Peer, pipeline::PipelineError, BytesMut};
use tari_utilities::epoch_time::EpochTime;
use tower::{layer::Layer, Service, ServiceExt};

Expand Down Expand Up @@ -204,12 +205,11 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
return Ok(());
}
}

let body = decryption_result
let err_body = decryption_result
.as_ref()
.err()
.cloned()
.expect("previous check that decryption failed");
.expect_err("previous check that decryption failed");
let mut body = BytesMut::with_capacity(err_body.len());
body.put(err_body.as_slice());

let excluded_peers = vec![source_peer.node_id.clone()];
let dest_node_id = dht_header.destination.to_derived_node_id();
Expand Down Expand Up @@ -259,7 +259,7 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
mod test {
use std::time::Duration;

use tari_comms::{runtime, runtime::task, wrap_in_envelope_body};
use tari_comms::{message::MessageExt, runtime, runtime::task, wrap_in_envelope_body};
use tokio::sync::mpsc;

use super::*;
Expand All @@ -278,7 +278,7 @@ mod test {

let node_identity = make_node_identity();
let inbound_msg =
make_dht_inbound_message(&node_identity, b"".to_vec(), DhtMessageFlags::empty(), false, false).unwrap();
make_dht_inbound_message(&node_identity, &b"".to_vec(), DhtMessageFlags::empty(), false, false).unwrap();
let msg = DecryptedDhtMessage::succeeded(
wrap_in_envelope_body!(Vec::new()),
Some(node_identity.public_key().clone()),
Expand All @@ -300,7 +300,7 @@ mod test {
let sample_body = b"Lorem ipsum";
let inbound_msg = make_dht_inbound_message(
&make_node_identity(),
sample_body.to_vec(),
&sample_body.to_vec(),
DhtMessageFlags::empty(),
false,
false,
Expand All @@ -318,7 +318,7 @@ mod test {
let (params, body) = oms_mock_state.pop_call().await.unwrap();

// Header and body are preserved when forwarding
assert_eq!(&body.to_vec(), &sample_body);
assert_eq!(&body.to_vec(), &sample_body.to_vec().to_encoded_bytes());
assert_eq!(params.dht_header.unwrap(), header);
}
}
Loading

0 comments on commit 195df85

Please sign in to comment.