Skip to content

Commit

Permalink
refactor: overhaul & simplify internal types
Browse files Browse the repository at this point in the history
- AddressFamily
- - rename enum variants to Inet / Inet6 / Any
- - serialize as "inet4" / "inet6" to align with ssh config
- - accept aliases "4" / "6"
- - no need to use IntOrString now we've changed config file format
- PortRange
- - no need to deserialise via IntOrString
- - drop From<u64>
- - derive Default
- HumanU64
- - no need to deserialize from IntOrString
- - ser/de as string for now
- drop IntOrString as no longer used
  • Loading branch information
crazyscot committed Dec 24, 2024
1 parent c0fcef0 commit 77e057e
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 190 deletions.
4 changes: 2 additions & 2 deletions src/cli/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ impl CliArgs {
CliArgs::from_arg_matches(&cli.get_matches_from(std::env::args_os())).unwrap();
// Custom logic: '-4' and '-6' convenience aliases
if args.ipv4_alias__ {
args.config.address_family = Some(AddressFamily::V4);
args.config.address_family = Some(AddressFamily::Inet);
} else if args.ipv6_alias__ {
args.config.address_family = Some(AddressFamily::V6);
args.config.address_family = Some(AddressFamily::Inet6);
}
args
}
Expand Down
139 changes: 47 additions & 92 deletions src/util/address_family.rs
Original file line number Diff line number Diff line change
@@ -1,142 +1,97 @@
//! CLI helper - Address family
// (c) 2024 Ross Younger

use std::fmt::Display;
use std::marker::PhantomData;
use std::str::FromStr;

use figment::error::Actual;
use serde::Serialize;
use figment::error::{Actual, OneOf};
use serde::{de, Deserialize, Serialize};

use crate::util::cli::IntOrString;

/// Representation an IP address family
/// Representation of an IP address family
///
/// This is a local type with special parsing semantics to take part in the config/CLI system.
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
/// This is a local type with special parsing semantics and aliasing to take part in the config/CLI system.
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, Serialize)]
#[serde(rename_all = "kebab-case")] // to match clap::ValueEnum
pub enum AddressFamily {
/// IPv4
#[value(name = "4")]
V4,
#[value(alias("4"), alias("inet4"))]
Inet,
/// IPv6
#[value(name = "6")]
V6,
#[value(alias("6"))]
Inet6,
/// We don't mind what type of IP address
Any,
}

impl From<AddressFamily> for u8 {
fn from(value: AddressFamily) -> Self {
match value {
AddressFamily::V4 => 4,
AddressFamily::V6 => 6,
AddressFamily::Any => 0,
}
}
}

impl Serialize for AddressFamily {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match *self {
AddressFamily::Any => serializer.serialize_str("any"),
t => serializer.serialize_u8(u8::from(t)),
}
}
}

impl Display for AddressFamily {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if *self == AddressFamily::Any {
write!(f, "any")
} else {
write!(f, "{}", u8::from(*self))
}
}
}

impl FromStr for AddressFamily {
type Err = figment::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s == "4" {
Ok(AddressFamily::V4)
} else if s == "6" {
Ok(AddressFamily::V6)
} else if s == "0" || s == "any" {
Ok(AddressFamily::Any)
} else {
Err(figment::error::Kind::InvalidType(Actual::Str(s.into()), "4 or 6".into()).into())
}
}
}

impl TryFrom<u64> for AddressFamily {
type Error = figment::Error;

fn try_from(value: u64) -> Result<Self, Self::Error> {
match value {
4 => Ok(AddressFamily::V4),
6 => Ok(AddressFamily::V6),
0 => Ok(AddressFamily::Any),
_ => Err(figment::error::Kind::InvalidValue(
Actual::Unsigned(value.into()),
"4 or 6".into(),
match s {
"4" | "inet" | "inet4" => Ok(AddressFamily::Inet),
"6" | "inet6" => Ok(AddressFamily::Inet6),
"any" => Ok(AddressFamily::Any),
_ => Err(figment::error::Kind::InvalidType(
Actual::Str(s.into()),
OneOf(&["inet", "4", "inet6", "6"]).to_string(),
)
.into()),
}
}
}

impl<'de> serde::Deserialize<'de> for AddressFamily {
impl<'de> Deserialize<'de> for AddressFamily {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_any(IntOrString(PhantomData))
let s = String::deserialize(deserializer)?;
FromStr::from_str(&s).map_err(de::Error::custom)
}
}

#[cfg(test)]
mod test {
use std::str::FromStr;

use super::AddressFamily;

#[test]
fn serialize() {
let a = AddressFamily::V4;
let b = AddressFamily::V6;
let a = AddressFamily::Inet;
let b = AddressFamily::Inet6;
let c = AddressFamily::Any;

let aa = serde_json::to_string(&a);
let bb = serde_json::to_string(&b);
assert_eq!(aa.unwrap(), "4");
assert_eq!(bb.unwrap(), "6");
let cc = serde_json::to_string(&c);
assert_eq!(aa.unwrap(), "\"inet\"");
assert_eq!(bb.unwrap(), "\"inet6\"");
assert_eq!(cc.unwrap(), "\"any\"");
}

#[test]
fn deser_str() {
let a: AddressFamily = serde_json::from_str(r#" "4" "#).unwrap();
assert_eq!(a, AddressFamily::V4);
let a: AddressFamily = serde_json::from_str(r#" "6" "#).unwrap();
assert_eq!(a, AddressFamily::V6);
}

#[test]
fn deser_int() {
let a: AddressFamily = serde_json::from_str("4").unwrap();
assert_eq!(a, AddressFamily::V4);
let a: AddressFamily = serde_json::from_str("6").unwrap();
assert_eq!(a, AddressFamily::V6);
use AddressFamily::*;
for (str, expected) in &[
("4", Inet),
("inet", Inet),
("inet4", Inet),
("6", Inet6),
("inet6", Inet6),
("any", Any),
] {
let raw = AddressFamily::from_str(str).expect(str);
let json = format!(r#""{str}""#);
let output = serde_json::from_str::<AddressFamily>(&json).expect(str);
assert_eq!(raw, *expected);
assert_eq!(output, *expected);
}
}

#[test]
fn deser_invalid() {
let _ = serde_json::from_str::<AddressFamily>("true").unwrap_err();
let _ = serde_json::from_str::<AddressFamily>("5").unwrap_err();
let _ = serde_json::from_str::<AddressFamily>(r#" "5" "#).unwrap_err();
let _ = serde_json::from_str::<AddressFamily>("-1").unwrap_err();
let _ = serde_json::from_str::<AddressFamily>(r#" "42" "#).unwrap_err();
let _ = serde_json::from_str::<AddressFamily>(r#" "string" "#).unwrap_err();
for s in &["true", "5", r#""5""#, "-1", r#""42"#, r#""string"#] {
let _ = serde_json::from_str::<AddressFamily>(s).expect_err(s);
}
}
}
51 changes: 0 additions & 51 deletions src/util/cli.rs

This file was deleted.

4 changes: 2 additions & 2 deletions src/util/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ pub fn lookup_host_by_family(host: &str, desired: AddressFamily) -> anyhow::Resu

let found = match desired {
AddressFamily::Any => it.next(),
AddressFamily::V4 => it.find(|addr| addr.is_ipv4()),
AddressFamily::V6 => it.find(|addr| addr.is_ipv6()),
AddressFamily::Inet => it.find(|addr| addr.is_ipv4()),
AddressFamily::Inet6 => it.find(|addr| addr.is_ipv6()),
};
found
.map(std::borrow::ToOwned::to_owned)
Expand Down
34 changes: 13 additions & 21 deletions src/util/humanu64.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
//! Serialization helper type - u64 parseable by humanize_rs
// (c) 2024 Ross Younger

use std::{marker::PhantomData, ops::Deref, str::FromStr};
use std::{ops::Deref, str::FromStr};

use humanize_rs::bytes::Bytes;
use serde::{
de::{self, Error as _},
Serialize,
};

use super::cli::IntOrString;

/// An integer field that may also be expressed using engineering prefixes (k, M, G, etc).
/// An integer field that may also be expressed using SI notation (k, M, G, etc).
/// For example, `1k` and `1000` are the same.
///
/// (Nerdy description: This is a newtype wrapper to `u64` that adds a flexible deserializer via `humanize_rs::bytes::Bytes<u64>`.)
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)]
#[serde(from = "IntOrString<HumanU64>", into = "u64")]
#[serde(from = "String", into = "String")]
pub struct HumanU64(pub u64);

impl HumanU64 {
Expand All @@ -42,6 +40,12 @@ impl From<HumanU64> for u64 {
}
}

impl From<HumanU64> for String {
fn from(value: HumanU64) -> Self {
format!("{}", *value)
}
}

impl FromStr for HumanU64 {
type Err = figment::Error;

Expand Down Expand Up @@ -71,14 +75,8 @@ impl<'de> serde::Deserialize<'de> for HumanU64 {
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_any(IntOrString(PhantomData))
}
}

#[cfg(test)]
impl rand::prelude::Distribution<HumanU64> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> HumanU64 {
rng.gen::<u64>().into()
let s = String::deserialize(deserializer)?;
FromStr::from_str(&s).map_err(de::Error::custom)
}
}

Expand All @@ -105,21 +103,15 @@ mod test {
test_deser_str("\"100k\"", 100_000);
}

#[test]
fn deser_raw_int() {
let foo: HumanU64 = serde_json::from_str("12345").unwrap();
assert_eq!(*foo, 12345);
}

#[test]
fn serde_test() {
let bw = HumanU64::new(42);
assert_tokens(&bw, &[Token::U64(42)]);
assert_tokens(&bw, &[Token::Str("42")]);
}

#[test]
fn from_int() {
let result = HumanU64::from(12345);
let result = HumanU64::new(12345);
assert_eq!(*result, 12345);
}
#[test]
Expand Down
2 changes: 0 additions & 2 deletions src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ pub use dns::lookup_host_by_family;
mod cert;
pub use cert::Credentials;

pub mod cli;

pub mod humanu64;
pub mod io;
pub mod socket;
Expand Down
Loading

0 comments on commit 77e057e

Please sign in to comment.