Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter-out local datasets when calling base-rule #805

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tests/integration/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
import pickle
import tempfile
from typing import Callable, Tuple, Union
from typing import Callable, Mapping, Tuple, Union

import numpy.testing as npt
import pytest
Expand All @@ -36,14 +36,15 @@
from trieste.acquisition.utils import copy_to_local_models
from trieste.ask_tell_optimization import AskTellOptimizer
from trieste.bayesian_optimizer import OptimizationResult, Record
from trieste.data import Dataset
from trieste.logging import set_step_number, tensorboard_writer
from trieste.models import TrainableProbabilisticModel
from trieste.models.gpflow import GaussianProcessRegression, build_gpr
from trieste.objectives import ScaledBranin, SimpleQuadratic
from trieste.objectives.utils import mk_batch_observer, mk_observer
from trieste.observer import OBJECTIVE
from trieste.space import Box, SearchSpace
from trieste.types import State, TensorType
from trieste.types import State, Tag, TensorType

# Optimizer parameters for testing against the branin function.
# We use a copy of these for a quicker test against a simple quadratic function
Expand Down Expand Up @@ -212,7 +213,9 @@ def _test_ask_tell_optimization_finds_minima(

# If query points are rank 3, then use a batched observer.
if tf.rank(new_point) == 3:
new_data_point = batch_observer(new_point)
new_data_point: Union[Mapping[Tag, Dataset], Dataset] = batch_observer(
new_point
)
else:
new_data_point = observer(new_point)

Expand Down
28 changes: 28 additions & 0 deletions tests/unit/acquisition/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
from collections.abc import Mapping
from typing import Callable, Optional
from unittest.mock import ANY, MagicMock

import gpflow
import numpy as np
Expand Down Expand Up @@ -1798,6 +1799,33 @@ def test_multi_trust_region_box_updated_datasets_are_in_regions(
)


def test_multi_trust_region_box_acquire_filters() -> None:
# Create some dummy models and datasets
models: Mapping[Tag, ANY] = {"global_tag": MagicMock()}
datasets: Mapping[Tag, ANY] = {
LocalizedTag("tag1", 1): MagicMock(),
LocalizedTag("tag1", 2): MagicMock(),
LocalizedTag("tag2", 1): MagicMock(),
LocalizedTag("tag2", 2): MagicMock(),
"global_tag": MagicMock(),
}

search_space = Box([0.0], [1.0])
mock_base_rule = MagicMock(spec=EfficientGlobalOptimization)
mock_base_rule.acquire.return_value = tf.constant([[[0.0], [0.0]]], dtype=tf.float64)

# Create a BatchTrustRegionBox instance with the mock base_rule.
subspaces = [SingleObjectiveTrustRegionBox(search_space) for _ in range(2)]
rule: BatchTrustRegionBox[ProbabilisticModel] = BatchTrustRegionBox(subspaces, mock_base_rule)

rule.acquire(search_space, models, datasets)(None)

# Only the global tags should be passed to the base_rule acquire call.
mock_base_rule.acquire.assert_called_once_with(
ANY, models, {"global_tag": datasets["global_tag"]}
)


def test_multi_trust_region_box_state_deepcopy() -> None:
search_space = Box([0.0, 0.0], [1.0, 1.0])
dataset = Dataset(
Expand Down
18 changes: 15 additions & 3 deletions trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,8 +1234,8 @@ def acquire(
# Otherwise, run the base rule as is (i.e as a batch), once with all models and datasets.
# Note: this should only trigger on the first call to `acquire`, as after that we will
# have a list of rules in `self._rules`.
if self._rules is None and (
_num_local_models > 0 or not isinstance(self._rule, EfficientGlobalOptimization)
if self._rules is None and not (
_num_local_models == 0 and isinstance(self._rule, EfficientGlobalOptimization)
):
self._rules = [copy.deepcopy(self._rule) for _ in range(num_subspaces)]

Expand Down Expand Up @@ -1282,7 +1282,19 @@ def state_func(
_points.append(rule.acquire(subspace, _models, _datasets))
points = tf.stack(_points, axis=1)
else:
points = self._rule.acquire(acquisition_space, models, datasets)
# Filter out local datasets as this is a rule (currently only EGO) with normal
# acquisition functions that don't expect local datasets.
# Note: no need to filter out local models, as setups with local models
# are handled above (i.e. we run the base rule sequentially for each subspace).
if datasets is not None:
_datasets = {
tag: dataset
for tag, dataset in datasets.items()
if not LocalizedTag.from_tag(tag).is_local
}
else:
_datasets = None
points = self._rule.acquire(acquisition_space, models, _datasets)

# We may modify the regions in filter_datasets later, so return a copy.
state_ = BatchTrustRegion.State(copy.deepcopy(acquisition_space))
Expand Down
2 changes: 1 addition & 1 deletion trieste/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def mk_multi_observer(**kwargs: Callable[[TensorType], TensorType]) -> MultiObse
def mk_batch_observer(
objective_or_observer: Union[Callable[[TensorType], TensorType], Observer],
default_key: Tag = OBJECTIVE,
) -> Observer:
) -> MultiObserver:
"""
Create an observer that returns the data from ``objective`` or an existing ``observer``
separately for each query point in a batch.
Expand Down
Loading