diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 2d692cc74..0bb0a6bbb 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -40,7 +40,7 @@ __description__ = "Didactic Gaussian processes in JAX" __url__ = "https://github.com/JaxGaussianProcesses/GPJax" __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors" -__version__ = "0.9.1" +__version__ = "0.9.2" __all__ = [ "base", diff --git a/gpjax/gps.py b/gpjax/gps.py index db11f8066..cca9ca365 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -17,6 +17,7 @@ import beartype.typing as tp from cola.annotations import PSD +from cola.linalg.algorithm_base import Algorithm from cola.linalg.decompositions.decompositions import Cholesky from cola.linalg.inverse.inv import solve from cola.ops.operators import I_like @@ -530,6 +531,7 @@ def sample_approx( train_data: Dataset, key: KeyArray, num_features: int | None = 100, + solver_algorithm: tp.Optional[Algorithm] = Cholesky(), ) -> FunctionalSample: r"""Draw approximate samples from the Gaussian process posterior. @@ -563,6 +565,11 @@ def sample_approx( key (KeyArray): The random seed used for the sample(s). num_features (int): The number of features used when approximating the kernel. + solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of + the inverse of the covariance matrix. See the + [CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms) + for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small + matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky(). Returns: FunctionalSample: A function representing an approximate sample from the Gaussian @@ -588,7 +595,7 @@ def sample_approx( canonical_weights = solve( Sigma, y + eps - jnp.inner(Phi, fourier_weights), - Cholesky(), + solver_algorithm, ) # [N, B] def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]: diff --git a/tests/test_gps.py b/tests/test_gps.py index 708ae08c4..da85a6840 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -25,13 +25,15 @@ Type, ) +from cola.linalg.algorithm_base import Auto +from cola.linalg.decompositions.decompositions import Cholesky +from cola.linalg.inverse.cg import CG from jax import config import jax.numpy as jnp import jax.random as jr import pytest import tensorflow_probability.substrates.jax.distributions as tfd -# from gpjax.dataset import Dataset from gpjax.dataset import Dataset from gpjax.distributions import GaussianDistribution from gpjax.gps import ( @@ -283,7 +285,10 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function): @pytest.mark.parametrize("num_datapoints", [1, 5]) @pytest.mark.parametrize("kernel", [RBF, Matern52]) @pytest.mark.parametrize("mean_function", [Zero, Constant]) -def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function): +@pytest.mark.parametrize("solver_algorithm", [Cholesky(), CG(), Auto()]) +def test_conjugate_posterior_sample_approx( + num_datapoints, kernel, mean_function, solver_algorithm +): kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1) p = Prior(kernel=kern, mean_function=mean_function()) * Gaussian( num_datapoints=num_datapoints @@ -310,26 +315,28 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function # with pytest.raises(ValidationErrors): # p.sample_approx(1, D, key, 0.5) - sampled_fn = p.sample_approx(1, D, key, 100) + sampled_fn = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm) assert isinstance(sampled_fn, Callable) # check type x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2)) evals = sampled_fn(x) assert evals.shape == (num_datapoints, 1.0) # check shape - sampled_fn_2 = p.sample_approx(1, D, key, 100) + sampled_fn_2 = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm) evals_2 = sampled_fn_2(x) max_delta = jnp.max(jnp.abs(evals - evals_2)) assert max_delta == 0.0 # samples same for same seed new_key = jr.key(12345) - sampled_fn_3 = p.sample_approx(1, D, new_key, 100) + sampled_fn_3 = p.sample_approx( + 1, D, new_key, 100, solver_algorithm=solver_algorithm + ) evals_3 = sampled_fn_3(x) max_delta = jnp.max(jnp.abs(evals - evals_3)) assert max_delta > 0.01 # samples different for different seed # Check validty of samples using Monte-Carlo - sampled_fn = p.sample_approx(10_000, D, key, 100) + sampled_fn = p.sample_approx(10_000, D, key, 100, solver_algorithm=solver_algorithm) sampled_evals = sampled_fn(x) approx_mean = jnp.mean(sampled_evals, -1) approx_var = jnp.var(sampled_evals, -1)