Skip to content
This repository was archived by the owner on Mar 31, 2025. It is now read-only.
This repository was archived by the owner on Mar 31, 2025. It is now read-only.

objax.Jit reports error when StateVar is added to the vc argument #219

@shenzebang

Description

@shenzebang

Hello everyone, I am new to jax and I am encountering the following problem. Can someone please help me to resolve it?

When I add the StateVar to the vc argument of objax.Jit, it reports

File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 232, in train_critic
_loss = train_op(data, subkey)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 258, in call
output, changes = self._call(self.vc.tensors(), kwargs, *args)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 332, in cache_miss
out_flat = xla.xla_call(
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1402, in bind
return call_bind(self, fun, *args, **params)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1405, in process
return trace.process_call(self, fun, tracers, params)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 600, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 576, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 260, in memoized_fun
ans = call(fun, *args)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 652, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1209, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 179, in call_wrapped
ans = gen.send(ans)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1363, in process_env_traces
outs = map(trace.full_raise, outs)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/util.py", line 40, in safe_map
return list(map(f, *args))
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 381, in full_raise
return self.lift(val)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1014, in new_const
self.frame.tracers.append(tracer)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1002, in frame
return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
IndexError: tuple index out of range

Here is the relevant code:


def train_critic(key):
    key, subkey = random.split(key)
    net = UNet(marginal_prob_std_fn, subkey)
    critic = UNet(marginal_prob_std_fn, subkey)
    objax.io.load_var_collection(os.path.join(model_path, 'scorenet.npz'), net.vars())
    objax.io.load_var_collection(os.path.join(model_path, 'scorenet.npz'), critic.vars())

    def critic_gv_fn(x_init, key):
        # define v for the Hutchinson’s Estimator
        key, subkey = random.split(key)
        v = random.normal(subkey, tuple([20] + list(x_init.shape)))
        # define the initial states
        t_0 = jnp.zeros(x_init.shape[0])
        score_init = net(t_0, x_init, training=False)
        critic_loss_init = jnp.zeros(1)
        critic_grad_init = [jnp.zeros_like(_var) for _var in critic.vars().subset(is_a=TrainVar)]
        state_init = [x_init, score_init, critic_loss_init, critic_grad_init]

        def ode_func(states, t):
            x = states[0]
            score = states[1]
            _t = jnp.ones([x.shape[0]]) * t
            diffusion_weight = diffusion_coeff_fn(t)
            score_pred = net(_t, x, training=False)
            dx = -.5 * (diffusion_weight ** 2) * score_pred

            f = lambda x: net(_t, x, training=False)

            def divergence_fn(_x, _v):
                # Hutchinson’s Estimator
                # computes the divergence of net at x with random vector v
                _, u = jvp(f, (_x,), (_v,))
                return jnp.sum(jnp.dot(u, _v))

            batch_div_fn = jax.vmap(divergence_fn, in_axes=[None, 0])

            def batch_div(x):
                return batch_div_fn(x, v).mean(axis=0)

            grad_div_fn = grad(batch_div)

            dscore_1 = - grad_div_fn(x)
            dscore_2 = - jvp(f, (x,), (score,))[1]  # f(x), df/dx * v = jvp(f, x, v)
            dscore = dscore_1 + dscore_2

            def dcritic_loss_fn(_x):
                critic_pred = critic(_t, _x, training=True)
                loss = ((critic_pred - score_pred) ** 2).sum(axis=(1, 2, 3)).mean()
                return loss

            dc_gv = objax.GradValues(dcritic_loss_fn, critic.vars())
            dcritic_grad, dcritic_loss = dc_gv(x)
            dcritic_loss = dcritic_loss[0][None]
            dstates = [dx, dscore, dcritic_loss, dcritic_grad]

            return dstates

        tspace = np.array((0., 1.))

        result = odeint(ode_func, state_init, tspace, atol=tolerance, rtol=tolerance)

        _g = [_var[1] for _var in result[3]]
        return _g, result[2][1], critic.vars().subset(is_a=StateVar).tensors()

    # define optimizer
    opt = objax.optimizer.Adam(critic.vars())

    # define train_op
    def train_op(x, key):
        g, v, svars_t = critic_gv_fn(x, key)
        critic.vars().subset(is_a=StateVar).assign(svars_t)
        opt(lr, g)
        return v

    train_op = objax.Jit(train_op, critic.vars().subset(is_a=TrainVar) + opt.vars())
    # reports error if I set "train_op = objax.Jit(train_op, critic.vars() + opt.vars())"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions