From 1a78f275c3eceee12e6621a51996fab99e8fe253 Mon Sep 17 00:00:00 2001 From: David Eriksson Date: Tue, 14 Nov 2023 16:21:07 -0800 Subject: [PATCH] Clean up is_fully_bayesian Summary: X-link: https://github.com/pytorch/botorch/pull/2108 This attempts to clean up the usage of `is_fully_bayesian` and also separately treat fully Bayesian models from ensemble models. The main changes in diff are to: - Add an `_is_fully_bayesian` attribute to `Model`. This is `True` for fully Bayesian models that rely on Pyro/NUTS to be fitted (they need some special handling for fitting and `state_dict` loading/saving. - Add an `_is_ensemble` attribute to `Model`. This indicates whether the model is a collection of multiple models that are stored in an additional batch dimension. This is hopefully a better classification, but I'm open to a different name here. - Rename `FullyBayesianPosterior` to `GaussianMixturePosterior` since that is more descriptive and plays better with the other changes. Reviewed By: esantorella Differential Revision: D50884342 --- ax/models/torch/botorch.py | 4 ++-- ax/models/torch/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ax/models/torch/botorch.py b/ax/models/torch/botorch.py index 054c8a7abe8..7cc8327cb1a 100644 --- a/ax/models/torch/botorch.py +++ b/ax/models/torch/botorch.py @@ -41,7 +41,7 @@ from botorch.models import ModelList from botorch.models.model import Model from botorch.utils.datasets import SupervisedDataset -from botorch.utils.transforms import is_fully_bayesian +from botorch.utils.transforms import is_ensemble from torch import Tensor from torch.nn import ModuleList # @manual @@ -572,7 +572,7 @@ def get_feature_importances_from_botorch_model( ) if ls.ndim == 2: ls = ls.unsqueeze(0) - if is_fully_bayesian(m): # Take the median over the MCMC samples + if is_ensemble(m): # Take the median over the model batch dimension ls = torch.quantile(ls, q=0.5, dim=0, keepdim=True) lengthscales.append(ls) lengthscales = torch.cat(lengthscales, dim=0) diff --git a/ax/models/torch/utils.py b/ax/models/torch/utils.py index 5e0af0038e3..a8a58f38e46 100644 --- a/ax/models/torch/utils.py +++ b/ax/models/torch/utils.py @@ -51,7 +51,7 @@ from botorch.acquisition.utils import get_infeasible_cost from botorch.models import ModelListGP, SingleTaskGP from botorch.models.model import Model -from botorch.posteriors.fully_bayesian import FullyBayesianPosterior +from botorch.posteriors.fully_bayesian import GaussianMixturePosterior from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.posteriors.posterior_list import PosteriorList from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler @@ -627,7 +627,7 @@ def predict_from_model(model: Model, X: Tensor) -> Tuple[Tensor, Tensor]: with torch.no_grad(): # TODO: Allow Posterior to (optionally) return the full covariance matrix posterior = model.posterior(X) - if isinstance(posterior, FullyBayesianPosterior): + if isinstance(posterior, GaussianMixturePosterior): mean = posterior.mixture_mean.cpu().detach() var = posterior.mixture_variance.cpu().detach().clamp_min(0) elif isinstance(posterior, (GPyTorchPosterior, PosteriorList)):