-
Notifications
You must be signed in to change notification settings - Fork 78
There are some problems when I run the GCIVL? #42
Description
Hi, Park, there are some problems as follow:
Traceback (most recent call last):
File "/mnt/Data_1/gaokai/lbx/OG/impls/main1.py", line 168, in
app.run(main)
File "/mnt/Data_1/gaokai/miniconda3/envs/hiql/lib/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/mnt/Data_1/gaokai/miniconda3/envs/hiql/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/mnt/Data_1/gaokai/lbx/OG/impls/main1.py", line 102, in main
agent, update_info = agent.update(batch, i)
File "/mnt/Data_1/gaokai/lbx/OG/impls/agents/gcivl0.py", line 147, in update
new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
File "/mnt/Data_1/gaokai/lbx/OG/impls/utils/flax_utils.py", line 137, in apply_loss_fn
grads, info = jax.grad(loss_fn, has_aux=True)(self.params)
File "/mnt/Data_1/gaokai/lbx/OG/impls/agents/gcivl0.py", line 145, in loss_fn
return self.total_loss(batch, t, grad_params, rng=rng)
File "/mnt/Data_1/gaokai/lbx/OG/impls/agents/gcivl0.py", line 118, in total_loss
value_loss, value_info = self.value_loss(batch, grad_params)
File "/mnt/Data_1/gaokai/lbx/OG/impls/agents/gcivl0.py", line 42, in value_loss
(next_v1, next_v2) = self.network.select('target_value')(batch['next_observations'], batch['value_goals'])
File "/mnt/Data_1/gaokai/miniconda3/envs/hiql/lib/python3.9/site-packages/jax/src/numpy/array_methods.py", line 739, in op
return getattr(self.aval, f"{name}")(self, *args)
File "/mnt/Data_1/gaokai/miniconda3/envs/hiql/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py", line 352, in _geti tem
return lax_numpy._rewriting_take(self, item)
File "/mnt/Data_1/gaokai/miniconda3/envs/hiql/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 6593, in _rewriti ng_take
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
File "/mnt/Data_1/gaokai/miniconda3/envs/hiql/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 6674, in _split_i ndex_for_jit
raise TypeError(f"JAX does not support string indexing; got {idx=}")
TypeError: JAX does not support string indexing; got idx=('next_observations',)