Skip to content

Commit

Permalink
Add more unit test coverage and remove unused functions
Browse files Browse the repository at this point in the history
  • Loading branch information
hgreebe committed Nov 28, 2023
1 parent 85cb2c0 commit 4a1c52e
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def _write_log_configs(log_configs):
def write_validated_json(input_json):
"""Write validated JSON back to the CloudWatch log configs file."""
log_configs = _read_log_configs()
log_configs["log_configs"].extend(input_json.get("log_configs"))
input_json["log_configs"].extend(log_configs.get("log_configs"))

# NOTICE: the input JSON's timestamp_formats dict is the one that is
# updated, so that those defined in the original config aren't clobbered.
log_configs["timestamp_formats"] = input_json["timestamp_formats"].update(log_configs.get("timestamp_formats"))
_write_log_configs(log_configs)
input_json["timestamp_formats"].update(log_configs.get("timestamp_formats"))
_write_log_configs(input_json)


def create_backup():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,6 @@ def select_logs(configs, args):
return selected_configs


def get_node_roles(scheudler_plugin_node_roles):
node_type_roles_map = {"ALL": ["ComputeFleet", "HeadNode"], "HEAD": ["HeadNode"], "COMPUTE": ["ComputeFleet"]}
return node_type_roles_map.get(scheudler_plugin_node_roles)


def add_timestamps(configs, timestamps_dict):
"""For each config, set its timestamp_format field based on its timestamp_format_key field."""
for config in configs:
Expand All @@ -156,8 +151,6 @@ def _collect_metric_properties(metric_config):
# initial dict with default key-value pairs
collected = {"metrics_collection_interval": DEFAULT_METRICS_COLLECTION_INTERVAL}
collected.update({key: metric_config[key] for key in desired_keys if key in metric_config})
if "append_dimensions" in metric_config and "ClusterName" in metric_config["append_dimensions"]:
collected.update({"append_dimensions": {"ClusterName": get_node_info().get("stack_name")}})
return collected

return {
Expand Down Expand Up @@ -211,15 +204,6 @@ def create_config(log_configs, metric_configs):
return cw_agent_config


def get_dict_value(value, attributes, default=None):
"""Get key value from dictionary and return default if the key does not exist."""
for key in attributes.split("."):
value = value.get(key, None)
if value is None:
return default
return value


def main():
"""Create cloudwatch agent config file."""
args = parse_args()
Expand Down
64 changes: 56 additions & 8 deletions test/unit/cloudwatch_agent/test_cloudwatch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,47 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
from assertpy import assert_that
from cloudwatch_agent_config_util import validate_json
from cloudwatch_agent_config_util import validate_json, write_validated_json


@pytest.mark.parametrize(
"error_type",
[None, "Duplicates", "Schema", "Timestamp"],
[None, "Duplicates", "Schema", "Timestamp", "FileNotFound", "InvalidJSON"],
)
def test_validate_json(mocker, error_type):
def test_validate_json(mocker, test_datadir, error_type):
input_json = {
"timestamp_formats": {
"month_first": "%b %-d %H:%M:%S",
},
"log_configs": [{"timestamp_format_key": "month_first", "log_stream_name": "test"}],
}
input_file_path = os.path.join(test_datadir, "config.json")
input_file = open(input_file_path, encoding="utf-8")
if error_type == "Schema":
input_json = "ERROR"
elif error_type == "Duplicates":
input_json["log_configs"].append({"timestamp_format_key": "month_first", "log_stream_name": "test"})
elif error_type == "Timestamp":
input_json["log_configs"].append({"timestamp_format_key": "default", "log_stream_name": "test2"})
print(input_json)
schema = {"type": "object", "properties": {"timestamp_formats": {"type": "object"}}}
mocker.patch(
"cloudwatch_agent_config_util._read_json_at",
return_value=input_json,
)
if error_type != "FileNotFound" and error_type != "InvalidJSON":
mocker.patch(
"builtins.open",
return_value=input_file,
)
mocker.patch(
"cloudwatch_agent_config_util._read_schema",
return_value=schema,
)
if error_type == "InvalidJSON":
mocker.patch(
"builtins.open",
side_effect=ValueError,
)
try:
validate_json(input_json)
validate_json()
Expand All @@ -51,3 +60,42 @@ def test_validate_json(mocker, error_type):
assert_that(e.args[0]).contains("The following log_stream_name values are used multiple times: test")
elif error_type == "Timestamp":
assert_that(e.args[0]).contains("contains an invalid timestamp_format_key")
elif error_type == "FileNotFound":
assert_that(e.args[0]).contains("No file exists")
elif error_type == "InvalidJSON":
assert_that(e.args[0]).contains("contains invalid JSON")
finally:
input_file.close()


def test_write_validated_json(mocker, test_datadir, tmpdir):
input_json = {
"timestamp_formats": {
"month_first": "%b %-d %H:%M:%S",
},
"log_configs": [{"timestamp_format_key": "month_first", "log_stream_name": "test"}],
}

input_json2 = {
"timestamp_formats": {
"default": "%Y-%m-%d %H:%M:%S,%f",
},
"log_configs": [{"timestamp_format_key": "month_first", "log_stream_name": "test2"}],
}

output_file = f"{tmpdir}/output.json"

mocker.patch(
"cloudwatch_agent_config_util._read_json_at",
return_value=input_json,
)

mocker.patch("os.environ.get", return_value=output_file)

write_validated_json(input_json2)

with (
open(output_file, encoding="utf-8") as f,
open(os.path.join(test_datadir, "output.json"), encoding="utf-8") as exp_f,
):
assert_that(f.read()).is_equal_to(exp_f.read())
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"timestamp_formats": {
"month_first": "%b %-d %H:%M:%S"
},
"log_configs": [
{
"timestamp_format_key": "month_first",
"log_stream_name": "test"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"timestamp_formats": {
"default": "%Y-%m-%d %H:%M:%S,%f",
"month_first": "%b %-d %H:%M:%S"
},
"log_configs": [
{
"timestamp_format_key": "month_first",
"log_stream_name": "test2"
},
{
"timestamp_format_key": "month_first",
"log_stream_name": "test"
}
]
}
105 changes: 46 additions & 59 deletions test/unit/cloudwatch_agent/test_write_cloudwatch_agent_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
from assertpy import assert_that
from write_cloudwatch_agent_json import (
Expand All @@ -18,6 +20,7 @@
add_timestamps,
create_config,
filter_output_fields,
gethostname,
select_configs_for_feature,
select_configs_for_node_role,
select_configs_for_platform,
Expand Down Expand Up @@ -77,49 +80,35 @@
}


@pytest.mark.asyncio
def test_add_log_group_name_params():
configs = add_log_group_name_params("test", CONFIGS)
for config in configs:
assert_that(config).contains("log_group_name")
assert_that(config["log_group_name"]).is_equal_to("test")


@pytest.mark.asyncio
def test_add_instance_log_stream_prefixes(mocker):
instance_id = "i-0096test"
mocker.patch(
"write_cloudwatch_agent_json.gethostname",
return_value=instance_id,
)

def test_add_instance_log_stream_prefixes():
configs = add_instance_log_stream_prefixes(CONFIGS)
for config in configs:
assert_that(config["log_stream_name"]).contains(instance_id)


@pytest.mark.asyncio
def test_select_configs_for_scheduler():
configs = select_configs_for_scheduler(CONFIGS, "slurm")
assert_that(len(configs)).is_equal_to(2)
configs = select_configs_for_scheduler(CONFIGS, "awsbatch")
assert_that(len(configs)).is_equal_to(3)
assert_that(config["log_stream_name"]).contains(gethostname())


@pytest.mark.asyncio
def test_select_configs_for_node_role():
configs = select_configs_for_node_role(CONFIGS, "ComputeFleet")
assert_that(len(configs)).is_equal_to(2)
configs = select_configs_for_node_role(CONFIGS, "HeadNode")
assert_that(len(configs)).is_equal_to(3)


@pytest.mark.asyncio
def test_select_configs_for_platform():
configs = select_configs_for_platform(CONFIGS, "amazon")
assert_that(len(configs)).is_equal_to(2)
configs = select_configs_for_platform(CONFIGS, "ubuntu")
assert_that(len(configs)).is_equal_to(2)
@pytest.mark.parametrize(
"dimensions",
[
{"platform": "amazon", "length": 2},
{"platform": "ubuntu", "length": 2},
{"scheduler": "slurm", "length": 2},
{"scheduler": "awsbatch", "length": 3},
],
)
def test_select_configs_for_dimesion(dimensions):
if "platform" in dimensions.keys():
configs = select_configs_for_platform(CONFIGS, dimensions["platform"])
assert_that(len(configs)).is_equal_to(dimensions["length"])
else:
configs = select_configs_for_scheduler(CONFIGS, dimensions["scheduler"])
assert_that(len(configs)).is_equal_to(dimensions["length"])


@pytest.mark.parametrize(
Expand All @@ -140,7 +129,6 @@ def test_select_configs_for_feature(mocker, info):
assert_that(len(selected_configs)).is_equal_to(info["length"])


@pytest.mark.asyncio
def test_add_timestamps():
timestamp_formats = {"month_first": "%b %-d %H:%M:%S", "default": "%Y-%m-%d %H:%M:%S,%f"}
configs = add_timestamps(CONFIGS, timestamp_formats)
Expand All @@ -149,7 +137,6 @@ def test_add_timestamps():
assert_that(config["timestamp_format"]).is_equal_to(timestamp_format)


@pytest.mark.asyncio
def test_filter_output_fields():
desired_keys = ["log_stream_name", "file_path", "timestamp_format", "log_group_name"]
configs = filter_output_fields(CONFIGS)
Expand All @@ -158,22 +145,14 @@ def test_filter_output_fields():
assert_that(desired_keys).contains(key)


@pytest.mark.asyncio
def test_create_config(mocker):
instance_id = "i-0096test"
mocker.patch(
"write_cloudwatch_agent_json.gethostname",
return_value=instance_id,
)

def test_create_config():
cw_agent_config = create_config(CONFIGS, METRIC_CONFIGS)

assert_that(len(cw_agent_config)).is_equal_to(2)
assert_that(len(cw_agent_config["logs"]["logs_collected"]["files"]["collect_list"])).is_equal_to(3)
assert_that(cw_agent_config["logs"]["log_stream_name"]).contains(instance_id)
assert_that(cw_agent_config["logs"]["log_stream_name"]).contains(gethostname())


@pytest.mark.asyncio
def test_select_metrics(mocker):
mocker.patch(
"write_cloudwatch_agent_json.select_configs_for_node_role",
Expand All @@ -187,23 +166,31 @@ def test_select_metrics(mocker):
assert_that(metric_configs["metrics_collected"][key]).does_not_contain_key("node_roles")


@pytest.mark.asyncio
def test_add_append_dimensions():
metrics = {"metrics_collected": METRIC_CONFIGS["metrics_collected"]}
metrics = add_append_dimensions(metrics, METRIC_CONFIGS)

assert_that(len(metrics)).is_equal_to(2)
assert_that(metrics["append_dimensions"]).is_type_of(dict)
assert_that(metrics["append_dimensions"]).contains_key("InstanceId")
@pytest.mark.parametrize(
"node",
[
{"role": "ComputeFleet", "length": 2},
{"role": "HeadNode", "length": 3},
],
)
def test_select_configs_for_node_role(node):
configs = select_configs_for_node_role(CONFIGS, node["role"])
assert_that(len(configs)).is_equal_to(node["length"])


@pytest.mark.asyncio
def test_add_aggregation_dimensions():
@pytest.mark.parametrize(
"dimension",
[{"name": "append", "type": dict}, {"name": "aggregation", "type": list}],
)
def test_add_dimensions(dimension):
metrics = {"metrics_collected": METRIC_CONFIGS["metrics_collected"]}
metrics = add_aggregation_dimensions(metrics, METRIC_CONFIGS)
if dimension["name"] == "append":
metrics = add_append_dimensions(metrics, METRIC_CONFIGS)
output = metrics["append_dimensions"]
else:
metrics = add_aggregation_dimensions(metrics, METRIC_CONFIGS)
output = metrics["aggregation_dimensions"][0]

assert_that(len(metrics)).is_equal_to(2)
assert_that(len(metrics["aggregation_dimensions"])).is_equal_to(1)
assert_that(metrics["aggregation_dimensions"][0]).is_type_of(list)
assert_that(metrics["aggregation_dimensions"][0]).contains("InstanceId")
assert_that(metrics["aggregation_dimensions"][0]).contains("path")
assert_that(output).is_type_of(dimension["type"])
assert_that(output).contains("InstanceId")

0 comments on commit 4a1c52e

Please sign in to comment.