diff --git a/tests/unit/acquisition/test_rule.py b/tests/unit/acquisition/test_rule.py index 730fa3c096..5d17465fe1 100644 --- a/tests/unit/acquisition/test_rule.py +++ b/tests/unit/acquisition/test_rule.py @@ -1213,7 +1213,7 @@ def test_trust_region_box_get_dataset_min() -> None: trb = SingleObjectiveTrustRegionBox(search_space) trb._lower = tf.constant([0.2, 0.2], dtype=tf.float64) trb._upper = tf.constant([0.7, 0.7], dtype=tf.float64) - x_min, y_min = trb.get_dataset_min(dataset) + x_min, y_min = trb.get_dataset_min({"foo": dataset}) npt.assert_array_equal(x_min, tf.constant([0.3, 0.4], dtype=tf.float64)) npt.assert_array_equal(y_min, tf.constant([0.2], dtype=tf.float64)) @@ -1227,7 +1227,7 @@ def test_trust_region_box_get_dataset_min_outside_search_space() -> None: tf.constant([[0.7], [0.9]], dtype=tf.float64), ) trb = SingleObjectiveTrustRegionBox(search_space) - x_min, y_min = trb.get_dataset_min(dataset) + x_min, y_min = trb.get_dataset_min({"foo": dataset}) npt.assert_array_equal(x_min, tf.constant([1.2, 1.3], dtype=tf.float64)) npt.assert_array_equal(y_min, tf.constant([np.inf], dtype=tf.float64)) diff --git a/trieste/acquisition/rule.py b/trieste/acquisition/rule.py index d7a2592dad..af108735df 100644 --- a/trieste/acquisition/rule.py +++ b/trieste/acquisition/rule.py @@ -20,9 +20,22 @@ import copy import math from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Callable, Generic, Optional, Sequence, Tuple, TypeVar, Union, cast, overload +from typing import ( + Any, + Callable, + Dict, + Generic, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, + overload, +) import numpy as np from check_shapes import check_shapes, inherit_check_shapes @@ -1005,51 +1018,53 @@ def update( """ ... - def select_model( + def select_models( self, models: Optional[Mapping[Tag, ProbabilisticModelType]] - ) -> Tuple[Optional[Tag], Optional[ProbabilisticModelType]]: + ) -> Optional[Mapping[Tag, ProbabilisticModelType]]: """ - Select a single model belonging to this region. This is an optional method that is - only required if the region is used with single model acquisition functions. + Select models belonging to this region for acquisition. :param models: The model for each tag. - :return: The model belonging to this region. + :return: The models belonging to this region. """ - # By default return the OBJECTIVE model. - return get_value_for_tag(models, OBJECTIVE) + # By default return all the models. + return models - def select_dataset( + def select_datasets( self, datasets: Optional[Mapping[Tag, Dataset]] - ) -> Tuple[Optional[Tag], Optional[Dataset]]: + ) -> Optional[Mapping[Tag, Dataset]]: """ - Select a single dataset belonging to this region. This is an optional method that is - only required if the region is used with single model acquisition functions. + Select datasets belonging to this region for acquisition. :param datasets: The dataset for each tag. - :return: The tag and associated dataset belonging to this region. + :return: The datasets belonging to this region. """ - # By default return the OBJECTIVE dataset. - return get_value_for_tag(datasets, OBJECTIVE) + # By default return all the datasets. + return datasets - def get_dataset_filter_mask( + def get_datasets_filter_mask( self, datasets: Optional[Mapping[Tag, Dataset]] - ) -> Tuple[Optional[Tag], Optional[tf.Tensor]]: + ) -> Optional[Mapping[Tag, tf.Tensor]]: """ - Return a boolean mask that can be used to filter out points from the dataset that + Return a boolean mask that can be used to filter out points from the datasets that belong to this region. :param datasets: The dataset for each tag. - :return: The tag for the selected dataset and a boolean mask that can be used to filter - that dataset. A value of `True` indicates that the corresponding point should be kept. + :return: A mapping for each tag belonging to this region, to a boolean mask that can be + used to filter out points from the datasets. A value of `True` indicates that the + corresponding point should be kept. """ - # Always select the region dataset for filtering. Don't directly filter the global dataset. + # Only select the region datasets for filtering. Don't directly filter the global dataset. assert self.index is not None, "the index should be set for filtering local datasets" - tag, dataset = get_value_for_tag(datasets, LocalTag(OBJECTIVE, self.index)) - if dataset is None: - return None, None + if datasets is None: + return None else: # By default return a mask that filters nothing. - return tag, tf.ones(tf.shape(dataset.query_points)[:-1], dtype=tf.bool) + return { + tag: tf.ones(tf.shape(dataset.query_points)[:-1], dtype=tf.bool) + for tag, dataset in datasets.items() + if LocalTag.from_tag(tag).local_index == self.index + } UpdatableTrustRegionType = TypeVar("UpdatableTrustRegionType", bound=UpdatableTrustRegion) @@ -1143,27 +1158,28 @@ def acquire( assert self._tags is not None assert self._init_subspaces is not None - num_subspaces = len(self._tags) - num_local_models = 0 + num_local_models: Dict[Tag, int] = defaultdict(int) for tag in models: ltag = LocalTag.from_tag(tag) - if ltag.is_local and ltag.global_tag == OBJECTIVE: - num_local_models += 1 - assert num_local_models in [0, num_subspaces], ( + if ltag.is_local: + num_local_models[ltag.global_tag] += 1 + num_local_models_vals = set(num_local_models.values()) + assert ( + len(num_local_models_vals) <= 1 + ), f"The number of local models should be the same for all tags, got {num_local_models}" + _num_local_models = 0 if len(num_local_models_vals) == 0 else num_local_models_vals.pop() + + num_subspaces = len(self._tags) + assert _num_local_models in [0, num_subspaces], ( f"When using local models, the number of subspaces {num_subspaces} should be equal to " - f"the number of local objective models {num_local_models}" + f"the number of local objective models {_num_local_models}" ) - # If the base rule is a single model acquisition rule, but we have local - # models, run the (deepcopied) base rule sequentially for each subspace. + # If we have local models, run the (deepcopied) base rule sequentially for each subspace. # Otherwise, run the base rule as is, once with all models and datasets. # Note: this should only trigger on the first call to `acquire`, as after that we will # have a list of rules in `self._rules`. - if ( - isinstance(self._rule, EfficientGlobalOptimization) - and hasattr(self._rule._builder, "single_builder") - and (num_local_models > 0 or OBJECTIVE not in models) - ): + if _num_local_models > 0: self._rules = [copy.deepcopy(self._rule) for _ in range(num_subspaces)] def state_func( @@ -1209,16 +1225,20 @@ def state_func( if self._rules is not None: _points = [] for subspace, rule in zip(subspaces, self._rules): - _, _model = subspace.select_model(models) - _, _dataset = subspace.select_dataset(datasets) - assert _model is not None - # Using default tag, as that is what single model acquisition builders expect. - model = {OBJECTIVE: _model} - if _dataset is None: - dataset = None - else: - dataset = {OBJECTIVE: _dataset} - _points.append(rule.acquire(subspace, model, dataset)) + _models = subspace.select_models(models) + _datasets = subspace.select_datasets(datasets) + assert _models is not None + # Remap all local tags to global ones. One reason is that single model + # acquisition builders expect OBJECTIVE to exist. + _models = { + LocalTag.from_tag(tag).global_tag: model for tag, model in _models.items() + } + if _datasets is not None: + _datasets = { + LocalTag.from_tag(tag).global_tag: dataset + for tag, dataset in _datasets.items() + } + _points.append(rule.acquire(subspace, _models, _datasets)) points = tf.stack(_points, axis=1) else: points = self._rule.acquire(acquisition_space, models, datasets) @@ -1289,11 +1309,12 @@ def update_datasets( # be to remove this assumption. assert self._init_subspaces is not None for subspace in self._init_subspaces: - tag, in_region = subspace.get_dataset_filter_mask(datasets) - assert tag is not None - ltag = LocalTag.from_tag(tag) - assert ltag.is_local, f"can only filter local tags, got {tag}" - used_masks[tag] = tf.logical_or(used_masks[tag], in_region) + in_region_masks = subspace.get_datasets_filter_mask(datasets) + if in_region_masks is not None: + for tag, in_region in in_region_masks.items(): + ltag = LocalTag.from_tag(tag) + assert ltag.is_local, f"can only filter local tags, got {tag}" + used_masks[tag] = tf.logical_or(used_masks[tag], in_region) filtered_datasets = {} global_tags = [] # Global datasets to re-generate. @@ -1381,7 +1402,7 @@ def initialize( Initialize the box by sampling a location from the global search space and setting the bounds. """ - _, dataset = self.select_dataset(datasets) + dataset = self.select_datasets(datasets) self.location = tf.squeeze(self.global_search_space.sample(1), axis=0) self._step_is_success = False @@ -1405,7 +1426,7 @@ def update( ``1 / beta``. Conversely, if it was unsuccessful, the size is reduced by the factor ``beta``. """ - _, dataset = self.select_dataset(datasets) + dataset = self.select_datasets(datasets) if tf.reduce_any(self.eps < self._min_eps): self.initialize(models, datasets) @@ -1420,22 +1441,21 @@ def update( self._update_bounds() self._y_min = y_min - def select_model( + def select_models( self, models: Optional[Mapping[Tag, ProbabilisticModelType]] - ) -> Tuple[Optional[Tag], Optional[ProbabilisticModelType]]: - # Select the model belonging to this box. Note there isn't necessarily a one-to-one - # mapping between regions and models. + ) -> Optional[Mapping[Tag, ProbabilisticModelType]]: + # Select the model belonging to this box. if self.index is None: tags = [OBJECTIVE] # If no index, then pick the global model. else: tags = [LocalTag(OBJECTIVE, self.index), OBJECTIVE] # Prefer local model if available. - return get_value_for_tag(models, tags) + tag, model = get_value_for_tag(models, tags) + return {tag: model} if model is not None else None - def select_dataset( + def select_datasets( self, datasets: Optional[Mapping[Tag, Dataset]] - ) -> Tuple[Optional[Tag], Optional[Dataset]]: - # Select the dataset belonging to this box. Note there isn't necessarily a one-to-one - # mapping between regions and datasets. + ) -> Optional[Mapping[Tag, Dataset]]: + # Select the dataset belonging to this box. if self.index is None: tags = [OBJECTIVE] # If no index, then pick the global dataset. else: @@ -1443,28 +1463,35 @@ def select_dataset( LocalTag(OBJECTIVE, self.index), OBJECTIVE, ] # Prefer local dataset if available. - return get_value_for_tag(datasets, tags) + tag, dataset = get_value_for_tag(datasets, tags) + return {tag: dataset} if dataset is not None else None - def get_dataset_filter_mask( + def get_datasets_filter_mask( self, datasets: Optional[Mapping[Tag, Dataset]] - ) -> Tuple[Optional[Tag], Optional[tf.Tensor]]: - # Always select the region dataset for filtering. Don't directly filter the global dataset. + ) -> Optional[Mapping[Tag, tf.Tensor]]: + # Only select the region datasets for filtering. Don't directly filter the global dataset. assert self.index is not None, "the index should be set for filtering local datasets" - tag, dataset = get_value_for_tag(datasets, LocalTag(OBJECTIVE, self.index)) - if dataset is None: - return None, None + if datasets is None: + return None else: # Only keep points that are in the box. - return tag, self.contains(dataset.query_points) + return { + tag: self.contains(dataset.query_points) + for tag, dataset in datasets.items() + if LocalTag.from_tag(tag).local_index == self.index + } @check_shapes( "return[0]: [D]", "return[1]: []", ) - def get_dataset_min(self, dataset: Optional[Dataset]) -> Tuple[TensorType, TensorType]: + def get_dataset_min( + self, datasets: Optional[Mapping[Tag, Dataset]] + ) -> Tuple[TensorType, TensorType]: """Calculate the minimum of the box using the given dataset.""" - if dataset is None: + if datasets is None: raise ValueError("""dataset must be provided""") + dataset = next(iter(datasets.values())) # Expect only one dataset. in_tr = self.contains(dataset.query_points) in_tr_obs = tf.where( @@ -1603,17 +1630,21 @@ def initialize( super().initialize(models, datasets) - def get_dataset_filter_mask( + def get_datasets_filter_mask( self, datasets: Optional[Mapping[Tag, Dataset]] - ) -> Tuple[Optional[Tag], Optional[tf.Tensor]]: + ) -> Optional[Mapping[Tag, tf.Tensor]]: # Don't filter out any points from the dataset by bypassing the # SingleObjectiveTrustRegionBox method. - return super(SingleObjectiveTrustRegionBox, self).get_dataset_filter_mask(datasets) + return super(SingleObjectiveTrustRegionBox, self).get_datasets_filter_mask(datasets) @inherit_check_shapes - def get_dataset_min(self, dataset: Optional[Dataset]) -> Tuple[TensorType, TensorType]: - if dataset is None: + def get_dataset_min( + self, datasets: Optional[Mapping[Tag, Dataset]] + ) -> Tuple[TensorType, TensorType]: + """Calculate the minimum of the box using the given dataset.""" + if datasets is None: raise ValueError("""dataset must be provided""") + dataset = next(iter(datasets.values())) # Expect only one dataset. # Always return the global minimum. ix = tf.argmin(dataset.observations)