generated from VectorInstitute/aieng-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #199 from VectorInstitute/nnunet_test
Nnunet test
- Loading branch information
Showing
21 changed files
with
819 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# NnUNetClient Example | ||
|
||
This example demonstrates how to use the NnUNetClient to train nnunet segmentation models in a federated setting. | ||
|
||
By default this example trains an nnunet model on the Task04_Hippocampus dataset from the Medical Segmentation Decathlon (MSD). However, any of the MSD datasets can be used by specifying them with the msd_dataset_name flag for the client. To run this example first create a config file for the server. An example config has been provided in this directory. The required keys for the config are: | ||
|
||
```yaml | ||
# Parameters that describe the server | ||
n_server_rounds: 1 | ||
|
||
# Parameters that describe the clients | ||
n_clients: 1 | ||
local_epochs: 1 # Or local_steps, one or the other must be chosen | ||
|
||
nnunet_config: 2d | ||
``` | ||
The only additional parameter required by nnunet is nnunet_config which is one of the official nnunet configurations (2d, 3d_fullres, 3d_lowres, 3d_cascade_fullres) | ||
One may also add the following optional keys to the config yaml file | ||
```yaml | ||
# Optional config parameters | ||
nnunet_plans: /Path/to/nnunet/plans.json | ||
starting_checkpoint: /Path/to/starting/checkpoint.pth | ||
``` | ||
To run a federated learning experiment with nnunet models, first ensure you are in the FL4Health directory and then start the nnunet server using the following command. To view a list of optional flags use the --help flag | ||
```bash | ||
python -m examples.nnunet_example.server --config_path examples/nnunet_example/config.yaml | ||
``` | ||
|
||
Once the server has started, start the necessary number of clients specified by the n_clients key in the config file. Each client can be started by running the following command in a seperate session. To view a list of optional flags use the --help flag. | ||
|
||
```bash | ||
python -m examples.nnunet_example.client --dataset_path examples/datasets/nnunet | ||
``` | ||
|
||
The MSD dataset will be downloaded and prepared automatically by the nnunet example script if it does not already exist. The dataset_path flag is used as more of a data working directory by the client. The client will create nnunet_raw, nnunet_preprocessed and nnunet_results sub directories if they do not already exist in the dataset_path folder. The dataset itself will be stored in a folder within nnunet_raw. Therefore when checking if the data already exists, the client will look for the following folder '{dataset_path}/nnunet_raw/{dataset_name}' |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import argparse | ||
import os | ||
import warnings | ||
from logging import INFO | ||
from os.path import exists, join | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
with warnings.catch_warnings(): | ||
# Need to import lightning utilities now in order to avoid deprecation | ||
# warnings. Ignore flake8 warning saying that it is unused | ||
# lightning utilities is imported by some of the dependencies | ||
# so by importing it now and filtering the warnings | ||
# https://github.com/Lightning-AI/utilities/issues/119 | ||
warnings.filterwarnings("ignore", category=DeprecationWarning) | ||
import lightning_utilities # noqa: F401 | ||
|
||
import torch | ||
from flwr.client import start_client | ||
from flwr.common.logger import log | ||
from torchmetrics.segmentation import GeneralizedDiceScore | ||
|
||
from fl4health.utils.load_data import load_msd_dataset | ||
from fl4health.utils.metrics import TorchMetric, TransformsMetric | ||
from fl4health.utils.msd_dataset_sources import get_msd_dataset_enum, msd_num_labels | ||
from research.picai.fl_nnunet.transforms import get_annotations_from_probs, get_probabilities_from_logits | ||
|
||
|
||
def main( | ||
dataset_path: Path, | ||
msd_dataset_name: str, | ||
server_address: str, | ||
fold: Union[int, str], | ||
always_preprocess: bool = False, | ||
) -> None: | ||
|
||
# Log device and server address | ||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
log(INFO, f"Using device: {DEVICE}") | ||
log(INFO, f"Using server address: {server_address}") | ||
|
||
# Load the dataset if necessary | ||
msd_dataset_enum = get_msd_dataset_enum(msd_dataset_name) | ||
nnUNet_raw = join(dataset_path, "nnunet_raw") | ||
if not exists(join(nnUNet_raw, msd_dataset_enum.value)): | ||
log(INFO, f"Downloading and extracting {msd_dataset_enum.value} dataset") | ||
load_msd_dataset(nnUNet_raw, msd_dataset_name) | ||
|
||
# The dataset ID will be the same as the MSD Task number | ||
dataset_id = int(msd_dataset_enum.value[4:6]) | ||
nnunet_dataset_name = f"Dataset{dataset_id:03d}_{msd_dataset_enum.value.split('_')[1]}" | ||
|
||
# Convert the msd dataset if necessary | ||
if not exists(join(nnUNet_raw, nnunet_dataset_name)): | ||
log(INFO, f"Converting {msd_dataset_enum.value} into nnunet dataset") | ||
convert_msd_dataset(source_folder=join(nnUNet_raw, msd_dataset_enum.value)) | ||
|
||
# Create a metric | ||
dice = TransformsMetric( | ||
metric=TorchMetric( | ||
name="Pseudo DICE", | ||
metric=GeneralizedDiceScore( | ||
num_classes=msd_num_labels[msd_dataset_enum], weight_type="square", include_background=False | ||
).to(DEVICE), | ||
), | ||
transforms=[get_probabilities_from_logits, get_annotations_from_probs], | ||
) | ||
|
||
# Create client | ||
client = nnUNetClient( | ||
# Args specific to nnUNetClient | ||
dataset_id=dataset_id, | ||
fold=fold, | ||
always_preprocess=always_preprocess, | ||
# BaseClient Args | ||
device=DEVICE, | ||
metrics=[dice], | ||
data_path=dataset_path, # Argument not actually used by nnUNetClient | ||
) | ||
|
||
start_client(server_address=server_address, client=client.to_client()) | ||
|
||
# Shutdown the client | ||
client.shutdown() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
prog="nnunet_example/client.py", | ||
description="""An exampled of nnUNetClient on any of the Medical | ||
Segmentation Decathelon (MSD) datasets. Automatically generates a | ||
nnunet segmentation model and trains it in a federated setting""", | ||
) | ||
|
||
# I have to use underscores instead of dashes because thats how they | ||
# defined it in run smoke tests | ||
parser.add_argument( | ||
"--dataset_path", | ||
type=str, | ||
required=True, | ||
help="""Path to the folder in which data should be stored. This script | ||
will automatically create nnunet_raw, and nnunet_preprocessed | ||
subfolders if they don't already exist. This script will also | ||
attempt to download and prepare the MSD Dataset into the | ||
nnunet_raw folder if it does not already exist.""", | ||
) | ||
parser.add_argument( | ||
"--fold", | ||
type=str, | ||
required=False, | ||
default="0", | ||
help="""[OPTIONAL] Which fold of the local client dataset to use for | ||
validation. nnunet defaults to 5 folds (0 to 4). Can also be set | ||
to 'all' to use all the data for both training and validation. | ||
Defaults to fold 0""", | ||
) | ||
parser.add_argument( | ||
"--msd_dataset_name", | ||
type=str, | ||
required=False, | ||
default="Task04_Hippocampus", # The smallest dataset | ||
help="""[OPTIONAL] Name of the MSD dataset to use. The options are | ||
defined by the values of the MsdDataset enum as returned by the | ||
get_msd_dataset_enum function""", | ||
) | ||
parser.add_argument( | ||
"--always-preprocess", | ||
action="store_true", | ||
required=False, | ||
help="""[OPTIONAL] Use this to force preprocessing the nnunet data | ||
even if the preprocessed data is found to already exist""", | ||
) | ||
parser.add_argument( | ||
"--server_address", | ||
type=str, | ||
required=False, | ||
default="0.0.0.0:8080", | ||
help="""[OPTIONAL] The server address for the clients to communicate | ||
to the server through. Defaults to 0.0.0.0:8080""", | ||
) | ||
args = parser.parse_args() | ||
|
||
# Create nnunet directory structure and set environment variables | ||
nnUNet_raw = join(args.dataset_path, "nnunet_raw") | ||
nnUNet_preprocessed = join(args.dataset_path, "nnunet_preprocessed") | ||
if not exists(nnUNet_raw): | ||
os.makedirs(nnUNet_raw) | ||
if not exists(nnUNet_preprocessed): | ||
os.makedirs(nnUNet_preprocessed) | ||
os.environ["nnUNet_raw"] = nnUNet_raw | ||
os.environ["nnUNet_preprocessed"] = nnUNet_preprocessed | ||
os.environ["nnUNet_results"] = join(args.dataset_path, "nnunet_results") | ||
log(INFO, "Setting nnunet environment variables") | ||
log(INFO, f"\tnnUNet_raw: {nnUNet_raw}") | ||
log(INFO, f"\tnnUNet_preprocessed: {nnUNet_preprocessed}") | ||
log(INFO, f"\tnnUNet_results: {join(args.dataset_path, 'nnunet_results')}") | ||
|
||
# Everything that uses nnunetv2 module can only be imported after | ||
# environment variables are changed | ||
from nnunetv2.dataset_conversion.convert_MSD_dataset import convert_msd_dataset | ||
|
||
from research.picai.fl_nnunet.nnunet_client import nnUNetClient | ||
|
||
# Check fold argument and start main method | ||
fold: Union[int, str] = "all" if args.fold == "all" else int(args.fold) | ||
main( | ||
dataset_path=Path(args.dataset_path), | ||
msd_dataset_name=args.msd_dataset_name, | ||
server_address=args.server_address, | ||
fold=fold, | ||
always_preprocess=args.always_preprocess, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Parameters that describe the server | ||
n_server_rounds: 1 | ||
|
||
# Parameters that describe the clients | ||
n_clients: 1 | ||
local_epochs: 1 | ||
|
||
nnunet_config: 2d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import argparse | ||
import json | ||
import pickle | ||
import warnings | ||
from functools import partial | ||
from typing import Optional | ||
|
||
import yaml | ||
|
||
with warnings.catch_warnings(): | ||
# Need to import lightning utilities now in order to avoid deprecation | ||
# warnings. Ignore flake8 warning saying that it is unused | ||
# lightning utilities is imported by some of the dependencies | ||
# so by importing it now and filtering the warnings | ||
# https://github.com/Lightning-AI/utilities/issues/119 | ||
warnings.filterwarnings("ignore", category=DeprecationWarning) | ||
import lightning_utilities # noqa: F401 | ||
|
||
import flwr as fl | ||
import torch | ||
from flwr.common.parameter import ndarrays_to_parameters | ||
from flwr.common.typing import Config | ||
from flwr.server.client_manager import SimpleClientManager | ||
from flwr.server.strategy import FedAvg | ||
|
||
from examples.utils.functions import make_dict_with_epochs_or_steps | ||
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn | ||
from research.picai.fl_nnunet.nnunet_server import NnUNetServer | ||
|
||
|
||
def get_config( | ||
current_server_round: int, | ||
nnunet_config: str, | ||
n_server_rounds: int, | ||
batch_size: int, | ||
n_clients: int, | ||
nnunet_plans: Optional[str] = None, | ||
local_epochs: Optional[int] = None, | ||
local_steps: Optional[int] = None, | ||
) -> Config: | ||
# Create config | ||
config: Config = { | ||
"n_clients": n_clients, | ||
"nnunet_config": nnunet_config, | ||
"n_server_rounds": n_server_rounds, | ||
"batch_size": batch_size, | ||
**make_dict_with_epochs_or_steps(local_epochs, local_steps), | ||
"current_server_round": current_server_round, | ||
} | ||
|
||
# Check if plans were provided | ||
if nnunet_plans is not None: | ||
plans_bytes = pickle.dumps(json.load(open(nnunet_plans, "r"))) | ||
config["nnunet_plans"] = plans_bytes | ||
|
||
return config | ||
|
||
|
||
def main(config: dict, server_address: str) -> None: | ||
# Partial function with everything set except current server round | ||
fit_config_fn = partial( | ||
get_config, | ||
n_clients=config["n_clients"], | ||
nnunet_config=config["nnunet_config"], | ||
n_server_rounds=config["n_server_rounds"], | ||
batch_size=0, # Set this to 0 because we're not using it | ||
nnunet_plans=config.get("nnunet_plans"), | ||
local_epochs=config.get("local_epochs"), | ||
local_steps=config.get("local_steps"), | ||
) | ||
|
||
if config.get("starting_checkpoint"): | ||
model = torch.load(config["starting_checkpoint"]) | ||
# Of course nnunet stores their pytorch models differently. | ||
params = ndarrays_to_parameters([val.cpu().numpy() for _, val in model["network_weights"].items()]) | ||
else: | ||
params = None | ||
|
||
strategy = FedAvg( | ||
min_fit_clients=config["n_clients"], | ||
min_evaluate_clients=config["n_clients"], | ||
min_available_clients=config["n_clients"], | ||
on_fit_config_fn=fit_config_fn, | ||
on_evaluate_config_fn=fit_config_fn, # Nothing changes for eval | ||
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, | ||
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, | ||
initial_parameters=params, | ||
) | ||
|
||
server = NnUNetServer( | ||
client_manager=SimpleClientManager(), | ||
strategy=strategy, | ||
) | ||
|
||
fl.server.start_server( | ||
server=server, | ||
server_address=server_address, | ||
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), | ||
) | ||
|
||
# Shutdown server | ||
server.shutdown() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--config_path", action="store", type=str, help="Path to the configuration file") | ||
parser.add_argument( | ||
"--server-address", | ||
type=str, | ||
required=False, | ||
default="0.0.0.0:8080", | ||
help="""[OPTIONAL] The address to use for the server. Defaults to | ||
0.0.0.0:8080""", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
with open(args.config_path, "r") as f: | ||
config = yaml.safe_load(f) | ||
|
||
main(config, server_address=args.server_address) |
Oops, something went wrong.