You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on May 11, 2023. It is now read-only.
For some applications with GPs, like Bayesian Optimization, the dataset grows dynamically with time. Unfortunately, dynamic array sizes with Jax jit compiled functions causes the computation to be re-compiled for every different buffer size. This means that the computation will take much longer than should be neccesary...
In my own code I was able to work around the recompilation with dynamic shapes by using a fixed buffer and modifying the Gaussian Process logic through a dynamic masks that treats all data at index i>t as independent of j<=t in the Kernel computation. One downside is of course that all iterations from t=1, ... n, will induce a time and memory complexity proportional to n. For most applications, however, the speed-up provided by jit makes this completely negligible.
I am not sure whether a solution already exists within gpjax as I'm still relatively new to this cool library :).
Describe Preferred Solution
I believe something like this can be implemented as follows, though I haven't yet tried.
Inherit from gpx.Dataset and create a sub-class gpx.OnlineDataset(gpx.Dataset) with a new integer time_step variable and requiring the exact shapes of the data-buffer for initialization.
Add a method to add data to the buffer through jax.ops.
Make a DynamicKernel class that wraps around the standard kernel K computation along the lines of K(a, b, a_idx, b_idx, t) that returns K(a, b) if a_idx <= b_idx <= t and otherwise int(a_idx == b_idx).
Describe Alternatives
NA
Related Code
Example of the jit recompilation based on the Documentation Regression notebook:
Feature Request
For some applications with GPs, like Bayesian Optimization, the dataset grows dynamically with time. Unfortunately, dynamic array sizes with Jax
jit
compiled functions causes the computation to be re-compiled for every different buffer size. This means that the computation will take much longer than should be neccesary...In my own code I was able to work around the recompilation with dynamic shapes by using a fixed buffer and modifying the Gaussian Process logic through a dynamic masks that treats all data at index
i>t
as independent ofj<=t
in the Kernel computation. One downside is of course that all iterations from t=1, ... n, will induce a time and memory complexity proportional ton
. For most applications, however, the speed-up provided byjit
makes this completely negligible.I am not sure whether a solution already exists within
gpjax
as I'm still relatively new to this cool library :).Describe Preferred Solution
I believe something like this can be implemented as follows, though I haven't yet tried.
gpx.Dataset
and create a sub-classgpx.OnlineDataset(gpx.Dataset)
with a new integertime_step
variable and requiring the exact shapes of the data-buffer for initialization.jax.ops
.DynamicKernel
class that wraps around the standard kernelK
computation along the lines ofK(a, b, a_idx, b_idx, t)
that returnsK(a, b)
ifa_idx <= b_idx <= t
and otherwiseint(a_idx == b_idx)
.Describe Alternatives
NA
Related Code
Example of the
jit
recompilation based on the Documentation Regression notebook:Additional Context
Example issue on the Jax: jax-ml/jax#2521
If the feature request is approved, would you be willing to submit a PR?
When I have time available I can try and port my solution to the gpjax API, though, I am still quite new to the library.
The text was updated successfully, but these errors were encountered: