diff --git a/CHANGELOG.md b/CHANGELOG.md index bd64635..d29a715 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- New `decode_raw` static methods to `Header` + ## [0.3.5] - 2024-05-22 ### Changed diff --git a/crates/rlp/src/header.rs b/crates/rlp/src/header.rs index 4938ad3..db62878 100644 --- a/crates/rlp/src/header.rs +++ b/crates/rlp/src/header.rs @@ -89,11 +89,7 @@ impl Header { } // SAFETY: this is already checked in `decode` - if buf.remaining() < payload_length { - unsafe { unreachable_unchecked() } - } - let bytes = unsafe { buf.get_unchecked(..payload_length) }; - buf.advance(payload_length); + let bytes = unsafe { advance_unchecked(buf, payload_length) }; Ok(bytes) } @@ -108,6 +104,40 @@ impl Header { core::str::from_utf8(bytes).map_err(|_| Error::Custom("invalid string")) } + /// Extracts the next payload from the given buffer, advancing it. + /// + /// # Errors + /// + /// Returns an error if the buffer is too short, the header is invalid or one of the headers one + /// level deeper is invalid. + #[inline] + pub fn decode_raw<'a>(buf: &mut &'a [u8]) -> Result> { + let Self { list, payload_length } = Self::decode(buf)?; + // SAFETY: this is already checked in `decode` + let mut payload = unsafe { advance_unchecked(buf, payload_length) }; + + if !list { + return Ok(PayloadView::String(payload)); + } + + let mut items = alloc::vec::Vec::new(); + while !payload.is_empty() { + // decode the next header without advancing in the payload + let Self { payload_length, .. } = Self::decode(&mut &payload[..])?; + // the length of the RLP encoding is the length of the header plus its payload length + // if payload length is 1 and the first byte is in [0x00, 0x7F], then there is no header + let rlp_length = if payload_length == 1 && payload[0] <= 0x7F { + 1 + } else { + payload_length + crate::length_of_length(payload_length) + }; + items.push(&payload[..rlp_length]); + payload.advance(rlp_length); + } + + return Ok(PayloadView::List(items)); + } + /// Encodes the header into the `out` buffer. #[inline] pub fn encode(&self, out: &mut dyn BufMut) { @@ -130,6 +160,12 @@ impl Header { } } +/// Structured representation of an RLP payload. +pub enum PayloadView<'a> { + String(&'a [u8]), + List(alloc::vec::Vec<&'a [u8]>), +} + /// Same as `buf.first().ok_or(Error::InputTooShort)`. #[inline(always)] fn get_next_byte(buf: &[u8]) -> Result { @@ -139,3 +175,71 @@ fn get_next_byte(buf: &[u8]) -> Result { // SAFETY: length checked above Ok(*unsafe { buf.get_unchecked(0) }) } + +/// Same as `let (bytes, rest) = buf.split_at(cnt); *buf = rest; bytes`. +#[inline(always)] +unsafe fn advance_unchecked<'a>(buf: &mut &'a [u8], cnt: usize) -> &'a [u8] { + if buf.remaining() < cnt { + unreachable_unchecked() + } + let bytes = &buf[..cnt]; + buf.advance(cnt); + bytes +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Encodable; + use alloc::vec::Vec; + use core::fmt::Debug; + + fn check_decode_raw_list(input: Vec) { + let encoded = crate::encode(&input); + let expected: Vec<_> = input.iter().map(crate::encode).collect(); + let mut buf = encoded.as_slice(); + assert!( + matches!(Header::decode_raw(&mut buf), Ok(PayloadView::List(v)) if v == expected), + "input: {:?}, expected list: {:?}", + input, + expected + ); + assert!(buf.is_empty(), "buffer was not advanced"); + } + + fn check_decode_raw_string(input: &str) { + let encoded = crate::encode(input); + let expected = Header::decode_bytes(&mut &encoded[..], false).unwrap(); + let mut buf = encoded.as_slice(); + assert!( + matches!(Header::decode_raw(&mut buf), Ok(PayloadView::String(v)) if v == expected), + "input: {}, expected string: {:?}", + input, + expected + ); + assert!(buf.is_empty(), "buffer was not advanced"); + } + + #[test] + fn decode_raw() { + // empty list + check_decode_raw_list(Vec::::new()); + // list of an empty RLP list + check_decode_raw_list(vec![Vec::::new()]); + // list of an empty RLP string + check_decode_raw_list(vec![""]); + // list of two RLP strings + check_decode_raw_list(vec![0xBBCCB5_u64, 0xFFC0B5_u64]); + // list of three RLP lists of various lengths + check_decode_raw_list(vec![vec![0u64], vec![1u64, 2u64], vec![3u64, 4u64, 5u64]]); + // list of four empty RLP strings + check_decode_raw_list(vec![0u64; 4]); + // list of all one-byte strings, some will have an RLP header and some won't + check_decode_raw_list((0u64..0xFF).collect()); + + // strings of various lengths + check_decode_raw_string(""); + check_decode_raw_string(" "); + check_decode_raw_string("test1234"); + } +}