Skip to content

Commit

Permalink
Fix trajectory sampling to handle active dims (#790)
Browse files Browse the repository at this point in the history
* Fix trajectory sampling to handle active dims

* Add unit test

* Add docstring
  • Loading branch information
khurram-ghani authored Oct 16, 2023
1 parent 2807b76 commit aa94c86
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
53 changes: 53 additions & 0 deletions tests/unit/models/gpflow/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Callable, List, Tuple, Type
from unittest.mock import MagicMock

import gpflow
import numpy.testing as npt
import pytest
import tensorflow as tf
Expand All @@ -30,6 +31,7 @@
GaussianProcess,
QuadraticMeanAndRBFKernel,
QuadraticMeanAndRBFKernelWithSamplers,
gpr_model,
quadratic_mean_rbf_kernel_model,
rbf,
svgp_model,
Expand All @@ -39,6 +41,7 @@
from trieste.models.gpflow import (
BatchReparametrizationSampler,
DecoupledTrajectorySampler,
GaussianProcessRegression,
IndependentReparametrizationSampler,
RandomFourierFeatureTrajectorySampler,
SparseVariational,
Expand Down Expand Up @@ -602,6 +605,56 @@ def test_rff_trajectory_update_trajectory_updates_and_doesnt_retrace(
assert trajectory.__call__._get_tracing_count() == 1 # type: ignore


@pytest.mark.parametrize(
"sampler_type", [RandomFourierFeatureTrajectorySampler, DecoupledTrajectorySampler]
)
@pytest.mark.parametrize(
"num_dimensions, active_dims",
[
(2, [0]),
(2, [1]),
(5, [1, 4]),
(5, [3, 2, 0]),
],
)
@random_seed
def test_trajectory_sampler_respects_active_dims(
sampler_type: Type[RandomFourierFeatureTrajectorySampler],
num_dimensions: int,
active_dims: List[int],
) -> None:
# Test that the trajectory sampler respects the active_dims setting in a GPflow model.
num_points = 10
query_points = tf.random.uniform((num_points, num_dimensions), dtype=tf.float64)
dataset = Dataset(query_points, quadratic(query_points))

model = GaussianProcessRegression(gpr_model(dataset.query_points, dataset.observations))
model.model.kernel = gpflow.kernels.Matern52(active_dims=active_dims)

num_active_dims = len(active_dims)
active_dims_mask = tf.scatter_nd(
tf.expand_dims(active_dims, -1), [True] * num_active_dims, (num_dimensions,)
)
x_rnd = tf.random.uniform((num_points, num_dimensions), dtype=tf.float64)
x_fix = tf.constant(0.5, shape=(num_points, num_dimensions), dtype=tf.float64)
# We vary values only on the irrelevant dimensions.
x_test = tf.where(active_dims_mask, x_fix, x_rnd)

batch_size = 2
x_test_with_batching = tf.expand_dims(x_test, -2)
x_test_with_batching = tf.tile(x_test_with_batching, [1, batch_size, 1]) # [N, B, D]
trajectory_sampler = sampler_type(model)
trajectory = trajectory_sampler.get_trajectory()

model_eval = trajectory(x_test_with_batching)
assert model_eval.shape == (num_points, batch_size, 1)
# The output should be constant since data only varies on irrelevant dimensions.
npt.assert_array_almost_equal(
tf.math.reduce_std(model_eval, axis=0),
tf.constant(0.0, shape=(batch_size, 1), dtype=tf.float64),
)


@pytest.mark.parametrize("num_features", [0, -2])
def test_decoupled_trajectory_sampler_raises_for_invalid_number_of_features(
num_features: int,
Expand Down
12 changes: 10 additions & 2 deletions trieste/models/gpflow/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,9 @@ def _get_kernel_function(kernel: Kernel) -> Callable[[TensorType, TensorType], t
# Select between a multioutput kernel and a single-output kernel.
def K(X: TensorType, X2: Optional[TensorType] = None) -> tf.Tensor:
if _is_multioutput_kernel(kernel):
return kernel.K(X, X2, full_output_cov=False) # [L, M, M]
return kernel(X, X2, full_cov=True, full_output_cov=False) # [L, M, M]
else:
return tf.expand_dims(kernel.K(X, X2), axis=0) # [1, M, M]
return tf.expand_dims(kernel(X, X2), axis=0) # [1, M, M]

return K

Expand Down Expand Up @@ -781,6 +781,7 @@ def __init__(
dummy_X = model.get_inducing_variables()[0][0:1, :]
else:
dummy_X = model.get_internal_data().query_points[0:1, :]
dummy_X = self.kernel.slice(dummy_X, None)[0] # Keep only the active dims from the kernel.

# Always build the weights and biases. This is important for saving the trajectory (using
# tf.saved_model.save) before it has been used.
Expand All @@ -793,6 +794,13 @@ def resample(self) -> None:
self.b.assign(self._bias_init(tf.shape(self.b), dtype=self._dtype))
self.W.assign(self._weights_init(tf.shape(self.W), dtype=self._dtype))

def call(self, x: TensorType) -> TensorType: # [N, D] -> [N, F] or [L, N, F]
"""
Evaluate the basis functions at ``x``
"""
x = self.kernel.slice(x, None)[0] # Keep only the active dims from the kernel.
return super().call(x) # [N, F] or [L, N, F]


class ResampleableDecoupledFeatureFunctions(ResampleableRandomFourierFeatureFunctions):
"""
Expand Down

0 comments on commit aa94c86

Please sign in to comment.