diff --git a/src/scripts/DeepEvidentialRegression.py b/src/scripts/DeepEvidentialRegression.py index 605724c..73e31c5 100644 --- a/src/scripts/DeepEvidentialRegression.py +++ b/src/scripts/DeepEvidentialRegression.py @@ -19,7 +19,7 @@ def parse_args(): - parser = argparse.ArgumentParser(description="data handling module") + parser = argparse.ArgumentParser(description="Runs DER") # there are three options with the parser: # 1) Read from a yaml # 2) Reads from the command line and default file @@ -175,7 +175,20 @@ def parse_args(): "--rs", type=int, default=DefaultsDER["model"]["rs"], - help="define a random seed to save", + help="random seed for the pytorch model initialization", + ) + parser.add_argument( + "--save_n_hidden", + action="store_true", + default=DefaultsDER["model"]["save_n_hidden"], + help="save chk with the number of neurons in the hidden layer", + ) + parser.add_argument( + "--n_hidden", + type=int, + required=False, + default=DefaultsDER["model"]["n_hidden"], + help="Number of hidden neurons in the hidden layer, default 64", ) parser.add_argument( "--verbose", @@ -215,6 +228,8 @@ def parse_args(): "savefig": args.savefig, "save_chk_random_seed_init": args.save_chk_random_seed_init, "rs": args.rs, + "save_n_hidden": args.save_n_hidden, + "n_hidden": args.n_hidden, "verbose": args.verbose, }, "data": { @@ -295,7 +310,9 @@ def parse_args(): "model_type", "DER") + "_noise_" + noise model, lossFn = models.model_setup_DER( - config.get_item("model", "loss_type", "DER"), DEVICE + config.get_item("model", "loss_type", "DER"), + DEVICE, + n_hidden=config.get_item("model", "n_hidden", "DER") ) model_ensemble = train.train_DER( trainDataLoader, @@ -325,5 +342,9 @@ def parse_args(): "save_chk_random_seed_init", "DER"), rs=config.get_item("model", "rs", "DER"), + save_n_hidden=config.get_item("model", + "save_n_hidden", + "DER"), + n_hidden=config.get_item("model", "n_hidden", "DER"), verbose=config.get_item("model", "verbose", "DER"), ) diff --git a/src/train/train.py b/src/train/train.py index e579267..1aea423 100644 --- a/src/train/train.py +++ b/src/train/train.py @@ -32,7 +32,10 @@ def train_DER( savefig=True, set_and_save_rs=False, rs=42, + save_n_hidden=False, + n_hidden=64, verbose=True, + ): # first determine if you even need to run anything if not save_all_checkpoints and save_final_checkpoint: @@ -82,7 +85,9 @@ def train_DER( set_random_seeds(seed_value=rs) best_loss = np.inf # init to infinity - model, lossFn = models.model_setup_DER(loss_type, DEVICE) + model, lossFn = models.model_setup_DER(loss_type, + DEVICE, + n_hidden=n_hidden) if verbose: print("model is", model, "lossfn", lossFn) @@ -252,62 +257,64 @@ def train_DER( plt.show() plt.close() if save_all_checkpoints: + + filename = ( + str(path_to_model) + + "checkpoints/" + + str(model_name) + + "_loss_" + + str(loss_type) + + "_COEFF_" + + str(COEFF) + + "_epoch_" + + str(epoch) + ) + if set_and_save_rs: - torch.save( - { - "epoch": epoch, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": opt.state_dict(), - "train_loss": np.mean(loss_this_epoch), - "valid_loss": NIGloss_val, - "valid_mse": mse, - "med_u_al_validation": med_u_al_val, - "med_u_ep_validation": med_u_ep_val, - "std_u_al_validation": std_u_al_val, - "std_u_ep_validation": std_u_ep_val, - }, - str(path_to_model) - + "checkpoints/" - + str(model_name) - + "_loss_" - + str(loss_type) - + "_COEFF_" - + str(COEFF) - + "_epoch_" - + str(epoch) - + "_rs_" - + str(rs) - + ".pt", - ) - else: - torch.save( - { - "epoch": epoch, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": opt.state_dict(), - "train_loss": np.mean(loss_this_epoch), - "valid_loss": NIGloss_val, - "valid_mse": mse, - "med_u_al_validation": med_u_al_val, - "med_u_ep_validation": med_u_ep_val, - "std_u_al_validation": std_u_al_val, - "std_u_ep_validation": std_u_ep_val, - }, - str(path_to_model) - + "checkpoints/" - + str(model_name) - + "_loss_" - + str(loss_type) - + "_COEFF_" - + str(COEFF) - + "_epoch_" - + str(epoch) - + ".pt", - ) + filename += "_rs_" + str(rs) + + if save_n_hidden: + filename += "_n_hidden_" + str(n_hidden) + + filename += ".pt" + torch.save( + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": opt.state_dict(), + "train_loss": np.mean(loss_this_epoch), + "valid_loss": NIGloss_val, + "valid_mse": mse, + "med_u_al_validation": med_u_al_val, + "med_u_ep_validation": med_u_ep_val, + "std_u_al_validation": std_u_al_val, + "std_u_ep_validation": std_u_ep_val, + }, + filename + ) + if save_final_checkpoint and (e % (EPOCHS - 1) == 0) and (e != 0): - # option to just save final epoch + filename = ( + str(path_to_model) + + "checkpoints/" + + str(model_name) + + "_loss_" + + str(loss_type) + + "_COEFF_" + + str(COEFF) + + "_epoch_" + + str(epoch) + ) + if set_and_save_rs: - torch.save( + filename += "_rs_" + str(rs) + + if save_n_hidden: + filename += "_n_hidden_" + str(n_hidden) + + filename += ".pt" + # option to just save final epoch + torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), @@ -320,43 +327,8 @@ def train_DER( "std_u_al_validation": std_u_al_val, "std_u_ep_validation": std_u_ep_val, }, - str(path_to_model) - + "checkpoints/" - + str(model_name) - + "_loss_" - + str(loss_type) - + "_COEFF_" - + str(COEFF) - + "_epoch_" - + str(epoch) - + "_rs_" - + str(rs) - + ".pt", + filename ) - else: - torch.save( - { - "epoch": epoch, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": opt.state_dict(), - "train_loss": np.mean(loss_this_epoch), - "valid_loss": NIGloss_val, - "valid_mse": mse, - "med_u_al_validation": med_u_al_val, - "med_u_ep_validation": med_u_ep_val, - "std_u_al_validation": std_u_al_val, - "std_u_ep_validation": std_u_ep_val, - }, - str(path_to_model) - + "checkpoints/" - + str(model_name) - + "_loss_" - + str(loss_type) - + "_COEFF_" - + str(COEFF) - + "_epoch_" - + str(epoch) - + ".pt") endTime = time.time() if verbose: print("start at", startTime, "end at", endTime) diff --git a/src/utils/defaults.py b/src/utils/defaults.py index fe52a4e..4ada583 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -71,6 +71,8 @@ "savefig": False, "save_chk_random_seed_init": False, "rs": 42, + "save_n_hidden": False, + "n_hidden": 64, "verbose": False, }, "plots_common": { @@ -99,8 +101,9 @@ "loss_type": "DER" }, "analysis": { - "noise_level_list": ["low", "medium", "high"], - "model_names_list": ["DER_wst", "DE_desiderata_2"], + "noise_level_list": ["low"],#, "medium", "high"], + "model_names_list": ["DER"], #, "DE_desiderata_2"], + # for the inits changed to "DER_wst" # ["DER_desiderata_2", "DE_desiderata_2"] "plot": True, "savefig": False,