Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optional solver_algorithm parameter to sample_approx #478

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import beartype.typing as tp
from cola.annotations import PSD
from cola.linalg.algorithm_base import Algorithm
from cola.linalg.decompositions.decompositions import Cholesky
from cola.linalg.inverse.inv import solve
from cola.ops.operators import I_like
Expand Down Expand Up @@ -530,6 +531,7 @@ def sample_approx(
train_data: Dataset,
key: KeyArray,
num_features: int | None = 100,
solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
) -> FunctionalSample:
r"""Draw approximate samples from the Gaussian process posterior.

Expand Down Expand Up @@ -563,6 +565,11 @@ def sample_approx(
key (KeyArray): The random seed used for the sample(s).
num_features (int): The number of features used when approximating the
kernel.
solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
the inverse of the covariance matrix. See the
[CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().

Returns:
FunctionalSample: A function representing an approximate sample from the Gaussian
Expand All @@ -588,7 +595,7 @@ def sample_approx(
canonical_weights = solve(
Sigma,
y + eps - jnp.inner(Phi, fourier_weights),
Cholesky(),
solver_algorithm,
) # [N, B]

def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
Expand Down
20 changes: 14 additions & 6 deletions tests/test_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@
Type,
)

# from gpjax.dataset import Dataset
theorashid marked this conversation as resolved.
Show resolved Hide resolved
from cola.linalg.algorithm_base import Auto
from cola.linalg.decompositions.decompositions import Cholesky
from cola.linalg.inverse.cg import CG
from jax import config
import jax.numpy as jnp
import jax.random as jr
import pytest
import tensorflow_probability.substrates.jax.distributions as tfd

# from gpjax.dataset import Dataset
from gpjax.dataset import Dataset
from gpjax.distributions import GaussianDistribution
from gpjax.gps import (
Expand Down Expand Up @@ -283,7 +286,10 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function):
@pytest.mark.parametrize("num_datapoints", [1, 5])
@pytest.mark.parametrize("kernel", [RBF, Matern52])
@pytest.mark.parametrize("mean_function", [Zero, Constant])
def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function):
@pytest.mark.parametrize("solver_algorithm", [Cholesky(), CG(), Auto()])
def test_conjugate_posterior_sample_approx(
num_datapoints, kernel, mean_function, solver_algorithm
):
kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1)
p = Prior(kernel=kern, mean_function=mean_function()) * Gaussian(
num_datapoints=num_datapoints
Expand All @@ -310,26 +316,28 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function
# with pytest.raises(ValidationErrors):
# p.sample_approx(1, D, key, 0.5)

sampled_fn = p.sample_approx(1, D, key, 100)
sampled_fn = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
assert isinstance(sampled_fn, Callable) # check type

x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
evals = sampled_fn(x)
assert evals.shape == (num_datapoints, 1.0) # check shape

sampled_fn_2 = p.sample_approx(1, D, key, 100)
sampled_fn_2 = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
evals_2 = sampled_fn_2(x)
max_delta = jnp.max(jnp.abs(evals - evals_2))
assert max_delta == 0.0 # samples same for same seed

new_key = jr.key(12345)
sampled_fn_3 = p.sample_approx(1, D, new_key, 100)
sampled_fn_3 = p.sample_approx(
1, D, new_key, 100, solver_algorithm=solver_algorithm
)
evals_3 = sampled_fn_3(x)
max_delta = jnp.max(jnp.abs(evals - evals_3))
assert max_delta > 0.01 # samples different for different seed

# Check validty of samples using Monte-Carlo
sampled_fn = p.sample_approx(10_000, D, key, 100)
sampled_fn = p.sample_approx(10_000, D, key, 100, solver_algorithm=solver_algorithm)
sampled_evals = sampled_fn(x)
approx_mean = jnp.mean(sampled_evals, -1)
approx_var = jnp.var(sampled_evals, -1)
Expand Down
Loading