From 21ff5213bbc363906e1a9286c7283e7cca43acfc Mon Sep 17 00:00:00 2001 From: beckynevin Date: Tue, 7 May 2024 17:08:56 -0600 Subject: [PATCH] figured out how to get checkpoint --- src/analyze/analyze.py | 17 +- src/scripts/AleatoricComparison.py | 295 ++++++----------------------- src/utils/config.py | 6 +- src/utils/defaults.py | 41 +++- 4 files changed, 116 insertions(+), 243 deletions(-) diff --git a/src/analyze/analyze.py b/src/analyze/analyze.py index 67d8cdc..c99c0e2 100644 --- a/src/analyze/analyze.py +++ b/src/analyze/analyze.py @@ -23,8 +23,14 @@ class AggregateCheckpoints: # def load_final_checkpoints(): # def load_all_checkpoints(): # functions for loading model checkpoints - def load_DE_checkpoint( - self, model_name, nmodel, epoch, beta, device, path="models/checkpoints/" + def load_checkpoint( + self, + model_name, + noise, + nmodel, + epoch, + beta, device, + path="models/" ): """ Load PyTorch model checkpoint from a .pt file. @@ -36,8 +42,11 @@ def load_DE_checkpoint( :param model: PyTorch model to load the checkpoint into :return: Loaded model """ - file_name = path + f"{model_name}_beta_{beta}_nmodel_{nmodel}_epoch_{epoch}.pt" - checkpoint = torch.load(file_name, map_location=device) + if model_name[0:2] == "DE": + file_name = str(path) + "checkpoints/" + f"{model_name}_noise_{noise}_beta_{beta}_nmodel_{nmodel}_epoch_{epoch}.pt" + checkpoint = torch.load(file_name, map_location=device) + else: + STOP return checkpoint def ep_al_checkpoint_DE(checkpoint): diff --git a/src/scripts/AleatoricComparison.py b/src/scripts/AleatoricComparison.py index b5058f5..66b37f5 100644 --- a/src/scripts/AleatoricComparison.py +++ b/src/scripts/AleatoricComparison.py @@ -5,15 +5,13 @@ import torch from torch.utils.data import TensorDataset, DataLoader -# from scripts import train, models, io -from train import train -from models import models + from data import DataModules from models import ModelModules from utils.config import Config -from utils.defaults import DefaultsDE +from utils.defaults import DefaultsAnalysis, DefaultsDE from data.data import DataPreparation, MyDataLoader -# from analyze.analyze import AggregateCheckpoints +from analyze.analyze import AggregateCheckpoints # from plots import Plots @@ -32,74 +30,20 @@ def parse_args(): parser.add_argument( "--data_path", "-d", - default=DefaultsDE["data"]["data_path"], + default=DefaultsAnalysis["data"]["data_path"], ) parser.add_argument( "--data_engine", "-dl", - default=DefaultsDE["data"]["data_engine"], + default=DefaultsAnalysis["data"]["data_engine"], choices=DataModules.keys(), ) # model # path to save the model results - parser.add_argument("--out_dir", default=DefaultsDE["common"]["out_dir"]) - parser.add_argument( - "--model_engine", - "-e", - default=DefaultsDE["model"]["model_engine"], - choices=ModelModules.keys(), - ) - parser.add_argument( - "--size_df", - type=float, - required=False, - default=DefaultsDE["data"]["size_df"], - help="Used to load the associated .h5 data file", - ) - parser.add_argument( - "--noise_level", - type=str, - default=DefaultsDE["data"]["noise_level"], - choices=["low", "medium", "high", "vhigh"], - help="low, medium, high or vhigh, \ - used to look up associated sigma value", - ) - parser.add_argument( - "--normalize", - required=False, - action="store_true", - default=DefaultsDE["data"]["normalize"], - help="If true theres an option to normalize the dataset", - ) - parser.add_argument( - "--val_proportion", - type=float, - required=False, - default=DefaultsDE["data"]["val_proportion"], - help="Proportion of the dataset to use as validation", - ) - parser.add_argument( - "--randomseed", - type=int, - required=False, - default=DefaultsDE["data"]["randomseed"], - help="Random seed used for shuffling the training and validation set", - ) - parser.add_argument( - "--generatedata", - action="store_true", - default=DefaultsDE["data"]["generatedata"], - help="option to generate df, if not specified \ - default behavior is to load from file", - ) - parser.add_argument( - "--batchsize", - type=int, - required=False, - default=DefaultsDE["data"]["batchsize"], - help="Size of batched used in the traindataloader", - ) + parser.add_argument("--out_dir", + default=DefaultsAnalysis["common"]["out_dir"]) + # now args for model parser.add_argument( "--n_models", @@ -107,21 +51,6 @@ def parse_args(): default=DefaultsDE["model"]["n_models"], help="Number of MVEs in the ensemble", ) - parser.add_argument( - "--init_lr", - type=float, - required=False, - default=DefaultsDE["model"]["init_lr"], - help="Learning rate", - ) - parser.add_argument( - "--loss_type", - type=str, - required=False, - default=DefaultsDE["model"]["loss_type"], - help="Loss types for MVE, options are no_var_loss, var_loss, \ - and bnn_loss", - ) parser.add_argument( "--BETA", type=beta_type, @@ -132,59 +61,42 @@ def parse_args(): step_decrease_to_0.5, and step_decrease_to_1.0", ) parser.add_argument( - "--model_type", + "--noise_level_list", + type=str, + required=False, + default=DefaultsAnalysis["analysis"]["noise_level_list"], + help="Noise levels to compare", + ) + parser.add_argument( + "--model_names_list", type=str, required=False, - default=DefaultsDE["model"]["model_type"], + default=DefaultsAnalysis["analysis"]["model_names_list"], help="Beginning of name for saved checkpoints and figures", ) parser.add_argument( "--n_epochs", type=int, required=False, - default=DefaultsDE["model"]["n_epochs"], - help="number of epochs for each MVE", - ) - parser.add_argument( - "--save_all_checkpoints", - action="store_true", - default=DefaultsDE["model"]["save_all_checkpoints"], - help="option to save all checkpoints", - ) - parser.add_argument( - "--save_final_checkpoint", - action="store_true", # Set to True if argument is present - default=DefaultsDE["model"]["save_final_checkpoint"], - help="option to save the final epoch checkpoint for each ensemble", - ) - parser.add_argument( - "--overwrite_final_checkpoint", - action="store_true", - default=DefaultsDE["model"]["overwrite_final_checkpoint"], - help="option to overwite already saved checkpoints", + default=DefaultsAnalysis["analysis"]["n_epochs"], + help="number of epochs", ) parser.add_argument( "--plot", action="store_true", - default=DefaultsDE["model"]["plot"], + default=DefaultsAnalysis["analysis"]["plot"], help="option to plot in notebook", ) parser.add_argument( "--savefig", action="store_true", - default=DefaultsDE["model"]["savefig"], + default=DefaultsAnalysis["analysis"]["savefig"], help="option to save a figure of the true and predicted values", ) - parser.add_argument( - "--run_analysis", - action="store_true", - default=DefaultsDE["analysis"]["run_analysis"], - help="option to run analysis on saved checkpoints", - ) parser.add_argument( "--verbose", action="store_true", - default=DefaultsDE["model"]["verbose"], + default=DefaultsAnalysis["analysis"]["verbose"], help="verbose option for train", ) args = parser.parse_args() @@ -194,7 +106,7 @@ def parse_args(): config = Config(args.config) else: - temp_config = DefaultsDE["common"]["temp_config"] + temp_config = DefaultsAnalysis["common"]["temp_config"] print( "Reading settings from cli and default, \ dumping to temp config: ", @@ -206,31 +118,18 @@ def parse_args(): # if not, default is from DefaultsDE dictionary input_yaml = { "common": {"out_dir": args.out_dir}, - "model": { - "model_engine": args.model_engine, - "model_type": args.model_type, - "loss_type": args.loss_type, - "n_models": args.n_models, - "init_lr": args.init_lr, - "BETA": args.BETA, - "n_epochs": args.n_epochs, - "save_all_checkpoints": args.save_all_checkpoints, - "save_final_checkpoint": args.save_final_checkpoint, - "overwrite_final_checkpoint": args.overwrite_final_checkpoint, - "plot": args.plot, - "savefig": args.savefig, - "verbose": args.verbose, - }, "data": { "data_path": args.data_path, "data_engine": args.data_engine, - "size_df": args.size_df, - "noise_level": args.noise_level, - "val_proportion": args.val_proportion, - "randomseed": args.randomseed, - "batchsize": args.batchsize, }, - "analysis": {"run_analysis": args.run_analysis} + "model": {"n_models": args.n_models, + "BETA": args.BETA}, + "analysis": {"noise_level_list": args.noise_level_list, + "model_names_list": args.model_names_list, + "n_epochs": args.n_epochs, + "plot": args.plot, + "savefig": args.savefig, + "verbose": args.verbose,} # "plots": {key: {} for key in args.plots}, # "metrics": {key: {} for key in args.metrics}, } @@ -241,7 +140,6 @@ def parse_args(): return config # return parser.parse_args() - def beta_type(value): if isinstance(value, float): return value @@ -258,108 +156,35 @@ def beta_type(value): ) + if __name__ == "__main__": config = parse_args() - size_df = int(config.get_item("data", "size_df", "DE")) - noise = config.get_item("data", "noise_level", "DE") - norm = config.get_item("data", "normalize", "DE", raise_exception=False) - val_prop = config.get_item("data", "val_proportion", "DE") - rs = config.get_item("data", "randomseed", "DE") - BATCH_SIZE = config.get_item("data", "batchsize", "DE") - sigma = DataPreparation.get_sigma(noise) - path_to_data = config.get_item("data", "data_path", "DE") - if config.get_item("data", "generatedata", "DE", raise_exception=False): - # generate the df - data = DataPreparation() - data.sample_params_from_prior(size_df) - data.simulate_data(data.params, sigma, "linear_homogeneous") - df_array = data.get_dict() - # Convert non-tensor entries to tensors - df = {} - for key, value in df_array.items(): - - if isinstance(value, TensorDataset): - # Keep tensors as they are - df[key] = value - else: - # Convert lists to tensors - df[key] = torch.tensor(value) - else: - loader = MyDataLoader() - df = loader.load_data_h5( - "linear_sigma_" + str(sigma) + "_size_" + str(size_df), - path=path_to_data, - ) - len_df = len(df["params"][:, 0].numpy()) - len_x = len(df["inputs"].numpy()) - ms_array = np.repeat(df["params"][:, 0].numpy(), len_x) - bs_array = np.repeat(df["params"][:, 1].numpy(), len_x) - xs_array = np.tile(df["inputs"].numpy(), len_df) - ys_array = np.reshape(df["output"].numpy(), (len_df * len_x)) - - inputs = np.array([xs_array, ms_array, bs_array]).T - model_inputs, model_outputs = DataPreparation.normalize(inputs, - ys_array, - norm) - x_train, x_val, y_train, y_val = DataPreparation.train_val_split( - model_inputs, model_outputs, val_proportion=val_prop, random_state=rs - ) - trainData = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train)) - trainDataLoader = DataLoader(trainData, - batch_size=BATCH_SIZE, - shuffle=True) - # set the device we will be using to train the model DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - model_name = config.get_item("model", - "model_type", - "DE") + "_noise_" + noise - model, lossFn = models.model_setup_DE( - config.get_item("model", "loss_type", "DE"), DEVICE - ) - print( - "save final checkpoint has this value", - config.get_item("model", "save_final_checkpoint", "DE"), - ) - model_ensemble = train.train_DE( - trainDataLoader, - x_val, - y_val, - config.get_item("model", "init_lr", "DE"), - DEVICE, - config.get_item("model", "loss_type", "DE"), - config.get_item("model", "n_models", "DE"), - model_name, - BETA=config.get_item("model", "BETA", "DE"), - EPOCHS=config.get_item("model", "n_epochs", "DE"), - path_to_model=config.get_item("common", "out_dir", "DE"), - save_all_checkpoints=config.get_item("model", - "save_all_checkpoints", - "DE"), - save_final_checkpoint=config.get_item("model", - "save_final_checkpoint", - "DE"), - overwrite_final_checkpoint=config.get_item( - "model", "overwrite_final_checkpoint", "DE" - ), - plot=config.get_item("model", "plot", "DE"), - savefig=config.get_item("model", "savefig", "DE"), - verbose=config.get_item("model", "verbose", "DE"), - ) - ''' - if config.get_item("analysis", "run_analysis", "DE"): - # now run the analysis on the resulting checkpoints - chk_module = AggregateCheckpoints() - print('n_models', config.get_item("model", "n_models", "DE")) - print('n_epochs', config.get_item("model", "n_epochs", "DE")) - for nmodel in range(config.get_item("model", "n_models", "DE")): - for epoch in range(config.get_item("model", "n_epochs", "DE")): - chk = chk_module.load_DE_checkpoint( - model_name, - nmodel, - epoch, - config.get_item("model", "BETA", "DE"), - DEVICE) - # things to grab: 'valid_mse' and 'valid_bnll' - print(chk) - ''' + noise_list = config.get_item("analysis", "noise_level_list", "Analysis") + sigma_list = [] + for noise in noise_list: + sigma_list.append(DataPreparation.get_sigma(noise)) + print('noise list', noise_list) + print('sigma list', sigma_list) + path_to_chk = config.get_item("common", "out_dir", "Analysis") + model_name_list = config.get_item("analysis", "model_names_list", "Analysis") + for noise in noise_list: + for model in model_name_list: + # now run the analysis on the resulting checkpoints + chk_module = AggregateCheckpoints() + print('n_models', config.get_item("model", "n_models", "DE")) + print('n_epochs', config.get_item("analysis", "n_epochs", "Analysis")) + for nmodel in range(config.get_item("model", "n_models", "DE")): + for epoch in range(config.get_item("analysis", "n_epochs", "Analysis")): + chk = chk_module.load_checkpoint( + model, + noise, + nmodel, + epoch, + config.get_item("model", "BETA", "DE"), + DEVICE, + ) + #path=path_to_chk) + # things to grab: 'valid_mse' and 'valid_bnll' + print(chk) + diff --git a/src/utils/config.py b/src/utils/config.py index fc45114..43c450c 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -57,7 +57,8 @@ def get_item(self, section, item, defaulttype, raise_exception=True): else: return { "DER": DefaultsDER, - "DE": DefaultsDE + "DE": DefaultsDE, + "Analysis": DefaultsAnalysis }[defaulttype][section][item] def get_section(self, section, defaulttype, raise_exception=True): @@ -69,5 +70,6 @@ def get_section(self, section, defaulttype, raise_exception=True): else: return { "DER": DefaultsDER, - "DE": DefaultsDE + "DE": DefaultsDE, + "Analysis": DefaultsAnalysis }[defaulttype][section] diff --git a/src/utils/defaults.py b/src/utils/defaults.py index 67a9b6a..49a7707 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -39,7 +39,6 @@ "line_style_cycle": ["-", "-."], "figure_size": [6, 6], }, - "analysis": {"run_analysis": False}, "plots": {"CDFRanks": {}, "Ranks": {"num_bins": None}, "CoverageFraction": {}}, @@ -84,7 +83,6 @@ "savefig": False, "verbose": False, }, - "analysis": {"run_analysis": False}, "plots_common": { "axis_spines": False, "tight_layout": True, @@ -108,3 +106,42 @@ "CoverageFraction": {}, }, } +DefaultsAnalysis = { + "common": { + "out_dir": "./DeepUQResources/results/", + "temp_config": "./DeepUQResources/temp/temp_config_analysis.yml", + }, + "data": { + "data_path": "./data/", + "data_engine": "DataLoader", + }, + "plots_common": { + "axis_spines": False, + "tight_layout": True, + "default_colorway": "viridis", + "plot_style": "fast", + "parameter_labels": ["$m$", "$b$"], + "parameter_colors": ["#9C92A3", "#0F5257"], + "line_style_cycle": ["-", "-."], + "figure_size": [6, 6], + }, + "analysis": {"run_analysis": False, + "noise_level_list": ["low", "medium", "high"], + "model_names_list": ["DE_desiderata_2", "DER"], + "n_epochs": 100, + "plot": False, + "savefig": False, + "verbose": False}, + "plots": {"CDFRanks": {}, + "Ranks": {"num_bins": None}, + "CoverageFraction": {}}, + "metrics_common": { + "use_progress_bar": False, + "samples_per_inference": 1000, + "percentiles": [75, 85, 95], + }, + "metrics": { + "AllSBC": {}, + "CoverageFraction": {}, + }, +}