Skip to content

Commit

Permalink
Add a backing_instance_max_count config option for clustermgtd to b…
Browse files Browse the repository at this point in the history
…e robust to eventual EC2 consistency (#613)

Adding a config option to clustermgtd, ec2_backing_instance_max_count, to allow more time for describe-instances to reach eventual consistency with run-instances data
Passes the max count and map to is_healthy() and is_bootstrap_failure() for static and dynamic nodes to evaluate the count for individual instances.
  • Loading branch information
dreambeyondorange authored Feb 20, 2024
1 parent 853f48d commit 1b4ba77
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 37 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ This file is used to list changes made in each version of the aws-parallelcluste
------

**ENHANCEMENTS**
- Add a clustermgtd config option `ec2_instance_missing_max_count` to allow a configurable amount of retries for eventual EC2
describe instances consistency with run instances

**CHANGES**

Expand Down
10 changes: 8 additions & 2 deletions src/slurm_plugin/cluster_event_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,18 @@ def detail_supplier(node_names):
# }
# }
@log_exception(logger, "publish_unhealthy_node_events", catch_exception=Exception, raise_on_error=False)
def publish_unhealthy_node_events(self, unhealthy_nodes: List[SlurmNode]):
def publish_unhealthy_node_events(
self, unhealthy_nodes: List[SlurmNode], ec2_instance_missing_max_count, nodes_without_backing_instance_count_map
):
"""Publish events for unhealthy nodes without a backing instance and for nodes that are not responding."""
timestamp = ClusterEventPublisher.timestamp()

nodes_with_invalid_backing_instance = [
node for node in unhealthy_nodes if not node.is_backing_instance_valid(log_warn_if_unhealthy=False)
node
for node in unhealthy_nodes
if not node.is_backing_instance_valid(
ec2_instance_missing_max_count, nodes_without_backing_instance_count_map, log_warn_if_unhealthy=False
)
]
self.publish_event(
logging.WARNING if nodes_with_invalid_backing_instance else logging.DEBUG,
Expand Down
27 changes: 22 additions & 5 deletions src/slurm_plugin/clustermgtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class ClustermgtdConfig:
"terminate_drain_nodes": True,
"terminate_down_nodes": True,
"orphaned_instance_timeout": 300,
"ec2_instance_missing_max_count": 2,
# Health check configs
"disable_ec2_health_check": False,
"disable_scheduled_event_health_check": False,
Expand Down Expand Up @@ -304,6 +305,11 @@ def _get_terminate_config(self, config):
self.insufficient_capacity_timeout = config.getfloat(
"clustermgtd", "insufficient_capacity_timeout", fallback=self.DEFAULTS.get("insufficient_capacity_timeout")
)
self.ec2_instance_missing_max_count = config.getint(
"clustermgtd",
"ec2_instance_missing_max_count",
fallback=self.DEFAULTS.get("ec2_instance_missing_max_count"),
)
self.disable_nodes_on_insufficient_capacity = self.insufficient_capacity_timeout > 0

def _get_dns_config(self, config):
Expand Down Expand Up @@ -384,6 +390,7 @@ def __init__(self, config):
self._insufficient_capacity_compute_resources = {}
self._static_nodes_in_replacement = set()
self._partitions_protected_failure_count_map = {}
self._nodes_without_backing_instance_count_map = {}
self._compute_fleet_status = ComputeFleetStatus.RUNNING
self._current_time = None
self._config = None
Expand Down Expand Up @@ -492,7 +499,10 @@ def _handle_successfully_launched_nodes(self, partitions_name_map):
partitions_protected_failure_count_map = self._partitions_protected_failure_count_map.copy()
for partition, failures_per_compute_resource in partitions_protected_failure_count_map.items():
partition_online_compute_resources = partitions_name_map[partition].get_online_node_by_type(
self._config.terminate_drain_nodes, self._config.terminate_down_nodes
self._config.terminate_drain_nodes,
self._config.terminate_down_nodes,
self._config.ec2_instance_missing_max_count,
self._nodes_without_backing_instance_count_map,
)
for compute_resource in failures_per_compute_resource.keys():
if compute_resource in partition_online_compute_resources:
Expand Down Expand Up @@ -762,6 +772,8 @@ def _find_unhealthy_slurm_nodes(self, slurm_nodes):
if not node.is_healthy(
consider_drain_as_unhealthy=self._config.terminate_drain_nodes,
consider_down_as_unhealthy=self._config.terminate_down_nodes,
ec2_instance_missing_max_count=self._config.ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map=self._nodes_without_backing_instance_count_map,
log_warn_if_unhealthy=node.name not in reserved_nodenames,
):
if not self._config.disable_capacity_blocks_management and node.name in reserved_nodenames:
Expand All @@ -778,7 +790,11 @@ def _find_unhealthy_slurm_nodes(self, slurm_nodes):
).append(node)
else:
unhealthy_dynamic_nodes.append(node)
self._event_publisher.publish_unhealthy_node_events(all_unhealthy_nodes)
self._event_publisher.publish_unhealthy_node_events(
all_unhealthy_nodes,
self._config.ec2_instance_missing_max_count,
self._nodes_without_backing_instance_count_map,
)
return (
unhealthy_dynamic_nodes,
unhealthy_static_nodes,
Expand Down Expand Up @@ -1167,11 +1183,12 @@ def _is_node_replacement_timeout(self, node):
"""Check if a static node is in replacement but replacement time is expired."""
return self._is_node_in_replacement_valid(node, check_node_is_valid=False)

@staticmethod
def _find_bootstrap_failure_nodes(slurm_nodes):
def _find_bootstrap_failure_nodes(self, slurm_nodes):
bootstrap_failure_nodes = []
for node in slurm_nodes:
if node.is_bootstrap_failure():
if node.is_bootstrap_failure(
self._config.ec2_instance_missing_max_count, self._nodes_without_backing_instance_count_map
):
bootstrap_failure_nodes.append(node)
return bootstrap_failure_nodes

Expand Down
110 changes: 95 additions & 15 deletions src/slurm_plugin/slurm_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,24 @@ def is_inactive(self):
def has_running_job(self):
return any(node.is_running_job() for node in self.slurm_nodes)

def get_online_node_by_type(self, terminate_drain_nodes, terminate_down_nodes):
def get_online_node_by_type(
self,
terminate_drain_nodes,
terminate_down_nodes,
ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map,
):
online_compute_resources = set()
if not self.state == "INACTIVE":
for node in self.slurm_nodes:
if (
node.is_healthy(terminate_drain_nodes, terminate_down_nodes, log_warn_if_unhealthy=False)
node.is_healthy(
terminate_drain_nodes,
terminate_down_nodes,
ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map,
log_warn_if_unhealthy=False,
)
and node.is_online()
):
logger.debug("Currently online node: %s, node state: %s", node.name, node.state_string)
Expand Down Expand Up @@ -233,6 +245,7 @@ def __init__(
self.is_failing_health_check = False
self.error_code = self._parse_error_code()
self.queue_name, self._node_type, self.compute_resource_name = parse_nodename(name)
self.ec2_backing_instance_valid = None

def is_nodeaddr_set(self):
"""Check if nodeaddr(private ip) for the node is set."""
Expand Down Expand Up @@ -378,7 +391,7 @@ def is_state_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealt
pass

@abstractmethod
def is_bootstrap_failure(self):
def is_bootstrap_failure(self, ec2_instance_missing_max_count, nodes_without_backing_instance_count_map: dict):
"""
Check if a slurm node has boostrap failure.
Expand All @@ -394,7 +407,14 @@ def is_bootstrap_timeout(self):
pass

@abstractmethod
def is_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=True):
def is_healthy(
self,
consider_drain_as_unhealthy,
consider_down_as_unhealthy,
ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map: dict,
log_warn_if_unhealthy=True,
):
"""Check if a slurm node is considered healthy."""
pass

Expand All @@ -404,8 +424,18 @@ def is_powering_down_with_nodeaddr(self):
# for example because of a short SuspendTimeout
return self.is_nodeaddr_set() and (self.is_power() or self.is_powering_down())

def is_backing_instance_valid(self, log_warn_if_unhealthy=True):
def is_backing_instance_valid(
self,
ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map: dict,
log_warn_if_unhealthy=True,
):
"""Check if a slurm node's addr is set, it points to a valid instance in EC2."""
# Perform this logic only once and return the result thereafter
if self.ec2_backing_instance_valid is not None:
return self.ec2_backing_instance_valid
# Set ec2_backing_instance_valid to True since it will be the result most often
self.ec2_backing_instance_valid = True
if self.is_nodeaddr_set():
if not self.instance:
if log_warn_if_unhealthy:
Expand All @@ -414,8 +444,30 @@ def is_backing_instance_valid(self, log_warn_if_unhealthy=True):
self,
self.state_string,
)
return False
return True
# Allow a few iterations for the eventual consistency of EC2 data
logger.debug(f"Map of slurm nodes without backing instances {nodes_without_backing_instance_count_map}")
missing_instance_loop_count = nodes_without_backing_instance_count_map.get(self.name, 0)
# If the loop count has been reached, the instance is unhealthy and will be terminated
if missing_instance_loop_count >= ec2_instance_missing_max_count:
if log_warn_if_unhealthy:
logger.warning(f"EC2 instance availability for node {self.name} has timed out.")
# Remove the slurm node from the map since a new instance will be launched
nodes_without_backing_instance_count_map.pop(self.name, None)
self.ec2_backing_instance_valid = False
else:
nodes_without_backing_instance_count_map[self.name] = missing_instance_loop_count + 1
if log_warn_if_unhealthy:
logger.warning(
f"Incrementing missing EC2 instance count for node {self.name} to "
f"{nodes_without_backing_instance_count_map[self.name]}."
)
else:
# Remove the slurm node from the map since the instance is healthy
nodes_without_backing_instance_count_map.pop(self.name, None)
else:
# Remove the slurm node from the map since the instance is healthy
nodes_without_backing_instance_count_map.pop(self.name, None)
return self.ec2_backing_instance_valid

@abstractmethod
def needs_reset_when_inactive(self):
Expand Down Expand Up @@ -478,11 +530,22 @@ def __init__(
reservation_name=reservation_name,
)

def is_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=True):
def is_healthy(
self,
consider_drain_as_unhealthy,
consider_down_as_unhealthy,
ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map: dict,
log_warn_if_unhealthy=True,
):
"""Check if a slurm node is considered healthy."""
return (
self._is_static_node_ip_configuration_valid(log_warn_if_unhealthy=log_warn_if_unhealthy)
and self.is_backing_instance_valid(log_warn_if_unhealthy=log_warn_if_unhealthy)
and self.is_backing_instance_valid(
ec2_instance_missing_max_count=ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map=nodes_without_backing_instance_count_map,
log_warn_if_unhealthy=log_warn_if_unhealthy,
)
and self.is_state_healthy(
consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=log_warn_if_unhealthy
)
Expand Down Expand Up @@ -533,9 +596,13 @@ def _is_static_node_ip_configuration_valid(self, log_warn_if_unhealthy=True):
return False
return True

def is_bootstrap_failure(self):
def is_bootstrap_failure(self, ec2_instance_missing_max_count, nodes_without_backing_instance_count_map: dict):
"""Check if a slurm node has boostrap failure."""
if self.is_static_nodes_in_replacement and not self.is_backing_instance_valid(log_warn_if_unhealthy=False):
if self.is_static_nodes_in_replacement and not self.is_backing_instance_valid(
ec2_instance_missing_max_count=ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map=nodes_without_backing_instance_count_map,
log_warn_if_unhealthy=False,
):
# Node is currently in replacement and no backing instance
logger.warning(
"Node bootstrap error: Node %s is currently in replacement and no backing instance, node state: %s",
Expand Down Expand Up @@ -618,17 +685,30 @@ def is_state_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealt
return False
return True

def is_healthy(self, consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=True):
def is_healthy(
self,
consider_drain_as_unhealthy,
consider_down_as_unhealthy,
ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map,
log_warn_if_unhealthy=True,
):
"""Check if a slurm node is considered healthy."""
return self.is_backing_instance_valid(log_warn_if_unhealthy=log_warn_if_unhealthy) and self.is_state_healthy(
return self.is_backing_instance_valid(
ec2_instance_missing_max_count=ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map=nodes_without_backing_instance_count_map,
log_warn_if_unhealthy=log_warn_if_unhealthy,
) and self.is_state_healthy(
consider_drain_as_unhealthy, consider_down_as_unhealthy, log_warn_if_unhealthy=log_warn_if_unhealthy
)

def is_bootstrap_failure(self):
def is_bootstrap_failure(self, ec2_instance_missing_max_count, nodes_without_backing_instance_count_map: dict):
"""Check if a slurm node has boostrap failure."""
# no backing instance + [working state]# in node state
if (self.is_configuring_job() or self.is_powering_up_idle()) and not self.is_backing_instance_valid(
log_warn_if_unhealthy=False
ec2_instance_missing_max_count=ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map=nodes_without_backing_instance_count_map,
log_warn_if_unhealthy=False,
):
logger.warning(
"Node bootstrap error: Node %s is in power up state without valid backing instance, node state: %s",
Expand Down
Loading

0 comments on commit 1b4ba77

Please sign in to comment.