From c07935a87295bd3fd63bc5ffd2e13e3763f8d556 Mon Sep 17 00:00:00 2001 From: hmoss <32096840+henrymoss@users.noreply.github.com> Date: Sun, 30 Apr 2023 13:39:31 +0100 Subject: [PATCH] Arccosine kernel (#245) * WIP * first go * nice test --------- Signed-off-by: Thomas Pinder Co-authored-by: Thomas Pinder --- gpjax/kernels/__init__.py | 2 + gpjax/kernels/nonstationary/__init__.py | 3 +- gpjax/kernels/nonstationary/arccosine.py | 117 +++++++++++++++++++++++ tests/test_kernels/test_nonstationary.py | 64 +++++++++++-- 4 files changed, 176 insertions(+), 10 deletions(-) create mode 100644 gpjax/kernels/nonstationary/arccosine.py diff --git a/gpjax/kernels/__init__.py b/gpjax/kernels/__init__.py index a72918d64..312f77481 100644 --- a/gpjax/kernels/__init__.py +++ b/gpjax/kernels/__init__.py @@ -29,6 +29,7 @@ ) from gpjax.kernels.non_euclidean import GraphKernel from gpjax.kernels.nonstationary import ( + ArcCosine, Linear, Polynomial, ) @@ -45,6 +46,7 @@ __all__ = [ "AbstractKernel", + "ArcCosine", "RBF", "GraphKernel", "Matern12", diff --git a/gpjax/kernels/nonstationary/__init__.py b/gpjax/kernels/nonstationary/__init__.py index 3a4ae5b58..e772fa655 100644 --- a/gpjax/kernels/nonstationary/__init__.py +++ b/gpjax/kernels/nonstationary/__init__.py @@ -13,7 +13,8 @@ # limitations under the License. # ============================================================================== +from gpjax.kernels.nonstationary.arccosine import ArcCosine from gpjax.kernels.nonstationary.linear import Linear from gpjax.kernels.nonstationary.polynomial import Polynomial -__all__ = ["Linear", "Polynomial"] +__all__ = ["Linear", "Polynomial", "ArcCosine"] diff --git a/gpjax/kernels/nonstationary/arccosine.py b/gpjax/kernels/nonstationary/arccosine.py new file mode 100644 index 000000000..fefbc9969 --- /dev/null +++ b/gpjax/kernels/nonstationary/arccosine.py @@ -0,0 +1,117 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from dataclasses import dataclass + +from beartype.typing import Union +import jax.numpy as jnp +from jaxtyping import Float +from simple_pytree import static_field +import tensorflow_probability.substrates.jax.bijectors as tfb + +from gpjax.base import param_field +from gpjax.kernels.base import AbstractKernel +from gpjax.typing import ( + Array, + ScalarFloat, + ScalarInt, +) + + +@dataclass +class ArcCosine(AbstractKernel): + """The ArCosine kernel. This kernel is non-stationary and resembles the behavior + of neural networks. See Section 3.1 of https://arxiv.org/pdf/1112.3712.pdf for + additional details. + """ + + order: ScalarInt = static_field(0) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) + weight_variance: Union[ScalarFloat, Float[Array, " D"]] = param_field( + jnp.array(1.0), bijector=tfb.Softplus() + ) + bias_variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) + + def __post_init__(self): + if self.order not in [0, 1, 2]: + raise ValueError("ArcCosine kernel only implemented for orders 0, 1 and 2.") + + self.name = f"ArcCosine (order {self.order})" + + def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` + + Args: + x (Float[Array, "D"]): The left hand argument of the kernel function's + call. + y (Float[Array, "D"]): The right hand argument of the kernel function's + call + + Returns + ------- + ScalarFloat: The value of :math:`k(x, y)`. + """ + + x = self.slice_input(x) + y = self.slice_input(y) + + x_x = self._weighted_prod(x, x) + x_y = self._weighted_prod(x, y) + y_y = self._weighted_prod(y, y) + + cos_theta = x_y / jnp.sqrt(x_x * y_y) + jitter = 1e-15 # improve numerical stability + theta = jnp.arccos(jitter + (1 - 2 * jitter) * cos_theta) + + K = self._J(theta) + K *= jnp.sqrt(x_x) ** self.order + K *= jnp.sqrt(y_y) ** self.order + K *= self.variance / jnp.pi + + return K.squeeze() + + def _weighted_prod( + self, x: Float[Array, " D"], y: Float[Array, " D"] + ) -> ScalarFloat: + """Calculate the weighted product between two arguments. + + Args: + x (Float[Array, "D"]): The left hand argument. + y (Float[Array, "D"]): The right hand argument. + Returns + ------- + ScalarFloat: The value of the weighted product between the two arguments``. + """ + return jnp.inner(self.weight_variance * x, y) + self.bias_variance + + def _J(self, theta: ScalarFloat) -> ScalarFloat: + """Evaluate the angular dependency function corresponding to the desired order. + + Args: + theta (Float[Array, "1"]): The weighted angle between inputs. + + Returns + ------- + Float[Array, "1"]: The value of the angular dependency function`. + """ + + if self.order == 0: + return jnp.pi - theta + elif self.order == 1: + return jnp.sin(theta) + (jnp.pi - theta) * jnp.cos(theta) + else: + return 3.0 * jnp.sin(theta) * jnp.cos(theta) + (jnp.pi - theta) * ( + 1.0 + 2.0 * jnp.cos(theta) ** 2 + ) diff --git a/tests/test_kernels/test_nonstationary.py b/tests/test_kernels/test_nonstationary.py index 79a63fca3..d498f263c 100644 --- a/tests/test_kernels/test_nonstationary.py +++ b/tests/test_kernels/test_nonstationary.py @@ -14,10 +14,7 @@ # ============================================================================== from dataclasses import is_dataclass -from itertools import ( - permutations, - product, -) +from itertools import product from typing import List import jax @@ -31,13 +28,11 @@ from gpjax.kernels.base import AbstractKernel from gpjax.kernels.computations import DenseKernelComputation from gpjax.kernels.nonstationary import ( + ArcCosine, Linear, Polynomial, ) -from gpjax.linops import ( - LinearOperator, - identity, -) +from gpjax.linops import LinearOperator # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -101,7 +96,9 @@ def test_initialization(self, fields: dict, dim: int) -> None: # Check meta leaves meta = kernel._pytree__meta assert not any(f in meta for f in self.static_fields) - assert list(meta.keys()) == sorted(set(fields) - set(self.static_fields)) + assert sorted(list(meta.keys())) == sorted( + set(fields) - set(self.static_fields) + ) for field in meta: # Bijectors @@ -170,3 +167,52 @@ class TestPolynomial(BaseTestKernel): static_fields = ["degree"] params = {"test_initialization": fields} default_compute_engine = DenseKernelComputation + + +class TestArcCosine(BaseTestKernel): + kernel = ArcCosine + fields = prod( + { + "variance": [0.1, 1.0], + "order": [0, 1, 2], + "weight_variance": [0.1, 1.0], + "bias_variance": [0.1, 1.0], + } + ) + static_fields = ["order"] + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation + + @pytest.mark.parametrize("order", [-1, 3], ids=lambda x: f"order={x}") + def test_defaults(self, order: int) -> None: + with pytest.raises(ValueError): + self.kernel(order=order) + + @pytest.mark.parametrize("order", [0, 1, 2], ids=lambda x: f"order={x}") + def test_values_by_monte_carlo_in_special_case(self, order: int) -> None: + """For certain values of weight variance (1.0) and bias variance (0.0), we can test + our calculations using the Monte Carlo expansion of the arccosine kernel, e.g. + see Eq. (1) of https://cseweb.ucsd.edu/~saul/papers/nips09_kernel.pdf. + """ + kernel: AbstractKernel = self.kernel( + weight_variance=jnp.array([1.0, 1.0]), bias_variance=1e-25, order=order + ) + key = jr.PRNGKey(123) + + # Inputs close(ish) together + a = jnp.array([[0.0, 0.0]]) + b = jnp.array([[2.0, 2.0]]) + + # calc cross-covariance exactly + Kab_exact = kernel.cross_covariance(a, b) + + # calc cross-covariance using samples + weights = jax.random.normal(key, (10_000, 2)) # [S, d] + weights_a = jnp.matmul(weights, a.T) # [S, 1] + weights_b = jnp.matmul(weights, b.T) # [S, 1] + H_a = jnp.heaviside(weights_a, 0.5) + H_b = jnp.heaviside(weights_b, 0.5) + integrands = H_a * H_b * (weights_a**order) * (weights_b**order) + Kab_approx = 2.0 * jnp.mean(integrands) + + assert jnp.max(Kab_approx - Kab_exact) < 1e-4