diff --git a/drjax/_src/primitives.py b/drjax/_src/primitives.py index ddcb66e..fdb01d8 100644 --- a/drjax/_src/primitives.py +++ b/drjax/_src/primitives.py @@ -203,7 +203,7 @@ def _batch_agg(xs, batched_shape): # Certain jax libs can silently insert the 'batching' dim 'all the way at # the front'; we are about to destroy the front axis by agging, so move # that puppy to the back. Tell the rest of JAX what happened here. - xs = batching.moveaxis(*xs, *batched_shape, -1) + xs = jnp.moveaxis(*xs, *batched_shape, -1) return agg_prim_fn(xs), len(xs.shape) - 2 # Make sure this can also be batched / mapped. This happens when dispatching