From 3dcdf001d1ae75aab4f8ac06ebcc15f124b71273 Mon Sep 17 00:00:00 2001 From: Zekun Li Date: Tue, 6 Dec 2022 10:07:31 +0800 Subject: [PATCH] support customized max container depth --- src/de.rs | 28 ++++++++++++++++++++++++++++ src/lib.rs | 7 +++++-- src/ser.rs | 42 ++++++++++++++++++++++++++++++++++++++++++ tests/serde.rs | 28 +++++++++++++++++++++++++++- 4 files changed, 102 insertions(+), 3 deletions(-) diff --git a/src/de.rs b/src/de.rs index 997900d..14d5fa5 100644 --- a/src/de.rs +++ b/src/de.rs @@ -43,6 +43,20 @@ where deserializer.end().map(move |_| t) } +/// Same as `from_bytes` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH` +/// Note that `limit` has to be lower than MAX_CONTAINER_DEPTH +pub fn from_bytes_with_limit<'a, T>(bytes: &'a [u8], limit: usize) -> Result +where + T: Deserialize<'a>, +{ + if limit > crate::MAX_CONTAINER_DEPTH { + return Err(Error::NotSupported("limit exceeds the max allowed depth")); + } + let mut deserializer = Deserializer::new(bytes, limit); + let t = T::deserialize(&mut deserializer)?; + deserializer.end().map(move |_| t) +} + /// Perform a stateful deserialization from a `&[u8]` using the provided `seed`. pub fn from_bytes_seed<'a, T>(seed: T, bytes: &'a [u8]) -> Result where @@ -53,6 +67,20 @@ where deserializer.end().map(move |_| t) } +/// Same as `from_bytes_seed` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH` +/// Note that `limit` has to be lower than MAX_CONTAINER_DEPTH +pub fn from_bytes_seed_with_limit<'a, T>(seed: T, bytes: &'a [u8], limit: usize) -> Result +where + T: DeserializeSeed<'a>, +{ + if limit > crate::MAX_CONTAINER_DEPTH { + return Err(Error::NotSupported("limit exceeds the max allowed depth")); + } + let mut deserializer = Deserializer::new(bytes, limit); + let t = seed.deserialize(&mut deserializer)?; + deserializer.end().map(move |_| t) +} + /// Deserialization implementation for BCS struct Deserializer<'de> { input: &'de [u8], diff --git a/src/lib.rs b/src/lib.rs index c8ee6a4..96a12ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -314,6 +314,9 @@ pub const MAX_SEQUENCE_LENGTH: usize = (1 << 31) - 1; /// Maximal allowed depth of BCS data, counting only structs and enums. pub const MAX_CONTAINER_DEPTH: usize = 500; -pub use de::{from_bytes, from_bytes_seed}; +pub use de::{from_bytes, from_bytes_seed, from_bytes_seed_with_limit, from_bytes_with_limit}; pub use error::{Error, Result}; -pub use ser::{is_human_readable, serialize_into, serialized_size, to_bytes}; +pub use ser::{ + is_human_readable, serialize_into, serialize_into_with_limit, serialized_size, + serialized_size_with_limit, to_bytes, to_bytes_with_limit, +}; diff --git a/src/ser.rs b/src/ser.rs index b119926..9bd42ce 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -55,6 +55,20 @@ where Ok(output) } +/// Same as `to_bytes` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH +/// Note that `limit` has to be lower than MAX_CONTAINER_DEPTH +pub fn to_bytes_with_limit(value: &T, limit: usize) -> Result> +where + T: ?Sized + Serialize, +{ + if limit > crate::MAX_CONTAINER_DEPTH { + return Err(Error::NotSupported("limit exceeds the max allowed depth")); + } + let mut output = Vec::new(); + serialize_into_with_limit(&mut output, value, limit)?; + Ok(output) +} + /// Same as `to_bytes` but write directly into an `std::io::Write` object. pub fn serialize_into(write: &mut W, value: &T) -> Result<()> where @@ -65,6 +79,20 @@ where value.serialize(serializer) } +/// Same as `serialize_into` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH +/// Note that `limit` has to be lower than MAX_CONTAINER_DEPTH +pub fn serialize_into_with_limit(write: &mut W, value: &T, limit: usize) -> Result<()> +where + W: ?Sized + std::io::Write, + T: ?Sized + Serialize, +{ + if limit > crate::MAX_CONTAINER_DEPTH { + return Err(Error::NotSupported("limit exceeds the max allowed depth")); + } + let serializer = Serializer::new(write, limit); + value.serialize(serializer) +} + struct WriteCounter(usize); impl std::io::Write for WriteCounter { @@ -91,6 +119,20 @@ where Ok(counter.0) } +/// Same as `serialized_size` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH +/// Note that `limit` has to be lower than MAX_CONTAINER_DEPTH +pub fn serialized_size_with_limit(value: &T, limit: usize) -> Result +where + T: ?Sized + Serialize, +{ + if limit > crate::MAX_CONTAINER_DEPTH { + return Err(Error::NotSupported("limit exceeds the max allowed depth")); + } + let mut counter = WriteCounter(0); + serialize_into_with_limit(&mut counter, value, limit)?; + Ok(counter.0) +} + pub fn is_human_readable() -> bool { let mut output = Vec::new(); let serializer = Serializer::new(&mut output, crate::MAX_CONTAINER_DEPTH); diff --git a/tests/serde.rs b/tests/serde.rs index d3d5f0c..73d78a4 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -13,7 +13,10 @@ use proptest::prelude::*; use proptest_derive::Arbitrary; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use bcs::{from_bytes, serialized_size, to_bytes, Error, MAX_CONTAINER_DEPTH, MAX_SEQUENCE_LENGTH}; +use bcs::{ + from_bytes, from_bytes_with_limit, serialized_size, to_bytes, to_bytes_with_limit, Error, + MAX_CONTAINER_DEPTH, MAX_SEQUENCE_LENGTH, +}; fn is_same(t: T) where @@ -654,6 +657,29 @@ fn test_recursion_limit() { to_bytes(&(&l3, &l3)), Err(Error::ExceededContainerDepthLimit("List")) ); + + // test customized limit + let limit = 100; + let not_supported_err = Error::NotSupported("limit exceeds the max allowed depth"); + let l4 = List::integers(limit); + assert_eq!( + to_bytes_with_limit(&l4, limit), + Err(Error::ExceededContainerDepthLimit("List")) + ); + assert_eq!( + to_bytes_with_limit(&l4, MAX_CONTAINER_DEPTH + 1), + Err(not_supported_err.clone()), + ); + let bytes = to_bytes_with_limit(&l4, limit + 1).unwrap(); + assert_eq!( + from_bytes_with_limit::>(&bytes, limit), + Err(Error::ExceededContainerDepthLimit("List")) + ); + assert_eq!(from_bytes_with_limit(&bytes, limit + 1), Ok(l4)); + assert_eq!( + from_bytes_with_limit::>(&bytes, MAX_CONTAINER_DEPTH + 1), + Err(not_supported_err) + ); } #[derive(Deserialize, Serialize, Clone, PartialEq, Eq, Debug)]