From 614f6fe7be8242317424dcba2de890ca374cc81b Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 6 May 2024 09:32:28 -0600 Subject: [PATCH] temp data tests running with analysis keyword added --- src/scripts/DeepEnsemble.py | 40 ++++++++++++++++----------- src/utils/defaults.py | 2 ++ test/test_DeepEnsemble.py | 5 +++- test/test_DeepEvidentialRegression.py | 2 +- 4 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/scripts/DeepEnsemble.py b/src/scripts/DeepEnsemble.py index 43ee9fa..c8c876f 100644 --- a/src/scripts/DeepEnsemble.py +++ b/src/scripts/DeepEnsemble.py @@ -175,6 +175,12 @@ def parse_args(): default=DefaultsDE["model"]["savefig"], help="option to save a figure of the true and predicted values", ) + parser.add_argument( + "--run_analysis", + action="store_true", + default=DefaultsDE["analysis"]["run_analysis"], + help="option to run analysis on saved checkpoints", + ) parser.add_argument( "--verbose", action="store_true", @@ -224,6 +230,7 @@ def parse_args(): "randomseed": args.randomseed, "batchsize": args.batchsize, }, + "analysis": {"run_analysis": args.run_analysis} # "plots": {key: {} for key in args.plots}, # "metrics": {key: {} for key in args.metrics}, } @@ -310,7 +317,7 @@ def beta_type(value): model, lossFn = models.model_setup_DE( config.get_item("model", "loss_type", "DE"), DEVICE ) - ''' + print( "save final checkpoint has this value", config.get_item("model", "save_final_checkpoint", "DE"), @@ -340,18 +347,19 @@ def beta_type(value): savefig=config.get_item("model", "savefig", "DE"), verbose=config.get_item("model", "verbose", "DE"), ) - ''' - # now run the analysis on the resulting checkpoints - chk_module = AggregateCheckpoints() - print('n_models', config.get_item("model", "n_models", "DE")) - print('n_epochs', config.get_item("model", "n_epochs", "DE")) - for nmodel in range(config.get_item("model", "n_models", "DE")): - for epoch in range(config.get_item("model", "n_epochs", "DE")): - chk = chk_module.load_DE_checkpoint(model, - model_name, - nmodel, - epoch, - config.get_item("model", "BETA", "DE"), - DEVICE) - print(chk) - STOP + if config.get_item("analysis", "run_analysis", "DE"): + # now run the analysis on the resulting checkpoints + chk_module = AggregateCheckpoints() + print('n_models', config.get_item("model", "n_models", "DE")) + print('n_epochs', config.get_item("model", "n_epochs", "DE")) + for nmodel in range(config.get_item("model", "n_models", "DE")): + for epoch in range(config.get_item("model", "n_epochs", "DE")): + chk = chk_module.load_DE_checkpoint( + model_name, + nmodel, + epoch, + config.get_item("model", "BETA", "DE"), + DEVICE) + # things to grab: 'valid_mse' and 'valid_bnll' + print(chk) + STOP diff --git a/src/utils/defaults.py b/src/utils/defaults.py index a8135b9..67a9b6a 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -39,6 +39,7 @@ "line_style_cycle": ["-", "-."], "figure_size": [6, 6], }, + "analysis": {"run_analysis": False}, "plots": {"CDFRanks": {}, "Ranks": {"num_bins": None}, "CoverageFraction": {}}, @@ -83,6 +84,7 @@ "savefig": False, "verbose": False, }, + "analysis": {"run_analysis": False}, "plots_common": { "axis_spines": False, "tight_layout": True, diff --git a/test/test_DeepEnsemble.py b/test/test_DeepEnsemble.py index a120d4a..4233967 100644 --- a/test/test_DeepEnsemble.py +++ b/test/test_DeepEnsemble.py @@ -95,6 +95,9 @@ def create_test_config( "randomseed": 42, "batchsize": 100, }, + "analysis": { + "run_analysis": False + } } print("theoretically dumping here", str(temp_directory) + "yamls/DE.yaml") yaml.dump(input_yaml, open(str(temp_directory) + "yamls/DE.yaml", "w")) @@ -152,7 +155,7 @@ def test_DE_from_config( assert ( expected_substring in file_name ), f"File '{file_name}' does not contain the expected substring" - + def test_DE_chkpt_saved( self, temp_directory, temp_data, noise_level="low", size_df=10 ): diff --git a/test/test_DeepEvidentialRegression.py b/test/test_DeepEvidentialRegression.py index a8746ff..1449577 100644 --- a/test/test_DeepEvidentialRegression.py +++ b/test/test_DeepEvidentialRegression.py @@ -94,6 +94,7 @@ def create_test_config( "randomseed": 42, "batchsize": 100, }, + "analysis": {"run_analysis": False} } print("theoretically dumping here", str(temp_directory) + "yamls/DER.yaml") print('this is the yaml', input_yaml) @@ -156,7 +157,6 @@ def test_DER_chkpt_saved(self, assert ( expected_substring in file_name ), f"File '{file_name}' does not contain the expected substring" - def test_DER_from_config(self, temp_directory, temp_data,