diff --git a/src/global.rs b/src/global.rs index 9be7127..9ee3b62 100644 --- a/src/global.rs +++ b/src/global.rs @@ -50,6 +50,10 @@ pub fn metric_namespace() -> String { std::env::var("metric_namespace").unwrap_or_else(|_| "LogRotation".to_string()) } +pub fn aws_partition() -> String { + std::env::var("aws_partition").unwrap_or_else(|_| "aws".to_string()) +} + pub fn initialize_logger() { env_logger::builder().format_timestamp(None).init(); } diff --git a/src/main.rs b/src/main.rs index be19a9d..8b27418 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ use terraform_aws_default_log_retention::{ cloudwatch_metrics_traits::PutMetricData, error::{Error, Severity}, event::CloudTrailEvent, - global::{cloudwatch_logs, cloudwatch_metrics, initialize_logger, log_group_tags, retention}, + global::{aws_partition, cloudwatch_logs, cloudwatch_metrics, initialize_logger, log_group_tags, retention}, metric_publisher::{self, Metric, MetricName}, retention_setter::get_existing_retention, }; @@ -97,8 +97,11 @@ async fn process_event( } let log_group_arn = format!( - "arn:aws:logs:{}:{}:log-group:{}", - event.detail.aws_region, event.detail.user_identity.account_id, log_group_name + "arn:{}:logs:{}:{}:log-group:{}", + aws_partition(), + event.detail.aws_region, + event.detail.user_identity.account_id, + log_group_name ); let tags = cloudwatch_logs.list_tags_for_resource(&log_group_arn).await?; if let Some(retention) = tags.tags().and_then(|tags| tags.get("retention")) { @@ -232,6 +235,57 @@ mod tests { insta::assert_debug_snapshot!(result); } + #[tokio::test] + // Testing for govcloud or China + async fn test_process_event_success_no_tags_different_aws_partition() { + std::env::set_var("aws_partition", "aws-cn"); + let event = CloudTrailEvent::new("123456789", "us-east-1", "MyLogGroupWasCreated"); + let log_group_arn = "arn:aws-cn:logs:us-east-1:123456789:log-group:MyLogGroupWasCreated"; + + let mut mock_cloud_watch_logs_client = MockCloudWatchLogs::new(); + mock_cloud_watch_logs_client + .expect_describe_log_groups() + .with(predicate::eq(Some("MyLogGroupWasCreated".to_string())), predicate::eq(None)) + .once() + .returning(|_, _| mock_describe_log_groups_response("MyLogGroupWasCreated", 0)); + + mock_cloud_watch_logs_client + .expect_list_tags_for_resource() + .with(predicate::eq(log_group_arn)) + .once() + .returning(|_| mock_list_tags_for_resource_response(None)); + + mock_cloud_watch_logs_client + .expect_put_retention_policy() + .with(predicate::eq("MyLogGroupWasCreated"), predicate::eq(30)) + .once() + .returning(|_, _| Ok(PutRetentionPolicyOutput::builder().build())); + + mock_cloud_watch_logs_client + .expect_tag_resource() + .with(predicate::eq(log_group_arn), predicate::eq(HashMap::new())) + .once() + .returning(|_, _| Ok(TagResourceOutput::builder().build())); + + let mut mock_cloud_watch_metrics_client = MockCloudWatchMetrics::new(); + mock_cloud_watch_metrics_client + .expect_put_metric_data() + .once() + .withf(|namespace, metrics| { + assert_eq!("LogRotation", namespace); + insta::assert_debug_snapshot!("CWMetricCall_process_event_success_no_tags", metrics); + true + }) + .returning(|_, _| Ok(PutMetricDataOutput::builder().build())); + + let result = process_event(event, mock_cloud_watch_logs_client, mock_cloud_watch_metrics_client) + .await + .expect("Should not fail"); + + std::env::remove_var("aws_partition"); + insta::assert_debug_snapshot!(result); + } + #[tokio::test] async fn test_process_event_fails_when_put_retention_policy_fails() { let event = CloudTrailEvent::new("123456789", "us-east-1", "MyLogGroupWasCreated"); @@ -487,12 +541,14 @@ mod tests { } } + #[allow(clippy::result_large_err)] // This is a test, don't care about large err type fn mock_describe_log_groups_response(log_group_name: &str, retention: i32) -> Result { let log_group = LogGroup::builder().log_group_name(log_group_name).retention_in_days(retention).build(); let response = DescribeLogGroupsOutput::builder().log_groups(log_group).build(); Ok(response) } + #[allow(clippy::result_large_err)] // This is a test, don't care about large err type fn mock_list_tags_for_resource_response(retention_tag_value: Option<&str>) -> Result { if let Some(retention_tag_value) = retention_tag_value { let mut tags: HashMap = HashMap::new(); diff --git a/src/retention_setter.rs b/src/retention_setter.rs index 4ed3b21..d5bfba9 100644 --- a/src/retention_setter.rs +++ b/src/retention_setter.rs @@ -91,6 +91,7 @@ mod tests { assert!(err.message.contains(group)); } + #[allow(clippy::result_large_err)] // This is a test, don't care about large err type fn mock_describe_log_groups_response(log_group_name: &str, retention: i32) -> Result { let log_group = LogGroup::builder().log_group_name(log_group_name).retention_in_days(retention).build(); let response = DescribeLogGroupsOutput::builder().log_groups(log_group).build(); diff --git a/src/snapshots/terraform_aws_default_log_retention__tests__process_event_success_no_tags_different_aws_partition.snap b/src/snapshots/terraform_aws_default_log_retention__tests__process_event_success_no_tags_different_aws_partition.snap new file mode 100644 index 0000000..0b38fd2 --- /dev/null +++ b/src/snapshots/terraform_aws_default_log_retention__tests__process_event_success_no_tags_different_aws_partition.snap @@ -0,0 +1,8 @@ +--- +source: src/main.rs +assertion_line: 286 +expression: result +--- +Object { + "message": String("Retention set successfully"), +} diff --git a/tf-lambda-iam-role.tf b/tf-lambda-iam-role.tf index d5f95b6..3035eee 100644 --- a/tf-lambda-iam-role.tf +++ b/tf-lambda-iam-role.tf @@ -22,7 +22,7 @@ data "aws_iam_policy_document" "log_retention" { "logs:PutRetentionPolicy", "logs:DescribeLogGroups" ] - resources = ["arn:aws:logs:*:*:*"] + resources = ["arn:${data.aws_partition.current.partition}:logs:*:*:*"] } statement { diff --git a/tf-log-retention-lambda.tf b/tf-log-retention-lambda.tf index 93cde3c..57dffb0 100644 --- a/tf-log-retention-lambda.tf +++ b/tf-log-retention-lambda.tf @@ -23,6 +23,7 @@ resource "aws_lambda_function" "log_retention" { log_retention_in_days = var.log_retention_in_days log_group_tags = local.log_group_tags_json metric_namespace = var.metric_namespace + aws_partition = data.aws_partition.current.partition RUST_BACKTRACE = 1 RUST_LOG = "warn,terraform_aws_default_log_retention=${var.log_level}" # https://docs.rs/env_logger/latest/env_logger/ } diff --git a/tf-lookups.tf b/tf-lookups.tf index 5f0b9ec..83eb2ed 100644 --- a/tf-lookups.tf +++ b/tf-lookups.tf @@ -1,6 +1,7 @@ data "aws_region" "current" {} data "aws_iam_account_alias" "current" {} data "aws_caller_identity" "current" {} +data "aws_partition" "current" {} # .issuer_arn grabs the underlying ARN (removes the assumed-role portion) data "aws_iam_session_context" "current" {