diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index d8478a921..d12643dd6 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -5,11 +5,39 @@ aggregator : init_state_path : save/fl_post_two_init.pbuf best_state_path : save/fl_post_two_best.pbuf last_state_path : save/fl_post_two_last.pbuf - rounds_to_train : 2 + rounds_to_train : &rounds_to_train 2 admins_endpoints_mapping: testfladmin@example.com: - GetExperimentStatus - SetStragglerCuttoffTime + - SetDynamicTaskArg + - GetDynamicTaskArg + + dynamictaskargs: &dynamictaskargs + train: + train_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 86400 # one day + val_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 86400 # one day + train_completion_dampener: # train_completed -> (train_completed)**(train_completion_dampener) NOTE: Value close to zero zero shifts non 0.0 completion rates much closer to 1.0 + admin_settable: True + min: -1.0 # inverts train_comleted, so this would be a way to have checkpoint_weighting = train_completed * data_size (as opposed to data_size / train_completed) + max: 1.0 # leaves completion rates as is + value: 1.0 + + aggregated_model_validation: + val_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 86400 # one day + collaborator : defaults : plan/defaults/collaborator.yaml @@ -17,6 +45,7 @@ collaborator : settings : delta_updates : false opt_treatment : CONTINUE_LOCAL + dynamictaskargs: *dynamictaskargs data_loader : defaults : plan/defaults/data_loader.yaml @@ -29,9 +58,10 @@ task_runner : defaults : plan/defaults/task_runner.yaml template : src.runner_nnunetv1.PyTorchNNUNetCheckpointTaskRunner settings : - device : cuda - gpu_num_string : '0' - nnunet_task : Task537_FLPost + device : cuda + gpu_num_string : '0' + nnunet_task : Task537_FLPost + actual_max_num_epochs : *rounds_to_train network : defaults : plan/defaults/network.yaml @@ -45,19 +75,39 @@ assigner : - name : train_and_validate percentage : 1.0 tasks : - # - aggregated_model_validation + - aggregated_model_validation - train - # - locally_tuned_model_validation + - locally_tuned_model_validation tasks : defaults : plan/defaults/tasks_torch.yaml + aggregated_model_validation: + function : validate + kwargs : + metrics : + - val_eval + - val_eval_C1 + - val_eval_C2 + - val_eval_C3 + - val_eval_C4 + apply : global train: function : train kwargs : metrics : - - train_loss - - val_eval + - train_loss epochs : 1 + locally_tuned_model_validation: + function : validate + kwargs : + metrics : + - val_eval + - val_eval_C1 + - val_eval_C2 + - val_eval_C3 + - val_eval_C4 + apply : local + from_checkpoint: true compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml @@ -66,4 +116,4 @@ straggler_handling_policy : template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling settings : straggler_cutoff_time : 600 - minimum_reporting : 2 \ No newline at end of file + minimum_reporting : 5 \ No newline at end of file diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index 3f24c9515..db3894e9d 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -246,11 +246,8 @@ def setup_fl_data(postopp_pardir, should be run using a virtual environment that has nnunet version 1 installed. args: - postopp_pardirs(list of str) : Parent directories for postopp data. The length of the list should either be - equal to num_insitutions, or one. If the length of the list is one and num_insitutions is not one, - the samples within that single directory will be used to create num_insititutions shards. - If the length of this list is equal to num_insitutions, the shards are defined by the samples within each string path. - Either way, all string paths within this list should piont to folders that have 'data' and 'labels' subdirectories with structure: + postopp_pardir(str) : Parent directory for postopp data. + This directory should have 'data' and 'labels' subdirectories, with structure: ├── data │ ├── AAAC_0 │ │ ├── 2008.03.30 @@ -298,7 +295,7 @@ def setup_fl_data(postopp_pardir, │ └── AAAC_extra_2008.12.10_final_seg.nii.gz └── report.yaml - three_digit_task_num(str): Should start with '5'. If num_institutions == N (see below), all N task numbers starting with this number will be used. + three_digit_task_num(str): Should start with '5'. task_name(str) : Any string task name. percent_train(float) : what percent of data is put into the training data split (rest to val) split_logic(str) : Determines how train/val split is performed @@ -336,7 +333,6 @@ def setup_fl_data(postopp_pardir, # Track the subjects and timestamps for each shard subject_to_timestamps = {} - print(f"\n######### CREATING SYMLINKS TO POSTOPP DATA #########\n") for postopp_subject_dir in all_subjects: subject_to_timestamps[postopp_subject_dir] = symlink_one_subject(postopp_subject_dir=postopp_subject_dir, postopp_data_dirpath=postopp_data_dirpath, @@ -356,12 +352,12 @@ def setup_fl_data(postopp_pardir, # Now call the os process to preprocess the data print(f"\n######### OS CALL TO PREPROCESS DATA #########\n") if plans_path: - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"]) - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl2d", "None", "--verify_dataset_integrity"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-pl2d", "None", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"]) plans_identifier_for_model_writing = shared_plans_identifier else: # this is a preliminary data setup, which will be passed over to the pretrained plan similar to above after we perform training on this plan - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity", "-pl2d", "None"]) plans_identifier_for_model_writing = local_plans_identifier # Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py index 4ebd1f9e7..a647a2f44 100644 --- a/examples/fl_post/fl/project/nnunet_model_setup.py +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -7,12 +7,13 @@ def train_on_task(task, network, network_trainer, fold, cuda_device, plans_identifier, continue_training=False, current_epoch=0): os.environ['CUDA_VISIBLE_DEVICES']=cuda_device - print(f"###########\nStarting training for task: {task}\n") - train_nnunet(epochs=1, - current_epoch = current_epoch, - network = network, + print(f"###########\nStarting training a single epoch for task: {task}\n") + # Function below is now hard coded for a single epoch of training. + train_nnunet(actual_max_num_epochs=1000, + fl_round=current_epoch, + network=network, task=task, - network_trainer = network_trainer, + network_trainer=network_trainer, fold=fold, continue_training=continue_training, p=plans_identifier) @@ -60,61 +61,13 @@ def delete_2d_data(network, task, plans_identifier): print(f"\n###########\nDeleting 2D data directory at: {data_dir_2d} \n##############\n") shutil.rmtree(data_dir_2d) -""" -def normalize_architecture(reference_plan_path, target_plan_path): - - # Take the plan file from reference_plan_path and use its contents to copy architecture into target_plan_path - # NOTE: Here we perform some checks and protection steps so that our assumptions if not correct will more - likely leed to an exception. - - - - assert_same_keys = ['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'num_classes', 'all_classes', 'base_num_features', - 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', - 'transpose_backward', 'preprocessor_name', 'conv_per_stage', 'data_identifier'] - copy_over_keys = ['plans_per_stage'] - nullify_keys = ['original_spacings', 'original_sizes'] - leave_alone_keys = ['list_of_npz_files', 'preprocessed_data_folder', 'dataset_properties'] - - - # check I got all keys here - assert set(copy_over_keys).union(set(assert_same_keys)).union(set(nullify_keys)).union(set(leave_alone_keys)) == set(['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'dataset_properties', 'list_of_npz_files', 'original_spacings', 'original_sizes', 'preprocessed_data_folder', 'num_classes', 'all_classes', 'base_num_features', 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', 'transpose_backward', 'data_identifier', 'plans_per_stage', 'preprocessor_name', 'conv_per_stage']) - - def get_pickle_obj(path): - with open(path, 'rb') as _file: - plan= pkl.load(_file) - return plan - - def write_pickled_obj(obj, path): - with open(path, 'wb') as _file: - pkl.dump(obj, _file) - - reference_plan = get_pickle_obj(path=reference_plan_path) - target_plan = get_pickle_obj(path=target_plan_path) - - for key in assert_same_keys: - if reference_plan[key] != target_plan[key]: - raise ValueError(f"normalize architecture failed since the reference and target plans differed in at least key: {key}") - for key in copy_over_keys: - target_plan[key] = reference_plan[key] - for key in nullify_keys: - target_plan[key] = None - # leave alone keys are left alone :) - - # write back to target plan - write_pickled_obj(obj=target_plan, path=target_plan_path) -""" def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, fold, init_model_path, init_model_info_path, plans_path, cuda_device='0'): """ Note that plans_identifier here is designated from fl_setup.py and is an alternative to the default one due to overwriting of the local plans by a globally distributed one """ - # Remove 2D data and 2D data info if appropriate - if network != '2d': - delete_2d_data(network=network, task=task, plans_identifier=plans_identifier) - # get or create architecture info model_folder = get_model_folder(network=network, @@ -141,7 +94,7 @@ def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, shutil.copyfile(src=col_paths['final_model_path'],dst=col_paths['initial_model_path']) shutil.copyfile(src=col_paths['final_model_info_path'],dst=col_paths['initial_model_info_path']) else: - print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\ncol_paths were: {col_paths}\n\n") + print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\n") shutil.copy(src=plans_path,dst=col_paths['plans_path']) shutil.copyfile(src=init_model_path,dst=col_paths['initial_model_path']) shutil.copyfile(src=init_model_info_path,dst=col_paths['initial_model_info_path']) diff --git a/examples/fl_post/fl/project/nnunet_setup.py b/examples/fl_post/fl/project/nnunet_setup.py index 0106b8f95..81f6787d8 100644 --- a/examples/fl_post/fl/project/nnunet_setup.py +++ b/examples/fl_post/fl/project/nnunet_setup.py @@ -28,7 +28,7 @@ def main(postopp_pardir, plans_path=None, local_plans_identifier=local_plans_identifier, shared_plans_identifier=shared_plans_identifier, - overwrite_nnunet_datadirs=False, + overwrite_nnunet_datadirs=True, timestamp_selection='all', cuda_device='0', verbose=False): @@ -105,7 +105,7 @@ def main(postopp_pardir, fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all' local_plans_identifier(str) : Used in the plans file naming for collaborators that will be performing local training to produce a pretrained model. shared_plans_identifier(str) : Used in the plans file naming for the shared plan distributed across the federation. - overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories with task numbers from first_three_digit_task_num to that plus one les than number of insitutions. + overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories for given task number and name. task_name(str) : Any string task name. timestamp_selection(str) : Indicates how to determine the timestamp to pick. Only 'earliest', 'latest', or 'all' are supported. for each subject ID at the source: 'latest' and 'earliest' are the only ones supported so far @@ -126,6 +126,7 @@ def main(postopp_pardir, # task_folder_info is a zipped lists indexed over tasks (collaborators) # zip(task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs) + col_paths = setup_fl_data(postopp_pardir=postopp_pardir, three_digit_task_num=three_digit_task_num, task_name=task_name, diff --git a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py index 1fe83a4f5..68cbbbc40 100644 --- a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py +++ b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py @@ -33,4 +33,4 @@ def get_valid_data_size(self): return self.valid_data_size def get_task_name(self): - return self.task_name + return self.task_name \ No newline at end of file diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 78869d4c1..2e5df028b 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -54,14 +54,17 @@ def seed_everything(seed=1234): torch.backends.cudnn.deterministic = True -def train_nnunet(epochs, - current_epoch, +def train_nnunet(actual_max_num_epochs, + fl_round, + val_epoch=True, + train_epoch=True, + train_cutoff=np.inf, + val_cutoff=np.inf, network='3d_fullres', network_trainer='nnUNetTrainerV2', task='Task543_FakePostOpp_More', fold='0', continue_training=True, - validation_only=False, c=False, p=plans_param, use_compressed_data=False, @@ -78,9 +81,13 @@ def train_nnunet(epochs, pretrained_weights=None): """ + actual_max_num_epochs (int): Provides the number of epochs intended to be trained over the course of the whole federation (for lr scheduling) + (this needs to be held constant outside of individual calls to this function so that the lr is consistetly scheduled) + fl_round (int): Federated round, equal to the epoch used for the model (in lr scheduling) + val_epoch (bool) : Will validation be performed + train_epoch (bool) : Will training run (rather than val only) task (int): can be task name or task id fold: "0, 1, ..., 5 or 'all'" - validation_only: use this if you want to only run the validation c: use this if you want to continue a training p: plans identifier. Only change this if you created a custom experiment planner use_compressed_data: "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data " @@ -132,7 +139,6 @@ def __init__(self, **kwargs): fold = args.fold network = args.network network_trainer = args.network_trainer - validation_only = args.validation_only plans_identifier = args.p find_lr = args.find_lr disable_postprocessing_on_folds = args.disable_postprocessing_on_folds @@ -198,6 +204,7 @@ def __init__(self, **kwargs): trainer = trainer_class( plans_file, fold, + actual_max_num_epochs=actual_max_num_epochs, output_folder=output_folder_name, dataset_directory=dataset_directory, batch_dice=batch_dice, @@ -206,6 +213,30 @@ def __init__(self, **kwargs): deterministic=deterministic, fp16=run_mixed_precision, ) + + + trainer.initialize(True) + + if os.getenv("PREP_INCREMENT_STEP", None) == "from_dataset_properties": + trainer.save_checkpoint( + join(trainer.output_folder, "model_final_checkpoint.model") + ) + print("Preparation round: Model-averaging") + return + + if find_lr: + trainer.find_lr(num_iters=self.actual_max_num_epochs) + else: + if args.continue_training: + # -c was set, continue a previous training and ignore pretrained weights + trainer.load_latest_checkpoint() + elif (not args.continue_training) and (args.pretrained_weights is not None): + # we start a new training. If pretrained_weights are set, use them + load_pretrained_weights(trainer.network, args.pretrained_weights) + else: + # new training without pretraine weights, do nothing + pass + # we want latest checkoint only (not best or any intermediate) trainer.save_final_checkpoint = ( True # whether or not to save the final checkpoint @@ -221,61 +252,44 @@ def __init__(self, **kwargs): trainer.save_latest_only = ( True # if false it will not store/overwrite _latest but separate files each ) - trainer.max_num_epochs = current_epoch + epochs - trainer.epoch = current_epoch - - # TODO: call validation separately - trainer.initialize(not validation_only) - - if os.getenv("PREP_INCREMENT_STEP", None) == "from_dataset_properties": - trainer.save_checkpoint( - join(trainer.output_folder, "model_final_checkpoint.model") - ) - print("Preparation round: Model-averaging") - return - - if find_lr: - trainer.find_lr() - else: - if not validation_only: - if args.continue_training: - # -c was set, continue a previous training and ignore pretrained weights - trainer.load_latest_checkpoint() - elif (not args.continue_training) and (args.pretrained_weights is not None): - # we start a new training. If pretrained_weights are set, use them - load_pretrained_weights(trainer.network, args.pretrained_weights) - else: - # new training without pretraine weights, do nothing - pass - - trainer.run_training() - else: - # if valbest: - # trainer.load_best_checkpoint(train=False) - # else: - # trainer.load_final_checkpoint(train=False) - trainer.load_latest_checkpoint() - trainer.network.eval() - - # if fold == "all": - # print("--> fold == 'all'") - # print("--> DONE") - # else: - # # predict validation - # trainer.validate( - # save_softmax=args.npz, - # validation_folder_name=val_folder, - # run_postprocessing_on_folds=not disable_postprocessing_on_folds, - # overwrite=args.val_disable_overwrite, - # ) - - # if network == "3d_lowres" and not args.disable_next_stage_pred: - # print("predicting segmentations for the next stage of the cascade") - # predict_next_stage( - # trainer, - # join( - # dataset_directory, - # trainer.plans["data_identifier"] + "_stage%d" % 1, - # ), - # ) + trainer.max_num_epochs = fl_round + 1 + trainer.epoch = fl_round + + # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data + # point is expected to be seen epochs number of times + + num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) + num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) + + # the nnunet trainer attributes have a different naming convention than I am using + trainer.num_batches_per_epoch = num_train_batches_per_epoch + trainer.num_val_batches_per_epoch = num_val_batches_per_epoch + + batches_applied_train, \ + batches_applied_val, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = trainer.run_training(train_cutoff=train_cutoff, + val_cutoff=val_cutoff, + val_epoch=val_epoch, + train_epoch=train_epoch) + + train_completed = batches_applied_train / float(num_train_batches_per_epoch) + val_completed = batches_applied_val / float(num_val_batches_per_epoch) + + return train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 + + diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 26f184522..3191d9a57 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -2,11 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 """ -Contributors: Micah Sheller, Patrick Foley, Brandon Edwards - DELETEME? +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards """ # TODO: Clean up imports -# TODO: ask Micah if this has to be changed (most probably no) import os import subprocess @@ -36,12 +35,15 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): def __init__(self, nnunet_task=None, config_path=None, + actual_max_num_epochs=1000, **kwargs): """Initialize. Args: - config_path(str) : Path to the configuration file used by the training and validation script. - kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). + nnunet_task (str) : Task string used to identify the data and model folders + config_path(str) : Path to the configuration file used by the training and validation script. + actual_max_num_epochs (int) : Number of epochs for which this collaborator's model will be trained, should match the total rounds of federation in which this runner is participating + kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). TODO: """ @@ -72,6 +74,14 @@ def __init__(self, ) self.config_path = config_path + self.actual_max_num_epochs=actual_max_num_epochs + + # self.task_completed is a dictionary of task to amount completed as a float in [0,1] + # Values will be dynamically updated + # TODO: Tasks are hard coded for now + self.task_completed = {'aggregated_model_validation': 1.0, + 'train': 1.0, + 'locally_tuned_model_validation': 1.0} def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): @@ -108,11 +118,10 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): # get device for correct placement of tensors device = self.device - checkpoint_dict = self.load_checkpoint(map_location=device) + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load, map_location=device) epoch = checkpoint_dict['epoch'] new_state = {} # grabbing keys from the checkpoint state dict, poping from the tensor_dict - # Brandon DEBUGGING seen_keys = [] for k in checkpoint_dict['state_dict']: if k not in seen_keys: @@ -134,91 +143,113 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): return epoch - def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): + def train(self, col_name, round_num, input_tensor_dict, epochs, val_cutoff_time=np.inf, train_cutoff_time=np.inf, train_completion_dampener=0.0, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform training for a specified number of epochs.""" self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint - current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - # 2. Train function existing externally - # Some todo inside function below - # TODO: test for off-by-one error - # TODO: we need to disable validation if possible, and separately call validation - - # FIXME: we need to understand how to use round_num instead of current_epoch - # this will matter in straggler handling cases - # TODO: Should we put this in a separate process? - train_nnunet(epochs=epochs, current_epoch=current_epoch, task=self.data_loader.get_task_name()) - - # 3. Load metrics from checkpoint - (all_tr_losses, all_val_losses, all_val_losses_tr_mode, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] - # these metrics are appended to the checkopint each epoch, so we select the most recent epoch - metrics = {'train_loss': all_tr_losses[-1], - 'val_eval': all_val_eval_metrics[-1]} - - return self.convert_results_to_tensorkeys(col_name, round_num, metrics) - - - - def validate(self, col_name, round_num, input_tensor_dict, **kwargs): + self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"Training for round:{round_num}") + train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, + fl_round=round_num, + train_cutoff=train_cutoff_time, + val_cutoff = val_cutoff_time, + task=self.data_loader.get_task_name(), + val_epoch=True, + train_epoch=True) + + # dampen the train_completion """ - Run the trained model on validation data; report results. - - Parameters - ---------- - input_tensor_dict : either the last aggregated or locally trained model - - Returns - ------- - output_tensor_dict : {TensorKey: nparray} (these correspond to acc, - precision, f1_score, etc.) + values in range: (0, 1] with values near 0.0 making all train_completion rates shift nearer to 1.0, thus making the + trained model update weighting during aggregation stay closer to the plain data size weighting + specifically, update_weight = train_data_size / train_completed**train_completion_dampener """ + train_completed = train_completed**train_completion_dampener - raise NotImplementedError() - - """ - TBD - for now commenting out + # update amount of task completed + self.task_completed['train'] = train_completed + self.task_completed['locally_tuned_model_validation'] = val_completed - self.rebuild_model(round_num, input_tensor_dict, validation=True) + # 3. Prepare metrics + metrics = {'train_loss': this_ave_train_loss} - # 1. Save model in native format - self.save_native(self.mlcube_model_in_path) + global_tensor_dict, local_tensor_dict = self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=True) - # 2. Call MLCube validate task - platform_yaml = os.path.join(self.mlcube_dir, 'platforms', '{}.yaml'.format(self.mlcube_runner_type)) - task_yaml = os.path.join(self.mlcube_dir, 'run', 'evaluate.yaml') - proc = subprocess.run(["mlcube_docker", - "run", - "--mlcube={}".format(self.mlcube_dir), - "--platform={}".format(platform_yaml), - "--task={}".format(task_yaml)]) - - # 3. Load any metrics - metrics = self.load_metrics(os.path.join(self.mlcube_dir, 'workspace', 'metrics', 'evaluate_metrics.json')) - - # set the validation data size - sample_count = int(metrics.pop(self.evaluation_sample_count_key)) - self.data_loader.set_valid_data_size(sample_count) + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + + return global_tensor_dict, local_tensor_dict + - # 4. Convert to tensorkeys - - origin = col_name - suffix = 'validate' - if kwargs['apply'] == 'local': - suffix += '_local' + def validate(self, col_name, round_num, input_tensor_dict, val_cutoff_time=np.inf, from_checkpoint=False, **kwargs): + # TODO: Figure out the right name to use for this method and the default assigner + """Perform validation.""" + + if not from_checkpoint: + self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) + # 1. Insert tensor_dict info into checkpoint + self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"Validating for round:{round_num}") + # 2. Train/val function existing externally + # Some todo inside function below + train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, + fl_round=round_num, + train_cutoff=0, + val_cutoff = val_cutoff_time, + task=self.data_loader.get_task_name(), + val_epoch=True, + train_epoch=False) + # double check + if train_completed != 0.0: + raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") + + # update amount of task completed + self.task_completed['aggregated_model_validation'] = val_completed + + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + + + # 3. Prepare metrics + metrics = {'val_eval': this_val_eval_metrics, + 'val_eval_C1': this_val_eval_metrics_C1, + 'val_eval_C2': this_val_eval_metrics_C2, + 'val_eval_C3': this_val_eval_metrics_C3, + 'val_eval_C4': this_val_eval_metrics_C4} else: - suffix += '_agg' - tags = ('metric', suffix) - output_tensor_dict = { - TensorKey( - metric_name, origin, round_num, True, tags - ): np.array(metrics[metric_name]) - for metric_name in metrics - } - - return output_tensor_dict, {} - - """ + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) + + all_tr_losses, \ + all_val_losses, \ + all_val_losses_tr_mode, \ + all_val_eval_metrics, \ + all_val_eval_metrics_C1, \ + all_val_eval_metrics_C2, \ + all_val_eval_metrics_C3, \ + all_val_eval_metrics_C4 = checkpoint_dict['plot_stuff'] + # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after + metrics = {'val_eval': all_val_eval_metrics[-1], + 'val_eval_C1': all_val_eval_metrics_C1[-1], + 'val_eval_C2': all_val_eval_metrics_C2[-1], + 'val_eval_C3': all_val_eval_metrics_C3[-1], + 'val_eval_C4': all_val_eval_metrics_C4[-1]} + + return self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=False) def load_metrics(self, filepath): @@ -230,4 +261,38 @@ def load_metrics(self, filepath): with open(filepath) as json_file: metrics = json.load(json_file) return metrics - """ \ No newline at end of file + """ + + + def get_train_data_size(self, task_name=None): + """Get the number of training examples. + + It will be used for weighted averaging in aggregation. + This overrides the parent class method, + allowing dynamic weighting by storing recent appropriate weights in class attributes. + + Returns: + int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema + """ + if not task_name: + return self.data_loader.get_train_data_size() + else: + # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] + return int(np.ceil(self.task_completed[task_name]**(-1) * self.data_loader.get_train_data_size())) + + + def get_valid_data_size(self, task_name=None): + """Get the number of training examples. + + It will be used for weighted averaging in aggregation. + This overrides the parent class method, + allowing dynamic weighting by storing recent appropriate weights in class attributes. + + Returns: + int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema + """ + if not task_name: + return self.data_loader.get_valid_data_size() + else: + # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] + return int(np.ceil(self.task_completed[task_name]**(-1) * self.data_loader.get_valid_data_size())) diff --git a/examples/fl_post/fl/project/src/runner_pt_chkpt.py b/examples/fl_post/fl/project/src/runner_pt_chkpt.py index 6ab7851b9..a7fbd2056 100644 --- a/examples/fl_post/fl/project/src/runner_pt_chkpt.py +++ b/examples/fl_post/fl/project/src/runner_pt_chkpt.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -Contributors: Micah Sheller, Patrick Foley, Brandon Edwards - DELETEME? +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards """ # TODO: Clean up imports @@ -82,12 +82,10 @@ def __init__(self, self.replace_checkpoint(self.checkpoint_path_initial) - def load_checkpoint(self, checkpoint_path=None, map_location=None): + def load_checkpoint(self, checkpoint_path, map_location=None): """ Function used to load checkpoint from disk. """ - if not checkpoint_path: - checkpoint_path = self.checkpoint_path_load checkpoint_dict = torch.load(checkpoint_path, map_location=map_location) return checkpoint_dict @@ -124,7 +122,7 @@ def get_required_tensorkeys_for_function(self, func_name, **kwargs): return self.required_tensorkeys_for_function[func_name] def reset_opt_vars(self): - current_checkpoint_dict = self.load_checkpoint() + current_checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) initial_checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_initial) derived_opt_state_dict = self._get_optimizer_state(checkpoint_dict=initial_checkpoint_dict) self._set_optimizer_state(derived_opt_state_dict=derived_opt_state_dict, @@ -172,7 +170,7 @@ def read_tensors_from_checkpoint(self, with_opt_vars): dict: Tensor dictionary {**dict, **optimizer_dict} """ - checkpoint_dict = self.load_checkpoint() + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) state = to_cpu_numpy(checkpoint_dict['state_dict']) if with_opt_vars: opt_state = self._get_optimizer_state(checkpoint_dict=checkpoint_dict) @@ -255,7 +253,9 @@ def _read_opt_state_from_checkpoint(self, checkpoint_dict): return derived_opt_state_dict - def convert_results_to_tensorkeys(self, col_name, round_num, metrics): + def convert_results_to_tensorkeys(self, col_name, round_num, metrics, insert_model): + # insert_model determined whether or not to include the model in the return dictionaries + # 5. Convert to tensorkeys # output metric tensors (scalar) @@ -268,11 +268,14 @@ def convert_results_to_tensorkeys(self, col_name, round_num, metrics): metrics[metric_name] ) for metric_name in metrics} - # output model tensors (Doesn't include TensorKey) - output_model_dict = self.get_tensor_dict(with_opt_vars=True) - global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(logger=self.logger, - tensor_dict=output_model_dict, - **self.tensor_dict_split_fn_kwargs) + if insert_model: + # output model tensors (Doesn't include TensorKey) + output_model_dict = self.get_tensor_dict(with_opt_vars=True) + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(logger=self.logger, + tensor_dict=output_model_dict, + **self.tensor_dict_split_fn_kwargs) + else: + global_model_dict, local_model_dict = {}, {} # create global tensorkeys global_tensorkey_model_dict = {