Skip to content

Commit

Permalink
one_hot_encoded_space
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Aug 16, 2024
1 parent 09c727e commit d39679b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
8 changes: 3 additions & 5 deletions tests/integration/test_mixed_space_bayesian_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from typing import cast

import gpflow
import numpy as np
import numpy.testing as npt
import pytest
Expand Down Expand Up @@ -50,6 +49,7 @@
CategoricalSearchSpace,
DiscreteSearchSpace,
TaggedProductSearchSpace,
one_hot_encoded_space,
one_hot_encoder,
)
from trieste.types import TensorType
Expand Down Expand Up @@ -260,11 +260,9 @@ def test_optimizer_finds_minima_of_the_categorical_scaled_branin_function(

# model uses one-hot encoding for the categorical inputs
encoder = one_hot_encoder(problem.search_space)
kernel = gpflow.kernels.Matern52(
variance=tf.math.reduce_variance(initial_data.observations), lengthscales=0.1
)
encoded_space = one_hot_encoded_space(problem.search_space)
model = GaussianProcessRegression(
build_gpr(encode_dataset(initial_data, encoder), kernel=kernel, likelihood_variance=1e-8),
build_gpr(encode_dataset(initial_data, encoder), encoded_space, likelihood_variance=1e-8),
encoder=encoder,
)

Expand Down
14 changes: 14 additions & 0 deletions trieste/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,20 @@ def one_hot_encoder(space: SearchSpace) -> EncoderFunction:
return space.one_hot_encoder if isinstance(space, HasOneHotEncoder) else lambda x: x


def one_hot_encoded_space(space: SearchSpace) -> SearchSpace:
"A bounded search space corresponding to the one-hot encoding of the given space."

if isinstance(space, GeneralDiscreteSearchSpace) and isinstance(space, HasOneHotEncoder):
return DiscreteSearchSpace(space.one_hot_encoder(space.points))
elif isinstance(space, TaggedProductSearchSpace):
spaces = [one_hot_encoded_space(space.get_subspace(tag)) for tag in space.subspace_tags]
return TaggedProductSearchSpace(spaces=spaces, tags=space.subspace_tags)
elif isinstance(space, HasOneHotEncoder):
raise NotImplementedError(f"Unsupported one-hot-encoded space {type(space)}")
else:
return space


class CategoricalSearchSpace(GeneralDiscreteSearchSpace, HasOneHotEncoder):
r"""
A categorical :class:`SearchSpace` representing a finite set :math:`\mathcal{C}` of categories,
Expand Down

0 comments on commit d39679b

Please sign in to comment.