diff --git a/deepuq/scripts/DeepEnsemble.py b/deepuq/scripts/DeepEnsemble.py index 02006d0..f187845 100644 --- a/deepuq/scripts/DeepEnsemble.py +++ b/deepuq/scripts/DeepEnsemble.py @@ -8,7 +8,6 @@ from torch.utils.data import TensorDataset, DataLoader from deepuq.train import train -from deepuq.models import models from deepuq.data import DataModules from deepuq.models import ModelModules from deepuq.utils.config import Config diff --git a/deepuq/scripts/DeepEvidentialRegression.py b/deepuq/scripts/DeepEvidentialRegression.py index 867d0f5..51bfcca 100644 --- a/deepuq/scripts/DeepEvidentialRegression.py +++ b/deepuq/scripts/DeepEvidentialRegression.py @@ -40,7 +40,7 @@ def parse_args(): - Model-related arguments: --model_engine, --init_lr, --loss_type, --COEFF, --model_type, --n_epochs, --save_all_checkpoints, --save_final_checkpoint, - --overwrite_final_checkpoint, --plot, --savefig, + --overwrite_model, --plot_inline, --plot_savefig, --save_chk_random_seed_init, --rs_list, --n_hidden, --save_n_hidden, --save_data_size, --verbose - General arguments: @@ -199,21 +199,21 @@ def parse_args(): help="option to save the final epoch checkpoint for each ensemble", ) parser.add_argument( - "--overwrite_final_checkpoint", + "--overwrite_model", action="store_true", - default=DefaultsDER["model"]["overwrite_final_checkpoint"], + default=DefaultsDER["model"]["overwrite_model"], help="option to overwite already saved checkpoints", ) parser.add_argument( - "--plot", + "--plot_inline", action="store_true", - default=DefaultsDER["model"]["plot"], + default=DefaultsDER["model"]["plot_inline"], help="option to plot in notebook", ) parser.add_argument( - "--savefig", + "--plot_savefig", action="store_true", - default=DefaultsDER["model"]["savefig"], + default=DefaultsDER["model"]["plot_savefig"], help="option to save a figure of the true and predicted values", ) parser.add_argument( @@ -286,9 +286,9 @@ def parse_args(): "n_epochs": args.n_epochs, "save_all_checkpoints": args.save_all_checkpoints, "save_final_checkpoint": args.save_final_checkpoint, - "overwrite_final_checkpoint": args.overwrite_final_checkpoint, - "plot": args.plot, - "savefig": args.savefig, + "overwrite_model": args.overwrite_model, + "plot_inline": args.plot_inline, + "plot_savefig": args.plot_savefig, "save_chk_random_seed_init": args.save_chk_random_seed_init, "rs": args.rs, "save_n_hidden": args.save_n_hidden, @@ -310,8 +310,6 @@ def parse_args(): "normalize": args.normalize, "uniform": args.uniform, }, - # "plots": {key: {} for key in args.plots}, - # "metrics": {key: {} for key in args.metrics}, } yaml.dump(input_yaml, open(temp_config, "w")) @@ -481,11 +479,11 @@ def main(): save_final_checkpoint=config.get_item( "model", "save_final_checkpoint", "DER" ), - overwrite_final_checkpoint=config.get_item( - "model", "overwrite_final_checkpoint", "DER" + overwrite_model=config.get_item( + "model", "overwrite_model", "DER" ), - plot=config.get_item("model", "plot", "DER"), - savefig=config.get_item("model", "savefig", "DER"), + plot_inline=config.get_item("model", "plot_inline", "DER"), + plot_savefig=config.get_item("model", "plot_savefig", "DER"), set_and_save_rs=config.get_item( "model", "save_chk_random_seed_init", "DER" ), diff --git a/deepuq/train/train.py b/deepuq/train/train.py index db0b8c5..0304ee2 100644 --- a/deepuq/train/train.py +++ b/deepuq/train/train.py @@ -771,7 +771,8 @@ def train_DE( epoch, ) # best_weights = copy.deepcopy(model.state_dict()) - if (plot_inline or plot_savefig) and (e % (EPOCHS - 1) == 0) and (e != 0): + if ((plot_inline or plot_savefig) and + (e % (EPOCHS - 1) == 0) and (e != 0)): ax1.plot( range(0, 1000), range(0, 1000), diff --git a/deepuq/utils/config.py b/deepuq/utils/config.py index d59f7e4..c2c9424 100644 --- a/deepuq/utils/config.py +++ b/deepuq/utils/config.py @@ -213,5 +213,4 @@ def get_section(self, section, defaulttype, raise_exception=True): return { "DER": DefaultsDER, "DE": DefaultsDE, - "Analysis": DefaultsAnalysis, }[defaulttype][section] diff --git a/test/test_DeepEnsemble.py b/test/test_DeepEnsemble.py index 805bb69..f5eda6c 100644 --- a/test/test_DeepEnsemble.py +++ b/test/test_DeepEnsemble.py @@ -375,7 +375,7 @@ def test_DE_overwrite_true( # Assert that the file was overwritten # (modification time not be the same) assert initial_mtime < final_mtime - + def test_DE_from_saved_data( self, temp_directory, temp_data, noise_level="low", size_df=100 ):