Skip to content

Commit

Permalink
Merge pull request #96 from deepskies/issue/al_and_ep
Browse files Browse the repository at this point in the history
Branch that calculates total uncertainty
  • Loading branch information
beckynevin authored May 15, 2024
2 parents 82a0262 + 1f9e3da commit d4ed220
Show file tree
Hide file tree
Showing 8 changed files with 473 additions and 169 deletions.
12 changes: 7 additions & 5 deletions src/analyze/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ def load_checkpoint(
model_name,
noise,
epoch,
beta,
device,
loss="SDER",
nmodel=None,
path="models/",
BETA=0.5,
nmodel=None,
COEFF=0.5,
loss="SDER",
):
"""
Load PyTorch model checkpoint from a .pt file.
Expand All @@ -32,13 +33,14 @@ def load_checkpoint(
if model_name[0:3] == "DER":
file_name = (
str(path)
+ f"{model_name}_noise_{noise}_loss_{loss}_epoch_{epoch}.pt"
+ f"{model_name}_noise_{noise}_loss_{loss}"
+ f"_COEFF_{COEFF}_epoch_{epoch}.pt"
)
checkpoint = torch.load(file_name, map_location=device)
elif model_name[0:2] == "DE":
file_name = (
str(path)
+ f"{model_name}_noise_{noise}_beta_{beta}_"
+ f"{model_name}_noise_{noise}_beta_{BETA}_"
f"nmodel_{nmodel}_epoch_{epoch}.pt"
)
checkpoint = torch.load(file_name, map_location=device)
Expand Down
5 changes: 3 additions & 2 deletions src/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,15 +211,16 @@ def loss_der(y, y_pred, coeff):
)


# simplified DER loss (from Meinert)
def loss_sder(y, y_pred, coeff):
gamma, nu, alpha, beta = y[:, 0], y[:, 1], y[:, 2], y[:, 3]
error = gamma - y_pred
var = beta / nu

# define aleatoric and epistemic uncert
u_al = np.sqrt(
beta.detach().numpy()
* (1 + nu.detach().numpy())
(beta.detach().numpy()
* (1 + nu.detach().numpy()))
/ (alpha.detach().numpy() * nu.detach().numpy())
)
u_ep = 1 / np.sqrt(nu.detach().numpy())
Expand Down
53 changes: 36 additions & 17 deletions src/scripts/Aleatoric.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ def parse_args():
there are string options: linear_decrease, \
step_decrease_to_0.5, and step_decrease_to_1.0",
)
parser.add_argument(
"--COEFF",
type=float,
required=False,
default=DefaultsAnalysis["model"]["COEFF"],
help="COEFF for DER",
)
parser.add_argument(
"--loss_type",
type=str,
required=False,
default=DefaultsAnalysis["model"]["loss_type"],
help="loss_type for DER, either SDER or DER",
)
parser.add_argument(
"--noise_level_list",
type=list,
Expand Down Expand Up @@ -108,7 +122,9 @@ def parse_args():
"common": {"dir": args.dir},
"model": {"n_models": args.n_models,
"n_epochs": args.n_epochs,
"BETA": args.BETA},
"BETA": args.BETA,
"COEFF": args.COEFF,
"loss_type": args.loss_type},
"analysis": {
"noise_level_list": args.noise_level_list,
"model_names_list": args.model_names_list,
Expand Down Expand Up @@ -149,6 +165,8 @@ def beta_type(value):
noise_list = config.get_item("analysis", "noise_level_list", "Analysis")
color_list = config.get_item("plots", "color_list", "Analysis")
BETA = config.get_item("model", "BETA", "Analysis")
COEFF = config.get_item("model", "COEFF", "Analysis")
loss_type = config.get_item("model", "loss_type", "Analysis")
sigma_list = []
for noise in noise_list:
sigma_list.append(DataPreparation.get_sigma(noise))
Expand Down Expand Up @@ -197,9 +215,10 @@ def beta_type(value):
model,
noise,
epoch,
BETA,
DEVICE,
path=path_to_chk,
COEFF=COEFF,
loss=loss_type
)
# path=path_to_chk)
# things to grab: 'valid_mse' and 'valid_bnll'
Expand All @@ -221,10 +240,10 @@ def beta_type(value):
model,
noise,
epoch,
BETA,
DEVICE,
nmodel=nmodels,
path=path_to_chk,
BETA=BETA,
nmodel=nmodels,
)
mu_vals, sig_vals = chk_module.ep_al_checkpoint_DE(chk)
list_mus.append(mu_vals)
Expand Down Expand Up @@ -255,32 +274,32 @@ def beta_type(value):
al - al_std,
al + al_std,
color=color_list[i],
alpha=0.5,
alpha=0.25,
edgecolor=None
)
ax.scatter(
ax.plot(
range(n_epochs),
np.sqrt(al_dict[model][noise]),
color=color_list[i],
edgecolors="black",
label=r"$\sigma = " + str(sigma_list[i]),
label=r"$\sigma = $" + str(sigma_list[i]),
)
ax.axhline(y=sigma_list[i], color=color_list[i])
ax.axhline(y=sigma_list[i], color=color_list[i], ls='--')
ax.set_ylabel("Aleatoric Uncertainty")
ax.set_xlabel("Epoch")
if model[0:3] == "DER":
ax.set_title("Deep Evidential Regression")
elif model[0:2] == "DE":
ax.set_title("Deep Ensemble (100 models)")
ax.set_ylim([-1, 15])
ax.set_ylim([0, 14])
plt.legend()
if config.get_item("analysis", "savefig", "Analysis"):
plt.savefig(
str(path_to_out)
+ "aleatoric_uncertainty_n_epochs_"
+ str(n_epochs)
+ "_n_models_DE_"
+ str(n_models)
+ ".png"
)
str(path_to_out)
+ "aleatoric_uncertainty_n_epochs_"
+ str(n_epochs)
+ "_n_models_DE_"
+ str(n_models)
+ ".png"
)
if config.get_item("analysis", "plot", "Analysis"):
plt.show()
Loading

0 comments on commit d4ed220

Please sign in to comment.