Skip to content

Commit

Permalink
Merge pull request #105 from deepskies/issue/DER_loss
Browse files Browse the repository at this point in the history
modified to be DER_wst with a scaling by this term
  • Loading branch information
beckynevin authored May 17, 2024
2 parents 52a8436 + ec3c92c commit dae46d4
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def loss_der(y, y_pred, coeff):
gamma, nu, alpha, beta = y[:, 0], y[:, 1], y[:, 2], y[:, 3]
error = gamma - y_pred
omega = 2.0 * beta * (1.0 + nu)

w_st = torch.sqrt(beta * (1 + nu) / (alpha * nu))
# define aleatoric and epistemic uncert
u_al = np.sqrt(
beta.detach().numpy()
Expand All @@ -204,7 +204,7 @@ def loss_der(y, y_pred, coeff):
+ (alpha + 0.5) * torch.log(error**2 * nu + omega)
+ torch.lgamma(alpha)
- torch.lgamma(alpha + 0.5)
+ coeff * torch.abs(error) * (2.0 * nu + alpha)
+ (coeff * torch.abs(error / w_st) * (2.0 * nu + alpha))
),
u_al,
u_ep,
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/Aleatoric.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def beta_type(value):
ax.set_title("Deep Evidential Regression")
elif model[0:2] == "DE":
ax.set_title("Deep Ensemble (100 models)")
ax.set_ylim([0, 14])
ax.set_ylim([0, 11])
plt.legend()
if config.get_item("analysis", "savefig", "Analysis"):
plt.savefig(
Expand Down
2 changes: 1 addition & 1 deletion src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
},
"analysis": {
"noise_level_list": ["low", "medium", "high"],
"model_names_list": ["DER", "DE_desiderata_2"],
"model_names_list": ["DER_wst", "DE_desiderata_2"],
# ["DER_desiderata_2", "DE_desiderata_2"]
"plot": True,
"savefig": False,
Expand Down

0 comments on commit dae46d4

Please sign in to comment.