forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Labels
bugSomething isn't workingSomething isn't workingstatus: triageIndicates an issue has been assigned for investigation.Indicates an issue has been assigned for investigation.
Description
Description
Background
While running MuJoCo MJX robotics simulation workloads on AMD RDNA 3 GPUs, the Newton solver (which internally uses jax.vmap over Cholesky decomposition) consistently crashes with:
INTERNAL: solver_kernels_ffi.cc:452: operation hipGetLastError() failed: out of memory
PR: #717
Minimal Reproducer
import jax, jax.numpy as jnp
A = jnp.eye(4) * 5 + jnp.ones((4, 4))
Ab = jnp.broadcast_to(A, (2, 4, 4)).copy()
r = jax.jit(lambda A: jax.vmap(jnp.linalg.cholesky)(A))(Ab)
r.block_until_ready()Root Cause
When batch > 1, PotrfDispatch routes to PotrfBatchedImpl, which calls hipSOLVER's batched API. This API internally calls hipMalloc outside of XLA's BFC allocator. Since XLA preallocates ~75% of GPU VRAM by default, the external allocation fails — even for a tiny 2x4x4 matrix.
The non-batched path (PotrfImpl) correctly uses XLA's scratch allocator for all memory and works fine.
Workaround
export XLA_PYTHON_CLIENT_PREALLOCATE=falseSystem info (python version, jaxlib version, accelerator, etc.)
- OS: Ubuntu 24.04.3 LTS, Kernel 6.14.0-37-generic x86_64
- CPU: AMD Ryzen 7 5800X 8-Core
- GPU: AMD Radeon RX 7900 XTX (gfx1100 / RDNA 3, 24GB VRAM)
- ROCm: 7.2.0
- Python: 3.12.12
- JAX / JAXlib: 0.8.0
- jax-rocm7-plugin: 0.8.0+rocm7.2.0
- jax-rocm7-pjrt: 0.8.0+rocm7.2.0
- rocSOLVER: 3.32.0.70200
- hipSOLVER: 3.2.0.70200
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstatus: triageIndicates an issue has been assigned for investigation.Indicates an issue has been assigned for investigation.