Skip to content

Commit

Permalink
Support specification of multiple scaling strategies (#566)
Browse files Browse the repository at this point in the history
- Use a more extensible configuration parameter for the scaling strategy.
- Default to all-or-nothing in case of missing or invalid scaling strategy.

Signed-off-by: Eddy Mwiti <[email protected]>
  • Loading branch information
EddyMM authored Oct 5, 2023
1 parent ea58da0 commit 6acd753
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 125 deletions.
18 changes: 18 additions & 0 deletions src/slurm_plugin/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
from concurrent.futures import Future
from datetime import datetime
from enum import Enum
from typing import Callable, Optional, Protocol, TypedDict

from common.utils import check_command_output, time_is_up, validate_absolute_path
Expand All @@ -34,6 +35,23 @@
)


class ScalingStrategy(Enum):
ALL_OR_NOTHING = "all-or-nothing"
BEST_EFFORT = "best-effort"

@classmethod
def _missing_(cls, strategy):
# Ref: https://docs.python.org/3/library/enum.html#enum.Enum._missing_
_strategy = str(strategy).lower()
for member in cls:
if member.value == _strategy:
return member
return cls.ALL_OR_NOTHING # Default to All-Or-Nothing

def __str__(self):
return str(self.value)


class TaskController(Protocol):
class TaskShutdownError(RuntimeError):
"""Exception raised if shutdown has been requested."""
Expand Down
57 changes: 33 additions & 24 deletions src/slurm_plugin/instance_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from common.ec2_utils import get_private_ip_address_and_dns_name
from common.schedulers.slurm_commands import get_nodes_info, update_nodes
from common.utils import grouper, setup_logging_filter
from slurm_plugin.common import ComputeInstanceDescriptor, log_exception, print_with_count
from slurm_plugin.common import ComputeInstanceDescriptor, ScalingStrategy, log_exception, print_with_count
from slurm_plugin.fleet_manager import EC2Instance, FleetManagerFactory
from slurm_plugin.slurm_resources import (
EC2_HEALTH_STATUS_UNHEALTHY_STATES,
Expand Down Expand Up @@ -165,7 +165,7 @@ def add_instances(
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = False,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
slurm_resume: Dict[str, any] = None,
assign_node_batch_size: int = None,
terminate_batch_size: int = None,
Expand Down Expand Up @@ -531,7 +531,7 @@ def add_instances(
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = False,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
slurm_resume: Dict[str, any] = None,
assign_node_batch_size: int = None,
terminate_batch_size: int = None,
Expand All @@ -550,7 +550,7 @@ def add_instances(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)
else:
logger.error(
Expand All @@ -563,7 +563,7 @@ def add_instances(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

self._terminate_unassigned_launched_instances(terminate_batch_size)
Expand All @@ -574,7 +574,7 @@ def _scaling_for_jobs(
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool,
all_or_nothing_batch: bool,
scaling_strategy: ScalingStrategy,
) -> None:
"""Scaling for job list."""
# Setup custom logging filter
Expand All @@ -591,7 +591,7 @@ def _scaling_for_jobs(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

def _terminate_unassigned_launched_instances(self, terminate_batch_size):
Expand All @@ -616,7 +616,7 @@ def _scaling_for_jobs_single_node(
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool,
all_or_nothing_batch: bool,
scaling_strategy: ScalingStrategy,
) -> None:
"""Scaling for job single node list."""
if job_list:
Expand All @@ -627,7 +627,7 @@ def _scaling_for_jobs_single_node(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)
else:
# Batch all single node jobs in a single best-effort EC2 launch request
Expand All @@ -639,7 +639,7 @@ def _scaling_for_jobs_single_node(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=False,
scaling_strategy=ScalingStrategy.BEST_EFFORT,
)

def _add_instances_for_resume_file(
Expand All @@ -649,7 +649,7 @@ def _add_instances_for_resume_file(
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = False,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
):
"""Launch requested EC2 instances for resume file."""
slurm_resume_data = self._get_slurm_resume_data(slurm_resume=slurm_resume, node_list=node_list)
Expand All @@ -663,7 +663,7 @@ def _add_instances_for_resume_file(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

self._scaling_for_jobs_multi_node(
Expand All @@ -673,7 +673,7 @@ def _add_instances_for_resume_file(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

def _scaling_for_jobs_multi_node(
Expand All @@ -683,15 +683,15 @@ def _scaling_for_jobs_multi_node(
launch_batch_size,
assign_node_batch_size,
update_node_address,
all_or_nothing_batch: bool,
scaling_strategy: ScalingStrategy,
):
# Optimize job level scaling with preliminary scale-all nodes attempt
self._update_dict(
self.unused_launched_instances,
self._launch_instances(
nodes_to_launch=self._parse_nodes_resume_list(node_list),
launch_batch_size=launch_batch_size,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
),
)

Expand All @@ -700,7 +700,7 @@ def _scaling_for_jobs_multi_node(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

def _get_slurm_resume_data(self, slurm_resume: Dict[str, any], node_list: List[str]) -> SlurmResumeData:
Expand Down Expand Up @@ -835,7 +835,7 @@ def _add_instances_for_nodes(
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = True,
scaling_strategy: ScalingStrategy = ScalingStrategy.ALL_OR_NOTHING,
node_list: List[str] = None,
job: SlurmResumeJob = None,
):
Expand All @@ -858,7 +858,7 @@ def _add_instances_for_nodes(
job=job if job else None,
nodes_to_launch=nodes_resume_mapping,
launch_batch_size=launch_batch_size,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)
# instances launched, e.g.
# {
Expand All @@ -873,7 +873,7 @@ def _add_instances_for_nodes(
q_cr_instances_launched_length = len(instances_launched.get(queue, {}).get(compute_resource, []))
successful_launched_nodes += slurm_node_list[:q_cr_instances_launched_length]
failed_launch_nodes += slurm_node_list[q_cr_instances_launched_length:]
if all_or_nothing_batch:
if scaling_strategy == ScalingStrategy.ALL_OR_NOTHING:
self.all_or_nothing_node_assignment(
assign_node_batch_size=assign_node_batch_size,
instances_launched=instances_launched,
Expand Down Expand Up @@ -992,7 +992,7 @@ def _launch_instances(
self,
nodes_to_launch: Dict[str, any],
launch_batch_size: int,
all_or_nothing_batch: bool,
scaling_strategy: ScalingStrategy,
job: SlurmResumeJob = None,
):
instances_launched = defaultdict(lambda: defaultdict(list))
Expand All @@ -1008,9 +1008,13 @@ def _launch_instances(
if slurm_node_list:
logger.info(
"Launching %s instances for nodes %s",
"all-or-nothing" if all_or_nothing_batch else "best-effort",
scaling_strategy,
print_with_count(slurm_node_list),
)
# At instance launch level, the various scaling strategies can be grouped based on the actual
# launch behaviour i.e. all-or-nothing or best-effort
all_or_nothing_batch = scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING]

fleet_manager = self._get_fleet_manager(all_or_nothing_batch, compute_resource, queue)

for batch_nodes in grouper(slurm_node_list, launch_batch_size):
Expand Down Expand Up @@ -1203,7 +1207,8 @@ def add_instances(
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = False,
# Default to BEST_EFFORT since clustermgtd is not yet adapted for Job Level Scaling
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
slurm_resume: Dict[str, any] = None,
assign_node_batch_size: int = None,
terminate_batch_size: int = None,
Expand All @@ -1217,17 +1222,21 @@ def add_instances(
node_list=node_list,
launch_batch_size=launch_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

def _add_instances_for_nodes(
self,
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = False,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
):
"""Launch requested EC2 instances for nodes."""
# At fleet management level, the scaling strategies can be grouped based on the actual
# launch behaviour i.e. all-or-nothing or best-effort
all_or_nothing_batch = scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING]

nodes_to_launch = self._parse_nodes_resume_list(node_list)
for queue, compute_resources in nodes_to_launch.items():
for compute_resource, slurm_node_list in compute_resources.items():
Expand Down
9 changes: 6 additions & 3 deletions src/slurm_plugin/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from common.schedulers.slurm_commands import get_nodes_info, set_nodes_down
from common.utils import read_json
from slurm_plugin.cluster_event_publisher import ClusterEventPublisher
from slurm_plugin.common import is_clustermgtd_heartbeat_valid, print_with_count
from slurm_plugin.common import ScalingStrategy, is_clustermgtd_heartbeat_valid, print_with_count
from slurm_plugin.instance_manager import InstanceManagerFactory
from slurm_plugin.slurm_resources import CONFIG_FILE_DIR

Expand All @@ -45,8 +45,8 @@ class SlurmResumeConfig:
"run_instances_overrides": "/opt/slurm/etc/pcluster/run_instances_overrides.json",
"create_fleet_overrides": "/opt/slurm/etc/pcluster/create_fleet_overrides.json",
"fleet_config_file": "/etc/parallelcluster/slurm_plugin/fleet-config.json",
"all_or_nothing_batch": True,
"job_level_scaling": True,
"scaling_strategy": "all-or-nothing",
}

def __init__(self, config_file_path):
Expand Down Expand Up @@ -92,6 +92,9 @@ def _get_config(self, config_file_path):
self.all_or_nothing_batch = config.getboolean(
"slurm_resume", "all_or_nothing_batch", fallback=self.DEFAULTS.get("all_or_nothing_batch")
)
self.scaling_strategy = config.get(
"slurm_resume", "scaling_strategy", fallback=self.DEFAULTS.get("scaling_strategy")
) # TODO: Check if it's a valid scaling strategy before calling expensive downstream APIs
self.job_level_scaling = config.getboolean(
"slurm_resume", "job_level_scaling", fallback=self.DEFAULTS.get("job_level_scaling")
)
Expand Down Expand Up @@ -213,7 +216,7 @@ def _resume(arg_nodes, resume_config, slurm_resume):
assign_node_batch_size=resume_config.assign_node_max_batch_size,
terminate_batch_size=resume_config.terminate_max_batch_size,
update_node_address=resume_config.update_node_address,
all_or_nothing_batch=resume_config.all_or_nothing_batch,
scaling_strategy=ScalingStrategy(resume_config.scaling_strategy),
)
failed_nodes = set().union(*instance_manager.failed_nodes.values())
success_nodes = [node for node in node_list if node not in failed_nodes]
Expand Down
16 changes: 15 additions & 1 deletion tests/slurm_plugin/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest
from assertpy import assert_that
from common.utils import read_json, time_is_up
from slurm_plugin.common import TIMESTAMP_FORMAT, get_clustermgtd_heartbeat
from slurm_plugin.common import TIMESTAMP_FORMAT, ScalingStrategy, get_clustermgtd_heartbeat


@pytest.mark.parametrize(
Expand Down Expand Up @@ -106,3 +106,17 @@ def test_read_json(test_datadir, caplog, json_file, default_value, raises_except
assert_that(caplog.text).matches(message_in_log)
else:
assert_that(caplog.text).does_not_match("exception")


@pytest.mark.parametrize(
"strategy_as_value, expected_strategy_enum",
[
("best-effort", ScalingStrategy.BEST_EFFORT),
("all-or-nothing", ScalingStrategy.ALL_OR_NOTHING),
("", ScalingStrategy.ALL_OR_NOTHING),
("invalid-strategy", ScalingStrategy.ALL_OR_NOTHING),
],
)
def test_scaling_strategies_enum_from_value(strategy_as_value, expected_strategy_enum):
strategy_enum = ScalingStrategy(strategy_as_value)
assert_that(strategy_enum).is_equal_to(expected_strategy_enum)
Loading

0 comments on commit 6acd753

Please sign in to comment.