Skip to content

Commit

Permalink
fix: properly handle deserialization from Reader
Browse files Browse the repository at this point in the history
The pattern "let s: &str = Deserialize::deserialize(deser)" used
in multiple places is not ideal, as it generates errors when the
deserializer uses an IO Reader and not a string. To fix this,
implementing a visitor is preferred, as it gives the deserializer the
choice of allocating or not a string.

To ensure the issue is fixed, the tests now deserialize both from a
string and from a reader.

Closes influxdata#55
  • Loading branch information
vthib committed Sep 27, 2022
1 parent e2e073e commit 71fe149
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 26 deletions.
9 changes: 9 additions & 0 deletions pbjson-test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,16 @@ mod tests {
}

fn verify_decode(decoded: &KitchenSink, expected: &str) {
// Decode from a string first
assert_eq!(decoded, &serde_json::from_str(expected).unwrap());

// Then, try decoding from a Reader: this can catch issues when trying to borrow data
// from the input, which is not possible when deserializing from a Reader (e.g. an opened
// file).
assert_eq!(
decoded,
&serde_json::from_reader(expected.as_bytes()).unwrap()
);
}

fn verify(decoded: &KitchenSink, expected: &str) {
Expand Down
35 changes: 26 additions & 9 deletions pbjson-types/src/duration.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::Duration;
use serde::{Deserialize, Serialize};
use serde::de::Visitor;
use serde::Serialize;

impl TryFrom<Duration> for std::time::Duration {
type Error = std::num::TryFromIntError;
Expand Down Expand Up @@ -55,12 +56,19 @@ impl Serialize for Duration {
}
}

impl<'de> serde::Deserialize<'de> for Duration {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
struct DurationVisitor;

impl<'de> Visitor<'de> for DurationVisitor {
type Value = Duration;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a duration string")
}

fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
D: serde::Deserializer<'de>,
E: serde::de::Error,
{
let s: &str = Deserialize::deserialize(deserializer)?;
let s = s
.strip_suffix('s')
.ok_or_else(|| serde::de::Error::custom("missing 's' suffix"))?;
Expand All @@ -70,7 +78,7 @@ impl<'de> serde::Deserialize<'de> for Duration {
None => (false, s),
};

let duration: Self = match s.split_once('.') {
let duration = match s.split_once('.') {
Some((seconds_str, decimal_str)) => {
let exp = 9_u32
.checked_sub(decimal_str.len() as u32)
Expand All @@ -80,19 +88,19 @@ impl<'de> serde::Deserialize<'de> for Duration {
let seconds = seconds_str.parse().map_err(serde::de::Error::custom)?;
let decimal: u32 = decimal_str.parse().map_err(serde::de::Error::custom)?;

Self {
Duration {
seconds,
nanos: (decimal * pow) as i32,
}
}
None => Self {
None => Duration {
seconds: s.parse().map_err(serde::de::Error::custom)?,
nanos: 0,
},
};

Ok(match negative {
true => Self {
true => Duration {
seconds: -duration.seconds,
nanos: -duration.nanos,
},
Expand All @@ -101,6 +109,15 @@ impl<'de> serde::Deserialize<'de> for Duration {
}
}

impl<'de> serde::Deserialize<'de> for Duration {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(DurationVisitor)
}
}

/// Splits nanoseconds into whole milliseconds, microseconds, and nanoseconds
fn split_nanos(mut nanos: u32) -> (u32, u32, u32) {
let millis = nanos / 1_000_000;
Expand Down
28 changes: 23 additions & 5 deletions pbjson-types/src/timestamp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::Timestamp;
use chrono::{DateTime, NaiveDateTime, Utc};
use serde::{Deserialize, Serialize};
use serde::de::Visitor;
use serde::Serialize;

impl TryFrom<Timestamp> for chrono::DateTime<Utc> {
type Error = std::num::TryFromIntError;
Expand Down Expand Up @@ -31,23 +32,40 @@ impl Serialize for Timestamp {
}
}

impl<'de> serde::Deserialize<'de> for Timestamp {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
struct TimestampVisitor;

impl<'de> Visitor<'de> for TimestampVisitor {
type Value = Timestamp;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a date string")
}

fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
D: serde::Deserializer<'de>,
E: serde::de::Error,
{
let s: &str = Deserialize::deserialize(deserializer)?;
let d = DateTime::parse_from_rfc3339(s).map_err(serde::de::Error::custom)?;
let d: DateTime<Utc> = d.into();
Ok(d.into())
}
}

impl<'de> serde::Deserialize<'de> for Timestamp {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(TimestampVisitor)
}
}

#[cfg(test)]
mod tests {
use super::*;
use chrono::{FixedOffset, TimeZone};
use serde::de::value::{BorrowedStrDeserializer, Error};
use serde::Deserialize;

#[test]
fn test_date() {
Expand Down
41 changes: 29 additions & 12 deletions pbjson/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ pub mod private {
/// Re-export base64
pub use base64;

use serde::de::Visitor;
use serde::Deserialize;
use std::borrow::Cow;
use std::str::FromStr;

/// Used to parse a number from either a string or its raw representation
Expand All @@ -32,7 +34,8 @@ pub mod private {
#[derive(Deserialize)]
#[serde(untagged)]
enum Content<'a, T> {
Str(&'a str),
#[serde(borrow)]
Str(Cow<'a, str>),
Number(T),
}

Expand All @@ -53,19 +56,19 @@ pub mod private {
}
}

#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
pub struct BytesDeserialize<T>(pub T);
struct Base64Visitor;

impl<'de, T> Deserialize<'de> for BytesDeserialize<T>
where
T: From<Vec<u8>>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
impl<'de> Visitor<'de> for Base64Visitor {
type Value = Vec<u8>;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a base64 string")
}

fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
D: serde::Deserializer<'de>,
E: serde::de::Error,
{
let s: &str = Deserialize::deserialize(deserializer)?;

let decoded = base64::decode_config(s, base64::STANDARD)
.or_else(|e| match e {
// Either standard or URL-safe base64 encoding are accepted
Expand All @@ -80,8 +83,22 @@ pub mod private {
_ => Err(e),
})
.map_err(serde::de::Error::custom)?;
Ok(decoded)
}
}

#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
pub struct BytesDeserialize<T>(pub T);

Ok(Self(decoded.into()))
impl<'de, T> Deserialize<'de> for BytesDeserialize<T>
where
T: From<Vec<u8>>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(Self(deserializer.deserialize_str(Base64Visitor)?.into()))
}
}

Expand Down

0 comments on commit 71fe149

Please sign in to comment.