diff --git a/python/cuda_cccl/tests/compute/test_merge_sort.py b/python/cuda_cccl/tests/compute/test_merge_sort.py index 5be3a314063..a54c92ca979 100644 --- a/python/cuda_cccl/tests/compute/test_merge_sort.py +++ b/python/cuda_cccl/tests/compute/test_merge_sort.py @@ -88,7 +88,12 @@ def test_merge_sort_keys(dtype, num_items, op): @pytest.mark.parametrize("dtype,num_items,op", merge_sort_params) -def test_merge_sort_pairs(dtype, num_items, op): +def test_merge_sort_pairs(dtype, num_items, op, monkeypatch): + if dtype == np.float16: + import cuda.compute._cccl_interop + + monkeypatch.setattr(cuda.compute._cccl_interop, "_check_sass", False) + h_in_keys = random_array(num_items, dtype) h_in_items = random_array(num_items, np.float32) @@ -125,7 +130,12 @@ def test_merge_sort_keys_copy(dtype, num_items, op): @pytest.mark.parametrize("dtype,num_items,op", merge_sort_params) -def test_merge_sort_pairs_copy(dtype, num_items, op): +def test_merge_sort_pairs_copy(dtype, num_items, op, monkeypatch): + if dtype == np.float16: + import cuda.compute._cccl_interop + + monkeypatch.setattr(cuda.compute._cccl_interop, "_check_sass", False) + h_in_keys = random_array(num_items, dtype) h_in_items = random_array(num_items, np.float32) h_out_keys = np.empty(num_items, dtype=dtype) @@ -239,7 +249,12 @@ def test_merge_sort_keys_copy_iterator_input(dtype, num_items, op): @pytest.mark.parametrize("dtype,num_items,op", merge_sort_params) -def test_merge_sort_pairs_copy_iterator_input(dtype, num_items, op): +def test_merge_sort_pairs_copy_iterator_input(dtype, num_items, op, monkeypatch): + if dtype == np.float16: + import cuda.compute._cccl_interop + + monkeypatch.setattr(cuda.compute._cccl_interop, "_check_sass", False) + h_in_keys = random_array(num_items, dtype) h_in_items = random_array(num_items, np.float32) h_out_keys = np.empty(num_items, dtype=dtype)