Skip to content

Commit

Permalink
Make log abstract
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Nov 27, 2023
1 parent 1abeaa4 commit 967016d
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 2 deletions.
5 changes: 4 additions & 1 deletion tests/unit/models/gpflow/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import Any
from typing import Any, Optional

import gpflow
import numpy.testing as npt
Expand All @@ -37,6 +37,9 @@ def optimize(self, dataset: Dataset) -> None:
def update(self, dataset: Dataset) -> None:
return

def log(self, dataset: Optional[Dataset] = None) -> None:
return


class _QuadraticGPModel(GPModel):
def __init__(self) -> None:
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/models/gpflux/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import annotations

from typing import Optional

import gpflow
import numpy.testing as npt
import pytest
Expand Down Expand Up @@ -70,6 +72,9 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
def update(self, dataset: Dataset) -> None:
return

def log(self, dataset: Optional[Dataset] = None) -> None:
return


class _QuadraticGPModel(DeepGP):
def __init__(
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_bayesian_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,9 @@ def update(self, dataset: Dataset) -> NoReturn:
def optimize(self, dataset: Dataset) -> NoReturn:
assert False

def log(self, dataset: Optional[Dataset] = None) -> None:
return

class _UnusableRule(AcquisitionRule[NoReturn, Box, ProbabilisticModel]):
def acquire(
self,
Expand Down
3 changes: 3 additions & 0 deletions tests/util/models/gpflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def covariance_between_points(
]
return tf.concat(covs, axis=-3)

def log(self, dataset: Optional[Dataset] = None) -> None:
return


class GaussianProcessWithoutNoise(GaussianMarginal, SupportsPredictJoint, HasReparamSampler):
"""A (static) Gaussian process over a vector random variable with independent reparam sampler
Expand Down
6 changes: 6 additions & 0 deletions trieste/models/gpflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,6 +1659,9 @@ def covariance_with_top_fidelity(self, query_points: TensorType) -> TensorType:

return f_var

def log(self, dataset: Optional[Dataset] = None) -> None:
return


class MultifidelityNonlinearAutoregressive(
TrainableProbabilisticModel, SupportsPredictY, SupportsCovarianceWithTopFidelity
Expand Down Expand Up @@ -2033,3 +2036,6 @@ def covariance_with_top_fidelity(self, query_points: TensorType) -> TensorType:
cov = tfp.stats.covariance(signal_sample, max_fidelity_sample)[:, :, 0]

return cov

def log(self, dataset: Optional[Dataset] = None) -> None:
return
3 changes: 2 additions & 1 deletion trieste/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
"""
raise NotImplementedError

@abstractmethod
def log(self, dataset: Optional[Dataset] = None) -> None:
"""
Log model-specific information at a given optimization step.
:param dataset: Optional data that can be used to log additional data-based model summaries.
"""
return
raise NotImplementedError


@runtime_checkable
Expand Down
4 changes: 4 additions & 0 deletions trieste/models/keras/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from check_shapes import inherit_check_shapes
from typing_extensions import Protocol, runtime_checkable

from ...data import Dataset
from ...types import TensorType
from ..interfaces import ProbabilisticModel
from ..optimizer import KerasOptimizer
Expand Down Expand Up @@ -76,6 +77,9 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
"""
)

def log(self, dataset: Optional[Dataset] = None) -> None:
return


@runtime_checkable
class DeepEnsembleModel(ProbabilisticModel, Protocol):
Expand Down

0 comments on commit 967016d

Please sign in to comment.