-
Notifications
You must be signed in to change notification settings - Fork 19
Closed
netket/folx
#1Description
Recently came across this somewhat obscure error because of a missing pvary.
The missing pvary is the one I mentioned at the end of #36 (comment), but there I had a different reason in mind of why it would be needed.
Reproducer:
import jax
import jax.numpy as jnp
from folx import forward_laplacian
from functools import partial
def f(w, x):
return jax.lax.integer_pow(x @ w, 1)
@jax.smap(out_axes=0,in_axes=(None, 0), axis_name='i')
@partial(jax.vmap, in_axes=(None, 0))
def test(w, x):
return forward_laplacian(partial(f, w))(x)
x = jnp.ones((1,16))
w = jnp.ones((16,16))
with jax.set_mesh(jax.sharding.Mesh(jax.devices(), 'i')):
test(w,x)ERROR:[folx](/private/tmp/bug.py:7:11 (f)) - Error in operation integer_pow.
Traceback (most recent call last):
File "/private/tmp/bug.py", line 18, in <module>
test(w,x)
File "/private/tmp/bug.py", line 12, in test
return forward_laplacian(partial(f, w))(x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/interpreter.py", line 309, in wrapped
out = eval_jaxpr_with_forward_laplacian(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/interpreter.py", line 226, in eval_jaxpr_with_forward_laplacian
raise e
File "/Users/clemens/folx/folx/interpreter.py", line 222, in eval_jaxpr_with_forward_laplacian
outvals = eval_laplacian(eqn, invals)
^^^^^^^
File "/Users/clemens/folx/folx/interpreter.py", line 183, in eval_laplacian
return fn(
^^^
File "/Users/clemens/folx/folx/wrapper.py", line 124, in new_fn
lapl_y, lapl_fns.jac_hessian_jac_trace(laplace_args, sparsity_threshold)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/hessian.py", line 394, in hessian_transform
return vmapped_jac_hessian_jac(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/hessian.py", line 367, in vmapped_jac_hessian_jac
result = hess_transform(lapl_args, extra_args, out_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/hessian.py", line 348, in hess_transform
result = general_jac_hessian_jac(merged_fn, args, out_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/hessian.py", line 99, in general_jac_hessian_jac
flat_out = JHJ_via_hessian(flat_fn, flat_x, grad_2d)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/hessian.py", line 41, in JHJ_via_hessian
flat_hessian = hessian(flat_fn)(flat_x)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/ad.py", line 97, in jacfun
J = jax.vmap(jvp_fun, out_axes=-1)(eye)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/ad.py", line 94, in jvp_fun
return jax.jvp(f, primals, unravel(s))[1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/ad.py", line 77, in jacfun
result = jax.vmap(vjp(flat_f, flat_primals))(eye)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/ad.py", line 45, in vjp
out, vjp = jax.vjp(fun, *primals)
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/ad.py", line 70, in flat_f
return jfu.ravel_pytree(f(*unravel(x)))[0]
^^^^^^^^^^^^^^
File "/Users/clemens/folx/folx/utils.py", line 99, in new_fn
return jfu.ravel_pytree(fn(*x))[0] # type: ignore
^^^^^^
File "/Users/clemens/folx/folx/hessian.py", line 334, in merged_fn
return fwd(*merge(x, extra_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Primitive mul requires varying manual axes to match, but got [frozenset({'i'}), frozenset()]. Please open an issue at https://github.com/jax-ml/jax/issues and as a temporary workaround pass the check_vma=False argument to `jax.shard_map`
PR with fix is on its way.
Metadata
Metadata
Assignees
Labels
No labels