Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch trust regions implementation of TURBO #791

Merged
merged 71 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
c89583e
Add support for local models and datasets (WIP)
khurram-ghani Sep 21, 2023
c8aebec
Add unit test for local models (WIP)
khurram-ghani Sep 25, 2023
b361e6d
Merge remote-tracking branch 'origin/develop' into khurram/batch_turbo
khurram-ghani Sep 27, 2023
4b3c2de
Update multi model/dataset test (WIP)
khurram-ghani Sep 29, 2023
28803aa
Add unit test for keep datasets in regions
khurram-ghani Oct 3, 2023
e99df96
Add more tests and move to local tags class
khurram-ghani Oct 6, 2023
2de6271
Always include global dataset in mapping
khurram-ghani Oct 6, 2023
6236c95
Add filter_mask method to trust region
khurram-ghani Oct 6, 2023
e96ebe8
Add more testing
khurram-ghani Oct 9, 2023
9a7dd23
Fix mypy model type issues
khurram-ghani Oct 10, 2023
4622a98
Add ask_tell testing
khurram-ghani Oct 10, 2023
af44e48
Fix summary init when only global dataset
khurram-ghani Oct 10, 2023
ef9e455
Remove walrus operator
khurram-ghani Oct 11, 2023
59e0cab
Update test, ask_tell data not changed in-place
khurram-ghani Oct 11, 2023
e4a5342
Add some test comments
khurram-ghani Oct 11, 2023
69a8590
Add some rule comments
khurram-ghani Oct 11, 2023
40df585
Allow input-multi-observers for batch observer
khurram-ghani Oct 11, 2023
ee1ee56
Allow multiple models/datasets for base rule
khurram-ghani Oct 11, 2023
7d03a04
Support multiple models/datasets in region selects
khurram-ghani Oct 12, 2023
8551d72
Fix TR plotting history colors
khurram-ghani Oct 12, 2023
b571733
Add notebook init points explanation
khurram-ghani Oct 12, 2023
2a934d1
Rename region index and add init param
khurram-ghani Oct 16, 2023
bea54ea
Merge branch 'develop' into khurram/local_models
khurram-ghani Oct 16, 2023
c8fb1d1
WIP
khurram-ghani Oct 13, 2023
c48e6d5
Add more TURBO implementation + attempt type fixes
khurram-ghani Oct 17, 2023
52f7975
Remove old comment
khurram-ghani Oct 17, 2023
b2eb662
Tidy-up redundant expression
khurram-ghani Oct 17, 2023
69e86cf
Fix TURBOBox and temp changes in TURBO for match
khurram-ghani Oct 17, 2023
08bccf6
Merge remote-tracking branch 'origin/khurram/local_models' into khurr…
khurram-ghani Oct 17, 2023
4a36cf9
Add intermediate box region class
khurram-ghani Oct 17, 2023
523aaab
Keep full datasets along with filtered ones
khurram-ghani Oct 18, 2023
8716143
Merge remote-tracking branch 'origin/khurram/local_models' into khurr…
khurram-ghani Oct 18, 2023
5f66432
Move subspace update to a new rule method
khurram-ghani Oct 19, 2023
7b7bf33
Add temp notebook for TURBO comparisons
khurram-ghani Oct 19, 2023
3789af2
Save TR subspaces in acquire to re-use later
khurram-ghani Oct 19, 2023
661c3a2
Update notebook to use TURBOBOx
khurram-ghani Oct 19, 2023
5b3ec0f
Make changes from PR feedback
khurram-ghani Nov 16, 2023
cb3ca46
Merge branch 'khurram/local_models' into khurram/batch_turbo
khurram-ghani Nov 16, 2023
6767f10
Fix rename after merge
khurram-ghani Nov 16, 2023
ce90838
Fix compare notebook after merge
khurram-ghani Nov 17, 2023
31b96d0
Move rule create later and tidy filtering dataset
khurram-ghani Nov 17, 2023
662947d
More testing in notebook
khurram-ghani Nov 17, 2023
8e26fba
Remove redundant dataset filtering
khurram-ghani Nov 17, 2023
2c21f85
Merge remote-tracking branch 'origin/develop' into khurram/local_models
khurram-ghani Nov 22, 2023
e5cccc8
Address some of the recent feedback
khurram-ghani Nov 23, 2023
25da01b
Fix dataset mypy error
khurram-ghani Nov 23, 2023
292faaa
Copy dataset in optimizers to avoid changing it
khurram-ghani Nov 24, 2023
9170c64
Share DatasetChecker and tidy-up exp values in tests
khurram-ghani Nov 24, 2023
efd2fc0
Address more feedback
khurram-ghani Nov 24, 2023
cdd4cbc
Merge branch 'khurram/local_models' into khurram/batch_turbo
khurram-ghani Nov 27, 2023
8d497aa
Remove prev TURBO and update tests to use new class
khurram-ghani Nov 27, 2023
586f9a8
Remove notebook for testing
khurram-ghani Nov 27, 2023
d8442a8
Create dataset and update at start of optim
khurram-ghani Nov 29, 2023
ff84690
Merge remote-tracking branch 'origin/develop' into khurram/local_models
khurram-ghani Nov 29, 2023
505631a
Avoid default num_models in integ tests
khurram-ghani Nov 29, 2023
b90d3a1
Fix old python typing issue
khurram-ghani Nov 29, 2023
250c647
Merge branch 'khurram/local_models' into khurram/batch_turbo
khurram-ghani Nov 29, 2023
7aec057
Address feedback
khurram-ghani Nov 30, 2023
6a42f05
Address more comments
khurram-ghani Dec 5, 2023
4e5aad6
Only copy state with track_state==True
khurram-ghani Dec 5, 2023
c01432c
Add comment explaining copy
khurram-ghani Dec 5, 2023
13b313d
Deepcopy subspace internal to rule
khurram-ghani Dec 5, 2023
a8ed8ce
Keep global datasets unfiltered
khurram-ghani Dec 11, 2023
9ba931d
Add notebook intro and improve TREGO text
khurram-ghani Dec 13, 2023
1de7e20
Merge remote-tracking branch 'origin/develop' into khurram/local_models
khurram-ghani Dec 13, 2023
a75e618
Use flatten_... func and add comment
khurram-ghani Dec 13, 2023
50afc55
Merge remote-tracking branch 'origin/khurram/local_models' into khurr…
khurram-ghani Dec 13, 2023
99aa9b2
Improve TR explanations
khurram-ghani Dec 14, 2023
30fe281
Merge remote-tracking branch 'origin/develop' into khurram/batch_turbo
khurram-ghani Dec 14, 2023
4f51581
Fix merge issues
khurram-ghani Dec 14, 2023
82919f2
Clarify parallel acq comments
khurram-ghani Dec 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions docs/notebooks/trust_region.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,10 @@ def plot_final_result(_dataset: trieste.data.Dataset) -> None:
# %% [markdown]
# We can also visualize the progress of the optimization by plotting the acquisition space at each
# step. This space is either the full search space or the trust region, depending on the step, and
# is shown as a translucent box; with the current optimum point in a region shown in matching
# color.
# is shown as a translucent box. The new query points per region are plotted in matching color.
#
# Note there is only one trust region in this plot, however the rule in the next section will show
# multiple trust regions.
# Note there is only one trust region in this plot, however the rules in the following sections will
# show multiple trust regions.

# %%
import base64
Expand All @@ -144,7 +143,10 @@ def plot_final_result(_dataset: trieste.data.Dataset) -> None:
)


def plot_history(result: trieste.bayesian_optimizer.OptimizationResult) -> None:
def plot_history(
result: trieste.bayesian_optimizer.OptimizationResult,
num_query_points: int | None = None,
) -> None:
frames = []
for step, hist in enumerate(
result.history + [result.final_result.unwrap()]
Expand All @@ -154,6 +156,7 @@ def plot_history(result: trieste.bayesian_optimizer.OptimizationResult) -> None:
search_space.lower,
search_space.upper,
hist,
num_query_points=num_query_points,
num_init=num_initial_data_points,
)

Expand Down Expand Up @@ -250,32 +253,42 @@ def plot_history(result: trieste.bayesian_optimizer.OptimizationResult) -> None:
#
# Finally, we show how to run Bayesian optimization with the `TurBO` algorithm. This is a
# trust region algorithm that uses local models and datasets to approximate the objective function
# within one trust region.
# within their respective trust regions.
#
# ### Create `TurBO` rule and run optimization loop
#
# As before, this meta-rule requires the specification of an aquisition base-rule for performing
# optimization within the trust region; for our example we use the `DiscreteThompsonSampling` rule.
# optimization within the trust regions; for our example we use the `DiscreteThompsonSampling` rule.
#
# Note that trieste maintains a global model that is, by default, automatically trained on each
# iteration. However, this global model is unused for `TurBO`; which uses a local model instead.
# As fitting the global model would be redundant and wasteful, we switch its training off by
# setting `fit_model=False` in the `optimize` method.
# We create 2 `TurBO` trust regions and associated local models by initially copying the global model
# (using `copy_to_local_models`).
#
# The optimizer will return `num_query_points` new query points for each region in every step of the
# loop. With 5 steps and 2 regions, that's 30 points in total.

# %%
turbo_acq_rule = trieste.acquisition.rule.TURBO(
search_space, rule=trieste.acquisition.rule.DiscreteThompsonSampling(500, 3)
num_regions = 2
num_query_points = 3

turbo_subspaces = [
trieste.acquisition.rule.TURBOBox(search_space) for _ in range(num_regions)
]
dts_rule = trieste.acquisition.rule.DiscreteThompsonSampling(
500, num_query_points
)
turbo_acq_rule = trieste.acquisition.rule.BatchTrustRegionBox(
turbo_subspaces, dts_rule
)

bo = trieste.bayesian_optimizer.BayesianOptimizer(observer, search_space)

num_steps = 5
result = bo.optimize(
num_steps,
initial_data,
build_model(),
{trieste.observer.OBJECTIVE: initial_data},
trieste.acquisition.utils.copy_to_local_models(build_model(), num_regions),
turbo_acq_rule,
track_state=True,
fit_model=False,
)
dataset = result.try_get_final_dataset()

Expand All @@ -288,7 +301,7 @@ def plot_history(result: trieste.bayesian_optimizer.OptimizationResult) -> None:
plot_final_result(dataset)

# %%
plot_history(result)
plot_history(result, num_regions * num_query_points)

# %% [markdown]
# ## LICENSE
Expand Down
107 changes: 68 additions & 39 deletions tests/integration/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,31 @@
import copy
import pickle
import tempfile
from typing import Callable
from typing import Callable, Tuple, Union

import numpy.testing as npt
import pytest
import tensorflow as tf

from tests.util.misc import random_seed
from trieste.acquisition import LocalPenalization
from trieste.acquisition import LocalPenalization, ParallelContinuousThompsonSampling
from trieste.acquisition.rule import (
AcquisitionRule,
AsynchronousGreedy,
AsynchronousRuleState,
BatchTrustRegionBox,
EfficientGlobalOptimization,
SingleObjectiveTrustRegionBox,
TREGOBox,
)
from trieste.acquisition.utils import copy_to_local_models
from trieste.ask_tell_optimization import AskTellOptimizer
from trieste.bayesian_optimizer import OptimizationResult, Record
from trieste.logging import set_step_number, tensorboard_writer
from trieste.models import TrainableProbabilisticModel
from trieste.models.gpflow import GaussianProcessRegression, build_gpr
from trieste.objectives import ScaledBranin, SimpleQuadratic
from trieste.objectives.utils import mk_observer
from trieste.objectives.utils import mk_batch_observer, mk_observer
from trieste.observer import OBJECTIVE
from trieste.space import Box, SearchSpace
from trieste.types import State, TensorType
Expand All @@ -59,14 +61,44 @@
id="EfficientGlobalOptimization/reload_state",
),
pytest.param(
15, False, lambda: BatchTrustRegionBox(TREGOBox(ScaledBranin.search_space)), id="TREGO"
15,
False,
lambda: BatchTrustRegionBox(TREGOBox(ScaledBranin.search_space)),
id="TREGO",
),
pytest.param(
16,
True,
lambda: BatchTrustRegionBox(TREGOBox(ScaledBranin.search_space)),
id="TREGO/reload_state",
),
pytest.param(
10,
False,
lambda: BatchTrustRegionBox(
[SingleObjectiveTrustRegionBox(ScaledBranin.search_space) for _ in range(3)],
EfficientGlobalOptimization(
ParallelContinuousThompsonSampling(),
num_query_points=3,
),
),
id="BatchTrustRegionBox",
),
pytest.param(
10,
False,
(
lambda: BatchTrustRegionBox(
[SingleObjectiveTrustRegionBox(ScaledBranin.search_space) for _ in range(3)],
EfficientGlobalOptimization(
ParallelContinuousThompsonSampling(),
num_query_points=2,
),
),
3,
),
id="BatchTrustRegionBox/LocalModels",
),
pytest.param(
10,
False,
Expand All @@ -92,23 +124,26 @@
)


@random_seed
@pytest.mark.slow # to run this, add --runslow yes to the pytest command
@pytest.mark.parametrize(*OPTIMIZER_PARAMS)
def test_ask_tell_optimizer_finds_minima_of_the_scaled_branin_function(
num_steps: int,
reload_state: bool,
acquisition_rule_fn: Callable[
[], AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModel]
]
| Callable[
AcquisitionRuleFunction = Union[
Callable[[], AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModel]],
Callable[
[],
AcquisitionRule[
State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State],
State[TensorType, Union[AsynchronousRuleState, BatchTrustRegionBox.State]],
Box,
TrainableProbabilisticModel,
],
],
]


@random_seed
@pytest.mark.slow # to run this, add --runslow yes to the pytest command
@pytest.mark.parametrize(*OPTIMIZER_PARAMS)
def test_ask_tell_optimizer_finds_minima_of_the_scaled_branin_function(
num_steps: int,
reload_state: bool,
acquisition_rule_fn: AcquisitionRuleFunction | Tuple[AcquisitionRuleFunction, int],
) -> None:
_test_ask_tell_optimization_finds_minima(True, num_steps, reload_state, acquisition_rule_fn)

Expand All @@ -118,17 +153,7 @@ def test_ask_tell_optimizer_finds_minima_of_the_scaled_branin_function(
def test_ask_tell_optimizer_finds_minima_of_simple_quadratic(
num_steps: int,
reload_state: bool,
acquisition_rule_fn: Callable[
[], AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModel]
]
| Callable[
[],
AcquisitionRule[
State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State],
Box,
TrainableProbabilisticModel,
],
],
acquisition_rule_fn: AcquisitionRuleFunction | Tuple[AcquisitionRuleFunction, int],
) -> None:
# for speed reasons we sometimes test with a simple quadratic defined on the same search space
# branin; currently assume that every rule should be able to solve this in 5 steps
Expand All @@ -141,17 +166,7 @@ def _test_ask_tell_optimization_finds_minima(
optimize_branin: bool,
num_steps: int,
reload_state: bool,
acquisition_rule_fn: Callable[
[], AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModel]
]
| Callable[
[],
AcquisitionRule[
State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State],
Box,
TrainableProbabilisticModel,
],
],
acquisition_rule_fn: AcquisitionRuleFunction | Tuple[AcquisitionRuleFunction, int],
) -> None:
# For the case when optimization state is saved and reload on each iteration
# we need to use new acquisition function object to imitate real life usage
Expand All @@ -160,17 +175,27 @@ def _test_ask_tell_optimization_finds_minima(
search_space = ScaledBranin.search_space
initial_query_points = search_space.sample(5)
observer = mk_observer(ScaledBranin.objective if optimize_branin else SimpleQuadratic.objective)
batch_observer = mk_batch_observer(observer)
initial_data = observer(initial_query_points)

if isinstance(acquisition_rule_fn, tuple):
acquisition_rule_fn, num_models = acquisition_rule_fn
else:
num_models = 1

model = GaussianProcessRegression(
build_gpr(initial_data, search_space, likelihood_variance=1e-7)
)
models = copy_to_local_models(model, num_models) if num_models > 1 else {OBJECTIVE: model}
initial_dataset = {OBJECTIVE: initial_data}

with tempfile.TemporaryDirectory() as tmpdirname:
summary_writer = tf.summary.create_file_writer(tmpdirname)
with tensorboard_writer(summary_writer):
set_step_number(0)
ask_tell = AskTellOptimizer(search_space, initial_data, model, acquisition_rule_fn())
ask_tell = AskTellOptimizer(
search_space, initial_dataset, models, acquisition_rule_fn()
)

for i in range(1, num_steps + 1):
# two scenarios are tested here, depending on `reload_state` parameter
Expand All @@ -185,7 +210,11 @@ def _test_ask_tell_optimization_finds_minima(
] = ask_tell.to_record()
written_state = pickle.dumps(state)

new_data_point = observer(new_point)
# If query points are rank 3, then use a batched observer.
if tf.rank(new_point) == 3:
new_data_point = batch_observer(new_point)
else:
new_data_point = observer(new_point)

if reload_state:
state = pickle.loads(written_state)
Expand Down
Loading