diff --git a/attestation-agent/kbs_protocol/src/client/rcar_client.rs b/attestation-agent/kbs_protocol/src/client/rcar_client.rs index e82196bc0..18f3a054f 100644 --- a/attestation-agent/kbs_protocol/src/client/rcar_client.rs +++ b/attestation-agent/kbs_protocol/src/client/rcar_client.rs @@ -56,6 +56,26 @@ async fn get_request_extra_params() -> serde_json::Value { extra_params } +async fn get_hash_algorithm(extra_params: serde_json::Value) -> Result { + let algorithm = match extra_params.get(SELECTED_HASH_ALGORITHM_JSON_KEY) { + Some(selected_hash_algorithm) => { + let name = selected_hash_algorithm + .as_str() + .ok_or(Error::UnexpectedJSONDataType( + "string".into(), + selected_hash_algorithm.to_string(), + ))? + .to_lowercase(); + + name.parse::() + .map_err(|_| Error::InvalidHashAlgorithm(name))? + } + None => DEFAULT_HASH_ALGORITHM, + }; + + Ok(algorithm) +} + async fn build_request(tee: Tee) -> Request { let extra_params = get_request_extra_params().await; @@ -168,19 +188,7 @@ impl KbsClient> { let extra_params = challenge.extra_params; - let algorithm = match extra_params.get(SELECTED_HASH_ALGORITHM_JSON_KEY) { - Some(selected_hash_algorithm) => { - // Note the blank string which will be handled as an error when parsed. - let name = selected_hash_algorithm - .as_str() - .unwrap_or("") - .to_lowercase(); - - name.parse::() - .map_err(|_| Error::InvalidHashAlgorithm(name))? - } - None => DEFAULT_HASH_ALGORITHM, - }; + let algorithm = get_hash_algorithm(extra_params).await?; let tee_pubkey = self.tee_key.export_pubkey()?; let runtime_data = json!({ @@ -350,16 +358,19 @@ impl KbsClientCapabilities for KbsClient> { #[cfg(test)] mod test { use crypto::HashAlgorithm; + use rstest::rstest; + use serde_json::{json, Value}; use std::{env, path::PathBuf, time::Duration}; use testcontainers::{clients, images::generic::GenericImage}; use tokio::fs; use crate::{ - evidence_provider::NativeEvidenceProvider, KbsClientBuilder, KbsClientCapabilities, + evidence_provider::NativeEvidenceProvider, Error, KbsClientBuilder, KbsClientCapabilities, }; use crate::client::rcar_client::{ - build_request, get_request_extra_params, KBS_PROTOCOL_VERSION, + build_request, get_hash_algorithm, get_request_extra_params, Result, + DEFAULT_HASH_ALGORITHM, KBS_PROTOCOL_VERSION, SELECTED_HASH_ALGORITHM_JSON_KEY, SUPPORTED_HASH_ALGORITHMS_JSON_KEY, }; use kbs_types::Tee; @@ -502,4 +513,63 @@ mod test { assert_eq!(request.extra_params, expected_extra_params); } } + + #[tokio::test] + #[rstest] + #[case(json!({}), Ok(DEFAULT_HASH_ALGORITHM))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: ""}), Err(Error::InvalidHashAlgorithm("".into())))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: "foo"}), Err(Error::InvalidHashAlgorithm("foo".into())))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: "sha256"}), Ok(HashAlgorithm::Sha256))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: "SHA256"}), Ok(HashAlgorithm::Sha256))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: "sha384"}), Ok(HashAlgorithm::Sha384))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: "SHA384"}), Ok(HashAlgorithm::Sha384))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: "sha512"}), Ok(HashAlgorithm::Sha512))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: "SHA512"}), Ok(HashAlgorithm::Sha512))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: []}), Err(Error::UnexpectedJSONDataType("string".into(), "[]".into())))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: {}}), Err(Error::UnexpectedJSONDataType("string".into(), "{}".into())))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: true}), Err(Error::UnexpectedJSONDataType("string".into(), "true".into())))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: 99999}), Err(Error::UnexpectedJSONDataType("string".into(), "99999".into())))] + #[rstest] + #[case(json!({SELECTED_HASH_ALGORITHM_JSON_KEY: 3.141}), Err(Error::UnexpectedJSONDataType("string".into(), "3.141".into())))] + async fn test_get_hash_algorithm( + #[case] extra_params: Value, + #[case] expected_result: Result, + ) { + let msg = + format!("test: extra_params: {extra_params:?}, expected result: {expected_result:?}"); + + let actual_result = get_hash_algorithm(extra_params).await; + + let msg = format!("{msg}, actual result: {actual_result:?}"); + + if std::env::var("DEBUG").is_ok() { + println!("DEBUG: {}", msg); + } + + if expected_result.is_err() { + let expected_result_msg = format!("{expected_result:?}"); + let actual_result_msg = format!("{actual_result:?}"); + + assert_eq!(expected_result_msg, actual_result_msg, "{msg:?}"); + + return; + } + + let expected_hash_algorithm = expected_result.unwrap(); + let actual_hash_algorithm = actual_result.unwrap(); + + assert_eq!(expected_hash_algorithm, actual_hash_algorithm, "{msg:?}"); + } } diff --git a/attestation-agent/kbs_protocol/src/error.rs b/attestation-agent/kbs_protocol/src/error.rs index 97d95ccb6..5a75d59cb 100644 --- a/attestation-agent/kbs_protocol/src/error.rs +++ b/attestation-agent/kbs_protocol/src/error.rs @@ -47,4 +47,7 @@ pub enum Error { #[error("invalid hash algorithm: {0}")] InvalidHashAlgorithm(String), + + #[error("unexpected JSON data type: expected {0}, got {1}")] + UnexpectedJSONDataType(String, String), }