From 4a1c52e51ae67208dcac9bd1212d363e6ae12e80 Mon Sep 17 00:00:00 2001 From: Helena Greebe Date: Mon, 27 Nov 2023 20:23:13 -0500 Subject: [PATCH] Add more unit test coverage and remove unused functions --- .../cloudwatch_agent_config_util.py | 6 +- .../cloudwatch/write_cloudwatch_agent_json.py | 16 --- .../cloudwatch_agent/test_cloudwatch_agent.py | 64 +++++++++-- .../test_validate_json/config.json | 11 ++ .../test_write_validated_json/output.json | 16 +++ .../test_write_cloudwatch_agent_json.py | 105 ++++++++---------- 6 files changed, 132 insertions(+), 86 deletions(-) create mode 100644 test/unit/cloudwatch_agent/test_cloudwatch_agent/test_validate_json/config.json create mode 100644 test/unit/cloudwatch_agent/test_cloudwatch_agent/test_write_validated_json/output.json diff --git a/cookbooks/aws-parallelcluster-environment/files/cloudwatch/cloudwatch_agent_config_util.py b/cookbooks/aws-parallelcluster-environment/files/cloudwatch/cloudwatch_agent_config_util.py index 4598e64a0c..a13f704b78 100644 --- a/cookbooks/aws-parallelcluster-environment/files/cloudwatch/cloudwatch_agent_config_util.py +++ b/cookbooks/aws-parallelcluster-environment/files/cloudwatch/cloudwatch_agent_config_util.py @@ -136,12 +136,12 @@ def _write_log_configs(log_configs): def write_validated_json(input_json): """Write validated JSON back to the CloudWatch log configs file.""" log_configs = _read_log_configs() - log_configs["log_configs"].extend(input_json.get("log_configs")) + input_json["log_configs"].extend(log_configs.get("log_configs")) # NOTICE: the input JSON's timestamp_formats dict is the one that is # updated, so that those defined in the original config aren't clobbered. - log_configs["timestamp_formats"] = input_json["timestamp_formats"].update(log_configs.get("timestamp_formats")) - _write_log_configs(log_configs) + input_json["timestamp_formats"].update(log_configs.get("timestamp_formats")) + _write_log_configs(input_json) def create_backup(): diff --git a/cookbooks/aws-parallelcluster-environment/files/cloudwatch/write_cloudwatch_agent_json.py b/cookbooks/aws-parallelcluster-environment/files/cloudwatch/write_cloudwatch_agent_json.py index caaaef089d..e2628d4534 100644 --- a/cookbooks/aws-parallelcluster-environment/files/cloudwatch/write_cloudwatch_agent_json.py +++ b/cookbooks/aws-parallelcluster-environment/files/cloudwatch/write_cloudwatch_agent_json.py @@ -125,11 +125,6 @@ def select_logs(configs, args): return selected_configs -def get_node_roles(scheudler_plugin_node_roles): - node_type_roles_map = {"ALL": ["ComputeFleet", "HeadNode"], "HEAD": ["HeadNode"], "COMPUTE": ["ComputeFleet"]} - return node_type_roles_map.get(scheudler_plugin_node_roles) - - def add_timestamps(configs, timestamps_dict): """For each config, set its timestamp_format field based on its timestamp_format_key field.""" for config in configs: @@ -156,8 +151,6 @@ def _collect_metric_properties(metric_config): # initial dict with default key-value pairs collected = {"metrics_collection_interval": DEFAULT_METRICS_COLLECTION_INTERVAL} collected.update({key: metric_config[key] for key in desired_keys if key in metric_config}) - if "append_dimensions" in metric_config and "ClusterName" in metric_config["append_dimensions"]: - collected.update({"append_dimensions": {"ClusterName": get_node_info().get("stack_name")}}) return collected return { @@ -211,15 +204,6 @@ def create_config(log_configs, metric_configs): return cw_agent_config -def get_dict_value(value, attributes, default=None): - """Get key value from dictionary and return default if the key does not exist.""" - for key in attributes.split("."): - value = value.get(key, None) - if value is None: - return default - return value - - def main(): """Create cloudwatch agent config file.""" args = parse_args() diff --git a/test/unit/cloudwatch_agent/test_cloudwatch_agent.py b/test/unit/cloudwatch_agent/test_cloudwatch_agent.py index 44fcdb028f..30c4c407c6 100644 --- a/test/unit/cloudwatch_agent/test_cloudwatch_agent.py +++ b/test/unit/cloudwatch_agent/test_cloudwatch_agent.py @@ -8,38 +8,47 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest from assertpy import assert_that -from cloudwatch_agent_config_util import validate_json +from cloudwatch_agent_config_util import validate_json, write_validated_json @pytest.mark.parametrize( "error_type", - [None, "Duplicates", "Schema", "Timestamp"], + [None, "Duplicates", "Schema", "Timestamp", "FileNotFound", "InvalidJSON"], ) -def test_validate_json(mocker, error_type): +def test_validate_json(mocker, test_datadir, error_type): input_json = { "timestamp_formats": { "month_first": "%b %-d %H:%M:%S", }, "log_configs": [{"timestamp_format_key": "month_first", "log_stream_name": "test"}], } + input_file_path = os.path.join(test_datadir, "config.json") + input_file = open(input_file_path, encoding="utf-8") if error_type == "Schema": input_json = "ERROR" elif error_type == "Duplicates": input_json["log_configs"].append({"timestamp_format_key": "month_first", "log_stream_name": "test"}) elif error_type == "Timestamp": input_json["log_configs"].append({"timestamp_format_key": "default", "log_stream_name": "test2"}) - print(input_json) schema = {"type": "object", "properties": {"timestamp_formats": {"type": "object"}}} - mocker.patch( - "cloudwatch_agent_config_util._read_json_at", - return_value=input_json, - ) + if error_type != "FileNotFound" and error_type != "InvalidJSON": + mocker.patch( + "builtins.open", + return_value=input_file, + ) mocker.patch( "cloudwatch_agent_config_util._read_schema", return_value=schema, ) + if error_type == "InvalidJSON": + mocker.patch( + "builtins.open", + side_effect=ValueError, + ) try: validate_json(input_json) validate_json() @@ -51,3 +60,42 @@ def test_validate_json(mocker, error_type): assert_that(e.args[0]).contains("The following log_stream_name values are used multiple times: test") elif error_type == "Timestamp": assert_that(e.args[0]).contains("contains an invalid timestamp_format_key") + elif error_type == "FileNotFound": + assert_that(e.args[0]).contains("No file exists") + elif error_type == "InvalidJSON": + assert_that(e.args[0]).contains("contains invalid JSON") + finally: + input_file.close() + + +def test_write_validated_json(mocker, test_datadir, tmpdir): + input_json = { + "timestamp_formats": { + "month_first": "%b %-d %H:%M:%S", + }, + "log_configs": [{"timestamp_format_key": "month_first", "log_stream_name": "test"}], + } + + input_json2 = { + "timestamp_formats": { + "default": "%Y-%m-%d %H:%M:%S,%f", + }, + "log_configs": [{"timestamp_format_key": "month_first", "log_stream_name": "test2"}], + } + + output_file = f"{tmpdir}/output.json" + + mocker.patch( + "cloudwatch_agent_config_util._read_json_at", + return_value=input_json, + ) + + mocker.patch("os.environ.get", return_value=output_file) + + write_validated_json(input_json2) + + with ( + open(output_file, encoding="utf-8") as f, + open(os.path.join(test_datadir, "output.json"), encoding="utf-8") as exp_f, + ): + assert_that(f.read()).is_equal_to(exp_f.read()) diff --git a/test/unit/cloudwatch_agent/test_cloudwatch_agent/test_validate_json/config.json b/test/unit/cloudwatch_agent/test_cloudwatch_agent/test_validate_json/config.json new file mode 100644 index 0000000000..c8958b5474 --- /dev/null +++ b/test/unit/cloudwatch_agent/test_cloudwatch_agent/test_validate_json/config.json @@ -0,0 +1,11 @@ +{ + "timestamp_formats": { + "month_first": "%b %-d %H:%M:%S" + }, + "log_configs": [ + { + "timestamp_format_key": "month_first", + "log_stream_name": "test" + } + ] +} \ No newline at end of file diff --git a/test/unit/cloudwatch_agent/test_cloudwatch_agent/test_write_validated_json/output.json b/test/unit/cloudwatch_agent/test_cloudwatch_agent/test_write_validated_json/output.json new file mode 100644 index 0000000000..87b00f6ac0 --- /dev/null +++ b/test/unit/cloudwatch_agent/test_cloudwatch_agent/test_write_validated_json/output.json @@ -0,0 +1,16 @@ +{ + "timestamp_formats": { + "default": "%Y-%m-%d %H:%M:%S,%f", + "month_first": "%b %-d %H:%M:%S" + }, + "log_configs": [ + { + "timestamp_format_key": "month_first", + "log_stream_name": "test2" + }, + { + "timestamp_format_key": "month_first", + "log_stream_name": "test" + } + ] +} \ No newline at end of file diff --git a/test/unit/cloudwatch_agent/test_write_cloudwatch_agent_json.py b/test/unit/cloudwatch_agent/test_write_cloudwatch_agent_json.py index fac1780f45..a71b7a32d0 100644 --- a/test/unit/cloudwatch_agent/test_write_cloudwatch_agent_json.py +++ b/test/unit/cloudwatch_agent/test_write_cloudwatch_agent_json.py @@ -8,6 +8,8 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest from assertpy import assert_that from write_cloudwatch_agent_json import ( @@ -18,6 +20,7 @@ add_timestamps, create_config, filter_output_fields, + gethostname, select_configs_for_feature, select_configs_for_node_role, select_configs_for_platform, @@ -77,7 +80,6 @@ } -@pytest.mark.asyncio def test_add_log_group_name_params(): configs = add_log_group_name_params("test", CONFIGS) for config in configs: @@ -85,41 +87,28 @@ def test_add_log_group_name_params(): assert_that(config["log_group_name"]).is_equal_to("test") -@pytest.mark.asyncio -def test_add_instance_log_stream_prefixes(mocker): - instance_id = "i-0096test" - mocker.patch( - "write_cloudwatch_agent_json.gethostname", - return_value=instance_id, - ) - +def test_add_instance_log_stream_prefixes(): configs = add_instance_log_stream_prefixes(CONFIGS) for config in configs: - assert_that(config["log_stream_name"]).contains(instance_id) - - -@pytest.mark.asyncio -def test_select_configs_for_scheduler(): - configs = select_configs_for_scheduler(CONFIGS, "slurm") - assert_that(len(configs)).is_equal_to(2) - configs = select_configs_for_scheduler(CONFIGS, "awsbatch") - assert_that(len(configs)).is_equal_to(3) + assert_that(config["log_stream_name"]).contains(gethostname()) -@pytest.mark.asyncio -def test_select_configs_for_node_role(): - configs = select_configs_for_node_role(CONFIGS, "ComputeFleet") - assert_that(len(configs)).is_equal_to(2) - configs = select_configs_for_node_role(CONFIGS, "HeadNode") - assert_that(len(configs)).is_equal_to(3) - - -@pytest.mark.asyncio -def test_select_configs_for_platform(): - configs = select_configs_for_platform(CONFIGS, "amazon") - assert_that(len(configs)).is_equal_to(2) - configs = select_configs_for_platform(CONFIGS, "ubuntu") - assert_that(len(configs)).is_equal_to(2) +@pytest.mark.parametrize( + "dimensions", + [ + {"platform": "amazon", "length": 2}, + {"platform": "ubuntu", "length": 2}, + {"scheduler": "slurm", "length": 2}, + {"scheduler": "awsbatch", "length": 3}, + ], +) +def test_select_configs_for_dimesion(dimensions): + if "platform" in dimensions.keys(): + configs = select_configs_for_platform(CONFIGS, dimensions["platform"]) + assert_that(len(configs)).is_equal_to(dimensions["length"]) + else: + configs = select_configs_for_scheduler(CONFIGS, dimensions["scheduler"]) + assert_that(len(configs)).is_equal_to(dimensions["length"]) @pytest.mark.parametrize( @@ -140,7 +129,6 @@ def test_select_configs_for_feature(mocker, info): assert_that(len(selected_configs)).is_equal_to(info["length"]) -@pytest.mark.asyncio def test_add_timestamps(): timestamp_formats = {"month_first": "%b %-d %H:%M:%S", "default": "%Y-%m-%d %H:%M:%S,%f"} configs = add_timestamps(CONFIGS, timestamp_formats) @@ -149,7 +137,6 @@ def test_add_timestamps(): assert_that(config["timestamp_format"]).is_equal_to(timestamp_format) -@pytest.mark.asyncio def test_filter_output_fields(): desired_keys = ["log_stream_name", "file_path", "timestamp_format", "log_group_name"] configs = filter_output_fields(CONFIGS) @@ -158,22 +145,14 @@ def test_filter_output_fields(): assert_that(desired_keys).contains(key) -@pytest.mark.asyncio -def test_create_config(mocker): - instance_id = "i-0096test" - mocker.patch( - "write_cloudwatch_agent_json.gethostname", - return_value=instance_id, - ) - +def test_create_config(): cw_agent_config = create_config(CONFIGS, METRIC_CONFIGS) assert_that(len(cw_agent_config)).is_equal_to(2) assert_that(len(cw_agent_config["logs"]["logs_collected"]["files"]["collect_list"])).is_equal_to(3) - assert_that(cw_agent_config["logs"]["log_stream_name"]).contains(instance_id) + assert_that(cw_agent_config["logs"]["log_stream_name"]).contains(gethostname()) -@pytest.mark.asyncio def test_select_metrics(mocker): mocker.patch( "write_cloudwatch_agent_json.select_configs_for_node_role", @@ -187,23 +166,31 @@ def test_select_metrics(mocker): assert_that(metric_configs["metrics_collected"][key]).does_not_contain_key("node_roles") -@pytest.mark.asyncio -def test_add_append_dimensions(): - metrics = {"metrics_collected": METRIC_CONFIGS["metrics_collected"]} - metrics = add_append_dimensions(metrics, METRIC_CONFIGS) - - assert_that(len(metrics)).is_equal_to(2) - assert_that(metrics["append_dimensions"]).is_type_of(dict) - assert_that(metrics["append_dimensions"]).contains_key("InstanceId") +@pytest.mark.parametrize( + "node", + [ + {"role": "ComputeFleet", "length": 2}, + {"role": "HeadNode", "length": 3}, + ], +) +def test_select_configs_for_node_role(node): + configs = select_configs_for_node_role(CONFIGS, node["role"]) + assert_that(len(configs)).is_equal_to(node["length"]) -@pytest.mark.asyncio -def test_add_aggregation_dimensions(): +@pytest.mark.parametrize( + "dimension", + [{"name": "append", "type": dict}, {"name": "aggregation", "type": list}], +) +def test_add_dimensions(dimension): metrics = {"metrics_collected": METRIC_CONFIGS["metrics_collected"]} - metrics = add_aggregation_dimensions(metrics, METRIC_CONFIGS) + if dimension["name"] == "append": + metrics = add_append_dimensions(metrics, METRIC_CONFIGS) + output = metrics["append_dimensions"] + else: + metrics = add_aggregation_dimensions(metrics, METRIC_CONFIGS) + output = metrics["aggregation_dimensions"][0] assert_that(len(metrics)).is_equal_to(2) - assert_that(len(metrics["aggregation_dimensions"])).is_equal_to(1) - assert_that(metrics["aggregation_dimensions"][0]).is_type_of(list) - assert_that(metrics["aggregation_dimensions"][0]).contains("InstanceId") - assert_that(metrics["aggregation_dimensions"][0]).contains("path") + assert_that(output).is_type_of(dimension["type"]) + assert_that(output).contains("InstanceId")