Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add changepoint kernel #357

Closed
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6bc8889
Init commit: changepoint kernel for two RBFs with independent parameters
jakeyeung Aug 10, 2023
a94ac98
Rename changepoint_kernel to just changepoint
jakeyeung Aug 10, 2023
4e4fab4
Replace boilerplate with standard GPJax boilerplate for 2023
jakeyeung Aug 14, 2023
6f3c01b
Rename function GetFunctionIndex to get_function_index
jakeyeung Aug 14, 2023
2fc53e4
Add SwitchPoint kernel, where you can initialize with a list of two
jakeyeung Aug 14, 2023
34598df
Remove commented out import
jakeyeung Aug 14, 2023
21521ad
Document what is kernels and what is tswitch in SwitchKernel
jakeyeung Aug 14, 2023
46366ad
Remove ChangePointRBF, since SwitchKernel covers the case of two RBFs
jakeyeung Aug 14, 2023
d779b93
Rename SwitchKernel to ChangePoint to align with literature
jakeyeung Aug 14, 2023
b71e5f2
Init commit: changepoint kernel for two RBFs with independent parameters
jakeyeung Aug 10, 2023
0d25c5f
Rename changepoint_kernel to just changepoint
jakeyeung Aug 10, 2023
f424d6c
Replace boilerplate with standard GPJax boilerplate for 2023
jakeyeung Aug 14, 2023
6dede3c
Rename function GetFunctionIndex to get_function_index
jakeyeung Aug 14, 2023
25bc0b7
Add SwitchPoint kernel, where you can initialize with a list of two
jakeyeung Aug 14, 2023
b961221
Remove commented out import
jakeyeung Aug 14, 2023
dd4feea
Document what is kernels and what is tswitch in SwitchKernel
jakeyeung Aug 14, 2023
f1ee522
Remove ChangePointRBF, since SwitchKernel covers the case of two RBFs
jakeyeung Aug 14, 2023
db264be
Rename SwitchKernel to ChangePoint to align with literature
jakeyeung Aug 14, 2023
efeb977
Add unit test for changepoint kernel: same as combination kernel.
jakeyeung Oct 10, 2023
19a58da
Merge branch 'changepoint_kernel' of github.com:jakeyeung/GPJax into …
jakeyeung Oct 10, 2023
443269d
Add citation for ChangePoint kernel.
jakeyeung Oct 11, 2023
07f25c3
Change inputs x and y to be 1-dimensional as intended.
jakeyeung Oct 11, 2023
309d51d
Remove default argument
jakeyeung Oct 11, 2023
46c52f6
Document the purpose of get_function_index()
jakeyeung Oct 11, 2023
127bd6c
Rename variables flst to kernel_options
jakeyeung Oct 11, 2023
ddcd673
Raise error if exactly two kernels are not supplied in current
jakeyeung Oct 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gpjax/kernels/stationary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================

from gpjax.kernels.stationary.changepoint import ChangePoint
from gpjax.kernels.stationary.matern12 import Matern12
from gpjax.kernels.stationary.matern32 import Matern32
from gpjax.kernels.stationary.matern52 import Matern52
Expand All @@ -23,6 +24,7 @@
from gpjax.kernels.stationary.white import White

__all__ = [
"ChangePoint",
"Matern12",
"Matern32",
"Matern52",
Expand Down
115 changes: 115 additions & 0 deletions gpjax/kernels/stationary/changepoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from dataclasses import dataclass

from beartype.typing import (
Callable,
List,
)
import jax
import jax.numpy as jnp
from jaxtyping import Float
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base import (
param_field,
static_field,
)
from gpjax.kernels.base import AbstractKernel
from gpjax.typing import (
Array,
ScalarFloat,
)


@dataclass
class ChangePoint(AbstractKernel):
r"""A change point kernel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a citation for your implementation of the kernel? We could add this to our citation.py. Also probably good to give a paper in the docstring. :)

Copy link
Author

@jakeyeung jakeyeung Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea started with the GPflow documentation, but I guess if you're thinking about a paper it would be Saatci et al 2010 : (https://icml.cc/Conferences/2010/papers/170.pdf).

I have now added it the Saatci et al 2010 reference to citation.py and the doc strings.

self.kernels: A list of exactly two kernels that will be switched.
self.tswitch: The point at which to change to a different kernel.
for example: if x and y are both less than tswitch, then you would use kernels[0]
if x and y are both greater than or equal to tswitch, then you would use kernels[1]
otherwise return cross-covariance of 0
"""

kernels: List[AbstractKernel] = None
operator: Callable = static_field(None)
tswitch: ScalarFloat = param_field(
jnp.array(1.0), bijector=tfb.Identity(), trainable=False
)
Comment on lines +55 to +57
Copy link
Member

@daniel-dodd daniel-dodd Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the changepoint location itself not learnt here (trainable=False)?

Also what happens if we have multiple changepoints (does this kernel work recursively by passing a ChangePoint kernel to kernels?)

Copy link
Author

@jakeyeung jakeyeung Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea started with implementing a basic change-point kernel similar to what GPflow has (See Change-point section in https://gpflow.github.io/GPflow/develop/notebooks/getting_started/kernels.html).

When I first implemented it, I was thinking of a fixed parameter, but I don't see any reason why it has to be that way.

At the moment this piece of code only works with two kernels, so one changepoint, but one can also imagine extending this to a list of N kernels with (N-1) changepoint locations.


def __post_init__(self):
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
kernels_list: List[AbstractKernel] = []

for kernel in self.kernels:
if not isinstance(kernel, AbstractKernel):
raise TypeError("can only combine Kernel instances") # pragma: no cover

if isinstance(kernel, self.__class__):
kernels_list.extend(kernel.kernels)

Check warning on line 64 in gpjax/kernels/stationary/changepoint.py

View check run for this annotation

Codecov / codecov/patch

gpjax/kernels/stationary/changepoint.py#L64

Added line #L64 was not covered by tests
else:
kernels_list.append(kernel)

self.kernels = kernels_list
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. But we should probably through an error if we specify more than two kernels (with the current implementation)?


def __call__(
self,
x: Float[Array, " D"],
y: Float[Array, " D"],
) -> ScalarFloat:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'd be clear the intended dimensions is one here? (correct me if I'm wrong). i.e., Float[Array, " 1"] instead of Float[Array, " D"] .

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should be 1-dimensional, I've now fixed now.

r"""Evaluate the kernel on a pair of inputs.

Args:
x (Float[Array, " D"]): The left hand input of the kernel function.
y (Float[Array, " D"]): The right hand input of the kernel function.

Returns
-------
ScalarFloat: The evaluated kernel function at the supplied inputs.
"""

def get_function_index(x, y, tswitch=1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably drop the default argument here.

Suggested change
def get_function_index(x, y, tswitch=1):
def get_function_index(x, y, tswitch):

r"""
Specify four possible indices given x, y, and tswitch.

Args:
x: Left hand argument of kernel function's call
y: Right hand argument of kernel function's call
tswitch: point at which to change to a different kernel
"""
# Four possible indexes: 0, 1, 2, 3
indx = 3 # if indx doesn't get set to 0, 1, or 2, then default 3
# lessthan means that at tswitch, you are already switched
cond1 = jnp.less(x, tswitch)
cond2 = jnp.less(y, tswitch)
indx = jnp.where(jnp.logical_and(cond1, cond2), 0, indx)
indx = jnp.where(jnp.logical_and(jnp.invert(cond1), cond2), 1, indx)
indx = jnp.where(jnp.logical_and(cond1, jnp.invert(cond2)), 2, indx)
return indx.squeeze().astype("uint8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would comment somewhere here (or use the docstring) the purpose - i.e., the index is going to select which kernel function to compute.

But I am a bit confused as to why we need 4 indexes? Why don't we shorten to three? As the flst could be shortened to three given k_zero is repeated? So e.g., we could have flst = [self.kernels[0].__call__, k_zero, self.kernels[1].__call__]. Just curious to know. :)

Copy link
Author

@jakeyeung jakeyeung Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now added in words the purpose of this function. I have also renamed flst to kernel_options to be more explicit.

The four indices can be reduced to three. I found the code easier to read to have four indices corresponding to four quadrants in a 2D covariance matrix. But returning just three indices might make more sense if this function accepts more than two kernels.

e.g. visualizing the gram matrix of the ChangePoint kernel you can see the four quadrants. You can imagine with more than two kernels it would be a bunch of blocks along the diagonal of the matrix, with 0 everywhere else.

image


def k_zero(x, y):
r"""Return 0 covariance"""
out = jnp.float64(0)
return out.squeeze()
Comment on lines +121 to +124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice. But the only concern I have is with jnp.float64 potentially clashing with JAX's default float32. Could this not work with jnp.array(0.0) (or even 0.0)? Or does this break the jax.lax.switch or something? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If problems persist. You could always use a Constant(0.0) kernel here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the details of jax enough to know how the different types play with other types and conditionals. What would be a good way to test which type is best to put in here?


indx = get_function_index(x, y, tswitch=self.tswitch)

flst = [self.kernels[0].__call__, k_zero, k_zero, self.kernels[1].__call__]
K = jax.lax.switch(indx, flst, x, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe on naming. I don't know what flst means. Maybe something like kernel_options.

Also, why do we need self.kernels[0].__call__ and self.kernels[1].__call__. Does this not work?

kernel_options = self.kernels + [k_zero].

Would this help with extending this kernel to more than two kernels? Or is there some issue with the computing this indexing?

Copy link
Author

@jakeyeung jakeyeung Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this would work. I just need to rewrite the get_function_index() so it returns an indx that selects a kernel, otherwise set indx = len(kernels) + 1 to select k_zero.

In order to return the right index, this new get_function_index() would check whether both x and y are greater than or equal to a vector of locations (returning a boolean vector of same size as locations), then return the greatest index for which this is True.

I can try out an initial implementation, but if you have code suggestions would be happy to hear a clean way to code this.


return K.squeeze()
39 changes: 39 additions & 0 deletions tests/test_kernels/test_stationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from gpjax.kernels.stationary import (
RBF,
ChangePoint,
Matern12,
Matern32,
Matern52,
Expand Down Expand Up @@ -241,6 +242,44 @@ class TestRationalQuadratic(BaseTestKernel):
default_compute_engine = DenseKernelComputation()


@pytest.mark.parametrize(
"kernel",
[RBF, RationalQuadratic, Matern12, Matern32, Matern52],
)
@pytest.mark.parametrize("tswitch_param", [0.1, 0.25, 0.5, 0.75, 0.9])
def test_changepoint_kernel(kernel: AbstractKernel, tswitch_param: float) -> None:
n_kerns = 2
Comment on lines +249 to +251
Copy link
Member

@daniel-dodd daniel-dodd Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. The only thing is, I would like to see (a) an error is thrown if n_kerns is wrong and (b) this is tested.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to ask basic question, how does write a test to make sure that a proper error is thrown?

# Create inputs
n = 20
x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1)

# Create list of kernels
kernels = [kernel() for _ in range(n_kerns)]

# Create combination kernel
combination_kernel = ChangePoint(kernels=kernels, tswitch=tswitch_param)

# Check params are a list of dictionaries
assert combination_kernel.kernels == kernels

# Check combination kernel set
assert len(combination_kernel.kernels) == n_kerns
assert isinstance(combination_kernel.kernels, list)
assert isinstance(combination_kernel.kernels[0], AbstractKernel)

# Compute gram matrix
Kxx = combination_kernel.gram(x)

# Check shapes
assert Kxx.shape[0] == Kxx.shape[1]
assert Kxx.shape[1] == n

# Check positive definiteness
jitter = 1e-6
eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * jitter)
assert (eigen_values > 0).all()


@pytest.mark.parametrize("smoothness", [1, 2, 3])
def test_build_studentt_dist(smoothness: int) -> None:
dist = build_student_t_distribution(smoothness)
Expand Down
Loading