Skip to content

Commit

Permalink
Check whether register_keras_serializable is actually necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Sep 9, 2024
1 parent acfcf92 commit 59b4c47
Showing 1 changed file with 1 addition and 11 deletions.
12 changes: 1 addition & 11 deletions trieste/models/keras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,14 @@

from __future__ import annotations

from typing import Callable, Optional
from typing import Optional

import tensorflow as tf
import tensorflow_probability as tfp
from gpflow.keras import tf_keras

from ...data import Dataset
from ...types import TensorType

try:
register_keras_serializable = tf_keras.saving.register_keras_serializable()
except AttributeError: # pragma: no cover (tested but not by coverage)

# not required in earlier version of TF
def register_keras_serializable(func: Callable[..., object]) -> Callable[..., object]:
return func


def get_tensor_spec_from_data(dataset: Dataset) -> tuple[tf.TensorSpec, tf.TensorSpec]:
r"""
Expand Down Expand Up @@ -131,7 +122,6 @@ def sample_model_index(
return indices


@register_keras_serializable
def negative_log_likelihood(
y_true: TensorType, y_pred: tfp.distributions.Distribution
) -> TensorType:
Expand Down

0 comments on commit 59b4c47

Please sign in to comment.