Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CoLA integration #370

Merged
merged 16 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.8"]
python-version: ["3.10"]

steps:
# Grap the latest commit from the branch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
# Select the Python versions to test against
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]
fail-fast: true
steps:
- name: Check out the code
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.8"]
python-version: ["3.10"]

steps:
# Grap the latest commit from the branch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
# Select the Python versions to test against
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]
fail-fast: true
steps:
- name: Check out the code
Expand Down
779 changes: 779 additions & 0 deletions docs/examples/classification.ipynb

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,17 @@
# datapoints below.

# %%
import cola
from gpjax.lower_cholesky import lower_cholesky

gram, cross_covariance = (kernel.gram, kernel.cross_covariance)
jitter = 1e-6

# Compute (latent) function value map estimates at training points:
Kxx = opt_posterior.prior.kernel.gram(x)
Kxx += identity_matrix(D.n) * jitter
Lx = Kxx.to_root()
Kxx = cola.PSD(Kxx)
Lx = lower_cholesky(Kxx)
f_hat = Lx @ opt_posterior.latent

# Negative Hessian, H = -∇²p_tilde(y|f):
Expand Down Expand Up @@ -250,16 +254,13 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs)
Kxx = opt_posterior.prior.kernel.gram(x)
Kxx += identity_matrix(D.n) * jitter
Lx = Kxx.to_root()

# Lx⁻¹ Kxt
Lx_inv_Ktx = Lx.solve(Kxt)
Kxx = cola.PSD(Kxx)

# Kxx⁻¹ Kxt
Kxx_inv_Ktx = Lx.T.solve(Lx_inv_Ktx)
Kxx_inv_Kxt = cola.solve(Kxx, Kxt)

# Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Ktx.T, H_inv), Kxx_inv_Ktx)
laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Kxt.T, H_inv), Kxx_inv_Kxt)

mean = map_latent_dist.mean()
covariance = map_latent_dist.covariance() + laplace_cov_term
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
# like our RBF kernel to act on the first, second and fourth dimensions.

# %%
slice_kernel = gpx.kernels.RBF(active_dims=[0, 1, 3], lengthscale = jnp.ones((3,)))
slice_kernel = gpx.kernels.RBF(active_dims=[0, 1, 3], lengthscale=jnp.ones((3,)))

# %% [markdown]
#
Expand Down
2 changes: 1 addition & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
RFF,
AbstractKernel,
BasisFunctionComputation,
ConstantDiagonalKernelComputation,
CatKernel,
ConstantDiagonalKernelComputation,
DenseKernelComputation,
DiagonalKernelComputation,
EigenKernelComputation,
Expand Down
44 changes: 23 additions & 21 deletions gpjax/citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
dataclass,
fields,
)
from functools import singledispatch

from beartype.typing import (
Dict,
Union,
)
from jaxlib.xla_extension import PjitFunction
from plum import dispatch

from gpjax.kernels import (
RFF,
Expand All @@ -26,8 +26,6 @@
NonConjugateMLL,
)

MaternKernels = Union[Matern12, Matern32, Matern52]
MLLs = Union[ConjugateMLL, NonConjugateMLL, LogPosteriorDensity]
CitationType = Union[str, Dict[str, str]]


Expand Down Expand Up @@ -89,24 +87,26 @@ class BookCitation(AbstractCitation):
####################
# Default citation
####################
@dispatch
def cite(tree) -> NullCitation:
@singledispatch
def cite(tree) -> AbstractCitation:
return NullCitation()


####################
# Default citation
####################
@dispatch
def cite(tree: PjitFunction) -> JittedFnCitation:
@cite.register(PjitFunction)
def _(tree):
return JittedFnCitation()


####################
# Kernel citations
####################
@dispatch
def cite(tree: MaternKernels) -> PhDThesisCitation:
@cite.register(Matern12)
@cite.register(Matern32)
@cite.register(Matern52)
def _(tree) -> PhDThesisCitation:
citation = PhDThesisCitation(
citation_key="matern1960SpatialV",
authors="Bertil Matérn",
Expand All @@ -121,8 +121,8 @@ def cite(tree: MaternKernels) -> PhDThesisCitation:
return citation


@dispatch
def cite(tree: ArcCosine) -> PaperCitation:
@cite.register(ArcCosine)
def _(_) -> PaperCitation:
return PaperCitation(
citation_key="cho2009kernel",
authors="Cho, Youngmin and Saul, Lawrence",
Expand All @@ -132,8 +132,8 @@ def cite(tree: ArcCosine) -> PaperCitation:
)


@dispatch
def cite(tree: GraphKernel) -> PaperCitation:
@cite.register(GraphKernel)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="borovitskiy2021matern",
title="Matérn Gaussian Processes on Graphs",
Expand All @@ -146,8 +146,8 @@ def cite(tree: GraphKernel) -> PaperCitation:
)


@dispatch
def cite(tree: RFF) -> PaperCitation:
@cite.register(RFF)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="rahimi2007random",
authors="Rahimi, Ali and Recht, Benjamin",
Expand All @@ -161,8 +161,10 @@ def cite(tree: RFF) -> PaperCitation:
####################
# Objective citations
####################
@dispatch
def cite(tree: MLLs) -> BookCitation:
@cite.register(ConjugateMLL)
@cite.register(NonConjugateMLL)
@cite.register(LogPosteriorDensity)
def _(tree) -> BookCitation:
return BookCitation(
citation_key="rasmussen2006gaussian",
title="Gaussian Processes for Machine Learning",
Expand All @@ -173,8 +175,8 @@ def cite(tree: MLLs) -> BookCitation:
)


@dispatch
def cite(tree: CollapsedELBO) -> PaperCitation:
@cite.register(CollapsedELBO)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="titsias2009variational",
title="Variational learning of inducing variables in sparse Gaussian processes",
Expand All @@ -184,8 +186,8 @@ def cite(tree: CollapsedELBO) -> PaperCitation:
)


@dispatch
def cite(tree: ELBO) -> PaperCitation:
@cite.register(ELBO)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="hensman2013gaussian",
title="Gaussian Processes for Big Data",
Expand Down
58 changes: 34 additions & 24 deletions gpjax/gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
Optional,
Tuple,
)
import cola
from cola.ops import (
Dense,
Identity,
)
from jax import vmap
import jax.numpy as jnp
import jax.random as jr
Expand All @@ -28,10 +33,7 @@
)
import tensorflow_probability.substrates.jax as tfp

from gpjax.linops import (
IdentityLinearOperator,
LinearOperator,
)
from gpjax.lower_cholesky import lower_cholesky
from gpjax.typing import (
Array,
KeyArray,
Expand All @@ -49,15 +51,15 @@ def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
if loc is not None and loc.ndim < 1:
raise ValueError("The parameter `loc` must have at least one dimension.")

if scale is not None and scale.ndim < 2:
if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
raise ValueError(
"The `scale` must have at least two dimensions, but "
f"`scale.shape = {scale.shape}`."
)

if scale is not None and not isinstance(scale, LinearOperator):
if scale is not None and not isinstance(scale, cola.LinearOperator):
raise ValueError(
f"scale must be a LinearOperator or a JAX array, but got {type(scale)}"
f"The `scale` must be a cola.LinearOperator but got {type(scale)}"
)

if scale is not None and (scale.shape[-1] != scale.shape[-2]):
Expand All @@ -79,7 +81,7 @@ class GaussianDistribution(tfd.Distribution):

Args:
loc (Optional[Float[Array, " N"]]): The mean of the distribution. Defaults to None.
scale (Optional[LinearOperator]): The scale matrix of the distribution. Defaults to None.
scale (Optional[cola.LinearOperator]): The scale matrix of the distribution. Defaults to None.

Returns
-------
Expand All @@ -94,7 +96,7 @@ class GaussianDistribution(tfd.Distribution):
def __init__(
self,
loc: Optional[Float[Array, " N"]] = None,
scale: Optional[LinearOperator] = None,
scale: Optional[cola.LinearOperator] = None,
) -> None:
r"""Initialises the distribution."""
_check_loc_scale(loc, scale)
Expand All @@ -112,10 +114,10 @@ def __init__(

# If not specified, set the scale to the identity matrix.
if scale is None:
scale = IdentityLinearOperator(num_dims)
scale = Identity(shape=(num_dims, num_dims), dtype=loc.dtype)

self.loc = loc
self.scale = scale
self.scale = cola.PSD(scale)

def mean(self) -> Float[Array, " N"]:
r"""Calculates the mean."""
Expand All @@ -135,11 +137,11 @@ def covariance(self) -> Float[Array, "N N"]:

def variance(self) -> Float[Array, " N"]:
r"""Calculates the variance."""
return self.scale.diagonal()
return cola.diag(self.scale)
Comment on lines -138 to +140
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check: This probably has edge case behaviour as diag switches between diagonal and (dense) diagonal matrix, while diagonal is strictly to a diagonal array.


def stddev(self) -> Float[Array, " N"]:
r"""Calculates the standard deviation."""
return jnp.sqrt(self.scale.diagonal())
return jnp.sqrt(cola.diag(self.scale))

@property
def event_shape(self) -> Tuple:
Expand All @@ -149,7 +151,10 @@ def event_shape(self) -> Tuple:
def entropy(self) -> ScalarFloat:
r"""Calculates the entropy of the distribution."""
return 0.5 * (
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi)) + self.scale.log_det()
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
+ cola.logdet(
self.scale, method="dense"
) # <--- Seems to be an issue with CoLA!
)

def log_prob(
Expand All @@ -168,20 +173,23 @@ def log_prob(
mu = self.loc
sigma = self.scale
n = mu.shape[-1]

if mask is not None:
y = jnp.where(mask, 0.0, y)
mu = jnp.where(mask, 0.0, mu)
sigma_masked = jnp.where(mask[None] + mask[:, None], 0.0, sigma.matrix)
sigma = sigma.replace(
matrix=jnp.where(jnp.diag(mask), 1 / (2 * jnp.pi), sigma_masked)
sigma_masked = jnp.where(mask[None] + mask[:, None], 0.0, sigma.to_dense())
sigma = cola.PSD(
Dense(jnp.where(jnp.diag(mask), 1 / (2 * jnp.pi), sigma_masked))
)

# diff, y - µ
diff = y - mu

# compute the pdf, -1/2[ n log(2π) + log|Σ| + (y - µ)ᵀΣ⁻¹(y - µ) ]
return -0.5 * (
n * jnp.log(2.0 * jnp.pi) + sigma.log_det() + diff.T @ sigma.solve(diff)
n * jnp.log(2.0 * jnp.pi)
+ cola.logdet(sigma, method="dense") # <--- Seems to be an issue with CoLA!
+ diff.T @ cola.solve(sigma, diff)
)

def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
Expand All @@ -195,7 +203,7 @@ def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
Float[Array, "n N"]: The samples.
"""
# Obtain covariance root.
sqrt = self.scale.to_root()
sqrt = lower_cholesky(self.scale)

# Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
Z = jr.normal(key, shape=(n, *self.event_shape))
Expand Down Expand Up @@ -263,24 +271,26 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl
sigma_p = p.scale

# Find covariance roots.
sqrt_p = sigma_p.to_root()
sqrt_q = sigma_q.to_root()
sqrt_p = lower_cholesky(sigma_p)
sqrt_q = lower_cholesky(sigma_q)

# diff, μp - μq
diff = mu_p - mu_q

# trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])²
trace = _frobenius_norm_squared(
sqrt_p.solve(sqrt_q.to_dense())
cola.solve(sqrt_p, sqrt_q.to_dense())
) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator.

# Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])²
mahalanobis = jnp.sum(
jnp.square(sqrt_p.solve(diff))
jnp.square(cola.solve(sqrt_p, diff))
) # TODO: Need to improve this. Perhaps add a Mahalanobis method to ``LinearOperator``s.

# KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2
return (mahalanobis - n_dim - sigma_q.log_det() + sigma_p.log_det() + trace) / 2.0
return (
mahalanobis - n_dim - cola.logdet(sigma_q) + cola.logdet(sigma_p) + trace
) / 2.0


__all__ = [
Expand Down
Loading