-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add changepoint kernel #357
Add changepoint kernel #357
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for opening your first PR into GPJax!
If you have not heard from us in a while, please feel free to ping @gpjax/developers
or anyone who has commented on the PR. Most of our reviewers are volunteers and sometimes things fall through the cracks.
You can also join us on Slack for real-time discussion.
For details on testing, writing docs, and our review process, please see the developer guide
We strive to be a welcoming and open project. Please follow our Code of Conduct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work @jakeyeung ! I've left some comments that we should resolve before merging. It will also need some unit tests before we're able to merge.
#!/usr/bin/env python | ||
""" | ||
AUTHOR: Jake Yeung ([email protected]) | ||
CREATED ON: 2023-07-28 | ||
LAST CHANGE: see git log | ||
LICENSE: Apache License | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change this to the standard GPJax boilerplate please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be done now.
kernels. Update new imports, check they are sorted with isort
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced ChangePointRBF to SwitchKernel, which accepts a list of two kernels.
#!/usr/bin/env python | ||
""" | ||
AUTHOR: Jake Yeung ([email protected]) | ||
CREATED ON: 2023-07-28 | ||
LAST CHANGE: see git log | ||
LICENSE: Apache License | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be done now.
Nice work @jakeyeung - this PR is looking good. If we can add some unit tests to this, then I think we'll be good to merge. |
Hi @thomaspinder, glad to know it's progressing. Any ideas of some unit tests that would be good, both ones standard for all kernels as well as ones specific to the changepoint? |
kernels. Update new imports, check they are sorted with isort
…changepoint_kernel
Hi @thomaspinder, sorry for the delay. I have now added a |
Nice work @jakeyeung! Just to check: have you rebased from the main branch - the change log seems huge here? Would just make it easier to see what you have added. :) |
Hi Daniel, yeah I think what happened was that I reverted one step back before this change to a bunch of files not added by me due to |
533347f
to
19a58da
Compare
tswitch: ScalarFloat = param_field( | ||
jnp.array(1.0), bijector=tfb.Identity(), trainable=False | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the changepoint location itself not learnt here (trainable=False
)?
Also what happens if we have multiple changepoints (does this kernel work recursively by passing a ChangePoint kernel to kernels
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea started with implementing a basic change-point kernel similar to what GPflow has (See Change-point section in https://gpflow.github.io/GPflow/develop/notebooks/getting_started/kernels.html).
When I first implemented it, I was thinking of a fixed parameter, but I don't see any reason why it has to be that way.
At the moment this piece of code only works with two kernels, so one changepoint, but one can also imagine extending this to a list of N kernels with (N-1) changepoint locations.
|
||
@dataclass | ||
class ChangePoint(AbstractKernel): | ||
r"""A change point kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a citation for your implementation of the kernel? We could add this to our citation.py
. Also probably good to give a paper in the docstring. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea started with the GPflow documentation, but I guess if you're thinking about a paper it would be Saatci et al 2010 : (https://icml.cc/Conferences/2010/papers/170.pdf).
I have now added it the Saatci et al 2010 reference to citation.py
and the doc strings.
def __call__( | ||
self, | ||
x: Float[Array, " D"], | ||
y: Float[Array, " D"], | ||
) -> ScalarFloat: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'd be clear the intended dimensions is one here? (correct me if I'm wrong). i.e., Float[Array, " 1"]
instead of Float[Array, " D"]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They should be 1-dimensional, I've now fixed now.
|
||
def get_function_index(x, y, tswitch=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would probably drop the default argument here.
def get_function_index(x, y, tswitch=1): | |
def get_function_index(x, y, tswitch): |
r""" | ||
Specify four possible indices given x, y, and tswitch. | ||
|
||
Args: | ||
x: Left hand argument of kernel function's call | ||
y: Right hand argument of kernel function's call | ||
tswitch: point at which to change to a different kernel | ||
""" | ||
# Four possible indexes: 0, 1, 2, 3 | ||
indx = 3 # if indx doesn't get set to 0, 1, or 2, then default 3 | ||
# lessthan means that at tswitch, you are already switched | ||
cond1 = jnp.less(x, tswitch) | ||
cond2 = jnp.less(y, tswitch) | ||
indx = jnp.where(jnp.logical_and(cond1, cond2), 0, indx) | ||
indx = jnp.where(jnp.logical_and(jnp.invert(cond1), cond2), 1, indx) | ||
indx = jnp.where(jnp.logical_and(cond1, jnp.invert(cond2)), 2, indx) | ||
return indx.squeeze().astype("uint8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would comment somewhere here (or use the docstring) the purpose - i.e., the index is going to select which kernel function to compute.
But I am a bit confused as to why we need 4 indexes? Why don't we shorten to three? As the flst
could be shortened to three given k_zero
is repeated? So e.g., we could have flst = [self.kernels[0].__call__, k_zero, self.kernels[1].__call__]
. Just curious to know. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have now added in words the purpose of this function. I have also renamed flst
to kernel_options
to be more explicit.
The four indices can be reduced to three. I found the code easier to read to have four indices corresponding to four quadrants in a 2D covariance matrix. But returning just three indices might make more sense if this function accepts more than two kernels.
e.g. visualizing the gram matrix of the ChangePoint kernel you can see the four quadrants. You can imagine with more than two kernels it would be a bunch of blocks along the diagonal of the matrix, with 0 everywhere else.
def k_zero(x, y): | ||
r"""Return 0 covariance""" | ||
out = jnp.float64(0) | ||
return out.squeeze() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice. But the only concern I have is with jnp.float64
potentially clashing with JAX's default float32
. Could this not work with jnp.array(0.0) (or even 0.0)? Or does this break the jax.lax.switch
or something? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If problems persist. You could always use a Constant(0.0)
kernel here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know the details of jax
enough to know how the different types play with other types and conditionals. What would be a good way to test which type is best to put in here?
flst = [self.kernels[0].__call__, k_zero, k_zero, self.kernels[1].__call__] | ||
K = jax.lax.switch(indx, flst, x, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe on naming. I don't know what flst
means. Maybe something like kernel_options
.
Also, why do we need self.kernels[0].__call__
and self.kernels[1].__call__
. Does this not work?
kernel_options = self.kernels + [k_zero].
Would this help with extending this kernel to more than two kernels? Or is there some issue with the computing this indexing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this would work. I just need to rewrite the get_function_index()
so it returns an indx
that selects a kernel, otherwise set indx = len(kernels) + 1
to select k_zero
.
In order to return the right index, this new get_function_index()
would check whether both x
and y
are greater than or equal to a vector of locations (returning a boolean vector of same size as locations
), then return the greatest index for which this is True.
I can try out an initial implementation, but if you have code suggestions would be happy to hear a clean way to code this.
def __post_init__(self): | ||
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels. | ||
kernels_list: List[AbstractKernel] = [] | ||
|
||
for kernel in self.kernels: | ||
if not isinstance(kernel, AbstractKernel): | ||
raise TypeError("can only combine Kernel instances") # pragma: no cover | ||
|
||
if isinstance(kernel, self.__class__): | ||
kernels_list.extend(kernel.kernels) | ||
else: | ||
kernels_list.append(kernel) | ||
|
||
self.kernels = kernels_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great. But we should probably through an error if we specify more than two kernels (with the current implementation)?
@pytest.mark.parametrize("tswitch_param", [0.1, 0.25, 0.5, 0.75, 0.9]) | ||
def test_changepoint_kernel(kernel: AbstractKernel, tswitch_param: float) -> None: | ||
n_kerns = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. The only thing is, I would like to see (a) an error is thrown if n_kerns
is wrong and (b) this is tested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to ask basic question, how does write a test to make sure that a proper error is thrown?
Thanks @jakeyeung. Overall this looks really good. Quite excited to give this a go on some data soon. Mainly asking some clarification questions in the above review. Just the only main points is, I would suggest it is made clear that this is for 1D data. This needs to be in the docstrings and in the typing if possible. I would like to see a check for the number of kernels. If this kernel is just for two kernels (and not more), then maybe instead of |
Type of changes
Checklist
poetry run pre-commit run --all-files --show-diff-on-failure
before committing.Description
Add change point kernel for two RBFs.
Issue Number: N/A
Related to Q&A in Discussion #337