Skip to content

Commit

Permalink
Merge pull request #204 from deepskies/feature/brians_first_corrections
Browse files Browse the repository at this point in the history
Feature/brians first corrections
  • Loading branch information
beckynevin authored Nov 5, 2024
2 parents 078c124 + 516965c commit dbc6081
Show file tree
Hide file tree
Showing 14 changed files with 595 additions and 2,080 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## [0.1.6] - 2024-11-05
### Fixed
- train.py now identifies if a previous version of the model checkpoints has been saved for both DE and DER
### Added
- default.py matches defaults used in experiments
- train.py modified with names of flags made more explicit

## [0.1.5] - 2024-11-01
### Fixed
- verbosity in data.py and analyze.py
Expand Down
22 changes: 15 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# DeepUQ
DeepUQ is a package for injecting and measuring different types of uncertainty in ML models.

[![PyPi](https://img.shields.io/badge/PyPi-0.1.5-blue)](https://pypi.org/project/deepuq/)
[![PyPi](https://img.shields.io/badge/PyPi-0.1.6-blue)](https://pypi.org/project/deepuq/)
[![License](https://img.shields.io/badge/License-MIT-lightgrey)](https://opensource.org/licenses/MIT)
[![Downloads](https://static.pepy.tech/personalized-badge/deepuq?period=month&units=international_system&left_color=black&right_color=brightgreen&left_text=Total%20Downloads)](https://pepy.tech/project/deepuq)

Expand All @@ -17,9 +17,17 @@ DeepUQ is a package for injecting and measuring different types of uncertainty i
> pip install deepuq
Now you can run some of the scripts!
> UQensemble --generatedata
> UQensemble --generatedata --save_final_checkpoint --save_all_checkpoints --plot_savefig --overwrite_model
^`generatedata` is required if you don't have any saved data. You can set other keywords like so.
^`--generatedata` is required if you don't have any saved data.

The default behavior is to train the model without saving any checkpoints. By specifying the `--save_final_checkpoint` flag, the script will save a pytorch checkpoint for the final epoch with the model weights as well as diagnostics like the MSE metric and the model loss. This checkpoint will be stored in a folder at the path specified by `--out_dir` flag, the default location is `./DeepUQResources/checkpoints/`.

To additionally save all checkpoints, use the `--save_all_checkpoints` flag.

To save diagnostic plots of the true and predicted model outputs as well as the model residuals, specify `--plot_inline` and `--plot_savefig` (to plot inline and save as a png, respectively).

The `--overwrite_model` flag will retrain and overwrite a previously existing version of the model.

It's also possible to verify the install works by running:
> pytest
Expand Down Expand Up @@ -94,10 +102,10 @@ The `deepuq/` folder contains the relevant modules for config settings, data gen

Example notebooks for how to train and analyze the results of the models can be found in the `notebooks/` folder.

The `DeepUQResources/` folder is the default location for saving checkpoints from the trained model and the `data/` folder is where the training and validation set are saved.
The `DeepUQResources/` folder is the default location for saving checkpoints and diagnostic plots from the trained model and the `data/` folder is where the training and validation set are saved.

## How to run the workflow
The scripts can be accessed via the ipython example notebooks or via the model modules (ie `DeepEnsemble.py`). For example, to ingest data and train a Deep Ensemble from the DeepUQ/ directory:
The scripts can be accessed via the ipython example notebooks in the `notebooks/` folder or via the model modules (ie `deepuq/scripts/DeepEnsemble.py`). For example, to ingest data and train a Deep Ensemble from the DeepUQ/ directory:

> python deepuq/scripts/DeepEnsemble.py
Expand All @@ -114,9 +122,9 @@ Where you would modify the "path/to/config/myconfig.yaml" to specify where your

The third option is to input settings on the command line. These choices are then combined with the default settings and output in a temporary yaml.

> python deepuq/scripts/DeepEnsemble.py --noise_level "low" --n_models 10 --out_dir ./DeepUQResources/results/ --save_final_checkpoint True --savefig True --n_epochs 10
> python deepuq/scripts/DeepEnsemble.py --noise_level "low" --n_models 10 --out_dir ./DeepUQResources/ --save_final_checkpoint --save_all_checkpoints --plot_savefig --n_epochs 10
This command will train a 10 network, 10 epoch ensemble on the low noise data and will save figures and final checkpoints to the specified directory. Required arguments are the noise setting (low/medium/high), the number of ensembles, and the working directory.
This command will train a 10 network, 10 epoch ensemble on the low noise data and will save figures and all checkpoints to the specified directory.

For more information on the arguments:
> python deepuq/scripts/DeepEnsemble.py --help
Expand Down
1 change: 0 additions & 1 deletion deepuq/analyze/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def load_checkpoint(
dict: The loaded checkpoint containing model weights and
additional data.
"""
print(model_name)
if model_name[0:3] == "DER":
file_name = (
str(path)
Expand Down
43 changes: 19 additions & 24 deletions deepuq/scripts/DeepEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.utils.data import TensorDataset, DataLoader

from deepuq.train import train
from deepuq.models import models
from deepuq.data import DataModules
from deepuq.models import ModelModules
from deepuq.utils.config import Config
Expand Down Expand Up @@ -36,7 +35,7 @@ def parse_args():
- Model-related arguments:
--model_engine, --n_models, --init_lr, --loss_type, --BETA,
--model_type, --n_epochs, --save_all_checkpoints,
--save_final_checkpoint, --overwrite_final_checkpoint, --plot,
--save_final_checkpoint, --overwrite_model, --plot,
--savefig, --save_chk_random_seed_init, --rs_list, --n_hidden,
--save_n_hidden, --save_data_size, --verbose
- General arguments:
Expand Down Expand Up @@ -205,21 +204,21 @@ def parse_args():
help="option to save the final epoch checkpoint for each ensemble",
)
parser.add_argument(
"--overwrite_final_checkpoint",
"--overwrite_model",
action="store_true",
default=DefaultsDE["model"]["overwrite_final_checkpoint"],
default=DefaultsDE["model"]["overwrite_model"],
help="option to overwite already saved checkpoints",
)
parser.add_argument(
"--plot",
"--plot_inline",
action="store_true",
default=DefaultsDE["model"]["plot"],
default=DefaultsDE["model"]["plot_inline"],
help="option to plot in notebook",
)
parser.add_argument(
"--savefig",
"--plot_savefig",
action="store_true",
default=DefaultsDE["model"]["savefig"],
default=DefaultsDE["model"]["plot_savefig"],
help="option to save a figure of the true and predicted values",
)
parser.add_argument(
Expand Down Expand Up @@ -296,9 +295,9 @@ def parse_args():
"n_epochs": args.n_epochs,
"save_all_checkpoints": args.save_all_checkpoints,
"save_final_checkpoint": args.save_final_checkpoint,
"overwrite_final_checkpoint": args.overwrite_final_checkpoint,
"plot": args.plot,
"savefig": args.savefig,
"overwrite_model": args.overwrite_model,
"plot_inline": args.plot_inline,
"plot_savefig": args.plot_savefig,
"save_chk_random_seed_init": args.save_chk_random_seed_init,
"rs_list": args.rs_list,
"save_n_hidden": args.save_n_hidden,
Expand All @@ -320,8 +319,6 @@ def parse_args():
"normalize": args.normalize,
"uniform": args.uniform,
},
# "plots": {key: {} for key in args.plots},
# "metrics": {key: {} for key in args.metrics},
}

yaml.dump(input_yaml, open(temp_config, "w"))
Expand Down Expand Up @@ -500,12 +497,6 @@ def main():
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = config.get_item("model", "model_type", "DE")
model, lossFn = models.model_setup_DE(
config.get_item("model", "loss_type", "DE"),
DEVICE,
n_hidden=config.get_item("model", "n_hidden", "DE"),
data_type=dim,
)
print(
"save final checkpoint has this value",
config.get_item("model", "save_final_checkpoint", "DE"),
Expand All @@ -525,7 +516,7 @@ def main():
model_name=model_name,
BETA=config.get_item("model", "BETA", "DE"),
EPOCHS=config.get_item("model", "n_epochs", "DE"),
path_to_model=config.get_item("common", "out_dir", "DE"),
out_dir=config.get_item("common", "out_dir", "DE"),
inject_type=injection,
data_dim=dim,
noise_level=noise,
Expand All @@ -535,11 +526,15 @@ def main():
save_final_checkpoint=config.get_item(
"model", "save_final_checkpoint", "DE"
),
overwrite_final_checkpoint=config.get_item(
"model", "overwrite_final_checkpoint", "DE"
overwrite_model=config.get_item(
"model", "overwrite_model", "DE"
),
plot_inline=config.get_item(
"model", "plot_inline", "DE"
),
plot_savefig=config.get_item(
"model", "plot_savefig", "DE"
),
plot=config.get_item("model", "plot", "DE"),
savefig=config.get_item("model", "savefig", "DE"),
set_and_save_rs=config.get_item(
"model", "save_chk_random_seed_init", "DE"
),
Expand Down
32 changes: 15 additions & 17 deletions deepuq/scripts/DeepEvidentialRegression.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def parse_args():
- Model-related arguments:
--model_engine, --init_lr, --loss_type, --COEFF, --model_type,
--n_epochs, --save_all_checkpoints, --save_final_checkpoint,
--overwrite_final_checkpoint, --plot, --savefig,
--overwrite_model, --plot_inline, --plot_savefig,
--save_chk_random_seed_init, --rs_list, --n_hidden, --save_n_hidden,
--save_data_size, --verbose
- General arguments:
Expand Down Expand Up @@ -199,21 +199,21 @@ def parse_args():
help="option to save the final epoch checkpoint for each ensemble",
)
parser.add_argument(
"--overwrite_final_checkpoint",
"--overwrite_model",
action="store_true",
default=DefaultsDER["model"]["overwrite_final_checkpoint"],
default=DefaultsDER["model"]["overwrite_model"],
help="option to overwite already saved checkpoints",
)
parser.add_argument(
"--plot",
"--plot_inline",
action="store_true",
default=DefaultsDER["model"]["plot"],
default=DefaultsDER["model"]["plot_inline"],
help="option to plot in notebook",
)
parser.add_argument(
"--savefig",
"--plot_savefig",
action="store_true",
default=DefaultsDER["model"]["savefig"],
default=DefaultsDER["model"]["plot_savefig"],
help="option to save a figure of the true and predicted values",
)
parser.add_argument(
Expand Down Expand Up @@ -286,9 +286,9 @@ def parse_args():
"n_epochs": args.n_epochs,
"save_all_checkpoints": args.save_all_checkpoints,
"save_final_checkpoint": args.save_final_checkpoint,
"overwrite_final_checkpoint": args.overwrite_final_checkpoint,
"plot": args.plot,
"savefig": args.savefig,
"overwrite_model": args.overwrite_model,
"plot_inline": args.plot_inline,
"plot_savefig": args.plot_savefig,
"save_chk_random_seed_init": args.save_chk_random_seed_init,
"rs": args.rs,
"save_n_hidden": args.save_n_hidden,
Expand All @@ -310,8 +310,6 @@ def parse_args():
"normalize": args.normalize,
"uniform": args.uniform,
},
# "plots": {key: {} for key in args.plots},
# "metrics": {key: {} for key in args.metrics},
}

yaml.dump(input_yaml, open(temp_config, "w"))
Expand Down Expand Up @@ -471,7 +469,7 @@ def main():
norm_params,
model_name=model_name,
EPOCHS=config.get_item("model", "n_epochs", "DER"),
path_to_model=config.get_item("common", "out_dir", "DER"),
out_dir=config.get_item("common", "out_dir", "DER"),
inject_type=injection,
data_dim=dim,
noise_level=noise,
Expand All @@ -481,11 +479,11 @@ def main():
save_final_checkpoint=config.get_item(
"model", "save_final_checkpoint", "DER"
),
overwrite_final_checkpoint=config.get_item(
"model", "overwrite_final_checkpoint", "DER"
overwrite_model=config.get_item(
"model", "overwrite_model", "DER"
),
plot=config.get_item("model", "plot", "DER"),
savefig=config.get_item("model", "savefig", "DER"),
plot_inline=config.get_item("model", "plot_inline", "DER"),
plot_savefig=config.get_item("model", "plot_savefig", "DER"),
set_and_save_rs=config.get_item(
"model", "save_chk_random_seed_init", "DER"
),
Expand Down
Loading

0 comments on commit dbc6081

Please sign in to comment.