From ce0bb327097e5348f6bd34c1e615a4a0fb20af75 Mon Sep 17 00:00:00 2001 From: Ryan Anderson Date: Tue, 13 Feb 2024 15:27:27 -0500 Subject: [PATCH] Add a `backing_instance_max_count` config option for clustermgtd to be robust to eventual EC2 consistency Adding a config option to clustermgtd, ec2_backing_instance_max_count, to allow more time for describe-instances to reach eventual consistency with run-instances data Passes the max count and map to is_healthy() for static and dynamic nodes to evaluate the count for individual instances. --- CHANGELOG.md | 2 + src/slurm_plugin/clustermgtd.py | 9 ++ src/slurm_plugin/slurm_resources.py | 73 +++++++++++++-- .../slurm_resources/test_slurm_resources.py | 88 ++++++++++++++++++- tests/slurm_plugin/test_clustermgtd.py | 2 + 5 files changed, 163 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e43d4e6d..ab9fead1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ This file is used to list changes made in each version of the aws-parallelcluste ------ **ENHANCEMENTS** +- Add a clustermgtd config option `ec2_backing_instance_max_count` to allow a configurable amount of time for eventual EC2 + describe instances consistency with run instances **CHANGES** diff --git a/src/slurm_plugin/clustermgtd.py b/src/slurm_plugin/clustermgtd.py index 12870bd6b..bf638d667 100644 --- a/src/slurm_plugin/clustermgtd.py +++ b/src/slurm_plugin/clustermgtd.py @@ -150,6 +150,7 @@ class ClustermgtdConfig: "terminate_drain_nodes": True, "terminate_down_nodes": True, "orphaned_instance_timeout": 300, + "ec2_backing_instance_max_count": 2, # Health check configs "disable_ec2_health_check": False, "disable_scheduled_event_health_check": False, @@ -304,6 +305,11 @@ def _get_terminate_config(self, config): self.insufficient_capacity_timeout = config.getfloat( "clustermgtd", "insufficient_capacity_timeout", fallback=self.DEFAULTS.get("insufficient_capacity_timeout") ) + self.ec2_backing_instance_max_count = config.getfloat( + "clustermgtd", + "ec2_backing_instance_max_count", + fallback=self.DEFAULTS.get("ec2_backing_instance_max_count"), + ) self.disable_nodes_on_insufficient_capacity = self.insufficient_capacity_timeout > 0 def _get_dns_config(self, config): @@ -384,6 +390,7 @@ def __init__(self, config): self._insufficient_capacity_compute_resources = {} self._static_nodes_in_replacement = set() self._partitions_protected_failure_count_map = {} + self._nodes_without_backing_instance_count_map = {} self._compute_fleet_status = ComputeFleetStatus.RUNNING self._current_time = None self._config = None @@ -763,6 +770,8 @@ def _find_unhealthy_slurm_nodes(self, slurm_nodes): consider_drain_as_unhealthy=self._config.terminate_drain_nodes, consider_down_as_unhealthy=self._config.terminate_down_nodes, log_warn_if_unhealthy=node.name not in reserved_nodenames, + ec2_backing_instance_max_count=self._config.ec2_backing_instance_max_count, + nodes_without_backing_instance_count_map=self._nodes_without_backing_instance_count_map, ): if not self._config.disable_capacity_blocks_management and node.name in reserved_nodenames: # do not consider as unhealthy the nodes reserved for capacity blocks diff --git a/src/slurm_plugin/slurm_resources.py b/src/slurm_plugin/slurm_resources.py index 05ab7a41c..2364bf5a5 100644 --- a/src/slurm_plugin/slurm_resources.py +++ b/src/slurm_plugin/slurm_resources.py @@ -233,6 +233,7 @@ def __init__( self.is_failing_health_check = False self.error_code = self._parse_error_code() self.queue_name, self._node_type, self.compute_resource_name = parse_nodename(name) + self.missing_count_incremented = False def is_nodeaddr_set(self): """Check if nodeaddr(private ip) for the node is set.""" @@ -394,7 +395,14 @@ def is_bootstrap_timeout(self): pass @abstractmethod - def is_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=True): + def is_healthy( + self, + consider_drain_as_unhealthy, + consider_down_as_unhealthy, + log_warn_if_unhealthy=True, + ec2_backing_instance_max_count=None, + nodes_without_backing_instance_count_map: dict = None, + ): """Check if a slurm node is considered healthy.""" pass @@ -404,7 +412,12 @@ def is_powering_down_with_nodeaddr(self): # for example because of a short SuspendTimeout return self.is_nodeaddr_set() and (self.is_power() or self.is_powering_down()) - def is_backing_instance_valid(self, log_warn_if_unhealthy=True): + def is_backing_instance_valid( + self, + log_warn_if_unhealthy=True, + ec2_backing_instance_max_count=None, + nodes_without_backing_instance_count_map: dict = None, + ): """Check if a slurm node's addr is set, it points to a valid instance in EC2.""" if self.is_nodeaddr_set(): if not self.instance: @@ -414,7 +427,31 @@ def is_backing_instance_valid(self, log_warn_if_unhealthy=True): self, self.state_string, ) - return False + # Allow a few iterations for the eventual consistency of EC2 data + logger.info( + f"ec2_backing_instance_max_count {ec2_backing_instance_max_count} " + f"nodes_without_backing_instance_count_map {nodes_without_backing_instance_count_map}" + ) + if any( + args in [None] + for args in [ec2_backing_instance_max_count, nodes_without_backing_instance_count_map] + ): + logger.info(f"No max count or map provided, ignoring backing instance timeout") + return False + elif nodes_without_backing_instance_count_map.get(self.name, 0) >= ec2_backing_instance_max_count: + logger.warning(f"Instance {self.name} availability has timed out.") + return False + else: + if not self.missing_count_incremented: + nodes_without_backing_instance_count_map[self.name] = ( + nodes_without_backing_instance_count_map.get(self.name, 0) + 1 + ) + logger.warning( + f"Instance {self.name} is not yet available in EC2, " + f"incrementing missing count to {nodes_without_backing_instance_count_map[self.name]}." + ) + self.missing_count_incremented = True + return True return True @abstractmethod @@ -478,11 +515,22 @@ def __init__( reservation_name=reservation_name, ) - def is_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=True): + def is_healthy( + self, + consider_drain_as_unhealthy, + consider_down_as_unhealthy, + log_warn_if_unhealthy=True, + ec2_backing_instance_max_count=None, + nodes_without_backing_instance_count_map: dict = None, + ): """Check if a slurm node is considered healthy.""" return ( self._is_static_node_ip_configuration_valid(log_warn_if_unhealthy=log_warn_if_unhealthy) - and self.is_backing_instance_valid(log_warn_if_unhealthy=log_warn_if_unhealthy) + and self.is_backing_instance_valid( + log_warn_if_unhealthy=log_warn_if_unhealthy, + ec2_backing_instance_max_count=ec2_backing_instance_max_count, + nodes_without_backing_instance_count_map=nodes_without_backing_instance_count_map, + ) and self.is_state_healthy( consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=log_warn_if_unhealthy ) @@ -618,9 +666,20 @@ def is_state_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealt return False return True - def is_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=True): + def is_healthy( + self, + consider_drain_as_unhealthy, + consider_down_as_unhealthy, + log_warn_if_unhealthy=True, + ec2_backing_instance_max_count=None, + nodes_without_backing_instance_count_map: dict = None, + ): """Check if a slurm node is considered healthy.""" - return self.is_backing_instance_valid(log_warn_if_unhealthy=log_warn_if_unhealthy) and self.is_state_healthy( + return self.is_backing_instance_valid( + log_warn_if_unhealthy=log_warn_if_unhealthy, + ec2_backing_instance_max_count=ec2_backing_instance_max_count, + nodes_without_backing_instance_count_map=nodes_without_backing_instance_count_map, + ) and self.is_state_healthy( consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=log_warn_if_unhealthy ) diff --git a/tests/slurm_plugin/slurm_resources/test_slurm_resources.py b/tests/slurm_plugin/slurm_resources/test_slurm_resources.py index 0d1291e3b..72fca0dad 100644 --- a/tests/slurm_plugin/slurm_resources/test_slurm_resources.py +++ b/tests/slurm_plugin/slurm_resources/test_slurm_resources.py @@ -1086,34 +1086,114 @@ def test_slurm_node_is_powering_down_with_nodeaddr(node, expected_result): @pytest.mark.parametrize( - "node, instance, expected_result", + "node, instance, max_count, count_map, final_count, count_incremented, expected_result", [ ( StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD", "queue1"), None, + None, + None, + None, + False, False, ), ( DynamicNode("node-dy-c5xlarge-1", "node-dy-c5xlarge-1", "hostname", "IDLE+CLOUD+POWER", "node"), None, + None, + None, + None, + False, True, ), ( DynamicNode("node-dy-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "node"), None, + None, + None, + None, + False, False, ), ( StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"), EC2Instance("id-1", "ip-1", "hostname", datetime(2020, 1, 1, 0, 0, 0)), + None, + None, + None, + False, True, ), + ( + StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"), + None, + 0, + None, + None, + False, + False, + ), + ( + StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"), + None, + None, + {"queue1-st-c5xlarge-1": 1}, + 1, + False, + False, + ), + ( + StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"), + None, + 2, + {"queue1-st-c5xlarge-1": 1}, + 2, + True, + True, + ), + ( + StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"), + None, + 2, + {"queue1-st-c5xlarge-1": 2}, + 2, + False, + False, + ), + ( + StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"), + None, + 2, + {"queue1-st-c5xlarge-1": 3}, + 3, + False, + False, + ), + ], + ids=[ + "static_no_backing", + "dynamic_power_save", + "dynamic_no_backing", + "static_valid", + "static_no_backing_with_count_no_map", + "static_no_backing_no_count_with_map", + "static_no_backing_with_count_not_exceeded_with_map", + "static_no_backing_with_count_exceeded_with_map", + "static_no_backing_with_count_exceeded_with_map_2", ], - ids=["static_no_backing", "dynamic_power_save", "dynamic_no_backing", "static_valid"], ) -def test_slurm_node_is_backing_instance_valid(node, instance, expected_result): +def test_slurm_node_is_backing_instance_valid( + node, instance, max_count, count_map, final_count, count_incremented, expected_result +): node.instance = instance - assert_that(node.is_backing_instance_valid()).is_equal_to(expected_result) + assert_that( + node.is_backing_instance_valid( + ec2_backing_instance_max_count=max_count, nodes_without_backing_instance_count_map=count_map + ) + ).is_equal_to(expected_result) + assert_that(node.missing_count_incremented).is_equal_to(count_incremented) + if count_map: + assert_that(count_map[node.name]).is_equal_to(final_count) @pytest.mark.parametrize( diff --git a/tests/slurm_plugin/test_clustermgtd.py b/tests/slurm_plugin/test_clustermgtd.py index 8ad53bfbe..0f05b615e 100644 --- a/tests/slurm_plugin/test_clustermgtd.py +++ b/tests/slurm_plugin/test_clustermgtd.py @@ -1352,6 +1352,7 @@ def test_maintain_nodes( region="region", boto3_config=None, fleet_config={}, + ec2_backing_instance_max_count=0, ) cluster_manager = ClusterManager(mock_sync_config) cluster_manager._static_nodes_in_replacement = static_nodes_in_replacement @@ -3877,6 +3878,7 @@ def test_find_unhealthy_slurm_nodes( boto3_config=None, fleet_config={}, disable_capacity_blocks_management=disable_capacity_blocks_management, + ec2_backing_instance_max_count=0, ) cluster_manager = ClusterManager(mock_sync_config) get_reserved_mock = mocker.patch.object(