Skip to content

Commit

Permalink
running DE from command line dumping into config :)
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Apr 17, 2024
1 parent 59e7179 commit 046cda9
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 37 deletions.
246 changes: 230 additions & 16 deletions src/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,7 @@
from torch.utils.data import TensorDataset
import torch
import h5py


def parse_args():
parser = argparse.ArgumentParser(description="data handling module")
parser.add_argument(
"--arg",
type=float,
required=False,
default=100,
help="Description",
)
return parser.parse_args()
import torch.nn as nn


class ModelLoader:
Expand Down Expand Up @@ -58,7 +47,232 @@ def predict(input, model):
return 0


# Example usage:
if __name__ == "__main__":
namespace = parse_args()
arg = namespace.arg
class DERLayer(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
gamma = x[:, 0]
nu = nn.functional.softplus(x[:, 1])
alpha = nn.functional.softplus(x[:, 2]) + 1.0
beta = nn.functional.softplus(x[:, 3])
return torch.stack((gamma, nu, alpha, beta), dim=1)


class SDERLayer(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
gamma = x[:, 0]
nu = nn.functional.softplus(x[:, 1])
alpha = nu + 1.0
beta = nn.functional.softplus(x[:, 3])
return torch.stack((gamma, nu, alpha, beta), dim=1)


def model_setup_DER(loss_type, DEVICE):
print('loss type', loss_type, type(loss_type))
# initialize the model from scratch
if loss_type == "SDER":
Layer = SDERLayer
# initialize our loss function
lossFn = loss_sder
if loss_type == "DER":
Layer = DERLayer
# initialize our loss function
lossFn = loss_der

# from https://github.com/pasteurlabs/unreasonable_effective_der
# /blob/main/x3_indepth.ipynb
model = torch.nn.Sequential(Model(4), Layer())
model = model.to(DEVICE)
return model, lossFn


class MuVarLayer(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
mu = x[:, 0]
# softplus enforces positivity
var = nn.functional.softplus(x[:, 1])
# var = x[:, 1]
return torch.stack((mu, var), dim=1)


def model_setup_DE(loss_type, DEVICE):
# initialize the model from scratch
if loss_type == "no_var_loss":
# model = de_no_var().to(DEVICE)
lossFn = torch.nn.MSELoss(reduction="mean")
if loss_type == "var_loss":
# model = de_var().to(DEVICE)
Layer = MuVarLayer
lossFn = torch.nn.GaussianNLLLoss(full=False,
eps=1e-06,
reduction="mean")
if loss_type == "bnll_loss":
# model = de_var().to(DEVICE)
Layer = MuVarLayer
lossFn = loss_bnll
model = torch.nn.Sequential(Model(2), Layer())
model = model.to(DEVICE)
return model, lossFn


class de_no_var(nn.Module):
def __init__(self):
super().__init__()
drop_percent = 0.1
self.ln_1 = nn.Linear(3, 100)
self.act1 = nn.ReLU()
self.drop1 = nn.Dropout(drop_percent)
self.ln_2 = nn.Linear(100, 100)
self.act2 = nn.ReLU()
self.drop2 = nn.Dropout(drop_percent)
self.ln_3 = nn.Linear(100, 100)
self.act3 = nn.ReLU()
self.drop3 = nn.Dropout(drop_percent)
self.ln_4 = nn.Linear(100, 1)
# this last dim needs to be 2 if using the GaussianNLLoss

def forward(self, x):
x = self.drop1(self.act1(self.ln_1(x)))
x = self.drop2(self.act2(self.ln_2(x)))
x = self.drop3(self.act3(self.ln_3(x)))
x = self.ln_4(x)
return x


class de_var(nn.Module):
def __init__(self):
super().__init__()
drop_percent = 0.1
self.ln_1 = nn.Linear(3, 100)
self.act1 = nn.ReLU()
self.drop1 = nn.Dropout(drop_percent)
self.ln_2 = nn.Linear(100, 100)
self.act2 = nn.ReLU()
self.drop2 = nn.Dropout(drop_percent)
self.ln_3 = nn.Linear(100, 100)
self.act3 = nn.ReLU()
self.drop3 = nn.Dropout(drop_percent)
self.ln_4 = nn.Linear(100, 2)
# this last dim needs to be 2 if using the GaussianNLLoss

def forward(self, x):
x = self.drop1(self.act1(self.ln_1(x)))
x = self.drop2(self.act2(self.ln_2(x)))
x = self.drop3(self.act3(self.ln_3(x)))
x = self.ln_4(x)
return x


# This following is from PasteurLabs -
# https://github.com/pasteurlabs/unreasonable_effective_der/blob/main/models.py


class Model(nn.Module):
def __init__(self, n_output, n_hidden=64):
super().__init__()
self.model = nn.Sequential(
nn.Linear(3, n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, n_output),
)

def forward(self, x):
return self.model(x)


def loss_der(y, y_pred, coeff):
gamma, nu, alpha, beta = y[:, 0], y[:, 1], y[:, 2], y[:, 3]
error = gamma - y_pred
omega = 2.0 * beta * (1.0 + nu)

# define aleatoric and epistemic uncert
u_al = np.sqrt(
beta.detach().numpy()
* (1 + nu.detach().numpy())
/ (alpha.detach().numpy() * nu.detach().numpy())
)
u_ep = 1 / np.sqrt(nu.detach().numpy())
return (
torch.mean(
0.5 * torch.log(math.pi / nu)
- alpha * torch.log(omega)
+ (alpha + 0.5) * torch.log(error**2 * nu + omega)
+ torch.lgamma(alpha)
- torch.lgamma(alpha + 0.5)
+ coeff * torch.abs(error) * (2.0 * nu + alpha)
),
u_al,
u_ep,
)


def loss_sder(y, y_pred, coeff):
gamma, nu, alpha, beta = y[:, 0], y[:, 1], y[:, 2], y[:, 3]
error = gamma - y_pred
var = beta / nu

# define aleatoric and epistemic uncert
u_al = np.sqrt(
beta.detach().numpy()
* (1 + nu.detach().numpy())
/ (alpha.detach().numpy() * nu.detach().numpy())
)
u_ep = 1 / np.sqrt(nu.detach().numpy())

return torch.mean(torch.log(var) + (1.0 + coeff * nu) * error**2 / var), \
u_al, u_ep


# from martius lab
# https://github.com/martius-lab/beta-nll
# and Seitzer+2020


def loss_bnll(mean, variance, target, beta): # beta=0.5):
"""Compute beta-NLL loss
:param mean: Predicted mean of shape B x D
:param variance: Predicted variance of shape B x D
:param target: Target of shape B x D
:param beta: Parameter from range [0, 1] controlling relative
weighting between data points, where `0` corresponds to
high weight on low error points and `1` to an equal weighting.
:returns: Loss per batch element of shape B
"""
loss = 0.5 * ((target - mean) ** 2 / variance + variance.log())
if beta > 0:
loss = loss * (variance.detach() ** beta)
return loss.sum(axis=-1)


'''
def get_loss(transform, beta=None):
if beta:
def beta_nll_loss(targets, outputs, beta=beta):
"""Compute beta-NLL loss
"""
mu = outputs[..., 0:1]
var = transform(outputs[..., 1:2])
loss = (K.square((targets - mu)) / var + K.log(var))
loss = loss * K.stop_gradient(var) ** beta
return loss
return beta_nll_loss
else:
def negative_log_likelihood(targets, outputs):
"""Calculate the negative loglikelihood."""
mu = outputs[..., 0:1]
var = transform(outputs[..., 1:2])
y = targets[..., 0:1]
loglik = - K.log(var) - K.square((y - mu)) / var
return - loglik
return negative_log_likelihood
'''
54 changes: 36 additions & 18 deletions src/scripts/DeepEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,32 @@ def parse_args():
os.makedirs(os.path.dirname(temp_config), exist_ok=True)

input_yaml = {
"common": {"out_dir": args.out_dir},
#"model": {"model_path":args.model_path, "model_engine":args.model_engine},
"common": {"out_dir": args.out_dir},
"model": {"model_path": args.model_path,
"model_engine": args.model_engine,
"model_type": args.model_type,
"loss_type": args.loss_type,
"n_models": args.n_models,
"init_lr": args.init_lr,
"wd": args.wd,
"BETA": args.BETA,
"n_epochs": args.n_epochs,
"path_to_models": args.path_to_models,
"save_all_checkpoints": args.save_all_checkpoints,
"save_final_checkpoint": args.save_final_checkpoint,
"overwrite_final_checkpoint": args.overwrite_final_checkpoint,
"plot": args.plot,
"savefig": args.savefig,
"verbose": args.verbose,
},
"data": {"data_path": args.data_path,
"data_engine": args.data_engine,
"size_df": args.size_df,
"noise_level": args.noise_level,
"val_proportion": args.val_proportion,
"randomseed": args.randomseed,
"batchsize": args.batchsize},
"batchsize": args.batchsize,
},
#"plots": {key: {} for key in args.plots},
#"metrics": {key: {} for key in args.metrics},
}
Expand Down Expand Up @@ -296,25 +313,26 @@ def beta_type(value):
# set the device we will be using to train the model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = namespace.model_type + "_noise_" + noise
model, lossFn = models.model_setup_DE(namespace.loss_type, DEVICE)
model_name = config.get_item("model", "model_type") + "_noise_" + noise
model, lossFn = models.model_setup_DE(config.get_item("model", "loss_type"), DEVICE)
model_ensemble = train.train_DE(
trainDataLoader,
x_val,
y_val,
namespace.init_lr,
config.get_item("model", "init_lr"),
DEVICE,
namespace.loss_type,
namespace.n_models,
namespace.wd,
config.get_item("model", "loss_type"),
config.get_item("model", "n_models"),
config.get_item("model", "wd"),
model_name,
BETA=namespace.BETA,
EPOCHS=namespace.n_epochs,
path_to_model=namespace.path_to_models,
save_all_checkpoints=namespace.save_all_checkpoints,
save_final_checkpoint=namespace.save_final_checkpoint,
overwrite_final_checkpoint=namespace.overwrite_final_checkpoint,
plot=namespace.plot,
savefig=namespace.savefig,
verbose=namespace.verbose,
BETA=config.get_item("model", "BETA"),
EPOCHS=config.get_item("model", "n_epochs"),
path_to_model=config.get_item("model", "path_to_models"),
save_all_checkpoints=config.get_item("model", "save_all_checkpoints"),
save_final_checkpoint=config.get_item("model", "save_final_checkpoint"),
overwrite_final_checkpoint=config.get_item("model",
"overwrite_final_checkpoint"),
plot=config.get_item("model", "plot"),
savefig=config.get_item("model", "savefig"),
verbose=config.get_item("model", "verbose"),
)
11 changes: 8 additions & 3 deletions src/utils/defaults.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
Defaults = {
"common":{
"out_dir":"./DeepDiagnosticsResources/results/",
"temp_config": "./DeepDiagnosticsResources/temp/temp_config.yml",
"sim_location": "DeepDiagnosticsResources_Simulators"
"out_dir":"./DeepUQResources/results/",
"temp_config": "./DeepUQResources/temp/temp_config.yml",
},
"data": {
"data_path": "./data/",
Expand All @@ -19,7 +18,13 @@
"model_path": "./models/",
# the engines are the classes, defined
"model_engine": "DE",
"model_type": "DE",
"loss_type": "bnll",
"n_models": 100,
"init_lr": 0.001,
"wd": "./",
"BETA": 0.5,
"n_epochs": 100,
},
"plots_common": {
"axis_spines": False,
Expand Down

0 comments on commit 046cda9

Please sign in to comment.