Skip to content
This repository was archived by the owner on Mar 31, 2025. It is now read-only.
This repository was archived by the owner on Mar 31, 2025. It is now read-only.

Model compiling twice when using jax==0.2.10 or later #212

@wil-j-wil

Description

@wil-j-wil

Hi,

I recently updated JAX, and noticed that my runtime increased. I have managed to isolate the issue to be that my objax model is compiling itself twice, i.e., on the second training iteration the model seems to be recompiling for some reason. This only happens for JAX versions 0.2.10 or later.

Any idea what the cause of this may be?

I hope this toy example is clear enough. I am using objax==1.3.1 and jaxlib==0.1.60

import objax
import jax.numpy as np
from jax import vmap
import time


class GaussianLikelihood(objax.Module):
    """
    The Gaussian likelihood
    """
    def __init__(self,
                 variance=0.1):
        """
        :param variance: The observation noise variance
        """
        self.variance = objax.TrainVar(np.array(variance))

    def expected_log_lik(self, y, post_mean, post_cov):
        """
        """
        exp_log_lik = (
            -0.5 * np.log(2 * np.pi)
            - 0.5 * np.log(self.variance.value)
            - 0.5 * ((y - post_mean) ** 2 + post_cov) / self.variance.value
        )
        return exp_log_lik


class GP(objax.Module):
    """
    A GP model
    """
    def __init__(self,
                 likelihood,
                 X,
                 Y):
        self.X = np.array(X)
        self.Y = np.array(Y)
        self.likelihood = likelihood
        self.posterior_mean = objax.StateVar(np.zeros([X.shape[0], 1, 1]))
        self.posterior_variance = objax.StateVar(np.ones([X.shape[0], 1, 1]))

    def energy(self):
        """
        """
        mean_f, cov_f = self.posterior_mean.value, self.posterior_variance.value

        E = vmap(self.likelihood.expected_log_lik)(
            self.Y,
            mean_f,
            cov_f
        )

        return np.sum(E)


# generate some data
N = 1000000
x = np.linspace(-10, 100, num=N)
y = np.sin(x)

# set up the model
lik = GaussianLikelihood(variance=1.0)
gp_model = GP(likelihood=lik, X=x, Y=y)

energy = objax.GradValues(gp_model.energy, gp_model.vars())

lr_adam = 0.1
iters = 10
opt = objax.optimizer.Adam(gp_model.vars())


def train_op():
    dE, E = energy()  # compute energy and its gradients w.r.t. hypers
    return dE, E


train_op = objax.Jit(train_op, gp_model.vars())

t0 = time.time()
for i in range(1, iters + 1):
    t2 = time.time()
    grad, loss = train_op()
    opt(lr_adam, grad)
    t3 = time.time()
    # print('iter %2d, energy: %1.4f' % (i, loss[0]))
    print('iter time: %2.2f secs' % (t3-t2))
t1 = time.time()
print('optimisation time: %2.2f secs' % (t1-t0))

Running this script with jax==0.2.9 gives

iter time: 0.12 secs
iter time: 0.01 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.01 secs
iter time: 0.01 secs
optimisation time: 0.28 secs

Running the script with jax==0.2.10 gives

iter time: 0.14 secs
iter time: 0.08 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
optimisation time: 0.38 secs

As you can see, there is a significant difference in the 2nd iteration, as if the model is re-compiling itself.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions