From 1605e5d5cb6fd10552065f30769d011a4457118b Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 22 Mar 2023 15:43:40 +0800 Subject: [PATCH 1/7] add `FLStatsHandler` Signed-off-by: KumoLiu --- docs/source/handlers.rst | 6 + monai/handlers/__init__.py | 1 + monai/handlers/nvflare_stats_handler.py | 224 ++++++++++++++++++++++++ 3 files changed, 231 insertions(+) create mode 100644 monai/handlers/nvflare_stats_handler.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 7da7f7f50d..763f2a1f13 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -143,6 +143,12 @@ Tensorboard handlers :members: +NVFlare stats handlers +-------------------- +.. autoclass:: FLStatsHandler + :members: + + LR Schedule handler ------------------- .. autoclass:: LrScheduleHandler diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index f032191043..397bf56279 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -40,5 +40,6 @@ from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler +from .nvflare_stats_handler import FLStatsHandler from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler diff --git a/monai/handlers/nvflare_stats_handler.py b/monai/handlers/nvflare_stats_handler.py new file mode 100644 index 0000000000..296fa007e6 --- /dev/null +++ b/monai/handlers/nvflare_stats_handler.py @@ -0,0 +1,224 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any + +import torch + +from monai.fl.utils.constants import ExtraItems +from monai.config import IgniteInfo +from monai.utils import is_scalar, min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +AnalyticsDataType, _ = optional_import("nvflare.apis.analytix", name="AnalyticsDataType") +Widget, _ = optional_import("nvflare.widgets.widget", name="Widget") + +if TYPE_CHECKING: + from ignite.engine import Engine + from tensorboardX import SummaryWriter as SummaryWriterX + from torch.utils.tensorboard import SummaryWriter +else: + Engine, _ = optional_import( + "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator" + ) + +DEFAULT_TAG = "Loss" + + +class FLStatsHandler: + """ + FLStatsHandler defines a set of Ignite Event-handlers for all the NVFlare ``AnalyticsSender`` logics. + It can be used for any Ignite Engine(trainer, validator and evaluator). + And it can support both epoch level and iteration level with pre-defined AnalyticsSender event sender. + The expected data source is Ignite ``engine.state.output`` and ``engine.state.metrics``. + + Default behaviors: + - When EPOCH_COMPLETED, write each dictionary item in + ``engine.state.metrics`` to TensorBoard. + - When ITERATION_COMPLETED, write each dictionary item in + ``self.output_transform(engine.state.output)`` to TensorBoard. + + Usage example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + + """ + + def __init__( + self, + stats_sender: Widget | None = None, + iteration_log: bool | Callable[[Engine, int], bool] = True, + epoch_log: bool | Callable[[Engine, int], bool] = True, + output_transform: Callable = lambda x: x[0], + global_epoch_transform: Callable = lambda x: x, + state_attributes: Sequence[str] | None = None, + state_attributes_type: AnalyticsDataType | None = None, + tag_name: str = DEFAULT_TAG, + ) -> None: + """ + Args: + stats_sender: user can specify AnalyticsSender. + iteration_log: whether to send data when iteration completed, default to `True`. + ``iteration_log`` can be also a function and it will be interpreted as an event filter + (see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details). + Event filter function accepts as input engine and event value (iteration) and should return True/False. + epoch_log: whether to send data when epoch completed, default to `True`. + ``epoch_log`` can be also a function and it will be interpreted as an event filter. + See ``iteration_log`` argument for more details. + output_transform: a callable that is used to transform the + ``ignite.engine.state.output`` into a scalar to plot, or a dictionary of {key: scalar}. + In the latter case, the output string will be formatted as key: value. + By default this value plotting happens when every iteration completed. + The default behavior is to print loss from output[0] as output is a decollated list + and we replicated loss value for every item of the decollated list. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + global_epoch_transform: a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine might want to use trainer engines epoch number + when plotting epoch vs metric curves. + state_attributes: expected attributes from `engine.state`, if provided, will extract them + when epoch completed. + state_attributes_type: the type of the expected attributes from `engine.state`. + Only required when `state_attributes` is not None. + tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``. + """ + + super().__init__() + self._sender = stats_sender + self.iteration_log = iteration_log + self.epoch_log = epoch_log + self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform + self.state_attributes = state_attributes + self.state_attributes_type = state_attributes_type + self.tag_name = tag_name + + def attach(self, engine: Engine) -> None: + """ + Register a set of Ignite Event-Handlers to a specified Ignite engine. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): + event = Events.ITERATION_COMPLETED + if callable(self.iteration_log): # substitute event with new one using filter callable + event = event(event_filter=self.iteration_log) + engine.add_event_handler(event, self.iteration_completed) + if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): + event = Events.EPOCH_COMPLETED + if callable(self.epoch_log): # substitute event with new one using filter callable + event = event(event_filter=self.epoch_log) + engine.add_event_handler(event, self.epoch_completed) + + def epoch_completed(self, engine: Engine) -> None: + """ + Handler for train or validation/evaluation epoch completed Event. + Write epoch level events, default values are from Ignite `engine.state.metrics` dict. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + self._sender = engine.state.extra.get(ExtraItems.STATS_SENDER, self._sender) + self._default_epoch_sender(engine, self._sender) + + def iteration_completed(self, engine: Engine) -> None: + """ + Handler for train or validation/evaluation iteration completed Event. + Write iteration level events, default values are from Ignite `engine.state.output`. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + self._sender = engine.state.extra.get(ExtraItems.STATS_SENDER, self._sender) + self._default_iteration_sender(engine, self._sender) + + def _send_stats( + self, _engine: Engine, sender, tag: str, value: Any, data_type: AnalyticsDataType, step: int + ) -> None: + """ + Write scale value into TensorBoard. + Default to call `Summarysender.add_scalar()`. + + Args: + _engine: Ignite Engine, unused argument. + sender: AnalyticsSender. + tag: tag name in the TensorBoard. + value: value of the scalar data for current step. + step: index of current step. + + """ + sender._add(tag, value, data_type, step) + + def _default_epoch_sender(self, engine: Engine, sender: Widget) -> None: + """ + Execute epoch level event write operation. + Default to write the values from Ignite `engine.state.metrics` dict and + write the values of specified attributes of `engine.state`. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + sender: AnalyticsSender. + + """ + current_epoch = self.global_epoch_transform(engine.state.epoch) + summary_dict = engine.state.metrics + for name, value in summary_dict.items(): + self._send_stats(engine, sender, name, value, AnalyticsDataType.SCALAR, current_epoch) + + if self.state_attributes is not None: + for attr in self.state_attributes: + self._send_stats(engine, sender, attr, getattr(engine.state, attr, None), self.state_attributes_type, current_epoch) + sender.flush() + + def _default_iteration_sender(self, engine: Engine, sender: Widget) -> None: + """ + Execute iteration level event write operation based on Ignite `engine.state.output` data. + Extract the values from `self.output_transform(engine.state.output)`. + Since `engine.state.output` is a decollated list and we replicated the loss value for every item + of the decollated list, the default behavior is to track the loss from `output[0]`. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + sender: AnalyticsSender. + + """ + loss = self.output_transform(engine.state.output) + if loss is None: + return # do nothing if output is empty + if isinstance(loss, dict): + data_type = AnalyticsDataType.SCALARS + elif is_scalar(loss): # not printing multi dimensional output + data_type = AnalyticsDataType.SCALAR + else: + warnings.warn( + "ignoring non-scalar output in FLStatsHandler," + " make sure `output_transform(engine.state.output)` returns" + " a scalar or a dictionary of key and scalar pairs to avoid this warning." + " {}".format(type(loss)) + ) + + self._send_stats( + _engine=engine, + sender=sender, + tag=self.tag_name, + value=loss.item() if isinstance(loss, torch.Tensor) else loss, + data_type=data_type, + step=engine.state.iteration, + ) + sender.flush() From 19070ef88f3e3f7689bbaf4fd710b0e6ba92d38d Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 22 Mar 2023 15:44:09 +0800 Subject: [PATCH 2/7] add extra in engine.state Signed-off-by: KumoLiu --- monai/engines/workflow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 30622c2b93..5d9b1b27c1 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -158,6 +158,7 @@ def set_sampler_epoch(engine: Engine) -> None: key_metric_name=None, # we can set many metrics, only use key_metric to compare and save the best model best_metric=-1, best_metric_epoch=-1, + extra={}, # extra sharable data for the workflow based on Ignite engine.state ) self.data_loader = data_loader self.non_blocking = non_blocking From aa0d91e046582f578e8a04a10afe2cd76a41a8ac Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 22 Mar 2023 15:45:12 +0800 Subject: [PATCH 3/7] update `MonaiAlgo` Signed-off-by: KumoLiu --- monai/fl/client/monai_algo.py | 10 +++++++++- monai/fl/utils/constants.py | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 031143c69b..63b46a0846 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -14,7 +14,7 @@ import logging import os from collections.abc import Mapping, MutableMapping -from typing import Any, cast +from typing import Any, cast, Callable import torch import torch.distributed as dist @@ -381,6 +381,7 @@ def __init__( eval_data_key: str | None = BundleKeys.VALID_DATA, data_stats_transform_list: list | None = None, tracking: str | dict | None = None, + stats_sender: Callable | None = None ): self.logger = logger if config_evaluate_filename == "default": @@ -404,6 +405,7 @@ def __init__( self.eval_data_key = eval_data_key self.data_stats_transform_list = data_stats_transform_list self.tracking = tracking + self.stats_sender = stats_sender self.app_root = "" self.train_parser: ConfigParser | None = None @@ -501,6 +503,12 @@ def initialize(self, extra=None): BundleKeys.EVALUATOR, default=ConfigItem(None, BundleKeys.EVALUATOR) ) + # set stats sender for nvflare + self.stats_sender = extra.get(ExtraItems.STATS_SENDER, self.stats_sender) + if self.stats_sender is not None: + self.trainer.state.extra[ExtraItems.STATS_SENDER] = self.stats_sender + self.evaluator.state.extra[ExtraItems.STATS_SENDER] = self.stats_sender + # Get filters self.pre_filters = self.filter_parser.get_parsed_content( FiltersType.PRE_FILTERS, default=ConfigItem(None, FiltersType.PRE_FILTERS) diff --git a/monai/fl/utils/constants.py b/monai/fl/utils/constants.py index fbd18b364c..d95c2ee71f 100644 --- a/monai/fl/utils/constants.py +++ b/monai/fl/utils/constants.py @@ -29,6 +29,7 @@ class ExtraItems(StrEnum): MODEL_TYPE = "fl_model_type" CLIENT_NAME = "fl_client_name" APP_ROOT = "fl_app_root" + STATS_SENDER = "fl_stats_sender" class FlPhase(StrEnum): From 3783448f8ca2a6a37e159e20ed932e8da3557e71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Mar 2023 07:54:59 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/handlers/nvflare_stats_handler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/handlers/nvflare_stats_handler.py b/monai/handlers/nvflare_stats_handler.py index 296fa007e6..e529951455 100644 --- a/monai/handlers/nvflare_stats_handler.py +++ b/monai/handlers/nvflare_stats_handler.py @@ -27,8 +27,6 @@ if TYPE_CHECKING: from ignite.engine import Engine - from tensorboardX import SummaryWriter as SummaryWriterX - from torch.utils.tensorboard import SummaryWriter else: Engine, _ = optional_import( "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator" From 66b5c023cb78e5cd045429ec3b9250382b1d1d69 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 22 Mar 2023 16:08:04 +0800 Subject: [PATCH 5/7] fix flake8 Signed-off-by: KumoLiu --- docs/source/handlers.rst | 2 +- monai/fl/client/monai_algo.py | 4 ++-- monai/handlers/__init__.py | 2 +- monai/handlers/nvflare_stats_handler.py | 6 ++++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 763f2a1f13..99213efb3e 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -144,7 +144,7 @@ Tensorboard handlers NVFlare stats handlers --------------------- +---------------------- .. autoclass:: FLStatsHandler :members: diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 63b46a0846..6b54709b5c 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -14,7 +14,7 @@ import logging import os from collections.abc import Mapping, MutableMapping -from typing import Any, cast, Callable +from typing import Any, Callable, cast import torch import torch.distributed as dist @@ -381,7 +381,7 @@ def __init__( eval_data_key: str | None = BundleKeys.VALID_DATA, data_stats_transform_list: list | None = None, tracking: str | dict | None = None, - stats_sender: Callable | None = None + stats_sender: Callable | None = None, ): self.logger = logger if config_evaluate_filename == "default": diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 397bf56279..219a1caaa6 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -29,6 +29,7 @@ from .metrics_reloaded_handler import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler from .metrics_saver import MetricsSaver from .mlflow_handler import MLFlowHandler +from .nvflare_stats_handler import FLStatsHandler from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler from .panoptic_quality import PanopticQuality from .parameter_scheduler import ParamSchedulerHandler @@ -40,6 +41,5 @@ from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler -from .nvflare_stats_handler import FLStatsHandler from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler diff --git a/monai/handlers/nvflare_stats_handler.py b/monai/handlers/nvflare_stats_handler.py index 296fa007e6..c9e8930aac 100644 --- a/monai/handlers/nvflare_stats_handler.py +++ b/monai/handlers/nvflare_stats_handler.py @@ -17,8 +17,8 @@ import torch -from monai.fl.utils.constants import ExtraItems from monai.config import IgniteInfo +from monai.fl.utils.constants import ExtraItems from monai.utils import is_scalar, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") @@ -183,7 +183,9 @@ def _default_epoch_sender(self, engine: Engine, sender: Widget) -> None: if self.state_attributes is not None: for attr in self.state_attributes: - self._send_stats(engine, sender, attr, getattr(engine.state, attr, None), self.state_attributes_type, current_epoch) + self._send_stats( + engine, sender, attr, getattr(engine.state, attr, None), self.state_attributes_type, current_epoch + ) sender.flush() def _default_iteration_sender(self, engine: Engine, sender: Widget) -> None: From 1e8e2465dcc430e50a565c745d3e8dd3c04be2f7 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 22 Mar 2023 16:15:04 +0800 Subject: [PATCH 6/7] fix rst Signed-off-by: KumoLiu --- docs/source/handlers.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 99213efb3e..31ef763a3a 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -143,8 +143,8 @@ Tensorboard handlers :members: -NVFlare stats handlers ----------------------- +NVFlare stats handler +--------------------- .. autoclass:: FLStatsHandler :members: From 1f4eeb522501d9ceaf5eeb2b5635fecf6d6f05dd Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 27 Mar 2023 17:20:46 +0800 Subject: [PATCH 7/7] using attach to add stats sender handler Signed-off-by: KumoLiu --- monai/fl/client/monai_algo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 6b54709b5c..344e84e193 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -506,8 +506,8 @@ def initialize(self, extra=None): # set stats sender for nvflare self.stats_sender = extra.get(ExtraItems.STATS_SENDER, self.stats_sender) if self.stats_sender is not None: - self.trainer.state.extra[ExtraItems.STATS_SENDER] = self.stats_sender - self.evaluator.state.extra[ExtraItems.STATS_SENDER] = self.stats_sender + self.stats_sender.attach(self.trainer) + self.stats_sender.attach(self.evaluator) # Get filters self.pre_filters = self.filter_parser.get_parsed_content(