Skip to content

Commit

Permalink
Allow input-multi-observers for batch observer
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Oct 11, 2023
1 parent 69a8590 commit 40df585
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 55 deletions.
2 changes: 1 addition & 1 deletion tests/integration/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _test_ask_tell_optimization_finds_minima(
search_space = ScaledBranin.search_space
initial_query_points = search_space.sample(5)
observer = mk_observer(ScaledBranin.objective if optimize_branin else SimpleQuadratic.objective)
batch_observer = mk_batch_observer(observer, OBJECTIVE)
batch_observer = mk_batch_observer(observer)
initial_data = observer(initial_query_points)

model = GaussianProcessRegression(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/acquisition/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,7 @@ def test_multi_trust_region_box_updated_datasets_are_in_regions(
)
rule = BatchTrustRegionBox(subspaces, base_rule)
_, points = rule.acquire(search_space, models, datasets)(None)
observer = mk_batch_observer(quadratic, OBJECTIVE)
observer = mk_batch_observer(quadratic)
new_data = observer(points)
assert not isinstance(new_data, Dataset)
datasets = rule.update_datasets(datasets, new_data)
Expand Down
57 changes: 33 additions & 24 deletions tests/unit/objectives/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Union
from typing import Callable, Sequence, Union

import numpy.testing as npt
import pytest
import tensorflow as tf

from trieste.data import Dataset
from trieste.objectives.utils import mk_batch_observer, mk_multi_observer, mk_observer
from trieste.observer import SingleObserver
from trieste.observer import Observer
from trieste.types import Tag, TensorType
from trieste.utils.misc import LocalTag

Expand Down Expand Up @@ -58,40 +58,49 @@ def test_mk_multi_observer() -> None:
npt.assert_array_equal(ys["bar"].observations, x_ - 1)


def test_mk_batch_observer_raises_on_multi_observer() -> None:
observer = mk_batch_observer(mk_multi_observer(foo=lambda x: x + 1, bar=lambda x: x - 1))
with pytest.raises(ValueError, match="mk_batch_observer does not support multi-observers"):
observer(tf.constant([[[3.0]]]))


@pytest.mark.parametrize("input_objective", [lambda x: x, lambda x: Dataset(x, x)])
@pytest.mark.parametrize(
"input_objective, keys",
[
(lambda x: x, ["baz"]),
(lambda x: Dataset(x, x), ["baz"]),
(mk_multi_observer(foo=lambda x: x + 1, bar=lambda x: x - 1), ["foo", "bar"]),
],
)
@pytest.mark.parametrize("batch_size", [1, 2, 3])
@pytest.mark.parametrize("num_query_points_per_batch", [1, 2])
@pytest.mark.parametrize("key", [None, "bar"])
def test_mk_batch_observer(
input_objective: Union[Callable[[TensorType], TensorType], SingleObserver],
input_objective: Union[Callable[[TensorType], TensorType], Observer],
keys: Sequence[Tag],
batch_size: int,
num_query_points_per_batch: int,
key: Tag,
) -> None:
x_ = tf.reshape(
tf.constant(range(batch_size * num_query_points_per_batch), tf.float64),
(num_query_points_per_batch, batch_size, 1),
)
ys = mk_batch_observer(input_objective, key)(x_)

if key is None:
assert isinstance(ys, Dataset)
npt.assert_array_equal(ys.query_points, tf.reshape(x_, [-1, 1]))
npt.assert_array_equal(ys.observations, tf.reshape(x_, [-1, 1]))
else:
assert isinstance(ys, dict)
exp_keys = {LocalTag(key, i).tag for i in range(batch_size)}
ys = mk_batch_observer(input_objective, "baz")(x_)

assert isinstance(ys, dict)

# Check keys.
exp_keys = set()
for key in keys:
exp_keys.update({LocalTag(key, i).tag for i in range(batch_size)})
exp_keys.add(key)
assert ys.keys() == exp_keys

# Check datasets.
for key in keys:
# Different observers (in parameterize above) return different observation values.
if key == "foo":
exp_o = x_ + 1
elif key == "bar":
exp_o = x_ - 1
else:
exp_o = x_

assert ys.keys() == exp_keys
npt.assert_array_equal(ys[key].query_points, tf.reshape(x_, [-1, 1]))
npt.assert_array_equal(ys[key].observations, tf.reshape(x_, [-1, 1]))
npt.assert_array_equal(ys[key].observations, tf.reshape(exp_o, [-1, 1]))
for i in range(batch_size):
npt.assert_array_equal(ys[LocalTag(key, i)].query_points, x_[:, i])
npt.assert_array_equal(ys[LocalTag(key, i)].observations, x_[:, i])
npt.assert_array_equal(ys[LocalTag(key, i)].observations, exp_o[:, i])
2 changes: 1 addition & 1 deletion tests/unit/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def update(self, dataset: Dataset) -> None:
for tag, model in models.items():
model._tag = tag

observer = mk_batch_observer(lambda x: Dataset(x, x), OBJECTIVE)
observer = mk_batch_observer(lambda x: Dataset(x, x))
rule = FixedAcquisitionRule(query_points)
ask_tell = AskTellOptimizer(search_space, init_data, models, rule)

Expand Down
18 changes: 12 additions & 6 deletions trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,19 @@ def update_datasets(
:param new_datasets: The new datasets.
:return: The updated datasets.
"""
# Account for the case where there may be an initial dataset that is not tagged
# per region. In this case, only the global dataset will exist in datasets. We
# want to copy this initial dataset to all the regions.
#
# In order to support local datasets, account for the case where there may be an initial
# dataset that is not tagged per region. In this case, only the global dataset will exist
# in datasets. We want to copy this initial dataset to all the regions.
# If a tag from tagged_output does not exist in datasets, then add it to
# datasets by copying the dataset from datasets with the same tag-prefix.
# Otherwise keep the existing dataset from datasets.
# datasets by copying the data from datasets with the same global tag. Otherwise keep the
# existing data from datasets.
#
# Note: this replication of initial data can potentially cause an issue when a global model
# is being used with local datasets, as the points may be repeated. This will only be an
# issue if two regions overlap and both contain that initial data-point -- as filtering
# (in BatchTrustRegion) would otherwise remove duplicates. The main way to avoid the issue
# in this scenario is to provide local initial datasets, instead of a global initial
# dataset.
updated_datasets = {}
for tag in new_datasets:
_, dataset = get_value_for_tag(datasets, [tag, LocalTag.from_tag(tag).global_tag])
Expand Down
2 changes: 1 addition & 1 deletion trieste/bayesian_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def optimize(
observer = self._observer
# If query_points are rank 3, then use a batched observer.
if tf.rank(query_points) == 3:
observer = mk_batch_observer(observer, OBJECTIVE)
observer = mk_batch_observer(observer)
observer_output = observer(query_points)

tagged_output = (
Expand Down
42 changes: 21 additions & 21 deletions trieste/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from check_shapes import check_shapes

from ..data import Dataset
from ..observer import MultiObserver, Observer, SingleObserver
from ..observer import OBJECTIVE, MultiObserver, Observer, SingleObserver
from ..types import Tag, TensorType
from ..utils.misc import LocalTag

Expand Down Expand Up @@ -64,21 +64,18 @@ def mk_multi_observer(**kwargs: Callable[[TensorType], TensorType]) -> MultiObse


def mk_batch_observer(
objective_or_observer: Union[Callable[[TensorType], TensorType], SingleObserver],
key: Optional[Tag] = None,
objective_or_observer: Union[Callable[[TensorType], TensorType], Observer],
default_key: Tag = OBJECTIVE,
) -> Observer:
"""
Create an observer that returns the data from ``objective`` or an existing ``observer``
separately for each query point in a batch.
:param objective_or_observer: An objective or an existing observer designed to be used with a
single data set and model.
:param key: An optional key to use to access the data from the observer result.
:param objective_or_observer: An objective or an existing observer.
:param default_key: The default key to use if ``objective_or_observer`` is an objective or
does not return a mapping.
:return: A multi-observer across the batch dimension of query points, returning the data from
``objective``. If ``key`` is provided, the observer will be a mapping. Otherwise, it will
return a single dataset.
:raise ValueError (or tf.errors.InvalidArgumentError): If ``objective_or_observer`` is a
multi-observer.
``objective_or_observer``.
"""

@check_shapes("qps: [n_points, batch_size, n_dims]")
Expand All @@ -92,23 +89,26 @@ def _observer(qps: TensorType) -> Mapping[Tag, Dataset]:
qps = tf.reshape(qps, [-1, qps.shape[-1]])
obs_or_dataset = objective_or_observer(qps)

if isinstance(obs_or_dataset, Mapping):
raise ValueError("mk_batch_observer does not support multi-observers")
elif not isinstance(obs_or_dataset, Dataset):
if not isinstance(obs_or_dataset, (Mapping, Dataset)):
# Just a single observation, so wrap in a dataset.
obs_or_dataset = Dataset(qps, obs_or_dataset)

if key is None:
# Always use rank 2 shape as models (e.g. GPR) expect this, so return as is.
return obs_or_dataset
else:
if isinstance(obs_or_dataset, Dataset):
# Convert to a mapping with a default key.
obs_or_dataset = {default_key: obs_or_dataset}

datasets = {}
for key, dataset in obs_or_dataset.items():
# Include overall dataset and per batch dataset.
obs = obs_or_dataset.observations
obs = dataset.observations
qps = tf.reshape(qps, [-1, batch_size, qps.shape[-1]])
obs = tf.reshape(obs, [-1, batch_size, obs.shape[-1]])
datasets: Mapping[Tag, Dataset] = {
key: obs_or_dataset,
_datasets = {
key: dataset,
**{LocalTag(key, i): Dataset(qps[:, i], obs[:, i]) for i in range(batch_size)},
}
return datasets
datasets.update(_datasets)

return datasets

return _observer

0 comments on commit 40df585

Please sign in to comment.