Skip to content

Commit

Permalink
Fix param randomization, generating repeating values (#785)
Browse files Browse the repository at this point in the history
* Fix param randomization, generating repeating values

* Add comment explaining change
  • Loading branch information
khurram-ghani authored Sep 20, 2023
1 parent 963f1dc commit 3279d31
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
32 changes: 30 additions & 2 deletions tests/unit/models/gpflow/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,42 @@ def test_randomize_hyperparameters_samples_from_constraints_when_given_prior_and


@random_seed
def test_randomize_hyperparameters_samples_different_values_for_multi_dimensional_params() -> None:
@pytest.mark.parametrize("compile", [False, True])
def test_randomize_hyperparameters_generates_different_values_for_params(compile: bool) -> None:
dim = 2
kernel = gpflow.kernels.RBF(variance=1.0, lengthscales=[1.0] * dim)
model = gpflow.models.GPR(data=(tf.zeros((1, dim)), tf.zeros((1, 1))), kernel=kernel)
kernel.variance.prior = tfp.distributions.LogNormal(0.0, 2.0)
model.likelihood.variance.prior = tfp.distributions.LogNormal(0.0, 2.0)
model.kernel.lengthscales.prior = tfp.distributions.LogNormal(0.0, 2.0)

compiler = tf.function if compile else lambda x: x
compiler(randomize_hyperparameters)(model)

npt.assert_raises(AssertionError, npt.assert_allclose, kernel.lengthscales[0], kernel.variance)
npt.assert_raises(
AssertionError, npt.assert_allclose, kernel.lengthscales[0], model.likelihood.variance
)
npt.assert_raises(
AssertionError, npt.assert_allclose, kernel.variance, model.likelihood.variance
)


@random_seed
@pytest.mark.parametrize("compile", [False, True])
def test_randomize_hyperparameters_samples_different_values_for_multi_dimensional_params(
compile: bool,
) -> None:
kernel = gpflow.kernels.RBF(variance=1.0, lengthscales=[0.2, 0.2])
upper = tf.cast([10.0] * 2, dtype=tf.float64)
lower = upper / 100
kernel.lengthscales = gpflow.Parameter(
kernel.lengthscales, transform=tfp.bijectors.Sigmoid(low=lower, high=upper)
)
randomize_hyperparameters(kernel)

compiler = tf.function if compile else lambda x: x
compiler(randomize_hyperparameters)(kernel)

npt.assert_raises(
AssertionError, npt.assert_allclose, kernel.lengthscales[0], kernel.lengthscales[1]
)
Expand Down
6 changes: 5 additions & 1 deletion trieste/models/gpflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def randomize_hyperparameters(object: gpflow.Module) -> None:
param.assign(sample)
elif param.prior is not None:
# handle constant priors for multi-dimensional parameters
if param.prior.batch_shape == param.prior.event_shape == [] and tf.rank(param) == 1:
# Use python conditionals here to avoid creating tensorflow `tf.cond` ops,
# i.e. using `len(param.shape)` instead of `tf.rank(param)`.
# Otherwise, tensorflow generates repeating random sequences for hyperparameters, see
# https://github.com/tensorflow/tensorflow/issues/61912.
if param.prior.batch_shape == param.prior.event_shape == [] and len(param.shape) == 1:
sample = param.prior.sample(tf.shape(param))
else:
sample = param.prior.sample()
Expand Down

0 comments on commit 3279d31

Please sign in to comment.