Skip to content

Commit

Permalink
making the optimizer a separate thing
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Mar 18, 2024
1 parent d2f1756 commit 7da4a5d
Showing 1 changed file with 147 additions and 100 deletions.
247 changes: 147 additions & 100 deletions src/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def train_DER(

# loop over our epochs
for e in range(0, EPOCHS):
if plot:
plt.clf()
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6),
gridspec_kw={'height_ratios': [3, 1]}
)

epoch = int(start_epoch + e)

# set the model in training mode
Expand All @@ -66,10 +72,7 @@ def train_DER(
# loop over the training set
if verbose:
print("epoch", epoch, round(e / EPOCHS, 2))

loss_this_epoch = []

plt.clf()
# randomly shuffles the training data (if shuffle = True)
# and draws batches up to the total training size
# (should be about 8 batches)
Expand All @@ -81,34 +84,34 @@ def train_DER(

pred = model(x)
loss = lossFn(pred, y, COEFF)
if plot:
if e % 5 == 0:
if i == 0:
# if loss_type == 'no_var_loss':
plt.scatter(
y,
pred[:, 0].flatten().detach().numpy(),
color="#F45866",
edgecolor="black",
zorder=100,
)
plt.errorbar(
y,
pred[:, 0].flatten().detach().numpy(),
yerr=loss[2],
color="#F45866",
zorder=100,
ls="None",
)
plt.annotate(
r"med $u_{ep} = " + str(np.median(loss[2])),
xy=(0.03, 0.93),
xycoords="axes fraction",
color="#F45866",
)

else:
plt.scatter(y, pred[:, 0].flatten().detach().numpy())
if plot and (e % 5 == 0):
if i == 0:
pred_loader_0 = pred[:,0].flatten().detach().numpy()
y_loader_0 = y.detach().numpy()
ax1.scatter(
y,
pred[:, 0].flatten().detach().numpy(),
color="black",
zorder=100,
)
ax1.errorbar(
y,
pred[:, 0].flatten().detach().numpy(),
yerr=loss[2],
color="black",
zorder=100,
ls="None",
)
ax1.annotate(
r"med $u_{ep} = $" + str(np.median(loss[2])),
xy=(0.03, 0.93),
xycoords="axes fraction",
color="black",
)
else:
ax1.scatter(y,
pred[:, 0].flatten().detach().numpy(),
color='grey')
loss_this_epoch.append(loss[0].item())

# zero out the gradients
Expand All @@ -120,12 +123,19 @@ def train_DER(
# optimizer takes a step based on the gradients of the parameters
# here, its taking a step for every batch
opt.step()
if plot:
if e % 5 == 0:
plt.ylabel("prediction")
plt.xlabel("true value")
plt.title("Epoch " + str(e))
plt.show()
if plot and (e % 5 == 0):
ax1.set_ylabel("prediction")
ax1.set_title("Epoch " + str(e))

# Residuals plot
residuals = pred_loader_0 - y_loader_0
ax2.scatter(y_loader_0, residuals, color='red')
ax2.axhline(0, color='black', linestyle='--', linewidth=1)
ax2.set_ylabel("Residuals")
ax2.set_xlabel("True Value")

plt.show()
plt.close()
model.eval()
y_pred = model(torch.Tensor(x_val))
loss = lossFn(y_pred, torch.Tensor(y_val), COEFF)
Expand Down Expand Up @@ -223,10 +233,12 @@ def train_DE(
for m in range(n_models):
print('model', m)
# initialize the model again each time from scratch
model, lossFn, opt = models.model_setup_DE(loss_type, DEVICE, INIT_LR)
model, lossFn = models.model_setup_DE(loss_type, DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=INIT_LR)

# loop over our epochs
for e in range(0, EPOCHS):
plt.close()
epoch = int(start_epoch + e)

# set the model in training mode
Expand All @@ -237,8 +249,12 @@ def train_DE(
print("epoch", epoch, round(e / EPOCHS, 2))

loss_this_epoch = []

plt.clf()
if plot is True:
plt.clf()
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6),
gridspec_kw={'height_ratios': [3, 1]}
)

# randomly shuffles the training data (if shuffle = True)
# and draws batches up to the total training size
# (should be about 8 batches)
Expand Down Expand Up @@ -267,54 +283,27 @@ def train_DE(
# 1 - e / EPOCHS # this one doesn't work great
'''
#beta_epoch = 1 - e / EPOCHS
beta_epoch = 1
beta_epoch = 0.5
loss = lossFn(pred[:, 0].flatten(),
pred[:, 1].flatten() ** 2,
y,
beta=beta_epoch)
if plot is True:
if e % 5 == 0:
if i == 0:
plt.plot(range(0, 1000),
range(0, 1000),
color='black', ls='--')
if loss_type == "no_var_loss":
plt.scatter(
y,
pred.flatten().detach().numpy(),
color="#F45866",
edgecolor="black",
zorder=100,
)
else:
plt.errorbar(
y,
pred[:, 0].flatten().detach().numpy(),
yerr=abs(pred[:, 1].
flatten().detach().numpy()),
linestyle="None",
color="black",
capsize=2,
zorder=100,
)
plt.scatter(
y,
pred[:, 0].flatten().detach().numpy(),
color="black",
zorder=100,
)
if loss_type == "bnll_loss":
plt.annotate(r'$\beta = $' +
str(round(beta_epoch, 2)),
xy=(0.03, 0.9),
xycoords='axes fraction')
if (e % (EPOCHS-1) == 0) and (e != 0):
if loss_type == "no_var_loss":
ax1.scatter(y, pred.flatten().detach().numpy(),
color='grey',
alpha=0.5,
label='training data')
else:
if loss_type == "no_var_loss":
plt.scatter(y, pred.flatten().detach().numpy(),
if i == 0:
ax1.scatter(y, pred[:, 0].flatten().
detach().numpy(),
color='grey',
alpha=0.5)
alpha=0.5,
label='training data')
else:
plt.scatter(y, pred[:, 0].flatten().
ax1.scatter(y, pred[:, 0].flatten().
detach().numpy(),
color='grey',
alpha=0.5)
Expand All @@ -332,21 +321,7 @@ def train_DE(
# of the parameters
# here, its taking a step for every batch
opt.step()
if plot is True:
if e % 5 == 0:
plt.ylabel("prediction")
plt.xlabel("true value")
plt.title("Epoch " + str(e))
plt.xlim([0, 1000])
plt.ylim([0, 1000])
if savefig is True:
plt.errorbar(200, 600, yerr=5,
color='red', capsize=2)
plt.savefig("../images/animations/" +
str(model_name) + "_nmodel_" +
str(m) + "_beta_0.5" + "_epoch_" +
str(epoch) + ".png")
plt.show()


loss_all_epochs.append(loss_this_epoch)
# print('training loss', np.mean(loss_this_epoch))
Expand All @@ -365,11 +340,11 @@ def train_DE(
y_pred[:, 1].flatten() ** 2,
).item()
if loss_type == "bnll_loss":
loss = lossFn(
y_pred[:, 0].flatten(),
y_pred[:, 1].flatten() ** 2,
torch.Tensor(y_val),
).item()
loss = lossFn(y_pred[:, 0].flatten(),
y_pred[:, 1].flatten() ** 2,
torch.Tensor(y_val),
beta=beta_epoch
).item()

loss_validation.append(loss)
mse_loss = torch.nn.MSELoss(reduction='mean')
Expand All @@ -381,6 +356,78 @@ def train_DE(
print("new best loss", loss, "in epoch", epoch)
# best_weights = copy.deepcopy(model.state_dict())
# print('validation loss', mse)
if (plot is True) and (e % (EPOCHS-1) == 0) and (e != 0):
ax1.plot(range(0, 1000),
range(0, 1000),
color='black',
ls='--')
if loss_type == "no_var_loss":
ax1.scatter(
y_val,
y_pred.flatten().detach().numpy(),
color="#F45866",
edgecolor="black",
zorder=100,
label='validation dtata'
)
else:
ax1.errorbar(
y_val,
y_pred[:, 0].flatten().detach().numpy(),
yerr=abs(y_pred[:, 1].
flatten().detach().numpy()),
linestyle="None",
color="black",
capsize=2,
zorder=100,
)
ax1.scatter(
y_val,
y_pred[:, 0].flatten().detach().numpy(),
color="#9CD08F",
s=5,
zorder=101,
label='validation data'
)
if loss_type == "bnll_loss":
ax2.annotate(r'$\beta = $' +
str(round(beta_epoch, 2)),
xy=(0.03, 0.4),
xycoords='axes fraction')

# add residual plot
residuals = y_pred[:, 0].flatten().detach().numpy() - y_val
ax2.scatter(y_val, residuals,
color='#9B287B',
s=5)
ax2.axhline(0, color='black', linestyle='--', linewidth=1)
ax2.set_ylabel("Residuals")
ax2.set_xlabel("True Value")


# add annotion for loss value
ax2.annotate(str(loss_type) + ' = ' + str(round(loss,2)),
xy=(0.03, 0.25),
xycoords='axes fraction')
# also add annotations for mse
ax2.annotate(r'MSE = ' + str(round(mse,2)),
xy=(0.03, 0.1),
xycoords='axes fraction')

ax1.set_ylabel("Prediction")
ax1.set_title("Epoch " + str(e))
ax1.set_xlim([0, 1000])
ax1.set_ylim([0, 1000])
ax1.legend()
if savefig is True:
ax1.errorbar(200, 600, yerr=5,
color='red', capsize=2)
plt.savefig("../images/animations/" +
str(model_name) + "_nmodel_" +
str(m) + "_beta_0.5" + "_epoch_" +
str(epoch) + ".png")
plt.show()
plt.close()

if save_checkpoints is True:

Expand Down

0 comments on commit 7da4a5d

Please sign in to comment.