-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add changepoint kernel #357
Changes from all commits
6bc8889
a94ac98
4e4fab4
6f3c01b
2fc53e4
34598df
21521ad
46366ad
d779b93
b71e5f2
0d25c5f
f424d6c
6dede3c
25bc0b7
b961221
dd4feea
f1ee522
db264be
efeb977
19a58da
443269d
07f25c3
309d51d
46c52f6
127bd6c
ddcd673
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
|
||
from dataclasses import dataclass | ||
|
||
from beartype.typing import ( | ||
Callable, | ||
List, | ||
) | ||
import jax | ||
import jax.numpy as jnp | ||
from jaxtyping import Float | ||
import tensorflow_probability.substrates.jax.bijectors as tfb | ||
|
||
from gpjax.base import ( | ||
param_field, | ||
static_field, | ||
) | ||
from gpjax.kernels.base import AbstractKernel | ||
from gpjax.typing import ( | ||
Array, | ||
ScalarFloat, | ||
) | ||
|
||
|
||
@dataclass | ||
class ChangePoint(AbstractKernel): | ||
r"""A change point kernel | ||
|
||
See Saatci, Turner, Rasmussen 2010 ICML paper for details. | ||
|
||
self.kernels: A list of exactly two kernels that will be switched. | ||
self.tswitch: The point at which to change to a different kernel. | ||
for example: if x and y are both less than tswitch, then you would use kernels[0] | ||
if x and y are both greater than or equal to tswitch, then you would use kernels[1] | ||
otherwise return cross-covariance of 0 | ||
|
||
""" | ||
|
||
kernels: List[AbstractKernel] = None | ||
operator: Callable = static_field(None) | ||
tswitch: ScalarFloat = param_field( | ||
jnp.array(1.0), bijector=tfb.Identity(), trainable=False | ||
) | ||
Comment on lines
+55
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the changepoint location itself not learnt here ( Also what happens if we have multiple changepoints (does this kernel work recursively by passing a ChangePoint kernel to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea started with implementing a basic change-point kernel similar to what GPflow has (See Change-point section in https://gpflow.github.io/GPflow/develop/notebooks/getting_started/kernels.html). When I first implemented it, I was thinking of a fixed parameter, but I don't see any reason why it has to be that way. At the moment this piece of code only works with two kernels, so one changepoint, but one can also imagine extending this to a list of N kernels with (N-1) changepoint locations. |
||
|
||
def __post_init__(self): | ||
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels. | ||
kernels_list: List[AbstractKernel] = [] | ||
|
||
if len(kernels_list) != 2: | ||
raise TypeError("Current implementation only accepts 2 kernels") | ||
|
||
for kernel in self.kernels: | ||
if not isinstance(kernel, AbstractKernel): | ||
raise TypeError("can only combine Kernel instances") # pragma: no cover | ||
|
||
if isinstance(kernel, self.__class__): | ||
kernels_list.extend(kernel.kernels) | ||
else: | ||
kernels_list.append(kernel) | ||
|
||
self.kernels = kernels_list | ||
|
||
def __call__( | ||
self, | ||
x: Float[Array, " 1"], | ||
y: Float[Array, " 1"], | ||
) -> ScalarFloat: | ||
r"""Evaluate the kernel on a pair of inputs. | ||
|
||
Args: | ||
x (Float[Array, " 1"]): The left hand input of the kernel function. | ||
y (Float[Array, " 1"]): The right hand input of the kernel function. | ||
|
||
Returns | ||
------- | ||
ScalarFloat: The evaluated kernel function at the supplied inputs. | ||
""" | ||
|
||
def get_function_index(x, y, tswitch): | ||
r""" | ||
Specify which kernel function to compute given x, y, and tswitch. | ||
|
||
There are four indices that correspond to different functions used to calculate the | ||
kernel distance between two points. | ||
|
||
First scenario: cond1 and cond2 less than tswitch (left of tswitch) would calculating | ||
distance by evaluating first kernel. | ||
Second and third scenario: the two points are left and right of tswitch, | ||
so evaluating the covariance as 0. | ||
Fourth scenario: both points are right of tswitch, calculate distance using second kernel. | ||
|
||
Args: | ||
x: Left hand argument of kernel function's call | ||
y: Right hand argument of kernel function's call | ||
tswitch: point at which to change to a different kernel | ||
""" | ||
# Four possible indexes: 0, 1, 2, 3 | ||
indx = 3 # if indx doesn't get set to 0, 1, or 2, then default 3 | ||
# lessthan means that at tswitch, you are already switched | ||
cond1 = jnp.less(x, tswitch) | ||
cond2 = jnp.less(y, tswitch) | ||
indx = jnp.where(jnp.logical_and(cond1, cond2), 0, indx) | ||
indx = jnp.where(jnp.logical_and(jnp.invert(cond1), cond2), 1, indx) | ||
indx = jnp.where(jnp.logical_and(cond1, jnp.invert(cond2)), 2, indx) | ||
return indx.squeeze().astype("uint8") | ||
|
||
def k_zero(x, y): | ||
r"""Return 0 covariance""" | ||
out = jnp.float64(0) | ||
return out.squeeze() | ||
Comment on lines
+121
to
+124
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is nice. But the only concern I have is with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If problems persist. You could always use a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know the details of |
||
|
||
indx = get_function_index(x, y, tswitch=self.tswitch) | ||
|
||
kernel_options = [ | ||
self.kernels[0].__call__, | ||
k_zero, | ||
k_zero, | ||
self.kernels[1].__call__, | ||
] | ||
K = jax.lax.switch(indx, kernel_options, x, y) | ||
|
||
return K.squeeze() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
) | ||
from gpjax.kernels.stationary import ( | ||
RBF, | ||
ChangePoint, | ||
Matern12, | ||
Matern32, | ||
Matern52, | ||
|
@@ -241,6 +242,44 @@ class TestRationalQuadratic(BaseTestKernel): | |
default_compute_engine = DenseKernelComputation() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"kernel", | ||
[RBF, RationalQuadratic, Matern12, Matern32, Matern52], | ||
) | ||
@pytest.mark.parametrize("tswitch_param", [0.1, 0.25, 0.5, 0.75, 0.9]) | ||
def test_changepoint_kernel(kernel: AbstractKernel, tswitch_param: float) -> None: | ||
n_kerns = 2 | ||
Comment on lines
+249
to
+251
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks good. The only thing is, I would like to see (a) an error is thrown if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry to ask basic question, how does write a test to make sure that a proper error is thrown? |
||
# Create inputs | ||
n = 20 | ||
x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) | ||
|
||
# Create list of kernels | ||
kernels = [kernel() for _ in range(n_kerns)] | ||
|
||
# Create combination kernel | ||
combination_kernel = ChangePoint(kernels=kernels, tswitch=tswitch_param) | ||
|
||
# Check params are a list of dictionaries | ||
assert combination_kernel.kernels == kernels | ||
|
||
# Check combination kernel set | ||
assert len(combination_kernel.kernels) == n_kerns | ||
assert isinstance(combination_kernel.kernels, list) | ||
assert isinstance(combination_kernel.kernels[0], AbstractKernel) | ||
|
||
# Compute gram matrix | ||
Kxx = combination_kernel.gram(x) | ||
|
||
# Check shapes | ||
assert Kxx.shape[0] == Kxx.shape[1] | ||
assert Kxx.shape[1] == n | ||
|
||
# Check positive definiteness | ||
jitter = 1e-6 | ||
eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * jitter) | ||
assert (eigen_values > 0).all() | ||
|
||
|
||
@pytest.mark.parametrize("smoothness", [1, 2, 3]) | ||
def test_build_studentt_dist(smoothness: int) -> None: | ||
dist = build_student_t_distribution(smoothness) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a citation for your implementation of the kernel? We could add this to our
citation.py
. Also probably good to give a paper in the docstring. :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea started with the GPflow documentation, but I guess if you're thinking about a paper it would be Saatci et al 2010 : (https://icml.cc/Conferences/2010/papers/170.pdf).
I have now added it the Saatci et al 2010 reference to
citation.py
and the doc strings.