From 369f79e019eea365e043c853b4e3ae3a17d87cf1 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 7 Nov 2024 11:45:44 -0800 Subject: [PATCH] initial changes to set seeds for nnunet data setup train/val split --- .../fl_post/fl/project/nnunet_data_setup.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index d36e18683..edc163cac 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -115,14 +115,17 @@ def doublecheck_postopp_pardir(postopp_pardir, verbose=False): raise ValueError(f"'labels' must be a subdirectory of postopp_src_pardir:{postopp_pardir}, but it is not.") -def split_by_subject(subject_to_timestamps, percent_train, verbose=False): +def split_by_subject(subject_to_timestamps, percent_train,split_seed, verbose=False): """ NOTE: An attempt is made to put percent_train of the total subjects into train (as opposed to val) regardless of how many timestamps there are for each subject. No subject is allowed to have samples in both train and val. """ subjects = list(subject_to_timestamps.keys()) - np.random.shuffle(subjects) + + # create a random number generator with our seed + rng = np.random.default_rng(split_seed) + rng.shuffle(subjects) train_cutoff = int(len(subjects) * percent_train) @@ -132,7 +135,7 @@ def split_by_subject(subject_to_timestamps, percent_train, verbose=False): return train_subject_to_timestamps, val_subject_to_timestamps -def split_by_timed_subjects(subject_to_timestamps, percent_train, random_tries=30, verbose=False): +def split_by_timed_subjects(subject_to_timestamps, percent_train, random_tries=30,split_seed, verbose=False): """ NOTE: An attempt is made to put percent_train of the subject timestamp combinations into train (as opposed to val) regardless of what that does to the subject ratios. No subject is allowed to have samples in both train and val. @@ -143,9 +146,11 @@ def percent_train_for_split(train_subjects, grand_total): sub_total += subject_counts[subject] return sub_total/grand_total - def shuffle_and_cut(subject_counts, grand_total, percent_train, verbose=False): + def shuffle_and_cut(subject_counts, grand_total, percent_train, seed, verbose=False): subjects = list(subject_counts.keys()) - np.random.shuffle(subjects) + # create a random number generator with our seed + rng = np.random.default_rng(seed) + rng.shuffle(subjects) for idx in range(2,len(subjects)+1): train_subjects = subjects[:idx-1] val_subjects = subjects[idx-1:] @@ -172,8 +177,9 @@ def shuffle_and_cut(subject_counts, grand_total, percent_train, verbose=False): best_percent_train = percent_train_for_split(train_subjects=best_train_subjects, grand_total=grand_total) # random shuffle times in order to find the closest we can get to honoring the percent_train requirement (train and val both need to be non-empty) - for _ in range(random_tries): - train_subjects, val_subjects, percent_train_estimate = shuffle_and_cut(subject_counts=subject_counts, grand_total=grand_total, percent_train=percent_train, verbose=verbose) + for _try in range(random_tries): + seed = split_seed + _try + train_subjects, val_subjects, percent_train_estimate = shuffle_and_cut(subject_counts=subject_counts, grand_total=grand_total, percent_train=percent_train, seed=seed, verbose=verbose) if abs(percent_train_estimate - percent_train) < abs(best_percent_train - percent_train): best_train_subjects = train_subjects best_val_subjects = val_subjects @@ -185,16 +191,16 @@ def shuffle_and_cut(subject_counts, grand_total, percent_train, verbose=False): return train_subject_to_timestamps, val_subject_to_timestamps -def write_splits_file(subject_to_timestamps, percent_train, split_logic, fold, task, splits_fname='splits_final.pkl', verbose=False): +def write_splits_file(subject_to_timestamps, percent_train, split_logic, split_seed, fold, task, splits_fname='splits_final.pkl', verbose=False): # double check we are in the right folder to modify the splits file splits_fpath = os.path.join(os.environ['nnUNet_preprocessed'], f'{task}', splits_fname) POSTOPP_splits_fpath = os.path.join(os.environ['nnUNet_preprocessed'], f'{task}', 'POSTOPP_BACKUP_' + splits_fname) # now split if split_logic == 'by_subject': - train_subject_to_timestamps, val_subject_to_timestamps = split_by_subject(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, verbose=verbose) + train_subject_to_timestamps, val_subject_to_timestamps = split_by_subject(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, split_seed=split_seed, verbose=verbose) elif split_logic == 'by_subject_time_pair': - train_subject_to_timestamps, val_subject_to_timestamps = split_by_timed_subjects(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, verbose=verbose) + train_subject_to_timestamps, val_subject_to_timestamps = split_by_timed_subjects(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, split_seed=split_seed, verbose=verbose) else: raise ValueError(f"Split logic of 'by_subject' and 'by_subject_time_pair' are the only ones supported, whereas a split_logic value of {split_logic} was provided.") @@ -235,6 +241,7 @@ def setup_fl_data(postopp_pardir, init_model_info_path, cuda_device, overwrite_nnunet_datadirs, + split_seed=2468, plans_path=None, verbose=False): """ @@ -301,7 +308,6 @@ def setup_fl_data(postopp_pardir, three_digit_task_num(str): Should start with '5'. If num_institutions == N (see below), all N task numbers starting with this number will be used. task_name(str) : Any string task name. percent_train(float) : what percent of data is put into the training data split (rest to val) - split_logic(str) : Determines how train/val split is performed timestamp_selection(str) : Indicates how to determine the timestamp to pick for each subject ID at the source: 'latest', 'earliest', and 'all' are the only ones supported so far network(str) : Which network is being used for NNUnet @@ -312,6 +318,7 @@ def setup_fl_data(postopp_pardir, init_model_info_path(str) : Path to the initial model info (pkl) file cuda_device(str) : Device to perform training ('cpu' or 'cuda') overwrite_nnunet_datadirs(bool) : Allows for overwriting past instances of NNUnet data directories using the task numbers from first_three_digit_task_num to that plus one less than number of insitutions. + split_seed (int) : Seed used for the random number generator used within the split logic plans_path(str) : Path to the training plans (pkl) percent_train(float) : What percentage of timestamped subjects to attempt dedicate to train versus val. Will be only approximately acheived in general since all timestamps associated with the same subject need to land exclusively in either train or val. @@ -367,7 +374,8 @@ def setup_fl_data(postopp_pardir, # Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val write_splits_file(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, - split_logic=split_logic, + split_logic=split_logic, + split_seed=split_seed, fold=fold, task=task, verbose=verbose)