Skip to content

Commit

Permalink
Merge pull request #65 from deepskies/issue/sDER_epistemic
Browse files Browse the repository at this point in the history
Issue/s der epistemic
  • Loading branch information
beckynevin authored Feb 13, 2024
2 parents c204c51 + 627833a commit 5cf6ff6
Show file tree
Hide file tree
Showing 5 changed files with 2,522 additions and 238 deletions.
258 changes: 258 additions & 0 deletions notebooks/epistemic_by_epoch.ipynb

Large diffs are not rendered by default.

153 changes: 83 additions & 70 deletions notebooks/save_dataframe_linefit.ipynb

Large diffs are not rendered by default.

2,237 changes: 2,123 additions & 114 deletions notebooks/unreasonable_DER_linefit.ipynb

Large diffs are not rendered by default.

19 changes: 10 additions & 9 deletions src/scripts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,8 @@ def forward(self, x):
x = self.ln_4(x)
return x

## in numpyro, you must specify number of sampling chains you will use upfront

# words of wisdom from Tian Li and crew:
# on gpu, don't use conda, use pip install
# HMC after SBI to look at degeneracies between params
# different guides (some are slower but better at showing degeneracies)

# This is from PasteurLabs -
# This following is from PasteurLabs -
# https://github.com/pasteurlabs/unreasonable_effective_der/blob/main/models.py

class Model(nn.Module):
Expand Down Expand Up @@ -117,8 +111,15 @@ def loss_der(y, y_pred, coeff):


def loss_sder(y, y_pred, coeff):
gamma, nu, _, beta = y[:, 0], y[:, 1], y[:, 2], y[:, 3]
gamma, nu, alpha, beta = y[:, 0], y[:, 1], y[:, 2], y[:, 3]
error = gamma - y_pred
var = beta / nu

return torch.mean(torch.log(var) + (1. + coeff * nu) * error**2 / var)
# define aleatoric and epistemic uncert
u_al = np.sqrt(beta.detach().numpy() * (1 + nu.detach().numpy()) /
(alpha.detach().numpy() * nu.detach().numpy()))
u_ep = 1 / np.sqrt(nu.detach().numpy())

return torch.mean(torch.log(var) + (1. + coeff * nu) * error**2 / var), \
u_al, \
u_ep
93 changes: 48 additions & 45 deletions src/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,36 @@ def train_model(data_source, n_epochs):

return 0

def model_setup_DER(DER_type, DEVICE):
# initialize the model from scratch
if DER_type == 'SDER':
#model = models.de_no_var().to(device)
DERLayer = models.SDERLayer

# initialize our loss function
lossFn = models.loss_sder
else:
#model = models.de_var().to(device)
DERLayer = models.DERLayer
# initialize our loss function
lossFn = models.loss_der

# from https://github.com/pasteurlabs/unreasonable_effective_der/blob/main/x3_indepth.ipynb
model = torch.nn.Sequential(models.Model(4), DERLayer())
model = model.to(DEVICE)
return model, lossFn

def train_DER(trainDataLoader,
x_val,
y_val,
INIT_LR,
DEVICE,
COEFF,
DER_type,
model_name='DER',
DER_name,
EPOCHS=40,
save_checkpoints=False,
path_to_model='models/',
plot=False):
# measure how long training is going to take
print("[INFO] training the network...")
Expand All @@ -58,10 +78,10 @@ def train_DER(trainDataLoader,
# Find last epoch saved
if save_checkpoints:

print(glob.glob('models/*'+model_name+'*'))
print(glob.glob(path_to_model+"/"+str(DER_name)+'*'))
list_models_run = []
for file in glob.glob('models/*'+model_name+'*'):
list_models_run.append(float(str.split(str(str.split(file, model_name+'_')[1]),'.')[0]))
for file in glob.glob(path_to_model+"/"+str(DER_name)+'*'):
list_models_run.append(float(str.split(str(str.split(file, DER_name+'_')[1]),'.')[0]))
if list_models_run:
start_epoch = max(list_models_run) + 1
else:
Expand All @@ -77,22 +97,8 @@ def train_DER(trainDataLoader,
best_loss = np.inf # init to infinity


# initialize the model from scratch
if DER_type == 'SDER':
#model = models.de_no_var().to(device)
DERLayer = models.SDERLayer

# initialize our loss function
lossFn = models.loss_sder
else:
#model = models.de_var().to(device)
DERLayer = models.DERLayer
# initialize our loss function
lossFn = models.loss_der

# from https://github.com/pasteurlabs/unreasonable_effective_der/blob/main/x3_indepth.ipynb
model = torch.nn.Sequential(models.Model(4), DERLayer())
model = model.to(DEVICE)
model, lossFn = model_setup_DER(DER_type, DEVICE)

loss_fct = functools.partial(lossFn, coeff=COEFF)
opt = torch.optim.Adam(model.parameters(), lr=INIT_LR)

Expand Down Expand Up @@ -121,10 +127,6 @@ def train_DER(trainDataLoader,
# perform a forward pass and calculate the training loss

pred = model(x)
#print('shapes train', np.shape(pred), np.shape(y))
#print('x', x)
#print('y', y)

loss = lossFn(pred, y, COEFF)
if plot == True:
if e % 5 == 0:
Expand All @@ -134,30 +136,25 @@ def train_DER(trainDataLoader,
color='#F45866',
edgecolor='black',
zorder=100)
'''
else:
plt.errorbar(y,
pred[:, 0].flatten().detach().numpy(),
yerr=abs(pred[:, 1].flatten().detach().numpy()),
linestyle='None',
color='#F45866',
zorder=100)
plt.scatter(y,
pred[:, 0].flatten().detach().numpy(),
color='#F45866',
edgecolor='black',
zorder=100)
'''
plt.errorbar(y, pred[:, 0].flatten().detach().numpy(),
yerr=loss[2],
color='#F45866',
zorder=100,
ls='None')
plt.annotate(r'med $u_{ep} = '+str(np.median(loss[2])),
xy = (0.03, 0.93),
xycoords='axes fraction',
color='#F45866')

else:
plt.scatter(y, pred[:, 0].flatten().detach().numpy())

loss_this_epoch.append(loss.item())
loss_this_epoch.append(loss[0].item())

# zero out the gradients
opt.zero_grad()
# perform the backpropagation step
# computes the derivative of loss with respect to the parameters
loss.backward()
loss[0].backward()
# update the weights
# optimizer takes a step based on the gradients of the parameters
# here, its taking a step for every batch
Expand All @@ -177,24 +174,30 @@ def train_DER(trainDataLoader,
#print('x val', x_val)
#print('y val', y_val)
y_pred = model(torch.Tensor(x_val))
NIGloss_val = lossFn(y_pred, torch.Tensor(y_val), COEFF).item()
loss = lossFn(y_pred, torch.Tensor(y_val), COEFF)
NIGloss_val = loss[0].item()
med_u_al_val = np.median(loss[1])
med_u_ep_val = np.median(loss[2])

loss_validation.append(NIGloss_val)
if NIGloss_val < best_loss:
best_loss = NIGloss_val
print('new best loss', NIGloss_val, 'in epoch', epoch)
#best_weights = copy.deepcopy(model.state_dict())
#print('validation loss', mse)



if save_checkpoints:

torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'train_loss': np.mean(loss_this_epoch),
'valid_loss': NIGloss_val
}, "/home/rnevin/deepskieslab/rnevin/TinyCNN/models/TinyCNN_MSE_"+str(epoch)+".pt")
'valid_loss': NIGloss_val,
'med_u_al_validation': med_u_al_val,
'med_u_ep_validation': med_u_ep_val,
}, path_to_model + "/" + str(DER_name)+"_"+str(epoch)+".pt")
endTime = time.time()
print('start at', startTime, 'end at', endTime)
print(endTime - startTime)
Expand Down

0 comments on commit 5cf6ff6

Please sign in to comment.