diff --git a/src/scripts/DeepEvidentialRegression.py b/src/scripts/DeepEvidentialRegression.py index 1e4bfe8..3a943cf 100644 --- a/src/scripts/DeepEvidentialRegression.py +++ b/src/scripts/DeepEvidentialRegression.py @@ -362,10 +362,20 @@ def parse_args(): xycoords='axes fraction', color='white', size=10) + plt.colorbar() plt.show() - model_inputs, model_outputs = DataPreparation.normalize( + model_inputs, model_outputs, norm_params = DataPreparation.normalize( model_inputs, model_outputs, norm ) + plt.clf() + plt.imshow(model_inputs[0]) + plt.annotate('Pixel sum = ' + str(round(model_outputs[0], 2)), + xy=(0.02, 0.9), + xycoords='axes fraction', + color='white', + size=10) + plt.colorbar() + plt.show() x_train, x_val, y_train, y_val = DataPreparation.train_val_split( model_inputs, model_outputs, val_proportion=val_prop, random_state=rs ) @@ -399,7 +409,8 @@ def parse_args(): DEVICE, config.get_item("model", "COEFF", "DER"), config.get_item("model", "loss_type", "DER"), - model_name, + norm_params, + model_name=model_name, EPOCHS=config.get_item("model", "n_epochs", "DER"), path_to_model=config.get_item("common", "out_dir", "DER"), data_prescription=prescription, diff --git a/src/train/train.py b/src/train/train.py index 037d0d5..891d3fa 100644 --- a/src/train/train.py +++ b/src/train/train.py @@ -23,6 +23,7 @@ def train_DER( DEVICE, COEFF, loss_type, + norm_params, model_name="DER", EPOCHS=100, path_to_model="models/", @@ -313,6 +314,7 @@ def train_DER( "mean_u_ep_validation": mean_u_ep_val, "std_u_al_validation": std_u_al_val, "std_u_ep_validation": std_u_ep_val, + "norm_params": norm_params, }, filename, ) @@ -361,6 +363,7 @@ def train_DER( "mean_u_ep_validation": mean_u_ep_val, "std_u_al_validation": std_u_al_val, "std_u_ep_validation": std_u_ep_val, + "norm_params": norm_params, }, filename, ) @@ -484,6 +487,8 @@ def train_DE( model, lossFn = models.model_setup_DE( loss_type, DEVICE, n_hidden=n_hidden, data_type=data_dim ) + if verbose: + print("model is", model, "lossfn", lossFn) opt = torch.optim.Adam(model.parameters(), lr=INIT_LR) mse_loss = torch.nn.MSELoss(reduction="mean")