From 77e057e00a55107b1287c6802106c79a41551d18 Mon Sep 17 00:00:00 2001 From: Ross Younger Date: Tue, 24 Dec 2024 20:43:58 +1300 Subject: [PATCH] refactor: overhaul & simplify internal types - AddressFamily - - rename enum variants to Inet / Inet6 / Any - - serialize as "inet4" / "inet6" to align with ssh config - - accept aliases "4" / "6" - - no need to use IntOrString now we've changed config file format - PortRange - - no need to deserialise via IntOrString - - drop From - - derive Default - HumanU64 - - no need to deserialize from IntOrString - - ser/de as string for now - drop IntOrString as no longer used --- src/cli/args.rs | 4 +- src/util/address_family.rs | 139 +++++++++++++------------------------ src/util/cli.rs | 51 -------------- src/util/dns.rs | 4 +- src/util/humanu64.rs | 34 ++++----- src/util/mod.rs | 2 - src/util/port_range.rs | 25 ++----- 7 files changed, 69 insertions(+), 190 deletions(-) delete mode 100644 src/util/cli.rs diff --git a/src/cli/args.rs b/src/cli/args.rs index d6be003..5589a7d 100644 --- a/src/cli/args.rs +++ b/src/cli/args.rs @@ -100,9 +100,9 @@ impl CliArgs { CliArgs::from_arg_matches(&cli.get_matches_from(std::env::args_os())).unwrap(); // Custom logic: '-4' and '-6' convenience aliases if args.ipv4_alias__ { - args.config.address_family = Some(AddressFamily::V4); + args.config.address_family = Some(AddressFamily::Inet); } else if args.ipv6_alias__ { - args.config.address_family = Some(AddressFamily::V6); + args.config.address_family = Some(AddressFamily::Inet6); } args } diff --git a/src/util/address_family.rs b/src/util/address_family.rs index e342b4f..2b340c0 100644 --- a/src/util/address_family.rs +++ b/src/util/address_family.rs @@ -1,142 +1,97 @@ //! CLI helper - Address family // (c) 2024 Ross Younger -use std::fmt::Display; -use std::marker::PhantomData; use std::str::FromStr; -use figment::error::Actual; -use serde::Serialize; +use figment::error::{Actual, OneOf}; +use serde::{de, Deserialize, Serialize}; -use crate::util::cli::IntOrString; - -/// Representation an IP address family +/// Representation of an IP address family /// -/// This is a local type with special parsing semantics to take part in the config/CLI system. -#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)] +/// This is a local type with special parsing semantics and aliasing to take part in the config/CLI system. +#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, Serialize)] +#[serde(rename_all = "kebab-case")] // to match clap::ValueEnum pub enum AddressFamily { /// IPv4 - #[value(name = "4")] - V4, + #[value(alias("4"), alias("inet4"))] + Inet, /// IPv6 - #[value(name = "6")] - V6, + #[value(alias("6"))] + Inet6, /// We don't mind what type of IP address Any, } -impl From for u8 { - fn from(value: AddressFamily) -> Self { - match value { - AddressFamily::V4 => 4, - AddressFamily::V6 => 6, - AddressFamily::Any => 0, - } - } -} - -impl Serialize for AddressFamily { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - match *self { - AddressFamily::Any => serializer.serialize_str("any"), - t => serializer.serialize_u8(u8::from(t)), - } - } -} - -impl Display for AddressFamily { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if *self == AddressFamily::Any { - write!(f, "any") - } else { - write!(f, "{}", u8::from(*self)) - } - } -} - impl FromStr for AddressFamily { type Err = figment::Error; fn from_str(s: &str) -> Result { - if s == "4" { - Ok(AddressFamily::V4) - } else if s == "6" { - Ok(AddressFamily::V6) - } else if s == "0" || s == "any" { - Ok(AddressFamily::Any) - } else { - Err(figment::error::Kind::InvalidType(Actual::Str(s.into()), "4 or 6".into()).into()) - } - } -} - -impl TryFrom for AddressFamily { - type Error = figment::Error; - - fn try_from(value: u64) -> Result { - match value { - 4 => Ok(AddressFamily::V4), - 6 => Ok(AddressFamily::V6), - 0 => Ok(AddressFamily::Any), - _ => Err(figment::error::Kind::InvalidValue( - Actual::Unsigned(value.into()), - "4 or 6".into(), + match s { + "4" | "inet" | "inet4" => Ok(AddressFamily::Inet), + "6" | "inet6" => Ok(AddressFamily::Inet6), + "any" => Ok(AddressFamily::Any), + _ => Err(figment::error::Kind::InvalidType( + Actual::Str(s.into()), + OneOf(&["inet", "4", "inet6", "6"]).to_string(), ) .into()), } } } -impl<'de> serde::Deserialize<'de> for AddressFamily { +impl<'de> Deserialize<'de> for AddressFamily { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { - deserializer.deserialize_any(IntOrString(PhantomData)) + let s = String::deserialize(deserializer)?; + FromStr::from_str(&s).map_err(de::Error::custom) } } #[cfg(test)] mod test { + use std::str::FromStr; + use super::AddressFamily; #[test] fn serialize() { - let a = AddressFamily::V4; - let b = AddressFamily::V6; + let a = AddressFamily::Inet; + let b = AddressFamily::Inet6; + let c = AddressFamily::Any; let aa = serde_json::to_string(&a); let bb = serde_json::to_string(&b); - assert_eq!(aa.unwrap(), "4"); - assert_eq!(bb.unwrap(), "6"); + let cc = serde_json::to_string(&c); + assert_eq!(aa.unwrap(), "\"inet\""); + assert_eq!(bb.unwrap(), "\"inet6\""); + assert_eq!(cc.unwrap(), "\"any\""); } #[test] fn deser_str() { - let a: AddressFamily = serde_json::from_str(r#" "4" "#).unwrap(); - assert_eq!(a, AddressFamily::V4); - let a: AddressFamily = serde_json::from_str(r#" "6" "#).unwrap(); - assert_eq!(a, AddressFamily::V6); - } - - #[test] - fn deser_int() { - let a: AddressFamily = serde_json::from_str("4").unwrap(); - assert_eq!(a, AddressFamily::V4); - let a: AddressFamily = serde_json::from_str("6").unwrap(); - assert_eq!(a, AddressFamily::V6); + use AddressFamily::*; + for (str, expected) in &[ + ("4", Inet), + ("inet", Inet), + ("inet4", Inet), + ("6", Inet6), + ("inet6", Inet6), + ("any", Any), + ] { + let raw = AddressFamily::from_str(str).expect(str); + let json = format!(r#""{str}""#); + let output = serde_json::from_str::(&json).expect(str); + assert_eq!(raw, *expected); + assert_eq!(output, *expected); + } } #[test] fn deser_invalid() { - let _ = serde_json::from_str::("true").unwrap_err(); - let _ = serde_json::from_str::("5").unwrap_err(); - let _ = serde_json::from_str::(r#" "5" "#).unwrap_err(); - let _ = serde_json::from_str::("-1").unwrap_err(); - let _ = serde_json::from_str::(r#" "42" "#).unwrap_err(); - let _ = serde_json::from_str::(r#" "string" "#).unwrap_err(); + for s in &["true", "5", r#""5""#, "-1", r#""42"#, r#""string"#] { + let _ = serde_json::from_str::(s).expect_err(s); + } } } diff --git a/src/util/cli.rs b/src/util/cli.rs deleted file mode 100644 index d34a38b..0000000 --- a/src/util/cli.rs +++ /dev/null @@ -1,51 +0,0 @@ -//! CLI generic serialization helpers -// (c) 2024 Ross Younger - -use std::{fmt, marker::PhantomData, str::FromStr}; - -use serde::{de, de::Visitor, Deserialize}; - -/// Deserialization helper for types which might reasonably be expressed as an -/// integer or a string. -/// -/// This is a Visitor that forwards string types to T's `FromStr` impl and -/// forwards int types to T's `From` or `From` impls. The `PhantomData` is to -/// keep the compiler from complaining about T being an unused generic type -/// parameter. We need T in order to know the Value type for the Visitor -/// impl. -#[allow(missing_debug_implementations)] -pub struct IntOrString(pub PhantomData T>); - -impl<'de, T> Visitor<'de> for IntOrString -where - T: Deserialize<'de> + TryFrom + FromStr, - ::Err: std::fmt::Display, - >::Error: std::fmt::Display, -{ - type Value = T; - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str("int or string") - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - T::from_str(value).map_err(de::Error::custom) - } - - fn visit_u64(self, value: u64) -> Result - where - E: de::Error, - { - T::try_from(value).map_err(de::Error::custom) - } - - fn visit_i64(self, value: i64) -> Result - where - E: de::Error, - { - let u = u64::try_from(value).map_err(de::Error::custom)?; - T::try_from(u).map_err(de::Error::custom) - } -} diff --git a/src/util/dns.rs b/src/util/dns.rs index 875a34d..70a7e12 100644 --- a/src/util/dns.rs +++ b/src/util/dns.rs @@ -19,8 +19,8 @@ pub fn lookup_host_by_family(host: &str, desired: AddressFamily) -> anyhow::Resu let found = match desired { AddressFamily::Any => it.next(), - AddressFamily::V4 => it.find(|addr| addr.is_ipv4()), - AddressFamily::V6 => it.find(|addr| addr.is_ipv6()), + AddressFamily::Inet => it.find(|addr| addr.is_ipv4()), + AddressFamily::Inet6 => it.find(|addr| addr.is_ipv6()), }; found .map(std::borrow::ToOwned::to_owned) diff --git a/src/util/humanu64.rs b/src/util/humanu64.rs index b9e97df..ad0445c 100644 --- a/src/util/humanu64.rs +++ b/src/util/humanu64.rs @@ -1,7 +1,7 @@ //! Serialization helper type - u64 parseable by humanize_rs // (c) 2024 Ross Younger -use std::{marker::PhantomData, ops::Deref, str::FromStr}; +use std::{ops::Deref, str::FromStr}; use humanize_rs::bytes::Bytes; use serde::{ @@ -9,15 +9,13 @@ use serde::{ Serialize, }; -use super::cli::IntOrString; - -/// An integer field that may also be expressed using engineering prefixes (k, M, G, etc). +/// An integer field that may also be expressed using SI notation (k, M, G, etc). /// For example, `1k` and `1000` are the same. /// /// (Nerdy description: This is a newtype wrapper to `u64` that adds a flexible deserializer via `humanize_rs::bytes::Bytes`.) #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)] -#[serde(from = "IntOrString", into = "u64")] +#[serde(from = "String", into = "String")] pub struct HumanU64(pub u64); impl HumanU64 { @@ -42,6 +40,12 @@ impl From for u64 { } } +impl From for String { + fn from(value: HumanU64) -> Self { + format!("{}", *value) + } +} + impl FromStr for HumanU64 { type Err = figment::Error; @@ -71,14 +75,8 @@ impl<'de> serde::Deserialize<'de> for HumanU64 { where D: serde::Deserializer<'de>, { - deserializer.deserialize_any(IntOrString(PhantomData)) - } -} - -#[cfg(test)] -impl rand::prelude::Distribution for rand::distributions::Standard { - fn sample(&self, rng: &mut R) -> HumanU64 { - rng.gen::().into() + let s = String::deserialize(deserializer)?; + FromStr::from_str(&s).map_err(de::Error::custom) } } @@ -105,21 +103,15 @@ mod test { test_deser_str("\"100k\"", 100_000); } - #[test] - fn deser_raw_int() { - let foo: HumanU64 = serde_json::from_str("12345").unwrap(); - assert_eq!(*foo, 12345); - } - #[test] fn serde_test() { let bw = HumanU64::new(42); - assert_tokens(&bw, &[Token::U64(42)]); + assert_tokens(&bw, &[Token::Str("42")]); } #[test] fn from_int() { - let result = HumanU64::from(12345); + let result = HumanU64::new(12345); assert_eq!(*result, 12345); } #[test] diff --git a/src/util/mod.rs b/src/util/mod.rs index fce6cc0..a4d9b53 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -10,8 +10,6 @@ pub use dns::lookup_host_by_family; mod cert; pub use cert::Credentials; -pub mod cli; - pub mod humanu64; pub mod io; pub mod socket; diff --git a/src/util/port_range.rs b/src/util/port_range.rs index 7e9b356..ddb1a48 100644 --- a/src/util/port_range.rs +++ b/src/util/port_range.rs @@ -1,13 +1,11 @@ /// CLI argument helper - PortRange // (c) 2024 Ross Younger use serde::{ - de::{Error, Unexpected}, + de::{self, Error, Unexpected}, Serialize, }; use std::{fmt::Display, str::FromStr}; -use super::cli::IntOrString; - /// A range of UDP port numbers. /// /// Port 0 is allowed with the usual meaning ("any available port"), but 0 may not form part of a range. @@ -17,8 +15,8 @@ use super::cli::IntOrString; /// remote_port 60000 # a single port /// remote_port 60000-60010 # a range /// ``` -#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize)] -#[serde(from = "IntOrString", into = "String")] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize)] +#[serde(from = "String", into = "String")] pub struct PortRange { /// First number in the range pub begin: u16, @@ -83,20 +81,6 @@ impl FromStr for PortRange { } } -impl From for PortRange { - fn from(value: u64) -> Self { - #[allow(clippy::cast_possible_truncation)] - let v = value as u16; - PortRange { begin: v, end: v } - } -} - -impl Default for PortRange { - fn default() -> Self { - Self::from(0) - } -} - impl PortRange { pub(crate) fn is_default(self) -> bool { self.begin == 0 && self.begin == self.end @@ -108,7 +92,8 @@ impl<'de> serde::Deserialize<'de> for PortRange { where D: serde::Deserializer<'de>, { - deserializer.deserialize_any(IntOrString(std::marker::PhantomData)) + let s = String::deserialize(deserializer)?; + FromStr::from_str(&s).map_err(de::Error::custom) } }