diff --git a/tests/integration/test_mixed_space_bayesian_optimization.py b/tests/integration/test_mixed_space_bayesian_optimization.py index bc4a7997c..16a9eb78f 100644 --- a/tests/integration/test_mixed_space_bayesian_optimization.py +++ b/tests/integration/test_mixed_space_bayesian_optimization.py @@ -15,7 +15,6 @@ from typing import cast -import gpflow import numpy as np import numpy.testing as npt import pytest @@ -50,6 +49,7 @@ CategoricalSearchSpace, DiscreteSearchSpace, TaggedProductSearchSpace, + one_hot_encoded_space, one_hot_encoder, ) from trieste.types import TensorType @@ -260,11 +260,9 @@ def test_optimizer_finds_minima_of_the_categorical_scaled_branin_function( # model uses one-hot encoding for the categorical inputs encoder = one_hot_encoder(problem.search_space) - kernel = gpflow.kernels.Matern52( - variance=tf.math.reduce_variance(initial_data.observations), lengthscales=0.1 - ) + encoded_space = one_hot_encoded_space(problem.search_space) model = GaussianProcessRegression( - build_gpr(encode_dataset(initial_data, encoder), kernel=kernel, likelihood_variance=1e-8), + build_gpr(encode_dataset(initial_data, encoder), encoded_space, likelihood_variance=1e-8), encoder=encoder, ) diff --git a/trieste/space.py b/trieste/space.py index 2ab3d1166..4a228460c 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -518,6 +518,20 @@ def one_hot_encoder(space: SearchSpace) -> EncoderFunction: return space.one_hot_encoder if isinstance(space, HasOneHotEncoder) else lambda x: x +def one_hot_encoded_space(space: SearchSpace) -> SearchSpace: + "A bounded search space corresponding to the one-hot encoding of the given space." + + if isinstance(space, GeneralDiscreteSearchSpace) and isinstance(space, HasOneHotEncoder): + return DiscreteSearchSpace(space.one_hot_encoder(space.points)) + elif isinstance(space, TaggedProductSearchSpace): + spaces = [one_hot_encoded_space(space.get_subspace(tag)) for tag in space.subspace_tags] + return TaggedProductSearchSpace(spaces=spaces, tags=space.subspace_tags) + elif isinstance(space, HasOneHotEncoder): + raise NotImplementedError(f"Unsupported one-hot-encoded space {type(space)}") + else: + return space + + class CategoricalSearchSpace(GeneralDiscreteSearchSpace, HasOneHotEncoder): r""" A categorical :class:`SearchSpace` representing a finite set :math:`\mathcal{C}` of categories,