From c05aac26db3c8f1cc4a14d9a53fb7b1904916d6e Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 26 Nov 2024 08:13:56 -0800 Subject: [PATCH 1/2] Add LogIntToFloat transform (#3091) Summary: This is a simple subclass of `IntToFloat` that only transforms log-scale parameters. Replacing `IntToFloat` with `LogIntToFloat` will avoid unnecessary use of continuous relaxation across the board, and allow us to utilize the various optimizers available in `Acquisition.optimize`. Additional context: With log-scale parameters, we have two options: transform them in Ax or transform them in BoTorch. Transforming them in Ax leads to both modeling and optimizing the parameter in the log-scale (good), but transforming in BoTorch leads to modeling in log-scale but optimizing in the raw scale (not ideal) and also introduces `TransformedPosterior` and some incompatibilities it brings. So, we want to transform log-scale parameters in Ax. Since log of an int parameter is no longer int, we have to relax them. But we don't want to relax any other int parameters, so we don't want to use `IntToFloat`. `LogIntToFloat` makes it possible to use continuous relaxation only for the log-scale parameters, which is a good step in the right direction. Differential Revision: D66244582 --- ax/core/search_space.py | 2 +- ax/modelbridge/transforms/int_to_float.py | 55 +++++++++++++++++-- .../tests/test_int_to_float_transform.py | 34 +++++++++++- ax/storage/transform_registry.py | 3 +- 4 files changed, 84 insertions(+), 10 deletions(-) diff --git a/ax/core/search_space.py b/ax/core/search_space.py index 446e9212a20..ecdbfe38eb9 100644 --- a/ax/core/search_space.py +++ b/ax/core/search_space.py @@ -67,7 +67,7 @@ class SearchSpace(Base): def __init__( self, - parameters: list[Parameter], + parameters: Sequence[Parameter], parameter_constraints: list[ParameterConstraint] | None = None, ) -> None: """Initialize SearchSpace diff --git a/ax/modelbridge/transforms/int_to_float.py b/ax/modelbridge/transforms/int_to_float.py index 5f161f50a13..74171a2eac2 100644 --- a/ax/modelbridge/transforms/int_to_float.py +++ b/ax/modelbridge/transforms/int_to_float.py @@ -12,6 +12,7 @@ from ax.core.observation import Observation, ObservationFeatures from ax.core.parameter import Parameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace +from ax.exceptions.core import UserInputError from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.rounding import ( contains_constrained_integer, @@ -65,18 +66,22 @@ def __init__( self.min_choices: int = checked_cast(int, config.get("min_choices", 0)) # Identify parameters that should be transformed - self.transform_parameters: set[str] = { + self.transform_parameters: set[str] = self._get_transform_parameters() + if contains_constrained := contains_constrained_integer( + self.search_space, self.transform_parameters + ): + self.rounding = "randomized" + self.contains_constrained_integer: bool = contains_constrained + + def _get_transform_parameters(self) -> set[str]: + """Identify parameters that should be transformed.""" + return { p_name for p_name, p in self.search_space.parameters.items() if isinstance(p, RangeParameter) and p.parameter_type == ParameterType.INT and ((p.cardinality() >= self.min_choices) or p.log_scale) } - if contains_constrained_integer(self.search_space, self.transform_parameters): - self.rounding = "randomized" - self.contains_constrained_integer: bool = True - else: - self.contains_constrained_integer: bool = False def transform_observation_features( self, observation_features: list[ObservationFeatures] @@ -183,3 +188,41 @@ def untransform_observation_features( obsf.parameters[p_name] = rounded_parameters[p_name] return observation_features + + +class LogIntToFloat(IntToFloat): + """Convert a log-scale RangeParameter of type int to type float. + + The behavior of this transform mirrors ``IntToFloat`` with the key difference + being that it only operates on log-scale parameters. + """ + + def __init__( + self, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + config: TConfig | None = None, + ) -> None: + if config is not None and "min_choices" in config: + raise UserInputError( + "`min_choices` cannot be specified for `LogIntToFloat` transform. " + ) + super().__init__( + search_space=search_space, + observations=observations, + modelbridge=modelbridge, + config=config, + ) + # Delete the attribute to avoid it presenting a misleading value. + del self.min_choices + + def _get_transform_parameters(self) -> set[str]: + """Identify parameters that should be transformed.""" + return { + p_name + for p_name, p in self.search_space.parameters.items() + if isinstance(p, RangeParameter) + and p.parameter_type == ParameterType.INT + and p.log_scale + } diff --git a/ax/modelbridge/transforms/tests/test_int_to_float_transform.py b/ax/modelbridge/transforms/tests/test_int_to_float_transform.py index 2f021c2cb49..f08b4ff9953 100644 --- a/ax/modelbridge/transforms/tests/test_int_to_float_transform.py +++ b/ax/modelbridge/transforms/tests/test_int_to_float_transform.py @@ -13,8 +13,8 @@ from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter from ax.core.parameter_constraint import OrderConstraint, SumConstraint from ax.core.search_space import RobustSearchSpace, SearchSpace -from ax.exceptions.core import UnsupportedError -from ax.modelbridge.transforms.int_to_float import IntToFloat +from ax.exceptions.core import UnsupportedError, UserInputError +from ax.modelbridge.transforms.int_to_float import IntToFloat, LogIntToFloat from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import get_robust_search_space @@ -324,3 +324,33 @@ def test_w_parameter_distributions(self) -> None: ) with self.assertRaisesRegex(UnsupportedError, "transform is not supported"): t.transform_search_space(rss) + + +class LogIntToFloatTransformTest(TestCase): + def test_log_int_to_float(self) -> None: + parameters = [ + RangeParameter("x", lower=1, upper=3, parameter_type=ParameterType.INT), + RangeParameter("y", lower=1, upper=50, parameter_type=ParameterType.INT), + RangeParameter( + "z", lower=1, upper=50, parameter_type=ParameterType.INT, log_scale=True + ), + ] + search_space = SearchSpace(parameters=parameters) + with self.assertRaisesRegex(UserInputError, "min_choices"): + LogIntToFloat(search_space=search_space, config={"min_choices": 5}) + t = LogIntToFloat(search_space=search_space) + self.assertFalse(hasattr(t, "min_choices")) + self.assertEqual(t.transform_parameters, {"z"}) + t_ss = t.transform_search_space(search_space) + self.assertEqual(t_ss.parameters["x"], parameters[0]) + self.assertEqual(t_ss.parameters["y"], parameters[1]) + self.assertEqual( + t_ss.parameters["z"], + RangeParameter( + name="z", + lower=0.50001, + upper=50.49999, + parameter_type=ParameterType.FLOAT, + log_scale=True, + ), + ) diff --git a/ax/storage/transform_registry.py b/ax/storage/transform_registry.py index 0c9518e96db..bac0e2fca26 100644 --- a/ax/storage/transform_registry.py +++ b/ax/storage/transform_registry.py @@ -18,7 +18,7 @@ from ax.modelbridge.transforms.derelativize import Derelativize from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice -from ax.modelbridge.transforms.int_to_float import IntToFloat +from ax.modelbridge.transforms.int_to_float import IntToFloat, LogIntToFloat from ax.modelbridge.transforms.ivw import IVW from ax.modelbridge.transforms.log import Log from ax.modelbridge.transforms.log_y import LogY @@ -95,6 +95,7 @@ TimeAsFeature: 27, TransformToNewSQ: 28, FillMissingParameters: 29, + LogIntToFloat: 30, } """ From 78ac44189293dd84bda64dfe40e5516e7928b72f Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 26 Nov 2024 08:13:56 -0800 Subject: [PATCH 2/2] Support for default input transforms in MBM (#3102) Summary: Input normalization is important to get the best performance out of the BoTorch models we use. The current setup relies on either using `UnitX` transform from Ax, or manually adding `Normalize` to `ModelConfig.input_transform_classes` to achieve input normalization. - `UnitX` is not ideal since it only applies to float valued `RangeParameters`. If we make everything into floats to use `UnitX`, we're locked into using continuous relaxation for acquisition optimization, which is something we want to move away from. - `Normalize` works well, particularly when `bounds` argument is provided (It's applied at each pass through the model, rather than once to the training data, but that's a separate discussion). However, requiring it as an explicit user input is a bit cumbersome. This diff adds the machinery for constructing a default set of input transforms. This implementation retains the previous `InputPerturbation` transform for robust optimization, and adds `Normalize` transform if the non-task features of the search space are not normalized. With this change, we should be able to remove `UnitX` transform from an MBM model(spec) without losing input normalization. Other considerations: - This setup only adds the default transforms if the `input_transform_classes` argument is left as `DEFAULT`. If the user supplies `input_transform_classes` or sets it to `None`, no defaults will be used. Would we want to add defaults even when the user supplies some transforms? If so, how would we decide whether to append or prepend the defaults? - As mentioned above, applying `Normalize` at each pass through the model is not super efficient. A vectorized application of an Ax transform should generally be more efficient. A longer term alternative would be to expand Ax-side `UnitX` to support more general parameter classes and types, without losing information in the process. This would require additional changes such as support for non-integer valued discrete `RangeParameters`, and support for non-integer discrete values in the mixed optimizer. Differential Revision: D65622788 --- .../input_constructors/input_transforms.py | 47 +++--- ax/models/torch/botorch_modular/surrogate.py | 112 ++++++++++++--- ax/models/torch/botorch_modular/utils.py | 6 +- ax/models/torch/tests/test_model.py | 10 +- ax/models/torch/tests/test_surrogate.py | 136 ++++++++++++++---- 5 files changed, 232 insertions(+), 79 deletions(-) diff --git a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py index d6afa6ad0b8..54fe9808a47 100644 --- a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py +++ b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py @@ -113,38 +113,31 @@ def _input_transform_argparse_normalize( A dictionary with input transform kwargs. """ input_transform_options = input_transform_options or {} - d = input_transform_options.get("d", len(dataset.feature_names)) - bounds = torch.as_tensor( - search_space_digest.bounds, - dtype=torch_dtype, - device=torch_device, - ).T + d = input_transform_options.get("d", len(dataset.feature_names)) if isinstance(dataset, RankingDataset) and isinstance(dataset.X, SliceContainer): d = dataset.X.values.shape[-1] - - indices = list(range(d)) - task_features = normalize_indices(search_space_digest.task_features, d=d) - - for task_feature in sorted(task_features, reverse=True): - del indices[task_feature] - input_transform_options.setdefault("d", d) - if ("indices" in input_transform_options) or (len(indices) < d): - input_transform_options.setdefault("indices", indices) - - if ( - ("bounds" not in input_transform_options) - and (bounds.shape[-1] < d) - and (len(search_space_digest.task_features) == 0) - ): - raise NotImplementedError( - "Normalize transform bounds should be specified explicitly if there" - " are task features outside the search space." - ) - - input_transform_options.setdefault("bounds", bounds) + if "indices" not in input_transform_options: + indices = list(range(d)) + task_features = normalize_indices(search_space_digest.task_features, d=d) + for task_feature in task_features: + indices.remove(task_feature) + input_transform_options["indices"] = indices + + if "bounds" not in input_transform_options: + bounds = torch.as_tensor( + search_space_digest.bounds, + dtype=torch_dtype, + device=torch_device, + ).T + if (bounds.shape[-1] < d) and (len(search_space_digest.task_features) == 0): + raise NotImplementedError( + "Normalize transform bounds should be specified explicitly if there " + "are task features outside the search space." + ) + input_transform_options["bounds"] = bounds return input_transform_options diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 63cade29e01..ccf2cc0e930 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -18,7 +18,6 @@ from typing import Any import numpy as np - import torch from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata @@ -43,6 +42,7 @@ ) from ax.models.torch.utils import ( _to_inequality_constraints, + normalize_indices, pick_best_out_of_sample_point_acqf_class, predict_from_model, ) @@ -70,12 +70,14 @@ ChainedInputTransform, InputPerturbation, InputTransform, + Normalize, ) from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.utils.containers import SliceContainer from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset from botorch.utils.dispatcher import Dispatcher +from botorch.utils.types import _DefaultType, DEFAULT from gpytorch.kernels import Kernel from gpytorch.likelihoods.likelihood import Likelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood @@ -127,7 +129,7 @@ def _extract_model_kwargs( def _make_botorch_input_transform( - input_transform_classes: list[type[InputTransform]], + input_transform_classes: list[type[InputTransform]] | _DefaultType, input_transform_options: dict[str, dict[str, Any]], dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, @@ -135,15 +137,46 @@ def _make_botorch_input_transform( """ Makes a BoTorch input transform from the provided input classes and options. """ + if isinstance(input_transform_classes, _DefaultType): + transforms = _construct_default_input_transforms( + search_space_digest=search_space_digest, dataset=dataset + ) + else: + transforms = _construct_specified_input_transforms( + input_transform_classes=input_transform_classes, + dataset=dataset, + search_space_digest=search_space_digest, + input_transform_options=input_transform_options, + ) + if len(transforms) == 0: + return None + elif len(transforms) > 1: + return ChainedInputTransform( + **{f"tf{i}": transforms[i] for i in range(len(transforms))} + ) + else: + return transforms[0] + + +def _construct_specified_input_transforms( + input_transform_classes: list[type[InputTransform]], + input_transform_options: dict[str, dict[str, Any]], + dataset: SupervisedDataset, + search_space_digest: SearchSpaceDigest, +) -> list[InputTransform]: + """Constructs a list of input transforms from input transform classes and + options provided in ``ModelConfig``. + """ if not ( isinstance(input_transform_classes, list) and all(issubclass(c, InputTransform) for c in input_transform_classes) ): - raise UserInputError("Expected a list of input transforms.") + raise UserInputError( + "Expected a list of input transform classes. " + f"Got {input_transform_classes=}." + ) if search_space_digest.robust_digest is not None: input_transform_classes = [InputPerturbation] + input_transform_classes - if len(input_transform_classes) == 0: - return None input_transform_kwargs = [ input_transform_argparse( @@ -157,7 +190,7 @@ def _make_botorch_input_transform( for transform_class in input_transform_classes ] - input_transforms = [ + return [ # pyre-fixme[45]: Cannot instantiate abstract class `InputTransform`. transform_class(**single_input_transform_kwargs) for transform_class, single_input_transform_kwargs in zip( @@ -165,15 +198,47 @@ def _make_botorch_input_transform( ) ] - input_transform_instance = ( - ChainedInputTransform( - **{f"tf{i}": input_transforms[i] for i in range(len(input_transforms))} + +def _construct_default_input_transforms( + search_space_digest: SearchSpaceDigest, + dataset: SupervisedDataset, +) -> list[InputTransform]: + """Construct the default input transforms for the given search space digest. + + The default transforms are added in this order: + - If the search space digest has a robust digest, an ``InputPerturbation`` transform + is used. + - If the bounds for the non-task features are not [0, 1], a ``Normalize`` transform + is used. The transfrom only applies to the non-task features. + """ + transforms = [] + # Add InputPerturbation if there is a robust digest. + if search_space_digest.robust_digest is not None: + transforms.append( + InputPerturbation( + **input_transform_argparse( + InputPerturbation, + dataset=dataset, + search_space_digest=search_space_digest, + ) + ) ) - if len(input_transforms) > 1 - else input_transforms[0] - ) + # Processing for Normalize. + bounds = torch.tensor(search_space_digest.bounds).T + indices = list(range(bounds.shape[-1])) + # Remove task features. + for task_feature in normalize_indices( + search_space_digest.task_features, d=bounds.shape[-1] + ): + indices.remove(task_feature) + # Skip the Normalize transform if the bounds are [0, 1]. + if not ( + torch.allclose(bounds[0, indices], torch.zeros(len(indices))) + and torch.allclose(bounds[1, indices], torch.ones(len(indices))) + ): + transforms.append(Normalize(d=bounds.shape[-1], indices=indices, bounds=bounds)) - return input_transform_instance + return transforms def _make_botorch_outcome_transform( @@ -315,11 +380,15 @@ def _raise_deprecation_warning( msg += "Please specify {k} via `model_configs`." warnings_raised = False default_is_dict = {"botorch_model_kwargs", "mll_kwargs"} + default_is_default = {"input_transform_classes"} for k, v in kwargs.items(): should_raise = False if k in default_is_dict: if v != {}: should_raise = True + elif k in default_is_default: + if v != DEFAULT: + should_raise = True elif (v is not None and k != "mll_class") or ( k == "mll_class" and v is not ExactMarginalLogLikelihood ): @@ -341,7 +410,7 @@ def get_model_config_from_deprecated_args( mll_options: dict[str, Any] | None, outcome_transform_classes: list[type[OutcomeTransform]] | None, outcome_transform_options: dict[str, dict[str, Any]] | None, - input_transform_classes: list[type[InputTransform]] | None, + input_transform_classes: list[type[InputTransform]] | _DefaultType | None, input_transform_options: dict[str, dict[str, Any]] | None, covar_module_class: type[Kernel] | None, covar_module_options: dict[str, Any] | None, @@ -417,6 +486,9 @@ class string names and the values are dictionaries of outcome transform input_transform_classes: List of BoTorch input transforms classes. Passed down to the BoTorch ``Model``. Multiple input transforms will be chained together using ``ChainedInputTransform``. + If `DEFAULT`, a default set of input transforms may be constructed + based on the search space digest. To disable this behavior, pass + in `input_transform_classes=None`. This argument is deprecated in favor of model_configs. input_transform_options: Input transform classes kwargs. The keys are class string names and the values are dictionaries of input transform @@ -464,7 +536,7 @@ class string names and the values are dictionaries of input transform likelihood_class: type[Likelihood] | None = None likelihood_kwargs: dict[str, Any] | None = None - input_transform_classes: list[type[InputTransform]] | None = None + input_transform_classes: list[type[InputTransform]] | _DefaultType | None = DEFAULT input_transform_options: dict[str, dict[str, Any]] | None = None outcome_transform_classes: list[type[OutcomeTransform]] | None = None @@ -566,6 +638,9 @@ class string names and the values are dictionaries of outcome transform input_transform_classes: List of BoTorch input transforms classes. Passed down to the BoTorch ``Model``. Multiple input transforms will be chained together using ``ChainedInputTransform``. + If `DEFAULT`, a default set of input transforms may be constructed + based on the search space digest. To disable this behavior, pass + in `input_transform_classes=None`. This argument is deprecated in favor of model_configs. input_transform_options: Input transform classes kwargs. The keys are class string names and the values are dictionaries of input transform @@ -608,7 +683,9 @@ def __init__( mll_options: dict[str, Any] | None = None, outcome_transform_classes: list[type[OutcomeTransform]] | None = None, outcome_transform_options: dict[str, dict[str, Any]] | None = None, - input_transform_classes: list[type[InputTransform]] | None = None, + input_transform_classes: list[type[InputTransform]] + | _DefaultType + | None = DEFAULT, input_transform_options: dict[str, dict[str, Any]] | None = None, covar_module_class: type[Kernel] | None = None, covar_module_options: dict[str, Any] | None = None, @@ -1053,8 +1130,6 @@ def cross_validate( search_space_digest=search_space_digest, model_config=model_config, default_botorch_model_class=none_throws(default_botorch_model_class), - # pyre-fixme [6]: state_dict() has a generic dict[str, Any] return type - # but it is actually an OrderedDict[str, Tensor]. state_dict=state_dict, refit=self.refit_on_cv, ) @@ -1073,7 +1148,6 @@ def cross_validate( train_mask[i] = 1 # evaluate model fit metric diag_fn = DIAGNOSTIC_FNS[none_throws(self.surrogate_spec.eval_criterion)] - # pyre-ignore [28]: Unexpected keyword argument `y_obs` to anonymous call. return diag_fn( y_obs=Y.view(-1).numpy(), y_pred=pred_Y, diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index d40bbec9825..dbb9d32d247 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -39,6 +39,7 @@ from botorch.models.transforms.outcome import OutcomeTransform from botorch.utils.datasets import SupervisedDataset from botorch.utils.transforms import is_fully_bayesian +from botorch.utils.types import _DefaultType, DEFAULT from gpytorch.kernels.kernel import Kernel from gpytorch.likelihoods import Likelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood @@ -81,6 +82,9 @@ class string names and the values are dictionaries of outcome transform input_transform_classes: List of BoTorch input transforms classes. Passed down to the BoTorch ``Model``. Multiple input transforms will be chained together using ``ChainedInputTransform``. + If `DEFAULT`, a default set of input transforms may be constructed + based on the search space digest. To disable this behavior, pass + in `input_transform_classes=None`. input_transform_options: Input transform classes kwargs. The keys are class string names and the values are dictionaries of input transform kwargs. For example, @@ -108,7 +112,7 @@ class string names and the values are dictionaries of input transform model_options: dict[str, Any] = field(default_factory=dict) mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood mll_options: dict[str, Any] = field(default_factory=dict) - input_transform_classes: list[type[InputTransform]] | None = None + input_transform_classes: list[type[InputTransform]] | _DefaultType | None = DEFAULT input_transform_options: dict[str, dict[str, Any]] | None = field( default_factory=dict ) diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index bbea4c38c64..43e05439b75 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -54,6 +54,7 @@ from botorch.sampling.normal import SobolQMCNormalSampler from botorch.utils.constraints import get_outcome_constraint_transforms from botorch.utils.datasets import SupervisedDataset +from botorch.utils.types import DEFAULT from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from pyre_extensions import none_throws @@ -335,7 +336,7 @@ def test_fit(self, mock_fit: Mock) -> None: model_options={}, mll_class=ExactMarginalLogLikelihood, mll_options={}, - input_transform_classes=None, + input_transform_classes=DEFAULT, input_transform_options={}, outcome_transform_classes=None, outcome_transform_options={}, @@ -635,7 +636,7 @@ def test_feature_importances(self) -> None: ) model.surrogate.fit( datasets=self.block_design_training_data, - search_space_digest=SearchSpaceDigest(feature_names=[], bounds=[]), + search_space_digest=self.search_space_digest, ) if botorch_model_class == SaasFullyBayesianSingleTaskGP: mcmc_samples = { @@ -741,10 +742,7 @@ def test_evaluate_acquisition_function( ) model.surrogate.fit( datasets=self.block_design_training_data, - search_space_digest=SearchSpaceDigest( - feature_names=[], - bounds=[], - ), + search_space_digest=self.search_space_digest, ) model.evaluate_acquisition_function( X=self.X_test, diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 9050fbfb444..70d6e4f3cc6 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -21,7 +21,11 @@ from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel from ax.models.torch.botorch_modular.surrogate import ( + _construct_default_input_transforms, + _construct_specified_input_transforms, _extract_model_kwargs, + _make_botorch_input_transform, + submodel_input_constructor, Surrogate, SurrogateSpec, ) @@ -46,10 +50,12 @@ from botorch.models.transforms.input import ( ChainedInputTransform, InputPerturbation, + Log10, Normalize, ) from botorch.models.transforms.outcome import OutcomeTransform, Standardize from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset +from botorch.utils.types import DEFAULT from gpytorch.constraints import GreaterThan, Interval from gpytorch.kernels import Kernel, LinearKernel, MaternKernel, RBFKernel, ScaleKernel from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood @@ -72,7 +78,7 @@ def __init__(self, train_X: Tensor, train_Y: Tensor) -> None: super().__init__(train_X=train_X, train_Y=train_Y) -class ExtractModelKwargsTest(TestCase): +class SurrogateInputConstructorsTest(TestCase): def test__extract_model_kwargs(self) -> None: feature_names = ["a", "b"] bounds = [(0.0, 1.0), (0.0, 1.0)] @@ -142,6 +148,87 @@ def test__extract_model_kwargs(self) -> None: self.assertEqual(model_kwargs["fidelity_features"], [0]) self.assertEqual(model_kwargs["categorical_features"], [1]) + def test__make_botorch_input_transform(self) -> None: + feature_names = ["a", "b"] + bounds = [(0.0, 1.0), (0.0, 1.0)] + search_space_digest = SearchSpaceDigest( + feature_names=feature_names, + bounds=bounds, + ) + dataset = SupervisedDataset( + X=torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + Y=torch.tensor([[1.0], [2.0]]), + feature_names=feature_names, + outcome_names=["metric"], + ) + + with self.subTest("Empty list of specified input transforms"): + with patch( + f"{SURROGATE_PATH}._construct_specified_input_transforms", + wraps=_construct_specified_input_transforms, + ) as mock_construct_specified_input_transforms: + transform = _make_botorch_input_transform( + input_transform_classes=[], + input_transform_options={}, + search_space_digest=search_space_digest, + dataset=dataset, + ) + mock_construct_specified_input_transforms.assert_called_once_with( + input_transform_classes=[], + input_transform_options={}, + search_space_digest=search_space_digest, + dataset=dataset, + ) + self.assertIsNone(transform) + + with self.subTest("Empty set of default transforms"): + with patch( + f"{SURROGATE_PATH}._construct_default_input_transforms", + wraps=_construct_default_input_transforms, + ) as mock_construct_default_input_transforms: + transform = _make_botorch_input_transform( + input_transform_classes=DEFAULT, + input_transform_options={}, + search_space_digest=search_space_digest, + dataset=dataset, + ) + mock_construct_default_input_transforms.assert_called_once_with( + search_space_digest=search_space_digest, + dataset=dataset, + ) + self.assertIsNone(transform) + + with self.subTest("Multiple specified transforms"): + transform = _make_botorch_input_transform( + input_transform_classes=[Normalize, Log10], + input_transform_options={"Log10": {"indices": [0]}}, + search_space_digest=search_space_digest, + dataset=dataset, + ) + transform = assert_is_instance(transform, ChainedInputTransform) + tf_values = list(transform.values()) + self.assertEqual(len(tf_values), 2) + self.assertIsInstance(tf_values[0], Normalize) + self.assertIsInstance(tf_values[1], Log10) + self.assertEqual(tf_values[1].indices.tolist(), [0]) + + bounds = [(1.0, 5.0), (2.0, 10.0)] + search_space_digest = SearchSpaceDigest( + feature_names=feature_names, + bounds=bounds, + task_features=[1], + ) + with self.subTest("Default Normalize transform"): + transform = _make_botorch_input_transform( + input_transform_classes=DEFAULT, + input_transform_options={}, + search_space_digest=search_space_digest, + dataset=dataset, + ) + transform = assert_is_instance(transform, Normalize) + self.assertEqual(transform.indices.tolist(), [0]) + self.assertEqual(transform.bounds.tolist(), [[1.0], [5.0]]) + class SurrogateTest(TestCase): def setUp(self) -> None: @@ -389,9 +476,12 @@ def test_dtype_and_device_properties(self) -> None: self.assertEqual(self.dtype, surrogate.dtype) self.assertEqual(self.device, surrogate.device) - @patch.object(SingleTaskGP, "__init__", return_value=None) + @patch( + f"{SURROGATE_PATH}.submodel_input_constructor", + wraps=submodel_input_constructor, + ) @patch(f"{SURROGATE_PATH}.fit_botorch_model") - def test_fit_model_reuse(self, mock_fit: Mock, mock_init: Mock) -> None: + def test_fit_model_reuse(self, mock_fit: Mock, mock_constructor: Mock) -> None: surrogate, _ = self._get_surrogate( botorch_model_class=SingleTaskGP, use_outcome_transform=False ) @@ -404,7 +494,7 @@ def test_fit_model_reuse(self, mock_fit: Mock, mock_init: Mock) -> None: search_space_digest=search_space_digest, ) mock_fit.assert_called_once() - mock_init.assert_called_once() + mock_constructor.assert_called_once() key = tuple(self.training_data[0].outcome_names) submodel = surrogate._submodels[key] self.assertIs(surrogate._last_datasets[key], self.training_data[0]) @@ -417,7 +507,7 @@ def test_fit_model_reuse(self, mock_fit: Mock, mock_init: Mock) -> None: ) # Still only called once -- i.e. not fitted again: mock_fit.assert_called_once() - mock_init.assert_called_once() + mock_constructor.assert_called_once() # Model is still the same object. self.assertIs(submodel, surrogate._submodels[key]) @@ -448,23 +538,12 @@ def test_construct_model(self) -> None: surrogate, _ = self._get_surrogate( botorch_model_class=botorch_model_class, use_outcome_transform=False ) - with self.assertRaisesRegex(TypeError, "posterior"): - # Base `Model` does not implement `posterior`, so instantiating it here - # will fail. - Surrogate()._construct_model( - dataset=self.training_data[0], - search_space_digest=self.search_space_digest, - model_config=ModelConfig(), - default_botorch_model_class=Model, - state_dict=None, - refit=True, - ) with patch.object( botorch_model_class, "construct_inputs", wraps=botorch_model_class.construct_inputs, ) as mock_construct_inputs, patch.object( - botorch_model_class, "__init__", return_value=None + botorch_model_class, "__init__", return_value=None, autospec=True ) as mock_init, patch(f"{SURROGATE_PATH}.fit_botorch_model") as mock_fit: model = surrogate._construct_model( dataset=self.training_data[0], @@ -479,7 +558,9 @@ def test_construct_model(self) -> None: call_kwargs = mock_init.call_args.kwargs self.assertTrue(torch.equal(call_kwargs["train_X"], self.Xs[0])) self.assertTrue(torch.equal(call_kwargs["train_Y"], self.Ys[0])) - self.assertEqual(len(call_kwargs), 2) + self.assertIsInstance(call_kwargs["input_transform"], Normalize) + self.assertIsNone(call_kwargs["outcome_transform"]) + self.assertEqual(len(call_kwargs), 4) mock_construct_inputs.assert_called_with( training_data=self.training_data[0], @@ -908,7 +989,6 @@ def test_fit_model_selection_metric_to_model_configs_multiple_metrics( call_kwargs = mock_model_selection.mock_calls[i].kwargs for k, v in expected_model_selection_kwargs.items(): self.assertEqual(call_kwargs[k], v) - # pyre-ignore[6] expected_cross_validate_kwargs["dataset"] = training_data[i] # check that each call to cross_validate uses the correct # model config. @@ -1122,9 +1202,13 @@ def test_w_robust_digest(self) -> None: robust_digest=robust_digest, ), ) - intf = checked_cast(InputPerturbation, surrogate.model.input_transform) - self.assertIsInstance(intf, InputPerturbation) - self.assertTrue(torch.equal(intf.perturbation_set, torch.zeros(2, 2))) + intf = assert_is_instance( + surrogate.model.input_transform, ChainedInputTransform + ) + intf_values = list(intf.values()) + self.assertIsInstance(intf_values[0], InputPerturbation) + self.assertIsInstance(intf_values[1], Normalize) + self.assertTrue(torch.equal(intf_values[0].perturbation_set, torch.zeros(2, 2))) def test_fit_mixed(self) -> None: # Test model construction with categorical variables. @@ -1593,7 +1677,10 @@ def test_construct_custom_model(self) -> None: @mock_botorch_optimize def test_w_robust_digest(self) -> None: surrogate = Surrogate( - botorch_model_class=SingleTaskGP, + surrogate_spec=SurrogateSpec( + botorch_model_class=SingleTaskGP, + input_transform_classes=[], + ) ) # Error handling. with self.assertRaisesRegex(NotImplementedError, "Environmental variable"): @@ -1615,9 +1702,6 @@ def test_w_robust_digest(self) -> None: multiplicative=False, ) # Input perturbation is constructed. - surrogate = Surrogate( - botorch_model_class=SingleTaskGP, - ) surrogate.fit( datasets=self.supervised_training_data, search_space_digest=SearchSpaceDigest(