diff --git a/tests/ann_test.py b/tests/ann_test.py index 18bb51bec93b..c4bc6afe273f 100644 --- a/tests/ann_test.py +++ b/tests/ann_test.py @@ -119,6 +119,10 @@ def test_autodiff(self, shape, dtype, k, is_max_k): ) def test_pmap(self, qy_shape, db_shape, dtype, k, recall): num_devices = jax.device_count() + # TODO(araganes): Re-enable once upstream HloShardingV3 lands (JAX 0.9.2+). + # New pmap's SPMD tiling can't convert back to 1D pmap mesh on ROCm. + if jtu.is_device_rocm(): + self.skipTest("IndivisibleError - SPMD tiling incompatible with 1D pmap mesh on ROCm") rng = jtu.rand_default(self.rng()) qy = rng(qy_shape, dtype) db = rng(db_shape, dtype)