-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add changepoint kernel #357
Changes from 20 commits
6bc8889
a94ac98
4e4fab4
6f3c01b
2fc53e4
34598df
21521ad
46366ad
d779b93
b71e5f2
0d25c5f
f424d6c
6dede3c
25bc0b7
b961221
dd4feea
f1ee522
db264be
efeb977
19a58da
443269d
07f25c3
309d51d
46c52f6
127bd6c
ddcd673
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,115 @@ | ||||||||||
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. | ||||||||||
# | ||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||
# you may not use this file except in compliance with the License. | ||||||||||
# You may obtain a copy of the License at | ||||||||||
# | ||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||
# | ||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||
# See the License for the specific language governing permissions and | ||||||||||
# limitations under the License. | ||||||||||
# ============================================================================== | ||||||||||
|
||||||||||
|
||||||||||
from dataclasses import dataclass | ||||||||||
|
||||||||||
from beartype.typing import ( | ||||||||||
Callable, | ||||||||||
List, | ||||||||||
) | ||||||||||
import jax | ||||||||||
import jax.numpy as jnp | ||||||||||
from jaxtyping import Float | ||||||||||
import tensorflow_probability.substrates.jax.bijectors as tfb | ||||||||||
|
||||||||||
from gpjax.base import ( | ||||||||||
param_field, | ||||||||||
static_field, | ||||||||||
) | ||||||||||
from gpjax.kernels.base import AbstractKernel | ||||||||||
from gpjax.typing import ( | ||||||||||
Array, | ||||||||||
ScalarFloat, | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
@dataclass | ||||||||||
class ChangePoint(AbstractKernel): | ||||||||||
r"""A change point kernel | ||||||||||
self.kernels: A list of exactly two kernels that will be switched. | ||||||||||
self.tswitch: The point at which to change to a different kernel. | ||||||||||
for example: if x and y are both less than tswitch, then you would use kernels[0] | ||||||||||
if x and y are both greater than or equal to tswitch, then you would use kernels[1] | ||||||||||
otherwise return cross-covariance of 0 | ||||||||||
""" | ||||||||||
|
||||||||||
kernels: List[AbstractKernel] = None | ||||||||||
operator: Callable = static_field(None) | ||||||||||
tswitch: ScalarFloat = param_field( | ||||||||||
jnp.array(1.0), bijector=tfb.Identity(), trainable=False | ||||||||||
) | ||||||||||
Comment on lines
+55
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the changepoint location itself not learnt here ( Also what happens if we have multiple changepoints (does this kernel work recursively by passing a ChangePoint kernel to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea started with implementing a basic change-point kernel similar to what GPflow has (See Change-point section in https://gpflow.github.io/GPflow/develop/notebooks/getting_started/kernels.html). When I first implemented it, I was thinking of a fixed parameter, but I don't see any reason why it has to be that way. At the moment this piece of code only works with two kernels, so one changepoint, but one can also imagine extending this to a list of N kernels with (N-1) changepoint locations. |
||||||||||
|
||||||||||
def __post_init__(self): | ||||||||||
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels. | ||||||||||
kernels_list: List[AbstractKernel] = [] | ||||||||||
|
||||||||||
for kernel in self.kernels: | ||||||||||
if not isinstance(kernel, AbstractKernel): | ||||||||||
raise TypeError("can only combine Kernel instances") # pragma: no cover | ||||||||||
|
||||||||||
if isinstance(kernel, self.__class__): | ||||||||||
kernels_list.extend(kernel.kernels) | ||||||||||
else: | ||||||||||
kernels_list.append(kernel) | ||||||||||
|
||||||||||
self.kernels = kernels_list | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is great. But we should probably through an error if we specify more than two kernels (with the current implementation)? |
||||||||||
|
||||||||||
def __call__( | ||||||||||
self, | ||||||||||
x: Float[Array, " D"], | ||||||||||
y: Float[Array, " D"], | ||||||||||
) -> ScalarFloat: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I'd be clear the intended dimensions is one here? (correct me if I'm wrong). i.e., There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They should be 1-dimensional, I've now fixed now. |
||||||||||
r"""Evaluate the kernel on a pair of inputs. | ||||||||||
|
||||||||||
Args: | ||||||||||
x (Float[Array, " D"]): The left hand input of the kernel function. | ||||||||||
y (Float[Array, " D"]): The right hand input of the kernel function. | ||||||||||
|
||||||||||
Returns | ||||||||||
------- | ||||||||||
ScalarFloat: The evaluated kernel function at the supplied inputs. | ||||||||||
""" | ||||||||||
|
||||||||||
def get_function_index(x, y, tswitch=1): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would probably drop the default argument here.
Suggested change
|
||||||||||
r""" | ||||||||||
Specify four possible indices given x, y, and tswitch. | ||||||||||
|
||||||||||
Args: | ||||||||||
x: Left hand argument of kernel function's call | ||||||||||
y: Right hand argument of kernel function's call | ||||||||||
tswitch: point at which to change to a different kernel | ||||||||||
""" | ||||||||||
# Four possible indexes: 0, 1, 2, 3 | ||||||||||
indx = 3 # if indx doesn't get set to 0, 1, or 2, then default 3 | ||||||||||
# lessthan means that at tswitch, you are already switched | ||||||||||
cond1 = jnp.less(x, tswitch) | ||||||||||
cond2 = jnp.less(y, tswitch) | ||||||||||
indx = jnp.where(jnp.logical_and(cond1, cond2), 0, indx) | ||||||||||
indx = jnp.where(jnp.logical_and(jnp.invert(cond1), cond2), 1, indx) | ||||||||||
indx = jnp.where(jnp.logical_and(cond1, jnp.invert(cond2)), 2, indx) | ||||||||||
return indx.squeeze().astype("uint8") | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would comment somewhere here (or use the docstring) the purpose - i.e., the index is going to select which kernel function to compute. But I am a bit confused as to why we need 4 indexes? Why don't we shorten to three? As the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have now added in words the purpose of this function. I have also renamed The four indices can be reduced to three. I found the code easier to read to have four indices corresponding to four quadrants in a 2D covariance matrix. But returning just three indices might make more sense if this function accepts more than two kernels. e.g. visualizing the gram matrix of the ChangePoint kernel you can see the four quadrants. You can imagine with more than two kernels it would be a bunch of blocks along the diagonal of the matrix, with 0 everywhere else. |
||||||||||
|
||||||||||
def k_zero(x, y): | ||||||||||
r"""Return 0 covariance""" | ||||||||||
out = jnp.float64(0) | ||||||||||
return out.squeeze() | ||||||||||
Comment on lines
+121
to
+124
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is nice. But the only concern I have is with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If problems persist. You could always use a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know the details of |
||||||||||
|
||||||||||
indx = get_function_index(x, y, tswitch=self.tswitch) | ||||||||||
|
||||||||||
flst = [self.kernels[0].__call__, k_zero, k_zero, self.kernels[1].__call__] | ||||||||||
K = jax.lax.switch(indx, flst, x, y) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe on naming. I don't know what Also, why do we need
Would this help with extending this kernel to more than two kernels? Or is there some issue with the computing this indexing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this would work. I just need to rewrite the In order to return the right index, this new I can try out an initial implementation, but if you have code suggestions would be happy to hear a clean way to code this. |
||||||||||
|
||||||||||
return K.squeeze() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
) | ||
from gpjax.kernels.stationary import ( | ||
RBF, | ||
ChangePoint, | ||
Matern12, | ||
Matern32, | ||
Matern52, | ||
|
@@ -241,6 +242,44 @@ class TestRationalQuadratic(BaseTestKernel): | |
default_compute_engine = DenseKernelComputation() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"kernel", | ||
[RBF, RationalQuadratic, Matern12, Matern32, Matern52], | ||
) | ||
@pytest.mark.parametrize("tswitch_param", [0.1, 0.25, 0.5, 0.75, 0.9]) | ||
def test_changepoint_kernel(kernel: AbstractKernel, tswitch_param: float) -> None: | ||
n_kerns = 2 | ||
Comment on lines
+249
to
+251
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks good. The only thing is, I would like to see (a) an error is thrown if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry to ask basic question, how does write a test to make sure that a proper error is thrown? |
||
# Create inputs | ||
n = 20 | ||
x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) | ||
|
||
# Create list of kernels | ||
kernels = [kernel() for _ in range(n_kerns)] | ||
|
||
# Create combination kernel | ||
combination_kernel = ChangePoint(kernels=kernels, tswitch=tswitch_param) | ||
|
||
# Check params are a list of dictionaries | ||
assert combination_kernel.kernels == kernels | ||
|
||
# Check combination kernel set | ||
assert len(combination_kernel.kernels) == n_kerns | ||
assert isinstance(combination_kernel.kernels, list) | ||
assert isinstance(combination_kernel.kernels[0], AbstractKernel) | ||
|
||
# Compute gram matrix | ||
Kxx = combination_kernel.gram(x) | ||
|
||
# Check shapes | ||
assert Kxx.shape[0] == Kxx.shape[1] | ||
assert Kxx.shape[1] == n | ||
|
||
# Check positive definiteness | ||
jitter = 1e-6 | ||
eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * jitter) | ||
assert (eigen_values > 0).all() | ||
|
||
|
||
@pytest.mark.parametrize("smoothness", [1, 2, 3]) | ||
def test_build_studentt_dist(smoothness: int) -> None: | ||
dist = build_student_t_distribution(smoothness) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a citation for your implementation of the kernel? We could add this to our
citation.py
. Also probably good to give a paper in the docstring. :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea started with the GPflow documentation, but I guess if you're thinking about a paper it would be Saatci et al 2010 : (https://icml.cc/Conferences/2010/papers/170.pdf).
I have now added it the Saatci et al 2010 reference to
citation.py
and the doc strings.