Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dht): updates to message padding #4594

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 82 additions & 47 deletions comms/dht/src/crypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ use crate::{
pub struct CipherKey(chacha20::Key);
pub struct AuthenticatedCipherKey(chacha20poly1305::Key);

const LITTLE_ENDIAN_U32_SIZE_REPRESENTATION: usize = 4;
const MESSAGE_BASE_LENGTH: usize = 6000;

/// Generates a Diffie-Hellman secret `kx.G` as a `chacha20::Key` given secret scalar `k` and public key `P = x.G`.
Expand All @@ -70,45 +69,57 @@ pub fn generate_ecdh_secret(secret_key: &CommsSecretKey, public_key: &CommsPubli
output
}

fn pad_message_to_base_length_multiple(message: &[u8]) -> Vec<u8> {
let n = message.len();
// little endian representation of message length, to be appended to padded message,
// assuming our code runs on 64-bits system
let prepend_to_message = (n as u32).to_le_bytes();

let k = prepend_to_message.len();

let div_n_base_len = (n + k) / MESSAGE_BASE_LENGTH;
let output_size = (div_n_base_len + 1) * MESSAGE_BASE_LENGTH;

// join prepend_message_len | message | zero_padding
let mut output = Vec::with_capacity(output_size);
output.extend_from_slice(&prepend_to_message);
output.extend_from_slice(message);
output.extend(std::iter::repeat(0u8).take(output_size - n - k));

output
fn pad_message_to_base_length_multiple(message: &[u8]) -> Result<Vec<u8>, DhtOutboundError> {
// We require a 32-bit length representation
if message.len() > (u32::max_value() as usize) {
AaronFeickert marked this conversation as resolved.
Show resolved Hide resolved
return Err(DhtOutboundError::CipherError("Message is too long".to_string()));
AaronFeickert marked this conversation as resolved.
Show resolved Hide resolved
}
let message_length = message.len();
let encoded_length = (message_length as u32).to_le_bytes();

// Pad the message (if needed) to the next multiple of the base length
let padding_length = if ((message_length + size_of::<u32>()) % MESSAGE_BASE_LENGTH) == 0 {
0
} else {
MESSAGE_BASE_LENGTH - ((message_length + size_of::<u32>()) % MESSAGE_BASE_LENGTH)
};

// The padded message is the encoded length, message, and zero padding
let mut padded_message = Vec::with_capacity(size_of::<u32>() + message_length + padding_length);
padded_message.extend_from_slice(&encoded_length);
padded_message.extend_from_slice(message);
padded_message.extend(std::iter::repeat(0u8).take(padding_length));

Ok(padded_message)
}

fn get_original_message_from_padded_text(message: &[u8]) -> Result<Vec<u8>, DhtOutboundError> {
let mut le_bytes = [0u8; 4];
le_bytes.copy_from_slice(&message[..LITTLE_ENDIAN_U32_SIZE_REPRESENTATION]);
// NOTE: This function can return errors relating to message length
// It is important not to leak error types to an adversary, or to have timing differences

// Assert that the padded message is a multiple of the base length
if message.is_empty() || (message.len() % MESSAGE_BASE_LENGTH) != 0 {
return Err(DhtOutboundError::CipherError("Bad padded message length".to_string()));
AaronFeickert marked this conversation as resolved.
Show resolved Hide resolved
}

// obtain length of original message, assuming our code runs on 64-bits system
let original_message_len = u32::from_le_bytes(le_bytes) as usize;
// Decode the message length
let mut encoded_length = [0u8; size_of::<u32>()];
encoded_length.copy_from_slice(&message[0..size_of::<u32>()]);
let message_length = u32::from_le_bytes(encoded_length) as usize;

if original_message_len > message.len() {
// The message is too short for the decoded length
if message_length + size_of::<u32>() > message.len() {
sdbondi marked this conversation as resolved.
Show resolved Hide resolved
return Err(DhtOutboundError::CipherError(
"Original length message is invalid".to_string(),
"Message is too short to be unpadded".to_string(),
));
}

// obtain original message
let start = LITTLE_ENDIAN_U32_SIZE_REPRESENTATION;
let end = LITTLE_ENDIAN_U32_SIZE_REPRESENTATION + original_message_len;
let original_message = &message[start..end];
// Remove the padding
let start = size_of::<u32>();
let end = start + message_length;
let unpadded_message = &message[start..end];
sdbondi marked this conversation as resolved.
Show resolved Hide resolved

Ok(original_message.to_vec())
Ok(unpadded_message.to_vec())
}

pub fn generate_key_message(data: &[u8]) -> CipherKey {
Expand Down Expand Up @@ -164,9 +175,9 @@ pub fn decrypt_with_chacha20_poly1305(
}

/// Encrypt the plain text using the ChaCha20 stream cipher
pub fn encrypt(cipher_key: &CipherKey, plain_text: &[u8]) -> Vec<u8> {
pub fn encrypt(cipher_key: &CipherKey, plain_text: &[u8]) -> Result<Vec<u8>, DhtOutboundError> {
// pad plain_text to avoid message length leaks
let plain_text = pad_message_to_base_length_multiple(plain_text);
let plain_text = pad_message_to_base_length_multiple(plain_text)?;

let mut nonce = [0u8; size_of::<Nonce>()];
OsRng.fill_bytes(&mut nonce);
Expand All @@ -179,7 +190,7 @@ pub fn encrypt(cipher_key: &CipherKey, plain_text: &[u8]) -> Vec<u8> {

buf[nonce.len()..].copy_from_slice(plain_text.as_slice());
cipher.apply_keystream(&mut buf[nonce.len()..]);
buf
Ok(buf)
}

/// Produces authenticated encryption of the signature using the ChaCha20-Poly1305 stream cipher,
Expand Down Expand Up @@ -266,7 +277,7 @@ mod test {
let pk = CommsPublicKey::default();
let key = CipherKey(*chacha20::Key::from_slice(pk.as_bytes()));
let plain_text = "Last enemy position 0830h AJ 9863".as_bytes().to_vec();
let encrypted = encrypt(&key, &plain_text);
let encrypted = encrypt(&key, &plain_text).unwrap();
let decrypted = decrypt(&key, &encrypted).unwrap();
assert_eq!(decrypted, plain_text);
}
Expand Down Expand Up @@ -385,7 +396,7 @@ mod test {
.take(MESSAGE_BASE_LENGTH - message.len() - prepend_message.len())
.collect::<Vec<_>>();

let pad_message = pad_message_to_base_length_multiple(message);
let pad_message = pad_message_to_base_length_multiple(message).unwrap();

// padded message is of correct length
assert_eq!(pad_message.len(), MESSAGE_BASE_LENGTH);
Expand All @@ -402,7 +413,7 @@ mod test {
// test for large message
let message = &[100u8; MESSAGE_BASE_LENGTH * 8 - 100];
let prepend_message = (message.len() as u32).to_le_bytes();
let pad_message = pad_message_to_base_length_multiple(message);
let pad_message = pad_message_to_base_length_multiple(message).unwrap();
let pad = std::iter::repeat(0u8)
.take((8 * MESSAGE_BASE_LENGTH) - message.len() - prepend_message.len())
.collect::<Vec<_>>();
Expand All @@ -426,7 +437,7 @@ mod test {
.take((9 * MESSAGE_BASE_LENGTH) - message.len() - prepend_message.len())
.collect::<Vec<_>>();

let pad_message = pad_message_to_base_length_multiple(message);
let pad_message = pad_message_to_base_length_multiple(message).unwrap();

// padded message is of correct length
assert_eq!(pad_message.len(), 9 * MESSAGE_BASE_LENGTH);
Expand All @@ -443,7 +454,7 @@ mod test {
// test for empty message
let message: [u8; 0] = [];
let prepend_message = (message.len() as u32).to_le_bytes();
let pad_message = pad_message_to_base_length_multiple(&message);
let pad_message = pad_message_to_base_length_multiple(&message).unwrap();
let pad = [0u8; MESSAGE_BASE_LENGTH - 4];

// padded message is of correct length
Expand All @@ -460,32 +471,56 @@ mod test {
assert_eq!(pad, pad_message[prepend_message.len() + message.len()..]);
}

#[test]
fn unpadding_failure_modes() {
// The padded message is empty
let message: [u8; 0] = [];
assert!(get_original_message_from_padded_text(&message)
.unwrap_err()
.to_string()
.contains("Bad padded message length"));

// We cannot extract the message length
let message = [0u8; size_of::<u32>() - 1];
assert!(get_original_message_from_padded_text(&message)
.unwrap_err()
.to_string()
.contains("Bad padded message length"));

// The padded message is not a multiple of the base length
let message = [0u8; 2 * MESSAGE_BASE_LENGTH + 1];
assert!(get_original_message_from_padded_text(&message)
.unwrap_err()
.to_string()
.contains("Bad padded message length"));
}

#[test]
fn get_original_message_from_padded_text_successful() {
// test for short message
let message = vec![0u8, 10, 22, 11, 38, 74, 59, 91, 73, 82, 75, 23, 59];
let pad_message = pad_message_to_base_length_multiple(message.as_slice());
let pad_message = pad_message_to_base_length_multiple(message.as_slice()).unwrap();

let output_message = get_original_message_from_padded_text(pad_message.as_slice()).unwrap();
assert_eq!(message, output_message);

// test for large message
let message = vec![100u8; 1024];
let pad_message = pad_message_to_base_length_multiple(message.as_slice());
let pad_message = pad_message_to_base_length_multiple(message.as_slice()).unwrap();

let output_message = get_original_message_from_padded_text(pad_message.as_slice()).unwrap();
assert_eq!(message, output_message);

// test for base message of base length
let message = vec![100u8; 984];
let pad_message = pad_message_to_base_length_multiple(message.as_slice());
let pad_message = pad_message_to_base_length_multiple(message.as_slice()).unwrap();

let output_message = get_original_message_from_padded_text(pad_message.as_slice()).unwrap();
assert_eq!(message, output_message);

// test for empty message
let message: Vec<u8> = vec![];
let pad_message = pad_message_to_base_length_multiple(message.as_slice());
let pad_message = pad_message_to_base_length_multiple(message.as_slice()).unwrap();

let output_message = get_original_message_from_padded_text(pad_message.as_slice()).unwrap();
assert_eq!(message, output_message);
Expand All @@ -494,7 +529,7 @@ mod test {
#[test]
fn padding_fails_if_pad_message_prepend_length_is_bigger_than_plaintext_length() {
let message = "This is my secret message, keep it secret !".as_bytes();
let mut pad_message = pad_message_to_base_length_multiple(message);
let mut pad_message = pad_message_to_base_length_multiple(message).unwrap();

// we modify the prepend length, in order to assert that the get original message
// method will output a different length message
Expand All @@ -512,7 +547,7 @@ mod test {
assert!(get_original_message_from_padded_text(pad_message.as_slice())
.unwrap_err()
.to_string()
.contains("Original length message is invalid"));
.contains("Message is too short to be unpadded"));
}

#[test]
Expand All @@ -522,7 +557,7 @@ mod test {
let pk = CommsPublicKey::default();
let key = CipherKey(*chacha20::Key::from_slice(pk.as_bytes()));
let message = "My secret message, keep it secret !".as_bytes().to_vec();
let mut encrypted = encrypt(&key, &message);
let mut encrypted = encrypt(&key, &message).unwrap();

let n = encrypted.len();
encrypted[n - 1] += 1;
Expand All @@ -535,9 +570,9 @@ mod test {
let pk = CommsPublicKey::default();
let key = CipherKey(*chacha20::Key::from_slice(pk.as_bytes()));
let message = "My secret message, keep it secret !".as_bytes().to_vec();
let mut encrypted = encrypt(&key, &message);
let mut encrypted = encrypt(&key, &message).unwrap();

encrypted[size_of::<Nonce>() + LITTLE_ENDIAN_U32_SIZE_REPRESENTATION + 1] += 1;
encrypted[size_of::<Nonce>() + size_of::<u32>() + 1] += 1;

assert!(decrypt(&key, &encrypted).unwrap() != message);
}
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ mod test {
let node_identity2 = make_node_identity();
let ecdh_key = crypt::generate_ecdh_secret(node_identity2.secret_key(), node_identity2.public_key());
let key_message = crypt::generate_key_message(&ecdh_key);
let encrypted_bytes = crypt::encrypt(&key_message, &msg.to_encoded_bytes());
let encrypted_bytes = crypt::encrypt(&key_message, &msg.to_encoded_bytes()).unwrap();
let dht_envelope = make_dht_envelope(
&node_identity2,
encrypted_bytes,
Expand Down
4 changes: 2 additions & 2 deletions comms/dht/src/inbound/decryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ mod test {
let key_message = crypt::generate_key_message(&shared_secret);
let msg_tag = MessageTag::new();

let message = crypt::encrypt(&key_message, &plain_text_msg);
let message = crypt::encrypt(&key_message, &plain_text_msg).unwrap();
let header = make_dht_header(
&node_identity,
&e_public_key,
Expand Down Expand Up @@ -711,7 +711,7 @@ mod test {
let key_message = crypt::generate_key_message(&shared_secret);
let msg_tag = MessageTag::new();

let message = crypt::encrypt(&key_message, &plain_text_msg);
let message = crypt::encrypt(&key_message, &plain_text_msg).unwrap();
let header = make_dht_header(
&node_identity,
&e_public_key,
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/outbound/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
// Generate key message for encryption of message
let key_message = crypt::generate_key_message(&shared_ephemeral_secret);
// Encrypt the message with the body with key message above
let encrypted_body = crypt::encrypt(&key_message, &body);
let encrypted_body = crypt::encrypt(&key_message, &body)?;

// Produce domain separated signature signature
let mac_signature = crypt::create_message_domain_separated_hash_parts(
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/test_utils/makers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ pub fn make_dht_envelope(
if flags.is_encrypted() {
let shared_secret = crypt::generate_ecdh_secret(&e_secret_key, node_identity.public_key());
let key_message = crypt::generate_key_message(&shared_secret);
message = crypt::encrypt(&key_message, &message);
message = crypt::encrypt(&key_message, &message).unwrap();
}
let header = make_dht_header(
node_identity,
Expand Down