Skip to content

Commit

Permalink
fix ucb lcb sign error
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyhickman committed Aug 2, 2023
1 parent be04ae8 commit 27f3e03
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/atlas/acquisition_functions/acqfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def __init__(self, reg_model, cla_model, cla_likelihood=None, **acqf_args):
def evaluate(self, X: torch.Tensor):
posterior = self.reg_model.posterior(X=X)
mean, sigma = self.compute_mean_sigma(posterior)
acqf_val = mean - self.beta.sqrt() * sigma
acqf_val = -mean - self.beta.sqrt() * sigma
return acqf_val


Expand All @@ -376,7 +376,7 @@ def __init__(self, reg_model, cla_model, cla_likelihood, **acqf_args):
def evaluate(self, X: torch.Tensor):
posterior = self.reg_model.posterior(X=X)
mean, sigma = self.compute_mean_sigma(posterior)
acqf_val = mean + self.beta.sqrt() * sigma
acqf_val = -mean + self.beta.sqrt() * sigma
return acqf_val


Expand Down
2 changes: 1 addition & 1 deletion src/atlas/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def cla_surrogate(
sample_x.append(float(element))
X_proc.append(sample_x)

X_proc = torch.tensor(np.array(X_proc, **tkwargs))
X_proc = torch.tensor(np.array(X_proc), **tkwargs)

if (
self.problem_type == "fully_categorical"
Expand Down

0 comments on commit 27f3e03

Please sign in to comment.