Skip to content

Commit

Permalink
flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Nov 4, 2024
1 parent 3090df1 commit d86f3b7
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 20 deletions.
1 change: 0 additions & 1 deletion deepuq/scripts/DeepEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.utils.data import TensorDataset, DataLoader

from deepuq.train import train
from deepuq.models import models
from deepuq.data import DataModules
from deepuq.models import ModelModules
from deepuq.utils.config import Config
Expand Down
30 changes: 14 additions & 16 deletions deepuq/scripts/DeepEvidentialRegression.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def parse_args():
- Model-related arguments:
--model_engine, --init_lr, --loss_type, --COEFF, --model_type,
--n_epochs, --save_all_checkpoints, --save_final_checkpoint,
--overwrite_final_checkpoint, --plot, --savefig,
--overwrite_model, --plot_inline, --plot_savefig,
--save_chk_random_seed_init, --rs_list, --n_hidden, --save_n_hidden,
--save_data_size, --verbose
- General arguments:
Expand Down Expand Up @@ -199,21 +199,21 @@ def parse_args():
help="option to save the final epoch checkpoint for each ensemble",
)
parser.add_argument(
"--overwrite_final_checkpoint",
"--overwrite_model",
action="store_true",
default=DefaultsDER["model"]["overwrite_final_checkpoint"],
default=DefaultsDER["model"]["overwrite_model"],
help="option to overwite already saved checkpoints",
)
parser.add_argument(
"--plot",
"--plot_inline",
action="store_true",
default=DefaultsDER["model"]["plot"],
default=DefaultsDER["model"]["plot_inline"],
help="option to plot in notebook",
)
parser.add_argument(
"--savefig",
"--plot_savefig",
action="store_true",
default=DefaultsDER["model"]["savefig"],
default=DefaultsDER["model"]["plot_savefig"],
help="option to save a figure of the true and predicted values",
)
parser.add_argument(
Expand Down Expand Up @@ -286,9 +286,9 @@ def parse_args():
"n_epochs": args.n_epochs,
"save_all_checkpoints": args.save_all_checkpoints,
"save_final_checkpoint": args.save_final_checkpoint,
"overwrite_final_checkpoint": args.overwrite_final_checkpoint,
"plot": args.plot,
"savefig": args.savefig,
"overwrite_model": args.overwrite_model,
"plot_inline": args.plot_inline,
"plot_savefig": args.plot_savefig,
"save_chk_random_seed_init": args.save_chk_random_seed_init,
"rs": args.rs,
"save_n_hidden": args.save_n_hidden,
Expand All @@ -310,8 +310,6 @@ def parse_args():
"normalize": args.normalize,
"uniform": args.uniform,
},
# "plots": {key: {} for key in args.plots},
# "metrics": {key: {} for key in args.metrics},
}

yaml.dump(input_yaml, open(temp_config, "w"))
Expand Down Expand Up @@ -481,11 +479,11 @@ def main():
save_final_checkpoint=config.get_item(
"model", "save_final_checkpoint", "DER"
),
overwrite_final_checkpoint=config.get_item(
"model", "overwrite_final_checkpoint", "DER"
overwrite_model=config.get_item(
"model", "overwrite_model", "DER"
),
plot=config.get_item("model", "plot", "DER"),
savefig=config.get_item("model", "savefig", "DER"),
plot_inline=config.get_item("model", "plot_inline", "DER"),
plot_savefig=config.get_item("model", "plot_savefig", "DER"),
set_and_save_rs=config.get_item(
"model", "save_chk_random_seed_init", "DER"
),
Expand Down
3 changes: 2 additions & 1 deletion deepuq/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,8 @@ def train_DE(
epoch,
)
# best_weights = copy.deepcopy(model.state_dict())
if (plot_inline or plot_savefig) and (e % (EPOCHS - 1) == 0) and (e != 0):
if ((plot_inline or plot_savefig) and
(e % (EPOCHS - 1) == 0) and (e != 0)):
ax1.plot(
range(0, 1000),
range(0, 1000),
Expand Down
1 change: 0 additions & 1 deletion deepuq/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,5 +213,4 @@ def get_section(self, section, defaulttype, raise_exception=True):
return {
"DER": DefaultsDER,
"DE": DefaultsDE,
"Analysis": DefaultsAnalysis,
}[defaulttype][section]
2 changes: 1 addition & 1 deletion test/test_DeepEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def test_DE_overwrite_true(
# Assert that the file was overwritten
# (modification time not be the same)
assert initial_mtime < final_mtime

def test_DE_from_saved_data(
self, temp_directory, temp_data, noise_level="low", size_df=100
):
Expand Down

0 comments on commit d86f3b7

Please sign in to comment.