Skip to content

Commit

Permalink
formatting and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Jun 10, 2024
1 parent ff559a6 commit a13f1fc
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 144 deletions.
106 changes: 36 additions & 70 deletions examples/monai-2D-mednist/client/data.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,77 @@
import os
from math import floor
import random
import PIL

import numpy as np
import PIL
import torch
import torchvision
import yaml
from monai.apps import download_and_extract
from torch.utils.data import Dataset as Dataset_

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)

import yaml

import numpy as np

DATA_CLASSES = {'AbdomenCT': 0, 'BreastMRI': 1, 'CXR': 2, 'ChestCT': 3, 'Hand': 4, 'HeadCT': 5}
DATA_CLASSES = {"AbdomenCT": 0, "BreastMRI": 1, "CXR": 2, "ChestCT": 3, "Hand": 4, "HeadCT": 5}


def split_data(data_path='data/MedNIST', splits=100, validation_split=0.9):
def split_data(data_path="data/MedNIST", splits=100, validation_split=0.9):
# create clients
clients = {'client ' + str(i): {"train": [], "validation": []} for i in range(splits)}
# print("clients: ", clients)
clients = {"client " + str(i): {"train": [], "validation": []} for i in range(splits)}

for class_ in os.listdir(data_path):
if os.path.isdir(os.path.join(data_path, class_)):
# print("class_: ", class_)
patients_in_class = [os.path.join(class_, patient) for patient in
os.listdir(os.path.join(data_path, class_))]
# print(patients_in_class)
patients_in_class = [os.path.join(class_, patient) for patient in os.listdir(os.path.join(data_path, class_))]
np.random.shuffle(patients_in_class)
chops = np.int32(np.linspace(0, len(patients_in_class), splits + 1))
for split in range(splits):
# print("split ", split)
p = patients_in_class[chops[split]:chops[split + 1]]
p = patients_in_class[chops[split] : chops[split + 1]]
valsplit = np.int32(len(p) * validation_split)
# print("'client ' + str(split): " 'client ' + str(split))
# print("p[:valsplit]: ", p[:valsplit])
# clients['client' + str(split)]["train"] = 3

clients['client ' + str(split)]["train"] += p[:valsplit]
clients['client ' + str(split)]["validation"] += p[valsplit:]
clients["client " + str(split)]["train"] += p[:valsplit]
clients["client " + str(split)]["validation"] += p[valsplit:]

with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), 'w') as file:
with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), "w") as file:
yaml.dump(clients, file, default_flow_style=False)





def get_data(out_dir="data"):
"""Get data from the external repository.
:param out_dir: Path to data directory. If doesn't
:param out_dir: Path to data directory. If doesn't
:type data_dir: str
"""

# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz"
md5 = "0bc7306e7427e00ad1c5526a6677552d"

compressed_file = os.path.join(out_dir, "MedNIST.tar.gz")
data_dir = os.path.abspath(out_dir)
print('data_dir:', data_dir)

data_dir = os.path.abspath(out_dir)
print("data_dir:", data_dir)
if os.path.exists(data_dir):
print('path exist.')
print("path exist.")
if not os.path.exists(compressed_file):
print('compressed file does not exist, downloading and extracting data.')
print("compressed file does not exist, downloading and extracting data.")
download_and_extract(resource, compressed_file, data_dir, md5)
else:
print('files already exist.')
print("files already exist.")

split_data()



def get_classes(data_path):
"""Get a list of classes from the dataset
:param data_path: Path to data directory.
:type data_path: str
"""

if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/MedNIST")
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/MedNIST")

class_names = sorted(x for x in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, x)))
return(class_names)





return class_names


def load_data(data_path, sample_size=None, is_train=True):
Expand All @@ -106,28 +82,23 @@ def load_data(data_path, sample_size=None, is_train=True):
:param is_train: Whether to load training or test data.
:type is_train: bool
:return: Tuple of data and labels.
:rtype: tuple"""

:rtype: tuple
"""
if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/MedNIST")
class_names = get_classes(data_path)
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/MedNIST")

class_names = get_classes(data_path)
num_class = len(class_names)

image_files_all = [
[os.path.join(data_path, class_names[i], x) for x in os.listdir(os.path.join(data_path, class_names[i]))]
for i in range(num_class)
]

image_files_all = [[os.path.join(data_path, class_names[i], x) for x in os.listdir(os.path.join(data_path, class_names[i]))] for i in range(num_class)]

# To make the dataset small, we are using sample_size=100 images of each class.

if sample_size is None:
image_files = image_files_all

else:
image_files = [random.sample(inner_list, sample_size) for inner_list in image_files_all]

image_files = image_files_all

else:
image_files = [random.sample(inner_list, sample_size) for inner_list in image_files_all]

num_each = [len(image_files[i]) for i in range(num_class)]
image_files_list = []
image_class = []
Expand All @@ -141,9 +112,8 @@ def load_data(data_path, sample_size=None, is_train=True):
print(f"Image dimensions: {image_width} x {image_height}")
print(f"Label names: {class_names}")
print(f"Label counts: {num_each}")

val_frac = 0.1
#test_frac = 0.1
length = len(image_files_list)
indices = np.arange(length)
np.random.shuffle(indices)
Expand All @@ -165,7 +135,6 @@ def load_data(data_path, sample_size=None, is_train=True):
return val_x, val_y, class_names



class MedNISTDataset(torch.utils.data.Dataset):
def __init__(self, data_path, image_files, transforms):
self.data_path = data_path
Expand All @@ -176,12 +145,9 @@ def __len__(self):
return len(self.image_files)

def __getitem__(self, index):
return (self.transforms(os.path.join(self.data_path, self.image_files[index])), DATA_CLASSES[os.path.dirname(self.image_files[index])])

return (self.transforms(os.path.join(self.data_path,self.image_files[index])),
DATA_CLASSES[os.path.dirname(self.image_files[index])])

if __name__ == "__main__":
# Prepare data if not already done
#if not os.path.exists(abs_path + "/data"):
get_data()
#load_data('./data')
4 changes: 0 additions & 4 deletions examples/monai-2D-mednist/client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import torch
from monai.networks.nets import DenseNet121



from fedn.utils.helpers.helpers import get_helper

HELPER_MODULE = "numpyhelper"
Expand All @@ -17,9 +15,7 @@ def compile_model():
:return: The compiled model.
:rtype: torch.nn.Module
"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=num_classes).to(device)
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=6).to(device)
return model

Expand Down
46 changes: 17 additions & 29 deletions examples/monai-2D-mednist/client/train.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,27 @@
import math
import os
import sys

import yaml

import torch
import yaml
from data import MedNISTDataset
from model import load_parameters, save_parameters
from data import load_data, get_classes, MedNISTDataset

from fedn.utils.helpers.helpers import save_metadata

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))

from monai.data import decollate_batch, DataLoader
from monai.networks.nets import DenseNet121
import numpy as np
from monai.data import DataLoader
from monai.transforms import (
Activations,
EnsureChannelFirst,
AsDiscrete,
Compose,
EnsureChannelFirst,
LoadImage,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
)
from monai.utils import set_determinism
import numpy as np
from monai.data import decollate_batch, DataLoader

train_transforms = Compose(
[
Expand All @@ -36,7 +30,7 @@
ScaleIntensity(),
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5)
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
]
)

Expand All @@ -57,44 +51,38 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No
:param client_settings_path: path to a local client settings file.
:type client_settings_path: str
"""

if client_settings_path is None:
client_settings_path = os.environ.get("FEDN_CLIENT_SETTINGS_PATH", dir_path + "/client_settings.yaml")

print("client_settings_path: ", client_settings_path)
with open(client_settings_path, 'r') as fh: # Used by CJG for local training

with open(client_settings_path, "r") as fh: # Used by CJG for local training
try:
client_settings = dict(yaml.safe_load(fh))
except yaml.YAMLError as e:
raise


print("client settings: ", client_settings)
batch_size = client_settings['batch_size']
max_epochs = client_settings['local_epochs']
num_workers = client_settings['num_workers']
split_index = client_settings['split_index']
lr = client_settings['lr']
batch_size = client_settings["batch_size"]
max_epochs = client_settings["local_epochs"]
num_workers = client_settings["num_workers"]
split_index = client_settings["split_index"]
lr = client_settings["lr"]

if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH")


with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), 'r') as file:
with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), "r") as file:
clients = yaml.safe_load(file)

image_list = clients['client ' + str(split_index)]['train']
image_list = clients["client " + str(split_index)]["train"]

train_ds = MedNISTDataset(data_path='data/MedNIST', transforms=train_transforms,
image_files=image_list)
train_ds = MedNISTDataset(data_path="data/MedNIST", transforms=train_transforms, image_files=image_list)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)


# Load parmeters and initialize model
model = load_parameters(in_model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
loss_function = torch.nn.CrossEntropyLoss()

Expand Down
Loading

0 comments on commit a13f1fc

Please sign in to comment.