Skip to content

Skip test_pmap on ROCm due to IndivisibleError with new pmap SPMD tiling#725

Open
AratiGanesh wants to merge 1 commit intorocm-jaxlib-v0.9.0from
araganes/skip-pmap-sharding-error
Open

Skip test_pmap on ROCm due to IndivisibleError with new pmap SPMD tiling#725
AratiGanesh wants to merge 1 commit intorocm-jaxlib-v0.9.0from
araganes/skip-pmap-sharding-error

Conversation

@AratiGanesh
Copy link

@AratiGanesh AratiGanesh commented Mar 4, 2026

Motivation

The test_pmap tests in tests/ann_test.py fail on ROCm devices with IndivisibleError as of JAX 0.8.0+, due to a sharding conversion mismatch between XLA's SPMD partitioner and JAX's 1D pmap mesh.

Technical Details

Since JAX 0.8.0, pmap is internally implemented as jit(shard_map(...)), which lets XLA's SPMD partitioner decide how to tile tensors across devices. For the approx_min_k operation in this test, XLA picks a 3D tiling (e.g. [1,2,4]) that cannot be converted back to JAX's 1D pmap mesh of 8 devices.

Added a conditional skip in tests/ann_test.py for test_pmap that only skips when -

  • Running on ROCm devices (jtu.is_device_rocm())

Test Result

image

Upstream PR - jax-ml#35611

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant