From 967016d1d8593babb3779f65f0f2e3d9b24f67e0 Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Mon, 27 Nov 2023 13:39:10 +0000 Subject: [PATCH] Make log abstract --- tests/unit/models/gpflow/test_interface.py | 5 ++++- tests/unit/models/gpflux/test_interface.py | 5 +++++ tests/unit/test_bayesian_optimizer.py | 3 +++ tests/util/models/gpflow/models.py | 3 +++ trieste/models/gpflow/models.py | 6 ++++++ trieste/models/interfaces.py | 3 ++- trieste/models/keras/interface.py | 4 ++++ 7 files changed, 27 insertions(+), 2 deletions(-) diff --git a/tests/unit/models/gpflow/test_interface.py b/tests/unit/models/gpflow/test_interface.py index f81c6c2a8e..d1f117d790 100644 --- a/tests/unit/models/gpflow/test_interface.py +++ b/tests/unit/models/gpflow/test_interface.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Optional import gpflow import numpy.testing as npt @@ -37,6 +37,9 @@ def optimize(self, dataset: Dataset) -> None: def update(self, dataset: Dataset) -> None: return + def log(self, dataset: Optional[Dataset] = None) -> None: + return + class _QuadraticGPModel(GPModel): def __init__(self) -> None: diff --git a/tests/unit/models/gpflux/test_interface.py b/tests/unit/models/gpflux/test_interface.py index 910a208033..b7b5fb860f 100644 --- a/tests/unit/models/gpflux/test_interface.py +++ b/tests/unit/models/gpflux/test_interface.py @@ -14,6 +14,8 @@ from __future__ import annotations +from typing import Optional + import gpflow import numpy.testing as npt import pytest @@ -70,6 +72,9 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType: def update(self, dataset: Dataset) -> None: return + def log(self, dataset: Optional[Dataset] = None) -> None: + return + class _QuadraticGPModel(DeepGP): def __init__( diff --git a/tests/unit/test_bayesian_optimizer.py b/tests/unit/test_bayesian_optimizer.py index 9e15032d6b..99017dff90 100644 --- a/tests/unit/test_bayesian_optimizer.py +++ b/tests/unit/test_bayesian_optimizer.py @@ -491,6 +491,9 @@ def update(self, dataset: Dataset) -> NoReturn: def optimize(self, dataset: Dataset) -> NoReturn: assert False + def log(self, dataset: Optional[Dataset] = None) -> None: + return + class _UnusableRule(AcquisitionRule[NoReturn, Box, ProbabilisticModel]): def acquire( self, diff --git a/tests/util/models/gpflow/models.py b/tests/util/models/gpflow/models.py index 1d25fdcdfe..0cd1937f0a 100644 --- a/tests/util/models/gpflow/models.py +++ b/tests/util/models/gpflow/models.py @@ -121,6 +121,9 @@ def covariance_between_points( ] return tf.concat(covs, axis=-3) + def log(self, dataset: Optional[Dataset] = None) -> None: + return + class GaussianProcessWithoutNoise(GaussianMarginal, SupportsPredictJoint, HasReparamSampler): """A (static) Gaussian process over a vector random variable with independent reparam sampler diff --git a/trieste/models/gpflow/models.py b/trieste/models/gpflow/models.py index fedc993c40..e748d048d2 100644 --- a/trieste/models/gpflow/models.py +++ b/trieste/models/gpflow/models.py @@ -1659,6 +1659,9 @@ def covariance_with_top_fidelity(self, query_points: TensorType) -> TensorType: return f_var + def log(self, dataset: Optional[Dataset] = None) -> None: + return + class MultifidelityNonlinearAutoregressive( TrainableProbabilisticModel, SupportsPredictY, SupportsCovarianceWithTopFidelity @@ -2033,3 +2036,6 @@ def covariance_with_top_fidelity(self, query_points: TensorType) -> TensorType: cov = tfp.stats.covariance(signal_sample, max_fidelity_sample)[:, :, 0] return cov + + def log(self, dataset: Optional[Dataset] = None) -> None: + return diff --git a/trieste/models/interfaces.py b/trieste/models/interfaces.py index 78c6e62e5e..99827903e0 100644 --- a/trieste/models/interfaces.py +++ b/trieste/models/interfaces.py @@ -84,13 +84,14 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType: """ raise NotImplementedError + @abstractmethod def log(self, dataset: Optional[Dataset] = None) -> None: """ Log model-specific information at a given optimization step. :param dataset: Optional data that can be used to log additional data-based model summaries. """ - return + raise NotImplementedError @runtime_checkable diff --git a/trieste/models/keras/interface.py b/trieste/models/keras/interface.py index 8f3c28c927..96b222cccd 100644 --- a/trieste/models/keras/interface.py +++ b/trieste/models/keras/interface.py @@ -22,6 +22,7 @@ from check_shapes import inherit_check_shapes from typing_extensions import Protocol, runtime_checkable +from ...data import Dataset from ...types import TensorType from ..interfaces import ProbabilisticModel from ..optimizer import KerasOptimizer @@ -76,6 +77,9 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType: """ ) + def log(self, dataset: Optional[Dataset] = None) -> None: + return + @runtime_checkable class DeepEnsembleModel(ProbabilisticModel, Protocol):