Skip to content

Commit

Permalink
Allow multiple models/datasets for base rule
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Oct 11, 2023
1 parent 40df585 commit ee1ee56
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 81 deletions.
4 changes: 2 additions & 2 deletions tests/unit/acquisition/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ def test_trust_region_box_get_dataset_min() -> None:
trb = SingleObjectiveTrustRegionBox(search_space)
trb._lower = tf.constant([0.2, 0.2], dtype=tf.float64)
trb._upper = tf.constant([0.7, 0.7], dtype=tf.float64)
x_min, y_min = trb.get_dataset_min(dataset)
x_min, y_min = trb.get_dataset_min({"foo": dataset})
npt.assert_array_equal(x_min, tf.constant([0.3, 0.4], dtype=tf.float64))
npt.assert_array_equal(y_min, tf.constant([0.2], dtype=tf.float64))

Expand All @@ -1227,7 +1227,7 @@ def test_trust_region_box_get_dataset_min_outside_search_space() -> None:
tf.constant([[0.7], [0.9]], dtype=tf.float64),
)
trb = SingleObjectiveTrustRegionBox(search_space)
x_min, y_min = trb.get_dataset_min(dataset)
x_min, y_min = trb.get_dataset_min({"foo": dataset})
npt.assert_array_equal(x_min, tf.constant([1.2, 1.3], dtype=tf.float64))
npt.assert_array_equal(y_min, tf.constant([np.inf], dtype=tf.float64))

Expand Down
189 changes: 110 additions & 79 deletions trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,22 @@
import copy
import math
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Callable, Generic, Optional, Sequence, Tuple, TypeVar, Union, cast, overload
from typing import (
Any,
Callable,
Dict,
Generic,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
overload,
)

import numpy as np
from check_shapes import check_shapes, inherit_check_shapes
Expand Down Expand Up @@ -1005,51 +1018,53 @@ def update(
"""
...

def select_model(
def select_models(
self, models: Optional[Mapping[Tag, ProbabilisticModelType]]
) -> Tuple[Optional[Tag], Optional[ProbabilisticModelType]]:
) -> Optional[Mapping[Tag, ProbabilisticModelType]]:
"""
Select a single model belonging to this region. This is an optional method that is
only required if the region is used with single model acquisition functions.
Select models belonging to this region for acquisition.
:param models: The model for each tag.
:return: The model belonging to this region.
:return: The models belonging to this region.
"""
# By default return the OBJECTIVE model.
return get_value_for_tag(models, OBJECTIVE)
# By default return all the models.
return models

def select_dataset(
def select_datasets(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Tuple[Optional[Tag], Optional[Dataset]]:
) -> Optional[Mapping[Tag, Dataset]]:
"""
Select a single dataset belonging to this region. This is an optional method that is
only required if the region is used with single model acquisition functions.
Select datasets belonging to this region for acquisition.
:param datasets: The dataset for each tag.
:return: The tag and associated dataset belonging to this region.
:return: The datasets belonging to this region.
"""
# By default return the OBJECTIVE dataset.
return get_value_for_tag(datasets, OBJECTIVE)
# By default return all the datasets.
return datasets

def get_dataset_filter_mask(
def get_datasets_filter_mask(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Tuple[Optional[Tag], Optional[tf.Tensor]]:
) -> Optional[Mapping[Tag, tf.Tensor]]:
"""
Return a boolean mask that can be used to filter out points from the dataset that
Return a boolean mask that can be used to filter out points from the datasets that
belong to this region.
:param datasets: The dataset for each tag.
:return: The tag for the selected dataset and a boolean mask that can be used to filter
that dataset. A value of `True` indicates that the corresponding point should be kept.
:return: A mapping for each tag belonging to this region, to a boolean mask that can be
used to filter out points from the datasets. A value of `True` indicates that the
corresponding point should be kept.
"""
# Always select the region dataset for filtering. Don't directly filter the global dataset.
# Only select the region datasets for filtering. Don't directly filter the global dataset.
assert self.index is not None, "the index should be set for filtering local datasets"
tag, dataset = get_value_for_tag(datasets, LocalTag(OBJECTIVE, self.index))
if dataset is None:
return None, None
if datasets is None:
return None
else:
# By default return a mask that filters nothing.
return tag, tf.ones(tf.shape(dataset.query_points)[:-1], dtype=tf.bool)
return {
tag: tf.ones(tf.shape(dataset.query_points)[:-1], dtype=tf.bool)
for tag, dataset in datasets.items()
if LocalTag.from_tag(tag).local_index == self.index
}


UpdatableTrustRegionType = TypeVar("UpdatableTrustRegionType", bound=UpdatableTrustRegion)
Expand Down Expand Up @@ -1143,27 +1158,28 @@ def acquire(
assert self._tags is not None
assert self._init_subspaces is not None

num_subspaces = len(self._tags)
num_local_models = 0
num_local_models: Dict[Tag, int] = defaultdict(int)
for tag in models:
ltag = LocalTag.from_tag(tag)
if ltag.is_local and ltag.global_tag == OBJECTIVE:
num_local_models += 1
assert num_local_models in [0, num_subspaces], (
if ltag.is_local:
num_local_models[ltag.global_tag] += 1
num_local_models_vals = set(num_local_models.values())
assert (
len(num_local_models_vals) <= 1
), f"The number of local models should be the same for all tags, got {num_local_models}"
_num_local_models = 0 if len(num_local_models_vals) == 0 else num_local_models_vals.pop()

num_subspaces = len(self._tags)
assert _num_local_models in [0, num_subspaces], (
f"When using local models, the number of subspaces {num_subspaces} should be equal to "
f"the number of local objective models {num_local_models}"
f"the number of local objective models {_num_local_models}"
)

# If the base rule is a single model acquisition rule, but we have local
# models, run the (deepcopied) base rule sequentially for each subspace.
# If we have local models, run the (deepcopied) base rule sequentially for each subspace.
# Otherwise, run the base rule as is, once with all models and datasets.
# Note: this should only trigger on the first call to `acquire`, as after that we will
# have a list of rules in `self._rules`.
if (
isinstance(self._rule, EfficientGlobalOptimization)
and hasattr(self._rule._builder, "single_builder")
and (num_local_models > 0 or OBJECTIVE not in models)
):
if _num_local_models > 0:
self._rules = [copy.deepcopy(self._rule) for _ in range(num_subspaces)]

def state_func(
Expand Down Expand Up @@ -1209,16 +1225,20 @@ def state_func(
if self._rules is not None:
_points = []
for subspace, rule in zip(subspaces, self._rules):
_, _model = subspace.select_model(models)
_, _dataset = subspace.select_dataset(datasets)
assert _model is not None
# Using default tag, as that is what single model acquisition builders expect.
model = {OBJECTIVE: _model}
if _dataset is None:
dataset = None
else:
dataset = {OBJECTIVE: _dataset}
_points.append(rule.acquire(subspace, model, dataset))
_models = subspace.select_models(models)
_datasets = subspace.select_datasets(datasets)
assert _models is not None
# Remap all local tags to global ones. One reason is that single model
# acquisition builders expect OBJECTIVE to exist.
_models = {
LocalTag.from_tag(tag).global_tag: model for tag, model in _models.items()
}
if _datasets is not None:
_datasets = {
LocalTag.from_tag(tag).global_tag: dataset
for tag, dataset in _datasets.items()
}
_points.append(rule.acquire(subspace, _models, _datasets))
points = tf.stack(_points, axis=1)
else:
points = self._rule.acquire(acquisition_space, models, datasets)
Expand Down Expand Up @@ -1289,11 +1309,12 @@ def update_datasets(
# be to remove this assumption.
assert self._init_subspaces is not None
for subspace in self._init_subspaces:
tag, in_region = subspace.get_dataset_filter_mask(datasets)
assert tag is not None
ltag = LocalTag.from_tag(tag)
assert ltag.is_local, f"can only filter local tags, got {tag}"
used_masks[tag] = tf.logical_or(used_masks[tag], in_region)
in_region_masks = subspace.get_datasets_filter_mask(datasets)
if in_region_masks is not None:
for tag, in_region in in_region_masks.items():
ltag = LocalTag.from_tag(tag)
assert ltag.is_local, f"can only filter local tags, got {tag}"
used_masks[tag] = tf.logical_or(used_masks[tag], in_region)

filtered_datasets = {}
global_tags = [] # Global datasets to re-generate.
Expand Down Expand Up @@ -1381,7 +1402,7 @@ def initialize(
Initialize the box by sampling a location from the global search space and setting the
bounds.
"""
_, dataset = self.select_dataset(datasets)
dataset = self.select_datasets(datasets)

self.location = tf.squeeze(self.global_search_space.sample(1), axis=0)
self._step_is_success = False
Expand All @@ -1405,7 +1426,7 @@ def update(
``1 / beta``. Conversely, if it was unsuccessful, the size is reduced by the factor
``beta``.
"""
_, dataset = self.select_dataset(datasets)
dataset = self.select_datasets(datasets)

if tf.reduce_any(self.eps < self._min_eps):
self.initialize(models, datasets)
Expand All @@ -1420,51 +1441,57 @@ def update(
self._update_bounds()
self._y_min = y_min

def select_model(
def select_models(
self, models: Optional[Mapping[Tag, ProbabilisticModelType]]
) -> Tuple[Optional[Tag], Optional[ProbabilisticModelType]]:
# Select the model belonging to this box. Note there isn't necessarily a one-to-one
# mapping between regions and models.
) -> Optional[Mapping[Tag, ProbabilisticModelType]]:
# Select the model belonging to this box.
if self.index is None:
tags = [OBJECTIVE] # If no index, then pick the global model.
else:
tags = [LocalTag(OBJECTIVE, self.index), OBJECTIVE] # Prefer local model if available.
return get_value_for_tag(models, tags)
tag, model = get_value_for_tag(models, tags)
return {tag: model} if model is not None else None

def select_dataset(
def select_datasets(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Tuple[Optional[Tag], Optional[Dataset]]:
# Select the dataset belonging to this box. Note there isn't necessarily a one-to-one
# mapping between regions and datasets.
) -> Optional[Mapping[Tag, Dataset]]:
# Select the dataset belonging to this box.
if self.index is None:
tags = [OBJECTIVE] # If no index, then pick the global dataset.
else:
tags = [
LocalTag(OBJECTIVE, self.index),
OBJECTIVE,
] # Prefer local dataset if available.
return get_value_for_tag(datasets, tags)
tag, dataset = get_value_for_tag(datasets, tags)
return {tag: dataset} if dataset is not None else None

def get_dataset_filter_mask(
def get_datasets_filter_mask(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Tuple[Optional[Tag], Optional[tf.Tensor]]:
# Always select the region dataset for filtering. Don't directly filter the global dataset.
) -> Optional[Mapping[Tag, tf.Tensor]]:
# Only select the region datasets for filtering. Don't directly filter the global dataset.
assert self.index is not None, "the index should be set for filtering local datasets"
tag, dataset = get_value_for_tag(datasets, LocalTag(OBJECTIVE, self.index))
if dataset is None:
return None, None
if datasets is None:
return None
else:
# Only keep points that are in the box.
return tag, self.contains(dataset.query_points)
return {
tag: self.contains(dataset.query_points)
for tag, dataset in datasets.items()
if LocalTag.from_tag(tag).local_index == self.index
}

@check_shapes(
"return[0]: [D]",
"return[1]: []",
)
def get_dataset_min(self, dataset: Optional[Dataset]) -> Tuple[TensorType, TensorType]:
def get_dataset_min(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Tuple[TensorType, TensorType]:
"""Calculate the minimum of the box using the given dataset."""
if dataset is None:
if datasets is None:
raise ValueError("""dataset must be provided""")
dataset = next(iter(datasets.values())) # Expect only one dataset.

in_tr = self.contains(dataset.query_points)
in_tr_obs = tf.where(
Expand Down Expand Up @@ -1603,17 +1630,21 @@ def initialize(

super().initialize(models, datasets)

def get_dataset_filter_mask(
def get_datasets_filter_mask(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Tuple[Optional[Tag], Optional[tf.Tensor]]:
) -> Optional[Mapping[Tag, tf.Tensor]]:
# Don't filter out any points from the dataset by bypassing the
# SingleObjectiveTrustRegionBox method.
return super(SingleObjectiveTrustRegionBox, self).get_dataset_filter_mask(datasets)
return super(SingleObjectiveTrustRegionBox, self).get_datasets_filter_mask(datasets)

@inherit_check_shapes
def get_dataset_min(self, dataset: Optional[Dataset]) -> Tuple[TensorType, TensorType]:
if dataset is None:
def get_dataset_min(
self, datasets: Optional[Mapping[Tag, Dataset]]
) -> Tuple[TensorType, TensorType]:
"""Calculate the minimum of the box using the given dataset."""
if datasets is None:
raise ValueError("""dataset must be provided""")
dataset = next(iter(datasets.values())) # Expect only one dataset.

# Always return the global minimum.
ix = tf.argmin(dataset.observations)
Expand Down

0 comments on commit ee1ee56

Please sign in to comment.