Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local models and datasets #788

Merged
merged 38 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c89583e
Add support for local models and datasets (WIP)
khurram-ghani Sep 21, 2023
c8aebec
Add unit test for local models (WIP)
khurram-ghani Sep 25, 2023
b361e6d
Merge remote-tracking branch 'origin/develop' into khurram/batch_turbo
khurram-ghani Sep 27, 2023
4b3c2de
Update multi model/dataset test (WIP)
khurram-ghani Sep 29, 2023
28803aa
Add unit test for keep datasets in regions
khurram-ghani Oct 3, 2023
e99df96
Add more tests and move to local tags class
khurram-ghani Oct 6, 2023
2de6271
Always include global dataset in mapping
khurram-ghani Oct 6, 2023
6236c95
Add filter_mask method to trust region
khurram-ghani Oct 6, 2023
e96ebe8
Add more testing
khurram-ghani Oct 9, 2023
9a7dd23
Fix mypy model type issues
khurram-ghani Oct 10, 2023
4622a98
Add ask_tell testing
khurram-ghani Oct 10, 2023
af44e48
Fix summary init when only global dataset
khurram-ghani Oct 10, 2023
ef9e455
Remove walrus operator
khurram-ghani Oct 11, 2023
59e0cab
Update test, ask_tell data not changed in-place
khurram-ghani Oct 11, 2023
e4a5342
Add some test comments
khurram-ghani Oct 11, 2023
69a8590
Add some rule comments
khurram-ghani Oct 11, 2023
40df585
Allow input-multi-observers for batch observer
khurram-ghani Oct 11, 2023
ee1ee56
Allow multiple models/datasets for base rule
khurram-ghani Oct 11, 2023
7d03a04
Support multiple models/datasets in region selects
khurram-ghani Oct 12, 2023
8551d72
Fix TR plotting history colors
khurram-ghani Oct 12, 2023
b571733
Add notebook init points explanation
khurram-ghani Oct 12, 2023
2a934d1
Rename region index and add init param
khurram-ghani Oct 16, 2023
bea54ea
Merge branch 'develop' into khurram/local_models
khurram-ghani Oct 16, 2023
52f7975
Remove old comment
khurram-ghani Oct 17, 2023
b2eb662
Tidy-up redundant expression
khurram-ghani Oct 17, 2023
523aaab
Keep full datasets along with filtered ones
khurram-ghani Oct 18, 2023
5b3ec0f
Make changes from PR feedback
khurram-ghani Nov 16, 2023
2c21f85
Merge remote-tracking branch 'origin/develop' into khurram/local_models
khurram-ghani Nov 22, 2023
e5cccc8
Address some of the recent feedback
khurram-ghani Nov 23, 2023
25da01b
Fix dataset mypy error
khurram-ghani Nov 23, 2023
292faaa
Copy dataset in optimizers to avoid changing it
khurram-ghani Nov 24, 2023
9170c64
Share DatasetChecker and tidy-up exp values in tests
khurram-ghani Nov 24, 2023
efd2fc0
Address more feedback
khurram-ghani Nov 24, 2023
ff84690
Merge remote-tracking branch 'origin/develop' into khurram/local_models
khurram-ghani Nov 29, 2023
505631a
Avoid default num_models in integ tests
khurram-ghani Nov 29, 2023
b90d3a1
Fix old python typing issue
khurram-ghani Nov 29, 2023
1de7e20
Merge remote-tracking branch 'origin/develop' into khurram/local_models
khurram-ghani Dec 13, 2023
a75e618
Use flatten_... func and add comment
khurram-ghani Dec 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/notebooks/trust_region.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,44 @@ def plot_history(result: trieste.bayesian_optimizer.OptimizationResult) -> None:
# %%
plot_history(result)

# %% [markdown]
# ## TEST
khurram-ghani marked this conversation as resolved.
Show resolved Hide resolved

# %%
num_query_points = 5

init_subspaces = [
trieste.acquisition.rule.SingleObjectiveTrustRegionBox(search_space)
for _ in range(num_query_points)
]
base_rule = trieste.acquisition.rule.EfficientGlobalOptimization(
builder=trieste.acquisition.ParallelContinuousThompsonSampling(),
num_query_points=1,
)
batch_acq_rule = trieste.acquisition.rule.BatchTrustRegionBox(
init_subspaces, base_rule
)

bo = trieste.bayesian_optimizer.BayesianOptimizer(observer, search_space)

num_steps = 5
result = bo.optimize(
num_steps,
{trieste.observer.OBJECTIVE: initial_data},
trieste.acquisition.utils.copy_to_local_models(
build_model(), num_query_points
),
batch_acq_rule,
track_state=True,
)
dataset = result.try_get_final_dataset()

# %%
plot_final_result(dataset)

# %%
plot_history(result)

# %% [markdown]
# ## Trust region `TurBO` acquisition rule
#
Expand Down
68 changes: 59 additions & 9 deletions tests/integration/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,24 @@
import tensorflow as tf

from tests.util.misc import random_seed
from trieste.acquisition import LocalPenalization
from trieste.acquisition import LocalPenalization, ParallelContinuousThompsonSampling
from trieste.acquisition.rule import (
AcquisitionRule,
AsynchronousGreedy,
AsynchronousRuleState,
BatchTrustRegionBox,
EfficientGlobalOptimization,
SingleObjectiveTrustRegionBox,
TREGOBox,
)
from trieste.acquisition.utils import copy_to_local_models
from trieste.ask_tell_optimization import AskTellOptimizer
from trieste.bayesian_optimizer import OptimizationResult, Record
from trieste.logging import set_step_number, tensorboard_writer
from trieste.models import TrainableProbabilisticModel
from trieste.models.gpflow import GaussianProcessRegression, build_gpr
from trieste.objectives import ScaledBranin, SimpleQuadratic
from trieste.objectives.utils import mk_observer
from trieste.objectives.utils import mk_batch_observer, mk_observer
from trieste.observer import OBJECTIVE
from trieste.space import Box, SearchSpace
from trieste.types import State, TensorType
Expand All @@ -47,26 +49,58 @@
# We use a copy of these for a quicker test against a simple quadratic function
# (copying is necessary as some of the acquisition rules are stateful).
OPTIMIZER_PARAMS = (
"num_steps, reload_state, acquisition_rule_fn",
"num_steps, reload_state, acquisition_rule_fn, num_models",
[
pytest.param(
20, False, lambda: EfficientGlobalOptimization(), id="EfficientGlobalOptimization"
20, False, lambda: EfficientGlobalOptimization(), 1, id="EfficientGlobalOptimization"
),
pytest.param(
20,
True,
lambda: EfficientGlobalOptimization(),
1,
khurram-ghani marked this conversation as resolved.
Show resolved Hide resolved
id="EfficientGlobalOptimization/reload_state",
),
pytest.param(
15, False, lambda: BatchTrustRegionBox(TREGOBox(ScaledBranin.search_space)), id="TREGO"
15,
False,
lambda: BatchTrustRegionBox(TREGOBox(ScaledBranin.search_space)),
1,
id="TREGO",
),
pytest.param(
16,
True,
lambda: BatchTrustRegionBox(TREGOBox(ScaledBranin.search_space)),
1,
id="TREGO/reload_state",
),
pytest.param(
10,
False,
lambda: BatchTrustRegionBox(
[SingleObjectiveTrustRegionBox(ScaledBranin.search_space) for _ in range(3)],
EfficientGlobalOptimization(
ParallelContinuousThompsonSampling(),
num_query_points=3,
),
),
1,
id="BatchTrustRegionBox",
),
pytest.param(
10,
False,
lambda: BatchTrustRegionBox(
[SingleObjectiveTrustRegionBox(ScaledBranin.search_space) for _ in range(3)],
EfficientGlobalOptimization(
ParallelContinuousThompsonSampling(),
num_query_points=2,
),
),
3,
id="BatchTrustRegionBox/LocalModels",
),
pytest.param(
10,
False,
Expand All @@ -76,6 +110,7 @@
).using(OBJECTIVE),
num_query_points=3,
),
1,
id="LocalPenalization",
),
pytest.param(
Expand All @@ -86,6 +121,7 @@
ScaledBranin.search_space,
).using(OBJECTIVE),
),
1,
id="LocalPenalization/AsynchronousGreedy",
),
],
Expand All @@ -109,8 +145,11 @@ def test_ask_tell_optimizer_finds_minima_of_the_scaled_branin_function(
TrainableProbabilisticModel,
],
],
num_models: int,
) -> None:
_test_ask_tell_optimization_finds_minima(True, num_steps, reload_state, acquisition_rule_fn)
_test_ask_tell_optimization_finds_minima(
True, num_steps, reload_state, acquisition_rule_fn, num_models
)


@random_seed
Expand All @@ -129,11 +168,12 @@ def test_ask_tell_optimizer_finds_minima_of_simple_quadratic(
TrainableProbabilisticModel,
],
],
num_models: int,
) -> None:
# for speed reasons we sometimes test with a simple quadratic defined on the same search space
# branin; currently assume that every rule should be able to solve this in 5 steps
_test_ask_tell_optimization_finds_minima(
False, min(num_steps, 5), reload_state, acquisition_rule_fn
False, min(num_steps, 5), reload_state, acquisition_rule_fn, num_models
)


Expand All @@ -152,6 +192,7 @@ def _test_ask_tell_optimization_finds_minima(
TrainableProbabilisticModel,
],
],
num_models: int,
) -> None:
# For the case when optimization state is saved and reload on each iteration
# we need to use new acquisition function object to imitate real life usage
Expand All @@ -160,17 +201,22 @@ def _test_ask_tell_optimization_finds_minima(
search_space = ScaledBranin.search_space
initial_query_points = search_space.sample(5)
observer = mk_observer(ScaledBranin.objective if optimize_branin else SimpleQuadratic.objective)
batch_observer = mk_batch_observer(observer)
initial_data = observer(initial_query_points)

model = GaussianProcessRegression(
build_gpr(initial_data, search_space, likelihood_variance=1e-7)
)
models = copy_to_local_models(model, num_models) if num_models > 1 else {OBJECTIVE: model}
initial_dataset = {OBJECTIVE: initial_data}

with tempfile.TemporaryDirectory() as tmpdirname:
summary_writer = tf.summary.create_file_writer(tmpdirname)
with tensorboard_writer(summary_writer):
set_step_number(0)
ask_tell = AskTellOptimizer(search_space, initial_data, model, acquisition_rule_fn())
ask_tell = AskTellOptimizer(
search_space, initial_dataset, models, acquisition_rule_fn()
)

for i in range(1, num_steps + 1):
# two scenarios are tested here, depending on `reload_state` parameter
Expand All @@ -185,7 +231,11 @@ def _test_ask_tell_optimization_finds_minima(
] = ask_tell.to_record()
written_state = pickle.dumps(state)

new_data_point = observer(new_point)
# If query points are rank 3, then use a batched observer.
if tf.rank(new_point) == 3:
new_data_point = batch_observer(new_point)
else:
new_data_point = observer(new_point)

if reload_state:
state = pickle.loads(written_state)
Expand Down
Loading
Loading