Skip to content

Commit

Permalink
Merge pull request #353 from kesyog/kyogeswaran/deserialize-uints
Browse files Browse the repository at this point in the history
Use TryInto for more permissive deserialization for integers
  • Loading branch information
matthiasbeyer authored Jun 29, 2022
2 parents 8b41015 + 7db2e8b commit 2d74d06
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 49 deletions.
8 changes: 4 additions & 4 deletions src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,25 @@ impl<'de> de::Deserializer<'de> for Value {
#[inline]
fn deserialize_u8<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
// FIXME: This should *fail* if the value does not fit in the requets integer type
visitor.visit_u8(self.into_int()? as u8)
visitor.visit_u8(self.into_uint()? as u8)
}

#[inline]
fn deserialize_u16<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
// FIXME: This should *fail* if the value does not fit in the requets integer type
visitor.visit_u16(self.into_int()? as u16)
visitor.visit_u16(self.into_uint()? as u16)
}

#[inline]
fn deserialize_u32<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
// FIXME: This should *fail* if the value does not fit in the requets integer type
visitor.visit_u32(self.into_int()? as u32)
visitor.visit_u32(self.into_uint()? as u32)
}

#[inline]
fn deserialize_u64<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
// FIXME: This should *fail* if the value does not fit in the requets integer type
visitor.visit_u64(self.into_int()? as u64)
visitor.visit_u64(self.into_uint()? as u64)
}

#[inline]
Expand Down
109 changes: 64 additions & 45 deletions src/value.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::convert::TryInto;
use std::fmt;
use std::fmt::Display;

Expand Down Expand Up @@ -269,21 +270,27 @@ impl Value {
pub fn into_int(self) -> Result<i64> {
match self.kind {
ValueKind::I64(value) => Ok(value),
ValueKind::I128(value) => Err(ConfigError::invalid_type(
self.origin,
Unexpected::I128(value),
"an signed 64 bit or less integer",
)),
ValueKind::U64(value) => Err(ConfigError::invalid_type(
self.origin,
Unexpected::U64(value),
"an signed 64 bit or less integer",
)),
ValueKind::U128(value) => Err(ConfigError::invalid_type(
self.origin,
Unexpected::U128(value),
"an signed 64 bit or less integer",
)),
ValueKind::I128(value) => value.try_into().map_err(|_| {
ConfigError::invalid_type(
self.origin,
Unexpected::I128(value),
"an signed 64 bit or less integer",
)
}),
ValueKind::U64(value) => value.try_into().map_err(|_| {
ConfigError::invalid_type(
self.origin,
Unexpected::U64(value),
"an signed 64 bit or less integer",
)
}),
ValueKind::U128(value) => value.try_into().map_err(|_| {
ConfigError::invalid_type(
self.origin,
Unexpected::U128(value),
"an signed 64 bit or less integer",
)
}),

ValueKind::String(ref s) => {
match s.to_lowercase().as_ref() {
Expand Down Expand Up @@ -330,11 +337,13 @@ impl Value {
ValueKind::I64(value) => Ok(value.into()),
ValueKind::I128(value) => Ok(value),
ValueKind::U64(value) => Ok(value.into()),
ValueKind::U128(value) => Err(ConfigError::invalid_type(
self.origin,
Unexpected::U128(value),
"an signed 128 bit integer",
)),
ValueKind::U128(value) => value.try_into().map_err(|_| {
ConfigError::invalid_type(
self.origin,
Unexpected::U128(value),
"an signed 128 bit integer",
)
}),

ValueKind::String(ref s) => {
match s.to_lowercase().as_ref() {
Expand Down Expand Up @@ -380,21 +389,27 @@ impl Value {
pub fn into_uint(self) -> Result<u64> {
match self.kind {
ValueKind::U64(value) => Ok(value),
ValueKind::U128(value) => Err(ConfigError::invalid_type(
self.origin,
Unexpected::U128(value),
"an unsigned 64 bit or less integer",
)),
ValueKind::I64(value) => Err(ConfigError::invalid_type(
self.origin,
Unexpected::I64(value),
"an unsigned 64 bit or less integer",
)),
ValueKind::I128(value) => Err(ConfigError::invalid_type(
self.origin,
Unexpected::I128(value),
"an unsigned 64 bit or less integer",
)),
ValueKind::U128(value) => value.try_into().map_err(|_| {
ConfigError::invalid_type(
self.origin,
Unexpected::U128(value),
"an unsigned 64 bit or less integer",
)
}),
ValueKind::I64(value) => value.try_into().map_err(|_| {
ConfigError::invalid_type(
self.origin,
Unexpected::I64(value),
"an unsigned 64 bit or less integer",
)
}),
ValueKind::I128(value) => value.try_into().map_err(|_| {
ConfigError::invalid_type(
self.origin,
Unexpected::I128(value),
"an unsigned 64 bit or less integer",
)
}),

ValueKind::String(ref s) => {
match s.to_lowercase().as_ref() {
Expand Down Expand Up @@ -440,16 +455,20 @@ impl Value {
match self.kind {
ValueKind::U64(value) => Ok(value.into()),
ValueKind::U128(value) => Ok(value),
ValueKind::I64(value) => Err(ConfigError::invalid_type(
self.origin,
Unexpected::I64(value),
"an unsigned 128 bit or less integer",
)),
ValueKind::I128(value) => Err(ConfigError::invalid_type(
self.origin,
Unexpected::I128(value),
"an unsigned 128 bit or less integer",
)),
ValueKind::I64(value) => value.try_into().map_err(|_| {
ConfigError::invalid_type(
self.origin,
Unexpected::I64(value),
"an unsigned 128 bit or less integer",
)
}),
ValueKind::I128(value) => value.try_into().map_err(|_| {
ConfigError::invalid_type(
self.origin,
Unexpected::I128(value),
"an unsigned 128 bit or less integer",
)
}),

ValueKind::String(ref s) => {
match s.to_lowercase().as_ref() {
Expand Down
73 changes: 73 additions & 0 deletions tests/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,39 @@ fn test_parse_int() {
})
}

#[test]
fn test_parse_uint() {
// using a struct in an enum here to make serde use `deserialize_any`
#[derive(Deserialize, Debug)]
#[serde(tag = "tag")]
enum TestUintEnum {
Uint(TestUint),
}

#[derive(Deserialize, Debug)]
struct TestUint {
int_val: u32,
}

temp_env::with_var("INT_VAL", Some("42"), || {
let environment = Environment::default().try_parsing(true);

let config = Config::builder()
.set_default("tag", "Uint")
.unwrap()
.add_source(environment)
.build()
.unwrap();

let config: TestUintEnum = config.try_deserialize().unwrap();

assert!(matches!(
config,
TestUintEnum::Uint(TestUint { int_val: 42 })
));
})
}

#[test]
fn test_parse_float() {
// using a struct in an enum here to make serde use `deserialize_any`
Expand Down Expand Up @@ -535,3 +568,43 @@ fn test_parse_off_string() {
}
})
}

#[test]
fn test_parse_int_default() {
#[derive(Deserialize, Debug)]
struct TestInt {
int_val: i32,
}

let environment = Environment::default().try_parsing(true);

let config = Config::builder()
.set_default("int_val", 42_i32)
.unwrap()
.add_source(environment)
.build()
.unwrap();

let config: TestInt = config.try_deserialize().unwrap();
assert_eq!(config.int_val, 42);
}

#[test]
fn test_parse_uint_default() {
#[derive(Deserialize, Debug)]
struct TestUint {
int_val: u32,
}

let environment = Environment::default().try_parsing(true);

let config = Config::builder()
.set_default("int_val", 42_u32)
.unwrap()
.add_source(environment)
.build()
.unwrap();

let config: TestUint = config.try_deserialize().unwrap();
assert_eq!(config.int_val, 42);
}
17 changes: 17 additions & 0 deletions tests/integer_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,20 @@ fn nonwrapping_u32() {
let port: u32 = c.get("settings.port").unwrap();
assert_eq!(port, 66000);
}

#[test]
#[should_panic]
fn invalid_signedness() {
let c = Config::builder()
.add_source(config::File::from_str(
r#"
[settings]
port = -1
"#,
config::FileFormat::Toml,
))
.build()
.unwrap();

let _: u32 = c.get("settings.port").unwrap();
}

0 comments on commit 2d74d06

Please sign in to comment.