From 491a4f57974066c40d986d4fd8876af17c2b4622 Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Mon, 16 Mar 2026 18:22:04 -0500 Subject: [PATCH] Disable sass checks for float16 merge sort --- .../tests/compute/test_merge_sort.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/python/cuda_cccl/tests/compute/test_merge_sort.py b/python/cuda_cccl/tests/compute/test_merge_sort.py index 5be3a314063..a54c92ca979 100644 --- a/python/cuda_cccl/tests/compute/test_merge_sort.py +++ b/python/cuda_cccl/tests/compute/test_merge_sort.py @@ -88,7 +88,12 @@ def test_merge_sort_keys(dtype, num_items, op): @pytest.mark.parametrize("dtype,num_items,op", merge_sort_params) -def test_merge_sort_pairs(dtype, num_items, op): +def test_merge_sort_pairs(dtype, num_items, op, monkeypatch): + if dtype == np.float16: + import cuda.compute._cccl_interop + + monkeypatch.setattr(cuda.compute._cccl_interop, "_check_sass", False) + h_in_keys = random_array(num_items, dtype) h_in_items = random_array(num_items, np.float32) @@ -125,7 +130,12 @@ def test_merge_sort_keys_copy(dtype, num_items, op): @pytest.mark.parametrize("dtype,num_items,op", merge_sort_params) -def test_merge_sort_pairs_copy(dtype, num_items, op): +def test_merge_sort_pairs_copy(dtype, num_items, op, monkeypatch): + if dtype == np.float16: + import cuda.compute._cccl_interop + + monkeypatch.setattr(cuda.compute._cccl_interop, "_check_sass", False) + h_in_keys = random_array(num_items, dtype) h_in_items = random_array(num_items, np.float32) h_out_keys = np.empty(num_items, dtype=dtype) @@ -239,7 +249,12 @@ def test_merge_sort_keys_copy_iterator_input(dtype, num_items, op): @pytest.mark.parametrize("dtype,num_items,op", merge_sort_params) -def test_merge_sort_pairs_copy_iterator_input(dtype, num_items, op): +def test_merge_sort_pairs_copy_iterator_input(dtype, num_items, op, monkeypatch): + if dtype == np.float16: + import cuda.compute._cccl_interop + + monkeypatch.setattr(cuda.compute._cccl_interop, "_check_sass", False) + h_in_keys = random_array(num_items, dtype) h_in_items = random_array(num_items, np.float32) h_out_keys = np.empty(num_items, dtype=dtype)