diff --git a/pbjson-test/src/lib.rs b/pbjson-test/src/lib.rs index f4a22a8..f1e8d57 100644 --- a/pbjson-test/src/lib.rs +++ b/pbjson-test/src/lib.rs @@ -63,7 +63,16 @@ mod tests { } fn verify_decode(decoded: &KitchenSink, expected: &str) { + // Decode from a string first assert_eq!(decoded, &serde_json::from_str(expected).unwrap()); + + // Then, try decoding from a Reader: this can catch issues when trying to borrow data + // from the input, which is not possible when deserializing from a Reader (e.g. an opened + // file). + assert_eq!( + decoded, + &serde_json::from_reader(expected.as_bytes()).unwrap() + ); } fn verify(decoded: &KitchenSink, expected: &str) { diff --git a/pbjson-types/src/duration.rs b/pbjson-types/src/duration.rs index d89d84f..bc2801d 100644 --- a/pbjson-types/src/duration.rs +++ b/pbjson-types/src/duration.rs @@ -1,5 +1,6 @@ use crate::Duration; -use serde::{Deserialize, Serialize}; +use serde::de::Visitor; +use serde::Serialize; impl TryFrom for std::time::Duration { type Error = std::num::TryFromIntError; @@ -55,12 +56,19 @@ impl Serialize for Duration { } } -impl<'de> serde::Deserialize<'de> for Duration { - fn deserialize(deserializer: D) -> Result +struct DurationVisitor; + +impl<'de> Visitor<'de> for DurationVisitor { + type Value = Duration; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("a duration string") + } + + fn visit_str(self, s: &str) -> Result where - D: serde::Deserializer<'de>, + E: serde::de::Error, { - let s: &str = Deserialize::deserialize(deserializer)?; let s = s .strip_suffix('s') .ok_or_else(|| serde::de::Error::custom("missing 's' suffix"))?; @@ -70,7 +78,7 @@ impl<'de> serde::Deserialize<'de> for Duration { None => (false, s), }; - let duration: Self = match s.split_once('.') { + let duration = match s.split_once('.') { Some((seconds_str, decimal_str)) => { let exp = 9_u32 .checked_sub(decimal_str.len() as u32) @@ -80,19 +88,19 @@ impl<'de> serde::Deserialize<'de> for Duration { let seconds = seconds_str.parse().map_err(serde::de::Error::custom)?; let decimal: u32 = decimal_str.parse().map_err(serde::de::Error::custom)?; - Self { + Duration { seconds, nanos: (decimal * pow) as i32, } } - None => Self { + None => Duration { seconds: s.parse().map_err(serde::de::Error::custom)?, nanos: 0, }, }; Ok(match negative { - true => Self { + true => Duration { seconds: -duration.seconds, nanos: -duration.nanos, }, @@ -101,6 +109,15 @@ impl<'de> serde::Deserialize<'de> for Duration { } } +impl<'de> serde::Deserialize<'de> for Duration { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(DurationVisitor) + } +} + /// Splits nanoseconds into whole milliseconds, microseconds, and nanoseconds fn split_nanos(mut nanos: u32) -> (u32, u32, u32) { let millis = nanos / 1_000_000; diff --git a/pbjson-types/src/timestamp.rs b/pbjson-types/src/timestamp.rs index 7be523f..7f2d22f 100644 --- a/pbjson-types/src/timestamp.rs +++ b/pbjson-types/src/timestamp.rs @@ -1,6 +1,7 @@ use crate::Timestamp; use chrono::{DateTime, NaiveDateTime, Utc}; -use serde::{Deserialize, Serialize}; +use serde::de::Visitor; +use serde::Serialize; impl TryFrom for chrono::DateTime { type Error = std::num::TryFromIntError; @@ -31,23 +32,40 @@ impl Serialize for Timestamp { } } -impl<'de> serde::Deserialize<'de> for Timestamp { - fn deserialize(deserializer: D) -> Result +struct TimestampVisitor; + +impl<'de> Visitor<'de> for TimestampVisitor { + type Value = Timestamp; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("a date string") + } + + fn visit_str(self, s: &str) -> Result where - D: serde::Deserializer<'de>, + E: serde::de::Error, { - let s: &str = Deserialize::deserialize(deserializer)?; let d = DateTime::parse_from_rfc3339(s).map_err(serde::de::Error::custom)?; let d: DateTime = d.into(); Ok(d.into()) } } +impl<'de> serde::Deserialize<'de> for Timestamp { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(TimestampVisitor) + } +} + #[cfg(test)] mod tests { use super::*; use chrono::{FixedOffset, TimeZone}; use serde::de::value::{BorrowedStrDeserializer, Error}; + use serde::Deserialize; #[test] fn test_date() { diff --git a/pbjson/src/lib.rs b/pbjson/src/lib.rs index b2f7122..0e17df0 100644 --- a/pbjson/src/lib.rs +++ b/pbjson/src/lib.rs @@ -22,7 +22,9 @@ pub mod private { /// Re-export base64 pub use base64; + use serde::de::Visitor; use serde::Deserialize; + use std::borrow::Cow; use std::str::FromStr; /// Used to parse a number from either a string or its raw representation @@ -32,7 +34,8 @@ pub mod private { #[derive(Deserialize)] #[serde(untagged)] enum Content<'a, T> { - Str(&'a str), + #[serde(borrow)] + Str(Cow<'a, str>), Number(T), } @@ -53,19 +56,19 @@ pub mod private { } } - #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)] - pub struct BytesDeserialize(pub T); + struct Base64Visitor; - impl<'de, T> Deserialize<'de> for BytesDeserialize - where - T: From>, - { - fn deserialize(deserializer: D) -> Result + impl<'de> Visitor<'de> for Base64Visitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("a base64 string") + } + + fn visit_str(self, s: &str) -> Result where - D: serde::Deserializer<'de>, + E: serde::de::Error, { - let s: &str = Deserialize::deserialize(deserializer)?; - let decoded = base64::decode_config(s, base64::STANDARD) .or_else(|e| match e { // Either standard or URL-safe base64 encoding are accepted @@ -80,8 +83,22 @@ pub mod private { _ => Err(e), }) .map_err(serde::de::Error::custom)?; + Ok(decoded) + } + } + + #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)] + pub struct BytesDeserialize(pub T); - Ok(Self(decoded.into())) + impl<'de, T> Deserialize<'de> for BytesDeserialize + where + T: From>, + { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Ok(Self(deserializer.deserialize_str(Base64Visitor)?.into())) } }