Skip to content

Commit

Permalink
adding a separate get model thing for DER
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Feb 7, 2024
1 parent 219e006 commit acda364
Showing 1 changed file with 44 additions and 41 deletions.
85 changes: 44 additions & 41 deletions src/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ def train_model(data_source, n_epochs):

return 0

def model_setup_DER(DER_type, DEVICE):
# initialize the model from scratch
if DER_type == 'SDER':
#model = models.de_no_var().to(device)
DERLayer = models.SDERLayer

# initialize our loss function
lossFn = models.loss_sder
else:
#model = models.de_var().to(device)
DERLayer = models.DERLayer
# initialize our loss function
lossFn = models.loss_der

# from https://github.com/pasteurlabs/unreasonable_effective_der/blob/main/x3_indepth.ipynb
model = torch.nn.Sequential(models.Model(4), DERLayer())
model = model.to(DEVICE)
return model, lossFn

def train_DER(trainDataLoader,
x_val,
y_val,
Expand All @@ -46,6 +65,7 @@ def train_DER(trainDataLoader,
model_name='DER',
EPOCHS=40,
save_checkpoints=False,
path_to_model='models/',
plot=False):
# measure how long training is going to take
print("[INFO] training the network...")
Expand Down Expand Up @@ -77,22 +97,8 @@ def train_DER(trainDataLoader,
best_loss = np.inf # init to infinity


# initialize the model from scratch
if DER_type == 'SDER':
#model = models.de_no_var().to(device)
DERLayer = models.SDERLayer

# initialize our loss function
lossFn = models.loss_sder
else:
#model = models.de_var().to(device)
DERLayer = models.DERLayer
# initialize our loss function
lossFn = models.loss_der

# from https://github.com/pasteurlabs/unreasonable_effective_der/blob/main/x3_indepth.ipynb
model = torch.nn.Sequential(models.Model(4), DERLayer())
model = model.to(DEVICE)
model, lossFn = model_setup_DER(DER_type, DEVICE)

loss_fct = functools.partial(lossFn, coeff=COEFF)
opt = torch.optim.Adam(model.parameters(), lr=INIT_LR)

Expand Down Expand Up @@ -121,10 +127,6 @@ def train_DER(trainDataLoader,
# perform a forward pass and calculate the training loss

pred = model(x)
#print('shapes train', np.shape(pred), np.shape(y))
#print('x', x)
#print('y', y)

loss = lossFn(pred, y, COEFF)
if plot == True:
if e % 5 == 0:
Expand All @@ -134,30 +136,25 @@ def train_DER(trainDataLoader,
color='#F45866',
edgecolor='black',
zorder=100)
'''
else:
plt.errorbar(y,
pred[:, 0].flatten().detach().numpy(),
yerr=abs(pred[:, 1].flatten().detach().numpy()),
linestyle='None',
color='#F45866',
zorder=100)
plt.scatter(y,
pred[:, 0].flatten().detach().numpy(),
color='#F45866',
edgecolor='black',
zorder=100)
'''
plt.errorbar(y, pred[:,0].flatten().detach().numpy(),
yerr = loss[2],
color='#F45866',
zorder=100,
ls='None')
plt.annotate(r'med $u_{ep} = '+str(np.median(loss[2])),
xy = (0.03, 0.93),
xycoords='axes fraction',
color='#F45866')

else:
plt.scatter(y, pred[:, 0].flatten().detach().numpy())

loss_this_epoch.append(loss.item())
loss_this_epoch.append(loss[0].item())

# zero out the gradients
opt.zero_grad()
# perform the backpropagation step
# computes the derivative of loss with respect to the parameters
loss.backward()
loss[0].backward()
# update the weights
# optimizer takes a step based on the gradients of the parameters
# here, its taking a step for every batch
Expand All @@ -177,24 +174,30 @@ def train_DER(trainDataLoader,
#print('x val', x_val)
#print('y val', y_val)
y_pred = model(torch.Tensor(x_val))
NIGloss_val = lossFn(y_pred, torch.Tensor(y_val), COEFF).item()
loss = lossFn(y_pred, torch.Tensor(y_val), COEFF)
NIGloss_val = loss[0].item()
med_u_al_val = np.median(loss[1])
med_u_ep_val = np.median(loss[2])

loss_validation.append(NIGloss_val)
if NIGloss_val < best_loss:
best_loss = NIGloss_val
print('new best loss', NIGloss_val, 'in epoch', epoch)
#best_weights = copy.deepcopy(model.state_dict())
#print('validation loss', mse)



if save_checkpoints:

torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'train_loss': np.mean(loss_this_epoch),
'valid_loss': NIGloss_val
}, "/home/rnevin/deepskieslab/rnevin/TinyCNN/models/TinyCNN_MSE_"+str(epoch)+".pt")
'valid_loss': NIGloss_val,
'med_u_al_validation': med_u_al_val,
'med_u_ep_validation': med_u_ep_val,
}, path_to_model + "/" + str(DER_type)+"_"+str(epoch)+".pt")
endTime = time.time()
print('start at', startTime, 'end at', endTime)
print(endTime - startTime)
Expand Down

0 comments on commit acda364

Please sign in to comment.