From 3dcdf001d1ae75aab4f8ac06ebcc15f124b71273 Mon Sep 17 00:00:00 2001
From: Zekun Li
Date: Tue, 6 Dec 2022 10:07:31 +0800
Subject: [PATCH] support customized max container depth
---
src/de.rs | 28 ++++++++++++++++++++++++++++
src/lib.rs | 7 +++++--
src/ser.rs | 42 ++++++++++++++++++++++++++++++++++++++++++
tests/serde.rs | 28 +++++++++++++++++++++++++++-
4 files changed, 102 insertions(+), 3 deletions(-)
diff --git a/src/de.rs b/src/de.rs
index 997900d..14d5fa5 100644
--- a/src/de.rs
+++ b/src/de.rs
@@ -43,6 +43,20 @@ where
deserializer.end().map(move |_| t)
}
+/// Same as `from_bytes` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH`
+/// Note that `limit` has to be lower than MAX_CONTAINER_DEPTH
+pub fn from_bytes_with_limit<'a, T>(bytes: &'a [u8], limit: usize) -> Result
+where
+ T: Deserialize<'a>,
+{
+ if limit > crate::MAX_CONTAINER_DEPTH {
+ return Err(Error::NotSupported("limit exceeds the max allowed depth"));
+ }
+ let mut deserializer = Deserializer::new(bytes, limit);
+ let t = T::deserialize(&mut deserializer)?;
+ deserializer.end().map(move |_| t)
+}
+
/// Perform a stateful deserialization from a `&[u8]` using the provided `seed`.
pub fn from_bytes_seed<'a, T>(seed: T, bytes: &'a [u8]) -> Result
where
@@ -53,6 +67,20 @@ where
deserializer.end().map(move |_| t)
}
+/// Same as `from_bytes_seed` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH`
+/// Note that `limit` has to be lower than MAX_CONTAINER_DEPTH
+pub fn from_bytes_seed_with_limit<'a, T>(seed: T, bytes: &'a [u8], limit: usize) -> Result
+where
+ T: DeserializeSeed<'a>,
+{
+ if limit > crate::MAX_CONTAINER_DEPTH {
+ return Err(Error::NotSupported("limit exceeds the max allowed depth"));
+ }
+ let mut deserializer = Deserializer::new(bytes, limit);
+ let t = seed.deserialize(&mut deserializer)?;
+ deserializer.end().map(move |_| t)
+}
+
/// Deserialization implementation for BCS
struct Deserializer<'de> {
input: &'de [u8],
diff --git a/src/lib.rs b/src/lib.rs
index c8ee6a4..96a12ee 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -314,6 +314,9 @@ pub const MAX_SEQUENCE_LENGTH: usize = (1 << 31) - 1;
/// Maximal allowed depth of BCS data, counting only structs and enums.
pub const MAX_CONTAINER_DEPTH: usize = 500;
-pub use de::{from_bytes, from_bytes_seed};
+pub use de::{from_bytes, from_bytes_seed, from_bytes_seed_with_limit, from_bytes_with_limit};
pub use error::{Error, Result};
-pub use ser::{is_human_readable, serialize_into, serialized_size, to_bytes};
+pub use ser::{
+ is_human_readable, serialize_into, serialize_into_with_limit, serialized_size,
+ serialized_size_with_limit, to_bytes, to_bytes_with_limit,
+};
diff --git a/src/ser.rs b/src/ser.rs
index b119926..9bd42ce 100644
--- a/src/ser.rs
+++ b/src/ser.rs
@@ -55,6 +55,20 @@ where
Ok(output)
}
+/// Same as `to_bytes` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH
+/// Note that `limit` has to be lower than MAX_CONTAINER_DEPTH
+pub fn to_bytes_with_limit(value: &T, limit: usize) -> Result>
+where
+ T: ?Sized + Serialize,
+{
+ if limit > crate::MAX_CONTAINER_DEPTH {
+ return Err(Error::NotSupported("limit exceeds the max allowed depth"));
+ }
+ let mut output = Vec::new();
+ serialize_into_with_limit(&mut output, value, limit)?;
+ Ok(output)
+}
+
/// Same as `to_bytes` but write directly into an `std::io::Write` object.
pub fn serialize_into(write: &mut W, value: &T) -> Result<()>
where
@@ -65,6 +79,20 @@ where
value.serialize(serializer)
}
+/// Same as `serialize_into` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH
+/// Note that `limit` has to be lower than MAX_CONTAINER_DEPTH
+pub fn serialize_into_with_limit(write: &mut W, value: &T, limit: usize) -> Result<()>
+where
+ W: ?Sized + std::io::Write,
+ T: ?Sized + Serialize,
+{
+ if limit > crate::MAX_CONTAINER_DEPTH {
+ return Err(Error::NotSupported("limit exceeds the max allowed depth"));
+ }
+ let serializer = Serializer::new(write, limit);
+ value.serialize(serializer)
+}
+
struct WriteCounter(usize);
impl std::io::Write for WriteCounter {
@@ -91,6 +119,20 @@ where
Ok(counter.0)
}
+/// Same as `serialized_size` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH
+/// Note that `limit` has to be lower than MAX_CONTAINER_DEPTH
+pub fn serialized_size_with_limit(value: &T, limit: usize) -> Result
+where
+ T: ?Sized + Serialize,
+{
+ if limit > crate::MAX_CONTAINER_DEPTH {
+ return Err(Error::NotSupported("limit exceeds the max allowed depth"));
+ }
+ let mut counter = WriteCounter(0);
+ serialize_into_with_limit(&mut counter, value, limit)?;
+ Ok(counter.0)
+}
+
pub fn is_human_readable() -> bool {
let mut output = Vec::new();
let serializer = Serializer::new(&mut output, crate::MAX_CONTAINER_DEPTH);
diff --git a/tests/serde.rs b/tests/serde.rs
index d3d5f0c..73d78a4 100644
--- a/tests/serde.rs
+++ b/tests/serde.rs
@@ -13,7 +13,10 @@ use proptest::prelude::*;
use proptest_derive::Arbitrary;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
-use bcs::{from_bytes, serialized_size, to_bytes, Error, MAX_CONTAINER_DEPTH, MAX_SEQUENCE_LENGTH};
+use bcs::{
+ from_bytes, from_bytes_with_limit, serialized_size, to_bytes, to_bytes_with_limit, Error,
+ MAX_CONTAINER_DEPTH, MAX_SEQUENCE_LENGTH,
+};
fn is_same(t: T)
where
@@ -654,6 +657,29 @@ fn test_recursion_limit() {
to_bytes(&(&l3, &l3)),
Err(Error::ExceededContainerDepthLimit("List"))
);
+
+ // test customized limit
+ let limit = 100;
+ let not_supported_err = Error::NotSupported("limit exceeds the max allowed depth");
+ let l4 = List::integers(limit);
+ assert_eq!(
+ to_bytes_with_limit(&l4, limit),
+ Err(Error::ExceededContainerDepthLimit("List"))
+ );
+ assert_eq!(
+ to_bytes_with_limit(&l4, MAX_CONTAINER_DEPTH + 1),
+ Err(not_supported_err.clone()),
+ );
+ let bytes = to_bytes_with_limit(&l4, limit + 1).unwrap();
+ assert_eq!(
+ from_bytes_with_limit::>(&bytes, limit),
+ Err(Error::ExceededContainerDepthLimit("List"))
+ );
+ assert_eq!(from_bytes_with_limit(&bytes, limit + 1), Ok(l4));
+ assert_eq!(
+ from_bytes_with_limit::>(&bytes, MAX_CONTAINER_DEPTH + 1),
+ Err(not_supported_err)
+ );
}
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, Debug)]