diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 1086429a9c43..02d9cfdd2128 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -138,8 +138,6 @@ def test_csr_fromdense_ad(self, shape, dtype): ) @jax.default_matmul_precision("float32") def test_csr_matmul_ad(self, shape, dtype, bshape): - if jtu.is_device_rocm(): - self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") csr_matmul = sparse_csr._csr_matvec if len(bshape) == 1 else sparse_csr._csr_matmat tol = {np.float32: 2E-5, np.float64: 1E-12, np.complex64: 1E-5, np.complex128: 1E-12} @@ -218,9 +216,6 @@ def test_csr_fromdense(self, shape, dtype): transpose=[True, False], ) def test_csr_matvec(self, shape, dtype, transpose): - if jtu.is_device_rocm(): - self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") - op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) @@ -589,8 +584,6 @@ def test_coo_spmm(self, shape, dtype, transpose): ) @jtu.run_on_devices("gpu") def test_csr_spmv(self, shape, dtype, transpose): - if jtu.is_device_rocm(): - self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") rng_sparse = sptu.rand_sparse(self.rng()) rng_dense = jtu.rand_default(self.rng()) @@ -1040,8 +1033,6 @@ def test_transpose(self, shape, dtype, Obj): for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") def test_matmul(self, shape, dtype, Obj, bshape): - if jtu.is_device_rocm(): - self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") rng = sptu.rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) M = rng(shape, dtype)