Skip to content

Commit

Permalink
correcting flake8 making sure everything is defined in DER
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Apr 9, 2024
1 parent 83703cd commit c8e2783
Showing 1 changed file with 42 additions and 86 deletions.
128 changes: 42 additions & 86 deletions src/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def train_DER(
overwrite_final_checkpoint=False,
plot=True,
savefig=True,
verbose=True
verbose=True,
):
# first determine if you even need to run anything
if not save_all_checkpoints and save_final_checkpoint:
Expand All @@ -38,7 +38,7 @@ def train_DER(
+ "_epoch_"
+ str(EPOCHS - 1)
+ ".pt"
)
)
if verbose:
print("final chk", final_chk)
# check if the final epoch checkpoint already exists
Expand Down Expand Up @@ -66,11 +66,11 @@ def train_DER(

startTime = time.time()
start_epoch = 0

best_loss = np.inf # init to infinity
model, lossFn = models.model_setup_DER(loss_type, DEVICE)
if verbose:
print('model is', model, 'lossfn', lossFn)
print("model is", model, "lossfn", lossFn)

opt = torch.optim.Adam(model.parameters(), lr=INIT_LR)

Expand Down Expand Up @@ -104,8 +104,6 @@ def train_DER(
loss = lossFn(pred, y, COEFF)
if plot or savefig:
if (e % (EPOCHS - 1) == 0) and (e != 0):
pred_loader_0 = pred[:, 0].flatten().detach().numpy()
y_loader_0 = y.detach().numpy()
ax1.scatter(
y,
pred[:, 0].flatten().detach().numpy(),
Expand All @@ -126,12 +124,12 @@ def train_DER(
xycoords="axes fraction",
color="black",
)
'''
"""
else:
ax1.scatter(y,
pred[:, 0].flatten().detach().numpy(),
color="grey")
'''
"""
loss_this_epoch.append(loss[0].item())

# zero out the gradients
Expand All @@ -143,11 +141,25 @@ def train_DER(
# optimizer takes a step based on the gradients of the parameters
# here, its taking a step for every batch
opt.step()
model.eval()
y_pred = model(torch.Tensor(x_val))
loss = lossFn(y_pred, torch.Tensor(y_val), COEFF)
NIGloss_val = loss[0].item()
med_u_al_val = np.median(loss[1])
med_u_ep_val = np.median(loss[2])
std_u_al_val = np.std(loss[1])
std_u_ep_val = np.std(loss[2])

# lets also grab mse loss
mse_loss = torch.nn.MSELoss(reduction="mean")
mse = mse_loss(y_pred[:, 0], torch.Tensor(y_val)).item()
if NIGloss_val < best_loss:
best_loss = NIGloss_val
if verbose:
print("new best loss", NIGloss_val, "in epoch", epoch)
# best_weights = copy.deepcopy(model.state_dict())
if (plot or savefig) and (e % (EPOCHS - 1) == 0) and (e != 0):
ax1.plot(range(0, 1000),
range(0, 1000),
color="black",
ls="--")
ax1.plot(range(0, 1000), range(0, 1000), color="black", ls="--")
if loss_type == "no_var_loss":
ax1.scatter(
y_val,
Expand Down Expand Up @@ -191,42 +203,21 @@ def train_DER(
ax2.set_ylabel("Residuals")
ax2.set_xlabel("True Value")
# add annotion for loss value
if loss_type == "bnll_loss":
ax1.annotate(
r"$\beta = $"
+ str(round(beta_epoch, 2))
+ "\n"
+ str(loss_type)
+ " = "
+ str(round(loss, 2))
+ "\n"
+ r"MSE = "
+ str(round(mse, 2)),
xy=(0.73, 0.1),
xycoords="axes fraction",
bbox=dict(
boxstyle="round,pad=0.5",
facecolor="lightgrey",
alpha=0.5
),
)

else:
ax1.annotate(
str(loss_type)
+ " = "
+ str(round(loss, 2))
+ "\n"
+ r"MSE = "
+ str(round(mse, 2)),
xy=(0.73, 0.1),
xycoords="axes fraction",
bbox=dict(
boxstyle="round,pad=0.5",
facecolor="lightgrey",
alpha=0.5
),
)
ax1.annotate(
str(loss_type)
+ " = "
+ str(round(loss, 2))
+ "\n"
+ r"MSE = "
+ str(round(mse, 2)),
xy=(0.73, 0.1),
xycoords="axes fraction",
bbox=dict(
boxstyle="round,pad=0.5",
facecolor="lightgrey",
alpha=0.5
),
)
ax1.set_ylabel("Prediction")
ax1.set_title("Epoch " + str(e))
ax1.set_xlim([0, 1000])
Expand All @@ -248,40 +239,6 @@ def train_DER(
if plot:
plt.show()
plt.close()
'''
if plot and (e % 5 == 0):
ax1.set_ylabel("prediction")
ax1.set_title("Epoch " + str(e))
# Residuals plot
residuals = pred_loader_0 - y_loader_0
ax2.scatter(y_loader_0, residuals, color="red")
ax2.axhline(0, color="black", linestyle="--", linewidth=1)
ax2.set_ylabel("Residuals")
ax2.set_xlabel("True Value")
plt.show()
plt.close()
'''
model.eval()
y_pred = model(torch.Tensor(x_val))
loss = lossFn(y_pred, torch.Tensor(y_val), COEFF)
NIGloss_val = loss[0].item()
med_u_al_val = np.median(loss[1])
med_u_ep_val = np.median(loss[2])
std_u_al_val = np.std(loss[1])
std_u_ep_val = np.std(loss[2])

# lets also grab mse loss
mse_loss = torch.nn.MSELoss(reduction="mean")
mse = mse_loss(y_pred[:, 0], torch.Tensor(y_val)).item()
if NIGloss_val < best_loss:
best_loss = NIGloss_val
if verbose:
print("new best loss", NIGloss_val, "in epoch", epoch)
# best_weights = copy.deepcopy(model.state_dict())
# print('validation loss', mse)

if save_all_checkpoints:

torch.save(
Expand All @@ -307,7 +264,7 @@ def train_DER(
+ ".pt",
)
if save_final_checkpoint and (e % (EPOCHS - 1) == 0) and (e != 0):
# option to just save final epoch
# option to just save final epoch
torch.save(
{
"epoch": epoch,
Expand All @@ -329,7 +286,7 @@ def train_DER(
+ "_epoch_"
+ str(epoch)
+ ".pt",
)
)
endTime = time.time()
if verbose:
print("start at", startTime, "end at", endTime)
Expand Down Expand Up @@ -371,8 +328,7 @@ def train_DE(

model_ensemble = []

print('this is the value of save_final_checkpoint',
save_final_checkpoint)
print("this is the value of save_final_checkpoint", save_final_checkpoint)

for m in range(n_models):
print("model", m)
Expand Down

0 comments on commit c8e2783

Please sign in to comment.