From b1e2e07dad01232284f6535c540eea1faf2b42ad Mon Sep 17 00:00:00 2001 From: Shawn Carere Date: Tue, 30 Jul 2024 12:08:23 -0400 Subject: [PATCH 01/11] Updates to nnunet client, created nnunet server --- .gitignore | 1 + examples/nnunet_example/README.md | 3 + examples/nnunet_example/client.py | 177 ++++++++++++++++++++++ examples/nnunet_example/config.yaml | 8 + examples/nnunet_example/server.py | 121 +++++++++++++++ fl4health/server/base_server.py | 70 ++++++++- fl4health/utils/load_data.py | 26 ++++ fl4health/utils/msd_dataset_sources.py | 63 ++++++++ research/picai/fl_nnunet/config.yaml | 8 +- research/picai/fl_nnunet/nnunet_client.py | 80 ++++++++-- research/picai/fl_nnunet/nnunet_server.py | 86 +++++++++++ research/picai/fl_nnunet/start_client.py | 20 +-- research/picai/fl_nnunet/start_server.py | 50 +++--- tests/smoke_tests/nnunet_config.yaml | 8 + 14 files changed, 669 insertions(+), 52 deletions(-) create mode 100644 examples/nnunet_example/README.md create mode 100644 examples/nnunet_example/client.py create mode 100644 examples/nnunet_example/config.yaml create mode 100644 examples/nnunet_example/server.py create mode 100644 fl4health/utils/msd_dataset_sources.py create mode 100644 research/picai/fl_nnunet/nnunet_server.py create mode 100644 tests/smoke_tests/nnunet_config.yaml diff --git a/.gitignore b/.gitignore index c207f71ca..059adf536 100644 --- a/.gitignore +++ b/.gitignore @@ -150,6 +150,7 @@ settings.json **/datasets/skin_cancer/PAD-UFES-20/** **/datasets/skin_cancer/ISIC_2019/** **/datasets/skin_cancer/Derm7pt/** +**/datasets/nnunet/** # logs diff --git a/examples/nnunet_example/README.md b/examples/nnunet_example/README.md new file mode 100644 index 000000000..2c72decda --- /dev/null +++ b/examples/nnunet_example/README.md @@ -0,0 +1,3 @@ +# NnUNetClient Example + +This example demonstrates how to use the NnUNetClient to train nnunet segmentation models in a federated setting. diff --git a/examples/nnunet_example/client.py b/examples/nnunet_example/client.py new file mode 100644 index 000000000..6e3fea145 --- /dev/null +++ b/examples/nnunet_example/client.py @@ -0,0 +1,177 @@ +import argparse +import os +import warnings +from logging import INFO +from os.path import exists, join +from pathlib import Path +from typing import Union + +with warnings.catch_warnings(): + # Need to import lightning utilities now in order to avoid deprecation + # warnings. Ignore flake8 warning saying that it is unused + # lightning utilities is imported by some of the dependencies + # so by importing it now and filtering the warnings + # https://github.com/Lightning-AI/utilities/issues/119 + warnings.filterwarnings("ignore", category=DeprecationWarning) + import lightning_utilities # noqa: F401 + + # Some finicky import stuff, if i don't silence deprecation warnings when + # importing flower then i get unsilenceable deprecation warning from a + # different api (batch generators) + # Issue: https://github.com/MIC-DKFZ/nnUNet/issues/2370 + from flwr.client import start_client + +import torch +from flwr.common.logger import log +from torchmetrics.segmentation import GeneralizedDiceScore + +from fl4health.utils.load_data import load_msd_dataset +from fl4health.utils.metrics import TorchMetric, TransformsMetric +from fl4health.utils.msd_dataset_sources import get_msd_dataset_enum, msd_num_labels +from research.picai.fl_nnunet.transforms import get_annotations_from_probs, get_probabilities_from_logits + + +def main( + dataset_path: Path, + msd_dataset_name: str, + server_address: str, + fold: Union[int, str], + always_preprocess: bool = False, +) -> None: + + # Log device and server address + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Using device: {DEVICE}") + log(INFO, f"Using server address: {server_address}") + + # Load the dataset if necessary + msd_dataset_enum = get_msd_dataset_enum(msd_dataset_name) + nnUNet_raw = join(dataset_path, "nnunet_raw") + if not exists(join(nnUNet_raw, msd_dataset_enum.value)): + log(INFO, f"Downloading and extracting {msd_dataset_enum.value} dataset") + load_msd_dataset(nnUNet_raw, msd_dataset_name) + + # The dataset ID will be the same as the MSD Task number + dataset_id = int(msd_dataset_enum.value[4:6]) + nnunet_dataset_name = f"Dataset{dataset_id:03d}_{msd_dataset_enum.value.split('_')[1]}" + + # Convert the msd dataset if necessary + if not exists(join(nnUNet_raw, nnunet_dataset_name)): + log(INFO, f"Converting {msd_dataset_enum.value} into nnunet dataset") + convert_msd_dataset(source_folder=join(nnUNet_raw, msd_dataset_enum.value)) + + # Create a metric + dice = TransformsMetric( + metric=TorchMetric( + name="Pseudo DICE", + metric=GeneralizedDiceScore( + num_classes=msd_num_labels[msd_dataset_enum], weight_type="square", include_background=False + ).to(DEVICE), + ), + transforms=[get_probabilities_from_logits, get_annotations_from_probs], + ) + + # Create client + client = nnUNetClient( + # Args specific to nnUNetClient + dataset_id=dataset_id, + fold=fold, + always_preprocess=always_preprocess, + # BaseClient Args + device=DEVICE, + metrics=[dice], + data_path=dataset_path, # Argument not actually used by nnUNetClient + ) + + start_client(server_address=server_address, client=client.to_client()) + + # Shutdown the client + client.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="nnunet_example/client.py", + description="""An exampled of nnUNetClient on any of the Medical + Segmentation Decathelon (MSD) datasets. Automatically generates a + nnunet segmentation model and trains it in a federated setting""", + ) + + # I have to use underscores instead of dashes because thats how they + # defined it in run smoke tests + parser.add_argument( + "--dataset_path", + type=str, + required=True, + help="""Path to the folder in which data should be stored. This script + will automatically create nnunet_raw, and nnunet_preprocessed + subfolders if they don't already exist. This script will also + attempt to download and prepare the MSD Dataset into the + nnunet_raw folder if it does not already exist.""", + ) + parser.add_argument( + "--fold", + type=str, + required=False, + default="0", + help="""[OPTIONAL] Which fold of the local client dataset to use for + validation. nnunet defaults to 5 folds (0 to 4). Can also be set + to 'all' to use all the data for both training and validation. + Defaults to fold 0""", + ) + parser.add_argument( + "--msd_dataset_name", + type=str, + required=False, + default="Task04_Hippocampus", # The smallest dataset + help="""[OPTIONAL] Name of the MSD dataset to use. The options are + defined by the values of the MsdDataset enum as returned by the + get_msd_dataset_enum function""", + ) + parser.add_argument( + "--always-preprocess", + action="store_true", + required=False, + help="""[OPTIONAL] Use this to force preprocessing the nnunet data + even if the preprocessed data is found to already exist""", + ) + parser.add_argument( + "--server_address", + type=str, + required=False, + default="0.0.0.0:8080", + help="""[OPTIONAL] The server address for the clients to communicate + to the server through. Defaults to 0.0.0.0:8080""", + ) + args = parser.parse_args() + + # Create nnunet directory structure and set environment variables + nnUNet_raw = join(args.dataset_path, "nnunet_raw") + nnUNet_preprocessed = join(args.dataset_path, "nnunet_preprocessed") + if not exists(nnUNet_raw): + os.makedirs(nnUNet_raw) + if not exists(nnUNet_preprocessed): + os.makedirs(nnUNet_preprocessed) + os.environ["nnUNet_raw"] = nnUNet_raw + os.environ["nnUNet_preprocessed"] = nnUNet_preprocessed + os.environ["nnUNet_results"] = join(args.dataset_path, "nnunet_results") + log(INFO, "Setting nnunet environment variables") + log(INFO, f"\tnnUNet_raw: {nnUNet_raw}") + log(INFO, f"\tnnUNet_preprocessed: {nnUNet_preprocessed}") + log(INFO, f"\tnnUNet_results: {join(args.dataset_path, 'nnunet_results')}") + + # Everything that uses nnunetv2 module can only be imported after + # environment variables are changed + from nnunetv2.dataset_conversion.convert_MSD_dataset import convert_msd_dataset + + from research.picai.fl_nnunet.nnunet_client import nnUNetClient + + # Check fold argument and start main method + fold: Union[int, str] = "all" if args.fold == "all" else int(args.fold) + main( + dataset_path=Path(args.dataset_path), + msd_dataset_name=args.msd_dataset_name, + server_address=args.server_address, + fold=fold, + always_preprocess=args.always_preprocess, + ) diff --git a/examples/nnunet_example/config.yaml b/examples/nnunet_example/config.yaml new file mode 100644 index 000000000..d2371d0d6 --- /dev/null +++ b/examples/nnunet_example/config.yaml @@ -0,0 +1,8 @@ +# Parameters that describe the server +n_server_rounds: 1 + +# Parameters that describe the clients +n_clients: 1 +local_epochs: 1 + +nnunet_config: 2d diff --git a/examples/nnunet_example/server.py b/examples/nnunet_example/server.py new file mode 100644 index 000000000..b4b6ae26c --- /dev/null +++ b/examples/nnunet_example/server.py @@ -0,0 +1,121 @@ +import argparse +import json +import pickle +import warnings +from functools import partial +from typing import Optional + +with warnings.catch_warnings(): + # Need to import lightning utilities now in order to avoid deprecation + # warnings. Ignore flake8 warning saying that it is unused + # lightning utilities is imported by some of the dependencies + # so by importing it now and filtering the warnings + # https://github.com/Lightning-AI/utilities/issues/119 + warnings.filterwarnings("ignore", category=DeprecationWarning) + import lightning_utilities # noqa: F401 + +import flwr as fl +import torch +import yaml +from flwr.common.parameter import ndarrays_to_parameters +from flwr.common.typing import Config +from flwr.server.client_manager import SimpleClientManager +from flwr.server.strategy import FedAvg + +from examples.utils.functions import make_dict_with_epochs_or_steps +from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from research.picai.fl_nnunet.nnunet_server import NnUNetServer + + +def get_config( + current_server_round: int, + nnunet_config: str, + n_server_rounds: int, + batch_size: int, + n_clients: int, + nnunet_plans: Optional[str] = None, + local_epochs: Optional[int] = None, + local_steps: Optional[int] = None, +) -> Config: + # Create config + config: Config = { + "n_clients": n_clients, + "nnunet_config": nnunet_config, + "n_server_rounds": n_server_rounds, + "batch_size": batch_size, + **make_dict_with_epochs_or_steps(local_epochs, local_steps), + "current_server_round": current_server_round, + } + + # Check if plans were provided + if nnunet_plans is not None: + plans_bytes = pickle.dumps(json.load(open(nnunet_plans, "r"))) + config["nnunet_plans"] = plans_bytes + + return config + + +def main(config: dict, server_address: str) -> None: + # Partial function with everything set except current server round + fit_config_fn = partial( + get_config, + n_clients=config["n_clients"], + nnunet_config=config["nnunet_config"], + n_server_rounds=config["n_server_rounds"], + batch_size=0, # Set this to 0 because we're not using it + nnunet_plans=config.get("nnunet_plans"), + local_epochs=config.get("local_epochs"), + local_steps=config.get("local_steps"), + ) + + if config.get("starting_checkpoint"): + model = torch.load(config["starting_checkpoint"]) + # Of course nnunet stores their pytorch models differently. + params = ndarrays_to_parameters([val.cpu().numpy() for _, val in model["network_weights"].items()]) + else: + params = None + + strategy = FedAvg( + min_fit_clients=config["n_clients"], + min_evaluate_clients=config["n_clients"], + min_available_clients=config["n_clients"], + on_fit_config_fn=fit_config_fn, + on_evaluate_config_fn=fit_config_fn, # Nothing changes for eval + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + initial_parameters=params, + ) + + server = NnUNetServer( + client_manager=SimpleClientManager(), + strategy=strategy, + ) + + fl.server.start_server( + server=server, + server_address=server_address, + config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), + ) + + # Shutdown server + server.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", action="store", type=str, help="Path to the configuration file") + parser.add_argument( + "--server-address", + type=str, + required=False, + default="0.0.0.0:8080", + help="""[OPTIONAL] The address to use for the server. Defaults to + 0.0.0.0:8080""", + ) + + args = parser.parse_args() + + with open(args.config_path, "r") as f: + config = yaml.safe_load(f) + + main(config, server_address=args.server_address) diff --git a/fl4health/server/base_server.py b/fl4health/server/base_server.py index 01db446bf..3af034b6e 100644 --- a/fl4health/server/base_server.py +++ b/fl4health/server/base_server.py @@ -1,12 +1,12 @@ import datetime -from logging import DEBUG, INFO, WARNING +from logging import DEBUG, INFO, WARN, WARNING from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union import torch.nn as nn from flwr.common import EvaluateRes, Parameters from flwr.common.logger import log from flwr.common.parameter import parameters_to_ndarrays -from flwr.common.typing import Scalar +from flwr.common.typing import Code, GetParametersIns, Scalar from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy from flwr.server.history import History @@ -350,3 +350,69 @@ def _hydrate_model_for_checkpointing(self) -> nn.Module: model_ndarrays = parameters_to_ndarrays(self.parameters) self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model) return self.server_model + + +class FlServerWithInitializer(FlServer): + initialized = False # Add attribute + + def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) -> Parameters: + """ + Get initial parameters from one of the available clients. Same as + parent function except we provide a config to the client when + requesting initial parameters + https://github.com/adap/flower/issues/3770 + + Note: + I have to use configure_fit to bypass mypy errors since + on_fit_config is not defined in the Strategy base class. The + downside is that configure fit will wait until enough clients for + training are present instead of just sampling one client. I + thought about defining a new init_config attribute but this + """ + # Server-side parameter initialization + parameters: Optional[Parameters] = self.strategy.initialize_parameters(client_manager=self._client_manager) + if parameters is not None: + log(INFO, "Using initial global parameters provided by strategy") + return parameters + + # Get initial parameters from one of the clients + log(INFO, "Requesting initial parameters from one random client") + random_client = self._client_manager.sample(1)[0] + dummy_params = Parameters([], "None") + config = self.strategy.configure_fit(server_round, dummy_params, self._client_manager)[0][1].config + ins = GetParametersIns(config=config) + get_parameters_res = random_client.get_parameters(ins=ins, timeout=timeout, group_id=server_round) + if get_parameters_res.status.code == Code.OK: + log(INFO, "Received initial parameters from one random client") + else: + log( + WARN, + "Failed to receive initial parameters from the client." " Empty initial parameters will be used.", + ) + return get_parameters_res.parameters + + def initialize(self, server_round: int, timeout: Optional[float] = None) -> None: + """ + Hook method to allow the server to do some additional initialization + prior to training. For example, NnUNetServer uses this method to ask a + client to initialize the global nnunet plans if one is not provided in + in the config + + Args: + server_round (int): The current server round. This hook method is + only called with a server_round=0 at the beginning of self.fit + timeout (Optional[float], optional): The server's timeout + parameter. Useful if one is requesting information from a + client Defaults to None. + """ + self.initialized = True + + def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + """ + Same as parent method except initialize hook method is called first + """ + # Initialize the server + if not self.initialized: + self.initialize(server_round=0, timeout=timeout) + + return super().fit(num_rounds, timeout) diff --git a/fl4health/utils/load_data.py b/fl4health/utils/load_data.py index 4bb54b67f..920240698 100644 --- a/fl4health/utils/load_data.py +++ b/fl4health/utils/load_data.py @@ -1,4 +1,5 @@ import random +import warnings from logging import INFO from pathlib import Path from typing import Callable, Dict, Optional, Tuple @@ -12,8 +13,14 @@ from fl4health.utils.dataset import TensorDataset from fl4health.utils.dataset_converter import DatasetConverter +from fl4health.utils.msd_dataset_sources import get_msd_dataset_enum, msd_md5_hashes, msd_urls from fl4health.utils.sampler import LabelBasedSampler +with warnings.catch_warnings(): + # ignoring some annoying scipy deprecation warnings + warnings.simplefilter("ignore", category=DeprecationWarning) + from monai.apps.utils import download_and_extract + class ToNumpy: def __call__(self, tensor: torch.Tensor) -> np.ndarray: @@ -169,3 +176,22 @@ def load_cifar10_test_data( evaluation_loader = DataLoader(evaluation_set, batch_size=batch_size, shuffle=False) num_examples = {"eval_set": len(evaluation_set)} return evaluation_loader, num_examples + + +def load_msd_dataset(data_path: str, msd_dataset_name: str) -> None: + """ + Downloads and extracts one of the 10 Medical Segmentation Decathelon (MSD) + datasets. + + Args: + data_path (str): Path to the folder in which to extract the + dataset. The data itself will be in a subfolder named after the + dataset, not in the data_path directory itself. The name of the + folder will be the name of the dataset as defined by the values of + the MsdDataset enum returned by get_msd_dataset_enum + msd_dataset_name (str): One of the 10 msd datasets + """ + msd_enum = get_msd_dataset_enum(msd_dataset_name) + msd_hash = msd_md5_hashes[msd_enum] + url = msd_urls[msd_enum] + download_and_extract(url=url, output_dir=data_path, hash_val=msd_hash, hash_type="md5", progress=True) diff --git a/fl4health/utils/msd_dataset_sources.py b/fl4health/utils/msd_dataset_sources.py new file mode 100644 index 000000000..ba07cc808 --- /dev/null +++ b/fl4health/utils/msd_dataset_sources.py @@ -0,0 +1,63 @@ +from enum import Enum + + +class MsdDataset(Enum): + TASK01_BRAINTUMOUR = "Task01_BrainTumour" + TASK02_HEART = "Task02_Heart" + TASK03_LIVER = "Task03_Liver" + TASK04_HIPPOCAMPUS = "Task04_Hippocampus" + TASK05_PROSTATE = "Task05_Prostate" + TASK06_LUNG = "Task06_Lung" + TASK07_PANCREAS = "Task07_Pancreas" + TASK08_HEPATICVESSEL = "Task08_HepaticVessel" + TASK09_SPLEEN = "Task09_Spleen" + TASK10_COLON = "Task10_Colon" + + +def get_msd_dataset_enum(dataset_name: str) -> MsdDataset: + try: + return MsdDataset(dataset_name) + except Exception as e: + raise e + + +msd_urls = { + MsdDataset.TASK01_BRAINTUMOUR: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task01_BrainTumour.tar", + MsdDataset.TASK02_HEART: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task02_Heart.tar", + MsdDataset.TASK03_LIVER: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task03_Liver.tar", + MsdDataset.TASK04_HIPPOCAMPUS: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tar", + MsdDataset.TASK05_PROSTATE: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task05_Prostate.tar", + MsdDataset.TASK06_LUNG: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task06_Lung.tar", + MsdDataset.TASK07_PANCREAS: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task07_Pancreas.tar", + MsdDataset.TASK08_HEPATICVESSEL: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task08_HepaticVessel.tar", + MsdDataset.TASK09_SPLEEN: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar", + MsdDataset.TASK10_COLON: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task10_Colon.tar", +} + +msd_md5_hashes = { + MsdDataset.TASK01_BRAINTUMOUR: "240a19d752f0d9e9101544901065d872", + MsdDataset.TASK02_HEART: "06ee59366e1e5124267b774dbd654057", + MsdDataset.TASK03_LIVER: "a90ec6c4aa7f6a3d087205e23d4e6397", + MsdDataset.TASK04_HIPPOCAMPUS: "9d24dba78a72977dbd1d2e110310f31b", + MsdDataset.TASK05_PROSTATE: "35138f08b1efaef89d7424d2bcc928db", + MsdDataset.TASK06_LUNG: "8afd997733c7fc0432f71255ba4e52dc", + MsdDataset.TASK07_PANCREAS: "4f7080cfca169fa8066d17ce6eb061e4", + MsdDataset.TASK08_HEPATICVESSEL: "641d79e80ec66453921d997fbf12a29c", + MsdDataset.TASK09_SPLEEN: "410d4a301da4e5b2f6f86ec3ddba524e", + MsdDataset.TASK10_COLON: "bad7a188931dc2f6acf72b08eb6202d0", +} + +# The number of classes for each MSD Dataset (including background) +# I got these from the paper, didn't download all the datasets to double check +msd_num_labels = { + MsdDataset.TASK01_BRAINTUMOUR: 4, + MsdDataset.TASK02_HEART: 2, + MsdDataset.TASK03_LIVER: 3, + MsdDataset.TASK04_HIPPOCAMPUS: 3, + MsdDataset.TASK05_PROSTATE: 3, + MsdDataset.TASK06_LUNG: 2, + MsdDataset.TASK07_PANCREAS: 3, + MsdDataset.TASK08_HEPATICVESSEL: 3, + MsdDataset.TASK09_SPLEEN: 2, + MsdDataset.TASK10_COLON: 2, +} diff --git a/research/picai/fl_nnunet/config.yaml b/research/picai/fl_nnunet/config.yaml index 850008a28..c43bd2e42 100644 --- a/research/picai/fl_nnunet/config.yaml +++ b/research/picai/fl_nnunet/config.yaml @@ -1,9 +1,7 @@ # You should set these yourself n_clients: 1 nnunet_config: 2d -nnunet_plans: /home/shawn/Code/nnunet_storage/nnUNet_preprocessed/Dataset012_PICAI-debug/nnUNetPlans.json -fold: 0 +# nnunet_plans: /home/shawn/Code/nnunet_storage/nnUNet_preprocessed/Dataset012_PICAI-debug/nnUNetPlans.json n_server_rounds: 1 -local_epochs: 3 -server_address: '0.0.0.0:8080' -starting_checkpoint: /home/shawn/Code/nnunet_storage/nnUNet_results/Dataset012_PICAI-debug/nnUNetTrainer_1epoch__nnUNetPlans__2d/fold_0/checkpoint_best.pth +local_epochs: 2 +# starting_checkpoint: /home/shawn/Code/nnunet_storage/nnUNet_results/Dataset012_PICAI-debug/nnUNetTrainer_1epoch__nnUNetPlans__2d/fold_0/checkpoint_best.pth diff --git a/research/picai/fl_nnunet/nnunet_client.py b/research/picai/fl_nnunet/nnunet_client.py index 1c45ff239..3ee29b643 100644 --- a/research/picai/fl_nnunet/nnunet_client.py +++ b/research/picai/fl_nnunet/nnunet_client.py @@ -3,14 +3,14 @@ import signal import warnings from logging import INFO -from os import makedirs +from os import makedirs, remove from os.path import exists, join from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch from flwr.common.logger import log -from flwr.common.typing import Config +from flwr.common.typing import Config, Scalar from torch import nn from torch.nn.modules.loss import _Loss from torch.optim import Optimizer @@ -37,12 +37,14 @@ # Raised an issue with nnunet. https://github.com/MIC-DKFZ/nnUNet/issues/2370 warnings.filterwarnings("ignore", category=DeprecationWarning) from batchgenerators.utilities.file_and_folder_operations import load_json, save_json + from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, preprocess_dataset from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw from nnunetv2.training.dataloading.utils import unpack_dataset from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name + # Get the default signal handlers used by python before flwr overrides them # We need these because the nnunet dataloaders spawn child processes # and flwr throws errors when those processes end. So we set the signal handlers @@ -136,6 +138,7 @@ def __init__( self.data_identifier = data_identifier self.always_preprocess: bool = always_preprocess self.plans_name = plans_identifier + self.fingerprint_extracted = False # nnunet specific attributes to be initialized in setup_client self.nnunet_trainer: nnUNetTrainer @@ -268,7 +271,22 @@ def maybe_preprocess(self, nnunet_config: NnUNetConfig) -> None: configurations=[nnunet_config.value], ) else: - log(INFO, "nnunet preprocessed data seems to already exist. Skipping preprocessing") + log(INFO, "\tnnunet preprocessed data seems to already exist. Skipping preprocessing") + + def maybe_extract_fingerprint(self) -> None: + """ + Checks if nnunet dataset fingerprint already exists and if not extracts one from the dataset + """ + fp_path = join(nnUNet_preprocessed, self.dataset_name, "dataset_fingerprint.json") + if self.always_preprocess or not exists(fp_path): + log(INFO, "\tExtracting nnunet dataset fingerprint") + with nostdout(): # prevent print statements from nnunet method + extract_fingerprints(dataset_ids=[self.dataset_id]) + else: + log(INFO, "\tnnunet dataset fingerprint already exists. Skipping fingerprint extraction") + + # Avoid extracting fingerprint multiple times when always_preprocess is true + self.fingerprint_extracted = True def setup_client(self, config: Config) -> None: """ @@ -290,15 +308,11 @@ def setup_client(self, config: Config) -> None: # Get nnunet config self.nnunet_config = get_valid_nnunet_config(self.narrow_config_type(config, "nnunet_config", str)) - # Check if dataset fingerprint has been extracted - if self.always_preprocess or not exists( - join(nnUNet_preprocessed, self.dataset_name, "dataset_fingerprint.json") - ): - log(INFO, "Extracting nnunet dataset fingerprint") - with nostdout(): # prevent print statements from nnunet method - extract_fingerprints(dataset_ids=[self.dataset_id]) + # Check if dataset fingerprint has already been extracted + if not self.fingerprint_extracted: + self.maybe_extract_fingerprint() else: - log(INFO, "nnunet dataset fingerprint already exists. Skipping fingerprint extraction") + log(INFO, "\tDataset fingerprint has already been extracted. Skipping.") # Create the nnunet plans for the local client self.plans = self.create_plans(config=config) @@ -521,3 +535,47 @@ def update_before_epoch(self, epoch: int) -> None: def get_client_specific_logs(self) -> Tuple[str, List[Tuple[LogLevel, str]]]: lr = self.optimizers["global"].param_groups[0]["lr"] return f" Current LR: {lr}", [] + + def get_properties(self, config: Config) -> Dict[str, Scalar]: + """ + Return properties (sample counts and nnunet plans) of client. + + If nnunet plans are not provided by the server, creates a new set of + nnunet plans from the local client dataset. These plans are intended + to be used for initializing global nnunet plans when they are not + provided. + + Args: + config (Config): The config from the server + + Returns: + Dict[str, Scalar]: A dictionary containing the train and + validation sample counts as well as the serialized nnunet plans + """ + # Check if nnunet plans have already been initialized + if "nnunet_plans" in config.keys(): + properties = super().get_properties(config) + properties["nnunet_plans"] = config["nnunet_plans"] + return properties + + # Check if local nnunet dataset fingerprint needs to be extracted + if not self.fingerprint_extracted: + self.maybe_extract_fingerprint() + + # Create experiment planner and plans + planner = ExperimentPlanner(dataset_name_or_id=self.dataset_id) + with nostdout(): # Prevent print statements from experiment planner + plans = planner.plan_experiment() + plans_bytes = pickle.dumps(plans) + + # Remove plans file that was created by planner + plans_path = join(nnUNet_preprocessed, self.dataset_name, planner.plans_identifier + ".json") + if exists(plans_path): + remove(plans_path) + + # return properties with initialized nnunet plans. Need to provide + # plans since client needs to be initialized to get properties + config["nnunet_plans"] = plans_bytes + properties = super().get_properties(config) + properties["nnunet_plans"] = pickle.dumps(plans_bytes) + return properties diff --git a/research/picai/fl_nnunet/nnunet_server.py b/research/picai/fl_nnunet/nnunet_server.py new file mode 100644 index 000000000..9f4a9a0c0 --- /dev/null +++ b/research/picai/fl_nnunet/nnunet_server.py @@ -0,0 +1,86 @@ +from logging import INFO, WARN +from typing import Any, Callable, List, Optional, Tuple, Union + +from flwr.common import Parameters +from flwr.common.logger import log +from flwr.common.typing import Code, Config, EvaluateIns, FitIns, GetPropertiesIns +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy + +from fl4health.server.base_server import FlServerWithInitializer + +FIT_CFG_FN = Callable[[int, Parameters, ClientManager], List[Tuple[ClientProxy, FitIns]]] +EVAL_CFG_FN = Callable[[int, Parameters, ClientManager], List[Tuple[ClientProxy, EvaluateIns]]] +CFG_FN = Union[FIT_CFG_FN, EVAL_CFG_FN] + + +def add_items_to_config_fn(fn: CFG_FN, items: Config) -> CFG_FN: + """ + Accepts a flwr Strategy configure function (either configure_fit or + configure_evaluate) and returns a new function that returns the same thing + except the dictionary items in the items argument have been added to the + config that is returned by the original function + + Args: + fn (CFG_FN): The Strategy configure function to wrap + items (Config): A Config containing additional items to update the + original config with + + Returns: + CFG_FN: The wrapped function. Argument and return type is the same + """ + + def new_fn(*args: Any, **kwargs: Any) -> Any: + cfg_ins = fn(*args, **kwargs) + for _, ins in cfg_ins: + ins.config.update(items) + return cfg_ins + + return new_fn + + +class NnUNetServer(FlServerWithInitializer): + """ + A Basic FlServer with added functionality to ask a client to initialize + the global nnunet plans if one was not provided in the config. Intended + for use with NnUNetClient + """ + + def initialize(self, server_round: int, timeout: Optional[float] = None) -> None: + # Get fit config + dummy_params = Parameters([], "None") + config = self.strategy.configure_fit(server_round, dummy_params, self._client_manager)[0][1].config + + # Check if plans need to be initialized + if config.get("nnunet_plans") is not None: + self.initialized = True + return + + # Sample properties from a random client to initialize plans + log(INFO, "") + log(INFO, "[PRE-INIT]") + log(INFO, "Requesting initialization of global nnunet plans from one random client via get_properties") + random_client = self._client_manager.sample(1)[0] + ins = GetPropertiesIns(config=config) + properties_res = random_client.get_properties(ins=ins, timeout=timeout, group_id=server_round) + + if properties_res.status.code == Code.OK: + log(INFO, "Recieved global nnunet plans from one random client") + else: + log(WARN, "Failed to receive properties from client to initialize nnnunet plans") + + properties = properties_res.properties + + # NnUNetClient has serialized nnunet_plans as a property + plans_bytes = properties["nnunet_plans"] + + # Wrap config functions so that nnunet_plans is included + log(INFO, "Wrapping strategy config functions to return nnunet_plans") + new_fit_cfg_fn = add_items_to_config_fn(self.strategy.configure_fit, {"nnunet_plans": plans_bytes}) + new_eval_cfg_fn = add_items_to_config_fn(self.strategy.configure_evaluate, {"nnunet_plans": plans_bytes}) + setattr(self.strategy, "configure_fit", new_fit_cfg_fn) + setattr(self.strategy, "configure_evaluate", new_eval_cfg_fn) + + # Finish + self.initialized = True + log(INFO, "") diff --git a/research/picai/fl_nnunet/start_client.py b/research/picai/fl_nnunet/start_client.py index 4b56885a6..f269fbe83 100644 --- a/research/picai/fl_nnunet/start_client.py +++ b/research/picai/fl_nnunet/start_client.py @@ -1,7 +1,7 @@ import argparse import warnings from logging import INFO -from os.path import join +from pathlib import Path from typing import Optional, Union with warnings.catch_warnings(): @@ -16,8 +16,6 @@ import flwr as fl import torch from flwr.common.logger import log -from nnunetv2.paths import nnUNet_preprocessed -from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name from torchmetrics.classification import Dice from torchmetrics.segmentation import GeneralizedDiceScore @@ -61,7 +59,6 @@ def main( metrics = [dice1, dice2] # Oddly each of these dice metrics is drastically different. # Create and start client - dataset_name = convert_id_to_dataset_name(dataset_id) client = nnUNetClient( # Args specific to nnUNetClient dataset_id=dataset_id, @@ -72,9 +69,7 @@ def main( # BaseClient Args device=DEVICE, metrics=metrics, - data_path=join( - nnUNet_preprocessed, dataset_name - ), # data_path is not actually used but is required by BaseClient + data_path=Path("dummy/path"), # Argument not used by nnUNetClient ) fl.client.start_client(server_address=server_address, client=client.to_client()) @@ -135,16 +130,7 @@ def main( args = parser.parse_args() # Convert fold to an integer if it is not 'all' - if args.fold != "all": - try: - fold = int(args.fold) - except ValueError as e: - print( - f"Unable to convert given value for fold to int: {args.fold}. Fold must be either 'all' or an integer" - ) - raise e - else: - fold = args.fold + fold: Union[int, str] = "all" if args.fold == "all" else int(args.fold) main( dataset_id=args.dataset_id, diff --git a/research/picai/fl_nnunet/start_server.py b/research/picai/fl_nnunet/start_server.py index dbda59298..890366b53 100644 --- a/research/picai/fl_nnunet/start_server.py +++ b/research/picai/fl_nnunet/start_server.py @@ -23,41 +23,47 @@ from flwr.server.strategy import FedAvg from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.server.base_server import FlServer # This is the lightning utils deprecation warning culprit from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from research.picai.fl_nnunet.nnunet_server import NnUNetServer def get_config( current_server_round: int, nnunet_config: str, - nnunet_plans: str, n_server_rounds: int, batch_size: int, n_clients: int, + nnunet_plans: Optional[str] = None, local_epochs: Optional[int] = None, local_steps: Optional[int] = None, ) -> Config: - nnunet_plans_dict = pickle.dumps(json.load(open(nnunet_plans, "r"))) - return { + # Create config + config: Config = { "n_clients": n_clients, "nnunet_config": nnunet_config, - "nnunet_plans": nnunet_plans_dict, "n_server_rounds": n_server_rounds, "batch_size": batch_size, **make_dict_with_epochs_or_steps(local_epochs, local_steps), "current_server_round": current_server_round, } + # Check if plans were provided + if nnunet_plans is not None: + plans_bytes = pickle.dumps(json.load(open(nnunet_plans, "r"))) + config["nnunet_plans"] = plans_bytes -def main(config: dict) -> None: + return config + + +def main(config: dict, server_address: str) -> None: # Partial function with everything set except current server round fit_config_fn = partial( get_config, n_clients=config["n_clients"], nnunet_config=config["nnunet_config"], - nnunet_plans=config["nnunet_plans"], n_server_rounds=config["n_server_rounds"], batch_size=0, # Set this to 0 because we're not using it + nnunet_plans=config.get("nnunet_plans"), local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), ) @@ -67,12 +73,12 @@ def main(config: dict) -> None: # Of course nnunet stores their pytorch models differently. params = ndarrays_to_parameters([val.cpu().numpy() for _, val in model["network_weights"].items()]) else: - raise Exception( - "There is a bug right now where params can not be None. \ - Therefore a starting checkpoint must be provided because I don't \ - want to mess up my code. I hav raised an issue with flwr" - ) - # params = None + # raise Exception( + # "There is a bug right now where params can not be None. \ + # Therefore a starting checkpoint must be provided because I don't \ + # want to mess up my code. I hav raised an issue with flwr" + # ) + params = None strategy = FedAvg( min_fit_clients=config["n_clients"], @@ -85,25 +91,35 @@ def main(config: dict) -> None: initial_parameters=params, ) - server = FlServer(client_manager=SimpleClientManager(), strategy=strategy) + # server = FlServer(client_manager=SimpleClientManager(), strategy=strategy) + server = NnUNetServer(client_manager=SimpleClientManager(), strategy=strategy) fl.server.start_server( server=server, - server_address=config["server_address"], + server_address=server_address, config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), ) # Shutdown server - # server.shutdown() + server.shutdown() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config-path", action="store", type=str, help="Path to the configuration file") + parser.add_argument( + "--server-address", + type=str, + required=False, + default="0.0.0.0:8080", + help="""[OPTIONAL] The address to use for the server. Defaults to + 0.0.0.0:8080""", + ) + args = parser.parse_args() with open(args.config_path, "r") as f: config = yaml.safe_load(f) - main(config) + main(config, args.server_address) diff --git a/tests/smoke_tests/nnunet_config.yaml b/tests/smoke_tests/nnunet_config.yaml new file mode 100644 index 000000000..1b76a61dc --- /dev/null +++ b/tests/smoke_tests/nnunet_config.yaml @@ -0,0 +1,8 @@ +# Parameters that describe the server +n_server_rounds: 1 + +# Parameters that describe the clients +n_clients: 1 +local_epochs: 2 + +nnunet_config: 2d From 9d4616b06aac1c305ba2c2c355a372ed3d8fce1d Mon Sep 17 00:00:00 2001 From: Shawn Carere Date: Tue, 30 Jul 2024 12:49:55 -0400 Subject: [PATCH 02/11] Updated readme --- examples/nnunet_example/README.md | 37 +++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/examples/nnunet_example/README.md b/examples/nnunet_example/README.md index 2c72decda..b19fd8192 100644 --- a/examples/nnunet_example/README.md +++ b/examples/nnunet_example/README.md @@ -1,3 +1,40 @@ # NnUNetClient Example This example demonstrates how to use the NnUNetClient to train nnunet segmentation models in a federated setting. + +By default this example trains an nnunet model on the Task04_Hippocampus dataset from the Medical Segmentation Decathelon (MSD). However, any of the MSD datasets can be used by specifying them with the msd_dataset_name flag for the client. To run this example first create a config file for the server. An example config has been provided in this directory. The required keys for the config are: + +```yaml +# Parameters that describe the server +n_server_rounds: 1 + +# Parameters that describe the clients +n_clients: 1 +local_epochs: 1 # Or local_steps, one or the other must be chosen + +nnunet_config: 2d +``` + +The only additional parameter required by nnunet is nnunet_config which is one of the official nnunet configurations (2d, 3d_fullres, 3d_lowres, 3d_cascade_fullres) + +One may also add the following optional keys to the config yaml file + +```yaml +# Optional config parameters +nnunet_plans: /Path/to/nnunet/plans/json +starting_checkpoint: /Path/to/starting/checkpoint.pth +``` + +To run a federated learning experiment with nnunet models, first ensure you are in the FL4Health directory and then start the nnunet server using the following command. To view a list of optional flags use the --help flag + +```bash +python -m examples.nnunet_example.server --config_path examples/nnunet_example/config.yaml +``` + +Once the server has started, start the necessary number of clients specified by the n_clients key in the config file. Each client can be started by running the following command in a seperate session. To view a list of optional flags use the --help flag. + +```bash +python -m examples.nnunet_example.client --dataset_path examples/datasets/nnunet +``` + +The MSD dataset will be downloaded and prepared automatically by the nnunet client if it does not already exist. The dataset_path flag is used as more of a data working directory by the client. The client will create nnunet_raw, nnunet_preprocessed and nnunet_results sub directories. The dataset itself will be stored in a folder within nnunet_raw. Therefore when checking if the data already exists, the client will look for the following folder '{dataset_path}/nnunet_raw/{dataset_name}' From f2600b1b435188f420299acfc4ff20538f60e6fc Mon Sep 17 00:00:00 2001 From: Shawn Carere Date: Tue, 30 Jul 2024 14:51:07 -0400 Subject: [PATCH 03/11] Added smoke tests. --- research/picai/fl_nnunet/nnunet_client.py | 27 +++++++++++++++++++ ...unet_config.yaml => nnunet_config_2d.yaml} | 2 +- tests/smoke_tests/nnunet_config_3d.yaml | 8 ++++++ tests/smoke_tests/run_smoke_test.py | 21 +++++++++++++++ 4 files changed, 57 insertions(+), 1 deletion(-) rename tests/smoke_tests/{nnunet_config.yaml => nnunet_config_2d.yaml} (88%) create mode 100644 tests/smoke_tests/nnunet_config_3d.yaml diff --git a/research/picai/fl_nnunet/nnunet_client.py b/research/picai/fl_nnunet/nnunet_client.py index 3ee29b643..88e88d847 100644 --- a/research/picai/fl_nnunet/nnunet_client.py +++ b/research/picai/fl_nnunet/nnunet_client.py @@ -36,6 +36,8 @@ # silences a bunch of deprecation warnings related to scipy.ndimage # Raised an issue with nnunet. https://github.com/MIC-DKFZ/nnUNet/issues/2370 warnings.filterwarnings("ignore", category=DeprecationWarning) + from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter + from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter from batchgenerators.utilities.file_and_folder_operations import load_json, save_json from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, preprocess_dataset @@ -579,3 +581,28 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: properties = super().get_properties(config) properties["nnunet_plans"] = pickle.dumps(plans_bytes) return properties + + def shutdown_dataloader(self, dataloader: Optional[DataLoader], dl_name: Optional[str] = None) -> None: + """ + Checks the dataloaders type and if it is a MultiThreadedAugmenter or + NonDetMultiThreadedAugmenter calls the _finish method to ensure they + are properly shutdown + + Args: + dataloader (DataLoader): The dataloader to shutdown + dl_name (Optional[str]): A string that identifies the dataloader + to shutdown. Used for logging purposes. Defaults to None + """ + if dataloader is not None and isinstance(dataloader, nnUNetDataLoaderWrapper): + if isinstance(dataloader.nnunet_dataloader, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): + if dl_name is not None: + log(INFO, f"\tShutting down nnunet dataloader: {dl_name}") + dataloader.nnunet_dataloader._finish() + + def shutdown(self) -> None: + # Not entirely sure if processes potentially opened by nnunet + # dataloaders were being ended so ensure that they are ended here + self.shutdown_dataloader(self.train_loader, "train_loader") + self.shutdown_dataloader(self.val_loader, "val_loader") + self.shutdown_dataloader(self.test_loader, "test_loader") + return super().shutdown() diff --git a/tests/smoke_tests/nnunet_config.yaml b/tests/smoke_tests/nnunet_config_2d.yaml similarity index 88% rename from tests/smoke_tests/nnunet_config.yaml rename to tests/smoke_tests/nnunet_config_2d.yaml index 1b76a61dc..d2371d0d6 100644 --- a/tests/smoke_tests/nnunet_config.yaml +++ b/tests/smoke_tests/nnunet_config_2d.yaml @@ -3,6 +3,6 @@ n_server_rounds: 1 # Parameters that describe the clients n_clients: 1 -local_epochs: 2 +local_epochs: 1 nnunet_config: 2d diff --git a/tests/smoke_tests/nnunet_config_3d.yaml b/tests/smoke_tests/nnunet_config_3d.yaml new file mode 100644 index 000000000..ff35a18fd --- /dev/null +++ b/tests/smoke_tests/nnunet_config_3d.yaml @@ -0,0 +1,8 @@ +# Parameters that describe the server +n_server_rounds: 1 + +# Parameters that describe the clients +n_clients: 1 +local_epochs: 1 + +nnunet_config: 3d_fullres diff --git a/tests/smoke_tests/run_smoke_test.py b/tests/smoke_tests/run_smoke_test.py index 841ba9da4..7900a2192 100644 --- a/tests/smoke_tests/run_smoke_test.py +++ b/tests/smoke_tests/run_smoke_test.py @@ -153,6 +153,7 @@ async def run_smoke_test( full_server_output = "" startup_messages = [ # printed by fedprox, apfl, basic_example, fedbn, fedper, fedrep, and ditto, FENDA, fl_plus_local_ft and moon + # Update, this is no longer in output, examples are actually being triggered by the [ROUND 1] startup message "FL starting", # printed by scaffold "Using Warm Start Strategy. Waiting for clients to be available for polling", @@ -161,6 +162,10 @@ async def run_smoke_test( # printed by federated_eval "Federated Evaluation Starting", "[ROUND 1]", + # As far as I can tell this is printed by most servers that inherit from FlServer + "Flower ECE: gRPC server running ", + "gRPC server running", + "server running", ] output_found = False @@ -454,6 +459,22 @@ def load_metrics_from_file(file_path: str) -> Dict[str, Any]: if __name__ == "__main__": loop = asyncio.get_event_loop() + loop.run_until_complete( + run_smoke_test( # By default will use Task04_Hippocampus Dataset + server_python_path="examples.nnunet_example.server", + client_python_path="examples.nnunet_example.client", + config_path="tests/smoke_tests/nnunet_config_2d.yaml", + dataset_path="examples/datasets/nnunet", + ) + ) + loop.run_until_complete( + run_smoke_test( # By default will use Task04_Hippocampus Dataset + server_python_path="examples.nnunet_example.server", + client_python_path="examples.nnunet_example.client", + config_path="tests/smoke_tests/nnunet_config_3d.yaml", + dataset_path="examples/datasets/nnunet", + ) + ) loop.run_until_complete( run_smoke_test( server_python_path="examples.fedprox_example.server", From 437863789707547839a5826c8063203c8c6f0213 Mon Sep 17 00:00:00 2001 From: Shawn Carere Date: Tue, 30 Jul 2024 15:24:19 -0400 Subject: [PATCH 04/11] ignored a pytorch vulnerability for torch versions below 2.2.0 --- .github/workflows/static_code_checks.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/static_code_checks.yaml b/.github/workflows/static_code_checks.yaml index 61c2e333b..48ebebf4b 100644 --- a/.github/workflows/static_code_checks.yaml +++ b/.github/workflows/static_code_checks.yaml @@ -48,6 +48,7 @@ jobs: # pip-audit to ignore these warnings # Vulnerabilities from GHSA-x38x-g6gr-jqff to GHSA-7p8j-qv6x-f4g4 # originate from mlflow + # Ignore pytorch vulnerability GHSA-pg7h-5qx3-wjr3 ignore-vulns: | GHSA-x38x-g6gr-jqff GHSA-j8mg-pqc5-x9gj @@ -58,3 +59,4 @@ jobs: GHSA-cwgg-w6mp-w9hg GHSA-43c4-9qgj-x742 GHSA-7p8j-qv6x-f4g4 + GHSA-pg7h-5qx3-wjr3 From 92a64b34bff6ae2a4465cf37405b84590394f8e0 Mon Sep 17 00:00:00 2001 From: Shawn Carere Date: Tue, 30 Jul 2024 15:38:12 -0400 Subject: [PATCH 05/11] added __init__.py --- examples/nnunet_example/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 examples/nnunet_example/__init__.py diff --git a/examples/nnunet_example/__init__.py b/examples/nnunet_example/__init__.py new file mode 100644 index 000000000..e69de29bb From 335f69a5ad1311c691490bd2736beb7c38980b8e Mon Sep 17 00:00:00 2001 From: Shawn Carere Date: Tue, 30 Jul 2024 16:33:01 -0400 Subject: [PATCH 06/11] made the nnunet smoke tests last to see if the others work --- tests/smoke_tests/run_smoke_test.py | 32 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/smoke_tests/run_smoke_test.py b/tests/smoke_tests/run_smoke_test.py index 7900a2192..0b332f4ac 100644 --- a/tests/smoke_tests/run_smoke_test.py +++ b/tests/smoke_tests/run_smoke_test.py @@ -459,22 +459,6 @@ def load_metrics_from_file(file_path: str) -> Dict[str, Any]: if __name__ == "__main__": loop = asyncio.get_event_loop() - loop.run_until_complete( - run_smoke_test( # By default will use Task04_Hippocampus Dataset - server_python_path="examples.nnunet_example.server", - client_python_path="examples.nnunet_example.client", - config_path="tests/smoke_tests/nnunet_config_2d.yaml", - dataset_path="examples/datasets/nnunet", - ) - ) - loop.run_until_complete( - run_smoke_test( # By default will use Task04_Hippocampus Dataset - server_python_path="examples.nnunet_example.server", - client_python_path="examples.nnunet_example.client", - config_path="tests/smoke_tests/nnunet_config_3d.yaml", - dataset_path="examples/datasets/nnunet", - ) - ) loop.run_until_complete( run_smoke_test( server_python_path="examples.fedprox_example.server", @@ -661,4 +645,20 @@ def load_metrics_from_file(file_path: str) -> Dict[str, Any]: dataset_path="examples/datasets/cifar_data/", ) ) + loop.run_until_complete( + run_smoke_test( # By default will use Task04_Hippocampus Dataset + server_python_path="examples.nnunet_example.server", + client_python_path="examples.nnunet_example.client", + config_path="tests/smoke_tests/nnunet_config_2d.yaml", + dataset_path="examples/datasets/nnunet", + ) + ) + loop.run_until_complete( + run_smoke_test( # By default will use Task04_Hippocampus Dataset + server_python_path="examples.nnunet_example.server", + client_python_path="examples.nnunet_example.client", + config_path="tests/smoke_tests/nnunet_config_3d.yaml", + dataset_path="examples/datasets/nnunet", + ) + ) loop.close() From 7df7f04452deb80faf2048e6d49e39d67d90e9c5 Mon Sep 17 00:00:00 2001 From: jewelltaylor Date: Wed, 31 Jul 2024 14:19:22 -0400 Subject: [PATCH 07/11] Pin setuptools, increase file descriptor limit in smoke test workflow and use local_steps=5 instead of local_epochs=1 for 3d nnunet smoke test --- .github/workflows/smoke_tests.yaml | 3 +++ poetry.lock | 16 ++++++++-------- pyproject.toml | 5 +++++ tests/smoke_tests/nnunet_config_3d.yaml | 2 +- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/.github/workflows/smoke_tests.yaml b/.github/workflows/smoke_tests.yaml index 775cce716..e832efd02 100644 --- a/.github/workflows/smoke_tests.yaml +++ b/.github/workflows/smoke_tests.yaml @@ -22,6 +22,9 @@ jobs: # Display the Python version being used - name: Display Python version run: python -c "import sys; print(sys.version)" + - name: Set up file descriptor limit + run: | + ulimit -n 4096 - name: Install and configure Poetry uses: snok/install-poetry@v1 with: diff --git a/poetry.lock b/poetry.lock index 24ed51840..a3dcab397 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -2611,7 +2611,6 @@ description = "Clang Python Bindings, mirrored from the official LLVM repo: http optional = false python-versions = "*" files = [ - {file = "libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a"}, {file = "libclang-18.1.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5"}, {file = "libclang-18.1.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8"}, {file = "libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl", hash = "sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b"}, @@ -5656,18 +5655,19 @@ test = ["pytest"] [[package]] name = "setuptools" -version = "70.0.0" +version = "69.5.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, - {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, + {file = "setuptools-69.5.1-py3-none-any.whl", hash = "sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32"}, + {file = "setuptools-69.5.1.tar.gz", hash = "sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "shellingham" @@ -7312,4 +7312,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.11" -content-hash = "3e99bd34f28aaf5adc97f53a9e991a3085749c6e203eefd546aaa0951a6ecbd3" +content-hash = "a98dd7469d0a2d873b1f4dc04ab31a89fba8642221d8d97c0fbfaf6ad2bfbf62" diff --git a/pyproject.toml b/pyproject.toml index d067fc4b7..4041a079b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,11 @@ dp-accounting = "^0.4.3" torchmetrics = "^1.3.0" aiohttp = "^3.9.3" urllib3 = "^2.2.2" +setuptools = "69.5.1" +# Documented issues with setuptools 70.0.0 +# https://stackoverflow.com/a/78606253/24046590 +# https://github.com/vllm-project/vllm/issues/4961 +# Temporary solution is to pin to 69.5.1 [tool.poetry.group.dev.dependencies] # locked the 2.13 version because of restrictions with tensorflow-io diff --git a/tests/smoke_tests/nnunet_config_3d.yaml b/tests/smoke_tests/nnunet_config_3d.yaml index ff35a18fd..5244fdacc 100644 --- a/tests/smoke_tests/nnunet_config_3d.yaml +++ b/tests/smoke_tests/nnunet_config_3d.yaml @@ -3,6 +3,6 @@ n_server_rounds: 1 # Parameters that describe the clients n_clients: 1 -local_epochs: 1 +local_steps: 5 nnunet_config: 3d_fullres From 5b21fb6a8e154707290805f89398a287bd32518e Mon Sep 17 00:00:00 2001 From: Shawn Carere Date: Tue, 6 Aug 2024 11:21:10 -0400 Subject: [PATCH 08/11] Addressed Johns comments --- examples/nnunet_example/README.md | 6 +++--- examples/nnunet_example/client.py | 7 +------ examples/nnunet_example/server.py | 3 ++- fl4health/server/base_server.py | 34 ++++++++++++++++++++++++++++++- 4 files changed, 39 insertions(+), 11 deletions(-) diff --git a/examples/nnunet_example/README.md b/examples/nnunet_example/README.md index b19fd8192..2955b5ad9 100644 --- a/examples/nnunet_example/README.md +++ b/examples/nnunet_example/README.md @@ -2,7 +2,7 @@ This example demonstrates how to use the NnUNetClient to train nnunet segmentation models in a federated setting. -By default this example trains an nnunet model on the Task04_Hippocampus dataset from the Medical Segmentation Decathelon (MSD). However, any of the MSD datasets can be used by specifying them with the msd_dataset_name flag for the client. To run this example first create a config file for the server. An example config has been provided in this directory. The required keys for the config are: +By default this example trains an nnunet model on the Task04_Hippocampus dataset from the Medical Segmentation Decathlon (MSD). However, any of the MSD datasets can be used by specifying them with the msd_dataset_name flag for the client. To run this example first create a config file for the server. An example config has been provided in this directory. The required keys for the config are: ```yaml # Parameters that describe the server @@ -21,7 +21,7 @@ One may also add the following optional keys to the config yaml file ```yaml # Optional config parameters -nnunet_plans: /Path/to/nnunet/plans/json +nnunet_plans: /Path/to/nnunet/plans.json starting_checkpoint: /Path/to/starting/checkpoint.pth ``` @@ -37,4 +37,4 @@ Once the server has started, start the necessary number of clients specified by python -m examples.nnunet_example.client --dataset_path examples/datasets/nnunet ``` -The MSD dataset will be downloaded and prepared automatically by the nnunet client if it does not already exist. The dataset_path flag is used as more of a data working directory by the client. The client will create nnunet_raw, nnunet_preprocessed and nnunet_results sub directories. The dataset itself will be stored in a folder within nnunet_raw. Therefore when checking if the data already exists, the client will look for the following folder '{dataset_path}/nnunet_raw/{dataset_name}' +The MSD dataset will be downloaded and prepared automatically by the nnunet example script if it does not already exist. The dataset_path flag is used as more of a data working directory by the client. The client will create nnunet_raw, nnunet_preprocessed and nnunet_results sub directories if they do not already exist in the dataset_path folder. The dataset itself will be stored in a folder within nnunet_raw. Therefore when checking if the data already exists, the client will look for the following folder '{dataset_path}/nnunet_raw/{dataset_name}' diff --git a/examples/nnunet_example/client.py b/examples/nnunet_example/client.py index 6e3fea145..4b66c8b1f 100644 --- a/examples/nnunet_example/client.py +++ b/examples/nnunet_example/client.py @@ -15,13 +15,8 @@ warnings.filterwarnings("ignore", category=DeprecationWarning) import lightning_utilities # noqa: F401 - # Some finicky import stuff, if i don't silence deprecation warnings when - # importing flower then i get unsilenceable deprecation warning from a - # different api (batch generators) - # Issue: https://github.com/MIC-DKFZ/nnUNet/issues/2370 - from flwr.client import start_client - import torch +from flwr.client import start_client from flwr.common.logger import log from torchmetrics.segmentation import GeneralizedDiceScore diff --git a/examples/nnunet_example/server.py b/examples/nnunet_example/server.py index b4b6ae26c..56a23a717 100644 --- a/examples/nnunet_example/server.py +++ b/examples/nnunet_example/server.py @@ -5,6 +5,8 @@ from functools import partial from typing import Optional +import yaml + with warnings.catch_warnings(): # Need to import lightning utilities now in order to avoid deprecation # warnings. Ignore flake8 warning saying that it is unused @@ -16,7 +18,6 @@ import flwr as fl import torch -import yaml from flwr.common.parameter import ndarrays_to_parameters from flwr.common.typing import Config from flwr.server.client_manager import SimpleClientManager diff --git a/fl4health/server/base_server.py b/fl4health/server/base_server.py index 3af034b6e..80188860a 100644 --- a/fl4health/server/base_server.py +++ b/fl4health/server/base_server.py @@ -353,7 +353,39 @@ def _hydrate_model_for_checkpointing(self) -> nn.Module: class FlServerWithInitializer(FlServer): - initialized = False # Add attribute + def __init__( + self, + client_manager: ClientManager, + strategy: Optional[Strategy] = None, + wandb_reporter: Optional[ServerWandBReporter] = None, + checkpointer: Optional[TorchCheckpointer] = None, + metrics_reporter: Optional[MetricsReporter] = None, + ) -> None: + """ + Server with an initialize hook method that is called prior to fit. + Override the self.initialize method to do server initialization prior + to training but after the clients have been created. Can be useful if + the state of the server depends on the properties of the clients. Eg. + The nnunet server requests an nnunet plans dict to be generated bu a + client if one was not provided. + + Args: + client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if + they are to be sampled at all. + strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle. + client updates and other information potentially sent by the participating clients. If None the + strategy is FedAvg as set by the flwr Server. + wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log + information and results to a Weights and Biases account. If None is provided, no logging occurs. + Defaults to None. + checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform + server side checkpointing based on some criteria. If none, then no server-side checkpointing is + performed. Defaults to None. + metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics + during the execution. Defaults to an instance of MetricsReporter with default init parameters. + """ + super().__init__(client_manager, strategy, wandb_reporter, checkpointer, metrics_reporter) + self.initialized = False def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) -> Parameters: """ From 1088ff9d71b402d54c826bcd427f5eed7da77b0d Mon Sep 17 00:00:00 2001 From: Shawn Carere Date: Tue, 6 Aug 2024 12:39:43 -0400 Subject: [PATCH 09/11] Prevent nnunet client from generating useless log files --- research/picai/fl_nnunet/nnunet_client.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/research/picai/fl_nnunet/nnunet_client.py b/research/picai/fl_nnunet/nnunet_client.py index 9e752a9d7..91da53535 100644 --- a/research/picai/fl_nnunet/nnunet_client.py +++ b/research/picai/fl_nnunet/nnunet_client.py @@ -1,9 +1,9 @@ import logging +import os import pickle import signal import warnings from logging import INFO -from os import makedirs, remove from os.path import exists, join from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -244,7 +244,7 @@ def create_plans(self, config: Config) -> Dict[str, Any]: # Can't run nnunet preprocessing without saving plans file if not exists(join(nnUNet_preprocessed, self.dataset_name)): - makedirs(join(nnUNet_preprocessed, self.dataset_name)) + os.makedirs(join(nnUNet_preprocessed, self.dataset_name)) plans_save_path = join(nnUNet_preprocessed, self.dataset_name, self.plans_name + ".json") save_json(plans, plans_save_path, sort_keys=False) return plans @@ -335,6 +335,17 @@ def setup_client(self, config: Config) -> None: # do it manually since nnunet_trainer not being used for training self.nnunet_trainer.set_deep_supervision_enabled(self.nnunet_trainer.enable_deep_supervision) + # Prevent nnunet from generating log files. And delete empty output directories + os.remove(self.nnunet_trainer.log_file) + self.nnunet_trainer.log_file = os.devnull + output_folder = Path(self.nnunet_trainer.output_folder) + while True: + if len(os.listdir(output_folder)) == 0: + os.rmdir(output_folder) + output_folder = output_folder.parent + else: + break + # Preprocess nnunet_raw data if needed self.maybe_preprocess(self.nnunet_config) unpack_dataset( # Reduces load on CPU and RAM during training @@ -574,7 +585,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: # Remove plans file that was created by planner plans_path = join(nnUNet_preprocessed, self.dataset_name, planner.plans_identifier + ".json") if exists(plans_path): - remove(plans_path) + os.remove(plans_path) # return properties with initialized nnunet plans. Need to provide # plans since client needs to be initialized to get properties From b33b461ec75ad66604e78cb6bda720b260113202 Mon Sep 17 00:00:00 2001 From: Shawn Carere Date: Tue, 6 Aug 2024 14:23:23 -0400 Subject: [PATCH 10/11] fixed bug where client might accidentally delete or overwrite an existing plans file --- research/picai/fl_nnunet/nnunet_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/research/picai/fl_nnunet/nnunet_client.py b/research/picai/fl_nnunet/nnunet_client.py index 91da53535..a5cc19696 100644 --- a/research/picai/fl_nnunet/nnunet_client.py +++ b/research/picai/fl_nnunet/nnunet_client.py @@ -577,7 +577,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: self.maybe_extract_fingerprint() # Create experiment planner and plans - planner = ExperimentPlanner(dataset_name_or_id=self.dataset_id) + planner = ExperimentPlanner(dataset_name_or_id=self.dataset_id, plans_name="temp_plans") with nostdout(): # Prevent print statements from experiment planner plans = planner.plan_experiment() plans_bytes = pickle.dumps(plans) From 308e9a2fafd76b13f8f42d99f5824983cd280d12 Mon Sep 17 00:00:00 2001 From: Shawn Carere Date: Thu, 8 Aug 2024 12:27:26 -0400 Subject: [PATCH 11/11] typo fix --- fl4health/server/base_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fl4health/server/base_server.py b/fl4health/server/base_server.py index 80188860a..307d5e5b7 100644 --- a/fl4health/server/base_server.py +++ b/fl4health/server/base_server.py @@ -366,7 +366,7 @@ def __init__( Override the self.initialize method to do server initialization prior to training but after the clients have been created. Can be useful if the state of the server depends on the properties of the clients. Eg. - The nnunet server requests an nnunet plans dict to be generated bu a + The nnunet server requests an nnunet plans dict to be generated by a client if one was not provided. Args: