Skip to content

Commit

Permalink
Add a backing_instance_max_count config option for clustermgtd to b…
Browse files Browse the repository at this point in the history
…e robust to eventual EC2 consistency

Adding a config option to clustermgtd, ec2_backing_instance_max_count, to allow more time for describe-instances to reach eventual consistency with run-instances data
Passes the max count and map to is_healthy() for static and dynamic nodes to evaluate the count for individual instances.
  • Loading branch information
dreambeyondorange committed Feb 16, 2024
1 parent 799b4e1 commit ce0bb32
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ This file is used to list changes made in each version of the aws-parallelcluste
------

**ENHANCEMENTS**
- Add a clustermgtd config option `ec2_backing_instance_max_count` to allow a configurable amount of time for eventual EC2
describe instances consistency with run instances

**CHANGES**

Expand Down
9 changes: 9 additions & 0 deletions src/slurm_plugin/clustermgtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class ClustermgtdConfig:
"terminate_drain_nodes": True,
"terminate_down_nodes": True,
"orphaned_instance_timeout": 300,
"ec2_backing_instance_max_count": 2,
# Health check configs
"disable_ec2_health_check": False,
"disable_scheduled_event_health_check": False,
Expand Down Expand Up @@ -304,6 +305,11 @@ def _get_terminate_config(self, config):
self.insufficient_capacity_timeout = config.getfloat(
"clustermgtd", "insufficient_capacity_timeout", fallback=self.DEFAULTS.get("insufficient_capacity_timeout")
)
self.ec2_backing_instance_max_count = config.getfloat(
"clustermgtd",
"ec2_backing_instance_max_count",
fallback=self.DEFAULTS.get("ec2_backing_instance_max_count"),
)
self.disable_nodes_on_insufficient_capacity = self.insufficient_capacity_timeout > 0

def _get_dns_config(self, config):
Expand Down Expand Up @@ -384,6 +390,7 @@ def __init__(self, config):
self._insufficient_capacity_compute_resources = {}
self._static_nodes_in_replacement = set()
self._partitions_protected_failure_count_map = {}
self._nodes_without_backing_instance_count_map = {}
self._compute_fleet_status = ComputeFleetStatus.RUNNING
self._current_time = None
self._config = None
Expand Down Expand Up @@ -763,6 +770,8 @@ def _find_unhealthy_slurm_nodes(self, slurm_nodes):
consider_drain_as_unhealthy=self._config.terminate_drain_nodes,
consider_down_as_unhealthy=self._config.terminate_down_nodes,
log_warn_if_unhealthy=node.name not in reserved_nodenames,
ec2_backing_instance_max_count=self._config.ec2_backing_instance_max_count,
nodes_without_backing_instance_count_map=self._nodes_without_backing_instance_count_map,
):
if not self._config.disable_capacity_blocks_management and node.name in reserved_nodenames:
# do not consider as unhealthy the nodes reserved for capacity blocks
Expand Down
73 changes: 66 additions & 7 deletions src/slurm_plugin/slurm_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
self.is_failing_health_check = False
self.error_code = self._parse_error_code()
self.queue_name, self._node_type, self.compute_resource_name = parse_nodename(name)
self.missing_count_incremented = False

def is_nodeaddr_set(self):
"""Check if nodeaddr(private ip) for the node is set."""
Expand Down Expand Up @@ -394,7 +395,14 @@ def is_bootstrap_timeout(self):
pass

@abstractmethod
def is_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=True):
def is_healthy(
self,
consider_drain_as_unhealthy,
consider_down_as_unhealthy,
log_warn_if_unhealthy=True,
ec2_backing_instance_max_count=None,
nodes_without_backing_instance_count_map: dict = None,
):
"""Check if a slurm node is considered healthy."""
pass

Expand All @@ -404,7 +412,12 @@ def is_powering_down_with_nodeaddr(self):
# for example because of a short SuspendTimeout
return self.is_nodeaddr_set() and (self.is_power() or self.is_powering_down())

def is_backing_instance_valid(self, log_warn_if_unhealthy=True):
def is_backing_instance_valid(
self,
log_warn_if_unhealthy=True,
ec2_backing_instance_max_count=None,
nodes_without_backing_instance_count_map: dict = None,
):
"""Check if a slurm node's addr is set, it points to a valid instance in EC2."""
if self.is_nodeaddr_set():
if not self.instance:
Expand All @@ -414,7 +427,31 @@ def is_backing_instance_valid(self, log_warn_if_unhealthy=True):
self,
self.state_string,
)
return False
# Allow a few iterations for the eventual consistency of EC2 data
logger.info(
f"ec2_backing_instance_max_count {ec2_backing_instance_max_count} "
f"nodes_without_backing_instance_count_map {nodes_without_backing_instance_count_map}"
)
if any(
args in [None]
for args in [ec2_backing_instance_max_count, nodes_without_backing_instance_count_map]
):
logger.info(f"No max count or map provided, ignoring backing instance timeout")
return False
elif nodes_without_backing_instance_count_map.get(self.name, 0) >= ec2_backing_instance_max_count:
logger.warning(f"Instance {self.name} availability has timed out.")
return False
else:
if not self.missing_count_incremented:
nodes_without_backing_instance_count_map[self.name] = (
nodes_without_backing_instance_count_map.get(self.name, 0) + 1
)
logger.warning(
f"Instance {self.name} is not yet available in EC2, "
f"incrementing missing count to {nodes_without_backing_instance_count_map[self.name]}."
)
self.missing_count_incremented = True
return True
return True

@abstractmethod
Expand Down Expand Up @@ -478,11 +515,22 @@ def __init__(
reservation_name=reservation_name,
)

def is_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=True):
def is_healthy(
self,
consider_drain_as_unhealthy,
consider_down_as_unhealthy,
log_warn_if_unhealthy=True,
ec2_backing_instance_max_count=None,
nodes_without_backing_instance_count_map: dict = None,
):
"""Check if a slurm node is considered healthy."""
return (
self._is_static_node_ip_configuration_valid(log_warn_if_unhealthy=log_warn_if_unhealthy)
and self.is_backing_instance_valid(log_warn_if_unhealthy=log_warn_if_unhealthy)
and self.is_backing_instance_valid(
log_warn_if_unhealthy=log_warn_if_unhealthy,
ec2_backing_instance_max_count=ec2_backing_instance_max_count,
nodes_without_backing_instance_count_map=nodes_without_backing_instance_count_map,
)
and self.is_state_healthy(
consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=log_warn_if_unhealthy
)
Expand Down Expand Up @@ -618,9 +666,20 @@ def is_state_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealt
return False
return True

def is_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=True):
def is_healthy(
self,
consider_drain_as_unhealthy,
consider_down_as_unhealthy,
log_warn_if_unhealthy=True,
ec2_backing_instance_max_count=None,
nodes_without_backing_instance_count_map: dict = None,
):
"""Check if a slurm node is considered healthy."""
return self.is_backing_instance_valid(log_warn_if_unhealthy=log_warn_if_unhealthy) and self.is_state_healthy(
return self.is_backing_instance_valid(
log_warn_if_unhealthy=log_warn_if_unhealthy,
ec2_backing_instance_max_count=ec2_backing_instance_max_count,
nodes_without_backing_instance_count_map=nodes_without_backing_instance_count_map,
) and self.is_state_healthy(
consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=log_warn_if_unhealthy
)

Expand Down
88 changes: 84 additions & 4 deletions tests/slurm_plugin/slurm_resources/test_slurm_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,34 +1086,114 @@ def test_slurm_node_is_powering_down_with_nodeaddr(node, expected_result):


@pytest.mark.parametrize(
"node, instance, expected_result",
"node, instance, max_count, count_map, final_count, count_incremented, expected_result",
[
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD", "queue1"),
None,
None,
None,
None,
False,
False,
),
(
DynamicNode("node-dy-c5xlarge-1", "node-dy-c5xlarge-1", "hostname", "IDLE+CLOUD+POWER", "node"),
None,
None,
None,
None,
False,
True,
),
(
DynamicNode("node-dy-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "node"),
None,
None,
None,
None,
False,
False,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"),
EC2Instance("id-1", "ip-1", "hostname", datetime(2020, 1, 1, 0, 0, 0)),
None,
None,
None,
False,
True,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"),
None,
0,
None,
None,
False,
False,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"),
None,
None,
{"queue1-st-c5xlarge-1": 1},
1,
False,
False,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"),
None,
2,
{"queue1-st-c5xlarge-1": 1},
2,
True,
True,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"),
None,
2,
{"queue1-st-c5xlarge-1": 2},
2,
False,
False,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"),
None,
2,
{"queue1-st-c5xlarge-1": 3},
3,
False,
False,
),
],
ids=[
"static_no_backing",
"dynamic_power_save",
"dynamic_no_backing",
"static_valid",
"static_no_backing_with_count_no_map",
"static_no_backing_no_count_with_map",
"static_no_backing_with_count_not_exceeded_with_map",
"static_no_backing_with_count_exceeded_with_map",
"static_no_backing_with_count_exceeded_with_map_2",
],
ids=["static_no_backing", "dynamic_power_save", "dynamic_no_backing", "static_valid"],
)
def test_slurm_node_is_backing_instance_valid(node, instance, expected_result):
def test_slurm_node_is_backing_instance_valid(
node, instance, max_count, count_map, final_count, count_incremented, expected_result
):
node.instance = instance
assert_that(node.is_backing_instance_valid()).is_equal_to(expected_result)
assert_that(
node.is_backing_instance_valid(
ec2_backing_instance_max_count=max_count, nodes_without_backing_instance_count_map=count_map
)
).is_equal_to(expected_result)
assert_that(node.missing_count_incremented).is_equal_to(count_incremented)
if count_map:
assert_that(count_map[node.name]).is_equal_to(final_count)


@pytest.mark.parametrize(
Expand Down
2 changes: 2 additions & 0 deletions tests/slurm_plugin/test_clustermgtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,7 @@ def test_maintain_nodes(
region="region",
boto3_config=None,
fleet_config={},
ec2_backing_instance_max_count=0,
)
cluster_manager = ClusterManager(mock_sync_config)
cluster_manager._static_nodes_in_replacement = static_nodes_in_replacement
Expand Down Expand Up @@ -3877,6 +3878,7 @@ def test_find_unhealthy_slurm_nodes(
boto3_config=None,
fleet_config={},
disable_capacity_blocks_management=disable_capacity_blocks_management,
ec2_backing_instance_max_count=0,
)
cluster_manager = ClusterManager(mock_sync_config)
get_reserved_mock = mocker.patch.object(
Expand Down

0 comments on commit ce0bb32

Please sign in to comment.