Skip to content

Commit

Permalink
fix(dht): saf storage uses constructs correct msg hash (#4003)
Browse files Browse the repository at this point in the history
Description
---
- construct the dedup hash correctly in SAF messages
- consolidate dedup hashing
- move dedup to after decryption/validation step in saf processor

Motivation and Context
---
SAF db is also used for dedup, so the hash must match. 
Closes #3419

How Has This Been Tested?
---
Existing unit tests, memorynet and manually
  • Loading branch information
sdbondi authored Apr 7, 2022
1 parent 6c5471e commit e1e7669
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 47 deletions.
8 changes: 4 additions & 4 deletions comms/dht/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ impl DhtActor {
} => {
let msg_hash_cache = self.msg_hash_dedup_cache.clone();
Box::pin(async move {
match msg_hash_cache.add_body_hash(message_hash, &received_from) {
match msg_hash_cache.add_msg_hash(&message_hash, &received_from) {
Ok(hit_count) => {
let _ = reply_tx.send(hit_count);
},
Expand All @@ -366,7 +366,7 @@ impl DhtActor {
GetMsgHashHitCount(hash, reply_tx) => {
let msg_hash_cache = self.msg_hash_dedup_cache.clone();
Box::pin(async move {
let hit_count = msg_hash_cache.get_hit_count(hash)?;
let hit_count = msg_hash_cache.get_hit_count(&hash)?;
let _ = reply_tx.send(hit_count);
Ok(())
})
Expand Down Expand Up @@ -1043,15 +1043,15 @@ mod test {
for key in &signatures {
let num_hits = actor
.msg_hash_dedup_cache
.add_body_hash(key.clone(), &CommsPublicKey::default())
.add_msg_hash(key, &CommsPublicKey::default())
.unwrap();
assert_eq!(num_hits, 1);
}
// Try to re-insert all; all hashes should have incremented their hit count
for key in &signatures {
let num_hits = actor
.msg_hash_dedup_cache
.add_body_hash(key.clone(), &CommsPublicKey::default())
.add_msg_hash(key, &CommsPublicKey::default())
.unwrap();
assert_eq!(num_hits, 2);
}
Expand Down
13 changes: 6 additions & 7 deletions comms/dht/src/dedup/dedup_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use chrono::{NaiveDateTime, Utc};
use diesel::{dsl, result::DatabaseErrorKind, sql_types, ExpressionMethods, OptionalExtension, QueryDsl, RunQueryDsl};
use log::*;
use tari_comms::types::CommsPublicKey;
use tari_crypto::tari_utilities::hex::Hex;
use tari_crypto::tari_utilities::hex::to_hex;
use tari_utilities::hex::Hex;

use crate::{
schema::dedup_cache,
Expand Down Expand Up @@ -59,9 +60,8 @@ impl DedupCacheDatabase {

/// Adds the body hash to the cache, returning the number of hits (inclusive) that have been recorded for this body
/// hash
#[allow(clippy::needless_pass_by_value)]
pub fn add_body_hash(&self, body_hash: Vec<u8>, public_key: &CommsPublicKey) -> Result<u32, StorageError> {
let hit_count = self.insert_body_hash_or_update_stats(&body_hash.to_hex(), &public_key.to_hex())?;
pub fn add_msg_hash(&self, msg_hash: &[u8], public_key: &CommsPublicKey) -> Result<u32, StorageError> {
let hit_count = self.insert_body_hash_or_update_stats(&to_hex(msg_hash), &public_key.to_hex())?;

if hit_count == 0 {
warn!(
Expand All @@ -72,12 +72,11 @@ impl DedupCacheDatabase {
Ok(hit_count)
}

#[allow(clippy::needless_pass_by_value)]
pub fn get_hit_count(&self, body_hash: Vec<u8>) -> Result<u32, StorageError> {
pub fn get_hit_count(&self, body_hash: &[u8]) -> Result<u32, StorageError> {
let conn = self.connection.get_pooled_connection()?;
let hit_count = dedup_cache::table
.select(dedup_cache::number_of_hits)
.filter(dedup_cache::body_hash.eq(&body_hash.to_hex()))
.filter(dedup_cache::body_hash.eq(&to_hex(body_hash)))
.get_result::<i32>(&conn)
.optional()?;

Expand Down
16 changes: 14 additions & 2 deletions comms/dht/src/dedup/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,28 @@ mod dedup_cache;
use std::task::Poll;

pub use dedup_cache::DedupCacheDatabase;
use digest::Digest;
use futures::{future::BoxFuture, task::Context};
use log::*;
use tari_comms::pipeline::PipelineError;
use tari_comms::{pipeline::PipelineError, types::Challenge};
use tari_utilities::hex::Hex;
use tower::{layer::Layer, Service, ServiceExt};

use crate::{actor::DhtRequester, inbound::DecryptedDhtMessage};
use crate::{
actor::DhtRequester,
inbound::{DecryptedDhtMessage, DhtInboundMessage},
};

const LOG_TARGET: &str = "comms::dht::dedup";

pub fn hash_inbound_message(msg: &DhtInboundMessage) -> [u8; 32] {
create_message_hash(&msg.dht_header.origin_mac, &msg.body)
}

pub fn create_message_hash(origin_mac: &[u8], body: &[u8]) -> [u8; 32] {
Challenge::new().chain(origin_mac).chain(&body).finalize().into()
}

/// # DHT Deduplication middleware
///
/// Takes in a `DecryptedDhtMessage` and checks the message signature cache for duplicates.
Expand Down
21 changes: 8 additions & 13 deletions comms/dht/src/inbound/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,17 @@ use std::{
sync::Arc,
};

use digest::Digest;
use tari_comms::{
message::{EnvelopeBody, MessageTag},
peer_manager::Peer,
types::{Challenge, CommsPublicKey},
types::CommsPublicKey,
};
use tari_utilities::ByteArray;

use crate::envelope::{DhtMessageFlags, DhtMessageHeader};

fn hash_inbound_message(message: &DhtInboundMessage) -> Vec<u8> {
Challenge::new()
.chain(&message.dht_header.origin_mac)
.chain(&message.body)
.finalize()
.to_vec()
}
use crate::{
dedup,
envelope::{DhtMessageFlags, DhtMessageHeader},
};

#[derive(Debug, Clone)]
pub struct DhtInboundMessage {
Expand Down Expand Up @@ -116,7 +111,7 @@ impl DecryptedDhtMessage {
message: DhtInboundMessage,
) -> Self {
Self {
dedup_hash: hash_inbound_message(&message),
dedup_hash: dedup::hash_inbound_message(&message).to_vec(),
tag: message.tag,
source_peer: message.source_peer,
authenticated_origin,
Expand All @@ -131,7 +126,7 @@ impl DecryptedDhtMessage {

pub fn failed(message: DhtInboundMessage) -> Self {
Self {
dedup_hash: hash_inbound_message(&message),
dedup_hash: dedup::hash_inbound_message(&message).to_vec(),
tag: message.tag,
source_peer: message.source_peer,
authenticated_origin: None,
Expand Down
13 changes: 6 additions & 7 deletions comms/dht/src/outbound/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use std::{sync::Arc, task::Poll};

use bytes::Bytes;
use chrono::{DateTime, Utc};
use digest::Digest;
use futures::{
future,
future::BoxFuture,
Expand Down Expand Up @@ -53,6 +52,7 @@ use crate::{
actor::DhtRequester,
broadcast_strategy::BroadcastStrategy,
crypt,
dedup,
discovery::DhtDiscoveryRequester,
envelope::{datetime_to_epochtime, datetime_to_timestamp, DhtMessageFlags, DhtMessageHeader, NodeDestination},
outbound::{
Expand Down Expand Up @@ -429,8 +429,8 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
)?;

if is_broadcast {
self.add_to_dedup_cache(&body, self.node_identity.public_key().clone())
.await?;
let hash = dedup::create_message_hash(origin_mac.as_deref().unwrap_or(&[]), &body);
self.add_to_dedup_cache(hash).await?;
}

// Construct a DhtOutboundMessage for each recipient
Expand Down Expand Up @@ -461,8 +461,7 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
Ok(messages.unzip())
}

async fn add_to_dedup_cache(&mut self, body: &[u8], public_key: CommsPublicKey) -> Result<(), DhtOutboundError> {
let hash = Challenge::new().chain(&body).finalize().to_vec();
async fn add_to_dedup_cache(&mut self, hash: [u8; 32]) -> Result<(), DhtOutboundError> {
trace!(
target: LOG_TARGET,
"Dedup added message hash {} to cache for message",
Expand All @@ -472,12 +471,12 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
// Do not count messages we've broadcast towards the total hit count
let hit_count = self
.dht_requester
.get_message_cache_hit_count(hash.clone())
.get_message_cache_hit_count(hash.to_vec())
.await
.map_err(|err| DhtOutboundError::FailedToInsertMessageHash(err.to_string()))?;
if hit_count == 0 {
self.dht_requester
.add_message_to_dedup_cache(hash, public_key)
.add_message_to_dedup_cache(hash.to_vec(), self.node_identity.public_key().clone())
.await
.map_err(|err| DhtOutboundError::FailedToInsertMessageHash(err.to_string()))?;
}
Expand Down
7 changes: 4 additions & 3 deletions comms/dht/src/store_forward/database/stored_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
use std::convert::TryInto;

use chrono::NaiveDateTime;
use digest::Digest;
use tari_comms::{message::MessageExt, types::Challenge};
use tari_comms::message::MessageExt;
use tari_utilities::{hex, hex::Hex};

use crate::{
dedup,
inbound::DecryptedDhtMessage,
proto::envelope::DhtHeader,
schema::stored_messages,
Expand Down Expand Up @@ -62,6 +62,7 @@ impl NewStoredMessage {
Ok(envelope_body) => envelope_body.to_encoded_bytes(),
Err(encrypted_body) => encrypted_body,
};
let body_hash = hex::to_hex(&dedup::create_message_hash(&dht_header.origin_mac, &body));

Some(Self {
version: dht_header.version.as_major().try_into().ok()?,
Expand All @@ -75,7 +76,7 @@ impl NewStoredMessage {
let dht_header: DhtHeader = dht_header.into();
dht_header.to_encoded_bytes()
},
body_hash: hex::to_hex(&Challenge::new().chain(body.clone()).finalize()),
body_hash,
body,
})
}
Expand Down
28 changes: 21 additions & 7 deletions comms/dht/src/store_forward/saf_handler/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
use std::{convert::TryInto, sync::Arc};

use chrono::{DateTime, NaiveDateTime, Utc};
use digest::Digest;
use futures::{future, stream, StreamExt};
use log::*;
use prost::Message;
Expand All @@ -41,6 +40,7 @@ use tower::{Service, ServiceExt};
use crate::{
actor::DhtRequester,
crypt,
dedup,
envelope::{timestamp_to_datetime, DhtMessageFlags, DhtMessageHeader, NodeDestination},
inbound::{DecryptedDhtMessage, DhtInboundMessage},
outbound::{OutboundMessageRequester, SendMessageParams},
Expand Down Expand Up @@ -445,6 +445,15 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
return Err(StoreAndForwardError::StoredAtWasInFuture);
}

let msg_hash = dedup::create_message_hash(
message
.dht_header
.as_ref()
.map(|h| h.origin_mac.as_slice())
.unwrap_or(&[]),
&message.body,
);

let dht_header: DhtMessageHeader = message
.dht_header
.expect("previously checked")
Expand Down Expand Up @@ -478,13 +487,19 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>

// Check that the destination is either undisclosed, for us or for our network region
Self::check_destination(config, peer_manager, node_identity, &dht_header).await?;
// Check that the message has not already been received.
Self::check_duplicate(&mut self.dht_requester, &message.body, source_peer.public_key.clone()).await?;

// Attempt to decrypt the message (if applicable), and deserialize it
let (authenticated_pk, decrypted_body) =
Self::authenticate_and_decrypt_if_required(node_identity, &dht_header, &message.body)?;

// Check that the message has not already been received.
Self::check_duplicate(
&mut self.dht_requester,
msg_hash.to_vec(),
source_peer.public_key.clone(),
)
.await?;

let mut inbound_msg =
DhtInboundMessage::new(MessageTag::new(), dht_header, Arc::clone(&source_peer), message.body);
inbound_msg.is_saf_message = true;
Expand All @@ -497,10 +512,9 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>

async fn check_duplicate(
dht_requester: &mut DhtRequester,
body: &[u8],
msg_hash: Vec<u8>,
public_key: CommsPublicKey,
) -> Result<(), StoreAndForwardError> {
let msg_hash = Challenge::new().chain(body).finalize().to_vec();
let hit_count = dht_requester.add_message_to_dedup_cache(msg_hash, public_key).await?;
if hit_count > 1 {
Err(StoreAndForwardError::DuplicateMessage)
Expand Down Expand Up @@ -642,8 +656,8 @@ mod test {
dht_header: DhtMessageHeader,
stored_at: NaiveDateTime,
) -> StoredMessage {
let msg_hash = hex::to_hex(&dedup::create_message_hash(&dht_header.origin_mac, message.as_bytes()));
let body = message.into_bytes();
let body_hash = hex::to_hex(&Challenge::new().chain(&body).finalize());
StoredMessage {
id: 1,
version: 0,
Expand All @@ -656,7 +670,7 @@ mod test {
is_encrypted: false,
priority: StoredMessagePriority::High as i32,
stored_at,
body_hash,
body_hash: msg_hash,
}
}

Expand Down
5 changes: 1 addition & 4 deletions comms/dht/src/test_utils/store_and_forward_mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@ use std::{
};

use chrono::Utc;
use digest::Digest;
use log::*;
use rand::{rngs::OsRng, RngCore};
use tari_comms::types::Challenge;
use tari_utilities::hex;
use tokio::{
runtime,
sync::{mpsc, RwLock},
Expand Down Expand Up @@ -150,7 +147,7 @@ impl StoreAndForwardMock {
is_encrypted: msg.is_encrypted,
priority: msg.priority,
stored_at: Utc::now().naive_utc(),
body_hash: hex::to_hex(&Challenge::new().chain(msg.body).finalize()),
body_hash: msg.body_hash,
});
reply_tx.send(Ok(false)).unwrap();
},
Expand Down

0 comments on commit e1e7669

Please sign in to comment.