From c8e2783131a94c85d3942b9fb745723bf48039c8 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Tue, 9 Apr 2024 15:02:06 -0600 Subject: [PATCH] correcting flake8 making sure everything is defined in DER --- src/scripts/train.py | 128 ++++++++++++++----------------------------- 1 file changed, 42 insertions(+), 86 deletions(-) diff --git a/src/scripts/train.py b/src/scripts/train.py index 56120bd..ce73e68 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -24,7 +24,7 @@ def train_DER( overwrite_final_checkpoint=False, plot=True, savefig=True, - verbose=True + verbose=True, ): # first determine if you even need to run anything if not save_all_checkpoints and save_final_checkpoint: @@ -38,7 +38,7 @@ def train_DER( + "_epoch_" + str(EPOCHS - 1) + ".pt" - ) + ) if verbose: print("final chk", final_chk) # check if the final epoch checkpoint already exists @@ -66,11 +66,11 @@ def train_DER( startTime = time.time() start_epoch = 0 - + best_loss = np.inf # init to infinity model, lossFn = models.model_setup_DER(loss_type, DEVICE) if verbose: - print('model is', model, 'lossfn', lossFn) + print("model is", model, "lossfn", lossFn) opt = torch.optim.Adam(model.parameters(), lr=INIT_LR) @@ -104,8 +104,6 @@ def train_DER( loss = lossFn(pred, y, COEFF) if plot or savefig: if (e % (EPOCHS - 1) == 0) and (e != 0): - pred_loader_0 = pred[:, 0].flatten().detach().numpy() - y_loader_0 = y.detach().numpy() ax1.scatter( y, pred[:, 0].flatten().detach().numpy(), @@ -126,12 +124,12 @@ def train_DER( xycoords="axes fraction", color="black", ) - ''' + """ else: ax1.scatter(y, pred[:, 0].flatten().detach().numpy(), color="grey") - ''' + """ loss_this_epoch.append(loss[0].item()) # zero out the gradients @@ -143,11 +141,25 @@ def train_DER( # optimizer takes a step based on the gradients of the parameters # here, its taking a step for every batch opt.step() + model.eval() + y_pred = model(torch.Tensor(x_val)) + loss = lossFn(y_pred, torch.Tensor(y_val), COEFF) + NIGloss_val = loss[0].item() + med_u_al_val = np.median(loss[1]) + med_u_ep_val = np.median(loss[2]) + std_u_al_val = np.std(loss[1]) + std_u_ep_val = np.std(loss[2]) + + # lets also grab mse loss + mse_loss = torch.nn.MSELoss(reduction="mean") + mse = mse_loss(y_pred[:, 0], torch.Tensor(y_val)).item() + if NIGloss_val < best_loss: + best_loss = NIGloss_val + if verbose: + print("new best loss", NIGloss_val, "in epoch", epoch) + # best_weights = copy.deepcopy(model.state_dict()) if (plot or savefig) and (e % (EPOCHS - 1) == 0) and (e != 0): - ax1.plot(range(0, 1000), - range(0, 1000), - color="black", - ls="--") + ax1.plot(range(0, 1000), range(0, 1000), color="black", ls="--") if loss_type == "no_var_loss": ax1.scatter( y_val, @@ -191,42 +203,21 @@ def train_DER( ax2.set_ylabel("Residuals") ax2.set_xlabel("True Value") # add annotion for loss value - if loss_type == "bnll_loss": - ax1.annotate( - r"$\beta = $" - + str(round(beta_epoch, 2)) - + "\n" - + str(loss_type) - + " = " - + str(round(loss, 2)) - + "\n" - + r"MSE = " - + str(round(mse, 2)), - xy=(0.73, 0.1), - xycoords="axes fraction", - bbox=dict( - boxstyle="round,pad=0.5", - facecolor="lightgrey", - alpha=0.5 - ), - ) - - else: - ax1.annotate( - str(loss_type) - + " = " - + str(round(loss, 2)) - + "\n" - + r"MSE = " - + str(round(mse, 2)), - xy=(0.73, 0.1), - xycoords="axes fraction", - bbox=dict( - boxstyle="round,pad=0.5", - facecolor="lightgrey", - alpha=0.5 - ), - ) + ax1.annotate( + str(loss_type) + + " = " + + str(round(loss, 2)) + + "\n" + + r"MSE = " + + str(round(mse, 2)), + xy=(0.73, 0.1), + xycoords="axes fraction", + bbox=dict( + boxstyle="round,pad=0.5", + facecolor="lightgrey", + alpha=0.5 + ), + ) ax1.set_ylabel("Prediction") ax1.set_title("Epoch " + str(e)) ax1.set_xlim([0, 1000]) @@ -248,40 +239,6 @@ def train_DER( if plot: plt.show() plt.close() - ''' - if plot and (e % 5 == 0): - ax1.set_ylabel("prediction") - ax1.set_title("Epoch " + str(e)) - - # Residuals plot - residuals = pred_loader_0 - y_loader_0 - ax2.scatter(y_loader_0, residuals, color="red") - ax2.axhline(0, color="black", linestyle="--", linewidth=1) - ax2.set_ylabel("Residuals") - ax2.set_xlabel("True Value") - - plt.show() - plt.close() - ''' - model.eval() - y_pred = model(torch.Tensor(x_val)) - loss = lossFn(y_pred, torch.Tensor(y_val), COEFF) - NIGloss_val = loss[0].item() - med_u_al_val = np.median(loss[1]) - med_u_ep_val = np.median(loss[2]) - std_u_al_val = np.std(loss[1]) - std_u_ep_val = np.std(loss[2]) - - # lets also grab mse loss - mse_loss = torch.nn.MSELoss(reduction="mean") - mse = mse_loss(y_pred[:, 0], torch.Tensor(y_val)).item() - if NIGloss_val < best_loss: - best_loss = NIGloss_val - if verbose: - print("new best loss", NIGloss_val, "in epoch", epoch) - # best_weights = copy.deepcopy(model.state_dict()) - # print('validation loss', mse) - if save_all_checkpoints: torch.save( @@ -307,7 +264,7 @@ def train_DER( + ".pt", ) if save_final_checkpoint and (e % (EPOCHS - 1) == 0) and (e != 0): - # option to just save final epoch + # option to just save final epoch torch.save( { "epoch": epoch, @@ -329,7 +286,7 @@ def train_DER( + "_epoch_" + str(epoch) + ".pt", - ) + ) endTime = time.time() if verbose: print("start at", startTime, "end at", endTime) @@ -371,8 +328,7 @@ def train_DE( model_ensemble = [] - print('this is the value of save_final_checkpoint', - save_final_checkpoint) + print("this is the value of save_final_checkpoint", save_final_checkpoint) for m in range(n_models): print("model", m)