Skip to content

Commit

Permalink
Remove HeteroskedasticSingleTaskGP (#2616)
Browse files Browse the repository at this point in the history
Summary:

This model has been buggy for quite a long time and we still haven't fixed it. Removing it should be preferable to keeping around a known buggy model. Example bug reports:
- #861
- #933
- #2551

Reviewed By: esantorella

Differential Revision: D65543676
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 7, 2024
1 parent cd657d9 commit 232aae1
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 249 deletions.
4 changes: 1 addition & 3 deletions botorch/acquisition/joint_entropy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@
from botorch import settings
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
from botorch.acquisition.objective import PosteriorTransform

from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
from botorch.models.model import Model

from botorch.models.utils import check_no_nans, fantasize as fantasize_flag
from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from torch import Tensor
Expand Down
3 changes: 1 addition & 2 deletions botorch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP

from botorch.models.gp_regression import HeteroskedasticSingleTaskGP, SingleTaskGP
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.higher_order_gp import HigherOrderGP
Expand All @@ -33,7 +33,6 @@
"SaasFullyBayesianSingleTaskGP",
"SaasFullyBayesianMultiTaskGP",
"GenericDeterministicModel",
"HeteroskedasticSingleTaskGP",
"HigherOrderGP",
"KroneckerMultiTaskGP",
"MixedSingleTaskGP",
Expand Down
19 changes: 1 addition & 18 deletions botorch/models/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from botorch.exceptions import UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models import SingleTaskGP
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
Expand Down Expand Up @@ -84,12 +83,6 @@ def _check_compatibility(models: ModuleList) -> None:
"All models must be of type BatchedMultiOutputGPyTorchModel."
)

# TODO: Add support for HeteroskedasticSingleTaskGP.
if any(isinstance(m, HeteroskedasticSingleTaskGP) for m in models):
raise NotImplementedError(
"Conversion of HeteroskedasticSingleTaskGP is currently unsupported."
)

# TODO: Add support for custom likelihoods.
if any(getattr(m, "_is_custom_likelihood", False) for m in models):
raise NotImplementedError(
Expand Down Expand Up @@ -289,11 +282,6 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
was_training = batch_model.training
batch_model.train()
# TODO: Add support for HeteroskedasticSingleTaskGP.
if isinstance(batch_model, HeteroskedasticSingleTaskGP):
raise NotImplementedError(
"Conversion of HeteroskedasticSingleTaskGP is currently not supported."
)
if isinstance(batch_model, MixedSingleTaskGP):
raise NotImplementedError(
"Conversion of MixedSingleTaskGP is currently not supported."
Expand Down Expand Up @@ -393,12 +381,7 @@ def batched_multi_output_to_single_output(
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
was_training = batch_mo_model.training
batch_mo_model.train()
# TODO: Add support for HeteroskedasticSingleTaskGP.
if isinstance(batch_mo_model, HeteroskedasticSingleTaskGP):
raise NotImplementedError(
"Conversion of HeteroskedasticSingleTaskGP currently not supported."
)
elif not isinstance(batch_mo_model, BatchedMultiOutputGPyTorchModel):
if not isinstance(batch_mo_model, BatchedMultiOutputGPyTorchModel):
raise UnsupportedError("Only BatchedMultiOutputGPyTorchModels are supported.")
# TODO: Add support for custom likelihoods.
elif getattr(batch_mo_model, "_is_custom_likelihood", False):
Expand Down
145 changes: 17 additions & 128 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,58 +10,47 @@
These models are often a good starting point and are further documented in the
tutorials.
`SingleTaskGP` and `HeteroskedasticSingleTaskGP` are single-task exact GP models,
differing in how they treat noise. They use relatively strong priors on the Kernel
hyperparameters, which work best when covariates are normalized to the unit cube
and outcomes are standardized (zero mean, unit variance). By default, these models
use a `Standardize` outcome transform, which applies this standardization. However,
they do not (yet) use an input transform by default.
These models all work in batch mode (each batch having its own hyperparameters).
When the training observations include multiple outputs, these models use
`SingleTaskGP` is a single-task exact GP model that uses relatively strong priors on
the Kernel hyperparameters, which work best when covariates are normalized to the unit
cube and outcomes are standardized (zero mean, unit variance). By default, this model
uses a `Standardize` outcome transform, which applies this standardization. However,
it does not (yet) use an input transform by default.
`SingleTaskGP` model works in batch mode (each batch having its own hyperparameters).
When the training observations include multiple outputs, `SingleTaskGP` uses
batching to model outputs independently.
These models all support multiple outputs. However, as single-task models,
`SingleTaskGP` and `HeteroskedasticSingleTaskGP` should be used only when the
outputs are independent and all use the same training data. If outputs are
independent and outputs have different training data, use the `ModelListGP`.
When modeling correlations between outputs, use a multi-task model like `MultiTaskGP`.
`SingleTaskGP` supports multiple outputs. However, as a single-task model,
`SingleTaskGP` should be used only when the outputs are independent and all
use the same training inputs. If outputs are independent but they have different
training inputs, use the `ModelListGP`. When modeling correlations between outputs,
use a multi-task model like `MultiTaskGP`.
"""

from __future__ import annotations

import warnings
from typing import NoReturn

import torch
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model import FantasizeMixin
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import Log, OutcomeTransform, Standardize
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.models.utils import validate_input_scaling
from botorch.models.utils.gpytorch_modules import (
get_covar_module_with_dim_scaled_prior,
get_gaussian_likelihood_with_lognormal_prior,
MIN_INFERRED_NOISE_LEVEL,
)
from botorch.utils.containers import BotorchContainer
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.constraints.constraints import GreaterThan
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.likelihoods.gaussian_likelihood import (
_GaussianLikelihoodBase,
FixedNoiseGaussianLikelihood,
GaussianLikelihood,
)
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.likelihoods.noise_models import HeteroskedasticNoise
from gpytorch.means.constant_mean import ConstantMean
from gpytorch.means.mean import Mean
from gpytorch.mlls.noise_model_added_loss_term import NoiseModelAddedLossTerm
from gpytorch.models.exact_gp import ExactGP
from gpytorch.module import Module
from gpytorch.priors.smoothed_box_prior import SmoothedBoxPrior
from torch import Tensor


Expand Down Expand Up @@ -255,105 +244,5 @@ def forward(self, x: Tensor) -> MultivariateNormal:
return MultivariateNormal(mean_x, covar_x)


class HeteroskedasticSingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP):
r"""A single-task exact GP model using a heteroskedastic noise model.
This model differs from `SingleTaskGP` with observed observation noise
variances (`train_Yvar`) in that it can predict noise levels out of sample.
This is achieved by internally wrapping another GP (a `SingleTaskGP`) to model
the (log of) the observation noise. Noise levels must be provided to
`HeteroskedasticSingleTaskGP` as `train_Yvar`.
Examples of cases in which noise levels are known include online
experimentation and simulation optimization.
Example:
>>> train_X = torch.rand(20, 2)
>>> train_Y = torch.sin(train_X).sum(dim=1, keepdim=True)
>>> se = torch.linalg.norm(train_X, dim=1, keepdim=True)
>>> train_Yvar = 0.1 + se * torch.rand_like(train_Y)
>>> model = HeteroskedasticSingleTaskGP(train_X, train_Y, train_Yvar)
"""

def __init__(
self,
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor,
outcome_transform: OutcomeTransform | None = None,
input_transform: InputTransform | None = None,
) -> None:
r"""
Args:
train_X: A `batch_shape x n x d` tensor of training features.
train_Y: A `batch_shape x n x m` tensor of training observations.
train_Yvar: A `batch_shape x n x m` tensor of observed measurement
noise.
outcome_transform: An outcome transform that is applied to the
training data during instantiation and to the posterior during
inference (that is, the `Posterior` obtained by calling
`.posterior` on the model will be on the original scale).
Note that the noise model internally log-transforms the
variances, which will happen after this transform is applied.
input_transform: An input transfrom that is applied in the model's
forward pass.
"""
if outcome_transform is not None:
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
self._validate_tensor_args(X=train_X, Y=train_Y, Yvar=train_Yvar)
validate_input_scaling(train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar)
self._set_dimensions(train_X=train_X, train_Y=train_Y)
noise_likelihood = GaussianLikelihood(
noise_prior=SmoothedBoxPrior(-3, 5, 0.5, transform=torch.log),
batch_shape=self._aug_batch_shape,
noise_constraint=GreaterThan(
MIN_INFERRED_NOISE_LEVEL, transform=None, initial_value=1.0
),
)
# Likelihood will always get evaluated with transformed X, so we need to
# transform the training data before constructing the noise model.
with torch.no_grad():
transformed_X = self.transform_inputs(
X=train_X, input_transform=input_transform
)
noise_model = SingleTaskGP(
train_X=transformed_X,
train_Y=train_Yvar,
likelihood=noise_likelihood,
outcome_transform=Log(),
)
likelihood = _GaussianLikelihoodBase(HeteroskedasticNoise(noise_model))
# This is hacky -- this class used to inherit from SingleTaskGP, but it
# shouldn't so this is a quick fix to enable getting rid of that
# inheritance
SingleTaskGP.__init__(
# pyre-fixme[6]: Incompatible parameter type
self,
train_X=train_X,
train_Y=train_Y,
likelihood=likelihood,
outcome_transform=None,
input_transform=input_transform,
)
self.register_added_loss_term("noise_added_loss")
self.update_added_loss_term(
"noise_added_loss", NoiseModelAddedLossTerm(noise_model)
)
if outcome_transform is not None:
self.outcome_transform = outcome_transform
self.to(train_X)

# pyre-fixme[15]: Inconsistent override
def condition_on_observations(self, *_, **__) -> NoReturn:
raise NotImplementedError

# pyre-fixme[15]: Inconsistent override
def subset_output(self, idcs) -> NoReturn:
raise NotImplementedError

def forward(self, x: Tensor) -> MultivariateNormal:
if self.training:
x = self.transform_inputs(x)
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
# Note: There used to be `HeteroskedasticSingleTaskGP` here,
# but due to persistent bugs, it was removed in #2616.
2 changes: 1 addition & 1 deletion botorch_community/acquisition/scorebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
)
from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.models.fully_bayesian import MCMC_DIM, SaasFullyBayesianSingleTaskGP
from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
from botorch.models.utils import fantasize as fantasize_flag
from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from botorch_community.acquisition.bayesian_active_learning import DISTANCE_METRICS
from torch import Tensor
Expand Down
8 changes: 2 additions & 6 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ Noise can be treated in several different ways:
if you know your observations are noiseless (by passing a zero noise level).

- _Heteroskedastic_: Noise is provided as an input and is modeled to allow for
predicting noise out-of-sample. Models like `HeteroskedasticSingleTaskGP` take
this approach.
predicting noise out-of-sample. BoTorch does not implement a model that
supports this out of the box.

## Standard BoTorch Models

Expand All @@ -90,10 +90,6 @@ instead.
- [`SingleTaskGP`](../api/models.html#botorch.models.gp_regression.SingleTaskGP):
a single-task exact GP that supports both inferred and observed noise. When
noise observations are not provided, it infers a homoskedastic noise level.
- [`HeteroskedasticSingleTaskGP`](../api/models.html#botorch.models.gp_regression.HeteroskedasticSingleTaskGP):
a single-task exact GP that differs from `SingleTaskGP` with observed noise in
that it models heteroskedastic noise using an additional internal GP model. It
requires noise observations.
- [`MixedSingleTaskGP`](../api/models.html#botorch.models.gp_regression_mixed.MixedSingleTaskGP):
a single-task exact GP that supports mixed search spaces, which combine
discrete and continuous features.
Expand Down
24 changes: 1 addition & 23 deletions test/models/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@

import torch
from botorch.exceptions import UnsupportedError
from botorch.models import (
HeteroskedasticSingleTaskGP,
ModelListGP,
SingleTaskGP,
SingleTaskMultiFidelityGP,
)
from botorch.models import ModelListGP, SingleTaskGP, SingleTaskMultiFidelityGP
from botorch.models.converter import (
_batched_kernel,
batched_multi_output_to_single_output,
Expand Down Expand Up @@ -58,12 +53,6 @@ def test_batched_to_model_list(self):
)
list_gp = batched_to_model_list(batch_gp)
self.assertIsInstance(list_gp, ModelListGP)
# test HeteroskedasticSingleTaskGP
batch_gp = HeteroskedasticSingleTaskGP(
train_X, train_Y, torch.rand_like(train_Y)
)
with self.assertRaises(NotImplementedError):
batched_to_model_list(batch_gp)
# test with transforms
input_tf = Normalize(
d=2,
Expand Down Expand Up @@ -161,12 +150,6 @@ def test_model_list_to_batched(self):
)
with self.assertRaises(UnsupportedError):
model_list_to_batched(ModelListGP(gp1, gp2))
# test HeteroskedasticSingleTaskGP
gp2 = HeteroskedasticSingleTaskGP(
train_X, train_Y1, torch.ones_like(train_Y1)
)
with self.assertRaises(NotImplementedError):
model_list_to_batched(ModelListGP(gp2))
# test custom likelihood
gp2 = SingleTaskGP(
train_X,
Expand Down Expand Up @@ -419,11 +402,6 @@ def test_batched_multi_output_to_single_output(self):
non_batch_model = SimpleGPyTorchModel(train_X, train_Y[:, :1])
with self.assertRaises(UnsupportedError):
batched_multi_output_to_single_output(non_batch_model)
gp2 = HeteroskedasticSingleTaskGP(
train_X, train_Y, torch.ones_like(train_Y)
)
with self.assertRaises(NotImplementedError):
batched_multi_output_to_single_output(gp2)
# test custom likelihood
gp2 = SingleTaskGP(train_X, train_Y, likelihood=GaussianLikelihood())
with self.assertRaises(NotImplementedError):
Expand Down
Loading

0 comments on commit 232aae1

Please sign in to comment.