Skip to content

Commit

Permalink
Merge pull request #199 from VectorInstitute/nnunet_test
Browse files Browse the repository at this point in the history
Nnunet test
  • Loading branch information
scarere authored Aug 8, 2024
2 parents d6159a9 + 308e9a2 commit a5c1b1b
Show file tree
Hide file tree
Showing 21 changed files with 819 additions and 61 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/smoke_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ jobs:
# Display the Python version being used
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Set up file descriptor limit
run: |
ulimit -n 4096
- name: Install and configure Poetry
uses: snok/install-poetry@v1
with:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/static_code_checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
# pip-audit to ignore these warnings
# Vulnerabilities from GHSA-x38x-g6gr-jqff to GHSA-7p8j-qv6x-f4g4
# originate from mlflow
# Ignore pytorch vulnerability GHSA-pg7h-5qx3-wjr3
ignore-vulns: |
GHSA-x38x-g6gr-jqff
GHSA-j8mg-pqc5-x9gj
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ settings.json
**/datasets/skin_cancer/PAD-UFES-20/**
**/datasets/skin_cancer/ISIC_2019/**
**/datasets/skin_cancer/Derm7pt/**
**/datasets/nnunet/**

# logs

Expand Down
40 changes: 40 additions & 0 deletions examples/nnunet_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# NnUNetClient Example

This example demonstrates how to use the NnUNetClient to train nnunet segmentation models in a federated setting.

By default this example trains an nnunet model on the Task04_Hippocampus dataset from the Medical Segmentation Decathlon (MSD). However, any of the MSD datasets can be used by specifying them with the msd_dataset_name flag for the client. To run this example first create a config file for the server. An example config has been provided in this directory. The required keys for the config are:

```yaml
# Parameters that describe the server
n_server_rounds: 1

# Parameters that describe the clients
n_clients: 1
local_epochs: 1 # Or local_steps, one or the other must be chosen

nnunet_config: 2d
```
The only additional parameter required by nnunet is nnunet_config which is one of the official nnunet configurations (2d, 3d_fullres, 3d_lowres, 3d_cascade_fullres)
One may also add the following optional keys to the config yaml file
```yaml
# Optional config parameters
nnunet_plans: /Path/to/nnunet/plans.json
starting_checkpoint: /Path/to/starting/checkpoint.pth
```
To run a federated learning experiment with nnunet models, first ensure you are in the FL4Health directory and then start the nnunet server using the following command. To view a list of optional flags use the --help flag
```bash
python -m examples.nnunet_example.server --config_path examples/nnunet_example/config.yaml
```

Once the server has started, start the necessary number of clients specified by the n_clients key in the config file. Each client can be started by running the following command in a seperate session. To view a list of optional flags use the --help flag.

```bash
python -m examples.nnunet_example.client --dataset_path examples/datasets/nnunet
```

The MSD dataset will be downloaded and prepared automatically by the nnunet example script if it does not already exist. The dataset_path flag is used as more of a data working directory by the client. The client will create nnunet_raw, nnunet_preprocessed and nnunet_results sub directories if they do not already exist in the dataset_path folder. The dataset itself will be stored in a folder within nnunet_raw. Therefore when checking if the data already exists, the client will look for the following folder '{dataset_path}/nnunet_raw/{dataset_name}'
Empty file.
172 changes: 172 additions & 0 deletions examples/nnunet_example/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import argparse
import os
import warnings
from logging import INFO
from os.path import exists, join
from pathlib import Path
from typing import Union

with warnings.catch_warnings():
# Need to import lightning utilities now in order to avoid deprecation
# warnings. Ignore flake8 warning saying that it is unused
# lightning utilities is imported by some of the dependencies
# so by importing it now and filtering the warnings
# https://github.com/Lightning-AI/utilities/issues/119
warnings.filterwarnings("ignore", category=DeprecationWarning)
import lightning_utilities # noqa: F401

import torch
from flwr.client import start_client
from flwr.common.logger import log
from torchmetrics.segmentation import GeneralizedDiceScore

from fl4health.utils.load_data import load_msd_dataset
from fl4health.utils.metrics import TorchMetric, TransformsMetric
from fl4health.utils.msd_dataset_sources import get_msd_dataset_enum, msd_num_labels
from research.picai.fl_nnunet.transforms import get_annotations_from_probs, get_probabilities_from_logits


def main(
dataset_path: Path,
msd_dataset_name: str,
server_address: str,
fold: Union[int, str],
always_preprocess: bool = False,
) -> None:

# Log device and server address
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Using device: {DEVICE}")
log(INFO, f"Using server address: {server_address}")

# Load the dataset if necessary
msd_dataset_enum = get_msd_dataset_enum(msd_dataset_name)
nnUNet_raw = join(dataset_path, "nnunet_raw")
if not exists(join(nnUNet_raw, msd_dataset_enum.value)):
log(INFO, f"Downloading and extracting {msd_dataset_enum.value} dataset")
load_msd_dataset(nnUNet_raw, msd_dataset_name)

# The dataset ID will be the same as the MSD Task number
dataset_id = int(msd_dataset_enum.value[4:6])
nnunet_dataset_name = f"Dataset{dataset_id:03d}_{msd_dataset_enum.value.split('_')[1]}"

# Convert the msd dataset if necessary
if not exists(join(nnUNet_raw, nnunet_dataset_name)):
log(INFO, f"Converting {msd_dataset_enum.value} into nnunet dataset")
convert_msd_dataset(source_folder=join(nnUNet_raw, msd_dataset_enum.value))

# Create a metric
dice = TransformsMetric(
metric=TorchMetric(
name="Pseudo DICE",
metric=GeneralizedDiceScore(
num_classes=msd_num_labels[msd_dataset_enum], weight_type="square", include_background=False
).to(DEVICE),
),
transforms=[get_probabilities_from_logits, get_annotations_from_probs],
)

# Create client
client = nnUNetClient(
# Args specific to nnUNetClient
dataset_id=dataset_id,
fold=fold,
always_preprocess=always_preprocess,
# BaseClient Args
device=DEVICE,
metrics=[dice],
data_path=dataset_path, # Argument not actually used by nnUNetClient
)

start_client(server_address=server_address, client=client.to_client())

# Shutdown the client
client.shutdown()


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="nnunet_example/client.py",
description="""An exampled of nnUNetClient on any of the Medical
Segmentation Decathelon (MSD) datasets. Automatically generates a
nnunet segmentation model and trains it in a federated setting""",
)

# I have to use underscores instead of dashes because thats how they
# defined it in run smoke tests
parser.add_argument(
"--dataset_path",
type=str,
required=True,
help="""Path to the folder in which data should be stored. This script
will automatically create nnunet_raw, and nnunet_preprocessed
subfolders if they don't already exist. This script will also
attempt to download and prepare the MSD Dataset into the
nnunet_raw folder if it does not already exist.""",
)
parser.add_argument(
"--fold",
type=str,
required=False,
default="0",
help="""[OPTIONAL] Which fold of the local client dataset to use for
validation. nnunet defaults to 5 folds (0 to 4). Can also be set
to 'all' to use all the data for both training and validation.
Defaults to fold 0""",
)
parser.add_argument(
"--msd_dataset_name",
type=str,
required=False,
default="Task04_Hippocampus", # The smallest dataset
help="""[OPTIONAL] Name of the MSD dataset to use. The options are
defined by the values of the MsdDataset enum as returned by the
get_msd_dataset_enum function""",
)
parser.add_argument(
"--always-preprocess",
action="store_true",
required=False,
help="""[OPTIONAL] Use this to force preprocessing the nnunet data
even if the preprocessed data is found to already exist""",
)
parser.add_argument(
"--server_address",
type=str,
required=False,
default="0.0.0.0:8080",
help="""[OPTIONAL] The server address for the clients to communicate
to the server through. Defaults to 0.0.0.0:8080""",
)
args = parser.parse_args()

# Create nnunet directory structure and set environment variables
nnUNet_raw = join(args.dataset_path, "nnunet_raw")
nnUNet_preprocessed = join(args.dataset_path, "nnunet_preprocessed")
if not exists(nnUNet_raw):
os.makedirs(nnUNet_raw)
if not exists(nnUNet_preprocessed):
os.makedirs(nnUNet_preprocessed)
os.environ["nnUNet_raw"] = nnUNet_raw
os.environ["nnUNet_preprocessed"] = nnUNet_preprocessed
os.environ["nnUNet_results"] = join(args.dataset_path, "nnunet_results")
log(INFO, "Setting nnunet environment variables")
log(INFO, f"\tnnUNet_raw: {nnUNet_raw}")
log(INFO, f"\tnnUNet_preprocessed: {nnUNet_preprocessed}")
log(INFO, f"\tnnUNet_results: {join(args.dataset_path, 'nnunet_results')}")

# Everything that uses nnunetv2 module can only be imported after
# environment variables are changed
from nnunetv2.dataset_conversion.convert_MSD_dataset import convert_msd_dataset

from research.picai.fl_nnunet.nnunet_client import nnUNetClient

# Check fold argument and start main method
fold: Union[int, str] = "all" if args.fold == "all" else int(args.fold)
main(
dataset_path=Path(args.dataset_path),
msd_dataset_name=args.msd_dataset_name,
server_address=args.server_address,
fold=fold,
always_preprocess=args.always_preprocess,
)
8 changes: 8 additions & 0 deletions examples/nnunet_example/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Parameters that describe the server
n_server_rounds: 1

# Parameters that describe the clients
n_clients: 1
local_epochs: 1

nnunet_config: 2d
122 changes: 122 additions & 0 deletions examples/nnunet_example/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import argparse
import json
import pickle
import warnings
from functools import partial
from typing import Optional

import yaml

with warnings.catch_warnings():
# Need to import lightning utilities now in order to avoid deprecation
# warnings. Ignore flake8 warning saying that it is unused
# lightning utilities is imported by some of the dependencies
# so by importing it now and filtering the warnings
# https://github.com/Lightning-AI/utilities/issues/119
warnings.filterwarnings("ignore", category=DeprecationWarning)
import lightning_utilities # noqa: F401

import flwr as fl
import torch
from flwr.common.parameter import ndarrays_to_parameters
from flwr.common.typing import Config
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from research.picai.fl_nnunet.nnunet_server import NnUNetServer


def get_config(
current_server_round: int,
nnunet_config: str,
n_server_rounds: int,
batch_size: int,
n_clients: int,
nnunet_plans: Optional[str] = None,
local_epochs: Optional[int] = None,
local_steps: Optional[int] = None,
) -> Config:
# Create config
config: Config = {
"n_clients": n_clients,
"nnunet_config": nnunet_config,
"n_server_rounds": n_server_rounds,
"batch_size": batch_size,
**make_dict_with_epochs_or_steps(local_epochs, local_steps),
"current_server_round": current_server_round,
}

# Check if plans were provided
if nnunet_plans is not None:
plans_bytes = pickle.dumps(json.load(open(nnunet_plans, "r")))
config["nnunet_plans"] = plans_bytes

return config


def main(config: dict, server_address: str) -> None:
# Partial function with everything set except current server round
fit_config_fn = partial(
get_config,
n_clients=config["n_clients"],
nnunet_config=config["nnunet_config"],
n_server_rounds=config["n_server_rounds"],
batch_size=0, # Set this to 0 because we're not using it
nnunet_plans=config.get("nnunet_plans"),
local_epochs=config.get("local_epochs"),
local_steps=config.get("local_steps"),
)

if config.get("starting_checkpoint"):
model = torch.load(config["starting_checkpoint"])
# Of course nnunet stores their pytorch models differently.
params = ndarrays_to_parameters([val.cpu().numpy() for _, val in model["network_weights"].items()])
else:
params = None

strategy = FedAvg(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
min_available_clients=config["n_clients"],
on_fit_config_fn=fit_config_fn,
on_evaluate_config_fn=fit_config_fn, # Nothing changes for eval
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=params,
)

server = NnUNetServer(
client_manager=SimpleClientManager(),
strategy=strategy,
)

fl.server.start_server(
server=server,
server_address=server_address,
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

# Shutdown server
server.shutdown()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", action="store", type=str, help="Path to the configuration file")
parser.add_argument(
"--server-address",
type=str,
required=False,
default="0.0.0.0:8080",
help="""[OPTIONAL] The address to use for the server. Defaults to
0.0.0.0:8080""",
)

args = parser.parse_args()

with open(args.config_path, "r") as f:
config = yaml.safe_load(f)

main(config, server_address=args.server_address)
Loading

0 comments on commit a5c1b1b

Please sign in to comment.