Skip to content

Commit

Permalink
Clean up is_fully_bayesian (#2108)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#1992

Pull Request resolved: #2108

This attempts to clean up the usage of `is_fully_bayesian` and also separately treat fully Bayesian models from ensemble models.

The main changes in diff are to:
- Add an `_is_fully_bayesian` attribute to `Model`. This is `True` for fully Bayesian models that rely on Pyro/NUTS to be fitted (they need some special handling for fitting and `state_dict` loading/saving.
- Add an `_is_ensemble` attribute to `Model`. This indicates whether the model is a collection of multiple models that are stored in an additional batch dimension. This is hopefully a better classification, but I'm open to a different name here.
- Rename `FullyBayesianPosterior` to `GaussianMixturePosterior` since that is more descriptive and plays better with the other changes.

Reviewed By: esantorella

Differential Revision: D50884342

fbshipit-source-id: 0ba603416c1823026c4fdf2e445cefdf8036cda8
  • Loading branch information
David Eriksson authored and facebook-github-bot committed Nov 16, 2023
1 parent f426dab commit 4a4a5bd
Show file tree
Hide file tree
Showing 17 changed files with 191 additions and 98 deletions.
8 changes: 4 additions & 4 deletions botorch/acquisition/multi_objective/logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from botorch.utils.transforms import (
concatenate_pending_points,
is_fully_bayesian,
is_ensemble,
match_batch_shape,
t_batch_mode_transform,
)
Expand Down Expand Up @@ -454,9 +454,9 @@ def forward(self, X: Tensor) -> Tensor:
# 1) X and X, and
# 2) X and X_baseline.
posterior = self.model.posterior(X_full)
# Account for possible one-to-many transform and the MCMC batch dimension in
# `SaasFullyBayesianSingleTaskGP`
event_shape_lag = 1 if is_fully_bayesian(self.model) else 2
# Account for possible one-to-many transform and the model batch dimensions in
# ensemble models.
event_shape_lag = 1 if is_ensemble(self.model) else 2
n_w = (
posterior._extended_shape()[X_full.dim() - event_shape_lag]
// X_full.shape[-2]
Expand Down
4 changes: 2 additions & 2 deletions botorch/acquisition/multi_objective/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from botorch.utils.objective import compute_smoothed_feasibility_indicator
from botorch.utils.transforms import (
concatenate_pending_points,
is_fully_bayesian,
is_ensemble,
match_batch_shape,
t_batch_mode_transform,
)
Expand Down Expand Up @@ -453,7 +453,7 @@ def forward(self, X: Tensor) -> Tensor:
posterior = self.model.posterior(X_full)
# Account for possible one-to-many transform and the MCMC batch dimension in
# `SaasFullyBayesianSingleTaskGP`
event_shape_lag = 1 if is_fully_bayesian(self.model) else 2
event_shape_lag = 1 if is_ensemble(self.model) else 2
n_w = (
posterior._extended_shape()[X_full.dim() - event_shape_lag]
// X_full.shape[-2]
Expand Down
4 changes: 2 additions & 2 deletions botorch/acquisition/multi_objective/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from botorch.utils.multi_objective.pareto import is_non_dominated
from botorch.utils.objective import compute_feasibility_indicator
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.transforms import is_fully_bayesian
from botorch.utils.transforms import is_ensemble
from torch import Tensor


Expand Down Expand Up @@ -110,7 +110,7 @@ def prune_inferior_points_multi_objective(
with `N_nz` the number of points in `X` that have non-zero (empirical,
under `num_samples` samples) probability of being pareto optimal.
"""
if marginalize_dim is None and is_fully_bayesian(model):
if marginalize_dim is None and is_ensemble(model):
# TODO: Properly deal with marginalizing fully Bayesian models
marginalize_dim = MCMC_DIM

Expand Down
4 changes: 2 additions & 2 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from botorch.sampling.pathwise import draw_matheron_paths
from botorch.utils.objective import compute_feasibility_indicator
from botorch.utils.sampling import optimize_posterior_samples
from botorch.utils.transforms import is_fully_bayesian, normalize_indices
from botorch.utils.transforms import is_ensemble, normalize_indices
from torch import Tensor


Expand Down Expand Up @@ -263,7 +263,7 @@ def prune_inferior_points(
with `N_nz` the number of points in `X` that have non-zero (empirical,
under `num_samples` samples) probability of being the best point.
"""
if marginalize_dim is None and is_fully_bayesian(model):
if marginalize_dim is None and is_ensemble(model):
# TODO: Properly deal with marginalizing fully Bayesian models
marginalize_dim = MCMC_DIM

Expand Down
12 changes: 8 additions & 4 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils import validate_input_scaling
from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM
from gpytorch.constraints import GreaterThan
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import MaternKernel, ScaleKernel
Expand Down Expand Up @@ -327,6 +327,9 @@ class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel):
>>> posterior = saas_gp.posterior(test_X)
"""

_is_fully_bayesian = True
_is_ensemble = True

def __init__(
self,
train_X: Tensor,
Expand Down Expand Up @@ -508,7 +511,7 @@ def posterior(
observation_noise: bool = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> FullyBayesianPosterior:
) -> GaussianMixturePosterior:
r"""Computes the posterior over model outputs at the provided points.
Args:
Expand All @@ -526,7 +529,8 @@ def posterior(
posterior_transform: An optional PosteriorTransform.
Returns:
A `FullyBayesianPosterior` object. Includes observation noise if specified.
A `GaussianMixturePosterior` object. Includes observation noise
if specified.
"""
self._check_if_fitted()
posterior = super().posterior(
Expand All @@ -536,5 +540,5 @@ def posterior(
posterior_transform=posterior_transform,
**kwargs,
)
posterior = FullyBayesianPosterior(distribution=posterior.distribution)
posterior = GaussianMixturePosterior(distribution=posterior.distribution)
return posterior
12 changes: 8 additions & 4 deletions botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import MaternKernel
Expand Down Expand Up @@ -189,6 +189,9 @@ class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
>>> posterior = mtsaas_gp.posterior(test_X)
"""

_is_fully_bayesian = True
_is_ensemble = True

def __init__(
self,
train_X: Tensor,
Expand Down Expand Up @@ -335,11 +338,12 @@ def posterior(
observation_noise: bool = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> FullyBayesianPosterior:
) -> GaussianMixturePosterior:
r"""Computes the posterior over model outputs at the provided points.
Returns:
A `FullyBayesianPosterior` object. Includes observation noise if specified.
A `GaussianMixturePosterior` object. Includes observation noise
if specified.
"""
self._check_if_fitted()
posterior = super().posterior(
Expand All @@ -349,7 +353,7 @@ def posterior(
posterior_transform=posterior_transform,
**kwargs,
)
posterior = FullyBayesianPosterior(distribution=posterior.distribution)
posterior = GaussianMixturePosterior(distribution=posterior.distribution)
return posterior

def forward(self, X: Tensor) -> MultivariateNormal:
Expand Down
12 changes: 6 additions & 6 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
mod_batch_shape,
multioutput_to_batch_mode_transform,
)
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.utils.transforms import is_fully_bayesian
from botorch.utils.transforms import is_ensemble
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from torch import Tensor
Expand Down Expand Up @@ -619,7 +619,7 @@ def posterior(
- If no `posterior_transform` is provided and the component models have no
`outcome_transform`, or if the component models only use linear outcome
transforms like `Standardize` (i.e. not `Log`), returns a
`GPyTorchPosterior` or `FullyBayesianPosterior` object,
`GPyTorchPosterior` or `GaussianMixturePosterior` object,
representing `batch_shape` joint distributions over `q` points
and the outputs selected by `output_indices` each. Includes
measurement noise if `observation_noise` is specified.
Expand Down Expand Up @@ -650,16 +650,16 @@ def posterior(
mvns = [p.distribution for p in posterior.posteriors]
# Combining MTMVNs into a single MTMVN is currently not supported.
if not any(isinstance(m, MultitaskMultivariateNormal) for m in mvns):
# Return the result as a GPyTorchPosterior/FullyBayesianPosterior.
# Return the result as a GPyTorchPosterior/GaussianMixturePosterior.
mvn = (
mvns[0]
if len(mvns) == 1
else MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
)
if any(is_fully_bayesian(m) for m in self.models):
if any(is_ensemble(m) for m in self.models):
# Mixing fully Bayesian and other GP models is currently
# not supported.
posterior = FullyBayesianPosterior(distribution=mvn)
posterior = GaussianMixturePosterior(distribution=mvn)
else:
posterior = GPyTorchPosterior(distribution=mvn)
if posterior_transform is not None:
Expand Down
8 changes: 7 additions & 1 deletion botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,23 @@ class Model(Module, ABC):
`Tensor` or `Module` type are automatically registered so they can be moved and/or
cast with the `to` method, automatically differentiated, and used with CUDA.
Args:
Attributes:
_has_transformed_inputs: A boolean denoting whether `train_inputs` are currently
stored as transformed or not.
_original_train_inputs: A Tensor storing the original train inputs for use in
`_revert_to_original_inputs`. Note that this is necessary since
transform / untransform cycle introduces numerical errors which lead
to upstream errors during training.
_is_fully_bayesian: Returns `True` if this is a fully Bayesian model.
_is_ensemble: Returns `True` if this model consists of multiple models
that are stored in an additional batch dimension. This is true for the fully
Bayesian models.
""" # noqa: E501

_has_transformed_inputs: bool = False
_original_train_inputs: Optional[Tensor] = None
_is_fully_bayesian = False
_is_ensemble = False

@abstractmethod
def posterior(
Expand Down
6 changes: 5 additions & 1 deletion botorch/posteriors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
# LICENSE file in the root directory of this source tree.

from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior
from botorch.posteriors.fully_bayesian import (
FullyBayesianPosterior,
GaussianMixturePosterior,
)
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.higher_order import HigherOrderGPPosterior
from botorch.posteriors.multitask import MultitaskGPPosterior
Expand All @@ -16,6 +19,7 @@

__all__ = [
"DeterministicPosterior",
"GaussianMixturePosterior",
"FullyBayesianPosterior",
"GPyTorchPosterior",
"HigherOrderGPPosterior",
Expand Down
26 changes: 17 additions & 9 deletions botorch/posteriors/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

from typing import Callable, Optional, Tuple
from warnings import warn

import torch
from botorch.posteriors.gpytorch import GPyTorchPosterior
Expand Down Expand Up @@ -54,7 +55,7 @@ def batched_bisect(
return center


def _quantile(posterior: FullyBayesianPosterior, value: Tensor) -> Tensor:
def _quantile(posterior: GaussianMixturePosterior, value: Tensor) -> Tensor:
r"""Compute the posterior quantiles for the mixture of models."""
if value.numel() > 1:
return torch.stack(
Expand All @@ -78,13 +79,13 @@ def _quantile(posterior: FullyBayesianPosterior, value: Tensor) -> Tensor:
)


class FullyBayesianPosterior(GPyTorchPosterior):
r"""A posterior for a fully Bayesian model.
class GaussianMixturePosterior(GPyTorchPosterior):
r"""A Gaussian mixture posterior.
The MCMC batch dimension that corresponds to the models in the mixture is located
at `MCMC_DIM` (defined at the top of this file). Note that while each MCMC sample
corresponds to a Gaussian posterior, the fully Bayesian posterior is rather a
mixture of Gaussian distributions.
corresponds to a Gaussian posterior, the posterior is rather a mixture of Gaussian
distributions.
"""

def __init__(self, distribution: MultivariateNormal) -> None:
Expand Down Expand Up @@ -137,7 +138,14 @@ def batch_range(self) -> Tuple[int, int]:
provide consistency in the acquisition values, i.e., to ensure that a
candidate produces same value regardless of its position on the t-batch.
"""
if self._is_mt:
return (0, -2)
else:
return (0, -1)
return (0, -2) if self._is_mt else (0, -1)


class FullyBayesianPosterior(GaussianMixturePosterior):
"""For backwards compatibility."""

warn(
"`FullyBayesianPosterior` is marked for deprecation, consider using "
"`GaussianMixturePosterior` instead.",
DeprecationWarning,
)
27 changes: 15 additions & 12 deletions botorch/posteriors/posterior_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
from typing import Any, List, Optional

import torch
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM
from botorch.posteriors import FullyBayesianPosterior
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM
from botorch.posteriors.posterior import Posterior
from torch import Tensor


class PosteriorList(Posterior):
r"""A Posterior represented by a list of independent Posteriors.
When at least one of the posteriors is a `FullyBayesianPosterior`, the other
posteriors are expanded to match the size of the `FullyBayesianPosterior`.
When at least one of the posteriors is a `GaussianMixturePosterior`, the other
posteriors are expanded to match the size of the `GaussianMixturePosterior`.
"""

def __init__(self, *posteriors: Posterior) -> None:
Expand All @@ -44,16 +45,16 @@ def __init__(self, *posteriors: Posterior) -> None:
self.posteriors = list(posteriors)

@cached_property
def _is_fully_bayesian(self) -> bool:
r"""Check if any of the posteriors is a `FullyBayesianPosterior`."""
return any(isinstance(p, FullyBayesianPosterior) for p in self.posteriors)
def _is_gaussian_mixture(self) -> bool:
r"""Check if any of the posteriors is a `GaussianMixturePosterior`."""
return any(isinstance(p, GaussianMixturePosterior) for p in self.posteriors)

def _get_mcmc_batch_dimension(self) -> int:
"""Return the number of MCMC samples in the corresponding batch dimension."""
mcmc_samples = [
p.mean.shape[MCMC_DIM]
for p in self.posteriors
if isinstance(p, FullyBayesianPosterior)
if isinstance(p, (GaussianMixturePosterior, FullyBayesianPosterior))
]
if len(set(mcmc_samples)) > 1:
raise NotImplementedError(
Expand All @@ -70,12 +71,12 @@ def _reshape_tensor(X: Tensor, mcmc_samples: int) -> Tensor:

def _reshape_and_cat(self, tensors: List[Tensor]):
r"""Reshape, if needed, and concatenate (across dim=-1) a list of tensors."""
if self._is_fully_bayesian:
if self._is_gaussian_mixture:
mcmc_samples = self._get_mcmc_batch_dimension()
return torch.cat(
[
x
if isinstance(p, FullyBayesianPosterior)
if isinstance(p, GaussianMixturePosterior)
else self._reshape_tensor(x, mcmc_samples=mcmc_samples)
for x, p in zip(tensors, self.posteriors)
],
Expand Down Expand Up @@ -112,16 +113,18 @@ def _extended_shape(
r"""Returns the shape of the samples produced by the posterior with
the given `sample_shape`.
If there's at least one `FullyBayesianPosterior`, the MCMC dimension
If there's at least one `GaussianMixturePosterior`, the MCMC dimension
is included the `_extended_shape`.
"""
if self._is_fully_bayesian:
if self._is_gaussian_mixture:
mcmc_shape = torch.Size([self._get_mcmc_batch_dimension()])
extend_dim = MCMC_DIM + 1 # The dimension to inject MCMC shape.
extended_shapes = []
for p in self.posteriors:
es = p._extended_shape(sample_shape=sample_shape)
if self._is_fully_bayesian and not isinstance(p, FullyBayesianPosterior):
if self._is_gaussian_mixture and not isinstance(
p, GaussianMixturePosterior
):
# Extend the shapes of non-fully Bayesian ones to match.
extended_shapes.append(es[:extend_dim] + mcmc_shape + es[extend_dim:])
else:
Expand Down
Loading

0 comments on commit 4a4a5bd

Please sign in to comment.