Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] GP classification / discrete outputs #640

Closed
DavidWalz opened this issue Dec 22, 2020 · 10 comments
Closed

[Question] GP classification / discrete outputs #640

DavidWalz opened this issue Dec 22, 2020 · 10 comments

Comments

@DavidWalz
Copy link

DavidWalz commented Dec 22, 2020

I'm implementing a BO loop with feasibility constraints along the lines of https://botorch.org/tutorials/constrained_multi_objective_bo
However, in my case evaluations of the feasibility constraint are discrete (0, 1) for which a GP model with a binomial likelihood seems to a suitable approach.

import gpytorch

class GPClassificationModel(gpytorch.models.ApproximateGP):
    def __init__(self, train_x):
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(train_x.size(0))
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, train_x, variational_distribution, learn_inducing_locations=True
        )
        super(GPClassificationModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

model = GPClassificationModel(train_x)
likelihood = gpytorch.likelihoods.BernoulliLikelihood()
mll = gpytorch.mlls.VariationalELBO(likelihood, model, len(train_y), combine_terms=False)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

model.train()
likelihood.train()

for i in range(400):
    optimizer.zero_grad()
    output = model(train_x)
    log_lik, kl_div, log_prior = mll(output, train_y)
    loss = -(log_lik - kl_div + log_prior)
    loss.backward()
    optimizer.step()

Now I'm wondering how to feed this model together with a SingleTaskGP to the acquisition function.
Do I have to base my GPClassificationModel on ApproximateGP or can I simply combine it in ModelListGPyTorchModel?

@Balandat
Copy link
Contributor

Good question. The SingleTaskGP is what you fit on your (continuous) objective? I know @mshvartsman has used a probit setup like this before, though I believe not in this constrained setting.

Sticking your probit model into a ModelListGPyTorchModel together with a SingleTaskGP should work (no guarantees here, may need some debugging, happy to help with that), since both just return MVNs. The thing you'll need to do for optimization is figure out how to use this in the objective. One way would be to specify a probability-weighted GenericMCObjective of the form

def bco(samples):
    obj = samples[..., 0]
    con = samples[..., 1]
    return obj * likelihood.log_marginal(samples[..., 1]).exp()

obj = GenericMCObjective(bco) 

where samples are samples produced from the posterior of the ModelListGP model and likelihood is the bernoulli likelihood from the model. You should be able to use this in any MCAcquisitionFunction.

Let me know how this works out, this seems like an interesting use case, happy to help out with making this work.

@DavidWalz
Copy link
Author

Thanks! Yes, the continous objectives are modeled by a SingleTaskGP.
So, this is as far as I am with a toy example https://gist.github.com/DavidWalz/fc1d1fa2d68bf1fa20f0d7639581a21a

For the constrained ParEGO acquisition I should be able to simply use

ConstrainedMCObjective(
    objective=lambda Z: scalarization(Z[..., :M]),
    constraints=[lambda Z: constr_likelihood(Z[..., -1])],
)

For the combining the 2 models I tried

models = ModelListGP(obj_model, constr_model)

     31         for m in models:
     32             if not hasattr(m, "likelihood"):
---> 33                 raise ValueError(
     34                     "IndependentModelList currently only supports models that have a likelihood (e.g. ExactGPs)"
     35                 )

The constructor works when I add the Bernoulli likelihood to model for the constraint.
However, forward only returns the MVN of the SingleTaskGP and posterior raises an error

constr_model.likelihood = constr_likelihood

models = botorch.models.ModelListGP(obj_model, constr_model)
models(X)
>> [MultivariateNormal(loc: torch.Size([2, 10000]))]

models.posterior(X)
>> ...\site-packages\botorch\models\gpytorch.py in posterior(self, X, output_indices, observation_noise, **kwargs)
    549             try:
    550                 oct = self.models[i].outcome_transform
--> 551                 tf_mvn = oct.untransform_posterior(GPyTorchPosterior(mvn)).mvn
    552             except AttributeError:
    553                 tf_mvn = mvn

...\site-packages\botorch\models\transforms\outcome.py in untransform_posterior(self, posterior)
    258             )
    259         if not self._m == posterior.event_shape[-1]:
--> 260             raise RuntimeError(
    261                 "Incompatible output dimensions encountered for transform "
    262                 f"{self._m} and posterior {posterior.event_shape[-1]}"

RuntimeError: Incompatible output dimensions encountered for transform 2 and posterior 1

If I need to implement a custom class based off ModelListGPyTorchModel, how would I go about that?

@Balandat
Copy link
Contributor

So the issue is the OutcomeTransform here. Essentially the new posterior has two outcome dimensions, but the Standardize outcome transform is set up to only operate on the data for the objective. Let me think about what one could do here that would be reasonably elegant.

In the meantime, one way to circumvent this would be to avoid using the outcome transform and just pass in standardized data - that way you wouldn't hit that issue. The downside is that now your objective is on a standardized scale so probability-weighted objective doesn't really make a lot of sense (btw the same issue occurs when your objective takes on negative values). To fix that you could keep the mean/sem from your manual standardization, and then add them into the objective to do the rescaling there:

lambda Z: scalarization(unstandardize(Z[..., :M], means, stds)),

where means and sems are the statistics computed on the training inputs, and unstandardize is sth like this:

def unstandardize(values: Tensor, mean: Tensor, std: Tensor) -> Tensor:
    return values * std + mean

@mshvartsman
Copy link
Contributor

Sorry, I haven't used probit GPs for constraints, so haven't seen this issue.

As a minor point, unless I'm missing something, your literal snippet won't work (there's no optimizer.step() call so no parameters will be updated). I'm assuming that's just a copy-paste error into github rather than your real code, but just flagging for any visitors from the internet who try to copy-paste this code and run it :). Relatedly, why not call botorch.fit.fit_gpytorch_model(mll) to optimize using scipy's L-BFGS instead of Adam?

@DavidWalz
Copy link
Author

@Balandat Thanks, I'll try that.

@mshvartsman Thanks. Indeed the optimizer.step() got lost while pasting here. I updated the initial post to avoid confusion.

Using L-BFGS instead of Adam would be very nice indeed. However, fit_gpytorch_model expects the the attributes train_inputs and train_targets which ExactGP provides, but AproximateGP does not. When trying to add them to my derived class, I got stuck trying to figure out what those 2 attributes are supposed to contain. If you've done that already, I would appreciate a hint.

@Balandat
Copy link
Contributor

You should be able to set that - train_inputs is a tuple of tensor arguments with the training data - typically with just a single element that is what BoTorch refers to as train_X; train_targets is the tensor of training observations, what BoTorch usually calls train_Y.

Looking through some of Michael's code, I've seen he has successfully done that on a GP model subclassing ApproximateGP, this is in the constructor:

class MichaelsCoolGP(ApproximateGP, GPyTorchModel):

    def __init__(
        self,
        train_x: torch.Tensor,
        train_y: torch.Tensor,
        inducing_points: torch.Tensor,
        ...,
    )
    # do some stufff ...
    self.train_inputs = (train_x,)
    self.train_targets = train_y

@DavidWalz
Copy link
Author

@mshvartsman Sorry, I had not seen #641 .
@Balandat I think I got it working using manually standardized outcomes values and a custom objective function that puts the outcomes on the unit scale and multiplies with the bernoulli probabilities, as proposed.
Here the full example:

import botorch
import gpytorch
import matplotlib.pyplot as plt
import torch
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.fit import fit_gpytorch_model
from botorch.models import ModelListGP, SingleTaskGP
from botorch.models.gpytorch import GPyTorchModel
from functools import partial
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy


N = 10  # initial observations
problem = botorch.test_functions.multi_objective.BraninCurrin(noise_std=0)

def feasibility_constraint(X):
    return ((X ** 2).sum(axis=-1) < 0.5).to(dtype=torch.float32)

train_x = botorch.utils.sampling.draw_sobol_samples(
    problem.bounds, n=N, q=1, seed=42
).squeeze()

train_y = problem(train_x)
train_y_mean = train_y.mean(dim=0)
train_y_std = train_y.std(dim=0)
train_yn = (train_y - train_y_mean) / train_y_std

train_f = feasibility_constraint(train_x)


class GPClassificationModel(ApproximateGP, GPyTorchModel):
    def __init__(self, train_x, train_y):
        self.train_inputs = (train_x,)
        self.train_targets = train_y

        variational_distribution = CholeskyVariationalDistribution(train_x.size(0))
        variational_strategy = VariationalStrategy(
            self, train_x, variational_distribution
        )
        super(GPClassificationModel, self).__init__(variational_strategy)

        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        self.likelihood = gpytorch.likelihoods.BernoulliLikelihood()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


obj1_model = SingleTaskGP(train_x, train_yn[:, [0]])
obj1_mll = ExactMarginalLogLikelihood(obj1_model.likelihood, obj1_model)
fit_gpytorch_model(obj1_mll)

obj2_model = SingleTaskGP(train_x, train_yn[:, [1]])
obj2_mll = ExactMarginalLogLikelihood(obj2_model.likelihood, obj2_model)
fit_gpytorch_model(obj2_mll)

constr_model = GPClassificationModel(train_x, train_f)
constr_mll = gpytorch.mlls.VariationalELBO(constr_model.likelihood, constr_model, N)
fit_gpytorch_model(constr_mll)

models = ModelListGP(obj1_model, obj2_model, constr_model)

# propose next point
train_yn_min = train_yn.min(dim=0).values
train_yn_max = train_yn.max(dim=0).values
train_data = torch.cat([train_yn, (train_f.view(-1, 1) * 200 - 100)], dim=1)

def objective_func(samples, weights):
    Yn, F = samples[..., :-1], samples[..., -1]
    # place outcomes on a scale from 0 (worst) to 1 (best)
    Yu = 1 - (Yn - train_yn_min) / (train_yn_max - train_yn_min)
    scalarization = (weights * Yu).min(dim=-1).values
    p = constr_model.likelihood(F).probs
    return p * scalarization

objective = botorch.acquisition.GenericMCObjective(
    partial(objective_func, weights=torch.tensor([1, 1]))
)
acquisition = qExpectedImprovement(
    model=models, objective=objective, best_f=objective(train_data).max()
)
candidate, acq_value = botorch.optim.optimize_acqf(
    acquisition, bounds=problem.bounds, q=1, num_restarts=5, raw_samples=256
)
print(candidate, acq_value)

# plot
n_grid = 51
dx = torch.linspace(0, 1, n_grid)
X1_grid, X2_grid = torch.meshgrid(dx, dx)
X = torch.stack([X1_grid.reshape(-1), X2_grid.reshape(-1)], dim=1)

with torch.no_grad():
    posterior_mean = models.posterior(X).mean.reshape(n_grid, n_grid, 3)

weights = torch.tensor([1, 1])
fig, (ax1, ax2, ax3, ax4) = plt.subplots(
    ncols=4, figsize=(15, 4), sharex=True, sharey=True
)
ax1.contourf(X1_grid, X2_grid, posterior_mean[..., 0], levels=16)
ax2.contourf(X1_grid, X2_grid, posterior_mean[..., 1], levels=16)
ax3.contourf(X1_grid, X2_grid, constr_model.likelihood(posterior_mean[..., 2]).probs)
ax4.contourf(X1_grid, X2_grid, objective_func(posterior_mean, weights))
ax1.scatter(*train_x.T, c="r")
ax2.scatter(*train_x.T, c="r")
ax3.scatter(*train_x.T, c="r")
ax1.set(xlabel="x1", title="output 1 (minimize)", ylabel="x2")
ax2.set(xlabel="x1", title="output 2 (minimize)")
ax3.set(xlabel="x1", title="feasibility constraint (maximize)")
ax4.set(xlabel="x1", title=f"objective, w={weights.numpy()} (maximize)")
fig.suptitle("Model posterior")
fig.tight_layout()
plt.show()

@eytan
Copy link
Contributor

eytan commented Dec 26, 2020 via email

@Balandat
Copy link
Contributor

Glad to see you got it to work.

So the issue is the OutcomeTransform here. Essentially the new posterior has two outcome dimensions, but the Standardize outcome transform is set up to only operate on the data for the objective

I'm actually somewhat unsure about this statement now - the way the outcome transforms are being handled inside the posterior of ModelListGPyTorchModel should be appropriate. I'll have to look into this some more to understand what exactly is going wrong.

@Balandat
Copy link
Contributor

Actually, I think it's because SingleTaskGP is a BatchedMultiOutputGPyTorchModel, which does some handling in its posterior call that doesn't happen in that of ModelListGP. Tracking this in a new issue: #650

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants