Skip to content

Commit

Permalink
Refactor header.rs
Browse files Browse the repository at this point in the history
1. Removed `NoiseHeader` struct in favor of three constants defined at
the top of the file.
2. Added documentation and changed visibility to `pub(crate)` where
needed.
3. Removed `Header::Default` and `Sv2Frame::Default` impls as they are
unused.
4. Removed `unwrap()`s
  • Loading branch information
jbesraa authored and plebhash committed Jul 3, 2024
1 parent c959f42 commit e4112fb
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 67 deletions.
18 changes: 9 additions & 9 deletions protocols/v2/codec-sv2/src/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use core::marker::PhantomData;
#[cfg(feature = "noise_sv2")]
use framing_sv2::framing2::HandShakeFrame;
#[cfg(feature = "noise_sv2")]
use framing_sv2::header::NoiseHeader;
use framing_sv2::header::{NOISE_HEADER_ENCRYPTED_SIZE, NOISE_HEADER_SIZE};
use framing_sv2::{
framing2::{EitherFrame, Frame as F_, Sv2Frame},
header::Header,
Expand Down Expand Up @@ -58,7 +58,7 @@ impl<'a, T: Serialize + GetSize + Deserialize<'a>, B: IsBuffer + AeadBuffer> Wit
let hint = *msg_len - self.noise_buffer.as_ref().len();
match hint {
0 => {
self.missing_noise_b = NoiseHeader::HEADER_SIZE;
self.missing_noise_b = NOISE_HEADER_SIZE;
Ok(self.while_handshaking())
}
_ => {
Expand All @@ -71,20 +71,20 @@ impl<'a, T: Serialize + GetSize + Deserialize<'a>, B: IsBuffer + AeadBuffer> Wit
let hint = if IsBuffer::len(&self.sv2_buffer) < SV2_FRAME_HEADER_SIZE {
let len = IsBuffer::len(&self.noise_buffer);
let src = self.noise_buffer.get_data_by_ref(len);
if src.len() < NoiseHeader::SIZE {
NoiseHeader::SIZE - src.len()
if src.len() < NOISE_HEADER_ENCRYPTED_SIZE {
NOISE_HEADER_ENCRYPTED_SIZE - src.len()
} else {
0
}
} else {
let src = self.sv2_buffer.get_data_by_ref_(SV2_FRAME_HEADER_SIZE);
let src = self.sv2_buffer.get_data_by_ref(SV2_FRAME_HEADER_SIZE);
let header = Header::from_bytes(src)?;
header.encrypted_len() - IsBuffer::len(&self.noise_buffer)
};

match hint {
0 => {
self.missing_noise_b = NoiseHeader::SIZE;
self.missing_noise_b = NOISE_HEADER_ENCRYPTED_SIZE;
self.decode_noise_frame(noise_codec)
}
_ => {
Expand All @@ -106,14 +106,14 @@ impl<'a, T: Serialize + GetSize + Deserialize<'a>, B: IsBuffer + AeadBuffer> Wit
IsBuffer::len(&self.sv2_buffer),
) {
// HERE THE SV2 HEADER IS READY TO BE DECRYPTED
(NoiseHeader::SIZE, 0) => {
(NOISE_HEADER_ENCRYPTED_SIZE, 0) => {
let src = self.noise_buffer.get_data_owned();
let decrypted_header = self.sv2_buffer.get_writable(NoiseHeader::SIZE);
let decrypted_header = self.sv2_buffer.get_writable(NOISE_HEADER_ENCRYPTED_SIZE);
decrypted_header.copy_from_slice(src.as_ref());
self.sv2_buffer.as_ref();
noise_codec.decrypt(&mut self.sv2_buffer)?;
let header =
Header::from_bytes(self.sv2_buffer.get_data_by_ref_(SV2_FRAME_HEADER_SIZE))?;
Header::from_bytes(self.sv2_buffer.get_data_by_ref(SV2_FRAME_HEADER_SIZE))?;
self.missing_noise_b = header.encrypted_len();
Err(Error::MissingBytes(header.encrypted_len()))
}
Expand Down
4 changes: 2 additions & 2 deletions protocols/v2/codec-sv2/src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use core::marker::PhantomData;
use framing_sv2::framing2::{EitherFrame, HandShakeFrame};
use framing_sv2::framing2::{Frame as F_, Sv2Frame};
#[allow(unused_imports)]
pub use framing_sv2::header::NoiseHeader;
pub use framing_sv2::header::NOISE_HEADER_ENCRYPTED_SIZE;

#[cfg(feature = "noise_sv2")]
use tracing::error;
Expand Down Expand Up @@ -76,7 +76,7 @@ impl<T: Serialize + GetSize> NoiseEncoder<T> {
} else {
SV2_FRAME_CHUNK_SIZE + start - AEAD_MAC_LEN
};
let mut encrypted_len = NoiseHeader::SIZE;
let mut encrypted_len = NOISE_HEADER_ENCRYPTED_SIZE;

while start < sv2.len() {
let to_encrypt = self.noise_buffer.get_writable(end - start);
Expand Down
9 changes: 7 additions & 2 deletions protocols/v2/framing-sv2/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@ impl fmt::Display for Error {
ExpectedSv2Frame => {
write!(f, "Expected `Sv2Frame`, received `HandshakeFrame`")
}
UnexpectedHeaderLength(i) => {
write!(f, "Unexpected `Header` length: `{}`", i)
UnexpectedHeaderLength(actual_size) => {
write!(
f,
"Unexpected `Header` length: `{}`, should be equal or more to {}",
actual_size,
const_sv2::SV2_FRAME_HEADER_SIZE
)
}
}
}
Expand Down
24 changes: 7 additions & 17 deletions protocols/v2/framing-sv2/src/framing2.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
header::{Header, NoiseHeader},
header::{Header, NOISE_HEADER_LEN_OFFSET, NOISE_HEADER_SIZE},
Error,
};
use alloc::vec::Vec;
Expand Down Expand Up @@ -83,16 +83,6 @@ pub struct Sv2Frame<T, B> {
serialized: Option<B>,
}

impl<T, B> Default for Sv2Frame<T, B> {
fn default() -> Self {
Sv2Frame {
header: Header::default(),
payload: None,
serialized: None,
}
}
}

/// Abstraction for a Noise Handshake Frame
/// Contains only a `Slice` payload with a fixed length
/// Only used during Noise Handshake process
Expand Down Expand Up @@ -253,7 +243,7 @@ impl<'a> Frame<'a, Slice> for HandShakeFrame {
/// Get the Noise Frame payload
#[inline]
fn payload(&'a mut self) -> &'a mut [u8] {
&mut self.payload[NoiseHeader::HEADER_SIZE..]
&mut self.payload[NOISE_HEADER_SIZE..]
}

/// `HandShakeFrame` always returns `None`.
Expand All @@ -280,17 +270,17 @@ impl<'a> Frame<'a, Slice> for HandShakeFrame {
/// indicates the surplus of bytes beyond the expected size.
#[inline]
fn size_hint(bytes: &[u8]) -> isize {
if bytes.len() < NoiseHeader::HEADER_SIZE {
return (NoiseHeader::HEADER_SIZE - bytes.len()) as isize;
if bytes.len() < NOISE_HEADER_SIZE {
return (NOISE_HEADER_SIZE - bytes.len()) as isize;
};

let len_b = &bytes[NoiseHeader::LEN_OFFSET..NoiseHeader::HEADER_SIZE];
let len_b = &bytes[NOISE_HEADER_LEN_OFFSET..NOISE_HEADER_SIZE];
let expected_len = u16::from_le_bytes([len_b[0], len_b[1]]) as usize;

if bytes.len() - NoiseHeader::HEADER_SIZE == expected_len {
if bytes.len() - NOISE_HEADER_SIZE == expected_len {
0
} else {
expected_len as isize - (bytes.len() - NoiseHeader::HEADER_SIZE) as isize
expected_len as isize - (bytes.len() - NOISE_HEADER_SIZE) as isize
}
}

Expand Down
96 changes: 59 additions & 37 deletions protocols/v2/framing-sv2/src/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,68 +7,65 @@ use binary_sv2::{Deserialize, Serialize, U24};
use const_sv2::{AEAD_MAC_LEN, SV2_FRAME_CHUNK_SIZE};
use core::convert::TryInto;

// Previously `NoiseHeader::SIZE`
pub const NOISE_HEADER_ENCRYPTED_SIZE: usize = const_sv2::ENCRYPTED_SV2_FRAME_HEADER_SIZE;
// Previously `NoiseHeader::LEN_OFFSET`
pub const NOISE_HEADER_LEN_OFFSET: usize = const_sv2::NOISE_FRAME_HEADER_LEN_OFFSET;
// Previously `NoiseHeader::HEADER_SIZE`
pub const NOISE_HEADER_SIZE: usize = const_sv2::NOISE_FRAME_HEADER_SIZE;

/// Abstraction for a SV2 Frame Header.
#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
pub struct Header {
extension_type: u16, // TODO use specific type?
msg_type: u8, // TODO use specific type?
/// Unique identifier of the extension describing this protocol message. Most significant bit
/// (i.e.bit 15, 0-indexed, aka channel_msg) indicates a message which is specific to a channel,
/// whereas if the most significant bit is unset, the message is to be interpreted by the
/// immediate receiving device. Note that the channel_msg bit is ignored in the extension
/// lookup, i.e.an extension_type of 0x8ABC is for the same "extension" as 0x0ABC. If the
/// channel_msg bit is set, the first four bytes of the payload field is a U32 representing the
/// channel_id this message is destined for. Note that for the Job Declaration and Template
/// Distribution Protocols the channel_msg bit is always unset.
extension_type: u16, // fix: use U16 type
/// Unique identifier of the extension describing this protocol message
msg_type: u8, // fix: use specific type?
/// Length of the protocol message, not including this header
msg_length: U24,
}

impl Default for Header {
fn default() -> Self {
Header {
extension_type: 0,
msg_type: 0,
// converting 0_32 into a U24 never panic
msg_length: 0_u32.try_into().unwrap(),
}
}
}

impl Header {
pub const LEN_OFFSET: usize = const_sv2::SV2_FRAME_HEADER_LEN_OFFSET;
pub const LEN_SIZE: usize = const_sv2::SV2_FRAME_HEADER_LEN_END;
pub const LEN_END: usize = Self::LEN_OFFSET + Self::LEN_SIZE;

pub const SIZE: usize = const_sv2::SV2_FRAME_HEADER_SIZE;

/// Construct a `Header` from ray bytes
/// Construct a `Header` from raw bytes
#[inline]
pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
if bytes.len() < Self::SIZE {
return Err(Error::UnexpectedHeaderLength(
(Self::SIZE - bytes.len()) as isize,
));
return Err(Error::UnexpectedHeaderLength(bytes.len() as isize));
};

let extension_type = u16::from_le_bytes([bytes[0], bytes[1]]);
let msg_type = bytes[2];
let msg_length = u32::from_le_bytes([bytes[3], bytes[4], bytes[5], 0]);

let msg_length: U24 = u32::from_le_bytes([bytes[3], bytes[4], bytes[5], 0]).try_into()?;
Ok(Self {
extension_type,
msg_type,
// Converting and u32 with the most significant byte set to 0 to and U24 never panic
msg_length: msg_length.try_into().unwrap(),
msg_length,
})
}

/// Get the payload length
#[allow(clippy::len_without_is_empty)]
#[inline]
pub fn len(&self) -> usize {
pub(crate) fn len(&self) -> usize {
let inner: u32 = self.msg_length.into();
inner as usize
}

/// Construct a `Header` from payload length, type and extension type.
#[inline]
pub fn from_len(len: u32, message_type: u8, extension_type: u16) -> Option<Header> {
pub(crate) fn from_len(msg_length: u32, msg_type: u8, extension_type: u16) -> Option<Header> {
Some(Self {
extension_type,
msg_type: message_type,
msg_length: len.try_into().ok()?,
msg_type,
msg_length: msg_length.try_into().ok()?,
})
}

Expand All @@ -83,9 +80,11 @@ impl Header {
}

/// Check if `Header` represents a channel message
///
/// A header can represent a channel message if the MSB(Most Significant Bit) is set.
pub fn channel_msg(&self) -> bool {
let mask = 0b0000_0000_0000_0001;
self.extension_type & mask == self.extension_type
const CHANNEL_MSG_MASK: u16 = 0b0000_0000_0000_0001;
self.extension_type & CHANNEL_MSG_MASK == self.extension_type
}

/// Calculate the length of the encrypted `Header`
Expand All @@ -100,10 +99,33 @@ impl Header {
}
}

pub struct NoiseHeader {}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;

#[test]
fn test_header_from_bytes() {
let bytes = vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
let header = Header::from_bytes(&bytes).unwrap();
assert_eq!(header.extension_type, 0x0201);
assert_eq!(header.msg_type, 0x03);
assert_eq!(header.msg_length, 0x060504_u32.try_into().unwrap());
}

#[test]
fn test_header_from_len() {
let header = Header::from_len(0x1234, 0x56, 0x789a).unwrap();
assert_eq!(header.extension_type, 0x789a);
assert_eq!(header.msg_type, 0x56);
assert_eq!(header.msg_length, 0x1234_u32.try_into().unwrap());

impl NoiseHeader {
pub const SIZE: usize = const_sv2::ENCRYPTED_SV2_FRAME_HEADER_SIZE;
pub const LEN_OFFSET: usize = const_sv2::NOISE_FRAME_HEADER_LEN_OFFSET;
pub const HEADER_SIZE: usize = const_sv2::NOISE_FRAME_HEADER_SIZE;
let extension_type = 0;
let msg_type = 0x1;
let msg_length = 0x1234_u32;
let header = Header::from_len(msg_length, msg_type, extension_type).unwrap();
assert_eq!(header.extension_type, 0);
assert_eq!(header.msg_type, 0x1);
assert_eq!(header.msg_length, 0x1234_u32.try_into().unwrap());
}
}

0 comments on commit e4112fb

Please sign in to comment.