diff --git a/src/scripts/train.py b/src/scripts/train.py index 2ff5005..f18c18a 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -58,6 +58,12 @@ def train_DER( # loop over our epochs for e in range(0, EPOCHS): + if plot: + plt.clf() + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), + gridspec_kw={'height_ratios': [3, 1]} + ) + epoch = int(start_epoch + e) # set the model in training mode @@ -66,10 +72,7 @@ def train_DER( # loop over the training set if verbose: print("epoch", epoch, round(e / EPOCHS, 2)) - loss_this_epoch = [] - - plt.clf() # randomly shuffles the training data (if shuffle = True) # and draws batches up to the total training size # (should be about 8 batches) @@ -81,34 +84,34 @@ def train_DER( pred = model(x) loss = lossFn(pred, y, COEFF) - if plot: - if e % 5 == 0: - if i == 0: - # if loss_type == 'no_var_loss': - plt.scatter( - y, - pred[:, 0].flatten().detach().numpy(), - color="#F45866", - edgecolor="black", - zorder=100, - ) - plt.errorbar( - y, - pred[:, 0].flatten().detach().numpy(), - yerr=loss[2], - color="#F45866", - zorder=100, - ls="None", - ) - plt.annotate( - r"med $u_{ep} = " + str(np.median(loss[2])), - xy=(0.03, 0.93), - xycoords="axes fraction", - color="#F45866", - ) - - else: - plt.scatter(y, pred[:, 0].flatten().detach().numpy()) + if plot and (e % 5 == 0): + if i == 0: + pred_loader_0 = pred[:,0].flatten().detach().numpy() + y_loader_0 = y.detach().numpy() + ax1.scatter( + y, + pred[:, 0].flatten().detach().numpy(), + color="black", + zorder=100, + ) + ax1.errorbar( + y, + pred[:, 0].flatten().detach().numpy(), + yerr=loss[2], + color="black", + zorder=100, + ls="None", + ) + ax1.annotate( + r"med $u_{ep} = $" + str(np.median(loss[2])), + xy=(0.03, 0.93), + xycoords="axes fraction", + color="black", + ) + else: + ax1.scatter(y, + pred[:, 0].flatten().detach().numpy(), + color='grey') loss_this_epoch.append(loss[0].item()) # zero out the gradients @@ -120,12 +123,19 @@ def train_DER( # optimizer takes a step based on the gradients of the parameters # here, its taking a step for every batch opt.step() - if plot: - if e % 5 == 0: - plt.ylabel("prediction") - plt.xlabel("true value") - plt.title("Epoch " + str(e)) - plt.show() + if plot and (e % 5 == 0): + ax1.set_ylabel("prediction") + ax1.set_title("Epoch " + str(e)) + + # Residuals plot + residuals = pred_loader_0 - y_loader_0 + ax2.scatter(y_loader_0, residuals, color='red') + ax2.axhline(0, color='black', linestyle='--', linewidth=1) + ax2.set_ylabel("Residuals") + ax2.set_xlabel("True Value") + + plt.show() + plt.close() model.eval() y_pred = model(torch.Tensor(x_val)) loss = lossFn(y_pred, torch.Tensor(y_val), COEFF) @@ -223,10 +233,12 @@ def train_DE( for m in range(n_models): print('model', m) # initialize the model again each time from scratch - model, lossFn, opt = models.model_setup_DE(loss_type, DEVICE, INIT_LR) + model, lossFn = models.model_setup_DE(loss_type, DEVICE) + opt = torch.optim.Adam(model.parameters(), lr=INIT_LR) # loop over our epochs for e in range(0, EPOCHS): + plt.close() epoch = int(start_epoch + e) # set the model in training mode @@ -237,8 +249,12 @@ def train_DE( print("epoch", epoch, round(e / EPOCHS, 2)) loss_this_epoch = [] - - plt.clf() + if plot is True: + plt.clf() + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), + gridspec_kw={'height_ratios': [3, 1]} + ) + # randomly shuffles the training data (if shuffle = True) # and draws batches up to the total training size # (should be about 8 batches) @@ -267,54 +283,27 @@ def train_DE( # 1 - e / EPOCHS # this one doesn't work great ''' #beta_epoch = 1 - e / EPOCHS - beta_epoch = 1 + beta_epoch = 0.5 loss = lossFn(pred[:, 0].flatten(), pred[:, 1].flatten() ** 2, y, beta=beta_epoch) if plot is True: - if e % 5 == 0: - if i == 0: - plt.plot(range(0, 1000), - range(0, 1000), - color='black', ls='--') - if loss_type == "no_var_loss": - plt.scatter( - y, - pred.flatten().detach().numpy(), - color="#F45866", - edgecolor="black", - zorder=100, - ) - else: - plt.errorbar( - y, - pred[:, 0].flatten().detach().numpy(), - yerr=abs(pred[:, 1]. - flatten().detach().numpy()), - linestyle="None", - color="black", - capsize=2, - zorder=100, - ) - plt.scatter( - y, - pred[:, 0].flatten().detach().numpy(), - color="black", - zorder=100, - ) - if loss_type == "bnll_loss": - plt.annotate(r'$\beta = $' + - str(round(beta_epoch, 2)), - xy=(0.03, 0.9), - xycoords='axes fraction') + if (e % (EPOCHS-1) == 0) and (e != 0): + if loss_type == "no_var_loss": + ax1.scatter(y, pred.flatten().detach().numpy(), + color='grey', + alpha=0.5, + label='training data') else: - if loss_type == "no_var_loss": - plt.scatter(y, pred.flatten().detach().numpy(), + if i == 0: + ax1.scatter(y, pred[:, 0].flatten(). + detach().numpy(), color='grey', - alpha=0.5) + alpha=0.5, + label='training data') else: - plt.scatter(y, pred[:, 0].flatten(). + ax1.scatter(y, pred[:, 0].flatten(). detach().numpy(), color='grey', alpha=0.5) @@ -332,21 +321,7 @@ def train_DE( # of the parameters # here, its taking a step for every batch opt.step() - if plot is True: - if e % 5 == 0: - plt.ylabel("prediction") - plt.xlabel("true value") - plt.title("Epoch " + str(e)) - plt.xlim([0, 1000]) - plt.ylim([0, 1000]) - if savefig is True: - plt.errorbar(200, 600, yerr=5, - color='red', capsize=2) - plt.savefig("../images/animations/" + - str(model_name) + "_nmodel_" + - str(m) + "_beta_0.5" + "_epoch_" + - str(epoch) + ".png") - plt.show() + loss_all_epochs.append(loss_this_epoch) # print('training loss', np.mean(loss_this_epoch)) @@ -365,11 +340,11 @@ def train_DE( y_pred[:, 1].flatten() ** 2, ).item() if loss_type == "bnll_loss": - loss = lossFn( - y_pred[:, 0].flatten(), - y_pred[:, 1].flatten() ** 2, - torch.Tensor(y_val), - ).item() + loss = lossFn(y_pred[:, 0].flatten(), + y_pred[:, 1].flatten() ** 2, + torch.Tensor(y_val), + beta=beta_epoch + ).item() loss_validation.append(loss) mse_loss = torch.nn.MSELoss(reduction='mean') @@ -381,6 +356,78 @@ def train_DE( print("new best loss", loss, "in epoch", epoch) # best_weights = copy.deepcopy(model.state_dict()) # print('validation loss', mse) + if (plot is True) and (e % (EPOCHS-1) == 0) and (e != 0): + ax1.plot(range(0, 1000), + range(0, 1000), + color='black', + ls='--') + if loss_type == "no_var_loss": + ax1.scatter( + y_val, + y_pred.flatten().detach().numpy(), + color="#F45866", + edgecolor="black", + zorder=100, + label='validation dtata' + ) + else: + ax1.errorbar( + y_val, + y_pred[:, 0].flatten().detach().numpy(), + yerr=abs(y_pred[:, 1]. + flatten().detach().numpy()), + linestyle="None", + color="black", + capsize=2, + zorder=100, + ) + ax1.scatter( + y_val, + y_pred[:, 0].flatten().detach().numpy(), + color="#9CD08F", + s=5, + zorder=101, + label='validation data' + ) + if loss_type == "bnll_loss": + ax2.annotate(r'$\beta = $' + + str(round(beta_epoch, 2)), + xy=(0.03, 0.4), + xycoords='axes fraction') + + # add residual plot + residuals = y_pred[:, 0].flatten().detach().numpy() - y_val + ax2.scatter(y_val, residuals, + color='#9B287B', + s=5) + ax2.axhline(0, color='black', linestyle='--', linewidth=1) + ax2.set_ylabel("Residuals") + ax2.set_xlabel("True Value") + + + # add annotion for loss value + ax2.annotate(str(loss_type) + ' = ' + str(round(loss,2)), + xy=(0.03, 0.25), + xycoords='axes fraction') + # also add annotations for mse + ax2.annotate(r'MSE = ' + str(round(mse,2)), + xy=(0.03, 0.1), + xycoords='axes fraction') + + ax1.set_ylabel("Prediction") + ax1.set_title("Epoch " + str(e)) + ax1.set_xlim([0, 1000]) + ax1.set_ylim([0, 1000]) + ax1.legend() + if savefig is True: + ax1.errorbar(200, 600, yerr=5, + color='red', capsize=2) + plt.savefig("../images/animations/" + + str(model_name) + "_nmodel_" + + str(m) + "_beta_0.5" + "_epoch_" + + str(epoch) + ".png") + plt.show() + plt.close() if save_checkpoints is True: