Skip to content

Commit

Permalink
overwrites stuff on command line
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Apr 22, 2024
1 parent a3ca63c commit eae7046
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 109 deletions.
Binary file added images/DeepUQWorkflow_Maggie.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
87 changes: 35 additions & 52 deletions src/scripts/DeepEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,23 @@ def parse_args():
choices=DataModules.keys())

# model
parser.add_argument("--model_path", '-m', default=None)
# path to save the model results
parser.add_argument("--out_dir",
default=DefaultsDE['common']['out_dir'])
parser.add_argument("--model_engine", '-e',
default=DefaultsDE['model']['model_engine'],
choices=ModelModules.keys())

# path to save the yaml if thats what you'd like
parser.add_argument("--out_dir",
default=DefaultsDE['common']['out_dir'])

# List of metrics (cannot supply specific kwargs)
# parser.add_argument("--metrics", nargs='+', default=list(Defaults['metrics'].keys()), choices=Metrics.keys())

# List of plots
#parser.add_argument("--plots", nargs='+', default=list(Defaults['plots'].keys()), choices=Plots.keys())

parser.add_argument(
"--size_df",
type=float,
required=False,
default=1000,
default=DefaultsDE['data']['size_df'],
help="Used to load the associated .h5 data file",
)
parser.add_argument(
"--noise_level",
type=str,
default="low",
default=DefaultsDE['data']['noise_level'],
choices=["low", "medium", "high", "vhigh"],
help="low, medium, high or vhigh, \
used to look up associated sigma value",
Expand All @@ -70,152 +61,149 @@ def parse_args():
"--normalize",
required=False,
action="store_true",
default=DefaultsDE['data']['normalize'],
help="If true theres an option to normalize the dataset",
)
parser.add_argument(
"--val_proportion",
type=float,
required=False,
default=0.1,
default=DefaultsDE['data']['val_proportion'],
help="Proportion of the dataset to use as validation",
)
parser.add_argument(
"--randomseed",
type=int,
required=False,
default=42,
default=DefaultsDE['data']['randomseed'],
help="Random seed used for shuffling the training and validation set",
)
parser.add_argument(
"--generatedata",
action="store_true",
default=False,
default=DefaultsDE['data']['generatedata'],
help="option to generate df, if not specified \
default behavior is to load from file",
)
parser.add_argument(
"--batchsize",
type=int,
required=False,
default=100,
default=DefaultsDE['data']['batchsize'],
help="Size of batched used in the traindataloader",
)
# now args for model
parser.add_argument(
"--n_models",
type=int,
default=100,
default=DefaultsDE['model']['n_models'],
help="Number of MVEs in the ensemble",
)
parser.add_argument(
"--init_lr",
type=float,
required=False,
default=0.001,
default=DefaultsDE['model']['init_lr'],
help="Learning rate",
)
parser.add_argument(
"--loss_type",
type=str,
required=False,
default="bnll_loss",
default=DefaultsDE['model']['loss_type'],
help="Loss types for MVE, options are no_var_loss, var_loss, \
and bnn_loss",
and bnn_loss",
)
parser.add_argument(
"--BETA",
type=beta_type,
required=False,
default=0.5,
default=DefaultsDE['model']['BETA'],
help="If loss_type is bnn_loss, specify a beta as a float or \
there are string options: linear_decrease, \
step_decrease_to_0.5, and step_decrease_to_1.0",
there are string options: linear_decrease, \
step_decrease_to_0.5, and step_decrease_to_1.0",
)
parser.add_argument(
"--wd",
type=str,
default="./DeepUQResources/",
default=DefaultsDE['model']['wd'],
help="Top level of directory, required arg",
)
parser.add_argument(
"--model_type",
type=str,
required=False,
default="DE",
default=DefaultsDE['model']['model_type'],
help="Beginning of name for saved checkpoints and figures",
)
parser.add_argument(
"--n_epochs",
type=int,
required=False,
default=100,
default=DefaultsDE['model']['n_epochs'],
help="number of epochs for each MVE",
)
parser.add_argument(
"--path_to_models",
type=str,
required=False,
default="models/",
help="path to where the checkpoints are saved",
)
parser.add_argument(
"--save_all_checkpoints",
action="store_true",
default=False,
default=DefaultsDE['model']['save_all_checkpoints'],
help="option to save all checkpoints",
)
parser.add_argument(
"--save_final_checkpoint",
action="store_true", # Set to True if argument is present
default=False, # Set default value to False if argument is not present
default=DefaultsDE['model']['save_final_checkpoint'],
help="option to save the final epoch checkpoint for each ensemble",
)
parser.add_argument(
"--overwrite_final_checkpoint",
action="store_true",
default=False,
default=DefaultsDE['model']['overwrite_final_checkpoint'],
help="option to overwite already saved checkpoints",
)
parser.add_argument(
"--plot",
action="store_true",
default=False,
default=DefaultsDE['model']['plot'],
help="option to plot in notebook",
)
parser.add_argument(
"--savefig",
action="store_true",
default=False,
default=DefaultsDE['model']['savefig'],
help="option to save a figure of the true and predicted values",
)
parser.add_argument(
"--verbose",
action="store_true",
default=False,
default=DefaultsDE['model']['verbose'],
help="verbose option for train",
)
#return parser.parse_args()

args = parser.parse_args()
args = parser.parse_args()
if args.config is not None:
print('Reading settings from config file', args.config)
config = Config(args.config)

else:
temp_config = DefaultsDE['common']['temp_config']
print('Reading settings from cli and default, \
dumping to temp config: ',
temp_config)
os.makedirs(os.path.dirname(temp_config), exist_ok=True)

# check if args were specified in cli
# if not, default is from DefaultsDE dictionary
input_yaml = {
"common": {"out_dir": args.out_dir},
"model": {"model_path": args.model_path,
"model_engine": args.model_engine,
"model": {"model_engine": args.model_engine,
"model_type": args.model_type,
"loss_type": args.loss_type,
"n_models": args.n_models,
"init_lr": args.init_lr,
"wd": args.wd,
"BETA": args.BETA,
"n_epochs": args.n_epochs,
"path_to_models": args.path_to_models,
"save_all_checkpoints": args.save_all_checkpoints,
"save_final_checkpoint": args.save_final_checkpoint,
"overwrite_final_checkpoint": args.overwrite_final_checkpoint,
Expand Down Expand Up @@ -267,10 +255,6 @@ def beta_type(value):
rs = config.get_item("data", "randomseed", "DE")
BATCH_SIZE = config.get_item("data", "batchsize", "DE")
sigma = DataPreparation.get_sigma(noise)
print("generated data", config.get_item("data",
"generatedata",
"DE",
raise_exception=False))
if config.get_item("data", "generatedata", "DE", raise_exception=False):
# generate the df
data = DataPreparation()
Expand Down Expand Up @@ -311,7 +295,6 @@ def beta_type(value):
trainDataLoader = DataLoader(trainData,
batch_size=BATCH_SIZE,
shuffle=True)
print("[INFO] initializing the gal model...")
# set the device we will be using to train the model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -332,7 +315,7 @@ def beta_type(value):
model_name,
BETA=config.get_item("model", "BETA", "DE"),
EPOCHS=config.get_item("model", "n_epochs", "DE"),
path_to_model=config.get_item("model", "path_to_models", "DE"),
path_to_model=config.get_item("common", "out_dir", "DE"),
save_all_checkpoints=config.get_item("model", "save_all_checkpoints",
"DE"),
save_final_checkpoint=config.get_item("model",
Expand Down
Loading

0 comments on commit eae7046

Please sign in to comment.