Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

feat: Dynamic data-set sizes #22

Open
joeryjoery opened this issue Nov 21, 2022 · 0 comments
Open

feat: Dynamic data-set sizes #22

joeryjoery opened this issue Nov 21, 2022 · 0 comments
Labels
enhancement New feature or request

Comments

@joeryjoery
Copy link

joeryjoery commented Nov 21, 2022

Feature Request

For some applications with GPs, like Bayesian Optimization, the dataset grows dynamically with time. Unfortunately, dynamic array sizes with Jax jit compiled functions causes the computation to be re-compiled for every different buffer size. This means that the computation will take much longer than should be neccesary...

In my own code I was able to work around the recompilation with dynamic shapes by using a fixed buffer and modifying the Gaussian Process logic through a dynamic masks that treats all data at index i>t as independent of j<=t in the Kernel computation. One downside is of course that all iterations from t=1, ... n, will induce a time and memory complexity proportional to n. For most applications, however, the speed-up provided by jit makes this completely negligible.

I am not sure whether a solution already exists within gpjax as I'm still relatively new to this cool library :).

Describe Preferred Solution

I believe something like this can be implemented as follows, though I haven't yet tried.

  1. Inherit from gpx.Dataset and create a sub-class gpx.OnlineDataset(gpx.Dataset) with a new integer time_step variable and requiring the exact shapes of the data-buffer for initialization.
  2. Add a method to add data to the buffer through jax.ops.
  3. Make a DynamicKernel class that wraps around the standard kernel K computation along the lines of K(a, b, a_idx, b_idx, t) that returns K(a, b) if a_idx <= b_idx <= t and otherwise int(a_idx == b_idx).

Describe Alternatives

NA

Related Code

Example of the jit recompilation based on the Documentation Regression notebook:

import gpjax as gpx
from jax import jit, random
from jax import numpy as jnp


n = 5

x = jnp.linspace(-1, 1, n)[..., None]
y = jnp.sin(x)[..., None]

xtest = jnp.linspace(-2, 2, 100)[..., None]


@jit
def gp_predict(xs, x_train, y_train):
    posterior = gpx.Prior(kernel=gpx.RBF()) * gpx.Gaussian(num_datapoints=len(x_train))
    
    params, *_ = gpx.initialise(
        posterior, random.PRNGKey(0), kernel={"lengthscale": jnp.array([0.5])}
    ).unpack()
    
    post_predictive = posterior(params, gpx.Dataset(X=x_train, y=y_train))
    out_dist = post_predictive(xs)
    
    return out_dist.mean(), out_dist.stddev()


# First call - compile
print('compile')
for i in range(len(x)):
    %time gp_predict(xtest, x[:i+1], y[:i+1])
print()


# Second call - use cached
print('jitted')    
for i in range(len(x)):
    %time gp_predict(xtest, x[:i+1], y[:i+1])

# Output
compile
CPU times: user 519 ms, sys: 1.64 ms, total: 521 ms
Wall time: 293 ms
CPU times: user 1.06 s, sys: 0 ns, total: 1.06 s
Wall time: 316 ms
CPU times: user 956 ms, sys: 17.9 ms, total: 974 ms
Wall time: 219 ms

jitted
CPU times: user 3.66 ms, sys: 443 µs, total: 4.11 ms
Wall time: 2.46 ms
CPU times: user 2.89 ms, sys: 348 µs, total: 3.23 ms
Wall time: 1.84 ms
CPU times: user 894 µs, sys: 0 ns, total: 894 µs
Wall time: 568 µs

Additional Context

Example issue on the Jax: jax-ml/jax#2521

If the feature request is approved, would you be willing to submit a PR?

When I have time available I can try and port my solution to the gpjax API, though, I am still quite new to the library.

@joeryjoery joeryjoery added the enhancement New feature or request label Nov 21, 2022
@thomaspinder thomaspinder transferred this issue from JaxGaussianProcesses/GPJax Mar 19, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant