Skip to content

Commit

Permalink
Also make TrainableProbabilisticModel pure
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Nov 28, 2023
1 parent a7e41e4 commit 6a7ca8c
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 35 deletions.
8 changes: 5 additions & 3 deletions tests/unit/models/gpflow/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
RandomFourierFeatureTrajectorySampler,
)
from trieste.models.optimizer import BatchOptimizer, DatasetTransformer, Optimizer
from trieste.models.utils import get_last_optimization_result, optimize_model_and_save_result
from trieste.space import Box
from trieste.types import TensorType
from trieste.utils import DEFAULTS
Expand Down Expand Up @@ -150,13 +151,14 @@ def test_gpflow_wrappers_default_optimize(
args = {}

loss = internal_model.training_loss(**args)
model.optimize_and_save_result(Dataset(*data))
optimize_model_and_save_result(model, Dataset(*data))

new_loss = internal_model.training_loss(**args)
assert new_loss < loss
if not isinstance(internal_model, SVGP):
assert model.last_optimization_result is not None
npt.assert_allclose(new_loss, model.last_optimization_result.fun)
optimization_result = get_last_optimization_result(model)
assert optimization_result is not None
npt.assert_allclose(new_loss, optimization_result.fun)


def test_gpflow_wrappers_ref_optimize(gpflow_interface_factory: ModelFactoryType) -> None:
Expand Down
13 changes: 9 additions & 4 deletions tests/unit/models/gpflux/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@
from trieste.models.gpflux import DeepGaussianProcess
from trieste.models.interfaces import HasTrajectorySampler
from trieste.models.optimizer import KerasOptimizer
from trieste.models.utils import get_module_with_variables
from trieste.models.utils import (
get_last_optimization_result,
get_module_with_variables,
optimize_model_and_save_result,
)
from trieste.types import TensorType


Expand Down Expand Up @@ -298,10 +302,11 @@ def test_deep_gaussian_process_with_lr_scheduler(
optimizer = KerasOptimizer(tf.optimizers.Adam(lr_schedule), fit_args)
model = DeepGaussianProcess(two_layer_model(x), optimizer)

model.optimize_and_save_result(Dataset(x, y))
optimize_model_and_save_result(model, Dataset(x, y))

assert model.last_optimization_result is not None
assert len(model.last_optimization_result.history["loss"]) == epochs
optimization_result = get_last_optimization_result(model)
assert optimization_result is not None
assert len(optimization_result.history["loss"]) == epochs


def test_deep_gaussian_process_default_optimizer_is_correct(
Expand Down
13 changes: 9 additions & 4 deletions tests/unit/models/keras/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
sample_with_replacement,
)
from trieste.models.optimizer import KerasOptimizer, TrainingData
from trieste.models.utils import get_module_with_variables
from trieste.models.utils import (
get_last_optimization_result,
get_module_with_variables,
optimize_model_and_save_result,
)
from trieste.types import TensorType

_ENSEMBLE_SIZE = 3
Expand Down Expand Up @@ -216,10 +220,11 @@ def scheduler(epoch: int, lr: float) -> float:

npt.assert_allclose(model.model.optimizer.lr.numpy(), init_lr, rtol=1e-6)

model.optimize_and_save_result(example_data)
optimize_model_and_save_result(model, example_data)

assert model.last_optimization_result is not None
npt.assert_allclose(model.last_optimization_result.history["lr"], [0.5, 0.25])
optimization_result = get_last_optimization_result(model)
assert optimization_result is not None
npt.assert_allclose(optimization_result.history["lr"], [0.5, 0.25])
npt.assert_allclose(model.model.optimizer.lr.numpy(), init_lr, rtol=1e-6)


Expand Down
5 changes: 3 additions & 2 deletions tests/unit/models/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TrainableSupportsPredictJoint,
TrainableSupportsPredictJointHasReparamSampler,
)
from trieste.models.utils import get_last_optimization_result, optimize_model_and_save_result
from trieste.types import TensorType


Expand Down Expand Up @@ -177,8 +178,8 @@ def _assert_data(self, dataset: Dataset) -> None:
stack = TrainableModelStack((model01, 2), (model2, 1), (model3, 1))
data = Dataset(tf.random.uniform([5, 7, 3]), tf.random.uniform([5, 7, 4]))
stack.update(data)
stack.optimize_and_save_result(data)
assert stack.last_optimization_result == [None] * 3
optimize_model_and_save_result(stack, data)
assert get_last_optimization_result(stack) == [None] * 3


def test_model_stack_reparam_sampler_raises_for_submodels_without_reparam_sampler() -> None:
Expand Down
6 changes: 4 additions & 2 deletions trieste/ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from copy import deepcopy
from typing import Dict, Generic, Mapping, TypeVar, cast, overload

from .models.utils import optimize_model_and_save_result

try:
import pandas as pd
except ModuleNotFoundError:
Expand Down Expand Up @@ -233,7 +235,7 @@ def __init__(
for tag, model in self._models.items():
dataset = datasets[tag]
model.update(dataset)
model.optimize_and_save_result(dataset)
optimize_model_and_save_result(model, dataset)

summary_writer = logging.get_tensorboard_writer()
if summary_writer:
Expand Down Expand Up @@ -434,7 +436,7 @@ def tell(self, new_data: Mapping[Tag, Dataset] | Dataset) -> None:
for tag, model in self._models.items():
dataset = self._datasets[tag]
model.update(dataset)
model.optimize_and_save_result(dataset)
optimize_model_and_save_result(model, dataset)

summary_writer = logging.get_tensorboard_writer()
if summary_writer:
Expand Down
5 changes: 3 additions & 2 deletions trieste/bayesian_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from scipy.spatial.distance import pdist

from .acquisition.multi_objective import non_dominated
from .models.utils import optimize_model_and_save_result

try:
import pandas as pd
Expand Down Expand Up @@ -719,7 +720,7 @@ def optimize(
for tag, model in models.items():
dataset = datasets[tag]
model.update(dataset)
model.optimize_and_save_result(dataset)
optimize_model_and_save_result(model, dataset)
if summary_writer:
logging.set_step_number(0)
with summary_writer.as_default(step=0):
Expand Down Expand Up @@ -752,7 +753,7 @@ def optimize(
for tag, model in models.items():
dataset = datasets[tag]
model.update(dataset)
model.optimize_and_save_result(dataset)
optimize_model_and_save_result(model, dataset)

if summary_writer:
with summary_writer.as_default(step=step):
Expand Down
16 changes: 0 additions & 16 deletions trieste/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,6 @@ def optimize(self, dataset: Dataset) -> Any:
"""
raise NotImplementedError

def optimize_and_save_result(self, dataset: Dataset) -> None:
"""
Optimize the model objective and save the optimization result in
``last_optimization_result``.
:param dataset: The data with which to train the model.
"""
setattr(self, "_last_optimization_result", self.optimize(dataset))

@property
def last_optimization_result(self) -> Optional[Any]:
"""
The last saved (optimizer-specific) optimization result.
"""
return getattr(self, "_last_optimization_result")


@runtime_checkable
class SupportsPredictJoint(ProbabilisticModel, Protocol):
Expand Down
21 changes: 19 additions & 2 deletions trieste/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from __future__ import annotations

from typing import Any
from typing import Any, Optional

import gpflow
import tensorflow as tf
Expand All @@ -27,7 +27,7 @@
from .. import logging
from ..data import Dataset
from ..utils.misc import get_variables
from .interfaces import ProbabilisticModel
from .interfaces import ProbabilisticModel, TrainableProbabilisticModel


def write_summary_data_based_metrics(
Expand Down Expand Up @@ -120,3 +120,20 @@ def get_module_with_variables(model: ProbabilisticModel, *dependencies: Any) ->
for dependency in dependencies:
module.saved_variables += get_variables(dependency)
return module


def optimize_model_and_save_result(model: TrainableProbabilisticModel, dataset: Dataset) -> None:
"""
Optimize the model objective and save the (optimizer-specific) optimization result
in the model object. To access it, use ``get_last_optimization_result``.
:param dataset: The data with which to train the model.
"""
setattr(model, "_last_optimization_result", model.optimize(dataset))


def get_last_optimization_result(model: TrainableProbabilisticModel) -> Optional[Any]:
"""
The last saved (optimizer-specific) optimization result.
"""
return getattr(model, "_last_optimization_result")

0 comments on commit 6a7ca8c

Please sign in to comment.