Skip to content

forward_laplacian inside of shard_map vma errors when using function with integer power 1 #38

@inailuig

Description

@inailuig

Recently came across this somewhat obscure error because of a missing pvary.
The missing pvary is the one I mentioned at the end of #36 (comment), but there I had a different reason in mind of why it would be needed.

Reproducer:

import jax
import jax.numpy as jnp
from folx import forward_laplacian
from functools import partial

def f(w, x):
    return jax.lax.integer_pow(x @ w, 1)

@jax.smap(out_axes=0,in_axes=(None, 0), axis_name='i')
@partial(jax.vmap, in_axes=(None, 0))
def test(w, x):
    return forward_laplacian(partial(f, w))(x)

x = jnp.ones((1,16))
w = jnp.ones((16,16))

with jax.set_mesh(jax.sharding.Mesh(jax.devices(), 'i')):
    test(w,x)
ERROR:[folx](/private/tmp/bug.py:7:11 (f)) - Error in operation integer_pow.
Traceback (most recent call last):
  File "/private/tmp/bug.py", line 18, in <module>
    test(w,x)
  File "/private/tmp/bug.py", line 12, in test
    return forward_laplacian(partial(f, w))(x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/interpreter.py", line 309, in wrapped
    out = eval_jaxpr_with_forward_laplacian(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/interpreter.py", line 226, in eval_jaxpr_with_forward_laplacian
    raise e
  File "/Users/clemens/folx/folx/interpreter.py", line 222, in eval_jaxpr_with_forward_laplacian
    outvals = eval_laplacian(eqn, invals)
^^^^^^^
  File "/Users/clemens/folx/folx/interpreter.py", line 183, in eval_laplacian
    return fn(
           ^^^
  File "/Users/clemens/folx/folx/wrapper.py", line 124, in new_fn
    lapl_y, lapl_fns.jac_hessian_jac_trace(laplace_args, sparsity_threshold)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 394, in hessian_transform
    return vmapped_jac_hessian_jac(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 367, in vmapped_jac_hessian_jac
    result = hess_transform(lapl_args, extra_args, out_idx)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 348, in hess_transform
    result = general_jac_hessian_jac(merged_fn, args, out_idx)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 99, in general_jac_hessian_jac
    flat_out = JHJ_via_hessian(flat_fn, flat_x, grad_2d)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 41, in JHJ_via_hessian
    flat_hessian = hessian(flat_fn)(flat_x)
                   ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/ad.py", line 97, in jacfun
    J = jax.vmap(jvp_fun, out_axes=-1)(eye)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/ad.py", line 94, in jvp_fun
    return jax.jvp(f, primals, unravel(s))[1]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/ad.py", line 77, in jacfun
    result = jax.vmap(vjp(flat_f, flat_primals))(eye)[0]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/ad.py", line 45, in vjp
    out, vjp = jax.vjp(fun, *primals)
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/ad.py", line 70, in flat_f
    return jfu.ravel_pytree(f(*unravel(x)))[0]
                            ^^^^^^^^^^^^^^
  File "/Users/clemens/folx/folx/utils.py", line 99, in new_fn
    return jfu.ravel_pytree(fn(*x))[0]  # type: ignore
                            ^^^^^^
  File "/Users/clemens/folx/folx/hessian.py", line 334, in merged_fn
    return fwd(*merge(x, extra_args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Primitive mul requires varying manual axes to match, but got [frozenset({'i'}), frozenset()]. Please open an issue at https://github.com/jax-ml/jax/issues and as a temporary workaround pass the check_vma=False argument to `jax.shard_map`

PR with fix is on its way.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions