-
Notifications
You must be signed in to change notification settings - Fork 73
objax.Jit reports error when StateVar is added to the vc argument #219
Description
Hello everyone, I am new to jax and I am encountering the following problem. Can someone please help me to resolve it?
When I add the StateVar to the vc argument of objax.Jit, it reports
File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 232, in train_critic
_loss = train_op(data, subkey)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 258, in call
output, changes = self._call(self.vc.tensors(), kwargs, *args)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 332, in cache_miss
out_flat = xla.xla_call(
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1402, in bind
return call_bind(self, fun, *args, **params)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1405, in process
return trace.process_call(self, fun, tracers, params)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 600, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 576, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 260, in memoized_fun
ans = call(fun, *args)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 652, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1209, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 179, in call_wrapped
ans = gen.send(ans)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1363, in process_env_traces
outs = map(trace.full_raise, outs)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/util.py", line 40, in safe_map
return list(map(f, *args))
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 381, in full_raise
return self.lift(val)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1014, in new_const
self.frame.tracers.append(tracer)
File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1002, in frame
return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
IndexError: tuple index out of range
Here is the relevant code:
def train_critic(key):
key, subkey = random.split(key)
net = UNet(marginal_prob_std_fn, subkey)
critic = UNet(marginal_prob_std_fn, subkey)
objax.io.load_var_collection(os.path.join(model_path, 'scorenet.npz'), net.vars())
objax.io.load_var_collection(os.path.join(model_path, 'scorenet.npz'), critic.vars())
def critic_gv_fn(x_init, key):
# define v for the Hutchinson’s Estimator
key, subkey = random.split(key)
v = random.normal(subkey, tuple([20] + list(x_init.shape)))
# define the initial states
t_0 = jnp.zeros(x_init.shape[0])
score_init = net(t_0, x_init, training=False)
critic_loss_init = jnp.zeros(1)
critic_grad_init = [jnp.zeros_like(_var) for _var in critic.vars().subset(is_a=TrainVar)]
state_init = [x_init, score_init, critic_loss_init, critic_grad_init]
def ode_func(states, t):
x = states[0]
score = states[1]
_t = jnp.ones([x.shape[0]]) * t
diffusion_weight = diffusion_coeff_fn(t)
score_pred = net(_t, x, training=False)
dx = -.5 * (diffusion_weight ** 2) * score_pred
f = lambda x: net(_t, x, training=False)
def divergence_fn(_x, _v):
# Hutchinson’s Estimator
# computes the divergence of net at x with random vector v
_, u = jvp(f, (_x,), (_v,))
return jnp.sum(jnp.dot(u, _v))
batch_div_fn = jax.vmap(divergence_fn, in_axes=[None, 0])
def batch_div(x):
return batch_div_fn(x, v).mean(axis=0)
grad_div_fn = grad(batch_div)
dscore_1 = - grad_div_fn(x)
dscore_2 = - jvp(f, (x,), (score,))[1] # f(x), df/dx * v = jvp(f, x, v)
dscore = dscore_1 + dscore_2
def dcritic_loss_fn(_x):
critic_pred = critic(_t, _x, training=True)
loss = ((critic_pred - score_pred) ** 2).sum(axis=(1, 2, 3)).mean()
return loss
dc_gv = objax.GradValues(dcritic_loss_fn, critic.vars())
dcritic_grad, dcritic_loss = dc_gv(x)
dcritic_loss = dcritic_loss[0][None]
dstates = [dx, dscore, dcritic_loss, dcritic_grad]
return dstates
tspace = np.array((0., 1.))
result = odeint(ode_func, state_init, tspace, atol=tolerance, rtol=tolerance)
_g = [_var[1] for _var in result[3]]
return _g, result[2][1], critic.vars().subset(is_a=StateVar).tensors()
# define optimizer
opt = objax.optimizer.Adam(critic.vars())
# define train_op
def train_op(x, key):
g, v, svars_t = critic_gv_fn(x, key)
critic.vars().subset(is_a=StateVar).assign(svars_t)
opt(lr, g)
return v
train_op = objax.Jit(train_op, critic.vars().subset(is_a=TrainVar) + opt.vars())
# reports error if I set "train_op = objax.Jit(train_op, critic.vars() + opt.vars())"