This repository was archived by the owner on Mar 31, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 73
This repository was archived by the owner on Mar 31, 2025. It is now read-only.
Model compiling twice when using jax==0.2.10 or later #212
Copy link
Copy link
Open
Description
Hi,
I recently updated JAX, and noticed that my runtime increased. I have managed to isolate the issue to be that my objax model is compiling itself twice, i.e., on the second training iteration the model seems to be recompiling for some reason. This only happens for JAX versions 0.2.10 or later.
Any idea what the cause of this may be?
I hope this toy example is clear enough. I am using objax==1.3.1 and jaxlib==0.1.60
import objax
import jax.numpy as np
from jax import vmap
import time
class GaussianLikelihood(objax.Module):
"""
The Gaussian likelihood
"""
def __init__(self,
variance=0.1):
"""
:param variance: The observation noise variance
"""
self.variance = objax.TrainVar(np.array(variance))
def expected_log_lik(self, y, post_mean, post_cov):
"""
"""
exp_log_lik = (
-0.5 * np.log(2 * np.pi)
- 0.5 * np.log(self.variance.value)
- 0.5 * ((y - post_mean) ** 2 + post_cov) / self.variance.value
)
return exp_log_lik
class GP(objax.Module):
"""
A GP model
"""
def __init__(self,
likelihood,
X,
Y):
self.X = np.array(X)
self.Y = np.array(Y)
self.likelihood = likelihood
self.posterior_mean = objax.StateVar(np.zeros([X.shape[0], 1, 1]))
self.posterior_variance = objax.StateVar(np.ones([X.shape[0], 1, 1]))
def energy(self):
"""
"""
mean_f, cov_f = self.posterior_mean.value, self.posterior_variance.value
E = vmap(self.likelihood.expected_log_lik)(
self.Y,
mean_f,
cov_f
)
return np.sum(E)
# generate some data
N = 1000000
x = np.linspace(-10, 100, num=N)
y = np.sin(x)
# set up the model
lik = GaussianLikelihood(variance=1.0)
gp_model = GP(likelihood=lik, X=x, Y=y)
energy = objax.GradValues(gp_model.energy, gp_model.vars())
lr_adam = 0.1
iters = 10
opt = objax.optimizer.Adam(gp_model.vars())
def train_op():
dE, E = energy() # compute energy and its gradients w.r.t. hypers
return dE, E
train_op = objax.Jit(train_op, gp_model.vars())
t0 = time.time()
for i in range(1, iters + 1):
t2 = time.time()
grad, loss = train_op()
opt(lr_adam, grad)
t3 = time.time()
# print('iter %2d, energy: %1.4f' % (i, loss[0]))
print('iter time: %2.2f secs' % (t3-t2))
t1 = time.time()
print('optimisation time: %2.2f secs' % (t1-t0))Running this script with jax==0.2.9 gives
iter time: 0.12 secs
iter time: 0.01 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.01 secs
iter time: 0.01 secs
optimisation time: 0.28 secs
Running the script with jax==0.2.10 gives
iter time: 0.14 secs
iter time: 0.08 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
optimisation time: 0.38 secs
As you can see, there is a significant difference in the 2nd iteration, as if the model is re-compiling itself.
Metadata
Metadata
Assignees
Labels
No labels