Skip to content

Commit

Permalink
Don't take TR ymin from init dataset (#843)
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani authored May 28, 2024
1 parent 452b7e2 commit 6eebbcd
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
15 changes: 8 additions & 7 deletions tests/unit/acquisition/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def test_trego_for_default_state(
npt.assert_array_almost_equal(ret_subspace.lower, lower_bound)
npt.assert_array_almost_equal(ret_subspace.upper, upper_bound)
npt.assert_array_almost_equal(query_point, [expected_query_point], 5)
npt.assert_array_almost_equal(ret_subspace._y_min, [0.012])
npt.assert_array_almost_equal(ret_subspace._y_min, [np.inf])
assert ret_subspace._is_global


Expand Down Expand Up @@ -1684,9 +1684,9 @@ def test_multi_trust_region_box_inits_regions_that_need_it() -> None:

# Change all eps values, with the second region have a lower eps than the min. This region
# should be re-initialized.
subspaces[0].eps = 0.35
subspaces[0].eps = 0.45
subspaces[1].eps = 0.25
subspaces[2].eps = 0.32
subspaces[2].eps = 0.42

# Check the property values.
assert bool(subspaces[0].requires_initialization) is False
Expand All @@ -1698,9 +1698,9 @@ def test_multi_trust_region_box_inits_regions_that_need_it() -> None:

# Check that the second region was re-initialized.
assert state is not None
assert cast(TestTrustRegionBox, state.subspaces[0]).eps < 0.35 # Expect reduction.
assert cast(TestTrustRegionBox, state.subspaces[0]).eps > 0.45 # Expect increase, step success.
assert cast(TestTrustRegionBox, state.subspaces[1]).eps == 0.4 # Expect re-initialized value.
assert cast(TestTrustRegionBox, state.subspaces[2]).eps < 0.32 # Expect reduction.
assert cast(TestTrustRegionBox, state.subspaces[0]).eps > 0.42 # Expect increase, step success.


def test_multi_trust_region_box_acquire_with_state() -> None:
Expand Down Expand Up @@ -1742,14 +1742,15 @@ def test_multi_trust_region_box_acquire_with_state() -> None:

assert next_state is not None
assert points.shape == [1, 3, 2]
# The regions correspond to first, third and first points in the dataset.
# The regions correspond to first, third and first points in the dataset. However, for the
# region that is initialized, the point is not used and value is set to infinity.
# First two regions should be updated.
# The third region should be initialized and not updated, as it is too close to the first
# subspace.
for point, subspace, exp_obs, exp_eps in zip(
points[0],
cast(Sequence[TestTrustRegionBox], next_state.subspaces),
[dataset.observations[0], dataset.observations[2], dataset.observations[0]],
[dataset.observations[0], dataset.observations[2], np.inf],
[0.1, 0.1, 0.07], # First two regions updated, third region initialized.
):
assert point in subspace
Expand Down
6 changes: 4 additions & 2 deletions trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,7 +1621,7 @@ def __init__(
self._update_domain()
# Initial value of the region minimum is set to infinity as we have not yet observed any
# data.
self._y_min = np.inf
self._y_min = tf.constant(np.inf, dtype=self.location.dtype)

def _init_eps(self) -> None:
self.eps = self._zeta * (self.global_search_space.upper - self.global_search_space.lower)
Expand Down Expand Up @@ -1665,7 +1665,9 @@ def initialize(
self._step_is_success = False
self._init_eps()
self._update_domain()
_, self._y_min = self.get_dataset_min(datasets)
# We haven't necessarily observed any data yet for this region; force first step to always
# be successful by setting the minimum to infinity.
self._y_min = tf.constant(np.inf, dtype=self.location.dtype)
self._initialized = True

def update(
Expand Down

0 comments on commit 6eebbcd

Please sign in to comment.