diff --git a/src/slurm_plugin/common.py b/src/slurm_plugin/common.py index accffba63..cc9151802 100644 --- a/src/slurm_plugin/common.py +++ b/src/slurm_plugin/common.py @@ -36,23 +36,17 @@ class ScalingStrategy(Enum): - ALL_OR_NOTHING = "all_or_nothing" - BEST_EFFORT = "best_effort" + ALL_OR_NOTHING = "all-or-nothing" + BEST_EFFORT = "best-effort" DEFAULT = BEST_EFFORT @classmethod - def _missing_(cls, value): - return cls.DEFAULT - - @classmethod - def from_str(cls, strategy): + def _missing_(cls, strategy): _strategy = str(strategy).lower() - if _strategy == "all_or_nothing": - return cls.ALL_OR_NOTHING - elif _strategy == "best_effort": - return cls.BEST_EFFORT - else: + for member in cls: + if member.value == _strategy: + return member return cls.DEFAULT def __str__(self): diff --git a/src/slurm_plugin/resume.py b/src/slurm_plugin/resume.py index 0aeb53d3b..13647cac0 100644 --- a/src/slurm_plugin/resume.py +++ b/src/slurm_plugin/resume.py @@ -46,7 +46,7 @@ class SlurmResumeConfig: "create_fleet_overrides": "/opt/slurm/etc/pcluster/create_fleet_overrides.json", "fleet_config_file": "/etc/parallelcluster/slurm_plugin/fleet-config.json", "job_level_scaling": True, - "scaling_strategy": "all_or_nothing", + "scaling_strategy": "all-or-nothing", } def __init__(self, config_file_path): diff --git a/tests/slurm_plugin/test_common.py b/tests/slurm_plugin/test_common.py index 5dd1ea854..a0cf72281 100644 --- a/tests/slurm_plugin/test_common.py +++ b/tests/slurm_plugin/test_common.py @@ -14,7 +14,7 @@ import pytest from assertpy import assert_that from common.utils import read_json, time_is_up -from slurm_plugin.common import TIMESTAMP_FORMAT, get_clustermgtd_heartbeat +from slurm_plugin.common import TIMESTAMP_FORMAT, get_clustermgtd_heartbeat, ScalingStrategy @pytest.mark.parametrize( @@ -106,3 +106,29 @@ def test_read_json(test_datadir, caplog, json_file, default_value, raises_except assert_that(caplog.text).matches(message_in_log) else: assert_that(caplog.text).does_not_match("exception") + + +@pytest.mark.parametrize( + "strategy_as_value, expected_strategy_enum", + [ + ( + "best-effort", + ScalingStrategy.BEST_EFFORT + ), + ( + "all-or-nothing", + ScalingStrategy.ALL_OR_NOTHING + ), + ( + "", + ScalingStrategy.DEFAULT + ), + ( + "invalid-strategy", + ScalingStrategy.DEFAULT + ), + ] +) +def test_scaling_strategies_enum_from_value(strategy_as_value, expected_strategy_enum): + strategy_enum = ScalingStrategy(strategy_as_value) + assert_that(strategy_enum).is_equal_to(expected_strategy_enum)