diff --git a/Cargo.lock b/Cargo.lock index 061239f1..2ea4c2f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,7 +40,7 @@ checksum = "6d3e73c93c3240c0bda063c239298e633114c69a888c3e37ca8bb33f343e9890" [[package]] name = "serde-json-wasm" -version = "1.0.0" +version = "1.0.1" dependencies = [ "serde", "serde_derive", diff --git a/Cargo.toml b/Cargo.toml index 8389629c..14daa54d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ license = "MIT OR Apache-2.0" name = "serde-json-wasm" readme = "README.md" repository = "https://github.com/CosmWasm/serde-json-wasm" -version = "1.0.0" +version = "1.0.1" exclude = [ ".cargo/", ".github/", diff --git a/src/de/errors.rs b/src/de/errors.rs index a167dc4d..d66d0101 100644 --- a/src/de/errors.rs +++ b/src/de/errors.rs @@ -75,6 +75,9 @@ pub enum Error { /// JSON has a comma after the last value in an array or map. TrailingComma, + /// JSON is nested too deeply, exceeded the recursion limit. + RecursionLimitExceeded, + /// Custom error message from serde Custom(String), } @@ -128,6 +131,7 @@ impl core::fmt::Display for Error { value." } Error::TrailingComma => "JSON has a comma after the last value in an array or map.", + Error::RecursionLimitExceeded => "JSON is nested too deeply, exceeded the recursion limit.", Error::Custom(msg) => msg, } ) diff --git a/src/de/mod.rs b/src/de/mod.rs index 1919cfcd..1eb599e2 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -21,6 +21,9 @@ use self::seq::SeqAccess; pub struct Deserializer<'b> { slice: &'b [u8], index: usize, + + /// Remaining depth until we hit the recursion limit + remaining_depth: u8, } enum StringLike<'a> { @@ -30,7 +33,11 @@ enum StringLike<'a> { impl<'a> Deserializer<'a> { fn new(slice: &'a [u8]) -> Deserializer<'_> { - Deserializer { slice, index: 0 } + Deserializer { + slice, + index: 0, + remaining_depth: 128, + } } fn eat_char(&mut self) { @@ -287,16 +294,22 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { } } b'[' => { - self.eat_char(); - let ret = visitor.visit_seq(SeqAccess::new(self))?; + check_recursion! { + self.eat_char(); + let ret = visitor.visit_seq(SeqAccess::new(self)); + } + let ret = ret?; self.end_seq()?; Ok(ret) } b'{' => { - self.eat_char(); - let ret = visitor.visit_map(MapAccess::new(self))?; + check_recursion! { + self.eat_char(); + let ret = visitor.visit_map(MapAccess::new(self)); + } + let ret = ret?; self.end_map()?; @@ -513,8 +526,11 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { { match self.parse_whitespace().ok_or(Error::EofWhileParsingValue)? { b'[' => { - self.eat_char(); - let ret = visitor.visit_seq(SeqAccess::new(self))?; + check_recursion! { + self.eat_char(); + let ret = visitor.visit_seq(SeqAccess::new(self)); + } + let ret = ret?; self.end_seq()?; @@ -550,9 +566,11 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { let peek = self.parse_whitespace().ok_or(Error::EofWhileParsingValue)?; if peek == b'{' { - self.eat_char(); - - let ret = visitor.visit_map(MapAccess::new(self))?; + check_recursion! { + self.eat_char(); + let ret = visitor.visit_map(MapAccess::new(self)); + } + let ret = ret?; self.end_map()?; @@ -588,8 +606,11 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { b'"' => visitor.visit_enum(UnitVariantAccess::new(self)), // if it is a struct enum b'{' => { - self.eat_char(); - visitor.visit_enum(StructVariantAccess::new(self)) + check_recursion! { + self.eat_char(); + let value = visitor.visit_enum(StructVariantAccess::new(self)); + } + value } _ => Err(Error::ExpectedSomeIdent), } @@ -649,6 +670,20 @@ where from_slice(s.as_bytes()) } +macro_rules! check_recursion { + ($this:ident $($body:tt)*) => { + $this.remaining_depth -= 1; + if $this.remaining_depth == 0 { + return Err($crate::de::Error::RecursionLimitExceeded); + } + + $this $($body)* + + $this.remaining_depth += 1; + }; +} +pub(crate) use check_recursion; + #[cfg(test)] mod tests { use super::from_str; diff --git a/src/lib.rs b/src/lib.rs index c63ceda6..ebb488ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -227,4 +227,29 @@ mod test { item ); } + + #[test] + fn no_stack_overflow() { + const AMOUNT: usize = 2000; + let mut json = String::from(r#"{"":"#); + + #[derive(Debug, Deserialize, Serialize)] + pub struct Person { + name: String, + age: u8, + phones: Vec, + } + + for _ in 0..AMOUNT { + json.push('['); + } + for _ in 0..AMOUNT { + json.push(']'); + } + + json.push_str(r#"] }[[[[[[[[[[[[[[[[[[[[[ ""","age":35,"phones":["#); + + let err = from_str::(&json).unwrap_err(); + assert_eq!(err, crate::de::Error::RecursionLimitExceeded); + } }