Skip to content

Commit

Permalink
v 0.8.0
Browse files Browse the repository at this point in the history
  • Loading branch information
frazane committed Nov 30, 2023
1 parent a6da548 commit 0a8ffac
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 21 deletions.
2 changes: 1 addition & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
__description__ = "Didactic Gaussian processes in JAX"
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
__version__ = "0.7.4"
__version__ = "0.8.0"

__all__ = [
"Module",
Expand Down
34 changes: 18 additions & 16 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
Optional,
Union,
)
import cola
from cola.linalg.inverse.inv import solve
from cola.annotations import PSD
from cola.ops.operators import I_like
from cola.linalg.decompositions.decompositions import Cholesky
import jax.numpy as jnp
from jax.random import (
Expand Down Expand Up @@ -275,8 +277,8 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
x = test_inputs
mx = self.mean_function(x)
Kxx = self.kernel.gram(x)
Kxx += cola.ops.I_like(Kxx) * self.jitter
Kxx = cola.PSD(Kxx)
Kxx += I_like(Kxx) * self.jitter
Kxx = PSD(Kxx)

return GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Kxx)

Expand Down Expand Up @@ -522,24 +524,24 @@ def predict(

# Precompute Gram matrix, Kxx, at training inputs, x
Kxx = self.prior.kernel.gram(x)
Kxx += cola.ops.I_like(Kxx) * self.jitter
Kxx += I_like(Kxx) * self.jitter

# Σ = Kxx + Io²
Sigma = Kxx + cola.ops.I_like(Kxx) * obs_noise
Sigma = cola.PSD(Sigma)
Sigma = Kxx + I_like(Kxx) * obs_noise
Sigma = PSD(Sigma)

mean_t = self.prior.mean_function(t)
Ktt = self.prior.kernel.gram(t)
Kxt = self.prior.kernel.cross_covariance(x, t)
Sigma_inv_Kxt = cola.solve(Sigma, Kxt)
Sigma_inv_Kxt = solve(Sigma, Kxt)

# μt + Ktx (Kxx + Io²)⁻¹ (y - μx)
mean = mean_t + jnp.matmul(Sigma_inv_Kxt.T, y - mx)

# Ktt - Ktx (Kxx + Io²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt)
covariance += cola.ops.I_like(covariance) * self.prior.jitter
covariance = cola.PSD(covariance)
covariance += I_like(covariance) * self.prior.jitter
covariance = PSD(covariance)

return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)

Expand Down Expand Up @@ -601,11 +603,11 @@ def sample_approx(
# v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Io² and ε ᯈ N(0, o²)
obs_var = self.likelihood.obs_stddev**2
Kxx = self.prior.kernel.gram(train_data.X) # [N, N]
Sigma = Kxx + cola.ops.I_like(Kxx) * (obs_var + self.jitter) # [N, N]
Sigma = Kxx + I_like(Kxx) * (obs_var + self.jitter) # [N, N]
eps = jnp.sqrt(obs_var) * normal(key, [train_data.n, num_samples]) # [N, B]
y = train_data.y - self.prior.mean_function(train_data.X) # account for mean
Phi = fourier_feature_fn(train_data.X)
canonical_weights = cola.solve(
canonical_weights = solve(
Sigma,
y + eps - jnp.inner(Phi, fourier_weights),
Cholesky(),
Expand Down Expand Up @@ -684,8 +686,8 @@ def predict(

# Precompute lower triangular of Gram matrix, Lx, at training inputs, x
Kxx = kernel.gram(x)
Kxx += cola.ops.I_like(Kxx) * self.prior.jitter
Kxx = cola.PSD(Kxx)
Kxx += I_like(Kxx) * self.prior.jitter
Kxx = PSD(Kxx)
Lx = lower_cholesky(Kxx)

# Unpack test inputs
Expand All @@ -697,7 +699,7 @@ def predict(
mean_t = mean_function(t)

# Lx⁻¹ Kxt
Lx_inv_Kxt = cola.solve(Lx, Ktx.T, Cholesky())
Lx_inv_Kxt = solve(Lx, Ktx.T, Cholesky())

# Whitened function values, wx, corresponding to the inputs, x
wx = self.latent
Expand All @@ -707,8 +709,8 @@ def predict(

# Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
covariance += cola.ops.I_like(covariance) * self.prior.jitter
covariance = cola.PSD(covariance)
covariance += I_like(covariance) * self.prior.jitter
covariance = PSD(covariance)

return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)

Expand Down
3 changes: 2 additions & 1 deletion gpjax/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Num,
)
import tensorflow_probability.substrates.jax.distributions as tfd
from cola.ops.operators import LinearOperator

from gpjax.base import (
Module,
Expand Down Expand Up @@ -60,7 +61,7 @@ def ndims(self):
def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]):
return self.compute_engine.cross_covariance(self, x, y)

def gram(self, x: Num[Array, "N D"]):
def gram(self, x: Num[Array, "N D"]) -> LinearOperator:
return self.compute_engine.gram(self, x)

def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]:
Expand Down
4 changes: 2 additions & 2 deletions gpjax/kernels/computations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from dataclasses import dataclass
import typing as tp

from cola import PSD
from cola.ops import (
from cola.annotations import PSD
from cola.ops.operators import (
Dense,
Diagonal,
LinearOperator,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gpjax"
version = "0.7.4"
version = "0.8.0"
description = "Gaussian processes in JAX."
authors = [
"Thomas Pinder <[email protected]>",
Expand Down

0 comments on commit 0a8ffac

Please sign in to comment.