diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 6a8c99e410ec..9457717d29f3 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -2330,6 +2330,10 @@ def test_tridiagonal_solve_endpoints(self): @jtu.sample_product(shape=[(3,), (3, 4)], dtype=float_types + complex_types) def test_tridiagonal_solve_grad(self, shape, dtype): + if jtu.is_device_rocm() and shape == (3, 4) and dtype == np.float32: + # numerical errors seen as of ROCm 7.2 due to rocSparse issue for grad0 variant + # TODO: re-enable the test once the rocSparse issue is fixed + self.skipTest("test_tridiagonal_solve_grad0 not supported on ROCm due to rocSparse issue") if dtype not in float_types and jtu.test_device_matches(["gpu"]): self.skipTest("Data type not supported on GPU") rng = self.rng()