Skip to content

Commit

Permalink
temp data tests running with analysis keyword added
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed May 6, 2024
1 parent 5a4b137 commit 614f6fe
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 18 deletions.
40 changes: 24 additions & 16 deletions src/scripts/DeepEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ def parse_args():
default=DefaultsDE["model"]["savefig"],
help="option to save a figure of the true and predicted values",
)
parser.add_argument(
"--run_analysis",
action="store_true",
default=DefaultsDE["analysis"]["run_analysis"],
help="option to run analysis on saved checkpoints",
)
parser.add_argument(
"--verbose",
action="store_true",
Expand Down Expand Up @@ -224,6 +230,7 @@ def parse_args():
"randomseed": args.randomseed,
"batchsize": args.batchsize,
},
"analysis": {"run_analysis": args.run_analysis}
# "plots": {key: {} for key in args.plots},
# "metrics": {key: {} for key in args.metrics},
}
Expand Down Expand Up @@ -310,7 +317,7 @@ def beta_type(value):
model, lossFn = models.model_setup_DE(
config.get_item("model", "loss_type", "DE"), DEVICE
)
'''

print(
"save final checkpoint has this value",
config.get_item("model", "save_final_checkpoint", "DE"),
Expand Down Expand Up @@ -340,18 +347,19 @@ def beta_type(value):
savefig=config.get_item("model", "savefig", "DE"),
verbose=config.get_item("model", "verbose", "DE"),
)
'''
# now run the analysis on the resulting checkpoints
chk_module = AggregateCheckpoints()
print('n_models', config.get_item("model", "n_models", "DE"))
print('n_epochs', config.get_item("model", "n_epochs", "DE"))
for nmodel in range(config.get_item("model", "n_models", "DE")):
for epoch in range(config.get_item("model", "n_epochs", "DE")):
chk = chk_module.load_DE_checkpoint(model,
model_name,
nmodel,
epoch,
config.get_item("model", "BETA", "DE"),
DEVICE)
print(chk)
STOP
if config.get_item("analysis", "run_analysis", "DE"):
# now run the analysis on the resulting checkpoints
chk_module = AggregateCheckpoints()
print('n_models', config.get_item("model", "n_models", "DE"))
print('n_epochs', config.get_item("model", "n_epochs", "DE"))
for nmodel in range(config.get_item("model", "n_models", "DE")):
for epoch in range(config.get_item("model", "n_epochs", "DE")):
chk = chk_module.load_DE_checkpoint(
model_name,
nmodel,
epoch,
config.get_item("model", "BETA", "DE"),
DEVICE)
# things to grab: 'valid_mse' and 'valid_bnll'
print(chk)
STOP
2 changes: 2 additions & 0 deletions src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"line_style_cycle": ["-", "-."],
"figure_size": [6, 6],
},
"analysis": {"run_analysis": False},
"plots": {"CDFRanks": {},
"Ranks": {"num_bins": None},
"CoverageFraction": {}},
Expand Down Expand Up @@ -83,6 +84,7 @@
"savefig": False,
"verbose": False,
},
"analysis": {"run_analysis": False},
"plots_common": {
"axis_spines": False,
"tight_layout": True,
Expand Down
5 changes: 4 additions & 1 deletion test/test_DeepEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def create_test_config(
"randomseed": 42,
"batchsize": 100,
},
"analysis": {
"run_analysis": False
}
}
print("theoretically dumping here", str(temp_directory) + "yamls/DE.yaml")
yaml.dump(input_yaml, open(str(temp_directory) + "yamls/DE.yaml", "w"))
Expand Down Expand Up @@ -152,7 +155,7 @@ def test_DE_from_config(
assert (
expected_substring in file_name
), f"File '{file_name}' does not contain the expected substring"

def test_DE_chkpt_saved(
self, temp_directory, temp_data, noise_level="low", size_df=10
):
Expand Down
2 changes: 1 addition & 1 deletion test/test_DeepEvidentialRegression.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def create_test_config(
"randomseed": 42,
"batchsize": 100,
},
"analysis": {"run_analysis": False}
}
print("theoretically dumping here", str(temp_directory) + "yamls/DER.yaml")
print('this is the yaml', input_yaml)
Expand Down Expand Up @@ -156,7 +157,6 @@ def test_DER_chkpt_saved(self,
assert (
expected_substring in file_name
), f"File '{file_name}' does not contain the expected substring"

def test_DER_from_config(self,
temp_directory,
temp_data,
Expand Down

0 comments on commit 614f6fe

Please sign in to comment.