Skip to content

Commit

Permalink
Merge pull request #7 from hasan7n/be_enable_partial_epochs
Browse files Browse the repository at this point in the history
Enabling partial epochs
  • Loading branch information
brandon-edwards authored Oct 30, 2024
2 parents 26b4337 + c396483 commit 128e28b
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 229 deletions.
68 changes: 59 additions & 9 deletions examples/fl_post/fl/mlcube/workspace/training_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,47 @@ aggregator :
init_state_path : save/fl_post_two_init.pbuf
best_state_path : save/fl_post_two_best.pbuf
last_state_path : save/fl_post_two_last.pbuf
rounds_to_train : 2
rounds_to_train : &rounds_to_train 2
admins_endpoints_mapping:
[email protected]:
- GetExperimentStatus
- SetStragglerCuttoffTime
- SetDynamicTaskArg
- GetDynamicTaskArg

dynamictaskargs: &dynamictaskargs
train:
train_cutoff_time:
admin_settable: True
min: 10 # 10 seconds
max: 86400 # one day
value: 86400 # one day
val_cutoff_time:
admin_settable: True
min: 10 # 10 seconds
max: 86400 # one day
value: 86400 # one day
train_completion_dampener: # train_completed -> (train_completed)**(train_completion_dampener) NOTE: Value close to zero zero shifts non 0.0 completion rates much closer to 1.0
admin_settable: True
min: -1.0 # inverts train_comleted, so this would be a way to have checkpoint_weighting = train_completed * data_size (as opposed to data_size / train_completed)
max: 1.0 # leaves completion rates as is
value: 1.0

aggregated_model_validation:
val_cutoff_time:
admin_settable: True
min: 10 # 10 seconds
max: 86400 # one day
value: 86400 # one day


collaborator :
defaults : plan/defaults/collaborator.yaml
template : openfl.component.Collaborator
settings :
delta_updates : false
opt_treatment : CONTINUE_LOCAL
dynamictaskargs: *dynamictaskargs

data_loader :
defaults : plan/defaults/data_loader.yaml
Expand All @@ -29,9 +58,10 @@ task_runner :
defaults : plan/defaults/task_runner.yaml
template : src.runner_nnunetv1.PyTorchNNUNetCheckpointTaskRunner
settings :
device : cuda
gpu_num_string : '0'
nnunet_task : Task537_FLPost
device : cuda
gpu_num_string : '0'
nnunet_task : Task537_FLPost
actual_max_num_epochs : *rounds_to_train

network :
defaults : plan/defaults/network.yaml
Expand All @@ -45,19 +75,39 @@ assigner :
- name : train_and_validate
percentage : 1.0
tasks :
# - aggregated_model_validation
- aggregated_model_validation
- train
# - locally_tuned_model_validation
- locally_tuned_model_validation

tasks :
defaults : plan/defaults/tasks_torch.yaml
aggregated_model_validation:
function : validate
kwargs :
metrics :
- val_eval
- val_eval_C1
- val_eval_C2
- val_eval_C3
- val_eval_C4
apply : global
train:
function : train
kwargs :
metrics :
- train_loss
- val_eval
- train_loss
epochs : 1
locally_tuned_model_validation:
function : validate
kwargs :
metrics :
- val_eval
- val_eval_C1
- val_eval_C2
- val_eval_C3
- val_eval_C4
apply : local
from_checkpoint: true

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
Expand All @@ -66,4 +116,4 @@ straggler_handling_policy :
template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling
settings :
straggler_cutoff_time : 600
minimum_reporting : 2
minimum_reporting : 5
16 changes: 6 additions & 10 deletions examples/fl_post/fl/project/nnunet_data_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,8 @@ def setup_fl_data(postopp_pardir,
should be run using a virtual environment that has nnunet version 1 installed.
args:
postopp_pardirs(list of str) : Parent directories for postopp data. The length of the list should either be
equal to num_insitutions, or one. If the length of the list is one and num_insitutions is not one,
the samples within that single directory will be used to create num_insititutions shards.
If the length of this list is equal to num_insitutions, the shards are defined by the samples within each string path.
Either way, all string paths within this list should piont to folders that have 'data' and 'labels' subdirectories with structure:
postopp_pardir(str) : Parent directory for postopp data.
This directory should have 'data' and 'labels' subdirectories, with structure:
├── data
│ ├── AAAC_0
│ │ ├── 2008.03.30
Expand Down Expand Up @@ -298,7 +295,7 @@ def setup_fl_data(postopp_pardir,
│ └── AAAC_extra_2008.12.10_final_seg.nii.gz
└── report.yaml
three_digit_task_num(str): Should start with '5'. If num_institutions == N (see below), all N task numbers starting with this number will be used.
three_digit_task_num(str): Should start with '5'.
task_name(str) : Any string task name.
percent_train(float) : what percent of data is put into the training data split (rest to val)
split_logic(str) : Determines how train/val split is performed
Expand Down Expand Up @@ -336,7 +333,6 @@ def setup_fl_data(postopp_pardir,
# Track the subjects and timestamps for each shard
subject_to_timestamps = {}

print(f"\n######### CREATING SYMLINKS TO POSTOPP DATA #########\n")
for postopp_subject_dir in all_subjects:
subject_to_timestamps[postopp_subject_dir] = symlink_one_subject(postopp_subject_dir=postopp_subject_dir,
postopp_data_dirpath=postopp_data_dirpath,
Expand All @@ -356,12 +352,12 @@ def setup_fl_data(postopp_pardir,
# Now call the os process to preprocess the data
print(f"\n######### OS CALL TO PREPROCESS DATA #########\n")
if plans_path:
subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"])
subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"])
subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl2d", "None", "--verify_dataset_integrity"])
subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-pl2d", "None", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"])
plans_identifier_for_model_writing = shared_plans_identifier
else:
# this is a preliminary data setup, which will be passed over to the pretrained plan similar to above after we perform training on this plan
subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"])
subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity", "-pl2d", "None"])
plans_identifier_for_model_writing = local_plans_identifier

# Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val
Expand Down
61 changes: 7 additions & 54 deletions examples/fl_post/fl/project/nnunet_model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

def train_on_task(task, network, network_trainer, fold, cuda_device, plans_identifier, continue_training=False, current_epoch=0):
os.environ['CUDA_VISIBLE_DEVICES']=cuda_device
print(f"###########\nStarting training for task: {task}\n")
train_nnunet(epochs=1,
current_epoch = current_epoch,
network = network,
print(f"###########\nStarting training a single epoch for task: {task}\n")
# Function below is now hard coded for a single epoch of training.
train_nnunet(actual_max_num_epochs=1000,
fl_round=current_epoch,
network=network,
task=task,
network_trainer = network_trainer,
network_trainer=network_trainer,
fold=fold,
continue_training=continue_training,
p=plans_identifier)
Expand Down Expand Up @@ -60,61 +61,13 @@ def delete_2d_data(network, task, plans_identifier):
print(f"\n###########\nDeleting 2D data directory at: {data_dir_2d} \n##############\n")
shutil.rmtree(data_dir_2d)

"""
def normalize_architecture(reference_plan_path, target_plan_path):
# Take the plan file from reference_plan_path and use its contents to copy architecture into target_plan_path

# NOTE: Here we perform some checks and protection steps so that our assumptions if not correct will more
likely leed to an exception.
assert_same_keys = ['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'num_classes', 'all_classes', 'base_num_features',
'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward',
'transpose_backward', 'preprocessor_name', 'conv_per_stage', 'data_identifier']
copy_over_keys = ['plans_per_stage']
nullify_keys = ['original_spacings', 'original_sizes']
leave_alone_keys = ['list_of_npz_files', 'preprocessed_data_folder', 'dataset_properties']
# check I got all keys here
assert set(copy_over_keys).union(set(assert_same_keys)).union(set(nullify_keys)).union(set(leave_alone_keys)) == set(['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'dataset_properties', 'list_of_npz_files', 'original_spacings', 'original_sizes', 'preprocessed_data_folder', 'num_classes', 'all_classes', 'base_num_features', 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', 'transpose_backward', 'data_identifier', 'plans_per_stage', 'preprocessor_name', 'conv_per_stage'])
def get_pickle_obj(path):
with open(path, 'rb') as _file:
plan= pkl.load(_file)
return plan
def write_pickled_obj(obj, path):
with open(path, 'wb') as _file:
pkl.dump(obj, _file)
reference_plan = get_pickle_obj(path=reference_plan_path)
target_plan = get_pickle_obj(path=target_plan_path)
for key in assert_same_keys:
if reference_plan[key] != target_plan[key]:
raise ValueError(f"normalize architecture failed since the reference and target plans differed in at least key: {key}")
for key in copy_over_keys:
target_plan[key] = reference_plan[key]
for key in nullify_keys:
target_plan[key] = None
# leave alone keys are left alone :)
# write back to target plan
write_pickled_obj(obj=target_plan, path=target_plan_path)
"""

def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, fold, init_model_path, init_model_info_path, plans_path, cuda_device='0'):
"""
Note that plans_identifier here is designated from fl_setup.py and is an alternative to the default one due to overwriting of the local plans by a globally distributed one
"""

# Remove 2D data and 2D data info if appropriate
if network != '2d':
delete_2d_data(network=network, task=task, plans_identifier=plans_identifier)

# get or create architecture info

model_folder = get_model_folder(network=network,
Expand All @@ -141,7 +94,7 @@ def trim_data_and_setup_model(task, network, network_trainer, plans_identifier,
shutil.copyfile(src=col_paths['final_model_path'],dst=col_paths['initial_model_path'])
shutil.copyfile(src=col_paths['final_model_info_path'],dst=col_paths['initial_model_info_path'])
else:
print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\ncol_paths were: {col_paths}\n\n")
print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\n")
shutil.copy(src=plans_path,dst=col_paths['plans_path'])
shutil.copyfile(src=init_model_path,dst=col_paths['initial_model_path'])
shutil.copyfile(src=init_model_info_path,dst=col_paths['initial_model_info_path'])
Expand Down
5 changes: 3 additions & 2 deletions examples/fl_post/fl/project/nnunet_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main(postopp_pardir,
plans_path=None,
local_plans_identifier=local_plans_identifier,
shared_plans_identifier=shared_plans_identifier,
overwrite_nnunet_datadirs=False,
overwrite_nnunet_datadirs=True,
timestamp_selection='all',
cuda_device='0',
verbose=False):
Expand Down Expand Up @@ -105,7 +105,7 @@ def main(postopp_pardir,
fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all'
local_plans_identifier(str) : Used in the plans file naming for collaborators that will be performing local training to produce a pretrained model.
shared_plans_identifier(str) : Used in the plans file naming for the shared plan distributed across the federation.
overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories with task numbers from first_three_digit_task_num to that plus one les than number of insitutions.
overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories for given task number and name.
task_name(str) : Any string task name.
timestamp_selection(str) : Indicates how to determine the timestamp to pick. Only 'earliest', 'latest', or 'all' are supported.
for each subject ID at the source: 'latest' and 'earliest' are the only ones supported so far
Expand All @@ -126,6 +126,7 @@ def main(postopp_pardir,

# task_folder_info is a zipped lists indexed over tasks (collaborators)
# zip(task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs)

col_paths = setup_fl_data(postopp_pardir=postopp_pardir,
three_digit_task_num=three_digit_task_num,
task_name=task_name,
Expand Down
2 changes: 1 addition & 1 deletion examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def get_valid_data_size(self):
return self.valid_data_size

def get_task_name(self):
return self.task_name
return self.task_name
Loading

0 comments on commit 128e28b

Please sign in to comment.