Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add changepoint kernel #357

Closed

Conversation

jakeyeung
Copy link

@jakeyeung jakeyeung commented Aug 10, 2023

Type of changes

  • Bug fix
  • New feature
  • Documentation / docstrings
  • Tests
  • Other

Checklist

  • I've formatted the new code by running poetry run pre-commit run --all-files --show-diff-on-failure before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

Add change point kernel for two RBFs.

Issue Number: N/A
Related to Q&A in Discussion #337

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening your first PR into GPJax!

If you have not heard from us in a while, please feel free to ping @gpjax/developers or anyone who has commented on the PR. Most of our reviewers are volunteers and sometimes things fall through the cracks.

You can also join us on Slack for real-time discussion.

For details on testing, writing docs, and our review process, please see the developer guide

We strive to be a welcoming and open project. Please follow our Code of Conduct.

Copy link
Collaborator

@thomaspinder thomaspinder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work @jakeyeung ! I've left some comments that we should resolve before merging. It will also need some unit tests before we're able to merge.

Comment on lines 1 to 8
#!/usr/bin/env python
"""
AUTHOR: Jake Yeung ([email protected])
CREATED ON: 2023-07-28
LAST CHANGE: see git log
LICENSE: Apache License
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change this to the standard GPJax boilerplate please?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be done now.

gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
Copy link
Author

@jakeyeung jakeyeung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced ChangePointRBF to SwitchKernel, which accepts a list of two kernels.

Comment on lines 1 to 8
#!/usr/bin/env python
"""
AUTHOR: Jake Yeung ([email protected])
CREATED ON: 2023-07-28
LAST CHANGE: see git log
LICENSE: Apache License
"""

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be done now.

gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
gpjax/kernels/stationary/changepoint.py Outdated Show resolved Hide resolved
@thomaspinder
Copy link
Collaborator

Nice work @jakeyeung - this PR is looking good. If we can add some unit tests to this, then I think we'll be good to merge.

@jakeyeung
Copy link
Author

Nice work @jakeyeung - this PR is looking good. If we can add some unit tests to this, then I think we'll be good to merge.

Hi @thomaspinder, glad to know it's progressing. Any ideas of some unit tests that would be good, both ones standard for all kernels as well as ones specific to the changepoint?

@jakeyeung
Copy link
Author

Hi @thomaspinder, sorry for the delay. I have now added a test_changepoint_kernel() function under test_stationary.py. Hope everything is alright now to merge.

@daniel-dodd
Copy link
Member

Nice work @jakeyeung! Just to check: have you rebased from the main branch - the change log seems huge here? Would just make it easier to see what you have added. :)

@jakeyeung
Copy link
Author

jakeyeung commented Oct 10, 2023

Nice work @jakeyeung! Just to check: have you rebased from the main branch - the change log seems huge here? Would just make it easier to see what you have added. :)

Hi Daniel, yeah I think what happened was that git fetch upstream and git rebase upstream/main updated a bunch of files, and then after poetry run pre-commit run --all-files changed a lot of files that weren't added by me.

I reverted one step back before this change to a bunch of files not added by me due to poetry run pre-commit run --all-files, and it should now just contain changes to three files.

Comment on lines +51 to +53
tswitch: ScalarFloat = param_field(
jnp.array(1.0), bijector=tfb.Identity(), trainable=False
)
Copy link
Member

@daniel-dodd daniel-dodd Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the changepoint location itself not learnt here (trainable=False)?

Also what happens if we have multiple changepoints (does this kernel work recursively by passing a ChangePoint kernel to kernels?)

Copy link
Author

@jakeyeung jakeyeung Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea started with implementing a basic change-point kernel similar to what GPflow has (See Change-point section in https://gpflow.github.io/GPflow/develop/notebooks/getting_started/kernels.html).

When I first implemented it, I was thinking of a fixed parameter, but I don't see any reason why it has to be that way.

At the moment this piece of code only works with two kernels, so one changepoint, but one can also imagine extending this to a list of N kernels with (N-1) changepoint locations.


@dataclass
class ChangePoint(AbstractKernel):
r"""A change point kernel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a citation for your implementation of the kernel? We could add this to our citation.py. Also probably good to give a paper in the docstring. :)

Copy link
Author

@jakeyeung jakeyeung Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea started with the GPflow documentation, but I guess if you're thinking about a paper it would be Saatci et al 2010 : (https://icml.cc/Conferences/2010/papers/170.pdf).

I have now added it the Saatci et al 2010 reference to citation.py and the doc strings.

Comment on lines 70 to 74
def __call__(
self,
x: Float[Array, " D"],
y: Float[Array, " D"],
) -> ScalarFloat:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'd be clear the intended dimensions is one here? (correct me if I'm wrong). i.e., Float[Array, " 1"] instead of Float[Array, " D"] .

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should be 1-dimensional, I've now fixed now.

Comment on lines 85 to 86

def get_function_index(x, y, tswitch=1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably drop the default argument here.

Suggested change
def get_function_index(x, y, tswitch=1):
def get_function_index(x, y, tswitch):

Comment on lines 87 to 103
r"""
Specify four possible indices given x, y, and tswitch.

Args:
x: Left hand argument of kernel function's call
y: Right hand argument of kernel function's call
tswitch: point at which to change to a different kernel
"""
# Four possible indexes: 0, 1, 2, 3
indx = 3 # if indx doesn't get set to 0, 1, or 2, then default 3
# lessthan means that at tswitch, you are already switched
cond1 = jnp.less(x, tswitch)
cond2 = jnp.less(y, tswitch)
indx = jnp.where(jnp.logical_and(cond1, cond2), 0, indx)
indx = jnp.where(jnp.logical_and(jnp.invert(cond1), cond2), 1, indx)
indx = jnp.where(jnp.logical_and(cond1, jnp.invert(cond2)), 2, indx)
return indx.squeeze().astype("uint8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would comment somewhere here (or use the docstring) the purpose - i.e., the index is going to select which kernel function to compute.

But I am a bit confused as to why we need 4 indexes? Why don't we shorten to three? As the flst could be shortened to three given k_zero is repeated? So e.g., we could have flst = [self.kernels[0].__call__, k_zero, self.kernels[1].__call__]. Just curious to know. :)

Copy link
Author

@jakeyeung jakeyeung Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now added in words the purpose of this function. I have also renamed flst to kernel_options to be more explicit.

The four indices can be reduced to three. I found the code easier to read to have four indices corresponding to four quadrants in a 2D covariance matrix. But returning just three indices might make more sense if this function accepts more than two kernels.

e.g. visualizing the gram matrix of the ChangePoint kernel you can see the four quadrants. You can imagine with more than two kernels it would be a bunch of blocks along the diagonal of the matrix, with 0 everywhere else.

image

Comment on lines +105 to +108
def k_zero(x, y):
r"""Return 0 covariance"""
out = jnp.float64(0)
return out.squeeze()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice. But the only concern I have is with jnp.float64 potentially clashing with JAX's default float32. Could this not work with jnp.array(0.0) (or even 0.0)? Or does this break the jax.lax.switch or something? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If problems persist. You could always use a Constant(0.0) kernel here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the details of jax enough to know how the different types play with other types and conditionals. What would be a good way to test which type is best to put in here?

Comment on lines 112 to 113
flst = [self.kernels[0].__call__, k_zero, k_zero, self.kernels[1].__call__]
K = jax.lax.switch(indx, flst, x, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe on naming. I don't know what flst means. Maybe something like kernel_options.

Also, why do we need self.kernels[0].__call__ and self.kernels[1].__call__. Does this not work?

kernel_options = self.kernels + [k_zero].

Would this help with extending this kernel to more than two kernels? Or is there some issue with the computing this indexing?

Copy link
Author

@jakeyeung jakeyeung Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this would work. I just need to rewrite the get_function_index() so it returns an indx that selects a kernel, otherwise set indx = len(kernels) + 1 to select k_zero.

In order to return the right index, this new get_function_index() would check whether both x and y are greater than or equal to a vector of locations (returning a boolean vector of same size as locations), then return the greatest index for which this is True.

I can try out an initial implementation, but if you have code suggestions would be happy to hear a clean way to code this.

Comment on lines 55 to 68
def __post_init__(self):
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
kernels_list: List[AbstractKernel] = []

for kernel in self.kernels:
if not isinstance(kernel, AbstractKernel):
raise TypeError("can only combine Kernel instances") # pragma: no cover

if isinstance(kernel, self.__class__):
kernels_list.extend(kernel.kernels)
else:
kernels_list.append(kernel)

self.kernels = kernels_list
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. But we should probably through an error if we specify more than two kernels (with the current implementation)?

Comment on lines +249 to +251
@pytest.mark.parametrize("tswitch_param", [0.1, 0.25, 0.5, 0.75, 0.9])
def test_changepoint_kernel(kernel: AbstractKernel, tswitch_param: float) -> None:
n_kerns = 2
Copy link
Member

@daniel-dodd daniel-dodd Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. The only thing is, I would like to see (a) an error is thrown if n_kerns is wrong and (b) this is tested.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to ask basic question, how does write a test to make sure that a proper error is thrown?

@daniel-dodd
Copy link
Member

daniel-dodd commented Oct 10, 2023

Thanks @jakeyeung. Overall this looks really good. Quite excited to give this a go on some data soon. Mainly asking some clarification questions in the above review. Just the only main points is, I would suggest it is made clear that this is for 1D data. This needs to be in the docstrings and in the typing if possible. I would like to see a check for the number of kernels. If this kernel is just for two kernels (and not more), then maybe instead of kernels as a list, it might be more sensible to have e.g., before_kernel, after_kernel, before and after the changepoint. Maybe I'd also change naming of tswitch to location of changepoint (or locations if this supports multiple kernels as an input).

@github-actions github-actions bot added the stale label Sep 2, 2024
@github-actions github-actions bot closed this Sep 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants