Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support specification of multiple scaling strategies #566

Merged
merged 1 commit into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/slurm_plugin/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
from concurrent.futures import Future
from datetime import datetime
from enum import Enum
from typing import Callable, Optional, Protocol, TypedDict

from common.utils import check_command_output, time_is_up, validate_absolute_path
Expand All @@ -34,6 +35,23 @@
)


class ScalingStrategy(Enum):
ALL_OR_NOTHING = "all-or-nothing"
BEST_EFFORT = "best-effort"

@classmethod
def _missing_(cls, strategy):
Fixed Show fixed Hide fixed
# Ref: https://docs.python.org/3/library/enum.html#enum.Enum._missing_
_strategy = str(strategy).lower()
for member in cls:
if member.value == _strategy:
return member
return cls.ALL_OR_NOTHING # Default to All-Or-Nothing

def __str__(self):
return str(self.value)


class TaskController(Protocol):
class TaskShutdownError(RuntimeError):
"""Exception raised if shutdown has been requested."""
Expand Down
57 changes: 33 additions & 24 deletions src/slurm_plugin/instance_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from common.ec2_utils import get_private_ip_address_and_dns_name
from common.schedulers.slurm_commands import get_nodes_info, update_nodes
from common.utils import grouper, setup_logging_filter
from slurm_plugin.common import ComputeInstanceDescriptor, log_exception, print_with_count
from slurm_plugin.common import ComputeInstanceDescriptor, ScalingStrategy, log_exception, print_with_count
from slurm_plugin.fleet_manager import EC2Instance, FleetManagerFactory
from slurm_plugin.slurm_resources import (
EC2_HEALTH_STATUS_UNHEALTHY_STATES,
Expand Down Expand Up @@ -165,7 +165,7 @@ def add_instances(
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = False,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
slurm_resume: Dict[str, any] = None,
assign_node_batch_size: int = None,
terminate_batch_size: int = None,
Expand Down Expand Up @@ -531,7 +531,7 @@ def add_instances(
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = False,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
slurm_resume: Dict[str, any] = None,
assign_node_batch_size: int = None,
terminate_batch_size: int = None,
Expand All @@ -550,7 +550,7 @@ def add_instances(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)
else:
logger.error(
Expand All @@ -563,7 +563,7 @@ def add_instances(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

self._terminate_unassigned_launched_instances(terminate_batch_size)
Expand All @@ -574,7 +574,7 @@ def _scaling_for_jobs(
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool,
all_or_nothing_batch: bool,
scaling_strategy: ScalingStrategy,
) -> None:
"""Scaling for job list."""
# Setup custom logging filter
Expand All @@ -591,7 +591,7 @@ def _scaling_for_jobs(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

def _terminate_unassigned_launched_instances(self, terminate_batch_size):
Expand All @@ -616,7 +616,7 @@ def _scaling_for_jobs_single_node(
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool,
all_or_nothing_batch: bool,
scaling_strategy: ScalingStrategy,
) -> None:
"""Scaling for job single node list."""
if job_list:
Expand All @@ -627,7 +627,7 @@ def _scaling_for_jobs_single_node(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)
else:
# Batch all single node jobs in a single best-effort EC2 launch request
Expand All @@ -639,7 +639,7 @@ def _scaling_for_jobs_single_node(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=False,
scaling_strategy=ScalingStrategy.BEST_EFFORT,
)

def _add_instances_for_resume_file(
Expand All @@ -649,7 +649,7 @@ def _add_instances_for_resume_file(
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = False,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
):
"""Launch requested EC2 instances for resume file."""
slurm_resume_data = self._get_slurm_resume_data(slurm_resume=slurm_resume, node_list=node_list)
Expand All @@ -663,7 +663,7 @@ def _add_instances_for_resume_file(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

self._scaling_for_jobs_multi_node(
Expand All @@ -673,7 +673,7 @@ def _add_instances_for_resume_file(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

def _scaling_for_jobs_multi_node(
Expand All @@ -683,15 +683,15 @@ def _scaling_for_jobs_multi_node(
launch_batch_size,
assign_node_batch_size,
update_node_address,
all_or_nothing_batch: bool,
scaling_strategy: ScalingStrategy,
):
# Optimize job level scaling with preliminary scale-all nodes attempt
self._update_dict(
self.unused_launched_instances,
self._launch_instances(
nodes_to_launch=self._parse_nodes_resume_list(node_list),
launch_batch_size=launch_batch_size,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
),
)

Expand All @@ -700,7 +700,7 @@ def _scaling_for_jobs_multi_node(
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

def _get_slurm_resume_data(self, slurm_resume: Dict[str, any], node_list: List[str]) -> SlurmResumeData:
Expand Down Expand Up @@ -835,7 +835,7 @@ def _add_instances_for_nodes(
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = True,
scaling_strategy: ScalingStrategy = ScalingStrategy.ALL_OR_NOTHING,
node_list: List[str] = None,
job: SlurmResumeJob = None,
):
Expand All @@ -858,7 +858,7 @@ def _add_instances_for_nodes(
job=job if job else None,
nodes_to_launch=nodes_resume_mapping,
launch_batch_size=launch_batch_size,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)
# instances launched, e.g.
# {
Expand All @@ -873,7 +873,7 @@ def _add_instances_for_nodes(
q_cr_instances_launched_length = len(instances_launched.get(queue, {}).get(compute_resource, []))
successful_launched_nodes += slurm_node_list[:q_cr_instances_launched_length]
failed_launch_nodes += slurm_node_list[q_cr_instances_launched_length:]
if all_or_nothing_batch:
if scaling_strategy == ScalingStrategy.ALL_OR_NOTHING:
self.all_or_nothing_node_assignment(
assign_node_batch_size=assign_node_batch_size,
instances_launched=instances_launched,
Expand Down Expand Up @@ -992,7 +992,7 @@ def _launch_instances(
self,
nodes_to_launch: Dict[str, any],
launch_batch_size: int,
all_or_nothing_batch: bool,
scaling_strategy: ScalingStrategy,
job: SlurmResumeJob = None,
):
instances_launched = defaultdict(lambda: defaultdict(list))
Expand All @@ -1008,9 +1008,13 @@ def _launch_instances(
if slurm_node_list:
logger.info(
"Launching %s instances for nodes %s",
"all-or-nothing" if all_or_nothing_batch else "best-effort",
scaling_strategy,
print_with_count(slurm_node_list),
)
# At instance launch level, the various scaling strategies can be grouped based on the actual
# launch behaviour i.e. all-or-nothing or best-effort
all_or_nothing_batch = scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING]

Comment on lines +1014 to +1017
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand this comment.

all_or_nothing_batch is actually a boolean that is passed downstream to FleetManagerFactory as all_or_nothing ..so maybe it would be better to rename this parameter accordingly since is not clear if/what the _batch adds

While, regards doing scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING]
instead of scaling_strategy == ScalingStrategy.ALL_OR_NOTHING:

as done a few lines above, I assume that this is because we are planning to add another strategy that falls in the "all or nothing" category.

Is this the meaning of the comment above ?

Copy link
Contributor Author

@EddyMM EddyMM Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all_or_nothing_batch the _batch alludes to the way we chunk the launch attempts in "batches" of a certain size/count of instances. So all_or_nothing_batch implies that each batch will be done in an all-or-nothing fashion.

as done a few lines above, I assume that this is because we are planning to add another strategy that falls in the "all or nothing" category.

Exactly, it was done with this in mind. That expression will be reused in the upcoming PR.

fleet_manager = self._get_fleet_manager(all_or_nothing_batch, compute_resource, queue)

for batch_nodes in grouper(slurm_node_list, launch_batch_size):
Expand Down Expand Up @@ -1203,7 +1207,8 @@ def add_instances(
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = False,
# Default to BEST_EFFORT since clustermgtd is not yet adapted for Job Level Scaling
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
slurm_resume: Dict[str, any] = None,
assign_node_batch_size: int = None,
terminate_batch_size: int = None,
Expand All @@ -1217,17 +1222,21 @@ def add_instances(
node_list=node_list,
launch_batch_size=launch_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
scaling_strategy=scaling_strategy,
)

def _add_instances_for_nodes(
self,
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
all_or_nothing_batch: bool = False,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
):
"""Launch requested EC2 instances for nodes."""
# At fleet management level, the scaling strategies can be grouped based on the actual
# launch behaviour i.e. all-or-nothing or best-effort
all_or_nothing_batch = scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING]

Comment on lines +1236 to +1239
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above.

nodes_to_launch = self._parse_nodes_resume_list(node_list)
for queue, compute_resources in nodes_to_launch.items():
for compute_resource, slurm_node_list in compute_resources.items():
Expand Down
9 changes: 6 additions & 3 deletions src/slurm_plugin/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from common.schedulers.slurm_commands import get_nodes_info, set_nodes_down
from common.utils import read_json
from slurm_plugin.cluster_event_publisher import ClusterEventPublisher
from slurm_plugin.common import is_clustermgtd_heartbeat_valid, print_with_count
from slurm_plugin.common import ScalingStrategy, is_clustermgtd_heartbeat_valid, print_with_count
from slurm_plugin.instance_manager import InstanceManagerFactory
from slurm_plugin.slurm_resources import CONFIG_FILE_DIR

Expand All @@ -45,8 +45,8 @@ class SlurmResumeConfig:
"run_instances_overrides": "/opt/slurm/etc/pcluster/run_instances_overrides.json",
"create_fleet_overrides": "/opt/slurm/etc/pcluster/create_fleet_overrides.json",
"fleet_config_file": "/etc/parallelcluster/slurm_plugin/fleet-config.json",
"all_or_nothing_batch": True,
"job_level_scaling": True,
"scaling_strategy": "all-or-nothing",
}

def __init__(self, config_file_path):
Expand Down Expand Up @@ -92,6 +92,9 @@ def _get_config(self, config_file_path):
self.all_or_nothing_batch = config.getboolean(
"slurm_resume", "all_or_nothing_batch", fallback=self.DEFAULTS.get("all_or_nothing_batch")
)
self.scaling_strategy = config.get(
"slurm_resume", "scaling_strategy", fallback=self.DEFAULTS.get("scaling_strategy")
) # TODO: Check if it's a valid scaling strategy before calling expensive downstream APIs
self.job_level_scaling = config.getboolean(
"slurm_resume", "job_level_scaling", fallback=self.DEFAULTS.get("job_level_scaling")
)
Expand Down Expand Up @@ -213,7 +216,7 @@ def _resume(arg_nodes, resume_config, slurm_resume):
assign_node_batch_size=resume_config.assign_node_max_batch_size,
terminate_batch_size=resume_config.terminate_max_batch_size,
update_node_address=resume_config.update_node_address,
all_or_nothing_batch=resume_config.all_or_nothing_batch,
scaling_strategy=ScalingStrategy(resume_config.scaling_strategy),
)
failed_nodes = set().union(*instance_manager.failed_nodes.values())
success_nodes = [node for node in node_list if node not in failed_nodes]
Expand Down
16 changes: 15 additions & 1 deletion tests/slurm_plugin/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest
from assertpy import assert_that
from common.utils import read_json, time_is_up
from slurm_plugin.common import TIMESTAMP_FORMAT, get_clustermgtd_heartbeat
from slurm_plugin.common import TIMESTAMP_FORMAT, ScalingStrategy, get_clustermgtd_heartbeat


@pytest.mark.parametrize(
Expand Down Expand Up @@ -106,3 +106,17 @@ def test_read_json(test_datadir, caplog, json_file, default_value, raises_except
assert_that(caplog.text).matches(message_in_log)
else:
assert_that(caplog.text).does_not_match("exception")


@pytest.mark.parametrize(
"strategy_as_value, expected_strategy_enum",
[
("best-effort", ScalingStrategy.BEST_EFFORT),
("all-or-nothing", ScalingStrategy.ALL_OR_NOTHING),
("", ScalingStrategy.ALL_OR_NOTHING),
("invalid-strategy", ScalingStrategy.ALL_OR_NOTHING),
],
)
def test_scaling_strategies_enum_from_value(strategy_as_value, expected_strategy_enum):
strategy_enum = ScalingStrategy(strategy_as_value)
assert_that(strategy_enum).is_equal_to(expected_strategy_enum)
Loading