From 232aae1e620e0ea23f2b0b1a0b950bc21a104de7 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 7 Nov 2024 08:39:14 -0800 Subject: [PATCH] Remove HeteroskedasticSingleTaskGP (#2616) Summary: This model has been buggy for quite a long time and we still haven't fixed it. Removing it should be preferable to keeping around a known buggy model. Example bug reports: - https://github.com/pytorch/botorch/issues/861 - https://github.com/pytorch/botorch/issues/933 - https://github.com/pytorch/botorch/issues/2551 Reviewed By: esantorella Differential Revision: D65543676 --- botorch/acquisition/joint_entropy_search.py | 4 +- botorch/models/__init__.py | 3 +- botorch/models/converter.py | 19 +-- botorch/models/gp_regression.py | 145 +++----------------- botorch_community/acquisition/scorebo.py | 2 +- docs/models.md | 8 +- test/models/test_converter.py | 24 +--- test/models/test_gp_regression.py | 70 +--------- 8 files changed, 26 insertions(+), 249 deletions(-) diff --git a/botorch/acquisition/joint_entropy_search.py b/botorch/acquisition/joint_entropy_search.py index eed1828a1c..5db9d32019 100644 --- a/botorch/acquisition/joint_entropy_search.py +++ b/botorch/acquisition/joint_entropy_search.py @@ -29,12 +29,10 @@ from botorch import settings from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin from botorch.acquisition.objective import PosteriorTransform - from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP -from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL from botorch.models.model import Model - from botorch.models.utils import check_no_nans, fantasize as fantasize_flag +from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL from botorch.sampling.normal import SobolQMCNormalSampler from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform from torch import Tensor diff --git a/botorch/models/__init__.py b/botorch/models/__init__.py index f0c24a27ee..031ce83299 100644 --- a/botorch/models/__init__.py +++ b/botorch/models/__init__.py @@ -17,7 +17,7 @@ from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP -from botorch.models.gp_regression import HeteroskedasticSingleTaskGP, SingleTaskGP +from botorch.models.gp_regression import SingleTaskGP from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.higher_order_gp import HigherOrderGP @@ -33,7 +33,6 @@ "SaasFullyBayesianSingleTaskGP", "SaasFullyBayesianMultiTaskGP", "GenericDeterministicModel", - "HeteroskedasticSingleTaskGP", "HigherOrderGP", "KroneckerMultiTaskGP", "MixedSingleTaskGP", diff --git a/botorch/models/converter.py b/botorch/models/converter.py index 276a40623f..48dcaf84ad 100644 --- a/botorch/models/converter.py +++ b/botorch/models/converter.py @@ -17,7 +17,6 @@ from botorch.exceptions import UnsupportedError from botorch.exceptions.warnings import BotorchWarning from botorch.models import SingleTaskGP -from botorch.models.gp_regression import HeteroskedasticSingleTaskGP from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel @@ -84,12 +83,6 @@ def _check_compatibility(models: ModuleList) -> None: "All models must be of type BatchedMultiOutputGPyTorchModel." ) - # TODO: Add support for HeteroskedasticSingleTaskGP. - if any(isinstance(m, HeteroskedasticSingleTaskGP) for m in models): - raise NotImplementedError( - "Conversion of HeteroskedasticSingleTaskGP is currently unsupported." - ) - # TODO: Add support for custom likelihoods. if any(getattr(m, "_is_custom_likelihood", False) for m in models): raise NotImplementedError( @@ -289,11 +282,6 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2) was_training = batch_model.training batch_model.train() - # TODO: Add support for HeteroskedasticSingleTaskGP. - if isinstance(batch_model, HeteroskedasticSingleTaskGP): - raise NotImplementedError( - "Conversion of HeteroskedasticSingleTaskGP is currently not supported." - ) if isinstance(batch_model, MixedSingleTaskGP): raise NotImplementedError( "Conversion of MixedSingleTaskGP is currently not supported." @@ -393,12 +381,7 @@ def batched_multi_output_to_single_output( warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2) was_training = batch_mo_model.training batch_mo_model.train() - # TODO: Add support for HeteroskedasticSingleTaskGP. - if isinstance(batch_mo_model, HeteroskedasticSingleTaskGP): - raise NotImplementedError( - "Conversion of HeteroskedasticSingleTaskGP currently not supported." - ) - elif not isinstance(batch_mo_model, BatchedMultiOutputGPyTorchModel): + if not isinstance(batch_mo_model, BatchedMultiOutputGPyTorchModel): raise UnsupportedError("Only BatchedMultiOutputGPyTorchModels are supported.") # TODO: Add support for custom likelihoods. elif getattr(batch_mo_model, "_is_custom_likelihood", False): diff --git a/botorch/models/gp_regression.py b/botorch/models/gp_regression.py index ee380f9e84..1e8ba53c3a 100644 --- a/botorch/models/gp_regression.py +++ b/botorch/models/gp_regression.py @@ -10,58 +10,47 @@ These models are often a good starting point and are further documented in the tutorials. -`SingleTaskGP` and `HeteroskedasticSingleTaskGP` are single-task exact GP models, -differing in how they treat noise. They use relatively strong priors on the Kernel -hyperparameters, which work best when covariates are normalized to the unit cube -and outcomes are standardized (zero mean, unit variance). By default, these models -use a `Standardize` outcome transform, which applies this standardization. However, -they do not (yet) use an input transform by default. - -These models all work in batch mode (each batch having its own hyperparameters). -When the training observations include multiple outputs, these models use +`SingleTaskGP` is a single-task exact GP model that uses relatively strong priors on +the Kernel hyperparameters, which work best when covariates are normalized to the unit +cube and outcomes are standardized (zero mean, unit variance). By default, this model +uses a `Standardize` outcome transform, which applies this standardization. However, +it does not (yet) use an input transform by default. + +`SingleTaskGP` model works in batch mode (each batch having its own hyperparameters). +When the training observations include multiple outputs, `SingleTaskGP` uses batching to model outputs independently. -These models all support multiple outputs. However, as single-task models, -`SingleTaskGP` and `HeteroskedasticSingleTaskGP` should be used only when the -outputs are independent and all use the same training data. If outputs are -independent and outputs have different training data, use the `ModelListGP`. -When modeling correlations between outputs, use a multi-task model like `MultiTaskGP`. +`SingleTaskGP` supports multiple outputs. However, as a single-task model, +`SingleTaskGP` should be used only when the outputs are independent and all +use the same training inputs. If outputs are independent but they have different +training inputs, use the `ModelListGP`. When modeling correlations between outputs, +use a multi-task model like `MultiTaskGP`. """ from __future__ import annotations import warnings -from typing import NoReturn import torch from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel from botorch.models.model import FantasizeMixin from botorch.models.transforms.input import InputTransform -from botorch.models.transforms.outcome import Log, OutcomeTransform, Standardize +from botorch.models.transforms.outcome import OutcomeTransform, Standardize from botorch.models.utils import validate_input_scaling from botorch.models.utils.gpytorch_modules import ( get_covar_module_with_dim_scaled_prior, get_gaussian_likelihood_with_lognormal_prior, - MIN_INFERRED_NOISE_LEVEL, ) from botorch.utils.containers import BotorchContainer from botorch.utils.datasets import SupervisedDataset from botorch.utils.types import _DefaultType, DEFAULT -from gpytorch.constraints.constraints import GreaterThan from gpytorch.distributions.multivariate_normal import MultivariateNormal -from gpytorch.likelihoods.gaussian_likelihood import ( - _GaussianLikelihoodBase, - FixedNoiseGaussianLikelihood, - GaussianLikelihood, -) +from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from gpytorch.likelihoods.likelihood import Likelihood -from gpytorch.likelihoods.noise_models import HeteroskedasticNoise from gpytorch.means.constant_mean import ConstantMean from gpytorch.means.mean import Mean -from gpytorch.mlls.noise_model_added_loss_term import NoiseModelAddedLossTerm from gpytorch.models.exact_gp import ExactGP from gpytorch.module import Module -from gpytorch.priors.smoothed_box_prior import SmoothedBoxPrior from torch import Tensor @@ -255,105 +244,5 @@ def forward(self, x: Tensor) -> MultivariateNormal: return MultivariateNormal(mean_x, covar_x) -class HeteroskedasticSingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP): - r"""A single-task exact GP model using a heteroskedastic noise model. - - This model differs from `SingleTaskGP` with observed observation noise - variances (`train_Yvar`) in that it can predict noise levels out of sample. - This is achieved by internally wrapping another GP (a `SingleTaskGP`) to model - the (log of) the observation noise. Noise levels must be provided to - `HeteroskedasticSingleTaskGP` as `train_Yvar`. - - Examples of cases in which noise levels are known include online - experimentation and simulation optimization. - - Example: - >>> train_X = torch.rand(20, 2) - >>> train_Y = torch.sin(train_X).sum(dim=1, keepdim=True) - >>> se = torch.linalg.norm(train_X, dim=1, keepdim=True) - >>> train_Yvar = 0.1 + se * torch.rand_like(train_Y) - >>> model = HeteroskedasticSingleTaskGP(train_X, train_Y, train_Yvar) - """ - - def __init__( - self, - train_X: Tensor, - train_Y: Tensor, - train_Yvar: Tensor, - outcome_transform: OutcomeTransform | None = None, - input_transform: InputTransform | None = None, - ) -> None: - r""" - Args: - train_X: A `batch_shape x n x d` tensor of training features. - train_Y: A `batch_shape x n x m` tensor of training observations. - train_Yvar: A `batch_shape x n x m` tensor of observed measurement - noise. - outcome_transform: An outcome transform that is applied to the - training data during instantiation and to the posterior during - inference (that is, the `Posterior` obtained by calling - `.posterior` on the model will be on the original scale). - Note that the noise model internally log-transforms the - variances, which will happen after this transform is applied. - input_transform: An input transfrom that is applied in the model's - forward pass. - """ - if outcome_transform is not None: - train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar) - self._validate_tensor_args(X=train_X, Y=train_Y, Yvar=train_Yvar) - validate_input_scaling(train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar) - self._set_dimensions(train_X=train_X, train_Y=train_Y) - noise_likelihood = GaussianLikelihood( - noise_prior=SmoothedBoxPrior(-3, 5, 0.5, transform=torch.log), - batch_shape=self._aug_batch_shape, - noise_constraint=GreaterThan( - MIN_INFERRED_NOISE_LEVEL, transform=None, initial_value=1.0 - ), - ) - # Likelihood will always get evaluated with transformed X, so we need to - # transform the training data before constructing the noise model. - with torch.no_grad(): - transformed_X = self.transform_inputs( - X=train_X, input_transform=input_transform - ) - noise_model = SingleTaskGP( - train_X=transformed_X, - train_Y=train_Yvar, - likelihood=noise_likelihood, - outcome_transform=Log(), - ) - likelihood = _GaussianLikelihoodBase(HeteroskedasticNoise(noise_model)) - # This is hacky -- this class used to inherit from SingleTaskGP, but it - # shouldn't so this is a quick fix to enable getting rid of that - # inheritance - SingleTaskGP.__init__( - # pyre-fixme[6]: Incompatible parameter type - self, - train_X=train_X, - train_Y=train_Y, - likelihood=likelihood, - outcome_transform=None, - input_transform=input_transform, - ) - self.register_added_loss_term("noise_added_loss") - self.update_added_loss_term( - "noise_added_loss", NoiseModelAddedLossTerm(noise_model) - ) - if outcome_transform is not None: - self.outcome_transform = outcome_transform - self.to(train_X) - - # pyre-fixme[15]: Inconsistent override - def condition_on_observations(self, *_, **__) -> NoReturn: - raise NotImplementedError - - # pyre-fixme[15]: Inconsistent override - def subset_output(self, idcs) -> NoReturn: - raise NotImplementedError - - def forward(self, x: Tensor) -> MultivariateNormal: - if self.training: - x = self.transform_inputs(x) - mean_x = self.mean_module(x) - covar_x = self.covar_module(x) - return MultivariateNormal(mean_x, covar_x) +# Note: There used to be `HeteroskedasticSingleTaskGP` here, +# but due to persistent bugs, it was removed in #2616. diff --git a/botorch_community/acquisition/scorebo.py b/botorch_community/acquisition/scorebo.py index 38453aeb0c..f0943efe3c 100644 --- a/botorch_community/acquisition/scorebo.py +++ b/botorch_community/acquisition/scorebo.py @@ -31,8 +31,8 @@ ) from botorch.acquisition.objective import ScalarizedPosteriorTransform from botorch.models.fully_bayesian import MCMC_DIM, SaasFullyBayesianSingleTaskGP -from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL from botorch.models.utils import fantasize as fantasize_flag +from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform from botorch_community.acquisition.bayesian_active_learning import DISTANCE_METRICS from torch import Tensor diff --git a/docs/models.md b/docs/models.md index 79d4067f7c..cec82f5c46 100644 --- a/docs/models.md +++ b/docs/models.md @@ -72,8 +72,8 @@ Noise can be treated in several different ways: if you know your observations are noiseless (by passing a zero noise level). - _Heteroskedastic_: Noise is provided as an input and is modeled to allow for - predicting noise out-of-sample. Models like `HeteroskedasticSingleTaskGP` take - this approach. + predicting noise out-of-sample. BoTorch does not implement a model that + supports this out of the box. ## Standard BoTorch Models @@ -90,10 +90,6 @@ instead. - [`SingleTaskGP`](../api/models.html#botorch.models.gp_regression.SingleTaskGP): a single-task exact GP that supports both inferred and observed noise. When noise observations are not provided, it infers a homoskedastic noise level. -- [`HeteroskedasticSingleTaskGP`](../api/models.html#botorch.models.gp_regression.HeteroskedasticSingleTaskGP): - a single-task exact GP that differs from `SingleTaskGP` with observed noise in - that it models heteroskedastic noise using an additional internal GP model. It - requires noise observations. - [`MixedSingleTaskGP`](../api/models.html#botorch.models.gp_regression_mixed.MixedSingleTaskGP): a single-task exact GP that supports mixed search spaces, which combine discrete and continuous features. diff --git a/test/models/test_converter.py b/test/models/test_converter.py index db23e06479..688c90e427 100644 --- a/test/models/test_converter.py +++ b/test/models/test_converter.py @@ -7,12 +7,7 @@ import torch from botorch.exceptions import UnsupportedError -from botorch.models import ( - HeteroskedasticSingleTaskGP, - ModelListGP, - SingleTaskGP, - SingleTaskMultiFidelityGP, -) +from botorch.models import ModelListGP, SingleTaskGP, SingleTaskMultiFidelityGP from botorch.models.converter import ( _batched_kernel, batched_multi_output_to_single_output, @@ -58,12 +53,6 @@ def test_batched_to_model_list(self): ) list_gp = batched_to_model_list(batch_gp) self.assertIsInstance(list_gp, ModelListGP) - # test HeteroskedasticSingleTaskGP - batch_gp = HeteroskedasticSingleTaskGP( - train_X, train_Y, torch.rand_like(train_Y) - ) - with self.assertRaises(NotImplementedError): - batched_to_model_list(batch_gp) # test with transforms input_tf = Normalize( d=2, @@ -161,12 +150,6 @@ def test_model_list_to_batched(self): ) with self.assertRaises(UnsupportedError): model_list_to_batched(ModelListGP(gp1, gp2)) - # test HeteroskedasticSingleTaskGP - gp2 = HeteroskedasticSingleTaskGP( - train_X, train_Y1, torch.ones_like(train_Y1) - ) - with self.assertRaises(NotImplementedError): - model_list_to_batched(ModelListGP(gp2)) # test custom likelihood gp2 = SingleTaskGP( train_X, @@ -419,11 +402,6 @@ def test_batched_multi_output_to_single_output(self): non_batch_model = SimpleGPyTorchModel(train_X, train_Y[:, :1]) with self.assertRaises(UnsupportedError): batched_multi_output_to_single_output(non_batch_model) - gp2 = HeteroskedasticSingleTaskGP( - train_X, train_Y, torch.ones_like(train_Y) - ) - with self.assertRaises(NotImplementedError): - batched_multi_output_to_single_output(gp2) # test custom likelihood gp2 = SingleTaskGP(train_X, train_Y, likelihood=GaussianLikelihood()) with self.assertRaises(NotImplementedError): diff --git a/test/models/test_gp_regression.py b/test/models/test_gp_regression.py index 32e2364fb2..643c834a82 100644 --- a/test/models/test_gp_regression.py +++ b/test/models/test_gp_regression.py @@ -11,26 +11,19 @@ import torch from botorch.exceptions.warnings import OptimizationWarning from botorch.fit import fit_gpytorch_mll -from botorch.models.gp_regression import HeteroskedasticSingleTaskGP, SingleTaskGP +from botorch.models.gp_regression import SingleTaskGP from botorch.models.transforms import Normalize, Standardize from botorch.models.transforms.input import InputStandardize from botorch.models.transforms.outcome import Log from botorch.posteriors import GPyTorchPosterior from botorch.sampling import SobolQMCNormalSampler from botorch.utils.datasets import SupervisedDataset -from botorch.utils.sampling import manual_seed from botorch.utils.test_helpers import get_pvar_expected from botorch.utils.testing import _get_random_data, BotorchTestCase from gpytorch.kernels import RBFKernel -from gpytorch.likelihoods import ( - _GaussianLikelihoodBase, - FixedNoiseGaussianLikelihood, - GaussianLikelihood, - HeteroskedasticNoise, -) +from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood from gpytorch.means import ConstantMean, ZeroMean from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood -from gpytorch.mlls.noise_model_added_loss_term import NoiseModelAddedLossTerm from gpytorch.priors import LogNormalPrior @@ -583,62 +576,3 @@ def test_fantasized_noise(self): == obs_noise.expand(X_f.shape[:-1] + torch.Size([m])) ).all() ) - - -class TestHeteroskedasticSingleTaskGP(TestGPRegressionBase): - def _get_model_and_data( - self, batch_shape, m, outcome_transform=None, input_transform=None, **tkwargs - ): - with manual_seed(0): - train_X, train_Y = _get_random_data(batch_shape=batch_shape, m=m, **tkwargs) - train_Yvar = (0.1 + 0.1 * torch.rand_like(train_Y)) ** 2 - model_kwargs = { - "train_X": train_X, - "train_Y": train_Y, - "train_Yvar": train_Yvar, - "input_transform": input_transform, - "outcome_transform": outcome_transform, - } - model = HeteroskedasticSingleTaskGP(**model_kwargs) - return model, model_kwargs - - def test_custom_init(self) -> None: - """ - This test exists because `TestHeteroskedasticSingleTaskGP` inherits from - `TestSingleTaskGP`, which has a `test_custom_init` method that isn't relevant - for `TestHeteroskedasticSingleTaskGP`. - """ - - def test_gp(self): - super().test_gp(double_only=True) - - def test_fantasize(self) -> None: - """ - This test exists because `TestHeteroskedasticSingleTaskGP` inherits from - `TestSingleTaskGP`, which has a `fantasize` method that isn't relevant - for `TestHeteroskedasticSingleTaskGP`. - """ - - def test_heteroskedastic_likelihood(self): - for batch_shape, m, dtype in itertools.product( - (torch.Size(), torch.Size([2])), (1, 2), (torch.float, torch.double) - ): - tkwargs = {"device": self.device, "dtype": dtype} - model, _ = self._get_model_and_data(batch_shape=batch_shape, m=m, **tkwargs) - self.assertIsInstance(model.likelihood, _GaussianLikelihoodBase) - self.assertFalse(isinstance(model.likelihood, GaussianLikelihood)) - self.assertIsInstance(model.likelihood.noise_covar, HeteroskedasticNoise) - self.assertIsInstance( - model.likelihood.noise_covar.noise_model, SingleTaskGP - ) - self.assertIsInstance( - model._added_loss_terms["noise_added_loss"], NoiseModelAddedLossTerm - ) - - def test_condition_on_observations(self): - with self.assertRaises(NotImplementedError): - super().test_condition_on_observations() - - def test_subset_model(self): - with self.assertRaises(NotImplementedError): - super().test_subset_model()