From 7c8b2e590b0737af1d5ea1d145f21590ff4cb56c Mon Sep 17 00:00:00 2001 From: Arati Ganesh Date: Wed, 4 Mar 2026 14:37:15 -0600 Subject: [PATCH] Skip test_pmap on ROCm due to IndivisibleError with new pmap SPMD tiling --- tests/ann_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/ann_test.py b/tests/ann_test.py index 18bb51bec93b..c4bc6afe273f 100644 --- a/tests/ann_test.py +++ b/tests/ann_test.py @@ -119,6 +119,10 @@ def test_autodiff(self, shape, dtype, k, is_max_k): ) def test_pmap(self, qy_shape, db_shape, dtype, k, recall): num_devices = jax.device_count() + # TODO(araganes): Re-enable once upstream HloShardingV3 lands (JAX 0.9.2+). + # New pmap's SPMD tiling can't convert back to 1D pmap mesh on ROCm. + if jtu.is_device_rocm(): + self.skipTest("IndivisibleError - SPMD tiling incompatible with 1D pmap mesh on ROCm") rng = jtu.rand_default(self.rng()) qy = rng(qy_shape, dtype) db = rng(db_shape, dtype)