Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nnunet test #199

Merged
merged 13 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/smoke_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ jobs:
# Display the Python version being used
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Set up file descriptor limit
run: |
ulimit -n 4096
- name: Install and configure Poetry
uses: snok/install-poetry@v1
with:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/static_code_checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
# pip-audit to ignore these warnings
# Vulnerabilities from GHSA-x38x-g6gr-jqff to GHSA-7p8j-qv6x-f4g4
# originate from mlflow
# Ignore pytorch vulnerability GHSA-pg7h-5qx3-wjr3
ignore-vulns: |
GHSA-x38x-g6gr-jqff
GHSA-j8mg-pqc5-x9gj
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ settings.json
**/datasets/skin_cancer/PAD-UFES-20/**
**/datasets/skin_cancer/ISIC_2019/**
**/datasets/skin_cancer/Derm7pt/**
**/datasets/nnunet/**

# logs

Expand Down
40 changes: 40 additions & 0 deletions examples/nnunet_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# NnUNetClient Example

This example demonstrates how to use the NnUNetClient to train nnunet segmentation models in a federated setting.

By default this example trains an nnunet model on the Task04_Hippocampus dataset from the Medical Segmentation Decathlon (MSD). However, any of the MSD datasets can be used by specifying them with the msd_dataset_name flag for the client. To run this example first create a config file for the server. An example config has been provided in this directory. The required keys for the config are:

```yaml
# Parameters that describe the server
n_server_rounds: 1

# Parameters that describe the clients
n_clients: 1
local_epochs: 1 # Or local_steps, one or the other must be chosen

nnunet_config: 2d
```

The only additional parameter required by nnunet is nnunet_config which is one of the official nnunet configurations (2d, 3d_fullres, 3d_lowres, 3d_cascade_fullres)

One may also add the following optional keys to the config yaml file

```yaml
# Optional config parameters
nnunet_plans: /Path/to/nnunet/plans.json
starting_checkpoint: /Path/to/starting/checkpoint.pth
```

To run a federated learning experiment with nnunet models, first ensure you are in the FL4Health directory and then start the nnunet server using the following command. To view a list of optional flags use the --help flag

```bash
python -m examples.nnunet_example.server --config_path examples/nnunet_example/config.yaml
```

Once the server has started, start the necessary number of clients specified by the n_clients key in the config file. Each client can be started by running the following command in a seperate session. To view a list of optional flags use the --help flag.

```bash
python -m examples.nnunet_example.client --dataset_path examples/datasets/nnunet
```

The MSD dataset will be downloaded and prepared automatically by the nnunet example script if it does not already exist. The dataset_path flag is used as more of a data working directory by the client. The client will create nnunet_raw, nnunet_preprocessed and nnunet_results sub directories if they do not already exist in the dataset_path folder. The dataset itself will be stored in a folder within nnunet_raw. Therefore when checking if the data already exists, the client will look for the following folder '{dataset_path}/nnunet_raw/{dataset_name}'
Empty file.
172 changes: 172 additions & 0 deletions examples/nnunet_example/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import argparse
import os
import warnings
from logging import INFO
from os.path import exists, join
from pathlib import Path
from typing import Union

with warnings.catch_warnings():
# Need to import lightning utilities now in order to avoid deprecation
# warnings. Ignore flake8 warning saying that it is unused
# lightning utilities is imported by some of the dependencies
# so by importing it now and filtering the warnings
# https://github.com/Lightning-AI/utilities/issues/119
scarere marked this conversation as resolved.
Show resolved Hide resolved
warnings.filterwarnings("ignore", category=DeprecationWarning)
import lightning_utilities # noqa: F401

import torch
from flwr.client import start_client
from flwr.common.logger import log
from torchmetrics.segmentation import GeneralizedDiceScore

from fl4health.utils.load_data import load_msd_dataset
from fl4health.utils.metrics import TorchMetric, TransformsMetric
from fl4health.utils.msd_dataset_sources import get_msd_dataset_enum, msd_num_labels
from research.picai.fl_nnunet.transforms import get_annotations_from_probs, get_probabilities_from_logits


def main(
dataset_path: Path,
msd_dataset_name: str,
server_address: str,
fold: Union[int, str],
always_preprocess: bool = False,
) -> None:

# Log device and server address
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Using device: {DEVICE}")
log(INFO, f"Using server address: {server_address}")

# Load the dataset if necessary
msd_dataset_enum = get_msd_dataset_enum(msd_dataset_name)
nnUNet_raw = join(dataset_path, "nnunet_raw")
if not exists(join(nnUNet_raw, msd_dataset_enum.value)):
log(INFO, f"Downloading and extracting {msd_dataset_enum.value} dataset")
load_msd_dataset(nnUNet_raw, msd_dataset_name)

# The dataset ID will be the same as the MSD Task number
dataset_id = int(msd_dataset_enum.value[4:6])
nnunet_dataset_name = f"Dataset{dataset_id:03d}_{msd_dataset_enum.value.split('_')[1]}"

# Convert the msd dataset if necessary
if not exists(join(nnUNet_raw, nnunet_dataset_name)):
log(INFO, f"Converting {msd_dataset_enum.value} into nnunet dataset")
convert_msd_dataset(source_folder=join(nnUNet_raw, msd_dataset_enum.value))

# Create a metric
dice = TransformsMetric(
metric=TorchMetric(
name="Pseudo DICE",
metric=GeneralizedDiceScore(
num_classes=msd_num_labels[msd_dataset_enum], weight_type="square", include_background=False
).to(DEVICE),
),
transforms=[get_probabilities_from_logits, get_annotations_from_probs],
)

# Create client
client = nnUNetClient(
# Args specific to nnUNetClient
dataset_id=dataset_id,
fold=fold,
always_preprocess=always_preprocess,
# BaseClient Args
device=DEVICE,
metrics=[dice],
data_path=dataset_path, # Argument not actually used by nnUNetClient
)

start_client(server_address=server_address, client=client.to_client())

# Shutdown the client
client.shutdown()


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="nnunet_example/client.py",
description="""An exampled of nnUNetClient on any of the Medical
Segmentation Decathelon (MSD) datasets. Automatically generates a
nnunet segmentation model and trains it in a federated setting""",
)

# I have to use underscores instead of dashes because thats how they
# defined it in run smoke tests
parser.add_argument(
"--dataset_path",
type=str,
required=True,
help="""Path to the folder in which data should be stored. This script
will automatically create nnunet_raw, and nnunet_preprocessed
subfolders if they don't already exist. This script will also
attempt to download and prepare the MSD Dataset into the
nnunet_raw folder if it does not already exist.""",
)
parser.add_argument(
"--fold",
type=str,
required=False,
default="0",
help="""[OPTIONAL] Which fold of the local client dataset to use for
validation. nnunet defaults to 5 folds (0 to 4). Can also be set
to 'all' to use all the data for both training and validation.
Defaults to fold 0""",
)
parser.add_argument(
"--msd_dataset_name",
type=str,
required=False,
default="Task04_Hippocampus", # The smallest dataset
help="""[OPTIONAL] Name of the MSD dataset to use. The options are
defined by the values of the MsdDataset enum as returned by the
get_msd_dataset_enum function""",
)
parser.add_argument(
"--always-preprocess",
action="store_true",
required=False,
help="""[OPTIONAL] Use this to force preprocessing the nnunet data
even if the preprocessed data is found to already exist""",
)
parser.add_argument(
"--server_address",
type=str,
required=False,
default="0.0.0.0:8080",
help="""[OPTIONAL] The server address for the clients to communicate
to the server through. Defaults to 0.0.0.0:8080""",
)
args = parser.parse_args()

# Create nnunet directory structure and set environment variables
nnUNet_raw = join(args.dataset_path, "nnunet_raw")
nnUNet_preprocessed = join(args.dataset_path, "nnunet_preprocessed")
if not exists(nnUNet_raw):
os.makedirs(nnUNet_raw)
if not exists(nnUNet_preprocessed):
os.makedirs(nnUNet_preprocessed)
os.environ["nnUNet_raw"] = nnUNet_raw
os.environ["nnUNet_preprocessed"] = nnUNet_preprocessed
os.environ["nnUNet_results"] = join(args.dataset_path, "nnunet_results")
log(INFO, "Setting nnunet environment variables")
log(INFO, f"\tnnUNet_raw: {nnUNet_raw}")
log(INFO, f"\tnnUNet_preprocessed: {nnUNet_preprocessed}")
log(INFO, f"\tnnUNet_results: {join(args.dataset_path, 'nnunet_results')}")

# Everything that uses nnunetv2 module can only be imported after
# environment variables are changed
from nnunetv2.dataset_conversion.convert_MSD_dataset import convert_msd_dataset

from research.picai.fl_nnunet.nnunet_client import nnUNetClient

# Check fold argument and start main method
fold: Union[int, str] = "all" if args.fold == "all" else int(args.fold)
main(
dataset_path=Path(args.dataset_path),
msd_dataset_name=args.msd_dataset_name,
server_address=args.server_address,
fold=fold,
always_preprocess=args.always_preprocess,
)
8 changes: 8 additions & 0 deletions examples/nnunet_example/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Parameters that describe the server
n_server_rounds: 1

# Parameters that describe the clients
n_clients: 1
local_epochs: 1

nnunet_config: 2d
122 changes: 122 additions & 0 deletions examples/nnunet_example/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import argparse
import json
import pickle
import warnings
from functools import partial
from typing import Optional

import yaml

with warnings.catch_warnings():
# Need to import lightning utilities now in order to avoid deprecation
# warnings. Ignore flake8 warning saying that it is unused
# lightning utilities is imported by some of the dependencies
# so by importing it now and filtering the warnings
# https://github.com/Lightning-AI/utilities/issues/119
warnings.filterwarnings("ignore", category=DeprecationWarning)
scarere marked this conversation as resolved.
Show resolved Hide resolved
import lightning_utilities # noqa: F401

import flwr as fl
import torch
from flwr.common.parameter import ndarrays_to_parameters
from flwr.common.typing import Config
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from research.picai.fl_nnunet.nnunet_server import NnUNetServer


def get_config(
current_server_round: int,
nnunet_config: str,
n_server_rounds: int,
batch_size: int,
n_clients: int,
nnunet_plans: Optional[str] = None,
local_epochs: Optional[int] = None,
local_steps: Optional[int] = None,
) -> Config:
# Create config
config: Config = {
"n_clients": n_clients,
"nnunet_config": nnunet_config,
"n_server_rounds": n_server_rounds,
"batch_size": batch_size,
**make_dict_with_epochs_or_steps(local_epochs, local_steps),
"current_server_round": current_server_round,
}

# Check if plans were provided
if nnunet_plans is not None:
plans_bytes = pickle.dumps(json.load(open(nnunet_plans, "r")))
config["nnunet_plans"] = plans_bytes

return config


def main(config: dict, server_address: str) -> None:
# Partial function with everything set except current server round
fit_config_fn = partial(
get_config,
n_clients=config["n_clients"],
nnunet_config=config["nnunet_config"],
n_server_rounds=config["n_server_rounds"],
batch_size=0, # Set this to 0 because we're not using it
nnunet_plans=config.get("nnunet_plans"),
local_epochs=config.get("local_epochs"),
local_steps=config.get("local_steps"),
)

if config.get("starting_checkpoint"):
model = torch.load(config["starting_checkpoint"])
# Of course nnunet stores their pytorch models differently.
params = ndarrays_to_parameters([val.cpu().numpy() for _, val in model["network_weights"].items()])
else:
params = None

strategy = FedAvg(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
min_available_clients=config["n_clients"],
on_fit_config_fn=fit_config_fn,
on_evaluate_config_fn=fit_config_fn, # Nothing changes for eval
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=params,
)

server = NnUNetServer(
client_manager=SimpleClientManager(),
strategy=strategy,
)

fl.server.start_server(
server=server,
server_address=server_address,
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

# Shutdown server
server.shutdown()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", action="store", type=str, help="Path to the configuration file")
parser.add_argument(
"--server-address",
type=str,
required=False,
default="0.0.0.0:8080",
help="""[OPTIONAL] The address to use for the server. Defaults to
0.0.0.0:8080""",
)

args = parser.parse_args()

with open(args.config_path, "r") as f:
config = yaml.safe_load(f)

main(config, server_address=args.server_address)
Loading