Skip to content

Commit

Permalink
initial changes to set seeds for nnunet data setup train/val split
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-edwards committed Nov 7, 2024
1 parent c413bf6 commit 369f79e
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions examples/fl_post/fl/project/nnunet_data_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,17 @@ def doublecheck_postopp_pardir(postopp_pardir, verbose=False):
raise ValueError(f"'labels' must be a subdirectory of postopp_src_pardir:{postopp_pardir}, but it is not.")


def split_by_subject(subject_to_timestamps, percent_train, verbose=False):
def split_by_subject(subject_to_timestamps, percent_train,split_seed, verbose=False):
"""
NOTE: An attempt is made to put percent_train of the total subjects into train (as opposed to val) regardless of how many timestamps there are for each subject.
No subject is allowed to have samples in both train and val.
"""

subjects = list(subject_to_timestamps.keys())
np.random.shuffle(subjects)

# create a random number generator with our seed
rng = np.random.default_rng(split_seed)
rng.shuffle(subjects)

train_cutoff = int(len(subjects) * percent_train)

Expand All @@ -132,7 +135,7 @@ def split_by_subject(subject_to_timestamps, percent_train, verbose=False):
return train_subject_to_timestamps, val_subject_to_timestamps


def split_by_timed_subjects(subject_to_timestamps, percent_train, random_tries=30, verbose=False):
def split_by_timed_subjects(subject_to_timestamps, percent_train, random_tries=30,split_seed, verbose=False):
"""
NOTE: An attempt is made to put percent_train of the subject timestamp combinations into train (as opposed to val) regardless of what that does to the subject ratios.
No subject is allowed to have samples in both train and val.
Expand All @@ -143,9 +146,11 @@ def percent_train_for_split(train_subjects, grand_total):
sub_total += subject_counts[subject]
return sub_total/grand_total

def shuffle_and_cut(subject_counts, grand_total, percent_train, verbose=False):
def shuffle_and_cut(subject_counts, grand_total, percent_train, seed, verbose=False):
subjects = list(subject_counts.keys())
np.random.shuffle(subjects)
# create a random number generator with our seed
rng = np.random.default_rng(seed)
rng.shuffle(subjects)
for idx in range(2,len(subjects)+1):
train_subjects = subjects[:idx-1]
val_subjects = subjects[idx-1:]
Expand All @@ -172,8 +177,9 @@ def shuffle_and_cut(subject_counts, grand_total, percent_train, verbose=False):
best_percent_train = percent_train_for_split(train_subjects=best_train_subjects, grand_total=grand_total)

# random shuffle <random_tries> times in order to find the closest we can get to honoring the percent_train requirement (train and val both need to be non-empty)
for _ in range(random_tries):
train_subjects, val_subjects, percent_train_estimate = shuffle_and_cut(subject_counts=subject_counts, grand_total=grand_total, percent_train=percent_train, verbose=verbose)
for _try in range(random_tries):
seed = split_seed + _try
train_subjects, val_subjects, percent_train_estimate = shuffle_and_cut(subject_counts=subject_counts, grand_total=grand_total, percent_train=percent_train, seed=seed, verbose=verbose)
if abs(percent_train_estimate - percent_train) < abs(best_percent_train - percent_train):
best_train_subjects = train_subjects
best_val_subjects = val_subjects
Expand All @@ -185,16 +191,16 @@ def shuffle_and_cut(subject_counts, grand_total, percent_train, verbose=False):
return train_subject_to_timestamps, val_subject_to_timestamps


def write_splits_file(subject_to_timestamps, percent_train, split_logic, fold, task, splits_fname='splits_final.pkl', verbose=False):
def write_splits_file(subject_to_timestamps, percent_train, split_logic, split_seed, fold, task, splits_fname='splits_final.pkl', verbose=False):
# double check we are in the right folder to modify the splits file
splits_fpath = os.path.join(os.environ['nnUNet_preprocessed'], f'{task}', splits_fname)
POSTOPP_splits_fpath = os.path.join(os.environ['nnUNet_preprocessed'], f'{task}', 'POSTOPP_BACKUP_' + splits_fname)

# now split
if split_logic == 'by_subject':
train_subject_to_timestamps, val_subject_to_timestamps = split_by_subject(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, verbose=verbose)
train_subject_to_timestamps, val_subject_to_timestamps = split_by_subject(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, split_seed=split_seed, verbose=verbose)
elif split_logic == 'by_subject_time_pair':
train_subject_to_timestamps, val_subject_to_timestamps = split_by_timed_subjects(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, verbose=verbose)
train_subject_to_timestamps, val_subject_to_timestamps = split_by_timed_subjects(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, split_seed=split_seed, verbose=verbose)
else:
raise ValueError(f"Split logic of 'by_subject' and 'by_subject_time_pair' are the only ones supported, whereas a split_logic value of {split_logic} was provided.")

Expand Down Expand Up @@ -235,6 +241,7 @@ def setup_fl_data(postopp_pardir,
init_model_info_path,
cuda_device,
overwrite_nnunet_datadirs,
split_seed=2468,
plans_path=None,
verbose=False):
"""
Expand Down Expand Up @@ -301,7 +308,6 @@ def setup_fl_data(postopp_pardir,
three_digit_task_num(str): Should start with '5'. If num_institutions == N (see below), all N task numbers starting with this number will be used.
task_name(str) : Any string task name.
percent_train(float) : what percent of data is put into the training data split (rest to val)
split_logic(str) : Determines how train/val split is performed
timestamp_selection(str) : Indicates how to determine the timestamp to pick
for each subject ID at the source: 'latest', 'earliest', and 'all' are the only ones supported so far
network(str) : Which network is being used for NNUnet
Expand All @@ -312,6 +318,7 @@ def setup_fl_data(postopp_pardir,
init_model_info_path(str) : Path to the initial model info (pkl) file
cuda_device(str) : Device to perform training ('cpu' or 'cuda')
overwrite_nnunet_datadirs(bool) : Allows for overwriting past instances of NNUnet data directories using the task numbers from first_three_digit_task_num to that plus one less than number of insitutions.
split_seed (int) : Seed used for the random number generator used within the split logic
plans_path(str) : Path to the training plans (pkl)
percent_train(float) : What percentage of timestamped subjects to attempt dedicate to train versus val. Will be only approximately acheived in general since
all timestamps associated with the same subject need to land exclusively in either train or val.
Expand Down Expand Up @@ -367,7 +374,8 @@ def setup_fl_data(postopp_pardir,
# Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val
write_splits_file(subject_to_timestamps=subject_to_timestamps,
percent_train=percent_train,
split_logic=split_logic,
split_logic=split_logic,
split_seed=split_seed,
fold=fold,
task=task,
verbose=verbose)
Expand Down

0 comments on commit 369f79e

Please sign in to comment.