-
Notifications
You must be signed in to change notification settings - Fork 2
feat: Dynamic data-set sizes #22
Description
Feature Request
For some applications with GPs, like Bayesian Optimization, the dataset grows dynamically with time. Unfortunately, dynamic array sizes with Jax jit compiled functions causes the computation to be re-compiled for every different buffer size. This means that the computation will take much longer than should be neccesary...
In my own code I was able to work around the recompilation with dynamic shapes by using a fixed buffer and modifying the Gaussian Process logic through a dynamic masks that treats all data at index i>t as independent of j<=t in the Kernel computation. One downside is of course that all iterations from t=1, ... n, will induce a time and memory complexity proportional to n. For most applications, however, the speed-up provided by jit makes this completely negligible.
I am not sure whether a solution already exists within gpjax as I'm still relatively new to this cool library :).
Describe Preferred Solution
I believe something like this can be implemented as follows, though I haven't yet tried.
- Inherit from
gpx.Datasetand create a sub-classgpx.OnlineDataset(gpx.Dataset)with a new integertime_stepvariable and requiring the exact shapes of the data-buffer for initialization. - Add a method to add data to the buffer through
jax.ops. - Make a
DynamicKernelclass that wraps around the standard kernelKcomputation along the lines ofK(a, b, a_idx, b_idx, t)that returnsK(a, b)ifa_idx <= b_idx <= tand otherwiseint(a_idx == b_idx).
Describe Alternatives
NA
Related Code
Example of the jit recompilation based on the Documentation Regression notebook:
import gpjax as gpx
from jax import jit, random
from jax import numpy as jnp
n = 5
x = jnp.linspace(-1, 1, n)[..., None]
y = jnp.sin(x)[..., None]
xtest = jnp.linspace(-2, 2, 100)[..., None]
@jit
def gp_predict(xs, x_train, y_train):
posterior = gpx.Prior(kernel=gpx.RBF()) * gpx.Gaussian(num_datapoints=len(x_train))
params, *_ = gpx.initialise(
posterior, random.PRNGKey(0), kernel={"lengthscale": jnp.array([0.5])}
).unpack()
post_predictive = posterior(params, gpx.Dataset(X=x_train, y=y_train))
out_dist = post_predictive(xs)
return out_dist.mean(), out_dist.stddev()
# First call - compile
print('compile')
for i in range(len(x)):
%time gp_predict(xtest, x[:i+1], y[:i+1])
print()
# Second call - use cached
print('jitted')
for i in range(len(x)):
%time gp_predict(xtest, x[:i+1], y[:i+1])
# Output
compile
CPU times: user 519 ms, sys: 1.64 ms, total: 521 ms
Wall time: 293 ms
CPU times: user 1.06 s, sys: 0 ns, total: 1.06 s
Wall time: 316 ms
CPU times: user 956 ms, sys: 17.9 ms, total: 974 ms
Wall time: 219 ms
jitted
CPU times: user 3.66 ms, sys: 443 µs, total: 4.11 ms
Wall time: 2.46 ms
CPU times: user 2.89 ms, sys: 348 µs, total: 3.23 ms
Wall time: 1.84 ms
CPU times: user 894 µs, sys: 0 ns, total: 894 µs
Wall time: 568 µsAdditional Context
Example issue on the Jax: jax-ml/jax#2521
If the feature request is approved, would you be willing to submit a PR?
When I have time available I can try and port my solution to the gpjax API, though, I am still quite new to the library.