Skip to content

Commit

Permalink
Make get_module_with_variables a utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Nov 27, 2023
1 parent 4463f43 commit 1abeaa4
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 18 deletions.
4 changes: 3 additions & 1 deletion docs/notebooks/expected_improvement.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ def build_model(data):

# %%
# save the model to a given path, exporting just the predict method
module = result.try_get_final_model().get_module_with_variables()
from trieste.models.utils import get_module_with_variables

module = get_module_with_variables(result.try_get_final_model())
module.predict = tf.function(
model.predict,
input_signature=[tf.TensorSpec(shape=[None, 2], dtype=tf.float64)],
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/models/gpflow/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from trieste.models.interfaces import HasTrajectorySampler
from trieste.models.optimizer import BatchOptimizer, Optimizer
from trieste.models.utils import get_module_with_variables
from trieste.types import TensorType


Expand Down Expand Up @@ -92,7 +93,7 @@ def test_gaussian_process_tf_saved_model(gpflow_interface_factory: ModelFactoryT
trajectory = trajectory_sampler.get_trajectory()

# generate client model with predict and sample methods
module = model.get_module_with_variables(trajectory_sampler, trajectory)
module = get_module_with_variables(model, trajectory_sampler, trajectory)
module.predict = tf.function(
model.predict, input_signature=[tf.TensorSpec(shape=[None, 1], dtype=tf.float64)]
)
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/models/gpflux/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from trieste.models.gpflux import DeepGaussianProcess
from trieste.models.interfaces import HasTrajectorySampler
from trieste.models.optimizer import KerasOptimizer
from trieste.models.utils import get_module_with_variables
from trieste.types import TensorType


Expand Down Expand Up @@ -395,7 +396,7 @@ def test_deepgp_tf_saved_model() -> None:
trajectory = trajectory_sampler.get_trajectory()

# generate client model with predict and sample methods
module = model.get_module_with_variables(trajectory_sampler, trajectory)
module = get_module_with_variables(model, trajectory_sampler, trajectory)
module.predict = tf.function(
model.predict, input_signature=[tf.TensorSpec(shape=[None, 1], dtype=tf.float64)]
)
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/models/keras/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
sample_with_replacement,
)
from trieste.models.optimizer import KerasOptimizer, TrainingData
from trieste.models.utils import get_module_with_variables
from trieste.types import TensorType

_ENSEMBLE_SIZE = 3
Expand Down Expand Up @@ -568,7 +569,7 @@ def test_deep_ensemble_tf_saved_model() -> None:
trajectory = trajectory_sampler.get_trajectory()

# generate client model with predict and sample methods
module = model.get_module_with_variables(trajectory_sampler, trajectory)
module = get_module_with_variables(model, trajectory_sampler, trajectory)
module.predict = tf.function(
model.predict, input_signature=[tf.TensorSpec(shape=[None, 3], dtype=tf.float64)]
)
Expand Down
14 changes: 0 additions & 14 deletions trieste/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ..data import Dataset
from ..types import TensorType
from ..utils import DEFAULTS
from ..utils.misc import get_variables

ProbabilisticModelType = TypeVar(
"ProbabilisticModelType", bound="ProbabilisticModel", contravariant=True
Expand Down Expand Up @@ -93,19 +92,6 @@ def log(self, dataset: Optional[Dataset] = None) -> None:
"""
return

def get_module_with_variables(self, *dependencies: Any) -> tf.Module:
"""
Return a fresh module with the model's variables attached, which can then be extended
with methods and saved using tf.saved_model.
:param dependencies: Dependent objects whose variables should also be included.
"""
module = tf.Module()
module.saved_variables = get_variables(self)
for dependency in dependencies:
module.saved_variables += get_variables(dependency)
return module


@runtime_checkable
class TrainableProbabilisticModel(ProbabilisticModel, Protocol):
Expand Down
18 changes: 18 additions & 0 deletions trieste/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@

from __future__ import annotations

from typing import Any

import gpflow
import tensorflow as tf
from gpflow.utilities.traversal import _merge_leaf_components, leaf_components

from .. import logging
from ..data import Dataset
from ..utils.misc import get_variables
from .interfaces import ProbabilisticModel


Expand Down Expand Up @@ -102,3 +105,18 @@ def write_summary_likelihood_parameters(
for k, v in likelihood_components.items():
if v.trainable:
logging.scalar(f"{prefix}likelihood.{k}", v)


def get_module_with_variables(model: ProbabilisticModel, *dependencies: Any) -> tf.Module:
"""
Return a fresh module with a model's variables attached, which can then be extended
with methods and saved using tf.saved_model.
:param model: Model to extract variables from.
:param dependencies: Dependent objects whose variables should also be included.
"""
module = tf.Module()
module.saved_variables = get_variables(model)
for dependency in dependencies:
module.saved_variables += get_variables(dependency)
return module

0 comments on commit 1abeaa4

Please sign in to comment.