From 610d7b290d90bfadbc3e7c32df94cc238574afd7 Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Mon, 17 Jul 2023 16:11:25 -0700 Subject: [PATCH 1/2] Implement FromStr for serde_yaml::Number --- src/de.rs | 6 ++--- src/error.rs | 2 ++ src/number.rs | 70 ++++++++++++++++++++++++++++++++------------------- 3 files changed, 49 insertions(+), 29 deletions(-) diff --git a/src/de.rs b/src/de.rs index e73de650..43f46781 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1063,7 +1063,7 @@ fn parse_negative_int( from_str_radix(scalar, 10).ok() } -fn parse_f64(scalar: &str) -> Option { +pub(crate) fn parse_f64(scalar: &str) -> Option { let unpositive = if let Some(unpositive) = scalar.strip_prefix('+') { if unpositive.starts_with(['+', '-']) { return None; @@ -1089,14 +1089,14 @@ fn parse_f64(scalar: &str) -> Option { None } -fn digits_but_not_number(scalar: &str) -> bool { +pub(crate) fn digits_but_not_number(scalar: &str) -> bool { // Leading zero(s) followed by numeric characters is a string according to // the YAML 1.2 spec. https://yaml.org/spec/1.2/spec.html#id2761292 let scalar = scalar.strip_prefix(['-', '+']).unwrap_or(scalar); scalar.len() > 1 && scalar.starts_with('0') && scalar[1..].bytes().all(|b| b.is_ascii_digit()) } -fn visit_int<'de, V>(visitor: V, v: &str) -> Result, V> +pub(crate) fn visit_int<'de, V>(visitor: V, v: &str) -> Result, V> where V: Visitor<'de>, { diff --git a/src/error.rs b/src/error.rs index 92c71252..01f8ed12 100644 --- a/src/error.rs +++ b/src/error.rs @@ -34,6 +34,7 @@ pub(crate) enum ErrorImpl { ScalarInMergeElement, SequenceInMergeElement, EmptyTag, + FailedToParseNumber, Shared(Arc), } @@ -239,6 +240,7 @@ impl ErrorImpl { f.write_str("expected a mapping for merging, but found sequence") } ErrorImpl::EmptyTag => f.write_str("empty YAML tag is not allowed"), + ErrorImpl::FailedToParseNumber => f.write_str("failed to parse YAML number"), ErrorImpl::Shared(_) => unreachable!(), } } diff --git a/src/number.rs b/src/number.rs index 6a5ea263..3b0293b6 100644 --- a/src/number.rs +++ b/src/number.rs @@ -1,9 +1,11 @@ -use crate::Error; +use crate::de; +use crate::error::{self, Error, ErrorImpl}; use serde::de::{Unexpected, Visitor}; use serde::{forward_to_deserialize_any, Deserialize, Deserializer, Serialize, Serializer}; use std::cmp::Ordering; use std::fmt::{self, Display}; use std::hash::{Hash, Hasher}; +use std::str::FromStr; /// Represents a YAML number, whether integer or floating point. #[derive(Clone, PartialEq, PartialOrd)] @@ -308,6 +310,22 @@ impl Display for Number { } } +impl FromStr for Number { + type Err = Error; + + fn from_str(repr: &str) -> Result { + if let Ok(result) = de::visit_int(NumberVisitor, repr) { + return result; + } + if !de::digits_but_not_number(repr) { + if let Some(float) = de::parse_f64(repr) { + return Ok(float.into()); + } + } + Err(error::new(ErrorImpl::FailedToParseNumber)) + } +} + impl PartialEq for N { fn eq(&self, other: &N) -> bool { match (*self, *other) { @@ -389,37 +407,37 @@ impl Serialize for Number { } } -impl<'de> Deserialize<'de> for Number { - #[inline] - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct NumberVisitor; +struct NumberVisitor; - impl<'de> Visitor<'de> for NumberVisitor { - type Value = Number; +impl<'de> Visitor<'de> for NumberVisitor { + type Value = Number; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a number") - } + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a number") + } - #[inline] - fn visit_i64(self, value: i64) -> Result { - Ok(value.into()) - } + #[inline] + fn visit_i64(self, value: i64) -> Result { + Ok(value.into()) + } - #[inline] - fn visit_u64(self, value: u64) -> Result { - Ok(value.into()) - } + #[inline] + fn visit_u64(self, value: u64) -> Result { + Ok(value.into()) + } - #[inline] - fn visit_f64(self, value: f64) -> Result { - Ok(value.into()) - } - } + #[inline] + fn visit_f64(self, value: f64) -> Result { + Ok(value.into()) + } +} +impl<'de> Deserialize<'de> for Number { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { deserializer.deserialize_any(NumberVisitor) } } From 3c681651e10c33f3a8065b4e3b64c903c73f5252 Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Mon, 17 Jul 2023 16:15:04 -0700 Subject: [PATCH 2/2] Add test of Number parsing --- tests/test_de.rs | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_de.rs b/tests/test_de.rs index 24845d32..840188fb 100644 --- a/tests/test_de.rs +++ b/tests/test_de.rs @@ -8,7 +8,7 @@ use indoc::indoc; use serde_derive::Deserialize; -use serde_yaml::{Deserializer, Value}; +use serde_yaml::{Deserializer, Number, Value}; use std::collections::BTreeMap; use std::fmt::Debug; @@ -676,3 +676,30 @@ fn test_tag_resolution() { test_de(yaml, &expected); } + +#[test] +fn test_parse_number() { + let n = "111".parse::().unwrap(); + assert_eq!(n, Number::from(111)); + + let n = "-111".parse::().unwrap(); + assert_eq!(n, Number::from(-111)); + + let n = "-1.1".parse::().unwrap(); + assert_eq!(n, Number::from(-1.1)); + + let n = ".nan".parse::().unwrap(); + assert_eq!(n, Number::from(f64::NAN)); + + let n = ".inf".parse::().unwrap(); + assert_eq!(n, Number::from(f64::INFINITY)); + + let n = "-.inf".parse::().unwrap(); + assert_eq!(n, Number::from(f64::NEG_INFINITY)); + + let err = "null".parse::().unwrap_err(); + assert_eq!(err.to_string(), "failed to parse YAML number"); + + let err = " 1 ".parse::().unwrap_err(); + assert_eq!(err.to_string(), "failed to parse YAML number"); +}