Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add changepoint kernel #357

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6bc8889
Init commit: changepoint kernel for two RBFs with independent parameters
jakeyeung Aug 10, 2023
a94ac98
Rename changepoint_kernel to just changepoint
jakeyeung Aug 10, 2023
4e4fab4
Replace boilerplate with standard GPJax boilerplate for 2023
jakeyeung Aug 14, 2023
6f3c01b
Rename function GetFunctionIndex to get_function_index
jakeyeung Aug 14, 2023
2fc53e4
Add SwitchPoint kernel, where you can initialize with a list of two
jakeyeung Aug 14, 2023
34598df
Remove commented out import
jakeyeung Aug 14, 2023
21521ad
Document what is kernels and what is tswitch in SwitchKernel
jakeyeung Aug 14, 2023
46366ad
Remove ChangePointRBF, since SwitchKernel covers the case of two RBFs
jakeyeung Aug 14, 2023
d779b93
Rename SwitchKernel to ChangePoint to align with literature
jakeyeung Aug 14, 2023
b71e5f2
Init commit: changepoint kernel for two RBFs with independent parameters
jakeyeung Aug 10, 2023
0d25c5f
Rename changepoint_kernel to just changepoint
jakeyeung Aug 10, 2023
f424d6c
Replace boilerplate with standard GPJax boilerplate for 2023
jakeyeung Aug 14, 2023
6dede3c
Rename function GetFunctionIndex to get_function_index
jakeyeung Aug 14, 2023
25bc0b7
Add SwitchPoint kernel, where you can initialize with a list of two
jakeyeung Aug 14, 2023
b961221
Remove commented out import
jakeyeung Aug 14, 2023
dd4feea
Document what is kernels and what is tswitch in SwitchKernel
jakeyeung Aug 14, 2023
f1ee522
Remove ChangePointRBF, since SwitchKernel covers the case of two RBFs
jakeyeung Aug 14, 2023
db264be
Rename SwitchKernel to ChangePoint to align with literature
jakeyeung Aug 14, 2023
efeb977
Add unit test for changepoint kernel: same as combination kernel.
jakeyeung Oct 10, 2023
19a58da
Merge branch 'changepoint_kernel' of github.com:jakeyeung/GPJax into …
jakeyeung Oct 10, 2023
443269d
Add citation for ChangePoint kernel.
jakeyeung Oct 11, 2023
07f25c3
Change inputs x and y to be 1-dimensional as intended.
jakeyeung Oct 11, 2023
309d51d
Remove default argument
jakeyeung Oct 11, 2023
46c52f6
Document the purpose of get_function_index()
jakeyeung Oct 11, 2023
127bd6c
Rename variables flst to kernel_options
jakeyeung Oct 11, 2023
ddcd673
Raise error if exactly two kernels are not supplied in current
jakeyeung Oct 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions gpjax/citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from gpjax.kernels import (
RFF,
ArcCosine,
ChangePoint,
GraphKernel,
Matern12,
Matern32,
Expand Down Expand Up @@ -137,6 +138,17 @@ def _(_) -> PaperCitation:
)


@cite.register(ChangePoint)
def _(_) -> PaperCitation:
return PaperCitation(
citation_key="saatcci2010gaussian",
authors="Saatci, Yunus and Turner, Ryan D and Rasmussen, Carl E",
title="Gaussian process change point models",
year="2010",
booktitle="Proceedings of the 27th International Conference on Machine Learning (ICML-10)",
)


@cite.register(GraphKernel)
def _(tree) -> PaperCitation:
return PaperCitation(
Expand Down
2 changes: 2 additions & 0 deletions gpjax/kernels/stationary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================

from gpjax.kernels.stationary.changepoint import ChangePoint
from gpjax.kernels.stationary.matern12 import Matern12
from gpjax.kernels.stationary.matern32 import Matern32
from gpjax.kernels.stationary.matern52 import Matern52
Expand All @@ -23,6 +24,7 @@
from gpjax.kernels.stationary.white import White

__all__ = [
"ChangePoint",
"Matern12",
"Matern32",
"Matern52",
Expand Down
136 changes: 136 additions & 0 deletions gpjax/kernels/stationary/changepoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from dataclasses import dataclass

from beartype.typing import (
Callable,
List,
)
import jax
import jax.numpy as jnp
from jaxtyping import Float
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base import (
param_field,
static_field,
)
from gpjax.kernels.base import AbstractKernel
from gpjax.typing import (
Array,
ScalarFloat,
)


@dataclass
class ChangePoint(AbstractKernel):
r"""A change point kernel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a citation for your implementation of the kernel? We could add this to our citation.py. Also probably good to give a paper in the docstring. :)

Copy link
Author

@jakeyeung jakeyeung Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea started with the GPflow documentation, but I guess if you're thinking about a paper it would be Saatci et al 2010 : (https://icml.cc/Conferences/2010/papers/170.pdf).

I have now added it the Saatci et al 2010 reference to citation.py and the doc strings.


See Saatci, Turner, Rasmussen 2010 ICML paper for details.

self.kernels: A list of exactly two kernels that will be switched.
self.tswitch: The point at which to change to a different kernel.
for example: if x and y are both less than tswitch, then you would use kernels[0]
if x and y are both greater than or equal to tswitch, then you would use kernels[1]
otherwise return cross-covariance of 0

"""

kernels: List[AbstractKernel] = None
operator: Callable = static_field(None)
tswitch: ScalarFloat = param_field(
jnp.array(1.0), bijector=tfb.Identity(), trainable=False
)
Comment on lines +55 to +57
Copy link
Member

@daniel-dodd daniel-dodd Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the changepoint location itself not learnt here (trainable=False)?

Also what happens if we have multiple changepoints (does this kernel work recursively by passing a ChangePoint kernel to kernels?)

Copy link
Author

@jakeyeung jakeyeung Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea started with implementing a basic change-point kernel similar to what GPflow has (See Change-point section in https://gpflow.github.io/GPflow/develop/notebooks/getting_started/kernels.html).

When I first implemented it, I was thinking of a fixed parameter, but I don't see any reason why it has to be that way.

At the moment this piece of code only works with two kernels, so one changepoint, but one can also imagine extending this to a list of N kernels with (N-1) changepoint locations.


def __post_init__(self):
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
kernels_list: List[AbstractKernel] = []

if len(kernels_list) != 2:
raise TypeError("Current implementation only accepts 2 kernels")

for kernel in self.kernels:
if not isinstance(kernel, AbstractKernel):
raise TypeError("can only combine Kernel instances") # pragma: no cover

if isinstance(kernel, self.__class__):
kernels_list.extend(kernel.kernels)
else:
kernels_list.append(kernel)

self.kernels = kernels_list

def __call__(
self,
x: Float[Array, " 1"],
y: Float[Array, " 1"],
) -> ScalarFloat:
r"""Evaluate the kernel on a pair of inputs.

Args:
x (Float[Array, " 1"]): The left hand input of the kernel function.
y (Float[Array, " 1"]): The right hand input of the kernel function.

Returns
-------
ScalarFloat: The evaluated kernel function at the supplied inputs.
"""

def get_function_index(x, y, tswitch):
r"""
Specify which kernel function to compute given x, y, and tswitch.

There are four indices that correspond to different functions used to calculate the
kernel distance between two points.

First scenario: cond1 and cond2 less than tswitch (left of tswitch) would calculating
distance by evaluating first kernel.
Second and third scenario: the two points are left and right of tswitch,
so evaluating the covariance as 0.
Fourth scenario: both points are right of tswitch, calculate distance using second kernel.

Args:
x: Left hand argument of kernel function's call
y: Right hand argument of kernel function's call
tswitch: point at which to change to a different kernel
"""
# Four possible indexes: 0, 1, 2, 3
indx = 3 # if indx doesn't get set to 0, 1, or 2, then default 3
# lessthan means that at tswitch, you are already switched
cond1 = jnp.less(x, tswitch)
cond2 = jnp.less(y, tswitch)
indx = jnp.where(jnp.logical_and(cond1, cond2), 0, indx)
indx = jnp.where(jnp.logical_and(jnp.invert(cond1), cond2), 1, indx)
indx = jnp.where(jnp.logical_and(cond1, jnp.invert(cond2)), 2, indx)
return indx.squeeze().astype("uint8")

def k_zero(x, y):
r"""Return 0 covariance"""
out = jnp.float64(0)
return out.squeeze()
Comment on lines +121 to +124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice. But the only concern I have is with jnp.float64 potentially clashing with JAX's default float32. Could this not work with jnp.array(0.0) (or even 0.0)? Or does this break the jax.lax.switch or something? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If problems persist. You could always use a Constant(0.0) kernel here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the details of jax enough to know how the different types play with other types and conditionals. What would be a good way to test which type is best to put in here?


indx = get_function_index(x, y, tswitch=self.tswitch)

kernel_options = [
self.kernels[0].__call__,
k_zero,
k_zero,
self.kernels[1].__call__,
]
K = jax.lax.switch(indx, kernel_options, x, y)

return K.squeeze()
39 changes: 39 additions & 0 deletions tests/test_kernels/test_stationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from gpjax.kernels.stationary import (
RBF,
ChangePoint,
Matern12,
Matern32,
Matern52,
Expand Down Expand Up @@ -241,6 +242,44 @@ class TestRationalQuadratic(BaseTestKernel):
default_compute_engine = DenseKernelComputation()


@pytest.mark.parametrize(
"kernel",
[RBF, RationalQuadratic, Matern12, Matern32, Matern52],
)
@pytest.mark.parametrize("tswitch_param", [0.1, 0.25, 0.5, 0.75, 0.9])
def test_changepoint_kernel(kernel: AbstractKernel, tswitch_param: float) -> None:
n_kerns = 2
Comment on lines +249 to +251
Copy link
Member

@daniel-dodd daniel-dodd Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. The only thing is, I would like to see (a) an error is thrown if n_kerns is wrong and (b) this is tested.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to ask basic question, how does write a test to make sure that a proper error is thrown?

# Create inputs
n = 20
x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1)

# Create list of kernels
kernels = [kernel() for _ in range(n_kerns)]

# Create combination kernel
combination_kernel = ChangePoint(kernels=kernels, tswitch=tswitch_param)

# Check params are a list of dictionaries
assert combination_kernel.kernels == kernels

# Check combination kernel set
assert len(combination_kernel.kernels) == n_kerns
assert isinstance(combination_kernel.kernels, list)
assert isinstance(combination_kernel.kernels[0], AbstractKernel)

# Compute gram matrix
Kxx = combination_kernel.gram(x)

# Check shapes
assert Kxx.shape[0] == Kxx.shape[1]
assert Kxx.shape[1] == n

# Check positive definiteness
jitter = 1e-6
eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * jitter)
assert (eigen_values > 0).all()


@pytest.mark.parametrize("smoothness", [1, 2, 3])
def test_build_studentt_dist(smoothness: int) -> None:
dist = build_student_t_distribution(smoothness)
Expand Down