Skip to content

Batched Cholesky (potrf) OOM crash due to hipSolver allocating outside XLA memory pool #718

@FlemingH

Description

@FlemingH

Description

Background

While running MuJoCo MJX robotics simulation workloads on AMD RDNA 3 GPUs, the Newton solver (which internally uses jax.vmap over Cholesky decomposition) consistently crashes with:

INTERNAL: solver_kernels_ffi.cc:452: operation hipGetLastError() failed: out of memory

PR: #717

Minimal Reproducer

import jax, jax.numpy as jnp

A = jnp.eye(4) * 5 + jnp.ones((4, 4))
Ab = jnp.broadcast_to(A, (2, 4, 4)).copy()
r = jax.jit(lambda A: jax.vmap(jnp.linalg.cholesky)(A))(Ab)
r.block_until_ready()

Root Cause

When batch > 1, PotrfDispatch routes to PotrfBatchedImpl, which calls hipSOLVER's batched API. This API internally calls hipMalloc outside of XLA's BFC allocator. Since XLA preallocates ~75% of GPU VRAM by default, the external allocation fails — even for a tiny 2x4x4 matrix.

The non-batched path (PotrfImpl) correctly uses XLA's scratch allocator for all memory and works fine.

Workaround

export XLA_PYTHON_CLIENT_PREALLOCATE=false

System info (python version, jaxlib version, accelerator, etc.)

  • OS: Ubuntu 24.04.3 LTS, Kernel 6.14.0-37-generic x86_64
  • CPU: AMD Ryzen 7 5800X 8-Core
  • GPU: AMD Radeon RX 7900 XTX (gfx1100 / RDNA 3, 24GB VRAM)
  • ROCm: 7.2.0
  • Python: 3.12.12
  • JAX / JAXlib: 0.8.0
  • jax-rocm7-plugin: 0.8.0+rocm7.2.0
  • jax-rocm7-pjrt: 0.8.0+rocm7.2.0
  • rocSOLVER: 3.32.0.70200
  • hipSOLVER: 3.2.0.70200

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingstatus: triageIndicates an issue has been assigned for investigation.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions