Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for default input transfroms in MBM #3102

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,38 +113,31 @@ def _input_transform_argparse_normalize(
A dictionary with input transform kwargs.
"""
input_transform_options = input_transform_options or {}
d = input_transform_options.get("d", len(dataset.feature_names))
bounds = torch.as_tensor(
search_space_digest.bounds,
dtype=torch_dtype,
device=torch_device,
).T

d = input_transform_options.get("d", len(dataset.feature_names))
if isinstance(dataset, RankingDataset) and isinstance(dataset.X, SliceContainer):
d = dataset.X.values.shape[-1]

indices = list(range(d))
task_features = normalize_indices(search_space_digest.task_features, d=d)

for task_feature in sorted(task_features, reverse=True):
del indices[task_feature]

input_transform_options.setdefault("d", d)

if ("indices" in input_transform_options) or (len(indices) < d):
input_transform_options.setdefault("indices", indices)

if (
("bounds" not in input_transform_options)
and (bounds.shape[-1] < d)
and (len(search_space_digest.task_features) == 0)
):
raise NotImplementedError(
"Normalize transform bounds should be specified explicitly if there"
" are task features outside the search space."
)

input_transform_options.setdefault("bounds", bounds)
if "indices" not in input_transform_options:
indices = list(range(d))
task_features = normalize_indices(search_space_digest.task_features, d=d)
for task_feature in task_features:
indices.remove(task_feature)
input_transform_options["indices"] = indices

if "bounds" not in input_transform_options:
bounds = torch.as_tensor(
search_space_digest.bounds,
dtype=torch_dtype,
device=torch_device,
).T
if (bounds.shape[-1] < d) and (len(search_space_digest.task_features) == 0):
raise NotImplementedError(
"Normalize transform bounds should be specified explicitly if there "
"are task features outside the search space."
)
input_transform_options["bounds"] = bounds

return input_transform_options

Expand Down
91 changes: 78 additions & 13 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Any

import numpy as np

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
Expand All @@ -43,6 +42,7 @@
)
from ax.models.torch.utils import (
_to_inequality_constraints,
normalize_indices,
pick_best_out_of_sample_point_acqf_class,
predict_from_model,
)
Expand Down Expand Up @@ -70,12 +70,14 @@
ChainedInputTransform,
InputPerturbation,
InputTransform,
Normalize,
)
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from botorch.utils.types import _DefaultType
from gpytorch.kernels import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
Expand Down Expand Up @@ -127,23 +129,54 @@ def _extract_model_kwargs(


def _make_botorch_input_transform(
input_transform_classes: list[type[InputTransform]],
input_transform_classes: list[type[InputTransform]] | _DefaultType,
input_transform_options: dict[str, dict[str, Any]],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
) -> InputTransform | None:
"""
Makes a BoTorch input transform from the provided input classes and options.
"""
if isinstance(input_transform_classes, _DefaultType):
transforms = _construct_default_input_transforms(
search_space_digest=search_space_digest, dataset=dataset
)
else:
transforms = _construct_specified_input_transforms(
input_transform_classes=input_transform_classes,
dataset=dataset,
search_space_digest=search_space_digest,
input_transform_options=input_transform_options,
)
if len(transforms) == 0:
return None
elif len(transforms) > 1:
return ChainedInputTransform(
**{f"tf{i}": transforms[i] for i in range(len(transforms))}
)
else:
return transforms[0]


def _construct_specified_input_transforms(
input_transform_classes: list[type[InputTransform]],
input_transform_options: dict[str, dict[str, Any]],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
) -> list[InputTransform]:
"""Constructs a list of input transforms from input transform classes and
options provided in ``ModelConfig``.
"""
if not (
isinstance(input_transform_classes, list)
and all(issubclass(c, InputTransform) for c in input_transform_classes)
):
raise UserInputError("Expected a list of input transforms.")
raise UserInputError(
"Expected a list of input transform classes. "
f"Got {input_transform_classes=}."
)
if search_space_digest.robust_digest is not None:
input_transform_classes = [InputPerturbation] + input_transform_classes
if len(input_transform_classes) == 0:
return None

input_transform_kwargs = [
input_transform_argparse(
Expand All @@ -157,23 +190,55 @@ def _make_botorch_input_transform(
for transform_class in input_transform_classes
]

input_transforms = [
return [
# pyre-fixme[45]: Cannot instantiate abstract class `InputTransform`.
transform_class(**single_input_transform_kwargs)
for transform_class, single_input_transform_kwargs in zip(
input_transform_classes, input_transform_kwargs
)
]

input_transform_instance = (
ChainedInputTransform(
**{f"tf{i}": input_transforms[i] for i in range(len(input_transforms))}

def _construct_default_input_transforms(
search_space_digest: SearchSpaceDigest,
dataset: SupervisedDataset,
) -> list[InputTransform]:
"""Construct the default input transforms for the given search space digest.

The default transforms are added in this order:
- If the search space digest has a robust digest, an ``InputPerturbation`` transform
is used.
- If the bounds for the non-task features are not [0, 1], a ``Normalize`` transform
is used. The transfrom only applies to the non-task features.
"""
transforms = []
# Add InputPerturbation if there is a robust digest.
if search_space_digest.robust_digest is not None:
transforms.append(
InputPerturbation(
**input_transform_argparse(
InputPerturbation,
dataset=dataset,
search_space_digest=search_space_digest,
)
)
)
if len(input_transforms) > 1
else input_transforms[0]
)
# Processing for Normalize.
bounds = torch.tensor(search_space_digest.bounds).T
indices = list(range(bounds.shape[-1]))
# Remove task features.
for task_feature in normalize_indices(
search_space_digest.task_features, d=bounds.shape[-1]
):
indices.remove(task_feature)
# Skip the Normalize transform if the bounds are [0, 1].
if not (
torch.allclose(bounds[0, indices], torch.zeros(len(indices)))
and torch.allclose(bounds[1, indices], torch.ones(len(indices)))
):
transforms.append(Normalize(d=bounds.shape[-1], indices=indices, bounds=bounds))

return input_transform_instance
return transforms


def _make_botorch_outcome_transform(
Expand Down
3 changes: 2 additions & 1 deletion ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import is_fully_bayesian
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.kernels.kernel import Kernel
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
Expand Down Expand Up @@ -108,7 +109,7 @@ class string names and the values are dictionaries of input transform
model_options: dict[str, Any] = field(default_factory=dict)
mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood
mll_options: dict[str, Any] = field(default_factory=dict)
input_transform_classes: list[type[InputTransform]] | None = None
input_transform_classes: list[type[InputTransform]] | _DefaultType | None = DEFAULT
input_transform_options: dict[str, dict[str, Any]] | None = field(
default_factory=dict
)
Expand Down
10 changes: 4 additions & 6 deletions ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.constraints import get_outcome_constraint_transforms
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.types import DEFAULT
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from pyre_extensions import none_throws
Expand Down Expand Up @@ -335,7 +336,7 @@ def test_fit(self, mock_fit: Mock) -> None:
model_options={},
mll_class=ExactMarginalLogLikelihood,
mll_options={},
input_transform_classes=None,
input_transform_classes=DEFAULT,
input_transform_options={},
outcome_transform_classes=None,
outcome_transform_options={},
Expand Down Expand Up @@ -635,7 +636,7 @@ def test_feature_importances(self) -> None:
)
model.surrogate.fit(
datasets=self.block_design_training_data,
search_space_digest=SearchSpaceDigest(feature_names=[], bounds=[]),
search_space_digest=self.search_space_digest,
)
if botorch_model_class == SaasFullyBayesianSingleTaskGP:
mcmc_samples = {
Expand Down Expand Up @@ -741,10 +742,7 @@ def test_evaluate_acquisition_function(
)
model.surrogate.fit(
datasets=self.block_design_training_data,
search_space_digest=SearchSpaceDigest(
feature_names=[],
bounds=[],
),
search_space_digest=self.search_space_digest,
)
model.evaluate_acquisition_function(
X=self.X_test,
Expand Down
46 changes: 22 additions & 24 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel
from ax.models.torch.botorch_modular.surrogate import (
_extract_model_kwargs,
submodel_input_constructor,
Surrogate,
SurrogateSpec,
)
Expand Down Expand Up @@ -389,9 +390,12 @@ def test_dtype_and_device_properties(self) -> None:
self.assertEqual(self.dtype, surrogate.dtype)
self.assertEqual(self.device, surrogate.device)

@patch.object(SingleTaskGP, "__init__", return_value=None)
@patch(
f"{SURROGATE_PATH}.submodel_input_constructor",
wraps=submodel_input_constructor,
)
@patch(f"{SURROGATE_PATH}.fit_botorch_model")
def test_fit_model_reuse(self, mock_fit: Mock, mock_init: Mock) -> None:
def test_fit_model_reuse(self, mock_fit: Mock, mock_constructor: Mock) -> None:
surrogate, _ = self._get_surrogate(
botorch_model_class=SingleTaskGP, use_outcome_transform=False
)
Expand All @@ -404,7 +408,7 @@ def test_fit_model_reuse(self, mock_fit: Mock, mock_init: Mock) -> None:
search_space_digest=search_space_digest,
)
mock_fit.assert_called_once()
mock_init.assert_called_once()
mock_constructor.assert_called_once()
key = tuple(self.training_data[0].outcome_names)
submodel = surrogate._submodels[key]
self.assertIs(surrogate._last_datasets[key], self.training_data[0])
Expand All @@ -417,7 +421,7 @@ def test_fit_model_reuse(self, mock_fit: Mock, mock_init: Mock) -> None:
)
# Still only called once -- i.e. not fitted again:
mock_fit.assert_called_once()
mock_init.assert_called_once()
mock_constructor.assert_called_once()
# Model is still the same object.
self.assertIs(submodel, surrogate._submodels[key])

Expand Down Expand Up @@ -448,23 +452,12 @@ def test_construct_model(self) -> None:
surrogate, _ = self._get_surrogate(
botorch_model_class=botorch_model_class, use_outcome_transform=False
)
with self.assertRaisesRegex(TypeError, "posterior"):
# Base `Model` does not implement `posterior`, so instantiating it here
# will fail.
Surrogate()._construct_model(
dataset=self.training_data[0],
search_space_digest=self.search_space_digest,
model_config=ModelConfig(),
default_botorch_model_class=Model,
state_dict=None,
refit=True,
)
with patch.object(
botorch_model_class,
"construct_inputs",
wraps=botorch_model_class.construct_inputs,
) as mock_construct_inputs, patch.object(
botorch_model_class, "__init__", return_value=None
botorch_model_class, "__init__", return_value=None, autospec=True
) as mock_init, patch(f"{SURROGATE_PATH}.fit_botorch_model") as mock_fit:
model = surrogate._construct_model(
dataset=self.training_data[0],
Expand All @@ -479,7 +472,9 @@ def test_construct_model(self) -> None:
call_kwargs = mock_init.call_args.kwargs
self.assertTrue(torch.equal(call_kwargs["train_X"], self.Xs[0]))
self.assertTrue(torch.equal(call_kwargs["train_Y"], self.Ys[0]))
self.assertEqual(len(call_kwargs), 2)
self.assertIsInstance(call_kwargs["input_transform"], Normalize)
self.assertIsNone(call_kwargs["outcome_transform"])
self.assertEqual(len(call_kwargs), 4)

mock_construct_inputs.assert_called_with(
training_data=self.training_data[0],
Expand Down Expand Up @@ -1120,9 +1115,12 @@ def test_w_robust_digest(self) -> None:
robust_digest=robust_digest,
),
)
intf = checked_cast(InputPerturbation, surrogate.model.input_transform)
self.assertIsInstance(intf, InputPerturbation)
self.assertTrue(torch.equal(intf.perturbation_set, torch.zeros(2, 2)))
intf = surrogate.model.input_transform
self.assertIsInstance(intf, ChainedInputTransform)
intf_values = list(intf.values())
self.assertIsInstance(intf_values[0], InputPerturbation)
self.assertIsInstance(intf_values[1], Normalize)
self.assertTrue(torch.equal(intf_values[0].perturbation_set, torch.zeros(2, 2)))

def test_fit_mixed(self) -> None:
# Test model construction with categorical variables.
Expand Down Expand Up @@ -1591,7 +1589,10 @@ def test_construct_custom_model(self) -> None:
@mock_botorch_optimize
def test_w_robust_digest(self) -> None:
surrogate = Surrogate(
botorch_model_class=SingleTaskGP,
surrogate_spec=SurrogateSpec(
botorch_model_class=SingleTaskGP,
input_transform_classes=[],
)
)
# Error handling.
with self.assertRaisesRegex(NotImplementedError, "Environmental variable"):
Expand All @@ -1613,9 +1614,6 @@ def test_w_robust_digest(self) -> None:
multiplicative=False,
)
# Input perturbation is constructed.
surrogate = Surrogate(
botorch_model_class=SingleTaskGP,
)
surrogate.fit(
datasets=self.supervised_training_data,
search_space_digest=SearchSpaceDigest(
Expand Down