Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement periodic kernel for HSGP #6877

Merged
merged 26 commits into from
Dec 10, 2023
Merged

Conversation

theorashid
Copy link
Contributor

@theorashid theorashid commented Aug 24, 2023

What is this PR about?

This is for #6868. We're adding functionality for periodic kernels in the HSGP approximation.

I'm opening a draft PR where I have:

  • Implemented the spectral density for the periodic kernel
  • Added a test for the spectral density for the periodic kernel
  • Simplified the HSGP test and set up @pytest.mark.parametrize where we can add the cov_func=Periodic when implementation is complete. I also removed the unused gp pytest.fixture

Design questions

Before, there are a few design questions I need help with:

  1. How do we want to implement this within the HSGP class? Do we want an if/else logic within HSGP class or shall we create a separate HSGPPeriodic class that we can try and call from HSGP if cov_func=Periodic
  2. The approximation is only for 1D so we need to assert that len(self._m) == 1 and equally cov_func.n_dims == 1. Where do we do this? (Note, this also makes self._m_star pointless.)
  3. Is it possible to have a centered/noncentered parametrisation in this case? Probably not because of the sine/cosine transform.
  4. How many betas do we need? The first sine term is 0 (sum of cosine terms from $j=0$ to $J$, sum of sine terms from $j=1$ to $J$). How does _drop_first fit in here to ensure identifiability?

Useful code snippets for HSGP class

Here are some rough snippets of code that will be in the HSGP class, but we should discuss where before implementing. It is all untested.

Calculate the eigenvectors

def calc_eigenvectors_periodic(
    Xs: TensorLike,
    period: TensorLike,
    m: Sequence[int],
    tl: ModuleType = np,
):
    """
    The eigenvectors do not depend on the covariance hyperparameters.
    m is just an int here because we are not in multiple dimensions.
    If I'm right, I don't think we need the eigenvalues
    """
    w0 = (2 * tl.pi) / period # angular frequency defining the periodicity
    m1 = tl.tile(w0 * Xs[:, None], m)
    m2 = tl.diag(tl.arange(m))
    mw0x = m1 @ m2
    phi_cos = tl.cos(mw0x)
    phi_sin = tl.sin(mw0x)
    return phi_cos, phi_sin

Attempt at prior_linearized

def prior_linearized(self, Xs: TensorLike):
    # Index Xs using input_dim and active_dims of covariance function
    Xs, _ = self.cov_func._slice(Xs)

    phi_cos, phi_sin = calc_eigenvectors_periodic(Xs, self.cov_func.period, self._m, tl=pt)

    psd = self.cov_func.power_spectral_density(self._m)

    return (phi_cos, phi_sin), psd

Attempt at prior (ignoring centered/noncentered)

def prior(self, name: str, X: TensorLike, dims: Optional[str] = None):
    self._X_mean = pt.mean(X, axis=0)
    (phi_cos, phi_sin), psd = self.prior_linearized(X - self._X_mean)

    self._beta = pm.Normal(
        f"{name}_hsgp_coeffs_", size=(self._m, 2)
    )
    f = phi_cos @ (psd * self._beta[:, 0]) + phi_sin[1:] @ (psd[1:] * self._beta[1:, 2])

    self.f = pm.Deterministic(name, f, dims=dims)
    return self.f

Resources

TODO

General TODO that I will update:

  • Add documentation for HSGP class with periodic
  • Add test for HSGP periodic
  • Add assertions so everything is in 1D

Checklist

New features

  • Periodic kernel for HSGP implementation

Bugfixes

  • Removed unnecessary GP fixture in HSGP test

Documentation

  • Added docs for PSD of Periodic

Maintenance

N/A


📚 Documentation preview 📚: https://pymc--6877.org.readthedocs.build/en/6877/

@theorashid
Copy link
Contributor Author

@bwengals if you have some time, please could I have your help with this so I don't mess up the current pymc.gp API

@codecov
Copy link

codecov bot commented Aug 24, 2023

Codecov Report

Merging #6877 (56963e0) into main (827918b) will increase coverage by 3.18%.
Report is 31 commits behind head on main.
The diff coverage is 95.38%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6877      +/-   ##
==========================================
+ Coverage   89.01%   92.19%   +3.18%     
==========================================
  Files         100      101       +1     
  Lines       16859    16909      +50     
==========================================
+ Hits        15007    15590     +583     
+ Misses       1852     1319     -533     
Files Coverage Δ
pymc/gp/__init__.py 100.00% <100.00%> (ø)
pymc/gp/cov.py 98.02% <100.00%> (+0.02%) ⬆️
pymc/gp/hsgp_approx.py 95.00% <95.00%> (+2.61%) ⬆️

... and 39 files with indirect coverage changes

@bwengals
Copy link
Contributor

Do we want an if/else logic within HSGP class or shall we create a separate HSGPPeriodic class that we can try and call from HSGP if cov_func=Periodic

I think this would be ideal! But it's hard to picture how tricky it'd make the code without actually trying to write it. I think you'd just need to check if cov_func is Periodic and if cov_func.n_dim == 1. Might have to rename HSGP be a function that dispatches to each class? Longer term thought (tagging @ricardoV94) is to port the GP stuff into pytensor to take advantage of all the rewrites and optimizations it can do.

@theorashid
Copy link
Contributor Author

theorashid commented Aug 25, 2023

@bwengals I've put a basic API there which uses a single class and if-else logic. Let me know what you think. I'm a bit concerned that multiple classes would bloat the file with repeat documentation. However, with a single class, the arguments L, c, parametrization (and maybe drop_first, but not sure) are redundant for the Periodic case.

If you're okay with it, please could you explain how I would go about testing the code. I need to adapt the two-sample test, but I'm unsure what this line is doing and how I would generalise for the Periodic case/other cov_func. And is there anything else I need to change?

Aside: There's a few issues with the type hints that aren't passing mypy, mostly to do with m: Seq[int]. I'll fix them later down the line. (By either removing type hints for m in Periodic.power_spectral_density(self, m), by dealing with the 1-D and many-D case in HSGP for tuple(m), also need to update calc_basis_periodic to take m as int or to slice m[-1]. Not sure how to fix Value of type "Normal" is not indexable, I thought you could index in pymc).

I think this case works for both 1-D cases X and X[:, None]. I tried using an outer product in calc_basis_periodic but it collapses the extended dim, but I'll have to work out which is preferable during testing.

pymc/gp/hsgp_approx.py Outdated Show resolved Hide resolved
@bwengals
Copy link
Contributor

bwengals commented Sep 8, 2023

I've put a basic API there which uses a single class and if-else logic. Let me know what you think. I'm a bit concerned that multiple classes would bloat the file with repeat documentation. However, with a single class, the arguments L, c, parametrization (and maybe drop_first, but not sure) are redundant for the Periodic case.

I see what you're saying now. I like how you've got it here better than what I'm picturing with multiple classes. It came out really simple. I think it'd be good to ignore drop_first for the Periodic case. I kind of want to remove it since it only makes sense in 1D. The input argument checking logic will have to be updated to handle the Periodic case where you don't need L or c. I think it makes sense at that part to just start with if cov == Periodic: and handle the two cases at the top level, wdyt?

If you're okay with it, please could you explain how I would go about testing the code. I need to adapt the two-sample test, but I'm unsure what this line is doing and how I would generalise for the Periodic case/other cov_func. And is there anything else I need to change?

That MMD statistic thing is OK as is. It compares two sets of multivariate samples to see if they're from the same distribution. In the tests I'm drawing samples from the HSGP prior and the regular GP prior and doing a hypothesis test to see if they're the same.

I think this case works for both 1-D cases X and X[:, None]. I tried using an outer product in calc_basis_periodic but it collapses the extended dim, but I'll have to work out which is preferable during testing.

All the other GP stuff in PyMC requires X to have the column vector shape with the [:, None], so I think it's def OK to expect that and not have to have it work for both shapes (n, ) and (n, 1)

@theorashid
Copy link
Contributor Author

theorashid commented Sep 15, 2023

Sorry I'm being slow on this. I'm having some paper review + thesis nightmares atm.

While I'm waiting for some free time to do this, some qs @bwengals :

I kind of want to remove it since it only makes sense in 1D.

If you let me know how, I'm happy to implement that in this PR

That MMD statistic thing is OK as is. It compares two sets of multivariate samples to see if they're from the same distribution. In the tests I'm drawing samples from the HSGP prior and the regular GP prior and doing a hypothesis test to see if they're the same.

I'm a bit confused how to generalise this to Periodic kernel though (so we can test if the HSGP and GP prior with Periodic kernel match). Because the ExpQuad is hardcoded in there. In particular, I am not sure how you chose the test_ell, and whether for Periodic I'd need a test_w0 for example. IIUC, the test_ell can be anything, because you're just checking the draws are from the same distribution, and not checking whether the draws fit any data well (i.e. not testing whether this is a good ell for the data)

All the other GP stuff in PyMC requires X to have the column vector shape with the [:, None]

Do we force this using np.atleast_2d or something? Or check shapes?

@bwengals
Copy link
Contributor

If you let me know how, I'm happy to implement that in this PR

Sure! Though maybe it's best to deprecate it. Maybe just put a deprecation warning in if drop_first is True?

So the ExpQuad there for the MMD statistic is used to compare samples to each other. It's just measuring distances from one sample to the next. Those samples could be from a GP with ExpQuad kernel or Periodic kernel or something else, doesn't matter. The number of input dimensions there is the number of dimensions in the sample.

The intuition behind the test (at least as I understand it) is, say you have two sets of samples from two multivariate distributions, represented by random variables X and Y. You want to test if the two multivariate distributions are the same using the samples. Say the dimension of the multivariate distribution(s) is 10. You calculate mean(K(s_X, s_X)) + mean(K(s_Y, s_Y)) - 2 mean(k(s_X, s_Y)) as your statistic. Say s_X are the sampls from the distribution of RV X and s_Y are the samples from RV Y. K can be any kernel function and it has input_dim = 10, though some will be more discerning than others. So if you think about it, this number will be near zero if the two sets of samples are from the same distribution. The "distances" between sets of samples in s_X and s_Y (term K(s_X, s_Y)) should be similarly scaled as the distances you'd get between just s_X (K(s_X, s_X)) and just s_Y (K(s_Y, s_Y)) . LMK if that's not really making sense.\

Do we force this using np.atleast_2d or something? Or check shapes?

No... but that's probably a good idea? Feel free to ignore for this PR!

@theorashid
Copy link
Contributor Author

Thesis is handed in btw so I'll be back on this shortly

@bwengals
Copy link
Contributor

bwengals commented Oct 6, 2023

Congratulations!! Take your time

@theorashid
Copy link
Contributor Author

theorashid commented Oct 9, 2023

Parameter checking is in place. After fixing a few type errors caused by forcing m to be a sequence so it's compatible with other GPs, I've run into a few issues. So if I understand what you're saying about MMD, I can just leave it all as it is and run

@pytest.mark.parametrize("cov_func,parameterization", [
    (pm.gp.cov.ExpQuad(1, ls=1), "centered"),
    (pm.gp.cov.ExpQuad(1, ls=1), "noncentered"),
    (pm.gp.cov.Periodic(1, period=1, ls=1), None),
])
test_prior/test_conditional

The tests are failing. I think there are a few things to address:

  1. test_prior is failing because maths: AssertionError: H0 was rejected, even though HSGP and GP priors should match. . I'm not sure how to check this. Did I need to change something with the test or shall I check visually whether I've done the maths right?
    2. test_conditonal is failing because ValueError: Prior is not set, can't create a conditional. Call .prior(name, X) first.. This refers to the lines

I'm not sure what the logic is here, or how to adapt it for the periodic case. self._beta and self._X_mean. self._sqrt_psd does not need to exist because all the relevant things are derived like the "noncentered" case(phi_cos, phi_sin), psd = self.prior_linearized(X - self._X_mean)

UPDATE: ignore this. I just needed to set self._parameterization for the periodic case. This is fixed.

All the other GP stuff in PyMC requires X to have the column vector shape with the [:, None], so I think it's def OK to expect that and not have to have it work for both shapes (n, ) and (n, 1)

  1. I looked into dimensionality for test_prior. samples1 has the shape (1000, 100) (nsamp, len(X1)). I had to edit calc_basis_periodic to change X[:, None] to X. I think it's robust but keeping this here unless something breaks down the line.

Hopefully all straightforward fixes...

Update:

I have fixed the conditional, and the mmd test passes for that. I printed the mmd statistic and critical value for the prior_test, and we have mmd=0.04685569443160442 critical_value=0.003483518867031471, so quite a way off, but in the right direction. I tried increasing m but that didn't improve things.

Here's some samples
image

@bwengals
Copy link
Contributor

Looks like its working nicely! Just gotta make the tests happy. I'll check out your branch today and see if I can help out. def gotta dig into it a bit

@bwengals
Copy link
Contributor

Hey so I think maybe found the issue. It looks like the period is very slightly off in the approximation. Try running this:

with pm.Model() as model:
    X1 = np.linspace(-5, 5, 100)[:, None]
    cov_func = pm.gp.cov.Periodic(1, ls=1, period=5)  # <---- Changed period to 5 here
    #cov_func = pm.gp.cov.Matern52(1, ls=1)
    
    hsgp = pm.gp.HSGP(m=[200], c=2.0, parameterization=None, cov_func=cov_func)
    f1 = hsgp.prior("f1", X=X1)

    gp = pm.gp.Latent(cov_func=cov_func)
    f2 = gp.prior("f2", X=X1)

    idata = pm.sample_prior_predictive(samples=1000)

samples1 = az.extract(idata.prior["f1"])["f1"].values.T
samples2 = az.extract(idata.prior["f2"])["f2"].values.T

## also added this below, subtracting the mean of the samples made the mmd stat much happier
# because the samples from periodic seem to often center far from zero
samples1 = samples1 - np.mean(samples1, axis=1)[:, None]
samples2 = samples2 - np.mean(samples2, axis=1)[:, None]

I changed the period to 5, and X1 goes from -5 to 5, so now there should be exactly two repetitions of the cycle in all the samples.

But if you look at the results of

samples1[:, 0] - samples1[:, -1]
# vs
samples2[:, 0] - samples2[:, -1]

the differences on sample2 are a bit bigger. I think the MMD stat is picking up on this. Maybe there is something very slightly off with the formula for the approx PSD? If not, maybe there is a bit more numerical error calculating the approx psd than for k(x, x') of Periodic.

If there's nothing wrong with your approx PSD calc, this may not be a big deal in practice because the differences are pretty tiny.

@theorashid
Copy link
Contributor Author

theorashid commented Oct 12, 2023

Okay so the only way I know how to do this is to step by step explain the maths I followed, which also makes it easier for someone (you, sorry) to check. If you're reading this, I haven't found any obvious mistake in the maths to correct.

I'm following the paper appendix B. I also checked the code alongside the numpyro example.

In the paper, they state the expansion
Screenshot 2023-10-12 at 19 39 43
Comparing pymc/gp and their notation, we have $\alpha$ is variance (factored out in gp until modelling), $\omega_0$ is the angular frequency which related to period as w0 = (2 * np.pi) / period, $j$ is effectively m (I have defined J = arange(0, m, 1) in the code to be consistent for m with other hsgp kernels), $\tau$ is calculated as differences in X, and the expansion coefficients $\tilde{q}_j^2$ are a function of $\ell$ / ell.

A few lines later, they give the expansion coefficients
Screenshot 2023-10-12 at 19 44 25
which matches the code for the power_spectral_density()

def power_spectral_density(self, J: TensorLike) -> TensorVariable:
        a = 1 / pt.square(self.ls)
        c = pt.where(J > 0, 2, 1) # J = 0 term does not have 2
        # this is the right bessel function: https://github.com/pymc-devs/pytensor/blob/main/pytensor/tensor/math.py#L1427-L1429
        # the numpyro example used the tfp.math.bessel_ive, which is exponentially scaled and they had to unscale it
        q2 = c * pt.iv(J, a) / pt.exp(a)

        return q2

Some more lines and they give the expansion

Screenshot 2023-10-12 at 19 47 00

Steps:

  1. Ingredient 1: The cosine and sine basis terms, which do not depend on $\ell$
def calc_basis_periodic(
    Xs: TensorLike,
    period: TensorLike,
    m: Sequence[int],
    tl: ModuleType = np,
):
    m0 = m[0]  # for compatibility with other kernels, m must be a sequence
    w0 = (2 * np.pi) / period  # angular frequency defining the periodicity
    m1 = tl.tile(w0 * Xs, m0)
    m2 = tl.diag(tl.arange(0, m0, 1))
    mw0x = m1 @ m2
    # the above lines basically mutiply a column of w0 * X by an increasing integer until we get to m0.
    # It is the same as doing `np.outer()` but I think I thought it was more safe-shape at the time
    phi_cos = tl.cos(mw0x)
    phi_sin = tl.sin(mw0x)
    return phi_cos, phi_sin
  1. Ingredient 2: Getting the expansion coefficients and the basis
            # in `prior_linearized()`
            phi_cos, phi_sin = calc_basis_periodic(Xs, self.cov_func.period, self._m, tl=pt)
            J = pt.arange(0, self._m[0], 1)
            psd = self.cov_func.power_spectral_density(J)
   
            ...

            # in `prior()`
            (phi_cos, phi_sin), psd = self.prior_linearized(X - self._X_mean)

            m0 = self._m[0]
  1. Ingredient 3: The $\beta_j$ terms ($2J + 1$ of them), taken from a $\text{Normal}(0, 1)$
           # in `prior()`
            self._beta = pm.Normal(f"{name}_hsgp_coeffs_", size=(m0 * 2 - 1))
  1. Adding it all up
            # in `prior()`
            # The first eigenfunction for the sine component is zero and so does not contribute to the approximation.
            f = (
                self.mean_func(X)
                + phi_cos @ (psd * self._beta[:m0])
                + phi_sin[..., 1:] @ (psd[1:] * self._beta[m0:])
            )

I'm a bit concerned the test isn't passing for 2 reasons:

  1. The MMD thing works fine for the conditional test
  2. The paper says the approximation is good for $J \geq \frac{3.72}{\ell}$, and with m=200 we're well above that.

What do you think?

@bwengals
Copy link
Contributor

bwengals commented Oct 12, 2023

Thanks so so much for laying everything out, makes things super easy for me, though time consuming for you. But thanks! Reading through everything you're doing looks 100%.

The MMD thing works fine for the conditional test

I think that makes sense here, it's checking if samples from conditional having seen no data are from the same distribution as the samples from the prior. Both from HSGP, so it's self consistent.

One thing (and this is a reach) is zeroing out the first eigenfunction on the sine side. I couldn't find a note in the paper, might have missed it. Just in the numpyro implementation where they say

The first eigenfunction for the sine component
is zero, so the first parameter wouldn't contribute to the approximation.
We set it to zero to identify the model and avoid divergences.

Maybe it's not zero but its just super tiny...?

I think best path forward is just to make a note of it and call it good. Agree that you implemented it right, even if MMD is a little skeptical. And actually, when I mean subtract the samples before passing them to the two sample test, it passes quite a bit for Periodic. Put it as one of those suggestion things (nevermind couldn't figure out how to do this). For the current random seed 10 tests pass on my machine when I do this.

        # in def test_prior

        samples1 = az.extract(idata.prior["f1"])["f1"].values.T
        samples2 = az.extract(idata.prior["f2"])["f2"].values.T
        
        # added these two lines
        samples1 = samples1 - np.mean(samples1, axis=1)[:, None]
        samples2 = samples2 - np.mean(samples2, axis=1)[:, None]

        h0, mmd, critical_value, reject = two_sample_test(
            samples1, samples2, n_sims=500, alpha=0.01
        )

@theorashid
Copy link
Contributor Author

theorashid commented Oct 13, 2023

though time consuming for you

Yeah lol I accidentally left the page whilst writing it too, so had to write it all again. Bit of a 'mare

I'll get to this either at some point this weekend, but more likely next week. Couple of qs

One thing (and this is a reach) is zeroing out the first eigenfunction on the sine side. I couldn't find a note in the paper, might have missed it. Just in the numpyro implementation where they say

IIUC, in the paper it's because the sum for cosine terms goes from 0 and the sum for sine terms goes from 1, which I've implemented with phi_sin[..., 1:]. Is that what you meant or something else?

I think best path forward is just to make a note of it and call it good.

Do you mean make a note on the tests explaining what we tried? Will it reduce coverage? I don't think so, because the code is called in test_conditional()

I'll tidy things up, freshen up the docs, fix the mypy issue, and fix the fresh conflicts. Then I'll take it out of draft status so it's ready for review. Nearly there...

@bwengals
Copy link
Contributor

Yeah lol I accidentally left the page whilst writing it too, so had to write it all again. Bit of a 'mare

Oh god, well thanks for redoing it! And no rush at all ofc.

Just to make a note that the period might be ever so slightly off either in the class or in the test. Maybe what you tried and where you left off. If it actually causes someone problems, which I really doubt, then at least they have a lead. Looks to me that your implementation matches with the paper and numpyro but I do think the tests actually did their job here and flagged this.

For me all the tests pass now when I add bit subtracting the mean from the samples.

Again, great PR, looking forward to using this!

@theorashid
Copy link
Contributor Author

theorashid commented Oct 19, 2023

Yeah this happened for both numpyro and blackjax sampelrs. AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)?

Yeah I got this too. I didn't have any reference to tfp though. I don't know enough about pytensor to know whether they call tfp internally there. @ricardoV94 ?

with pm.Model():
    hsgp = pm.gp.HSGP(m=[200], cov_func=pm.gp.cov.Periodic(1, period=1, ls=1))
    f1 = hsgp.prior("f1", X=X1)

    idata = pm.sample_prior_predictive(samples=1000, random_seed=1)
    pm.sample(nuts_sampler='numpyro')



    pm.sample(nuts_sampler='numpyro')
  File "/Users/theorashid/Library/Mobile Documents/com~apple~CloudDocs/Documents/dev/pymc/pymc/sampling/mcmc.py", line 660, in sample
  File "/Users/theorashid/Library/Mobile Documents/com~apple~CloudDocs/Documents/dev/pymc/pymc/sampling/mcmc.py", line 313, in _sample_external_nuts
    )
  File "/Users/theorashid/Library/Mobile Documents/com~apple~CloudDocs/Documents/dev/pymc/pymc/sampling/jax.py", line 23, in <module>
    import jax
  File "/opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev/lib/python3.11/site-packages/jax/__init__.py", line 39, in <module>
    from jax import config as _config_module
  File "/opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev/lib/python3.11/site-packages/jax/config.py", line 17, in <module>
    from jax._src.config import config  # noqa: F401
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev/lib/python3.11/site-packages/jax/_src/config.py", line 27, in <module>
    from jax._src import lib
  File "/opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev/lib/python3.11/site-packages/jax/_src/lib/__init__.py", line 75, in <module>
    jax_version=jax.version.__version__,
                ^^^^^^^^^^^
AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)

@ricardoV94
Copy link
Member

@theorashid can you provide a reproducible snippet?

@theorashid
Copy link
Contributor Author

theorashid commented Oct 19, 2023

I tried with some different models to get a minimum example and they threw the error that I need to install numpyro (I was told first to install jax, then jaxlib, but not numpyro). I installed numpyro and it all works now – so @bwengals see if that fixes it for you, unless we had different errors.

The model above should probably throw an error saying to install numpyro rather than the circular jax import error, but that's just being picky. All seems to work for me locally. Again, not really sure where tfp comes into it

@bwengals
Copy link
Contributor

Still no luck, looks like its a versioning issue between tensorflow, tfp, and jax. What versions of these things do you have?

@theorashid
Copy link
Contributor Author

theorashid commented Oct 20, 2023

No tfp in my env

(pymc-dev) dev/pymc [hsgp-periodic●] » pip freeze | grep -e numpyro -e jax -e jaxlib -e pytensor -e tensorflow -e pymc
jax==0.4.18
jaxlib==0.4.18
numpyro==0.13.2

EDIT: I just installed tensorflow-probability==0.22.0 into my env and it worked fine

@bwengals
Copy link
Contributor

Awesome, thanks @theorashid. Got my env sorted and it looks like the numpyro sampler does a nice job on it.

[
(pm.gp.cov.ExpQuad(1, ls=1), "centered"),
(pm.gp.cov.ExpQuad(1, ls=1), "noncentered"),
# (pm.gp.cov.Periodic(1, period=1, ls=1), None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get passing tests when I uncomment this line, then zero the mean on samples1 and samples2

        samples1 = az.extract(idata.prior["f1"])["f1"].values.T
        samples2 = az.extract(idata.prior["f2"])["f2"].values.T

        samples1 = samples1 - np.mean(samples1, axis=1)[:, None]
        samples2 = samples2 - np.mean(samples2, axis=1)[:, None]

Does that work on your end?

Copy link
Contributor Author

@theorashid theorashid Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that passes for me. Is the logic sound for doing this? As in, is there a mathematical reason that we shouldn't do this? If not, what comment can I put it to explain why we do this if I ever come back to this?

Just double checking but it's all completely unrelated that we take the mean of the domain with self.prior_linearized(X - self._X_mean)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all completely unrelated that we take the mean of the domain with self.prior_linearized(X - self._X_mean)

Yep unrelated to that. Here's some explanation on the MMD statistic, I don't think I linked that already.

You can use this MMD statistic thing to tell if two sets of samples come from the same distribution. It's a kernel test, can be applied to samples from any multivariate situation, doesn't have to be applied to samples drawn from other kernel methods (sorry for the inception-ness we have here in this case). So here, sets of samples from the periodic HSGP prior and samples from the periodic GP prior should look like they came from the same distribution. The kernel the MMD stat uses (expquad) is unrelated to the kernels from the GPs, it's its own thing.

Actually I just wrote a long reply and deleted it cuz I found this. Try running this to see why the MMD statistic started passing after removing the mean:

with pm.Model() as model:
    
    eta = pm.Exponential("eta", lam=1.0)
    ell = pm.InverseGamma("ell", mu=2, sigma=1)
    per = 1.0
    cov_func = pm.gp.cov.Periodic(1, ls=ell, period=per)
    
    hsgp = pm.gp.HSGP(m=[500], cov_func=cov_func)
    f_hsgp = hsgp.prior("f_hsgp", X=x[:, None])
    
    gp = pm.gp.Latent(cov_func=cov_func)
    f_gp = gp.prior("f_gp", X=x[:, None])
    
    idata = pm.sample_prior_predictive(1000)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
f_hsgp = az.extract(idata.prior, var_names="f_hsgp")
axs[0].plot(f_hsgp.data);
axs[0].set_ylim([-3, 3])

f_gp = az.extract(idata.prior, var_names="f_gp")
axs[1].plot(f_gp.data);
axs[1].set_ylim([-3, 3]);

The HSGP prior samples don't have as much level shift up and down, or maybe a lower overall variance. Maybe something is still off in the implementation, or further down in some special function like iv? I also wonder now if the Stan or Numpyro implementations also show this behavior if you take draws from their priors? I'll probably leave this over the weekend and come back to it next week, but I can help you run this stuff down.

I think if Stan/numpyro show the same behavior, we can just mark the test xfail with a (fairly long probably) explanation. I don't think the mean subtraction thing I was doing is legit anymore... Maybe worth reaching out to the paper authors in that case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running that code, I got the pytensor warning UserWarning: Optimization Warning: The Op iv does not provide a C implementation. As well as being potentially slow, this also disables loop fusion. Could that be related? I don't know if there's a way to run sample_prior_predictive using jax. Given the coefficients arec * pt.iv(J, a) / pt.exp(a), maybe all the samples are out by pt.iv(J, a)? Wild guess though

I'll write you a snippet for the numpyro now, Stan would take a bit longer for me

Copy link
Contributor Author

@theorashid theorashid Oct 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the numpyro model:

import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import Predictive

def modified_bessel_first_kind(v, z):
    v = jnp.asarray(v, dtype=float)
    return jnp.exp(jnp.abs(z)) * tfp.math.bessel_ive(v, z)

def diag_spectral_density_periodic(alpha, length, M):
    a = length ** (-2)
    J = jnp.arange(0, M)
    c = jnp.where(J > 0, 2, 1)
    q2 = (c * alpha**2 / jnp.exp(a)) * modified_bessel_first_kind(J, a)
    return q2

def eigenfunctions_periodic(x, w0, M):
    m1 = jnp.tile(w0 * x[:, None], M)
    m2 = jnp.diag(jnp.arange(M, dtype=jnp.float32))
    mw0x = m1 @ m2
    cosines = jnp.cos(mw0x)
    sines = jnp.sin(mw0x)
    return cosines, sines

def model(X, y=None):
    per = 1.0
    w0 = (jnp.pi * 2 / per)
    M = 500

    # alpha = numpyro.sample("alpha", dist.HalfNormal(1.0)) # this is the variance, setting to 1
    alpha = 1.0
    # numpyro does not have mu, sigma parametrisation, so change the pymc prior to compare
    ell = numpyro.sample("ell", dist.InverseGamma(2, 1))

    q2 = diag_spectral_density_periodic(alpha, ell, M)
    cosines, sines = eigenfunctions_periodic(X, w0, M)

    with numpyro.plate("cos_basis", M):
        beta_cos = numpyro.sample("beta_cos", dist.Normal(0, 1))
    
    with numpyro.plate("sin_basis", M - 1):
        beta_sin = numpyro.sample("beta_sin", dist.Normal(0, 1))

    zero = jnp.array([0.0])
    beta_sin = jnp.concatenate((zero, beta_sin))

    f = numpyro.deterministic("f_hsgp", cosines @ (q2 * beta_cos) + sines @ (q2 * beta_sin))

    # nominal likelihood, ignored for prior prediction
    sigma = numpyro.sample("sigma", dist.HalfNormal(1.0))
    with numpyro.plate("obs", X.shape[0]):
        numpyro.sample("y", dist.Normal(f, sigma), obs=y)

prior_predictive = Predictive(model, num_samples=1000)
prior_predictions = prior_predictive(jax.random.PRNGKey(0), X=jnp.squeeze(X1))

Visually, this has the same variance as our pymc implementation

idata_numpyro = az.from_numpyro(prior=prior_predictions)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
f_hsgp_numpyro = az.extract(idata_numpyro.prior, var_names="f_hsgp")
axs[0].plot(f_hsgp_numpyro.data);
axs[0].set_ylim([-3, 3])

# looks the same as our pymc hsgp
f_hsgp = az.extract(idata.prior, var_names="f_hsgp")
axs[1].plot(f_hsgp.data);
axs[1].set_ylim([-3, 3]);

This rules out that iv is not being called by PyTensor (wild guess above). So either the variance of the base periodic GP is wrong or of the HSGP. I can't see where I've got the maths wrong. What's controlling the up and down shift? The fact that centring the samples makes the test pass suggests the GP variance of the samples (i.e. alpha in the paper) is correct.

Aside: any idea why the numpyro/Stan models would define the period using the standard deviation of the domain: period = 365.25, w0 = x.std() * (jnp.pi * 2 / period) (also here in Stan)? Just checked and this does not affect our problem because we're looking at the overall variance

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UserWarning: Optimization Warning: The Op iv does not provide a C implementation. As well as being potentially slow, this a ...

That's a benign warning, you can ignore. It shouldn't affect results (only speed).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I found it finally! So, in your numpyro code I changed alpha from 1 to sqrt(2), and now samples match those from pm.gp.cov.Periodic. I think they must have missed a factor of two somewhere in the paper..? Another way to see it is to flatten the prior samples together into a pile and plot something like:

# gp samples flattened should match standard normal draws below
plt.hist(np.random.randn(20000), 100, alpha=0.5, density=True, color="k");

# f_pymc samples from pm.gp.cov.Periodic
plt.hist(f_pymc.data.flatten(), 100, alpha=0.5, density=True);

# f_hsgp samples from your numpyro code
plt.hist(f_hsgp.data.flatten(), 100, alpha=0.5, density=True);

Aside: any idea why the numpyro/Stan models would define the period using the standard deviation of the domain: period = 365.25, w0 = x.std() * (jnp.pi * 2 / period) (also here in Stan)? Just checked and this does not affect our problem because we're looking at the overall varianc

No... not sure. The period you have in your code and your numpyro code is correct though.

Copy link
Contributor Author

@theorashid theorashid Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, are we sure on this? Like, is it a hack or do we back the maths of 2? Are we sure it isn't a coincidence? If so, it's an easy fix:

f = (
    self.mean_func(X)
+    + 2 * (
            phi_cos @ (psd * self._beta[:m0])  # type: ignore
             + phi_sin[..., 1:] @ (psd[1:] * self._beta[m0:]
      )  # type: ignore
)

If we're happy with it, I'll implement it and put a note, and probably shoot something off to the paper author. What do you think?

UPDATE: test does not pass after doing this, even though the histograms from the flattened samples visually look okay.

Another thing:

Is the alpha in the paper a variance or a sigma? If we compare it to the kernel cookbook, I think alpha is sigma**2.

I think the numpyro code has an error. They have the line q2 = (c * alpha**2 / jnp.exp(a)) * modified_bessel_first_kind(J, a). Instead, I follow the paper more closely, and do not include alpha in the q terms (just q2 = c * pt.iv(J, a) / pt.exp(a)). If we look at equation B.8 of the paper, It actually has alpha ** 0.5 as a coefficient – i.e sigma. So I think the numpyro code should say alpha ** 0.5 or just alpha if they're looking at a standard deviation prior.

For reference, the corresponding two lines of Stan code are this and this – looks like exp(log(alpha)+ ... that their sigma_f2 is actually a variance too.

I'm not sure where this leaves us, because we specify cov_funcs like sigma ** 2 * pm.gp.cov.Periodic(1, ls=ell, period=per). It would be annoying to have to explain that specifically for the periodic HSGP case, we should do sigma * pm.gp.cov.Periodic(1, ls=ell, period=per). What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UPDATE: test does not pass after doing this

Dang I didn't check that. Will take another look again tomorrow and get back to you. I'm wondering if the test is in "no way" territory, or if it's saying "kind of close". Was hoping we got lucky with the factor of 2.

It would be annoying to have to explain that specifically for the periodic HSGP case, we should do sigma * pm.gp.cov.Periodic(1, ls=ell, period=per)

For sure, I would hope that HSGP periodic can be a drop in replacement for regular periodic, without having to rethink your priors.

@bwengals
Copy link
Contributor

One final thing.. so it looks like we've been running this assuming the user passes in

cov_func = pm.gp.cov.Periodic(...)

But really they're probably going to pass in

eta = pm.HalfNormal(...)
cov_func = eta**2 * pm.gp.cov.Periodic(...)

In a few places in your code you've got to check if the incoming cov_func is periodic or not. Those checks should be replaced by something a bit more sophisticated, something like this function should do it:

def is_cov_func_periodic(cov_func):
    if isinstance(cov_func, pm.gp.cov.Periodic):
        return True
    
    elif isinstance(cov_func, pm.gp.cov.Prod) and len(cov_func._factor_list) == 2:
        # check that one of the factors is a Periodic cov_func
        # and that the other is a scaling factor, so TensorLike
        factor_set = set(cov_func._factor_list)
        total = 0
        for factor in factor_set:
            if isinstance(factor, pm.gp.cov.Periodic):
                total += 1
            
            if isinstance(factor, pt.TensorLike):
                total += 1
       
        if total == 2:
            return True
        
    return False

In the cov library, if you multiply a covariance function by a scalar it becomes a Prod covariances with a _factor_list attribute that keeps track of what objects were multiplied. This function checks that it's length 2 and that one element is pm.gp.cov.Periodic and the other is a TensorLike that should be a scalar. Am using a for loop to check each of the two elements even though that's probably not very clever...

@theorashid
Copy link
Contributor Author

theorashid commented Oct 26, 2023

So TensorLike = Union[np.ndarray, TensorVariable], meaning isinstance(1.1, TensorLike returns False. But a scalar multiplier would be a value covariance function I think?

Also, type("a" * cov_func) is pymc.gp.cov.Prod, which is of course trash.

Also, we could conceivably have len(cov_func._factor_list) > 2, e.g. some sort of horseshoe prior thrown in, or even a few more scalars.

How about I just check whether pm.gp.cov.Periodic is in anywhere in cov_func._factor_list? And I'll assume either the gp module or pytensor gets angry enough with "a" * cov_func that I don't need to implement any logic. The excludes the case where someone might do cov_func1 * cov_func2 though... but again, might there be an error somewhere else?

Aside: is sampling from hsgp periodic really slow for you? Or is that just my machine. It's nearly instant for pm.Latent. Actually I think it's to do with iv not having a C implementation, and pm.sample_prior_predictive needs C. Sorry it's been a while

@bwengals
Copy link
Contributor

So TensorLike = Union[np.ndarray, TensorVariable], meaning isinstance(1.1, TensorLike returns False. But a scalar multiplier would be a value covariance function I think?

Also, type("a" * cov_func) is pymc.gp.cov.Prod, which is of course trash.

Also, we could conceivably have len(cov_func._factor_list) > 2, e.g. some sort of horseshoe prior thrown in, or even a few more scalars.

How about I just check whether pm.gp.cov.Periodic is in anywhere in cov_func._factor_list? And I'll assume either the gp module or pytensor gets angry enough with "a" * cov_func that I don't need to implement any logic. The excludes the case where someone might do cov_func1 * cov_func2 though... but again, might there be an error somewhere else?

True. All good points... Another thought is changing the API a bit and asking the user to pass in the scale parameter to HSGP just for the periodic case? Maybe to do this you'd have to first check if anything in _factor_list is Periodic, and if so throw an error telling the user to pass in a plain pm.gp.cov.Periodic and the scale parameter. Would be nice ofc not to do this but HSGP periodic is so special-casey already.

Otherwise yeah, the way it's designed there really isn't a super clean way to check and grab the scale parameter off of pm.gp.cov.Periodic. With the regular HSGP, the scale gets worked out by Covariance._merge_factors_psd (which is pretty ugly on its own).

Aside: is sampling from hsgp periodic really slow for you? Or is that just my machine. It's nearly instant for pm.Latent. Actually I think it's to do with iv not having a C implementation, and pm.sample_prior_predictive needs C. Sorry it's been a while

Yeah, pretty slow with the base pymc sampler. I figured it was the came thing with iv. It sped up a ton with the numpyro and blackjax samplers, since Jax can call iv via tensorflow. Also tried nutpie without much luck but I think the same issue with no C implementation is present there.

pymc/gp/hsgp_approx.py Outdated Show resolved Hide resolved
pymc/gp/hsgp_approx.py Outdated Show resolved Hide resolved
pymc/gp/hsgp_approx.py Outdated Show resolved Hide resolved
pymc/gp/hsgp_approx.py Outdated Show resolved Hide resolved
Copy link
Contributor

@bwengals bwengals left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, just some minor stuff. Where'd you land on the scale parameter / alpha? I might have missed it in the code.

@theorashid
Copy link
Contributor Author

theorashid commented Nov 22, 2023

Took your advice and simplified the HSGPPeriodic case so users only have to pass an integer for m rather than a sequence.

I have added a scale parameter. The user has to pass cov_func as Periodic only (no eta ** 2 * Periodic()). The rest follows the maths in the paper of eq B8. The scale parameter is the sqrt of the variance, and it has been incorporated by multiplying all of the coefficients $\tilde{q}_j$: psd = self.scale * self.cov_func.power_spectral_density_approx(J).

Screenshot 2023-11-22 at 12 48 23

Copy link
Contributor

@bwengals bwengals left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a lot of work, thank you for sticking with it!

@bwengals
Copy link
Contributor

bwengals commented Nov 23, 2023

Red x from unrelated failing docs build only (looks related to myst). Might be fixed @theorashid by merging main into here and updating deps, if something changed with myst. If that's not the issue, is this safe to merge @ricardoV94?

Edit: crossed out bc no

@bwengals bwengals merged commit 01ddcb8 into pymc-devs:main Dec 10, 2023
21 of 22 checks passed
@ricardoV94 ricardoV94 added enhancements GP Gaussian Process labels Dec 10, 2023
@ricardoV94 ricardoV94 changed the title HSGP periodic Implement periodic kernel for HSGP Dec 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements GP Gaussian Process
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants