From 135b47f44193b6efb49405412101b11bb5ae9965 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 28 Nov 2023 21:47:26 +0100 Subject: [PATCH] Drop CatKernel --- gpjax/__init__.py | 2 - gpjax/dataset.py | 12 +- gpjax/distributions.py | 4 +- gpjax/gps.py | 8 +- gpjax/kernels/__init__.py | 6 +- gpjax/kernels/non_euclidean/__init__.py | 3 +- gpjax/kernels/non_euclidean/categorical.py | 139 --------------------- tests/test_gps.py | 4 +- tests/test_kernels/test_non_euclidean.py | 79 +----------- tests/test_kernels/test_nonstationary.py | 4 +- tests/test_kernels/test_stationary.py | 4 +- tests/test_likelihoods.py | 4 +- 12 files changed, 21 insertions(+), 248 deletions(-) delete mode 100644 gpjax/kernels/non_euclidean/categorical.py diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 27cbf986b..6e45ae72d 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -35,7 +35,6 @@ RFF, AbstractKernel, BasisFunctionComputation, - CatKernel, ConstantDiagonalKernelComputation, DenseKernelComputation, DiagonalKernelComputation, @@ -123,7 +122,6 @@ "CollapsedELBO", "ELBO", "AbstractKernel", - "CatKernel", "Linear", "DenseKernelComputation", "DiagonalKernelComputation", diff --git a/gpjax/dataset.py b/gpjax/dataset.py index b812a56f0..5fcc71baf 100644 --- a/gpjax/dataset.py +++ b/gpjax/dataset.py @@ -16,13 +16,9 @@ from dataclasses import dataclass import warnings -from beartype.typing import ( - Optional, -) +from beartype.typing import Optional import jax.numpy as jnp -from jaxtyping import ( - Num, -) +from jaxtyping import Num from simple_pytree import Pytree from gpjax.typing import Array @@ -49,9 +45,7 @@ def __post_init__(self) -> None: def __repr__(self) -> str: r"""Returns a string representation of the dataset.""" - repr = ( - f"- Number of observations: {self.n}\n- Input dimension: {self.in_dim}" - ) + repr = f"- Number of observations: {self.n}\n- Input dimension: {self.in_dim}" return repr def is_supervised(self) -> bool: diff --git a/gpjax/distributions.py b/gpjax/distributions.py index 19306a87d..24bf0a601 100644 --- a/gpjax/distributions.py +++ b/gpjax/distributions.py @@ -28,9 +28,7 @@ from jax import vmap import jax.numpy as jnp import jax.random as jr -from jaxtyping import ( - Float, -) +from jaxtyping import Float import tensorflow_probability.substrates.jax as tfp from gpjax.lower_cholesky import lower_cholesky diff --git a/gpjax/gps.py b/gpjax/gps.py index 0759ee79a..5928ef491 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -14,9 +14,7 @@ # ============================================================================== from abc import abstractmethod -from dataclasses import ( - dataclass, -) +from dataclasses import dataclass from typing import overload from beartype.typing import ( @@ -44,9 +42,7 @@ ) from gpjax.dataset import Dataset from gpjax.distributions import GaussianDistribution -from gpjax.kernels import ( - RFF, -) +from gpjax.kernels import RFF from gpjax.kernels.base import AbstractKernel from gpjax.likelihoods import ( AbstractLikelihood, diff --git a/gpjax/kernels/__init__.py b/gpjax/kernels/__init__.py index 3e01404e9..a3f86352f 100644 --- a/gpjax/kernels/__init__.py +++ b/gpjax/kernels/__init__.py @@ -28,10 +28,7 @@ DiagonalKernelComputation, EigenKernelComputation, ) -from gpjax.kernels.non_euclidean import ( - CatKernel, - GraphKernel, -) +from gpjax.kernels.non_euclidean import GraphKernel from gpjax.kernels.nonstationary import ( ArcCosine, Linear, @@ -54,7 +51,6 @@ "Constant", "RBF", "GraphKernel", - "CatKernel", "Matern12", "Matern32", "Matern52", diff --git a/gpjax/kernels/non_euclidean/__init__.py b/gpjax/kernels/non_euclidean/__init__.py index ee45287b0..d364bc71b 100644 --- a/gpjax/kernels/non_euclidean/__init__.py +++ b/gpjax/kernels/non_euclidean/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== -from gpjax.kernels.non_euclidean.categorical import CatKernel from gpjax.kernels.non_euclidean.graph import GraphKernel -__all__ = ["GraphKernel", "CatKernel"] +__all__ = ["GraphKernel"] diff --git a/gpjax/kernels/non_euclidean/categorical.py b/gpjax/kernels/non_euclidean/categorical.py deleted file mode 100644 index 1b0c9ead2..000000000 --- a/gpjax/kernels/non_euclidean/categorical.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -from dataclasses import dataclass -from typing import ( - NamedTuple, - Union, -) - -import jax.numpy as jnp -from jaxtyping import ( - Float, - Int, -) -import tensorflow_probability.substrates.jax as tfp - -from gpjax.base import ( - param_field, - static_field, -) -from gpjax.kernels.base import AbstractKernel -from gpjax.typing import ( - Array, - ScalarInt, -) - -tfb = tfp.bijectors - -CatKernelParams = NamedTuple( - "CatKernelParams", - [("stddev", Float[Array, "N 1"]), ("cholesky_lower", Float[Array, " N*(N-1)//2"])], -) - - -@dataclass -class CatKernel(AbstractKernel): - r"""The categorical kernel is defined for a fixed number of values of categorical input. - - It stores a standard dev for each input value (i.e. the diagonal of the gram), and a lower cholesky factor for correlations. - It returns the corresponding values from an the gram matrix when called. - - Args: - stddev (Float[Array, "N"]): The standard deviation parameters, one for each input space value. - cholesky_lower (Float[Array, "N*(N-1)//2 N"]): The parameters for the Cholesky factor of the gram matrix. - inspace_vals (list): The values in the input space this CatKernel works for. Stored for order reference, making clear the indices used for each input space value. - name (str): The name of the kernel. - input_1hot (bool): If True, the kernel expect to be called with a 1-hot encoding of the input space values. If False, it expects the indices of the input space values. - - Raises: - ValueError: If the number of diagonal variance parameters does not match the number of input space values. - """ - - stddev: Float[Array, " N"] = param_field(jnp.ones((2,)), bijector=tfb.Softplus()) - cholesky_lower: Float[Array, "N N"] = param_field( - jnp.eye(2), bijector=tfb.CorrelationCholesky() - ) - inspace_vals: Union[list, None] = static_field(None) - name: str = "Categorical Kernel" - input_1hot: bool = static_field(False) - - def __post_init__(self): - if self.inspace_vals is not None and len(self.inspace_vals) != len(self.stddev): - raise ValueError( - f"The number of stddev parameters ({len(self.stddev)}) has to match the number of input space values ({len(self.inspace_vals)}), unless inspace_vals is None." - ) - - @property - def explicit_gram(self) -> Float[Array, "N N"]: - """Access the PSD gram matrix resulting from the parameters. - - Returns: - Float[Array, "N N"]: The gram matrix. - """ - L = self.stddev.reshape(-1, 1) * self.cholesky_lower - return L @ L.T - - def __call__( # TODO not consistent with general kernel interface - self, - x: Union[ScalarInt, Int[Array, " N"]], - y: Union[ScalarInt, Int[Array, " N"]], - ): - r"""Compute the (co)variance between a pair of dictionary indices. - - Args: - x (Union[ScalarInt, Int[Array, "N"]]): The index of the first dictionary entry, or its one-hot encoding. - y (Union[ScalarInt, Int[Array, "N"]]): The index of the second dictionary entry, or its one-hot encoding. - - Returns - ------- - ScalarFloat: The value of $k(v_i, v_j)$. - """ - try: - x = x.squeeze() - y = y.squeeze() - except AttributeError: - pass - if self.input_1hot: - return self.explicit_gram[jnp.outer(x, y) == 1] - else: - return self.explicit_gram[x, y] - - @staticmethod - def num_cholesky_lower_params(num_inspace_vals: ScalarInt) -> ScalarInt: - """Compute the number of parameters required to store the lower triangular Cholesky factor of the gram matrix. - - Args: - num_inspace_vals (ScalarInt): The number of values in the input space. - - Returns: - ScalarInt: The number of parameters required to store the lower triangle of the Cholesky factor of the gram matrix. - """ - return num_inspace_vals * (num_inspace_vals - 1) // 2 - - @staticmethod - def gram_to_stddev_cholesky_lower(gram: Float[Array, "N N"]) -> CatKernelParams: - """Compute the standard deviation and lower triangular Cholesky factor of the gram matrix. - - Args: - gram (Float[Array, "N N"]): The gram matrix. - - Returns: - tuple[Float[Array, "N"], Float[Array, "N N"]]: The standard deviation and lower triangular Cholesky factor of the gram matrix, where the latter is scaled to result in unit variances. - """ - stddev = jnp.sqrt(jnp.diag(gram)) - L = jnp.linalg.cholesky(gram) / stddev.reshape(-1, 1) - return CatKernelParams(stddev, L) diff --git a/tests/test_gps.py b/tests/test_gps.py index 12766772e..59af2bb91 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -252,7 +252,9 @@ def test_posterior_construct( leaves_rmul = jtu.tree_leaves(posterior_rmul) leaves_manual = jtu.tree_leaves(posterior_manual) - for leaf_mul, leaf_rmul, leaf_man in zip(leaves_mul, leaves_rmul, leaves_manual, strict=True): + for leaf_mul, leaf_rmul, leaf_man in zip( + leaves_mul, leaves_rmul, leaves_manual, strict=True + ): assert (leaf_mul == leaf_rmul).all() assert (leaf_rmul == leaf_man).all() diff --git a/tests/test_kernels/test_non_euclidean.py b/tests/test_kernels/test_non_euclidean.py index f8b997d8c..4ed6d68a6 100644 --- a/tests/test_kernels/test_non_euclidean.py +++ b/tests/test_kernels/test_non_euclidean.py @@ -13,13 +13,9 @@ from cola.ops import I_like from jax import config import jax.numpy as jnp -import jax.random as jr import networkx as nx -from gpjax.kernels.non_euclidean import ( - CatKernel, - GraphKernel, -) +from gpjax.kernels.non_euclidean import GraphKernel # # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -50,76 +46,3 @@ def test_graph_kernel(): Kxx += I_like(Kxx) * 1e-6 eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) assert all(eigen_values > 0) - - -def test_cat_kernel(): - x = jr.normal(jr.PRNGKey(123), (5000, 3)) - gram = jnp.cov(x.T) - params = CatKernel.gram_to_stddev_cholesky_lower(gram) - dk = CatKernel( - inspace_vals=list(range(len(gram))), - stddev=params.stddev, - cholesky_lower=params.cholesky_lower, - ) - assert jnp.allclose(dk.explicit_gram, gram) - - sdev = jnp.ones((2,)) - cholesky_lower = jnp.eye(2) - inspace_vals = [0.0, 1.0] - - # Initialize CatKernel object - dict_kernel = CatKernel( - stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals - ) - - assert dict_kernel.stddev.shape == sdev.shape - assert jnp.allclose(dict_kernel.stddev, sdev) - assert jnp.allclose(dict_kernel.cholesky_lower, cholesky_lower) - assert dict_kernel.inspace_vals == inspace_vals - - -def test_cat_kernel_gram_to_stddev_cholesky_lower(): - gram = jnp.array([[1.0, 0.5], [0.5, 1.0]]) - sdev_expected = jnp.array([1.0, 1.0]) - cholesky_lower_expected = jnp.array([[1.0, 0.0], [0.5, 0.8660254]]) - - # Compute sdev and cholesky_lower from gram - sdev, cholesky_lower = CatKernel.gram_to_stddev_cholesky_lower(gram) - - assert jnp.allclose(sdev, sdev_expected) - assert jnp.allclose(cholesky_lower, cholesky_lower_expected) - - -def test_cat_kernel_call(): - sdev = jnp.ones((2,)) - cholesky_lower = jnp.eye(2) - inspace_vals = [0.0, 1.0] - - # Initialize CatKernel object - dict_kernel = CatKernel( - stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals - ) - - # Compute kernel value for pair of inputs - kernel_value = dict_kernel.__call__(0, 1) - - assert jnp.allclose(kernel_value, 0.0) # since cholesky_lower is identity matrix - - -def test_cat_kernel_explicit_gram(): - sdev = jnp.ones((2,)) - cholesky_lower = jnp.eye(2) - inspace_vals = [0.0, 1.0] - - # Initialize CatKernel object - dict_kernel = CatKernel( - stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals - ) - - # Compute explicit gram matrix - explicit_gram = dict_kernel.explicit_gram - - assert explicit_gram.shape == (2, 2) - assert jnp.allclose( - explicit_gram, jnp.eye(2) - ) # since sdev are ones and cholesky_lower is identity matrix diff --git a/tests/test_kernels/test_nonstationary.py b/tests/test_kernels/test_nonstationary.py index 4e2f5e813..803e44d23 100644 --- a/tests/test_kernels/test_nonstationary.py +++ b/tests/test_kernels/test_nonstationary.py @@ -147,7 +147,9 @@ def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None: def prod(inp): - return [dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values())] + return [ + dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values()) + ] class TestLinear(BaseTestKernel): diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index cb93abe4d..3e214b45b 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -172,7 +172,9 @@ def test_isotropic(self, dim: int): def prod(inp): - return [dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values())] + return [ + dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values()) + ] class TestRBF(BaseTestKernel): diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 21df80e26..62e829712 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -141,7 +141,9 @@ def _test_call_check(likelihood, latent_mean, latent_cov, latent_dist): def prod(inp): - return [dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values())] + return [ + dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values()) + ] class TestGaussian(BaseTestLikelihood):