Skip to content

Commit

Permalink
address review issues
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Sep 2, 2023
1 parent 3ee9133 commit f65908b
Show file tree
Hide file tree
Showing 11 changed files with 282 additions and 101 deletions.
32 changes: 32 additions & 0 deletions nvflare/app_common/abstract/metric_comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import Union


class MetricComparator:
@abstractmethod
def compare(self, a, b) -> Union[int, float]:
"""Compare two metric values.
Metric values do not have to be numbers.
Args:
a: first metric value
b: second metric value
Returns: negative number if a < b; 0 if a == b; positive number if a > b.
"""
pass
73 changes: 41 additions & 32 deletions nvflare/app_common/ccwf/client_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.ccwf.common import Constant, ResultType, StatusReport, make_task_name, topic_for_end_workflow
from nvflare.fuel.utils.validation_utils import check_positive_number, check_str
from nvflare.fuel.utils.validation_utils import check_non_empty_str, check_positive_number
from nvflare.security.logging import secure_format_traceback


Expand All @@ -56,23 +56,21 @@ def __init__(
allow_busy_task: bool = False,
):
"""
Constructor of a CWE object.
Constructor of a ClientSideController object.
Args:
task_name_prefix: prefix of task names
task_name_prefix: prefix of task names. All CCWF task names are prefixed with this.
learn_task_name: name for the Learning Task (LT)
persistor_id: ID of the persistor component
shareable_generator_id: ID of the shareable generator component
max_status_report_interval: max interval between status reports to the server
learn_task_check_interval: interval for checking incoming Learning Task (LT)
learn_task_send_timeout: timeout for sending the LT to other client(s)
final_result_send_timeout: timeout for sending final result to participating clients
learn_task_abort_timeout: time to wait for the LT to become stopped after aborting it
allow_busy_task:
allow_busy_task: whether a new learn task is allowed when working on current learn task
"""
check_str("task_name_prefix", task_name_prefix)
check_str("learn_task_name", learn_task_name)
check_str("persistor_id", persistor_id)
check_str("shareable_generator_id", shareable_generator_id)

check_non_empty_str("task_name_prefix", task_name_prefix)
check_positive_number("max_status_report_interval", max_status_report_interval)
check_positive_number("learn_task_check_interval", learn_task_check_interval)
check_positive_number("learn_task_send_timeout", learn_task_send_timeout)
Expand Down Expand Up @@ -138,34 +136,34 @@ def get_config_prop(self, name: str, default=None):
return default
return self.config.get(name, default)

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.engine = fl_ctx.get_engine()
if not self.engine:
self.system_panic("no engine", fl_ctx)
return
def start_run(self, fl_ctx: FLContext):
self.engine = fl_ctx.get_engine()
if not self.engine:
self.system_panic("no engine", fl_ctx)
return

runner = fl_ctx.get_prop(FLContextKey.RUNNER)
if not runner:
self.system_panic("no client runner", fl_ctx)
return
runner = fl_ctx.get_prop(FLContextKey.RUNNER)
if not runner:
self.system_panic("no client runner", fl_ctx)
return

self.me = fl_ctx.get_identity_name()
self.me = fl_ctx.get_identity_name()
if self.learn_task_name:
self.learn_executor = runner.find_executor(self.learn_task_name)
if not self.learn_executor:
self.system_panic(f"no executor for task {self.learn_task_name}", fl_ctx)
return

engine = fl_ctx.get_engine()
self.persistor = engine.get_component(self.persistor_id)
if not isinstance(self.persistor, LearnablePersistor):
self.system_panic(
f"Persistor {self.persistor_id} must be a Persistor instance, but got {type(self.persistor)}",
fl_ctx,
)
return
self.persistor = self.engine.get_component(self.persistor_id)
if not isinstance(self.persistor, LearnablePersistor):
self.system_panic(
f"Persistor {self.persistor_id} must be a Persistor instance, but got {type(self.persistor)}",
fl_ctx,
)
return

self.shareable_generator = engine.get_component(self.shareable_generator_id)
if self.shareable_generator_id:
self.shareable_generator = self.engine.get_component(self.shareable_generator_id)
if not isinstance(self.shareable_generator, ShareableGenerator):
self.system_panic(
f"Shareable generator {self.shareable_generator_id} must be a Shareable Generator instance, "
Expand All @@ -174,9 +172,16 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
)
return

self.initialize(fl_ctx)
self.initialize(fl_ctx)

if self.learn_task_name:
self.log_info(fl_ctx, "Started learn thread")
self.learn_thread.start()

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.start_run(fl_ctx)

elif event_type == EventType.BEFORE_PULL_TASK:
# add my status to fl_ctx
if not self.workflow_id:
Expand Down Expand Up @@ -320,15 +325,19 @@ def _try_broadcast_final_result(
for t in targets:
reply = resp.get(t)
if not isinstance(reply, Shareable):
self.log_error(fl_ctx, f"failed to send {result_type} result to client {t}")
self.log_error(fl_ctx, f"reply must be Shareable but got {type(reply)}")
self.log_error(
fl_ctx,
f"bad response for {result_type} result from client {t}: "
f"reply must be Shareable but got {type(reply)}",
)
num_errors += 1
continue

rc = reply.get_return_code(ReturnCode.OK)
if rc != ReturnCode.OK:
self.log_error(fl_ctx, f"bad response for {result_type} result from client {t}: {rc}")
num_errors += 1

if num_errors == 0:
self.log_info(fl_ctx, f"successfully broadcast {result_type} result to {targets}")
return num_errors
Expand Down
15 changes: 15 additions & 0 deletions nvflare/app_common/ccwf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

from nvflare.app_common.abstract.metric_comparator import MetricComparator


class Constant:

Expand Down Expand Up @@ -183,3 +187,14 @@ def topic_for_end_workflow(wf_id):

def make_task_name(prefix: str, base_name: str) -> str:
return f"{prefix}_{base_name}"


class NumberMetricComparator(MetricComparator):
def compare(self, a, b) -> Union[int, float]:
if not isinstance(a, (int, float)):
raise ValueError(f"metric value must be a number but got {type(a)}")

if not isinstance(b, (int, float)):
raise ValueError(f"metric value must be a number but got {type(b)}")

return a - b
56 changes: 17 additions & 39 deletions nvflare/app_common/ccwf/cse_client_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import threading

from nvflare.apis.controller_spec import Task
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
Expand All @@ -24,7 +23,7 @@
from nvflare.app_common.app_constant import AppConstants, ValidateType
from nvflare.app_common.ccwf.client_ctl import ClientSideController
from nvflare.app_common.ccwf.common import Constant, ModelType, make_task_name
from nvflare.fuel.utils.validation_utils import check_positive_number, check_str
from nvflare.fuel.utils.validation_utils import check_non_empty_str, check_positive_number
from nvflare.security.logging import secure_format_traceback


Expand All @@ -39,11 +38,14 @@ def __init__(
get_model_timeout=Constant.GET_MODEL_TIMEOUT,
):
check_positive_number("get_model_timeout", get_model_timeout)
check_str("submit_model_task_name", submit_model_task_name)
check_str("validation_task_name", validation_task_name)
check_non_empty_str("submit_model_task_name", submit_model_task_name)
check_non_empty_str("validation_task_name", validation_task_name)
check_non_empty_str("persistor_id", persistor_id)

super().__init__(
task_name_prefix=task_name_prefix,
learn_task_name="",
shareable_generator_id="",
persistor_id=persistor_id,
max_status_report_interval=max_status_report_interval,
)
Expand All @@ -60,45 +62,21 @@ def __init__(
self.local_model = None
self.model_lock = threading.Lock()

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.engine = fl_ctx.get_engine()
if not self.engine:
self.system_panic("no engine", fl_ctx)
return

runner = fl_ctx.get_prop(FLContextKey.RUNNER)
if not runner:
self.system_panic("no client runner", fl_ctx)
def start_run(self, fl_ctx: FLContext):
super().start_run(fl_ctx)
runner = fl_ctx.get_prop(FLContextKey.RUNNER)
if self.submit_model_task_name:
self.submit_model_executor = runner.find_executor(self.submit_model_task_name)
if not self.submit_model_executor:
self.system_panic(f"no executor for task {self.submit_model_task_name}", fl_ctx)
return

self.me = fl_ctx.get_identity_name()

if self.submit_model_task_name:
self.submit_model_executor = runner.find_executor(self.submit_model_task_name)
if not self.submit_model_executor:
self.system_panic(f"no executor for task {self.submit_model_task_name}", fl_ctx)
return

if self.validation_task_name:
self.validate_executor = runner.find_executor(self.validation_task_name)
if not self.validate_executor:
self.system_panic(f"no executor for task {self.validation_task_name}", fl_ctx)
return

engine = fl_ctx.get_engine()
self.persistor = engine.get_component(self.persistor_id)
if not isinstance(self.persistor, ModelPersistor):
self.system_panic(
f"Persistor {self.persistor_id} must be a ModelPersistor instance, but got {type(self.persistor)}",
fl_ctx,
)
if self.validation_task_name:
self.validate_executor = runner.find_executor(self.validation_task_name)
if not self.validate_executor:
self.system_panic(f"no executor for task {self.validation_task_name}", fl_ctx)
return

self.initialize(fl_ctx)
else:
super().handle_event(event_type, fl_ctx)

def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
if task_name == self.start_task_name:
self.is_starting_client = True
Expand Down
8 changes: 4 additions & 4 deletions nvflare/app_common/ccwf/cse_server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ def prepare_config(self):
}

def process_config_reply(self, client_name: str, reply: Shareable, fl_ctx: FLContext) -> bool:
global_models = reply.get(Constant.GLOBAL_NAMES)
if global_models:
for m in global_models:
global_names = reply.get(Constant.GLOBAL_NAMES)
if global_names:
for m in global_names:
if m not in self.global_names:
self.global_names[m] = client_name
self.log_info(fl_ctx, f"got global model {m} from {client_name}")
self.log_info(fl_ctx, f"got global model name {m} from {client_name}")
return True

def _ask_to_evaluate(
Expand Down
5 changes: 5 additions & 0 deletions nvflare/app_common/ccwf/cyclic_client_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.ccwf.client_ctl import ClientSideController
from nvflare.app_common.ccwf.common import Constant, ResultType, RROrder, rotate_to_front
from nvflare.fuel.utils.validation_utils import check_non_empty_str


class CyclicClientController(ClientSideController):
Expand All @@ -35,6 +36,10 @@ def __init__(
learn_task_send_timeout=Constant.LEARN_TASK_SEND_TIMEOUT,
final_result_send_timeout=Constant.FINAL_RESULT_SEND_TIMEOUT,
):
check_non_empty_str("learn_task_name", learn_task_name)
check_non_empty_str("persistor_id", persistor_id)
check_non_empty_str("shareable_generator_id", shareable_generator_id)

super().__init__(
task_name_prefix=task_name_prefix,
learn_task_name=learn_task_name,
Expand Down
3 changes: 0 additions & 3 deletions nvflare/app_common/ccwf/cyclic_server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ def __init__(
progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT,
rr_order: str = RROrder.FIXED,
):
if not result_clients:
result_clients = []

super().__init__(
num_rounds=num_rounds,
task_name_prefix=task_name_prefix,
Expand Down
11 changes: 10 additions & 1 deletion nvflare/app_common/ccwf/server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
start_task_timeout=Constant.START_TASK_TIMEOUT,
task_check_period: float = Constant.TASK_CHECK_INTERVAL,
job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL,
starting_client: str = None,
starting_client=None,
starting_client_policy: str = DefaultPolicy.ANY,
starting_client_allow_none=False,
participating_clients=None,
Expand Down Expand Up @@ -89,6 +89,15 @@ def __init__(
progress_timeout:
"""
Controller.__init__(self, task_check_period)
if not participating_clients:
participating_clients = []

if not result_clients:
result_clients = []

if not starting_client:
starting_client = ""

self.task_name_prefix = task_name_prefix
self.configure_task_name = make_task_name(task_name_prefix, Constant.BASENAME_CONFIG)
self.configure_task_timeout = configure_task_timeout
Expand Down
Loading

0 comments on commit f65908b

Please sign in to comment.