From f65908b83527a6f15d5c502d3f63d41efee8cc67 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Sat, 2 Sep 2023 15:14:29 -0400 Subject: [PATCH] address review issues --- .../app_common/abstract/metric_comparator.py | 32 +++++++ nvflare/app_common/ccwf/client_ctl.py | 73 +++++++++------- nvflare/app_common/ccwf/common.py | 15 ++++ nvflare/app_common/ccwf/cse_client_ctl.py | 56 ++++-------- nvflare/app_common/ccwf/cse_server_ctl.py | 8 +- nvflare/app_common/ccwf/cyclic_client_ctl.py | 5 ++ nvflare/app_common/ccwf/cyclic_server_ctl.py | 3 - nvflare/app_common/ccwf/server_ctl.py | 11 ++- nvflare/app_common/ccwf/swarm_client_ctl.py | 87 ++++++++++++++----- nvflare/app_common/ccwf/swarm_server_ctl.py | 6 ++ .../fuel/utils/validation_utils_test.py | 87 +++++++++++++++++++ 11 files changed, 282 insertions(+), 101 deletions(-) create mode 100644 nvflare/app_common/abstract/metric_comparator.py create mode 100644 tests/unit_test/fuel/utils/validation_utils_test.py diff --git a/nvflare/app_common/abstract/metric_comparator.py b/nvflare/app_common/abstract/metric_comparator.py new file mode 100644 index 0000000000..231ffecb08 --- /dev/null +++ b/nvflare/app_common/abstract/metric_comparator.py @@ -0,0 +1,32 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Union + + +class MetricComparator: + @abstractmethod + def compare(self, a, b) -> Union[int, float]: + """Compare two metric values. + Metric values do not have to be numbers. + + Args: + a: first metric value + b: second metric value + + Returns: negative number if a < b; 0 if a == b; positive number if a > b. + + """ + pass diff --git a/nvflare/app_common/ccwf/client_ctl.py b/nvflare/app_common/ccwf/client_ctl.py index 4298b41cfc..bd74d9ff2a 100644 --- a/nvflare/app_common/ccwf/client_ctl.py +++ b/nvflare/app_common/ccwf/client_ctl.py @@ -29,7 +29,7 @@ from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.app_event_type import AppEventType from nvflare.app_common.ccwf.common import Constant, ResultType, StatusReport, make_task_name, topic_for_end_workflow -from nvflare.fuel.utils.validation_utils import check_positive_number, check_str +from nvflare.fuel.utils.validation_utils import check_non_empty_str, check_positive_number from nvflare.security.logging import secure_format_traceback @@ -56,23 +56,21 @@ def __init__( allow_busy_task: bool = False, ): """ - Constructor of a CWE object. + Constructor of a ClientSideController object. Args: - task_name_prefix: prefix of task names + task_name_prefix: prefix of task names. All CCWF task names are prefixed with this. learn_task_name: name for the Learning Task (LT) + persistor_id: ID of the persistor component + shareable_generator_id: ID of the shareable generator component max_status_report_interval: max interval between status reports to the server learn_task_check_interval: interval for checking incoming Learning Task (LT) learn_task_send_timeout: timeout for sending the LT to other client(s) final_result_send_timeout: timeout for sending final result to participating clients learn_task_abort_timeout: time to wait for the LT to become stopped after aborting it - allow_busy_task: + allow_busy_task: whether a new learn task is allowed when working on current learn task """ - check_str("task_name_prefix", task_name_prefix) - check_str("learn_task_name", learn_task_name) - check_str("persistor_id", persistor_id) - check_str("shareable_generator_id", shareable_generator_id) - + check_non_empty_str("task_name_prefix", task_name_prefix) check_positive_number("max_status_report_interval", max_status_report_interval) check_positive_number("learn_task_check_interval", learn_task_check_interval) check_positive_number("learn_task_send_timeout", learn_task_send_timeout) @@ -138,34 +136,34 @@ def get_config_prop(self, name: str, default=None): return default return self.config.get(name, default) - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self.engine = fl_ctx.get_engine() - if not self.engine: - self.system_panic("no engine", fl_ctx) - return + def start_run(self, fl_ctx: FLContext): + self.engine = fl_ctx.get_engine() + if not self.engine: + self.system_panic("no engine", fl_ctx) + return - runner = fl_ctx.get_prop(FLContextKey.RUNNER) - if not runner: - self.system_panic("no client runner", fl_ctx) - return + runner = fl_ctx.get_prop(FLContextKey.RUNNER) + if not runner: + self.system_panic("no client runner", fl_ctx) + return - self.me = fl_ctx.get_identity_name() + self.me = fl_ctx.get_identity_name() + if self.learn_task_name: self.learn_executor = runner.find_executor(self.learn_task_name) if not self.learn_executor: self.system_panic(f"no executor for task {self.learn_task_name}", fl_ctx) return - engine = fl_ctx.get_engine() - self.persistor = engine.get_component(self.persistor_id) - if not isinstance(self.persistor, LearnablePersistor): - self.system_panic( - f"Persistor {self.persistor_id} must be a Persistor instance, but got {type(self.persistor)}", - fl_ctx, - ) - return + self.persistor = self.engine.get_component(self.persistor_id) + if not isinstance(self.persistor, LearnablePersistor): + self.system_panic( + f"Persistor {self.persistor_id} must be a Persistor instance, but got {type(self.persistor)}", + fl_ctx, + ) + return - self.shareable_generator = engine.get_component(self.shareable_generator_id) + if self.shareable_generator_id: + self.shareable_generator = self.engine.get_component(self.shareable_generator_id) if not isinstance(self.shareable_generator, ShareableGenerator): self.system_panic( f"Shareable generator {self.shareable_generator_id} must be a Shareable Generator instance, " @@ -174,9 +172,16 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): ) return - self.initialize(fl_ctx) + self.initialize(fl_ctx) + + if self.learn_task_name: self.log_info(fl_ctx, "Started learn thread") self.learn_thread.start() + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self.start_run(fl_ctx) + elif event_type == EventType.BEFORE_PULL_TASK: # add my status to fl_ctx if not self.workflow_id: @@ -320,8 +325,11 @@ def _try_broadcast_final_result( for t in targets: reply = resp.get(t) if not isinstance(reply, Shareable): - self.log_error(fl_ctx, f"failed to send {result_type} result to client {t}") - self.log_error(fl_ctx, f"reply must be Shareable but got {type(reply)}") + self.log_error( + fl_ctx, + f"bad response for {result_type} result from client {t}: " + f"reply must be Shareable but got {type(reply)}", + ) num_errors += 1 continue @@ -329,6 +337,7 @@ def _try_broadcast_final_result( if rc != ReturnCode.OK: self.log_error(fl_ctx, f"bad response for {result_type} result from client {t}: {rc}") num_errors += 1 + if num_errors == 0: self.log_info(fl_ctx, f"successfully broadcast {result_type} result to {targets}") return num_errors diff --git a/nvflare/app_common/ccwf/common.py b/nvflare/app_common/ccwf/common.py index 85815c6b07..5d884c5800 100644 --- a/nvflare/app_common/ccwf/common.py +++ b/nvflare/app_common/ccwf/common.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + +from nvflare.app_common.abstract.metric_comparator import MetricComparator + class Constant: @@ -183,3 +187,14 @@ def topic_for_end_workflow(wf_id): def make_task_name(prefix: str, base_name: str) -> str: return f"{prefix}_{base_name}" + + +class NumberMetricComparator(MetricComparator): + def compare(self, a, b) -> Union[int, float]: + if not isinstance(a, (int, float)): + raise ValueError(f"metric value must be a number but got {type(a)}") + + if not isinstance(b, (int, float)): + raise ValueError(f"metric value must be a number but got {type(b)}") + + return a - b diff --git a/nvflare/app_common/ccwf/cse_client_ctl.py b/nvflare/app_common/ccwf/cse_client_ctl.py index 6afc415df9..186d417a04 100644 --- a/nvflare/app_common/ccwf/cse_client_ctl.py +++ b/nvflare/app_common/ccwf/cse_client_ctl.py @@ -14,7 +14,6 @@ import threading from nvflare.apis.controller_spec import Task -from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey, ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable, make_reply @@ -24,7 +23,7 @@ from nvflare.app_common.app_constant import AppConstants, ValidateType from nvflare.app_common.ccwf.client_ctl import ClientSideController from nvflare.app_common.ccwf.common import Constant, ModelType, make_task_name -from nvflare.fuel.utils.validation_utils import check_positive_number, check_str +from nvflare.fuel.utils.validation_utils import check_non_empty_str, check_positive_number from nvflare.security.logging import secure_format_traceback @@ -39,11 +38,14 @@ def __init__( get_model_timeout=Constant.GET_MODEL_TIMEOUT, ): check_positive_number("get_model_timeout", get_model_timeout) - check_str("submit_model_task_name", submit_model_task_name) - check_str("validation_task_name", validation_task_name) + check_non_empty_str("submit_model_task_name", submit_model_task_name) + check_non_empty_str("validation_task_name", validation_task_name) + check_non_empty_str("persistor_id", persistor_id) super().__init__( task_name_prefix=task_name_prefix, + learn_task_name="", + shareable_generator_id="", persistor_id=persistor_id, max_status_report_interval=max_status_report_interval, ) @@ -60,45 +62,21 @@ def __init__( self.local_model = None self.model_lock = threading.Lock() - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self.engine = fl_ctx.get_engine() - if not self.engine: - self.system_panic("no engine", fl_ctx) - return - - runner = fl_ctx.get_prop(FLContextKey.RUNNER) - if not runner: - self.system_panic("no client runner", fl_ctx) + def start_run(self, fl_ctx: FLContext): + super().start_run(fl_ctx) + runner = fl_ctx.get_prop(FLContextKey.RUNNER) + if self.submit_model_task_name: + self.submit_model_executor = runner.find_executor(self.submit_model_task_name) + if not self.submit_model_executor: + self.system_panic(f"no executor for task {self.submit_model_task_name}", fl_ctx) return - self.me = fl_ctx.get_identity_name() - - if self.submit_model_task_name: - self.submit_model_executor = runner.find_executor(self.submit_model_task_name) - if not self.submit_model_executor: - self.system_panic(f"no executor for task {self.submit_model_task_name}", fl_ctx) - return - - if self.validation_task_name: - self.validate_executor = runner.find_executor(self.validation_task_name) - if not self.validate_executor: - self.system_panic(f"no executor for task {self.validation_task_name}", fl_ctx) - return - - engine = fl_ctx.get_engine() - self.persistor = engine.get_component(self.persistor_id) - if not isinstance(self.persistor, ModelPersistor): - self.system_panic( - f"Persistor {self.persistor_id} must be a ModelPersistor instance, but got {type(self.persistor)}", - fl_ctx, - ) + if self.validation_task_name: + self.validate_executor = runner.find_executor(self.validation_task_name) + if not self.validate_executor: + self.system_panic(f"no executor for task {self.validation_task_name}", fl_ctx) return - self.initialize(fl_ctx) - else: - super().handle_event(event_type, fl_ctx) - def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: if task_name == self.start_task_name: self.is_starting_client = True diff --git a/nvflare/app_common/ccwf/cse_server_ctl.py b/nvflare/app_common/ccwf/cse_server_ctl.py index cfef37ebab..e46711e00e 100644 --- a/nvflare/app_common/ccwf/cse_server_ctl.py +++ b/nvflare/app_common/ccwf/cse_server_ctl.py @@ -144,12 +144,12 @@ def prepare_config(self): } def process_config_reply(self, client_name: str, reply: Shareable, fl_ctx: FLContext) -> bool: - global_models = reply.get(Constant.GLOBAL_NAMES) - if global_models: - for m in global_models: + global_names = reply.get(Constant.GLOBAL_NAMES) + if global_names: + for m in global_names: if m not in self.global_names: self.global_names[m] = client_name - self.log_info(fl_ctx, f"got global model {m} from {client_name}") + self.log_info(fl_ctx, f"got global model name {m} from {client_name}") return True def _ask_to_evaluate( diff --git a/nvflare/app_common/ccwf/cyclic_client_ctl.py b/nvflare/app_common/ccwf/cyclic_client_ctl.py index db9822016b..404892f1da 100644 --- a/nvflare/app_common/ccwf/cyclic_client_ctl.py +++ b/nvflare/app_common/ccwf/cyclic_client_ctl.py @@ -20,6 +20,7 @@ from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.ccwf.client_ctl import ClientSideController from nvflare.app_common.ccwf.common import Constant, ResultType, RROrder, rotate_to_front +from nvflare.fuel.utils.validation_utils import check_non_empty_str class CyclicClientController(ClientSideController): @@ -35,6 +36,10 @@ def __init__( learn_task_send_timeout=Constant.LEARN_TASK_SEND_TIMEOUT, final_result_send_timeout=Constant.FINAL_RESULT_SEND_TIMEOUT, ): + check_non_empty_str("learn_task_name", learn_task_name) + check_non_empty_str("persistor_id", persistor_id) + check_non_empty_str("shareable_generator_id", shareable_generator_id) + super().__init__( task_name_prefix=task_name_prefix, learn_task_name=learn_task_name, diff --git a/nvflare/app_common/ccwf/cyclic_server_ctl.py b/nvflare/app_common/ccwf/cyclic_server_ctl.py index 92adac1f6f..ab588a0bc7 100644 --- a/nvflare/app_common/ccwf/cyclic_server_ctl.py +++ b/nvflare/app_common/ccwf/cyclic_server_ctl.py @@ -33,9 +33,6 @@ def __init__( progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, rr_order: str = RROrder.FIXED, ): - if not result_clients: - result_clients = [] - super().__init__( num_rounds=num_rounds, task_name_prefix=task_name_prefix, diff --git a/nvflare/app_common/ccwf/server_ctl.py b/nvflare/app_common/ccwf/server_ctl.py index 5e4bbfe85e..bcf4c53746 100644 --- a/nvflare/app_common/ccwf/server_ctl.py +++ b/nvflare/app_common/ccwf/server_ctl.py @@ -60,7 +60,7 @@ def __init__( start_task_timeout=Constant.START_TASK_TIMEOUT, task_check_period: float = Constant.TASK_CHECK_INTERVAL, job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL, - starting_client: str = None, + starting_client=None, starting_client_policy: str = DefaultPolicy.ANY, starting_client_allow_none=False, participating_clients=None, @@ -89,6 +89,15 @@ def __init__( progress_timeout: """ Controller.__init__(self, task_check_period) + if not participating_clients: + participating_clients = [] + + if not result_clients: + result_clients = [] + + if not starting_client: + starting_client = "" + self.task_name_prefix = task_name_prefix self.configure_task_name = make_task_name(task_name_prefix, Constant.BASENAME_CONFIG) self.configure_task_timeout = configure_task_timeout diff --git a/nvflare/app_common/ccwf/swarm_client_ctl.py b/nvflare/app_common/ccwf/swarm_client_ctl.py index 8e0d698737..7dccb96ec3 100644 --- a/nvflare/app_common/ccwf/swarm_client_ctl.py +++ b/nvflare/app_common/ccwf/swarm_client_ctl.py @@ -17,7 +17,6 @@ import time from nvflare.apis.controller_spec import Task -from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FLContextKey, ReturnCode from nvflare.apis.fl_context import FLContext @@ -25,11 +24,12 @@ from nvflare.apis.signal import Signal from nvflare.app_common.abstract.aggregator import Aggregator from nvflare.app_common.abstract.learnable import Learnable +from nvflare.app_common.abstract.metric_comparator import MetricComparator from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.app_event_type import AppEventType from nvflare.app_common.ccwf.client_ctl import ClientSideController -from nvflare.app_common.ccwf.common import Constant, ResultType, make_task_name -from nvflare.fuel.utils.validation_utils import check_positive_int, check_positive_number, check_str +from nvflare.app_common.ccwf.common import Constant, NumberMetricComparator, ResultType, make_task_name +from nvflare.fuel.utils.validation_utils import check_non_empty_str, check_positive_int, check_positive_number from nvflare.security.logging import secure_format_traceback @@ -47,6 +47,7 @@ def __init__( for_round: int, executor: ClientSideController, aggregator: Aggregator, + metric_comparator: MetricComparator, all_clients: list, trainers: list, min_responses_required: int, @@ -57,6 +58,7 @@ def __init__( self.fl_ctx = fl_ctx self.executor = executor self.aggregator = aggregator + self.metric_comparator = metric_comparator self.all_clients = all_clients self.trainers = trainers self.for_round = for_round @@ -75,12 +77,15 @@ def __init__( self.current_best_client = task_data.get_header(Constant.CLIENT) self.current_best_global_metric = task_data.get_header(Constant.METRIC) self.current_best_round = task_data.get_header(Constant.ROUND) - self.log_info( - fl_ctx, - f"gatherer starting with best client {self.current_best_client} " - f"with metric {self.current_best_global_metric} " - f"at round {self.current_best_round}", - ) + if not self.current_best_client: + self.log_info(fl_ctx, "gatherer starting from scratch") + else: + self.log_info( + fl_ctx, + f"gatherer starting with previous best result from client {self.current_best_client} " + f"with metric {self.current_best_global_metric} " + f"at round {self.current_best_round}", + ) def gather(self, client_name: str, result: Shareable, fl_ctx: FLContext) -> Shareable: with self.lock: @@ -162,7 +167,10 @@ def aggregate(self): mine_is_better = False if self.current_best_global_metric is not None: - if self.executor.best_metric is not None and self.executor.best_metric > self.current_best_global_metric: + if ( + self.executor.best_metric is not None + and self.metric_comparator.compare(self.executor.best_metric, self.current_best_global_metric) > 0 + ): mine_is_better = True elif self.executor.best_metric is not None: mine_is_better = True @@ -198,6 +206,7 @@ def is_done(self): # timeout? now = time.time() if self.timeout and now - self.start_time > self.timeout: + self.log_warning(self.fl_ctx, f"gatherer for round {self.for_round} timed out after {self.timeout} seconds") return True if ( @@ -205,6 +214,11 @@ def is_done(self): and now - self.min_resps_received_time > self.wait_time_after_min_resps_received ): # received min responses required and waited for long time + self.log_info( + self.fl_ctx, + f"gatherer for round {self.for_round} exit after {self.wait_time_after_min_resps_received} seconds " + f"since received minimum responses", + ) return True @@ -216,6 +230,7 @@ def __init__( persistor_id=AppConstants.DEFAULT_PERSISTOR_ID, shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID, aggregator_id=AppConstants.DEFAULT_AGGREGATOR_ID, + metric_comparator_id=None, max_status_report_interval=Constant.MAX_STATUS_REPORT_INTERVAL, learn_task_check_interval=Constant.LEARN_TASK_CHECK_INTERVAL, learn_task_abort_timeout=Constant.LEARN_TASK_ABORT_TIMEOUT, @@ -225,7 +240,13 @@ def __init__( min_responses_required: int = 1, wait_time_after_min_resps_received: float = 10.0, ): - check_str("aggregator_id", aggregator_id) + check_non_empty_str("learn_task_name", learn_task_name) + check_non_empty_str("persistor_id", persistor_id) + check_non_empty_str("shareable_generator_id", shareable_generator_id) + check_non_empty_str("aggregator_id", aggregator_id) + + if metric_comparator_id: + check_non_empty_str("metric_comparator_id", metric_comparator_id) if learn_task_timeout: check_positive_number("learn_task_timeout", learn_task_timeout) @@ -245,6 +266,8 @@ def __init__( final_result_send_timeout=final_result_send_timeout, allow_busy_task=True, ) + self.metric_comparator_id = metric_comparator_id + self.metric_comparator = None self.rcv_learn_result_task_name = make_task_name(task_name_prefix, Constant.BASENAME_RCV_LEARN_RESULT) self.learn_task_timeout = learn_task_timeout self.min_responses_required = min_responses_required @@ -282,33 +305,52 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort return self._process_learn_result(shareable, fl_ctx, abort_signal) return super().execute(task_name, shareable, fl_ctx, abort_signal) - def handle_event(self, event_type: str, fl_ctx: FLContext): - super().handle_event(event_type, fl_ctx) - if event_type == EventType.START_RUN: - self.aggregator = self.engine.get_component(self.aggregator_id) - if not isinstance(self.aggregator, Aggregator): + def start_run(self, fl_ctx: FLContext): + super().start_run(fl_ctx) + self.aggregator = self.engine.get_component(self.aggregator_id) + if not isinstance(self.aggregator, Aggregator): + self.system_panic( + f"aggregator {self.aggregator_id} must be an Aggregator but got {type(self.aggregator)}", + fl_ctx, + ) + return + + if self.metric_comparator_id: + self.metric_comparator = self.engine.get_component(self.metric_comparator_id) + if not isinstance(self.metric_comparator, MetricComparator): self.system_panic( - f"aggregator {self.aggregator_id} must be an Aggregator but got {type(self.aggregator)}", + f"metric comparator {self.metric_comparator_id} must be a MetricComparator " + f"but got {type(self.metric_comparator)}", fl_ctx, ) return + else: + # use default comparator + self.metric_comparator = NumberMetricComparator() + + aggr_thread = threading.Thread(target=self._monitor_gather) + aggr_thread.daemon = True + aggr_thread.start() + self.log_info(fl_ctx, "started aggregator thread") - aggr_thread = threading.Thread(target=self._monitor_gather) - aggr_thread.daemon = True - aggr_thread.start() - self.log_info(fl_ctx, "started aggregator thread") - elif event_type == AppEventType.GLOBAL_BEST_MODEL_AVAILABLE: + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == AppEventType.GLOBAL_BEST_MODEL_AVAILABLE: client = fl_ctx.get_prop(Constant.CLIENT) if client and client != self.me: # this global best model is from other client + # we got here because this event is fired when I receive the best model shared from another + # client at the end of the workflow. return + # we got here because the best model selector fired this event: it found the "local best global" self.best_metric = fl_ctx.get_prop(AppConstants.VALIDATION_RESULT) self.best_result = copy.deepcopy(fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)) self.log_info(fl_ctx, f"got GLOBAL_BEST_MODEL_AVAILABLE: best metric={self.best_metric}") current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) self.best_round = current_round self.update_status(last_round=current_round, action="better_aggregation") + else: + super().handle_event(event_type, fl_ctx) def start_workflow(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: clients = self.get_config_prop(Constant.CLIENTS) @@ -541,6 +583,7 @@ def do_learn_task(self, name: str, task_data: Shareable, fl_ctx: FLContext, abor self.gatherer = Gatherer( fl_ctx=fl_ctx, all_clients=self.get_config_prop(Constant.CLIENTS), + metric_comparator=self.metric_comparator, trainers=self.trainers, for_round=current_round, timeout=self.learn_task_timeout, diff --git a/nvflare/app_common/ccwf/swarm_server_ctl.py b/nvflare/app_common/ccwf/swarm_server_ctl.py index 4cd8e94a8b..06ded491c3 100644 --- a/nvflare/app_common/ccwf/swarm_server_ctl.py +++ b/nvflare/app_common/ccwf/swarm_server_ctl.py @@ -57,6 +57,12 @@ def __init__( max_status_report_interval=max_status_report_interval, progress_timeout=progress_timeout, ) + if not train_clients: + train_clients = [] + + if not aggr_clients: + aggr_clients = [] + self.aggr_clients = aggr_clients self.train_clients = train_clients diff --git a/tests/unit_test/fuel/utils/validation_utils_test.py b/tests/unit_test/fuel/utils/validation_utils_test.py new file mode 100644 index 0000000000..54d95229bd --- /dev/null +++ b/tests/unit_test/fuel/utils/validation_utils_test.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nvflare.fuel.utils.validation_utils import validate_candidate, validate_candidates + + +class TestValidationUtils: + + @pytest.mark.parametrize( + "var_name, candidate, base, default_policy, allow_none, output", + [ + ("x", "red", ["red", "blue"], "any", True, "red"), + ("x", " red ", ["red", "blue"], "any", True, "red"), + ("x", "", ["red", "blue"], "any", True, "red"), + ("x", "", ["red", "blue"], "empty", True, ""), + ("x", None, ["red", "blue"], "any", True, ""), + ("x", "@none", ["red", "blue"], "any", True, ""), + ], + ) + def test_validate_candidate(self, var_name, candidate, base, default_policy, allow_none, output): + assert validate_candidate(var_name, candidate, base, default_policy, allow_none) == output + + @pytest.mark.parametrize( + "var_name, candidate, base, default_policy, allow_none", + [ + ("x", "red", ["red", "blue"], "bad", True), + ("x", 2, ["red", "blue"], "any", True), + ("x", "", ["red", "blue"], "disallow", True), + ("x", "", ["red", "blue"], "all", True), + ("x", "yellow", ["red", "blue"], "any", True), + ("x", None, ["red", "blue"], "any", False), + ("x", "@none", ["red", "blue"], "any", False), + ("x", "@all", ["red", "blue"], "any", False), + ], + ) + def test_validate_candidate_error(self, var_name, candidate, base, default_policy, allow_none): + with pytest.raises(ValueError): + validate_candidate(var_name, candidate, base, default_policy, allow_none) + + @pytest.mark.parametrize( + "var_name, candidates, base, default_policy, allow_none, output", + [ + ("x", "red", ["red", "blue"], "any", True, ["red"]), + ("x", [" red ", "blue", "red"], ["red", "blue", "green"], "any", True, ["red", "blue"]), + ("x", "", ["red", "blue"], "any", True, ["red"]), + ("x", "", ["red", "blue"], "all", True, ["red", "blue"]), + ("x", "", ["red", "blue"], "empty", True, []), + ("x", "red", ["red", "blue"], "any", True, ["red"]), + ("x", [], ["red", "blue"], "any", True, ["red"]), + ("x", [], ["red", "blue"], "empty", True, []), + ("x", [], ["red", "blue"], "all", True, ["red", "blue"]), + ("x", None, ["red", "blue"], "any", True, []), + ("x", "@all", ["red", "blue"], "any", True, ["red", "blue"]), + ("x", "@none", ["red", "blue"], "any", True, []), + ], + ) + def test_validate_candidates(self, var_name, candidates, base, default_policy, allow_none, output): + assert validate_candidates(var_name, candidates, base, default_policy, allow_none) == output + + @pytest.mark.parametrize( + "var_name, candidate, base, default_policy, allow_none", + [ + ("x", "red", ["red", "blue"], "bad", True), + ("x", 2, ["red", "blue"], "any", True), + ("x", "", ["red", "blue"], "disallow", True), + ("x", [], ["red", "blue"], "disallow", True), + ("x", "yellow", ["red", "blue"], "any", True), + ("x", None, ["red", "blue"], "any", False), + ("x", "@none", ["red", "blue"], "any", False), + ], + ) + def test_validate_candidates_error(self, var_name, candidate, base, default_policy, allow_none): + with pytest.raises(ValueError): + validate_candidates(var_name, candidate, base, default_policy, allow_none) \ No newline at end of file