Skip to content

Commit

Permalink
refactor: CopyJobSpec stores user_at_host
Browse files Browse the repository at this point in the history
  • Loading branch information
crazyscot committed Jan 9, 2025
1 parent d2918f0 commit eacac23
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 49 deletions.
57 changes: 13 additions & 44 deletions src/client/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ use std::str::FromStr;

use crate::transport::ThroughputMode;

/// Strips the optional user@ part off a hostname
fn hostname_of(user_at_host: &str) -> &str {
user_at_host.split_once('@').unwrap_or(("", user_at_host)).1
}

/// A file source or destination specified by the user
#[derive(Debug, Clone, Default, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FileSpec {
/// The remote [user@]host for the file. This may be a hostname or an IP address.
/// It may also be a _hostname alias_ that matches a Host section in the user's ssh config file.
Expand All @@ -23,12 +28,7 @@ pub struct FileSpec {
impl FileSpec {
/// Returns only the hostname part of the file, if any; the username is stripped.
pub(crate) fn hostname(&self) -> Option<&str> {
if let Some(user_host) = &self.user_at_host {
let (_user, host) = user_host.split_once('@').unwrap_or(("", user_host));
Some(host)
} else {
None
}
self.user_at_host.as_ref().map(|s| hostname_of(s))
}
}

Expand Down Expand Up @@ -76,10 +76,13 @@ impl std::fmt::Display for FileSpec {
}

/// Details of a file copy job.
#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone)]
pub struct CopyJobSpec {
pub(crate) source: FileSpec,
pub(crate) destination: FileSpec,
/// The `[user@]host` part of whichever of the source or destination contained one.
/// (There can be only one.)
pub(crate) user_at_host: String,
}

impl CopyJobSpec {
Expand All @@ -92,20 +95,9 @@ impl CopyJobSpec {
}
}

/// The `[user@]hostname` portion of whichever of the arguments contained a hostname.
fn remote_user_host(&self) -> &str {
self.source
.user_at_host
.as_ref()
.unwrap_or_else(|| self.destination.user_at_host.as_ref().unwrap())
}

/// The hostname portion of whichever of the arguments contained one.
pub(crate) fn remote_host(&self) -> &str {
let user_host = self.remote_user_host();
// It might be user@host, or it might be just the hostname or IP.
let (_, host) = user_host.split_once('@').unwrap_or(("", user_host));
host
hostname_of(&self.user_at_host)
}
}

Expand All @@ -114,7 +106,7 @@ mod test {
type Res = anyhow::Result<()>;
use engineering_repr::EngineeringQuantity;

use super::{CopyJobSpec, FileSpec};
use super::FileSpec;
use std::str::FromStr;

#[test]
Expand Down Expand Up @@ -176,27 +168,4 @@ mod test {
let q = "1k".parse::<EngineeringQuantity<u64>>().unwrap();
assert_eq!(u64::from(q), 1000);
}
#[test]
fn throughput_mode() {
let job = CopyJobSpec {
destination: FileSpec::from_str("host:file").unwrap(),
..Default::default()
};
assert_eq!(job.throughput_mode(), crate::transport::ThroughputMode::Tx);

let job2 = CopyJobSpec {
source: FileSpec::from_str("host:file").unwrap(),
..Default::default()
};
assert_eq!(job2.throughput_mode(), crate::transport::ThroughputMode::Rx);
}
#[test]
fn remote_user_host() {
let job = CopyJobSpec {
source: FileSpec::from_str("user@host:file").unwrap(),
..Default::default()
};
assert_eq!(job.remote_host(), "host");
assert_eq!(job.remote_user_host(), "user@host");
}
}
6 changes: 3 additions & 3 deletions src/client/main_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ pub async fn client_main(
spinner.set_message("Preparing");
let job_spec = crate::client::CopyJobSpec::try_from(&parameters)?;
let credentials = Credentials::generate()?;
let user_hostname = job_spec.remote_host();
let remote_host = super::ssh::resolve_host_alias(user_hostname, &config.ssh_config)
.unwrap_or_else(|| user_hostname.into());
let hostname = job_spec.remote_host();
let remote_host = super::ssh::resolve_host_alias(hostname, &config.ssh_config)
.unwrap_or_else(|| hostname.into());

// If the user didn't specify the address family: we do the DNS lookup, figure it out and tell ssh to use that.
// (Otherwise if we resolved a v4 and ssh a v6 - as might happen with round-robin DNS - that could be surprising.)
Expand Down
19 changes: 17 additions & 2 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,15 @@ impl TryFrom<&Parameters> for CopyJobSpec {
if !(source.user_at_host.is_none() ^ destination.user_at_host.is_none()) {
anyhow::bail!("One file argument must be remote");
}
let user_at_host = source
.user_at_host
.clone()
.unwrap_or_else(|| destination.user_at_host.clone().unwrap_or_default());

Ok(Self {
source,
destination,
user_at_host,
})
}
}
Expand Down Expand Up @@ -188,9 +193,19 @@ mod tests {

#[test]
fn test_copy_job_spec_conversion() {
let params = Parameters::parse_from(["test", "host:source.txt", "destination.txt"]);
let params = Parameters::parse_from(["test", "user@host:source.txt", "destination.txt"]);
let copy_job_spec = CopyJobSpec::try_from(&params).unwrap();
assert_eq!(copy_job_spec.source.to_string(), "host:source.txt");
assert_eq!(copy_job_spec.source.to_string(), "user@host:source.txt");
assert_eq!(copy_job_spec.destination.to_string(), "destination.txt");
assert_eq!(copy_job_spec.remote_host(), "host");
assert_eq!(copy_job_spec.user_at_host, "user@host");
}

#[test]
fn there_can_be_only_one_remote() {
let params =
Parameters::parse_from(["test", "user@host:source.txt", "user@host:destination.txt"]);
let _ = CopyJobSpec::try_from(&params).expect_err("but there can be only one!");
assert!(params.remote_host_lossy().is_err());
}
}

0 comments on commit eacac23

Please sign in to comment.