Skip to content

Commit

Permalink
Fix resampling test and mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Oct 15, 2024
1 parent 7ddfb00 commit cf57ba4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
23 changes: 14 additions & 9 deletions tests/unit/models/gpflux/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from __future__ import annotations

from typing import Callable, Tuple
from typing import Callable, Sequence, Tuple
from unittest.mock import patch

import gpflow.kernels
Expand Down Expand Up @@ -508,8 +508,15 @@ def test_dgp_decoupled_layer_update_updates(

evals_1 = decoupled_layer(xs)

original_W = decoupled_layer._feature_functions.W.value().numpy()
original_b = decoupled_layer._feature_functions.b.value().numpy()
def get_values(x: tf.Variable | Sequence[tf.Variable]) -> Sequence[tf.Tensor]:
# weights and biases are either a single variable or a list of variables
if isinstance(x, tf.Variable):
x = [x]
return [x.value().numpy() for x in x]

original_W = get_values(decoupled_layer._feature_functions.W)
original_b = get_values(decoupled_layer._feature_functions.b)

for _ in range(5):
x_train = tf.random.uniform([20, 2], minval=-10.0, maxval=10.0, dtype=tf.float64)
y_train = tf.random.normal([20, 1], dtype=tf.float64)
Expand All @@ -522,9 +529,7 @@ def test_dgp_decoupled_layer_update_updates(
npt.assert_array_less(1e-2, tf.reduce_sum(tf.abs(evals_1 - evals_new)))

# Check that RFF weights change
npt.assert_array_less(
1e-2, tf.reduce_sum(tf.abs(original_b - decoupled_layer._feature_functions.b))
)
npt.assert_array_less(
1e-2, tf.reduce_sum(tf.abs(original_W - decoupled_layer._feature_functions.W))
)
for old_b, new_b in zip(original_b, get_values(decoupled_layer._feature_functions.b)):
npt.assert_array_less(1e-2, tf.reduce_sum(tf.abs(old_b - new_b)))
for old_W, new_W in zip(original_W, get_values(decoupled_layer._feature_functions.W)):
npt.assert_array_less(1e-2, tf.reduce_sum(tf.abs(old_W - new_W)))
4 changes: 2 additions & 2 deletions trieste/models/gpflow/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,11 +811,11 @@ def resample(self) -> None:
b.assign(self._bias_init(tf.shape(b), dtype=self._dtype))

if isinstance(self.W, tf.Variable):
self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), dtype=self._dtype))
self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), self._dtype))
else:
tf.debugging.Assert(isinstance(self.W, list), [])
for W, k in zip(self.W, cycle(self.sub_kernels)):
W.assign(self._weights_init(k)(tf.shape(W), dtype=self._dtype))
W.assign(self._weights_init(k)(tf.shape(W), self._dtype))


class ResampleableDecoupledFeatureFunctions(ResampleableRandomFourierFeatureFunctions):
Expand Down
4 changes: 2 additions & 2 deletions trieste/models/gpflux/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,11 +453,11 @@ def resample(self) -> None:
b.assign(self._bias_init(tf.shape(b), dtype=self._dtype))

if isinstance(self.W, tf.Variable):
self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), dtype=self._dtype))
self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), self._dtype))
else:
tf.debugging.Assert(isinstance(self.W, list), [])
for W, k in zip(self.W, cycle(self.sub_kernels)):
W.assign(self._weights_init(k)(tf.shape(W), dtype=self._dtype))
W.assign(self._weights_init(k)(tf.shape(W), self._dtype))

def __call__(self, x: TensorType) -> TensorType: # [N, D] -> [N, L + M] or [P, N, L + M]
"""
Expand Down

0 comments on commit cf57ba4

Please sign in to comment.