From 46d67d9a7fe61e44b1f6993806e2c3ecf24d99f6 Mon Sep 17 00:00:00 2001 From: Konrad Borowski Date: Sat, 15 Sep 2018 12:41:11 +0200 Subject: [PATCH] Prevent too deep recursion --- src/de.rs | 32 ++++++++++++++++++++++-------- src/error.rs | 10 ++++++++++ tests/test_error.rs | 48 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 8 deletions(-) diff --git a/src/de.rs b/src/de.rs index 24249324..bd31fffb 100644 --- a/src/de.rs +++ b/src/de.rs @@ -79,6 +79,7 @@ struct Deserializer<'a> { aliases: &'a BTreeMap, pos: &'a mut usize, path: Path<'a>, + remaining_depth: u8, } impl<'a> Deserializer<'a> { @@ -109,6 +110,7 @@ impl<'a> Deserializer<'a> { aliases: self.aliases, pos: pos, path: Path::Alias { parent: &self.path }, + remaining_depth: self.remaining_depth, }) } None => panic!("unresolved alias: {}", *pos), @@ -161,11 +163,11 @@ impl<'a> Deserializer<'a> { where V: Visitor<'de>, { - let (value, len) = { - let mut seq = SeqAccess { de: self, len: 0 }; + let (value, len) = self.recursion_check(|de| { + let mut seq = SeqAccess { de: de, len: 0 }; let value = visitor.visit_seq(&mut seq)?; - (value, seq.len) - }; + Ok((value, seq.len)) + })?; self.end_sequence(len)?; Ok(value) } @@ -174,15 +176,15 @@ impl<'a> Deserializer<'a> { where V: Visitor<'de>, { - let (value, len) = { + let (value, len) = self.recursion_check(|de| { let mut map = MapAccess { - de: &mut *self, + de: de, len: 0, key: None, }; let value = visitor.visit_map(&mut map)?; - (value, map.len) - }; + Ok((value, map.len)) + })?; self.end_mapping(len)?; Ok(value) } @@ -238,6 +240,16 @@ impl<'a> Deserializer<'a> { Err(de::Error::invalid_length(total, &ExpectedMap(len))) } } + + fn recursion_check Result, T>(&mut self, f: F) -> Result { + let previous_depth = self.remaining_depth; + self.remaining_depth = previous_depth + .checked_sub(1) + .ok_or_else(Error::recursion_limit_exceeded)?; + let result = f(self); + self.remaining_depth = previous_depth; + result + } } fn visit_scalar<'de, V>( @@ -303,6 +315,7 @@ impl<'de, 'a, 'r> de::SeqAccess<'de> for SeqAccess<'a, 'r> { parent: &self.de.path, index: self.len, }, + remaining_depth: self.de.remaining_depth, }; self.len += 1; seed.deserialize(&mut element_de).map(Some) @@ -357,6 +370,7 @@ impl<'de, 'a, 'r> de::MapAccess<'de> for MapAccess<'a, 'r> { parent: &self.de.path, } }, + remaining_depth: self.de.remaining_depth, }; seed.deserialize(&mut value_de) } @@ -409,6 +423,7 @@ impl<'de, 'a, 'r> de::EnumAccess<'de> for EnumAccess<'a, 'r> { parent: &self.de.path, key: variant, }, + remaining_depth: self.de.remaining_depth, }; Ok((ret, variant_visitor)) } @@ -949,6 +964,7 @@ where aliases: &loader.aliases, pos: &mut pos, path: Path::Root, + remaining_depth: 128, })?; if pos == loader.events.len() { Ok(t) diff --git a/src/error.rs b/src/error.rs index 5555632e..b950201c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -41,6 +41,7 @@ pub enum ErrorImpl { EndOfStream, MoreThanOneDocument, + RecursionLimitExceeded, } #[derive(Debug)] @@ -157,6 +158,10 @@ impl Error { Error(Box::new(ErrorImpl::FromUtf8(err))) } + pub(crate) fn recursion_limit_exceeded() -> Error { + Error(Box::new(ErrorImpl::RecursionLimitExceeded)) + } + // Not public API. Should be pub(crate). #[doc(hidden)] pub fn fix_marker(mut self, marker: Marker, path: Path) -> Self { @@ -183,6 +188,7 @@ impl error::Error for Error { ErrorImpl::MoreThanOneDocument => { "deserializing from YAML containing more than one document is not supported" } + ErrorImpl::RecursionLimitExceeded => "recursion limit exceeded", } } @@ -218,6 +224,7 @@ impl Display for Error { ErrorImpl::MoreThanOneDocument => f.write_str( "deserializing from YAML containing more than one document is not supported", ), + ErrorImpl::RecursionLimitExceeded => f.write_str("recursion limit exceeded"), } } } @@ -241,6 +248,9 @@ impl Debug for Error { } ErrorImpl::EndOfStream => formatter.debug_tuple("EndOfStream").finish(), ErrorImpl::MoreThanOneDocument => formatter.debug_tuple("MoreThanOneDocument").finish(), + ErrorImpl::RecursionLimitExceeded => { + formatter.debug_tuple("RecursionLimitExceeded").finish() + } } } } diff --git a/tests/test_error.rs b/tests/test_error.rs index 796d9ba1..a20f7037 100644 --- a/tests/test_error.rs +++ b/tests/test_error.rs @@ -257,3 +257,51 @@ fn test_invalid_scalar_type() { let expected = "x: invalid type: unit value, expected an array of length 1 at line 2 column 1"; test_error::(yaml, expected); } + +#[test] +fn test_infinite_recursion_objects() { + #[derive(Deserialize, Debug)] + struct S { + x: Option>, + } + + let yaml = "&a {x: *a}"; + let expected = "recursion limit exceeded"; + test_error::(yaml, expected); +} + +#[test] +fn test_infinite_recursion_arrays() { + #[derive(Deserialize, Debug)] + struct S { + x: Option>, + } + + let yaml = "&a [*a]"; + let expected = "recursion limit exceeded"; + test_error::(yaml, expected); +} + +#[test] +fn test_finite_recursion_objects() { + #[derive(Deserialize, Debug)] + struct S { + x: Option>, + } + + let yaml = "{x:".repeat(1_000) + &"}".repeat(1_000); + let expected = "recursion limit exceeded"; + test_error::(&yaml, expected); +} + +#[test] +fn test_finite_recursion_arrays() { + #[derive(Deserialize, Debug)] + struct S { + x: Option>, + } + + let yaml = "[".repeat(1_000) + &"]".repeat(1_000); + let expected = "recursion limit exceeded"; + test_error::(&yaml, expected); +}