From 5517bfa808aff1a4d98481c1ec3fe33db1ab55ef Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 18 Aug 2025 09:27:10 -0700 Subject: [PATCH] [drjax] Avoid call to deprecated `batching.moveaxis` This is deprecated in JAX v0.7.1; `jnp.moveaxis` is a drop-in replacement. PiperOrigin-RevId: 796456784 --- drjax/_src/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/drjax/_src/primitives.py b/drjax/_src/primitives.py index ddcb66e..fdb01d8 100644 --- a/drjax/_src/primitives.py +++ b/drjax/_src/primitives.py @@ -203,7 +203,7 @@ def _batch_agg(xs, batched_shape): # Certain jax libs can silently insert the 'batching' dim 'all the way at # the front'; we are about to destroy the front axis by agging, so move # that puppy to the back. Tell the rest of JAX what happened here. - xs = batching.moveaxis(*xs, *batched_shape, -1) + xs = jnp.moveaxis(*xs, *batched_shape, -1) return agg_prim_fn(xs), len(xs.shape) - 2 # Make sure this can also be batched / mapped. This happens when dispatching