diff --git a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py index d6afa6ad0b8..54fe9808a47 100644 --- a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py +++ b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py @@ -113,38 +113,31 @@ def _input_transform_argparse_normalize( A dictionary with input transform kwargs. """ input_transform_options = input_transform_options or {} - d = input_transform_options.get("d", len(dataset.feature_names)) - bounds = torch.as_tensor( - search_space_digest.bounds, - dtype=torch_dtype, - device=torch_device, - ).T + d = input_transform_options.get("d", len(dataset.feature_names)) if isinstance(dataset, RankingDataset) and isinstance(dataset.X, SliceContainer): d = dataset.X.values.shape[-1] - - indices = list(range(d)) - task_features = normalize_indices(search_space_digest.task_features, d=d) - - for task_feature in sorted(task_features, reverse=True): - del indices[task_feature] - input_transform_options.setdefault("d", d) - if ("indices" in input_transform_options) or (len(indices) < d): - input_transform_options.setdefault("indices", indices) - - if ( - ("bounds" not in input_transform_options) - and (bounds.shape[-1] < d) - and (len(search_space_digest.task_features) == 0) - ): - raise NotImplementedError( - "Normalize transform bounds should be specified explicitly if there" - " are task features outside the search space." - ) - - input_transform_options.setdefault("bounds", bounds) + if "indices" not in input_transform_options: + indices = list(range(d)) + task_features = normalize_indices(search_space_digest.task_features, d=d) + for task_feature in task_features: + indices.remove(task_feature) + input_transform_options["indices"] = indices + + if "bounds" not in input_transform_options: + bounds = torch.as_tensor( + search_space_digest.bounds, + dtype=torch_dtype, + device=torch_device, + ).T + if (bounds.shape[-1] < d) and (len(search_space_digest.task_features) == 0): + raise NotImplementedError( + "Normalize transform bounds should be specified explicitly if there " + "are task features outside the search space." + ) + input_transform_options["bounds"] = bounds return input_transform_options diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 63cade29e01..76fae90dcff 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -18,7 +18,6 @@ from typing import Any import numpy as np - import torch from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata @@ -43,6 +42,7 @@ ) from ax.models.torch.utils import ( _to_inequality_constraints, + normalize_indices, pick_best_out_of_sample_point_acqf_class, predict_from_model, ) @@ -70,12 +70,14 @@ ChainedInputTransform, InputPerturbation, InputTransform, + Normalize, ) from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.utils.containers import SliceContainer from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset from botorch.utils.dispatcher import Dispatcher +from botorch.utils.types import _DefaultType from gpytorch.kernels import Kernel from gpytorch.likelihoods.likelihood import Likelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood @@ -127,7 +129,7 @@ def _extract_model_kwargs( def _make_botorch_input_transform( - input_transform_classes: list[type[InputTransform]], + input_transform_classes: list[type[InputTransform]] | _DefaultType, input_transform_options: dict[str, dict[str, Any]], dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, @@ -135,15 +137,46 @@ def _make_botorch_input_transform( """ Makes a BoTorch input transform from the provided input classes and options. """ + if isinstance(input_transform_classes, _DefaultType): + transforms = _construct_default_input_transforms( + search_space_digest=search_space_digest, dataset=dataset + ) + else: + transforms = _construct_specified_input_transforms( + input_transform_classes=input_transform_classes, + dataset=dataset, + search_space_digest=search_space_digest, + input_transform_options=input_transform_options, + ) + if len(transforms) == 0: + return None + elif len(transforms) > 1: + return ChainedInputTransform( + **{f"tf{i}": transforms[i] for i in range(len(transforms))} + ) + else: + return transforms[0] + + +def _construct_specified_input_transforms( + input_transform_classes: list[type[InputTransform]], + input_transform_options: dict[str, dict[str, Any]], + dataset: SupervisedDataset, + search_space_digest: SearchSpaceDigest, +) -> list[InputTransform]: + """Constructs a list of input transforms from input transform classes and + options provided in ``ModelConfig``. + """ if not ( isinstance(input_transform_classes, list) and all(issubclass(c, InputTransform) for c in input_transform_classes) ): - raise UserInputError("Expected a list of input transforms.") + raise UserInputError( + "Expected a list of input transform classes. " + f"Got {input_transform_classes=}." + ) if search_space_digest.robust_digest is not None: input_transform_classes = [InputPerturbation] + input_transform_classes - if len(input_transform_classes) == 0: - return None input_transform_kwargs = [ input_transform_argparse( @@ -157,7 +190,7 @@ def _make_botorch_input_transform( for transform_class in input_transform_classes ] - input_transforms = [ + return [ # pyre-fixme[45]: Cannot instantiate abstract class `InputTransform`. transform_class(**single_input_transform_kwargs) for transform_class, single_input_transform_kwargs in zip( @@ -165,15 +198,47 @@ def _make_botorch_input_transform( ) ] - input_transform_instance = ( - ChainedInputTransform( - **{f"tf{i}": input_transforms[i] for i in range(len(input_transforms))} + +def _construct_default_input_transforms( + search_space_digest: SearchSpaceDigest, + dataset: SupervisedDataset, +) -> list[InputTransform]: + """Construct the default input transforms for the given search space digest. + + The default transforms are added in this order: + - If the search space digest has a robust digest, an ``InputPerturbation`` transform + is used. + - If the bounds for the non-task features are not [0, 1], a ``Normalize`` transform + is used. The transfrom only applies to the non-task features. + """ + transforms = [] + # Add InputPerturbation if there is a robust digest. + if search_space_digest.robust_digest is not None: + transforms.append( + InputPerturbation( + **input_transform_argparse( + InputPerturbation, + dataset=dataset, + search_space_digest=search_space_digest, + ) + ) ) - if len(input_transforms) > 1 - else input_transforms[0] - ) + # Processing for Normalize. + bounds = torch.tensor(search_space_digest.bounds).T + indices = list(range(bounds.shape[-1])) + # Remove task features. + for task_feature in normalize_indices( + search_space_digest.task_features, d=bounds.shape[-1] + ): + indices.remove(task_feature) + # Skip the Normalize transform if the bounds are [0, 1]. + if not ( + torch.allclose(bounds[0, indices], torch.zeros(len(indices))) + and torch.allclose(bounds[1, indices], torch.ones(len(indices))) + ): + transforms.append(Normalize(d=bounds.shape[-1], indices=indices, bounds=bounds)) - return input_transform_instance + return transforms def _make_botorch_outcome_transform( diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index d40bbec9825..bc2aec504a5 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -39,6 +39,7 @@ from botorch.models.transforms.outcome import OutcomeTransform from botorch.utils.datasets import SupervisedDataset from botorch.utils.transforms import is_fully_bayesian +from botorch.utils.types import _DefaultType, DEFAULT from gpytorch.kernels.kernel import Kernel from gpytorch.likelihoods import Likelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood @@ -108,7 +109,7 @@ class string names and the values are dictionaries of input transform model_options: dict[str, Any] = field(default_factory=dict) mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood mll_options: dict[str, Any] = field(default_factory=dict) - input_transform_classes: list[type[InputTransform]] | None = None + input_transform_classes: list[type[InputTransform]] | _DefaultType | None = DEFAULT input_transform_options: dict[str, dict[str, Any]] | None = field( default_factory=dict ) diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index bbea4c38c64..43e05439b75 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -54,6 +54,7 @@ from botorch.sampling.normal import SobolQMCNormalSampler from botorch.utils.constraints import get_outcome_constraint_transforms from botorch.utils.datasets import SupervisedDataset +from botorch.utils.types import DEFAULT from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from pyre_extensions import none_throws @@ -335,7 +336,7 @@ def test_fit(self, mock_fit: Mock) -> None: model_options={}, mll_class=ExactMarginalLogLikelihood, mll_options={}, - input_transform_classes=None, + input_transform_classes=DEFAULT, input_transform_options={}, outcome_transform_classes=None, outcome_transform_options={}, @@ -635,7 +636,7 @@ def test_feature_importances(self) -> None: ) model.surrogate.fit( datasets=self.block_design_training_data, - search_space_digest=SearchSpaceDigest(feature_names=[], bounds=[]), + search_space_digest=self.search_space_digest, ) if botorch_model_class == SaasFullyBayesianSingleTaskGP: mcmc_samples = { @@ -741,10 +742,7 @@ def test_evaluate_acquisition_function( ) model.surrogate.fit( datasets=self.block_design_training_data, - search_space_digest=SearchSpaceDigest( - feature_names=[], - bounds=[], - ), + search_space_digest=self.search_space_digest, ) model.evaluate_acquisition_function( X=self.X_test, diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index ae6d741533c..8d0a6cf4392 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -22,6 +22,7 @@ from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel from ax.models.torch.botorch_modular.surrogate import ( _extract_model_kwargs, + submodel_input_constructor, Surrogate, SurrogateSpec, ) @@ -389,9 +390,12 @@ def test_dtype_and_device_properties(self) -> None: self.assertEqual(self.dtype, surrogate.dtype) self.assertEqual(self.device, surrogate.device) - @patch.object(SingleTaskGP, "__init__", return_value=None) + @patch( + f"{SURROGATE_PATH}.submodel_input_constructor", + wraps=submodel_input_constructor, + ) @patch(f"{SURROGATE_PATH}.fit_botorch_model") - def test_fit_model_reuse(self, mock_fit: Mock, mock_init: Mock) -> None: + def test_fit_model_reuse(self, mock_fit: Mock, mock_constructor: Mock) -> None: surrogate, _ = self._get_surrogate( botorch_model_class=SingleTaskGP, use_outcome_transform=False ) @@ -404,7 +408,7 @@ def test_fit_model_reuse(self, mock_fit: Mock, mock_init: Mock) -> None: search_space_digest=search_space_digest, ) mock_fit.assert_called_once() - mock_init.assert_called_once() + mock_constructor.assert_called_once() key = tuple(self.training_data[0].outcome_names) submodel = surrogate._submodels[key] self.assertIs(surrogate._last_datasets[key], self.training_data[0]) @@ -417,7 +421,7 @@ def test_fit_model_reuse(self, mock_fit: Mock, mock_init: Mock) -> None: ) # Still only called once -- i.e. not fitted again: mock_fit.assert_called_once() - mock_init.assert_called_once() + mock_constructor.assert_called_once() # Model is still the same object. self.assertIs(submodel, surrogate._submodels[key]) @@ -448,23 +452,12 @@ def test_construct_model(self) -> None: surrogate, _ = self._get_surrogate( botorch_model_class=botorch_model_class, use_outcome_transform=False ) - with self.assertRaisesRegex(TypeError, "posterior"): - # Base `Model` does not implement `posterior`, so instantiating it here - # will fail. - Surrogate()._construct_model( - dataset=self.training_data[0], - search_space_digest=self.search_space_digest, - model_config=ModelConfig(), - default_botorch_model_class=Model, - state_dict=None, - refit=True, - ) with patch.object( botorch_model_class, "construct_inputs", wraps=botorch_model_class.construct_inputs, ) as mock_construct_inputs, patch.object( - botorch_model_class, "__init__", return_value=None + botorch_model_class, "__init__", return_value=None, autospec=True ) as mock_init, patch(f"{SURROGATE_PATH}.fit_botorch_model") as mock_fit: model = surrogate._construct_model( dataset=self.training_data[0], @@ -479,7 +472,9 @@ def test_construct_model(self) -> None: call_kwargs = mock_init.call_args.kwargs self.assertTrue(torch.equal(call_kwargs["train_X"], self.Xs[0])) self.assertTrue(torch.equal(call_kwargs["train_Y"], self.Ys[0])) - self.assertEqual(len(call_kwargs), 2) + self.assertIsInstance(call_kwargs["input_transform"], Normalize) + self.assertIsNone(call_kwargs["outcome_transform"]) + self.assertEqual(len(call_kwargs), 4) mock_construct_inputs.assert_called_with( training_data=self.training_data[0], @@ -1120,9 +1115,12 @@ def test_w_robust_digest(self) -> None: robust_digest=robust_digest, ), ) - intf = checked_cast(InputPerturbation, surrogate.model.input_transform) - self.assertIsInstance(intf, InputPerturbation) - self.assertTrue(torch.equal(intf.perturbation_set, torch.zeros(2, 2))) + intf = surrogate.model.input_transform + self.assertIsInstance(intf, ChainedInputTransform) + intf_values = list(intf.values()) + self.assertIsInstance(intf_values[0], InputPerturbation) + self.assertIsInstance(intf_values[1], Normalize) + self.assertTrue(torch.equal(intf_values[0].perturbation_set, torch.zeros(2, 2))) def test_fit_mixed(self) -> None: # Test model construction with categorical variables. @@ -1591,7 +1589,10 @@ def test_construct_custom_model(self) -> None: @mock_botorch_optimize def test_w_robust_digest(self) -> None: surrogate = Surrogate( - botorch_model_class=SingleTaskGP, + surrogate_spec=SurrogateSpec( + botorch_model_class=SingleTaskGP, + input_transform_classes=[], + ) ) # Error handling. with self.assertRaisesRegex(NotImplementedError, "Environmental variable"): @@ -1613,9 +1614,6 @@ def test_w_robust_digest(self) -> None: multiplicative=False, ) # Input perturbation is constructed. - surrogate = Surrogate( - botorch_model_class=SingleTaskGP, - ) surrogate.fit( datasets=self.supervised_training_data, search_space_digest=SearchSpaceDigest(