diff --git a/nvflare/app_common/ccwf/cse_server_ctl.py b/nvflare/app_common/ccwf/cse_server_ctl.py index ab1751d1e8..d1c49feb34 100644 --- a/nvflare/app_common/ccwf/cse_server_ctl.py +++ b/nvflare/app_common/ccwf/cse_server_ctl.py @@ -67,11 +67,9 @@ def __init__( participating_clients=participating_clients, starting_client="", starting_client_policy=DefaultValuePolicy.EMPTY, - starting_client_allow_none=True, max_status_report_interval=max_status_report_interval, - result_clients=None, + result_clients="", result_clients_policy=DefaultValuePolicy.EMPTY, - result_clients_allow_none=True, progress_timeout=progress_timeout, ) diff --git a/nvflare/app_common/ccwf/cyclic_server_ctl.py b/nvflare/app_common/ccwf/cyclic_server_ctl.py index 10eb020769..0a7670fe1a 100644 --- a/nvflare/app_common/ccwf/cyclic_server_ctl.py +++ b/nvflare/app_common/ccwf/cyclic_server_ctl.py @@ -14,7 +14,7 @@ from nvflare.app_common.ccwf.common import Constant, RROrder from nvflare.app_common.ccwf.server_ctl import ServerSideController -from nvflare.fuel.utils.validation_utils import DefaultValuePolicy, check_str +from nvflare.fuel.utils.validation_utils import DefaultValuePolicy, check_str, normalize_config_arg class CyclicServerController(ServerSideController): @@ -33,6 +33,11 @@ def __init__( progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, rr_order: str = RROrder.FIXED, ): + result_clients = normalize_config_arg(result_clients) + starting_client = normalize_config_arg(starting_client) + if starting_client is None: + raise ValueError("starting_client must be specified") + super().__init__( num_rounds=num_rounds, task_name_prefix=task_name_prefix, @@ -43,10 +48,8 @@ def __init__( participating_clients=participating_clients, result_clients=result_clients, result_clients_policy=DefaultValuePolicy.ALL, - result_clients_allow_none=True, starting_client=starting_client, starting_client_policy=DefaultValuePolicy.ANY, - starting_client_allow_none=False, max_status_report_interval=max_status_report_interval, progress_timeout=progress_timeout, ) diff --git a/nvflare/app_common/ccwf/server_ctl.py b/nvflare/app_common/ccwf/server_ctl.py index 8274b007bd..f7f10c1f9d 100644 --- a/nvflare/app_common/ccwf/server_ctl.py +++ b/nvflare/app_common/ccwf/server_ctl.py @@ -35,6 +35,7 @@ check_positive_int, check_positive_number, check_str, + normalize_config_arg, validate_candidate, validate_candidates, ) @@ -63,11 +64,9 @@ def __init__( job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL, starting_client=None, starting_client_policy: str = DefaultValuePolicy.ANY, - starting_client_allow_none=False, participating_clients=None, result_clients=None, result_clients_policy: str = DefaultValuePolicy.ALL, - result_clients_allow_none=True, max_status_report_interval: float = Constant.PER_CLIENT_STATUS_REPORT_TIMEOUT, progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, ): @@ -90,14 +89,9 @@ def __init__( progress_timeout: """ Controller.__init__(self, task_check_period) - if not participating_clients: - participating_clients = [] - - if not result_clients: - result_clients = [] - - if not starting_client: - starting_client = "" + participating_clients = normalize_config_arg(participating_clients) + if participating_clients is None: + raise ValueError("participating_clients must not be empty") self.task_name_prefix = task_name_prefix self.configure_task_name = make_task_name(task_name_prefix, Constant.BASENAME_CONFIG) @@ -112,11 +106,9 @@ def __init__( self.job_status_check_interval = job_status_check_interval self.starting_client = starting_client self.starting_client_policy = starting_client_policy - self.starting_client_allow_none = starting_client_allow_none self.participating_clients = participating_clients self.result_clients = result_clients self.result_clients_policy = result_clients_policy - self.result_clients_allow_none = result_clients_allow_none self.client_statuses = {} # client name => ClientStatus self.cw_started = False self.asked_to_stop = False @@ -141,8 +133,8 @@ def start_controller(self, fl_ctx: FLContext): self.workflow_id = wf_id all_clients = self._engine.get_clients() - if len(all_clients) <= 1: - raise RuntimeError("Not enough client sites.") + if len(all_clients) < 2: + raise RuntimeError(f"this workflow requires at least 2 clients, but only got {all_clients}") all_client_names = [t.name for t in all_clients] self.participating_clients = validate_candidates( @@ -158,7 +150,7 @@ def start_controller(self, fl_ctx: FLContext): candidate=self.starting_client, base=self.participating_clients, default_policy=self.starting_client_policy, - allow_none=self.starting_client_allow_none, + allow_none=True, ) self.result_clients = validate_candidates( @@ -166,11 +158,9 @@ def start_controller(self, fl_ctx: FLContext): candidates=self.result_clients, base=self.participating_clients, default_policy=self.result_clients_policy, - allow_none=self.result_clients_allow_none, + allow_none=True, ) - self.log_info(fl_ctx, f"result clients: {self.result_clients}") - for c in self.participating_clients: self.client_statuses[c] = ClientStatus() @@ -201,6 +191,8 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): if extra_config: learn_config.update(extra_config) + self.log_info(fl_ctx, f"Workflow Config: {learn_config}") + # configure all clients shareable = Shareable() shareable[Constant.CONFIG] = learn_config diff --git a/nvflare/app_common/ccwf/swarm_server_ctl.py b/nvflare/app_common/ccwf/swarm_server_ctl.py index 980b67645b..60fdf2e3c5 100644 --- a/nvflare/app_common/ccwf/swarm_server_ctl.py +++ b/nvflare/app_common/ccwf/swarm_server_ctl.py @@ -15,7 +15,7 @@ from nvflare.apis.fl_context import FLContext from nvflare.app_common.ccwf.common import Constant from nvflare.app_common.ccwf.server_ctl import ServerSideController -from nvflare.fuel.utils.validation_utils import DefaultValuePolicy, validate_candidates +from nvflare.fuel.utils.validation_utils import DefaultValuePolicy, normalize_config_arg, validate_candidates class SwarmServerController(ServerSideController): @@ -36,8 +36,10 @@ def __init__( aggr_clients=None, train_clients=None, ): - if not result_clients: - result_clients = [] + result_clients = normalize_config_arg(result_clients) + starting_client = normalize_config_arg(starting_client) + if starting_client is None: + raise ValueError("starting_client must be specified") super().__init__( num_rounds=num_rounds, @@ -50,10 +52,8 @@ def __init__( participating_clients=participating_clients, result_clients=result_clients, result_clients_policy=DefaultValuePolicy.ALL, - result_clients_allow_none=True, starting_client=starting_client, starting_client_policy=DefaultValuePolicy.ANY, - starting_client_allow_none=False, max_status_report_interval=max_status_report_interval, progress_timeout=progress_timeout, ) diff --git a/nvflare/fuel/utils/validation_utils.py b/nvflare/fuel/utils/validation_utils.py index 868ba94bef..626787ce04 100644 --- a/nvflare/fuel/utils/validation_utils.py +++ b/nvflare/fuel/utils/validation_utils.py @@ -263,3 +263,14 @@ def validate_candidate(var_name: str, candidate, base: list, default_policy: str raise ValueError(f"invalid value '{candidate}' in '{var_name}': it must be one of {base}") else: return c + + +def normalize_config_arg(value): + if value is False: + return None # specified to be "empty" + if isinstance(value, str): + if value.strip().lower() == SYMBOL_NONE: + return None + if not value: + return "" # meaning to take default + return value diff --git a/tests/unit_test/fuel/utils/validation_utils_test.py b/tests/unit_test/fuel/utils/validation_utils_test.py index 83548ff812..f4d7195621 100644 --- a/tests/unit_test/fuel/utils/validation_utils_test.py +++ b/tests/unit_test/fuel/utils/validation_utils_test.py @@ -20,12 +20,32 @@ check_number_range, check_positive_int, check_positive_number, + normalize_config_arg, validate_candidate, validate_candidates, ) class TestValidationUtils: + @pytest.mark.parametrize( + "value, result", + [ + ("x", "x"), + (123, 123), + ("", ""), + (False, None), + ("@None", None), + (None, ""), + (0, ""), + ([], ""), + ({}, ""), + ((), ""), + ([1, 2, 3], [1, 2, 3]), + ], + ) + def test_normalize_config_arg(self, value, result): + assert normalize_config_arg(value) == result + @pytest.mark.parametrize( "name, num, min_value, max_value", [