Skip to content

Commit

Permalink
Add stats_sender to MonaiAlgo for FL stats
Browse files Browse the repository at this point in the history
  • Loading branch information
nvkevlu committed Sep 13, 2023
1 parent c22a2bd commit d9775b8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
10 changes: 9 additions & 1 deletion monai/fl/client/monai_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import time
from collections.abc import Mapping, MutableMapping
from typing import Any, cast
from typing import Any, Callable, cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -359,6 +359,7 @@ def __init__(
eval_workflow_name: str = "train",
train_workflow: BundleWorkflow | None = None,
eval_workflow: BundleWorkflow | None = None,
stats_sender: Callable | None = None,
):
self.logger = logger
self.bundle_root = bundle_root
Expand Down Expand Up @@ -390,6 +391,7 @@ def __init__(
if not isinstance(eval_workflow, BundleWorkflow) or eval_workflow.get_workflow_type() is None:
raise ValueError("train workflow must be BundleWorkflow and set type.")
self.eval_workflow = eval_workflow
self.stats_sender = stats_sender

self.app_root = ""
self.filter_parser: ConfigParser | None = None
Expand Down Expand Up @@ -478,6 +480,12 @@ def initialize(self, extra=None):
if len(config_filter_files) > 0:
self.filter_parser.read_config(config_filter_files)

# set stats sender for nvflare
self.stats_sender = extra.get(ExtraItems.STATS_SENDER, self.stats_sender)
if self.stats_sender is not None:
self.stats_sender.attach(self.trainer)
self.stats_sender.attach(self.evaluator)

# Get filters
self.pre_filters = self.filter_parser.get_parsed_content(
FiltersType.PRE_FILTERS, default=ConfigItem(None, FiltersType.PRE_FILTERS)
Expand Down
1 change: 1 addition & 0 deletions monai/fl/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ExtraItems(StrEnum):
MODEL_TYPE = "fl_model_type"
CLIENT_NAME = "fl_client_name"
APP_ROOT = "fl_app_root"
STATS_SENDER = "fl_stats_sender"


class FlPhase(StrEnum):
Expand Down

0 comments on commit d9775b8

Please sign in to comment.