diff --git a/keylime-agent/src/config.rs b/keylime-agent/src/config.rs index aa37473c..21c05169 100644 --- a/keylime-agent/src/config.rs +++ b/keylime-agent/src/config.rs @@ -12,6 +12,7 @@ use keylime::{ hostname_parser::{parse_hostname, HostnameParsingError}, ip_parser::{parse_ip, IpParsingError}, list_parser::{parse_list, ListParsingError}, + version::{self, GetErrorInput}, }; use log::*; use serde::{Deserialize, Serialize, Serializer}; @@ -20,6 +21,7 @@ use std::{ collections::HashMap, env, path::{Path, PathBuf}, + str::FromStr, }; use thiserror::Error; use uuid::Uuid; @@ -621,15 +623,34 @@ fn config_translate_keywords( } } versions => { - let parsed: Vec:: = match parse_list(&config.agent.api_versions) { - Ok(list) => list + let parsed: Vec = match parse_list( + &config.agent.api_versions, + ) { + Ok(list) => { + let mut filtered_versions = list .iter() .inspect(|e| { if !SUPPORTED_API_VERSIONS.contains(e) { warn!("Skipping API version \"{e}\" obtained from 'api_versions' configuration option") }}) .filter(|e| SUPPORTED_API_VERSIONS.contains(e)) - .map(|&s| s.into()) - .collect(), + .map(|&s| version::Version::from_str(s)) + .inspect(|err| if let Err(e) = err { + warn!("Skipping API version \"{}\" obtained from 'api_versions' configuration option", e.input()); + }) + .filter(|e| e.is_ok()) + .map(|v| { + let Ok(ver) = v else {unreachable!();}; + ver + }) + .collect::>(); + + // Sort the versions from the configuration from the oldest to the newest + filtered_versions.sort(); + filtered_versions + .iter() + .map(|v| v.to_string()) + .collect::>() + } Err(e) => { warn!("Failed to parse list from 'api_versions' configuration option; using default supported versions"); SUPPORTED_API_VERSIONS.iter().map(|&s| s.into()).collect() @@ -996,6 +1017,63 @@ mod tests { assert_eq!(version, old); } + #[test] + fn test_translate_invalid_api_versions_filtered() { + let old = SUPPORTED_API_VERSIONS[0]; + + let mut test_config = KeylimeConfig { + agent: AgentConfig { + api_versions: format!("a.b, {old}, c.d"), + ..Default::default() + }, + }; + let result = config_translate_keywords(&test_config); + assert!(result.is_ok()); + let config = result.unwrap(); //#[allow_ci] + let version = config.agent.api_versions; + assert_eq!(version, old); + } + + #[test] + fn test_translate_invalid_api_versions_fallback_default() { + let old = SUPPORTED_API_VERSIONS; + + let mut test_config = KeylimeConfig { + agent: AgentConfig { + api_versions: "a.b, c.d".to_string(), + ..Default::default() + }, + }; + let result = config_translate_keywords(&test_config); + assert!(result.is_ok()); + let config = result.unwrap(); //#[allow_ci] + let version = config.agent.api_versions; + assert_eq!(version, old.join(", ")); + } + + #[test] + fn test_translate_api_versions_sort() { + let old = SUPPORTED_API_VERSIONS; + let reversed = SUPPORTED_API_VERSIONS + .iter() + .rev() + .copied() + .collect::>() + .join(", "); + + let mut test_config = KeylimeConfig { + agent: AgentConfig { + api_versions: reversed, + ..Default::default() + }, + }; + let result = config_translate_keywords(&test_config); + assert!(result.is_ok()); + let config = result.unwrap(); //#[allow_ci] + let version = config.agent.api_versions; + assert_eq!(version, old.join(", ")); + } + #[test] fn test_get_uuid() { assert_eq!(get_uuid("hash_ek"), "hash_ek"); diff --git a/keylime/src/registrar_client.rs b/keylime/src/registrar_client.rs index 4c408fd0..dcfcd22d 100644 --- a/keylime/src/registrar_client.rs +++ b/keylime/src/registrar_client.rs @@ -341,7 +341,8 @@ impl<'a> RegistrarClientBuilder<'a> { Ok(registrar_api_version.to_string()) } else { // Check if one of the API versions that the registrar supports is enabled - // from the latest to the oldest + // from the latest to the oldest, assuming the reported versions are ordered from the + // oldest to the newest for reg_supported_version in resp.results.supported_versions.iter().rev() { @@ -632,7 +633,8 @@ impl RegistrarClient<'_> { // In case the registrar does not support the '/version' endpoint, try the enabled API // versions if self.api_version == UNKNOWN_API_VERSION { - for api_version in &self.enabled_api_versions { + // Assume the list of enabled versions is ordered from the oldest to the newest + for api_version in self.enabled_api_versions.iter().rev() { info!("Trying to register agent using API version {api_version}"); let r = self.try_register_agent(api_version).await; diff --git a/keylime/src/version.rs b/keylime/src/version.rs index 662c1dd3..17abfe71 100644 --- a/keylime/src/version.rs +++ b/keylime/src/version.rs @@ -1,4 +1,6 @@ use serde::{Deserialize, Serialize}; +use std::{fmt, str::FromStr}; +use thiserror::Error; #[derive(Serialize, Deserialize, Debug)] pub struct KeylimeVersion { @@ -10,3 +12,148 @@ pub struct KeylimeRegistrarVersion { pub current_version: String, pub supported_versions: Vec, } + +pub trait GetErrorInput { + fn input(&self) -> String; +} + +#[derive(Error, Debug)] +pub enum VersionParsingError { + /// The version input was malformed + #[error("input '{input}' malformed as a version")] + MalformedVersion { input: String }, + + /// The parts of the version were not numbers + #[error("parts of version '{input}' were not numbers")] + ParseError { + input: String, + source: std::num::ParseIntError, + }, +} + +impl GetErrorInput for VersionParsingError { + fn input(&self) -> String { + match self { + VersionParsingError::MalformedVersion { input } => input.into(), + VersionParsingError::ParseError { input, source: _ } => { + input.into() + } + } + } +} + +// Implement the trait for all the references +impl GetErrorInput for &T +where + T: GetErrorInput, +{ + fn input(&self) -> String { + (**self).input() + } +} + +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +pub struct Version { + major: u32, + minor: u32, +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}.{}", self.major, self.minor) + } +} + +impl FromStr for Version { + type Err = VersionParsingError; + + fn from_str(input: &str) -> Result { + let mut parts = input.split('.'); + match (parts.next(), parts.next()) { + (Some(major), Some(minor)) => Ok(Version { + major: major.parse().map_err(|e| { + VersionParsingError::ParseError { + input: input.to_string(), + source: e, + } + })?, + minor: minor.parse().map_err(|e| { + VersionParsingError::ParseError { + input: input.to_string(), + source: e, + } + })?, + }), + _ => Err(VersionParsingError::MalformedVersion { + input: input.to_string(), + }), + } + } +} + +impl TryFrom<&str> for Version { + type Error = VersionParsingError; + + fn try_from(input: &str) -> Result { + Version::from_str(input) + } +} + +impl TryFrom for Version { + type Error = VersionParsingError; + + fn try_from(input: String) -> Result { + Version::from_str(input.as_str()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_from_str() { + let v = Version::from_str("1.2").unwrap(); //#[allow_ci] + assert_eq!(v, Version { major: 1, minor: 2 }); + let v2: Version = "3.4".try_into().unwrap(); //#[allow_ci] + assert_eq!(v2, Version { major: 3, minor: 4 }); + let v3: Version = "5.6".to_string().try_into().unwrap(); //#[allow_ci] + assert_eq!(v3, Version { major: 5, minor: 6 }); + } + + #[test] + fn test_display() { + let s = format!("{}", Version { major: 1, minor: 2 }); + assert_eq!(s, "1.2".to_string()); + } + + #[test] + fn test_ord() { + let v11: Version = "1.1".try_into().unwrap(); //#[allow_ci] + let v12: Version = "1.2".try_into().unwrap(); //#[allow_ci] + let v21: Version = "2.1".try_into().unwrap(); //#[allow_ci] + let v110: Version = "1.10".try_into().unwrap(); //#[allow_ci] + assert!(v11 < v12); + assert!(v12 < v110); + assert!(v110 < v21); + + let mut v = vec![v12.clone(), v110.clone(), v11.clone()]; + v.sort(); + let expected = vec![v11, v12, v110]; + assert_eq!(v, expected); + } + + #[test] + fn test_invalid() { + let result = Version::from_str("a.b"); + assert!(result.is_err()); + let result = Version::from_str("1.b"); + assert!(result.is_err()); + let result = Version::from_str("a.2"); + assert!(result.is_err()); + let result = Version::from_str("22"); + assert!(result.is_err()); + let result = Version::from_str(".12"); + assert!(result.is_err()); + } +}