Skip to content

Commit

Permalink
fix: all configuration members non-optional
Browse files Browse the repository at this point in the history
This fixes some sunspots introduced by configuration
  • Loading branch information
crazyscot committed Dec 10, 2024
1 parent a7696ec commit 4ee2dc8
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 47 deletions.
4 changes: 2 additions & 2 deletions src/cli/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ impl CliArgs {
CliArgs::from_arg_matches(&cli.get_matches_from(std::env::args_os())).unwrap();
// Custom logic: '-4' and '-6' convenience aliases
if args.ipv4_alias__ {
args.config.address_family = Some(Some(AddressFamily::V4));
args.config.address_family = Some(AddressFamily::V4);
} else if args.ipv6_alias__ {
args.config.address_family = Some(Some(AddressFamily::V6));
args.config.address_family = Some(AddressFamily::V6);
}
args
}
Expand Down
11 changes: 7 additions & 4 deletions src/client/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,14 @@ impl Channel {
if parameters.remote_debug {
let _ = server.arg("--debug");
}
if let Some(w) = config.initial_congestion_window {
let _ = server.args(["--initial-congestion-window", &w.to_string()]);
match config.initial_congestion_window {
0 => (),
w => {
let _ = server.args(["--initial-congestion-window", &w.to_string()]);
}
}
if let Some(pr) = config.remote_port {
let _ = server.args(["--port", &pr.to_string()]);
if !config.remote_port.is_default() {
let _ = server.args(["--port", &config.remote_port.to_string()]);
}
let _ = server
.stdin(Stdio::piped())
Expand Down
2 changes: 1 addition & 1 deletion src/config/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ mod test {
);
let fake_cli = Configuration_Optional {
rtt: Some(999),
initial_congestion_window: Some(Some(67890)), // yeah the double-Some is a bit of a wart
initial_congestion_window: Some(67890),
..Default::default()
};
let mut mgr = Manager::without_files();
Expand Down
34 changes: 17 additions & 17 deletions src/config/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use derive_deftly::Deftly;
/// The set of configurable options supported by qcp.
///
/// **Note:** The implementation of `default()` for this struct returns qcp's hard-wired configuration defaults.
// Maintainer note: None of the members of this struct should be Option<anything>. That leads to sunspots in the CLI and strange warts (Some(Some(foo))).
#[derive(Deftly)]
#[derive_deftly(Optionalify)]
#[deftly(visibility = "pub(crate)")]
Expand All @@ -42,9 +43,9 @@ pub struct Configuration {
///
/// (For example, when you are connected via an asymmetric last-mile DSL or fibre profile.)
///
/// If not specified, uses the value of `rx`.
/// If not specified or 0, uses the value of `rx`.
#[arg(short('B'), long, alias("tx-bw"), help_heading("Network tuning"), display_order(10), value_name="bytes", value_parser=clap::value_parser!(HumanU64))]
pub tx: Option<HumanU64>,
pub tx: HumanU64,

/// The expected network Round Trip time to the target system, in milliseconds.
/// [default: 300]
Expand Down Expand Up @@ -74,7 +75,7 @@ pub struct Configuration {
///
/// _Setting this value too high reduces performance!_
#[arg(long, help_heading("Advanced network tuning"), value_name = "bytes")]
pub initial_congestion_window: Option<u64>,
pub initial_congestion_window: u64,

/// Uses the given UDP port or range on the local endpoint.
/// This can be useful when there is a firewall between the endpoints.
Expand All @@ -84,7 +85,7 @@ pub struct Configuration {
///
/// If unspecified, uses any available UDP port.
#[arg(short = 'p', long, value_name("M-N"), help_heading("Connection"))]
pub port: Option<PortRange>,
pub port: PortRange,

/// Connection timeout for the QUIC endpoints [seconds; default 5]
///
Expand All @@ -99,7 +100,7 @@ pub struct Configuration {
/// If unspecified, uses whatever seems suitable given the target address or the result of DNS lookup.
// (see also [CliArgs::ipv4_alias__] and [CliArgs::ipv6_alias__])
#[arg(long, alias("ipv"), help_heading("Connection"), group("ip address"))]
pub address_family: Option<AddressFamily>,
pub address_family: AddressFamily,

/// Specifies the ssh client program to use [default: `ssh`]
#[arg(long, help_heading("Connection"))]
Expand Down Expand Up @@ -129,7 +130,7 @@ pub struct Configuration {
///
/// If unspecified, uses any available UDP port.
#[arg(short = 'P', long, value_name("M-N"), help_heading("Connection"))]
pub remote_port: Option<PortRange>,
pub remote_port: PortRange,

/// Specifies the time format to use when printing messages to the console or to file
#[arg(short = 'T', long, value_name("FORMAT"), help_heading("Output"))]
Expand Down Expand Up @@ -157,10 +158,9 @@ impl Configuration {
#[must_use]
/// Transmit bandwidth (accessor)
pub fn tx(&self) -> u64 {
if let Some(tx) = self.tx {
*tx
} else {
self.rx()
match *self.tx {
0 => self.rx(),
tx => tx,
}
}
/// RTT accessor as Duration
Expand Down Expand Up @@ -206,8 +206,8 @@ impl Configuration {
#[must_use]
pub fn format_transport_config(&self) -> String {
let iwind = match self.initial_congestion_window {
None => "<default>".to_string(),
Some(s) => s.human_count_bytes().to_string(),
0 => "<default>".to_string(),
s => s.human_count_bytes().to_string(),
};
let (tx, rx) = (self.tx(), self.rx());
format!(
Expand All @@ -229,18 +229,18 @@ impl Default for Configuration {
Self {
// Transport
rx: 12_500_000.into(),
tx: None,
tx: 0.into(),
rtt: 300,
congestion: CongestionControllerType::Cubic,
initial_congestion_window: None,
port: None,
initial_congestion_window: 0,
port: PortRange::default(),
timeout: 5,

// Client
address_family: None,
address_family: AddressFamily::Any,
ssh: "ssh".into(),
ssh_opt: vec![],
remote_port: None,
remote_port: PortRange::default(),
time_format: TimeFormat::Local,
}
}
Expand Down
9 changes: 5 additions & 4 deletions src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,19 @@ pub fn create_config(params: &Configuration, mode: ThroughputMode) -> Result<Arc
ThroughputMode::Tx => (),
}

let window = params.initial_congestion_window;
match params.congestion {
CongestionControllerType::Cubic => {
let mut cubic = CubicConfig::default();
if let Some(w) = params.initial_congestion_window {
let _ = cubic.initial_window(w);
if window != 0 {
let _ = cubic.initial_window(window);
}
let _ = config.congestion_controller_factory(Arc::new(cubic));
}
CongestionControllerType::Bbr => {
let mut bbr = BbrConfig::default();
if let Some(w) = params.initial_congestion_window {
let _ = bbr.initial_window(w);
if window != 0 {
let _ = bbr.initial_window(window);
}
let _ = config.congestion_controller_factory(Arc::new(bbr));
}
Expand Down
33 changes: 25 additions & 8 deletions src/util/address_family.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,47 @@ use crate::util::cli::IntOrString;
/// Representation an IP address family
///
/// This is a local type with special parsing semantics to take part in the config/CLI system.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, clap::ValueEnum)]
#[serde(from = "IntOrString<AddressFamily>", into = "u64")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum AddressFamily {
/// IPv4
#[value(name = "4")]
V4,
/// IPv6
#[value(name = "6")]
V6,
/// We don't mind what type of IP address
Any,
}

impl From<AddressFamily> for u64 {
impl From<AddressFamily> for u8 {
fn from(value: AddressFamily) -> Self {
match value {
AddressFamily::V4 => 4,
AddressFamily::V6 => 6,
AddressFamily::Any => 0,
}
}
}

impl Serialize for AddressFamily {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match *self {
AddressFamily::Any => serializer.serialize_str("any"),
t => serializer.serialize_u8(u8::from(t)),
}
}
}

impl Display for AddressFamily {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let u: u8 = match self {
AddressFamily::V4 => 4,
AddressFamily::V6 => 6,
};
write!(f, "{u}")
if *self == AddressFamily::Any {
write!(f, "any")
} else {
write!(f, "{}", u8::from(*self))
}
}
}

Expand All @@ -51,6 +65,8 @@ impl FromStr for AddressFamily {
Ok(AddressFamily::V4)
} else if s == "6" {
Ok(AddressFamily::V6)
} else if s == "0" || s == "any" {
Ok(AddressFamily::Any)
} else {
Err(figment::error::Kind::InvalidType(Actual::Str(s.into()), "4 or 6".into()).into())
}
Expand All @@ -64,6 +80,7 @@ impl TryFrom<u64> for AddressFamily {
match value {
4 => Ok(AddressFamily::V4),
6 => Ok(AddressFamily::V6),
0 => Ok(AddressFamily::Any),
_ => Err(figment::error::Kind::InvalidValue(
Actual::Unsigned(value.into()),
"4 or 6".into(),
Expand Down
8 changes: 4 additions & 4 deletions src/util/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ use super::AddressFamily;
/// Results can be restricted to a given address family.
/// Only the first matching result is returned.
/// If there are no matching records of the required type, returns an error.
pub fn lookup_host_by_family(host: &str, desired: Option<AddressFamily>) -> anyhow::Result<IpAddr> {
pub fn lookup_host_by_family(host: &str, desired: AddressFamily) -> anyhow::Result<IpAddr> {
let candidates = dns_lookup::lookup_host(host)
.with_context(|| format!("host name lookup for {host} failed"))?;
let mut it = candidates.iter();

let found = match desired {
None => it.next(),
Some(AddressFamily::V4) => it.find(|addr| addr.is_ipv4()),
Some(AddressFamily::V6) => it.find(|addr| addr.is_ipv6()),
AddressFamily::Any => it.next(),
AddressFamily::V4 => it.find(|addr| addr.is_ipv4()),
AddressFamily::V6 => it.find(|addr| addr.is_ipv6()),
};
found
.map(std::borrow::ToOwned::to_owned)
Expand Down
12 changes: 12 additions & 0 deletions src/util/port_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ impl From<u64> for PortRange {
}
}

impl Default for PortRange {
fn default() -> Self {
Self::from(0)
}
}

impl PortRange {
pub(crate) fn is_default(self) -> bool {
self.begin == 0 && self.begin == self.end
}
}

impl<'de> serde::Deserialize<'de> for PortRange {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
Expand Down
10 changes: 3 additions & 7 deletions src/util/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub fn bind_unspecified_for(peer: &SocketAddr) -> anyhow::Result<std::net::UdpSo
/// Creates and binds a UDP socket from a restricted range of local ports, using the address family necessary to reach the given peer address
pub fn bind_range_for_peer(
peer: &SocketAddr,
range: Option<PortRange>,
range: PortRange,
) -> anyhow::Result<std::net::UdpSocket> {
let addr: IpAddr = match peer {
SocketAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
Expand All @@ -98,12 +98,8 @@ pub fn bind_range_for_peer(
/// Creates and binds a UDP socket from a restricted range of local ports, for a given local address
pub fn bind_range_for_address(
addr: IpAddr,
range: Option<PortRange>,
range: PortRange,
) -> anyhow::Result<std::net::UdpSocket> {
let range = match range {
None => PortRange { begin: 0, end: 0 },
Some(r) => r,
};
if range.begin == range.end {
return Ok(UdpSocket::bind(SocketAddr::new(addr, range.begin))?);
}
Expand All @@ -119,7 +115,7 @@ pub fn bind_range_for_address(
/// Creates and binds a UDP socket from a restricted range of local ports, for the unspecified address of the given address family
pub fn bind_range_for_family(
family: ConnectionType,
range: Option<PortRange>,
range: PortRange,
) -> anyhow::Result<std::net::UdpSocket> {
let addr = match family {
ConnectionType::Ipv4 => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
Expand Down

0 comments on commit 4ee2dc8

Please sign in to comment.