Skip to content

Commit

Permalink
Drop CatKernel
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Nov 28, 2023
1 parent cf300fd commit 135b47f
Show file tree
Hide file tree
Showing 12 changed files with 21 additions and 248 deletions.
2 changes: 0 additions & 2 deletions gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
RFF,
AbstractKernel,
BasisFunctionComputation,
CatKernel,
ConstantDiagonalKernelComputation,
DenseKernelComputation,
DiagonalKernelComputation,
Expand Down Expand Up @@ -123,7 +122,6 @@
"CollapsedELBO",
"ELBO",
"AbstractKernel",
"CatKernel",
"Linear",
"DenseKernelComputation",
"DiagonalKernelComputation",
Expand Down
12 changes: 3 additions & 9 deletions gpjax/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@
from dataclasses import dataclass
import warnings

from beartype.typing import (
Optional,
)
from beartype.typing import Optional
import jax.numpy as jnp
from jaxtyping import (
Num,
)
from jaxtyping import Num
from simple_pytree import Pytree

from gpjax.typing import Array
Expand All @@ -49,9 +45,7 @@ def __post_init__(self) -> None:

def __repr__(self) -> str:
r"""Returns a string representation of the dataset."""
repr = (
f"- Number of observations: {self.n}\n- Input dimension: {self.in_dim}"
)
repr = f"- Number of observations: {self.n}\n- Input dimension: {self.in_dim}"
return repr

def is_supervised(self) -> bool:
Expand Down
4 changes: 1 addition & 3 deletions gpjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
from jax import vmap
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Float,
)
from jaxtyping import Float
import tensorflow_probability.substrates.jax as tfp

from gpjax.lower_cholesky import lower_cholesky
Expand Down
8 changes: 2 additions & 6 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
# ==============================================================================

from abc import abstractmethod
from dataclasses import (
dataclass,
)
from dataclasses import dataclass
from typing import overload

from beartype.typing import (
Expand Down Expand Up @@ -44,9 +42,7 @@
)
from gpjax.dataset import Dataset
from gpjax.distributions import GaussianDistribution
from gpjax.kernels import (
RFF,
)
from gpjax.kernels import RFF
from gpjax.kernels.base import AbstractKernel
from gpjax.likelihoods import (
AbstractLikelihood,
Expand Down
6 changes: 1 addition & 5 deletions gpjax/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@
DiagonalKernelComputation,
EigenKernelComputation,
)
from gpjax.kernels.non_euclidean import (
CatKernel,
GraphKernel,
)
from gpjax.kernels.non_euclidean import GraphKernel
from gpjax.kernels.nonstationary import (
ArcCosine,
Linear,
Expand All @@ -54,7 +51,6 @@
"Constant",
"RBF",
"GraphKernel",
"CatKernel",
"Matern12",
"Matern32",
"Matern52",
Expand Down
3 changes: 1 addition & 2 deletions gpjax/kernels/non_euclidean/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================

from gpjax.kernels.non_euclidean.categorical import CatKernel
from gpjax.kernels.non_euclidean.graph import GraphKernel

__all__ = ["GraphKernel", "CatKernel"]
__all__ = ["GraphKernel"]
139 changes: 0 additions & 139 deletions gpjax/kernels/non_euclidean/categorical.py

This file was deleted.

4 changes: 3 additions & 1 deletion tests/test_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def test_posterior_construct(
leaves_rmul = jtu.tree_leaves(posterior_rmul)
leaves_manual = jtu.tree_leaves(posterior_manual)

for leaf_mul, leaf_rmul, leaf_man in zip(leaves_mul, leaves_rmul, leaves_manual, strict=True):
for leaf_mul, leaf_rmul, leaf_man in zip(
leaves_mul, leaves_rmul, leaves_manual, strict=True
):
assert (leaf_mul == leaf_rmul).all()
assert (leaf_rmul == leaf_man).all()

Expand Down
79 changes: 1 addition & 78 deletions tests/test_kernels/test_non_euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,9 @@
from cola.ops import I_like
from jax import config
import jax.numpy as jnp
import jax.random as jr
import networkx as nx

from gpjax.kernels.non_euclidean import (
CatKernel,
GraphKernel,
)
from gpjax.kernels.non_euclidean import GraphKernel

# # Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -50,76 +46,3 @@ def test_graph_kernel():
Kxx += I_like(Kxx) * 1e-6
eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense())
assert all(eigen_values > 0)


def test_cat_kernel():
x = jr.normal(jr.PRNGKey(123), (5000, 3))
gram = jnp.cov(x.T)
params = CatKernel.gram_to_stddev_cholesky_lower(gram)
dk = CatKernel(
inspace_vals=list(range(len(gram))),
stddev=params.stddev,
cholesky_lower=params.cholesky_lower,
)
assert jnp.allclose(dk.explicit_gram, gram)

sdev = jnp.ones((2,))
cholesky_lower = jnp.eye(2)
inspace_vals = [0.0, 1.0]

# Initialize CatKernel object
dict_kernel = CatKernel(
stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals
)

assert dict_kernel.stddev.shape == sdev.shape
assert jnp.allclose(dict_kernel.stddev, sdev)
assert jnp.allclose(dict_kernel.cholesky_lower, cholesky_lower)
assert dict_kernel.inspace_vals == inspace_vals


def test_cat_kernel_gram_to_stddev_cholesky_lower():
gram = jnp.array([[1.0, 0.5], [0.5, 1.0]])
sdev_expected = jnp.array([1.0, 1.0])
cholesky_lower_expected = jnp.array([[1.0, 0.0], [0.5, 0.8660254]])

# Compute sdev and cholesky_lower from gram
sdev, cholesky_lower = CatKernel.gram_to_stddev_cholesky_lower(gram)

assert jnp.allclose(sdev, sdev_expected)
assert jnp.allclose(cholesky_lower, cholesky_lower_expected)


def test_cat_kernel_call():
sdev = jnp.ones((2,))
cholesky_lower = jnp.eye(2)
inspace_vals = [0.0, 1.0]

# Initialize CatKernel object
dict_kernel = CatKernel(
stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals
)

# Compute kernel value for pair of inputs
kernel_value = dict_kernel.__call__(0, 1)

assert jnp.allclose(kernel_value, 0.0) # since cholesky_lower is identity matrix


def test_cat_kernel_explicit_gram():
sdev = jnp.ones((2,))
cholesky_lower = jnp.eye(2)
inspace_vals = [0.0, 1.0]

# Initialize CatKernel object
dict_kernel = CatKernel(
stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals
)

# Compute explicit gram matrix
explicit_gram = dict_kernel.explicit_gram

assert explicit_gram.shape == (2, 2)
assert jnp.allclose(
explicit_gram, jnp.eye(2)
) # since sdev are ones and cholesky_lower is identity matrix
4 changes: 3 additions & 1 deletion tests/test_kernels/test_nonstationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None:


def prod(inp):
return [dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values())]
return [
dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values())
]


class TestLinear(BaseTestKernel):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_kernels/test_stationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def test_isotropic(self, dim: int):


def prod(inp):
return [dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values())]
return [
dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values())
]


class TestRBF(BaseTestKernel):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def _test_call_check(likelihood, latent_mean, latent_cov, latent_dist):


def prod(inp):
return [dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values())]
return [
dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values())
]


class TestGaussian(BaseTestLikelihood):
Expand Down

0 comments on commit 135b47f

Please sign in to comment.