Skip to content

Commit

Permalink
simplify server side controller init args
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Sep 7, 2023
1 parent b272b4b commit ecbc479
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 29 deletions.
4 changes: 1 addition & 3 deletions nvflare/app_common/ccwf/cse_server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,9 @@ def __init__(
participating_clients=participating_clients,
starting_client="",
starting_client_policy=DefaultValuePolicy.EMPTY,
starting_client_allow_none=True,
max_status_report_interval=max_status_report_interval,
result_clients=None,
result_clients="",
result_clients_policy=DefaultValuePolicy.EMPTY,
result_clients_allow_none=True,
progress_timeout=progress_timeout,
)

Expand Down
9 changes: 6 additions & 3 deletions nvflare/app_common/ccwf/cyclic_server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from nvflare.app_common.ccwf.common import Constant, RROrder
from nvflare.app_common.ccwf.server_ctl import ServerSideController
from nvflare.fuel.utils.validation_utils import DefaultValuePolicy, check_str
from nvflare.fuel.utils.validation_utils import DefaultValuePolicy, check_str, normalize_config_arg


class CyclicServerController(ServerSideController):
Expand All @@ -33,6 +33,11 @@ def __init__(
progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT,
rr_order: str = RROrder.FIXED,
):
result_clients = normalize_config_arg(result_clients)
starting_client = normalize_config_arg(starting_client)
if starting_client is None:
raise ValueError("starting_client must be specified")

super().__init__(
num_rounds=num_rounds,
task_name_prefix=task_name_prefix,
Expand All @@ -43,10 +48,8 @@ def __init__(
participating_clients=participating_clients,
result_clients=result_clients,
result_clients_policy=DefaultValuePolicy.ALL,
result_clients_allow_none=True,
starting_client=starting_client,
starting_client_policy=DefaultValuePolicy.ANY,
starting_client_allow_none=False,
max_status_report_interval=max_status_report_interval,
progress_timeout=progress_timeout,
)
Expand Down
28 changes: 10 additions & 18 deletions nvflare/app_common/ccwf/server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
check_positive_int,
check_positive_number,
check_str,
normalize_config_arg,
validate_candidate,
validate_candidates,
)
Expand Down Expand Up @@ -63,11 +64,9 @@ def __init__(
job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL,
starting_client=None,
starting_client_policy: str = DefaultValuePolicy.ANY,
starting_client_allow_none=False,
participating_clients=None,
result_clients=None,
result_clients_policy: str = DefaultValuePolicy.ALL,
result_clients_allow_none=True,
max_status_report_interval: float = Constant.PER_CLIENT_STATUS_REPORT_TIMEOUT,
progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT,
):
Expand All @@ -90,14 +89,9 @@ def __init__(
progress_timeout:
"""
Controller.__init__(self, task_check_period)
if not participating_clients:
participating_clients = []

if not result_clients:
result_clients = []

if not starting_client:
starting_client = ""
participating_clients = normalize_config_arg(participating_clients)
if participating_clients is None:
raise ValueError("participating_clients must not be empty")

self.task_name_prefix = task_name_prefix
self.configure_task_name = make_task_name(task_name_prefix, Constant.BASENAME_CONFIG)
Expand All @@ -112,11 +106,9 @@ def __init__(
self.job_status_check_interval = job_status_check_interval
self.starting_client = starting_client
self.starting_client_policy = starting_client_policy
self.starting_client_allow_none = starting_client_allow_none
self.participating_clients = participating_clients
self.result_clients = result_clients
self.result_clients_policy = result_clients_policy
self.result_clients_allow_none = result_clients_allow_none
self.client_statuses = {} # client name => ClientStatus
self.cw_started = False
self.asked_to_stop = False
Expand All @@ -141,8 +133,8 @@ def start_controller(self, fl_ctx: FLContext):
self.workflow_id = wf_id

all_clients = self._engine.get_clients()
if len(all_clients) <= 1:
raise RuntimeError("Not enough client sites.")
if len(all_clients) < 2:
raise RuntimeError(f"this workflow requires at least 2 clients, but only got {all_clients}")

all_client_names = [t.name for t in all_clients]
self.participating_clients = validate_candidates(
Expand All @@ -158,19 +150,17 @@ def start_controller(self, fl_ctx: FLContext):
candidate=self.starting_client,
base=self.participating_clients,
default_policy=self.starting_client_policy,
allow_none=self.starting_client_allow_none,
allow_none=True,
)

self.result_clients = validate_candidates(
var_name="result_clients",
candidates=self.result_clients,
base=self.participating_clients,
default_policy=self.result_clients_policy,
allow_none=self.result_clients_allow_none,
allow_none=True,
)

self.log_info(fl_ctx, f"result clients: {self.result_clients}")

for c in self.participating_clients:
self.client_statuses[c] = ClientStatus()

Expand Down Expand Up @@ -201,6 +191,8 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
if extra_config:
learn_config.update(extra_config)

self.log_info(fl_ctx, f"Workflow Config: {learn_config}")

# configure all clients
shareable = Shareable()
shareable[Constant.CONFIG] = learn_config
Expand Down
10 changes: 5 additions & 5 deletions nvflare/app_common/ccwf/swarm_server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.ccwf.common import Constant
from nvflare.app_common.ccwf.server_ctl import ServerSideController
from nvflare.fuel.utils.validation_utils import DefaultValuePolicy, validate_candidates
from nvflare.fuel.utils.validation_utils import DefaultValuePolicy, normalize_config_arg, validate_candidates


class SwarmServerController(ServerSideController):
Expand All @@ -36,8 +36,10 @@ def __init__(
aggr_clients=None,
train_clients=None,
):
if not result_clients:
result_clients = []
result_clients = normalize_config_arg(result_clients)
starting_client = normalize_config_arg(starting_client)
if starting_client is None:
raise ValueError("starting_client must be specified")

super().__init__(
num_rounds=num_rounds,
Expand All @@ -50,10 +52,8 @@ def __init__(
participating_clients=participating_clients,
result_clients=result_clients,
result_clients_policy=DefaultValuePolicy.ALL,
result_clients_allow_none=True,
starting_client=starting_client,
starting_client_policy=DefaultValuePolicy.ANY,
starting_client_allow_none=False,
max_status_report_interval=max_status_report_interval,
progress_timeout=progress_timeout,
)
Expand Down
11 changes: 11 additions & 0 deletions nvflare/fuel/utils/validation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,14 @@ def validate_candidate(var_name: str, candidate, base: list, default_policy: str
raise ValueError(f"invalid value '{candidate}' in '{var_name}': it must be one of {base}")
else:
return c


def normalize_config_arg(value):
if value is False:
return None # specified to be "empty"
if isinstance(value, str):
if value.strip().lower() == SYMBOL_NONE:
return None
if not value:
return "" # meaning to take default
return value
20 changes: 20 additions & 0 deletions tests/unit_test/fuel/utils/validation_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,32 @@
check_number_range,
check_positive_int,
check_positive_number,
normalize_config_arg,
validate_candidate,
validate_candidates,
)


class TestValidationUtils:
@pytest.mark.parametrize(
"value, result",
[
("x", "x"),
(123, 123),
("", ""),
(False, None),
("@None", None),
(None, ""),
(0, ""),
([], ""),
({}, ""),
((), ""),
([1, 2, 3], [1, 2, 3]),
],
)
def test_normalize_config_arg(self, value, result):
assert normalize_config_arg(value) == result

@pytest.mark.parametrize(
"name, num, min_value, max_value",
[
Expand Down

0 comments on commit ecbc479

Please sign in to comment.