Skip to content

Commit

Permalink
Support multiple models/datasets in region selects
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Oct 12, 2023
1 parent ee1ee56 commit 7d03a04
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 69 deletions.
34 changes: 25 additions & 9 deletions tests/unit/acquisition/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,16 +545,23 @@ def test_async_keeps_track_of_pending_points(
npt.assert_allclose(state.pending_points, tf.concat([point2, point3], axis=0))


@pytest.mark.parametrize("datasets", [{}, {"foo": empty_dataset([1], [1])}])
@pytest.mark.parametrize(
"datasets",
[
{},
{"foo": empty_dataset([1], [1])},
{OBJECTIVE: empty_dataset([1], [1]), "foo": empty_dataset([1], [1])},
],
)
@pytest.mark.parametrize(
"models", [{}, {"foo": QuadraticMeanAndRBFKernel()}, {OBJECTIVE: QuadraticMeanAndRBFKernel()}]
)
def test_trego_raises_for_missing_datasets_key(
datasets: dict[Tag, Dataset], models: dict[Tag, ProbabilisticModel]
datasets: Mapping[Tag, Dataset], models: dict[Tag, ProbabilisticModel]
) -> None:
search_space = Box([-1], [1])
rule = BatchTrustRegionBox(TREGOBox(search_space)) # type: ignore[var-annotated]
with pytest.raises(ValueError, match="none of the tags '.LocalTag.OBJECTIVE, 0., "):
with pytest.raises(ValueError, match="a single OBJECTIVE dataset must be provided"):
rule.acquire(search_space, models, datasets=datasets)(None)


Expand Down Expand Up @@ -1195,12 +1202,21 @@ def test_turbo_state_deepcopy() -> None:
npt.assert_allclose(tr_state_copy.y_min, tr_state.y_min)


# get_dataset_min raises if dataset is None.
def test_trust_region_box_get_dataset_min_raises_if_dataset_is_none() -> None:
@pytest.mark.parametrize(
"datasets",
[
{},
{"foo": empty_dataset([1], [1])},
{OBJECTIVE: empty_dataset([1], [1]), "foo": empty_dataset([1], [1])},
],
)
def test_trust_region_box_get_dataset_min_raises_if_dataset_is_faulty(
datasets: Mapping[Tag, Dataset]
) -> None:
search_space = Box([0.0, 0.0], [1.0, 1.0])
trb = SingleObjectiveTrustRegionBox(search_space)
with pytest.raises(ValueError, match="dataset must be provided"):
trb.get_dataset_min(None)
with pytest.raises(ValueError, match="a single OBJECTIVE dataset must be provided"):
trb.get_dataset_min(datasets)


# get_dataset_min picks the minimum x and y values from the dataset.
Expand All @@ -1213,7 +1229,7 @@ def test_trust_region_box_get_dataset_min() -> None:
trb = SingleObjectiveTrustRegionBox(search_space)
trb._lower = tf.constant([0.2, 0.2], dtype=tf.float64)
trb._upper = tf.constant([0.7, 0.7], dtype=tf.float64)
x_min, y_min = trb.get_dataset_min({"foo": dataset})
x_min, y_min = trb.get_dataset_min({OBJECTIVE: dataset})
npt.assert_array_equal(x_min, tf.constant([0.3, 0.4], dtype=tf.float64))
npt.assert_array_equal(y_min, tf.constant([0.2], dtype=tf.float64))

Expand All @@ -1227,7 +1243,7 @@ def test_trust_region_box_get_dataset_min_outside_search_space() -> None:
tf.constant([[0.7], [0.9]], dtype=tf.float64),
)
trb = SingleObjectiveTrustRegionBox(search_space)
x_min, y_min = trb.get_dataset_min({"foo": dataset})
x_min, y_min = trb.get_dataset_min({OBJECTIVE: dataset})
npt.assert_array_equal(x_min, tf.constant([1.2, 1.3], dtype=tf.float64))
npt.assert_array_equal(y_min, tf.constant([np.inf], dtype=tf.float64))

Expand Down
150 changes: 90 additions & 60 deletions trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Generic,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
Expand Down Expand Up @@ -1018,6 +1019,23 @@ def update(
"""
...

def _get_tags(self, tags: Set[Tag]) -> Tuple[Set[Tag], Set[Tag]]:
# Separate tags into local (matching index) and global tags (without matching
# local tag).
local_gtags = set()
global_tags = set()
for tag in tags:
ltag = LocalTag.from_tag(tag)
if not ltag.is_local:
global_tags.add(tag)
elif ltag.local_index == self.index:
local_gtags.add(ltag.global_tag)

# Only keep global tags that don't have a matching local tag.
global_tags = global_tags.difference(local_gtags)

return local_gtags, global_tags

def select_models(
self, models: Optional[Mapping[Tag, ProbabilisticModelType]]
) -> Optional[Mapping[Tag, ProbabilisticModelType]]:
Expand All @@ -1027,8 +1045,25 @@ def select_models(
:param models: The model for each tag.
:return: The models belonging to this region.
"""
# By default return all the models.
return models
if models is None:
_models = {}
elif self.index is None:
# If no index, then return the global models.
_models = {
tag: model for tag, model in models.items() if not LocalTag.from_tag(tag).is_local
}
else:
# Prefer matching local model for each tag, otherwise select the global model.
local_gtags, global_tags = self._get_tags(set(models))

_models = {}
for tag in local_gtags:
ltag = LocalTag(tag, self.index)
_models[ltag] = models[ltag]
for tag in global_tags:
_models[tag] = models[tag]

return _models if _models else None

def select_datasets(
self, datasets: Optional[Mapping[Tag, Dataset]]
Expand All @@ -1039,8 +1074,27 @@ def select_datasets(
:param datasets: The dataset for each tag.
:return: The datasets belonging to this region.
"""
# By default return all the datasets.
return datasets
if datasets is None:
_datasets = {}
elif self.index is None:
# If no index, then return the global datasets.
_datasets = {
tag: dataset
for tag, dataset in datasets.items()
if not LocalTag.from_tag(tag).is_local
}
else:
# Prefer matching local dataset for each tag, otherwise select the global dataset.
local_gtags, global_tags = self._get_tags(set(datasets))

_datasets = {}
for tag in local_gtags:
ltag = LocalTag(tag, self.index)
_datasets[ltag] = datasets[ltag]
for tag in global_tags:
_datasets[tag] = datasets[tag]

return _datasets if _datasets else None

def get_datasets_filter_mask(
self, datasets: Optional[Mapping[Tag, Dataset]]
Expand All @@ -1059,9 +1113,9 @@ def get_datasets_filter_mask(
if datasets is None:
return None
else:
# By default return a mask that filters nothing.
# Only keep points that are in the box.
return {
tag: tf.ones(tf.shape(dataset.query_points)[:-1], dtype=tf.bool)
tag: self.contains(dataset.query_points)
for tag, dataset in datasets.items()
if LocalTag.from_tag(tag).local_index == self.index
}
Expand Down Expand Up @@ -1172,7 +1226,7 @@ def acquire(
num_subspaces = len(self._tags)
assert _num_local_models in [0, num_subspaces], (
f"When using local models, the number of subspaces {num_subspaces} should be equal to "
f"the number of local objective models {_num_local_models}"
f"the number of local models {_num_local_models}"
)

# If we have local models, run the (deepcopied) base rule sequentially for each subspace.
Expand Down Expand Up @@ -1402,13 +1456,13 @@ def initialize(
Initialize the box by sampling a location from the global search space and setting the
bounds.
"""
dataset = self.select_datasets(datasets)
datasets = self.select_datasets(datasets)

self.location = tf.squeeze(self.global_search_space.sample(1), axis=0)
self._step_is_success = False
self._init_eps()
self._update_bounds()
_, self._y_min = self.get_dataset_min(dataset)
_, self._y_min = self.get_dataset_min(datasets)

def update(
self,
Expand All @@ -1426,13 +1480,13 @@ def update(
``1 / beta``. Conversely, if it was unsuccessful, the size is reduced by the factor
``beta``.
"""
dataset = self.select_datasets(datasets)
datasets = self.select_datasets(datasets)

if tf.reduce_any(self.eps < self._min_eps):
self.initialize(models, datasets)
return

x_min, y_min = self.get_dataset_min(dataset)
x_min, y_min = self.get_dataset_min(datasets)
self.location = x_min

tr_volume = tf.reduce_prod(self.upper - self.lower)
Expand All @@ -1441,46 +1495,6 @@ def update(
self._update_bounds()
self._y_min = y_min

def select_models(
self, models: Optional[Mapping[Tag, ProbabilisticModelType]]
) -> Optional[Mapping[Tag, ProbabilisticModelType]]:
# Select the model belonging to this box.
if self.index is None:
tags = [OBJECTIVE] # If no index, then pick the global model.
else:
tags = [LocalTag(OBJECTIVE, self.index), OBJECTIVE] # Prefer local model if available.
tag, model = get_value_for_tag(models, tags)
return {tag: model} if model is not None else None

def select_datasets(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Optional[Mapping[Tag, Dataset]]:
# Select the dataset belonging to this box.
if self.index is None:
tags = [OBJECTIVE] # If no index, then pick the global dataset.
else:
tags = [
LocalTag(OBJECTIVE, self.index),
OBJECTIVE,
] # Prefer local dataset if available.
tag, dataset = get_value_for_tag(datasets, tags)
return {tag: dataset} if dataset is not None else None

def get_datasets_filter_mask(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Optional[Mapping[Tag, tf.Tensor]]:
# Only select the region datasets for filtering. Don't directly filter the global dataset.
assert self.index is not None, "the index should be set for filtering local datasets"
if datasets is None:
return None
else:
# Only keep points that are in the box.
return {
tag: self.contains(dataset.query_points)
for tag, dataset in datasets.items()
if LocalTag.from_tag(tag).local_index == self.index
}

@check_shapes(
"return[0]: [D]",
"return[1]: []",
Expand All @@ -1489,9 +1503,13 @@ def get_dataset_min(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Tuple[TensorType, TensorType]:
"""Calculate the minimum of the box using the given dataset."""
if datasets is None:
raise ValueError("""dataset must be provided""")
dataset = next(iter(datasets.values())) # Expect only one dataset.
if (
datasets is None
or len(datasets) != 1
or LocalTag.from_tag(next(iter(datasets))).global_tag != OBJECTIVE
):
raise ValueError("""a single OBJECTIVE dataset must be provided""")
dataset = next(iter(datasets.values()))

in_tr = self.contains(dataset.query_points)
in_tr_obs = tf.where(
Expand Down Expand Up @@ -1633,18 +1651,30 @@ def initialize(
def get_datasets_filter_mask(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Optional[Mapping[Tag, tf.Tensor]]:
# Don't filter out any points from the dataset by bypassing the
# SingleObjectiveTrustRegionBox method.
return super(SingleObjectiveTrustRegionBox, self).get_datasets_filter_mask(datasets)
# Only select the region datasets for filtering. Don't directly filter the global dataset.
assert self.index is not None, "the index should be set for filtering local datasets"
if datasets is None:
return None
else:
# Don't filter out any points from the dataset. Always keep the entire dataset.
return {
tag: tf.ones(tf.shape(dataset.query_points)[:-1], dtype=tf.bool)
for tag, dataset in datasets.items()
if LocalTag.from_tag(tag).local_index == self.index
}

@inherit_check_shapes
def get_dataset_min(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Tuple[TensorType, TensorType]:
"""Calculate the minimum of the box using the given dataset."""
if datasets is None:
raise ValueError("""dataset must be provided""")
dataset = next(iter(datasets.values())) # Expect only one dataset.
if (
datasets is None
or len(datasets) != 1
or LocalTag.from_tag(next(iter(datasets))).global_tag != OBJECTIVE
):
raise ValueError("""a single OBJECTIVE dataset must be provided""")
dataset = next(iter(datasets.values()))

# Always return the global minimum.
ix = tf.argmin(dataset.observations)
Expand Down

0 comments on commit 7d03a04

Please sign in to comment.