diff --git a/comms/dht/src/error.rs b/comms/dht/src/error.rs index a0d8718cee..a0e2090030 100644 --- a/comms/dht/src/error.rs +++ b/comms/dht/src/error.rs @@ -22,7 +22,7 @@ use thiserror::Error; -#[derive(Debug, Error)] +#[derive(Debug, Error, PartialEq)] pub enum DhtEncryptError { #[error("Message body invalid")] InvalidMessageBody, diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index ef6d10323c..9800fc1c0b 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -47,7 +47,7 @@ use crate::{ const LOG_TARGET: &str = "comms::middleware::decryption"; -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] enum DecryptionError { #[error("Failed to validate ENCRYPTED message signature")] MessageSignatureInvalidEncryptedSignature, @@ -75,8 +75,8 @@ enum DecryptionError { EncryptedMessageNoDestination, #[error("Decryption failed: {0}")] DecryptionFailedMalformedCipher(#[from] DhtEncryptError), - #[error("Encrypted message must have a non-empty body")] - EncryptedMessageEmptyBody, + #[error("Message must have a non-empty body")] + MessageEmptyBody, } /// This layer is responsible for attempting to decrypt inbound messages. @@ -193,6 +193,7 @@ where S: Service Err(err @ MessageSignatureClearTextDecodeFailed) | Err(err @ MessageSignatureInvalidClearTextSignature) | Err(err @ EncryptedMessageNoDestination) | + Err(err @ MessageEmptyBody) | Err(err @ MessageSignatureErrorClearText(_)) => { warn!( target: LOG_TARGET, @@ -354,7 +355,7 @@ where S: Service fn initial_validation(message: DhtInboundMessage) -> Result { // Messages must not be empty if message.body.is_empty() { - return Err(DecryptionError::EncryptedMessageEmptyBody); + return Err(DecryptionError::MessageEmptyBody); } if message.dht_header.flags.is_encrypted() { @@ -509,7 +510,7 @@ mod test { wrap_in_envelope_body, BytesMut, }; - use tari_test_utils::{counter_context, unpack_enum}; + use tari_test_utils::counter_context; use tokio::time::sleep; use tower::service_fn; @@ -526,6 +527,43 @@ mod test { }, }; + /// Receive a message, assert a specific error is raised, and ban the peer if needed + async fn expect_error( + node_identity: Arc, + message: DhtInboundMessage, + error: DecryptionError, + ban: bool, + ) { + // Set up messaging + let (connectivity, mock) = create_connectivity_mock(); + let mock_state = mock.spawn(); + let result = Arc::new(Mutex::new(None)); + let service = service_fn({ + let result = result.clone(); + move |msg: DecryptedDhtMessage| { + *result.lock().unwrap() = Some(msg); + future::ready(Result::<(), PipelineError>::Ok(())) + } + }); + let mut service = DecryptionService::new(Default::default(), node_identity, connectivity, service); + + // Receive the message and check for the expected error + let err = service.call(message).await.unwrap_err(); + let err = err.downcast::().unwrap(); + assert_eq!(error, err); + assert!(result.lock().unwrap().is_none()); + + // Assert the expected ban status + if ban { + mock_state.await_call_count(1).await; + assert_eq!(mock_state.count_calls_containing("BanPeer").await, 1); + } else { + // Waiting like this isn't a guarantee that the peer won't be banned + sleep(Duration::from_secs(1)).await; + assert_eq!(mock_state.count_calls_containing("BanPeer").await, 0); + } + } + #[test] fn poll_ready() { let service = service_fn(|_: DecryptedDhtMessage| future::ready(Result::<(), PipelineError>::Ok(()))); @@ -540,8 +578,11 @@ mod test { assert_eq!(counter.get(), 0); } - #[test] - fn decrypt_inbound_success() { + #[runtime::test] + /// We can decrypt valid encrypted messages destined for us + async fn decrypt_inbound_success() { + let (connectivity, mock) = create_connectivity_mock(); + let mock_state = mock.spawn(); let result = Arc::new(Mutex::new(None)); let service = service_fn({ let result = result.clone(); @@ -551,21 +592,30 @@ mod test { } }); let node_identity = make_node_identity(); - let (connectivity, _) = create_connectivity_mock(); let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service); + // Encrypt a message for us let plain_text_msg = wrap_in_envelope_body!(b"Secret plans".to_vec()); let inbound_msg = make_dht_inbound_message(&node_identity, &plain_text_msg, DhtMessageFlags::ENCRYPTED, true, true).unwrap(); + // Check that decryption yields the original message block_on(service.call(inbound_msg)).unwrap(); let decrypted = result.lock().unwrap().take().unwrap(); assert!(decrypted.decryption_succeeded()); assert_eq!(decrypted.decryption_result.unwrap(), plain_text_msg); + + // Don't ban the peer + // Waiting like this isn't a guarantee that the peer won't be banned + sleep(Duration::from_secs(1)).await; + assert_eq!(mock_state.count_calls_containing("BanPeer").await, 0); } - #[test] - fn decrypt_inbound_fail() { + #[runtime::test] + /// An encrypted message is not destined for us + async fn decrypt_inbound_not_for_us() { + let (connectivity, mock) = create_connectivity_mock(); + let mock_state = mock.spawn(); let result = Arc::new(Mutex::new(None)); let service = service_fn({ let result = result.clone(); @@ -575,9 +625,9 @@ mod test { } }); let node_identity = make_node_identity(); - let (connectivity, _) = create_connectivity_mock(); let mut service = DecryptionService::new(Default::default(), node_identity, connectivity, service); + // Encrypt a message for someone else let some_secret = b"Super secret message".to_vec(); let some_other_node_identity = make_node_identity(); let inbound_msg = make_dht_inbound_message( @@ -589,205 +639,175 @@ mod test { ) .unwrap(); + // Decryption fails, but it's not an error block_on(service.call(inbound_msg.clone())).unwrap(); let decrypted = result.lock().unwrap().take().unwrap(); - assert!(!decrypted.decryption_succeeded()); assert_eq!(decrypted.decryption_result.unwrap_err(), inbound_msg.body); + + // Don't ban the peer + // Waiting like this isn't a guarantee that the peer won't be banned + sleep(Duration::from_secs(1)).await; + assert_eq!(mock_state.count_calls_containing("BanPeer").await, 0); } - #[test] - fn decrypt_inbound_fail_empty_contents() { - let service = service_fn( - move |_msg: DecryptedDhtMessage| -> future::Ready> { - panic!("Should not be called") - }, - ); + #[runtime::test] + /// A message is empty + async fn empty_message() { let node_identity = make_node_identity(); - let (connectivity, _) = create_connectivity_mock(); - let mut service = DecryptionService::new(Default::default(), node_identity, connectivity, service); + let other_identity = make_node_identity(); - let some_other_node_identity = make_node_identity(); - let mut inbound_msg = make_dht_inbound_message( - &some_other_node_identity, - &Vec::new(), - DhtMessageFlags::ENCRYPTED, - true, - true, - ) - .unwrap(); - inbound_msg.body = Vec::new(); + // Encrypt an empty message + for identity in [&node_identity, &other_identity] { + for encrypted_flag in [DhtMessageFlags::NONE, DhtMessageFlags::ENCRYPTED] { + let mut message = make_dht_inbound_message(identity, &Vec::new(), encrypted_flag, true, true).unwrap(); + message.body = Vec::new(); - let err = block_on(service.call(inbound_msg)).unwrap_err(); - let err = err.downcast::().unwrap(); - unpack_enum!(DecryptionError::EncryptedMessageEmptyBody = err); + // Ban the peer + expect_error(node_identity.clone(), message, DecryptionError::MessageEmptyBody, true).await; + } + } } #[runtime::test] - async fn decrypt_inbound_fail_destination() { - let (connectivity, mock) = create_connectivity_mock(); - mock.spawn(); - let result = Arc::new(Mutex::new(None)); - let service = service_fn({ - let result = result.clone(); - move |msg: DecryptedDhtMessage| { - *result.lock().unwrap() = Some(msg); - future::ready(Result::<(), PipelineError>::Ok(())) - } - }); + /// An encrypted message is destined for us but can't be decrypted + async fn decrypt_inbound_fail_for_us() { let node_identity = make_node_identity(); - let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service); + // Encrypt an invalid message destined for us let nonsense = b"Cannot Decrypt this".to_vec(); - let inbound_msg = + let message = make_dht_inbound_message_raw(&node_identity, nonsense, DhtMessageFlags::ENCRYPTED, true, true).unwrap(); - let err = service.call(inbound_msg).await.unwrap_err(); - let err = err.downcast::().unwrap(); - unpack_enum!(DecryptionError::MessageRejectDecryptionFailed = err); - assert!(result.lock().unwrap().is_none()); + // Don't ban the peer + expect_error( + node_identity, + message, + DecryptionError::MessageRejectDecryptionFailed, + false, + ) + .await; } #[runtime::test] + /// An encrypted message has no destination async fn decrypt_inbound_fail_no_destination() { - let (connectivity, mock) = create_connectivity_mock(); - mock.spawn(); - let result = Arc::new(Mutex::new(None)); - let service = service_fn({ - let result = result.clone(); - move |msg: DecryptedDhtMessage| { - *result.lock().unwrap() = Some(msg); - future::ready(Result::<(), PipelineError>::Ok(())) - } - }); let node_identity = make_node_identity(); - let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service); + // Encrypt a message with no destination let plain_text_msg = b"Secret message to nowhere".to_vec(); - let inbound_msg = + let message = make_dht_inbound_message(&node_identity, &plain_text_msg, DhtMessageFlags::ENCRYPTED, true, false).unwrap(); - let err = service.call(inbound_msg).await.unwrap_err(); - let err = err.downcast::().unwrap(); - unpack_enum!(DecryptionError::EncryptedMessageNoDestination = err); - assert!(result.lock().unwrap().is_none()); + // Ban the peer + expect_error( + node_identity, + message, + DecryptionError::EncryptedMessageNoDestination, + true, + ) + .await; } #[runtime::test] + /// An encrypted message destined for us has an invalid signature async fn decrypt_inbound_fail_invalid_signature_encrypted() { - let (connectivity, mock) = create_connectivity_mock(); - let mock_state = mock.spawn(); - let result = Arc::new(Mutex::new(None)); - let service = service_fn({ - let result = result.clone(); - move |msg: DecryptedDhtMessage| { - *result.lock().unwrap() = Some(msg); - future::ready(Result::<(), PipelineError>::Ok(())) - } - }); let node_identity = make_node_identity(); - let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service); + // Encrypt a message destined for us let plain_text_msg = BytesMut::from(b"Secret message".as_slice()); let (e_secret_key, e_public_key) = make_keypair(); let shared_secret = CommsDHKE::new(&e_secret_key, node_identity.public_key()); let key_message = crypt::generate_key_message(&shared_secret); let msg_tag = MessageTag::new(); - let mut message = plain_text_msg.clone(); - crypt::encrypt_message(&key_message, &mut message).unwrap(); - let message = message.freeze(); + let mut message_bytes = plain_text_msg.clone(); + crypt::encrypt_message(&key_message, &mut message_bytes).unwrap(); + let message_bytes = message_bytes.freeze(); let header = make_dht_header( &node_identity, &e_public_key, &e_secret_key, - &message, + &message_bytes, DhtMessageFlags::ENCRYPTED, true, msg_tag, true, ) .unwrap(); - let envelope = DhtEnvelope::new(header.into(), message.into()); + let envelope = DhtEnvelope::new(header.into(), message_bytes.into()); let msg_tag = MessageTag::new(); - let mut inbound_msg = DhtInboundMessage::new( + let mut message = DhtInboundMessage::new( msg_tag, envelope.header.unwrap().try_into().unwrap(), Arc::new(node_identity.to_peer()), envelope.body, ); - // Sign invalid data. Other peers cannot validate this while propagating, but this should not cause them to be - // banned. + // Manipulate the signature; we can decrypt it, but it's not valid for this message let signature = make_valid_message_signature(&node_identity, b"sign invalid data"); let key_signature = crypt::generate_key_signature(&shared_secret); + message.dht_header.message_signature = crypt::encrypt_signature(&key_signature, &signature).unwrap(); - inbound_msg.dht_header.message_signature = crypt::encrypt_signature(&key_signature, &signature).unwrap(); - - let err = service.call(inbound_msg).await.unwrap_err(); - let err = err.downcast::().unwrap(); - unpack_enum!(DecryptionError::MessageSignatureInvalidEncryptedSignature = err); - assert!(result.lock().unwrap().is_none()); - - // Proving a negative i.e. ban is not called, we have no choice but to sleep to wait for any potential calls to - // be registered. This should ensure that if this bug re-occurs that this test is flaky. - sleep(Duration::from_secs(1)).await; - assert_eq!(mock_state.count_calls_containing("BanPeer").await, 0); + // Don't ban the peer + expect_error( + node_identity, + message, + DecryptionError::MessageSignatureInvalidEncryptedSignature, + false, + ) + .await; } #[runtime::test] + /// An unencrypted message has an invalid signature async fn decrypt_inbound_fail_invalid_signature_cleartext() { - let (connectivity, mock) = create_connectivity_mock(); - let mock_state = mock.spawn(); - let result = Arc::new(Mutex::new(None)); - let service = service_fn({ - let result = result.clone(); - move |msg: DecryptedDhtMessage| { - *result.lock().unwrap() = Some(msg); - future::ready(Result::<(), PipelineError>::Ok(())) - } - }); let node_identity = make_node_identity(); - let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service); + let other_identity = make_node_identity(); + let plain_text_msg = b"a message".to_vec(); - let plain_text_msg = BytesMut::from(b"Public message".as_slice()); - let (e_secret_key, e_public_key) = make_keypair(); - let shared_secret = CommsDHKE::new(&e_secret_key, node_identity.public_key()); - let key_message = crypt::generate_key_message(&shared_secret); - let msg_tag = MessageTag::new(); + // Handle the cases where we are and aren't the recipient + for identity in [&node_identity, &other_identity] { + let mut message = + make_dht_inbound_message(identity, &plain_text_msg, DhtMessageFlags::NONE, true, true).unwrap(); - let mut message = plain_text_msg.clone(); - crypt::encrypt_message(&key_message, &mut message).unwrap(); - let message = message.freeze(); - let header = make_dht_header( - &node_identity, - &e_public_key, - &e_secret_key, - &message, - DhtMessageFlags::NONE, - true, - msg_tag, - true, - ) - .unwrap(); - let envelope = DhtEnvelope::new(header.into(), message.into()); - let msg_tag = MessageTag::new(); - let mut inbound_msg = DhtInboundMessage::new( - msg_tag, - envelope.header.unwrap().try_into().unwrap(), - Arc::new(node_identity.to_peer()), - envelope.body, - ); + // Manipulate the signature so it's invalid + message.dht_header.message_signature = make_valid_message_signature(identity, b"a different message"); + + // Ban the peer + expect_error( + node_identity.clone(), + message, + DecryptionError::MessageSignatureInvalidClearTextSignature, + true, + ) + .await; + } + } - inbound_msg.dht_header.ephemeral_public_key = Some(e_public_key.clone()); - inbound_msg.dht_header.message_signature = make_valid_message_signature(&node_identity, b"sign invalid data"); + #[runtime::test] + /// An encrypted message has no signature + async fn decrypt_inbound_fail_missing_signature_encrypted() { + let node_identity = make_node_identity(); + let other_identity = make_node_identity(); + let plain_text_msg = b"a secret message".to_vec(); - let err = service.call(inbound_msg).await.unwrap_err(); - let err = err.downcast::().unwrap(); - unpack_enum!(DecryptionError::MessageSignatureInvalidClearTextSignature = err); - assert!(result.lock().unwrap().is_none()); + // Handle the cases where we are and aren't the recipient + for identity in [&node_identity, &other_identity] { + let mut message = + make_dht_inbound_message(identity, &plain_text_msg, DhtMessageFlags::ENCRYPTED, true, true).unwrap(); + + // Remove the signature + message.dht_header.message_signature = Vec::new(); - mock_state.await_call_count(1).await; - assert_eq!(mock_state.count_calls_containing("BanPeer").await, 1); + // Ban the peer + expect_error( + node_identity.clone(), + message, + DecryptionError::MessageSignatureNotProvidedForEncryptedMessage, + true, + ) + .await; + } } } diff --git a/comms/dht/src/message_signature.rs b/comms/dht/src/message_signature.rs index 7d22e8a0e4..975b5d9208 100644 --- a/comms/dht/src/message_signature.rs +++ b/comms/dht/src/message_signature.rs @@ -121,7 +121,7 @@ pub struct ProtoMessageSignature { pub signature: Vec, } -#[derive(Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error, PartialEq)] pub enum MessageSignatureError { #[error("Failed to validate message signature")] InvalidSignatureBytes,