From 6fa8eb4a7a654ffd0779199faa2346d4828ec6d3 Mon Sep 17 00:00:00 2001 From: binyang <1575938147@qq.com> Date: Wed, 24 Dec 2025 21:31:27 +0800 Subject: [PATCH 1/5] support trtllm_batch_context_with_kv_cache and trtllm_fp8_block_scale_moe --- tests/attention/test_attention_sink_blackwell.py | 5 +++-- tests/moe/test_trtllm_gen_fused_moe.py | 1 - tests/test_helpers/sink_attention_reference.py | 14 ++++++++------ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/attention/test_attention_sink_blackwell.py b/tests/attention/test_attention_sink_blackwell.py index 104ca9cf8e..b9ad71d9a0 100644 --- a/tests/attention/test_attention_sink_blackwell.py +++ b/tests/attention/test_attention_sink_blackwell.py @@ -250,6 +250,7 @@ def test_blackwell_trtllm_gen_context_attention_sink( atol, rtol = 1e-2, 1e-2 else: raise ValueError(f"Unsupported dtype: {dtype}") - ref_o = o_ref.float().numpy() - np.testing.assert_allclose(ref_o, paddle_o, atol=atol, rtol=rtol) + output_o = output.float().numpy() + np.testing.assert_allclose(ref_o, output_o, atol=atol, rtol=rtol) + diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 190cddc92d..aa03f06ab8 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -2411,7 +2411,6 @@ def test_moe_quantization_classes( """ device = paddle.get_device() compute_capability = cur_get_compute_capability(device) - print("Compute Capability: ", compute_capability) if compute_capability[0] in [11, 12]: pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.") # Skip incompatible combinations diff --git a/tests/test_helpers/sink_attention_reference.py b/tests/test_helpers/sink_attention_reference.py index 1f8d1a13c4..3d5c56b7f4 100644 --- a/tests/test_helpers/sink_attention_reference.py +++ b/tests/test_helpers/sink_attention_reference.py @@ -302,9 +302,10 @@ def sink_attention_unified( if causal: if mode == "incremental": # For incremental: new token can attend to all previous tokens - mask = torch.ones(1, kv_len, device=q.device, dtype=torch.bool) + mask = torch.ones(1, kv_len, device=q.place.device, dtype=torch.bool) if window_left >= 0: - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device) + col_idx = torch.arange(kv_len, dtype=torch.int32, + device=q.place.device) mask = (kv_len - 1 - window_left) <= col_idx elif mode == "prefill": # For regular prefill: standard causal mask @@ -312,13 +313,14 @@ def sink_attention_unified( # 1 # ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) mask = torch.arange(kv_len - qo_len, kv_len).unsqueeze( - 1 - ) >= torch.arange(0, kv_len).unsqueeze(0) + 1) >= torch.arange(0, kv_len).unsqueeze(0) if window_left >= 0: - row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ + row_idx = torch.arange(qo_len, dtype=torch.int32, + device=q.place.device)[ :, None ] - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[ + col_idx = torch.arange(kv_len, dtype=torch.int32, + device=q.place.device)[ None, : ] mask &= row_idx - window_left <= col_idx From adf96061c01b431838280eef974ce9580b3f0302 Mon Sep 17 00:00:00 2001 From: binyang <1575938147@qq.com> Date: Fri, 26 Dec 2025 14:31:19 +0800 Subject: [PATCH 2/5] adaptive attention and moe test --- tests/attention/test_attention_sink_blackwell.py | 1 - tests/moe/test_trtllm_gen_fused_moe.py | 1 + tests/test_helpers/sink_attention_reference.py | 14 ++++++-------- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/attention/test_attention_sink_blackwell.py b/tests/attention/test_attention_sink_blackwell.py index b9ad71d9a0..09d666653c 100644 --- a/tests/attention/test_attention_sink_blackwell.py +++ b/tests/attention/test_attention_sink_blackwell.py @@ -253,4 +253,3 @@ def test_blackwell_trtllm_gen_context_attention_sink( ref_o = o_ref.float().numpy() output_o = output.float().numpy() np.testing.assert_allclose(ref_o, output_o, atol=atol, rtol=rtol) - diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index aa03f06ab8..190cddc92d 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -2411,6 +2411,7 @@ def test_moe_quantization_classes( """ device = paddle.get_device() compute_capability = cur_get_compute_capability(device) + print("Compute Capability: ", compute_capability) if compute_capability[0] in [11, 12]: pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.") # Skip incompatible combinations diff --git a/tests/test_helpers/sink_attention_reference.py b/tests/test_helpers/sink_attention_reference.py index 3d5c56b7f4..61ef36a20e 100644 --- a/tests/test_helpers/sink_attention_reference.py +++ b/tests/test_helpers/sink_attention_reference.py @@ -345,12 +345,10 @@ def sink_attention_unified( mask &= abs_row_positions - window_left <= col_idx else: # Non-causal mask - q_device = q.place - print("+++",q_device) if mode == "incremental": - mask = torch.ones(1, kv_len, device=q_device, dtype=torch.bool) + mask = torch.ones(1, kv_len, dtype=torch.bool) if window_left >= 0: - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device) + col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.place) mask = (kv_len - 1 - window_left) <= col_idx else: # prefill or chunk mask = torch.ones(qo_len, kv_len, dtype=torch.bool) @@ -358,19 +356,19 @@ def sink_attention_unified( if mode == "chunk": # For chunk mode, apply window relative to absolute positions current_chunk_start = kv_len - qo_len - row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ + row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.place)[ :, None ] - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[ + col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.place)[ None, : ] abs_row_positions = current_chunk_start + row_idx mask = abs_row_positions - window_left <= col_idx else: # prefill - row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ + row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.place)[ :, None ] - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[ + col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.place)[ None, : ] mask = row_idx - window_left <= col_idx From 9a8f8c494c1d2df85946d500b5e9918bee25a905 Mon Sep 17 00:00:00 2001 From: binyang <1575938147@qq.com> Date: Fri, 26 Dec 2025 15:11:11 +0800 Subject: [PATCH 3/5] support test --- .../test_helpers/sink_attention_reference.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/tests/test_helpers/sink_attention_reference.py b/tests/test_helpers/sink_attention_reference.py index 61ef36a20e..aaa2367228 100644 --- a/tests/test_helpers/sink_attention_reference.py +++ b/tests/test_helpers/sink_attention_reference.py @@ -302,10 +302,9 @@ def sink_attention_unified( if causal: if mode == "incremental": # For incremental: new token can attend to all previous tokens - mask = torch.ones(1, kv_len, device=q.place.device, dtype=torch.bool) + mask = torch.ones(1, kv_len, device=q.device, dtype=torch.bool) if window_left >= 0: - col_idx = torch.arange(kv_len, dtype=torch.int32, - device=q.place.device) + col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device) mask = (kv_len - 1 - window_left) <= col_idx elif mode == "prefill": # For regular prefill: standard causal mask @@ -313,14 +312,17 @@ def sink_attention_unified( # 1 # ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) mask = torch.arange(kv_len - qo_len, kv_len).unsqueeze( +<<<<<<< HEAD 1) >= torch.arange(0, kv_len).unsqueeze(0) +======= + 1 + ) >= torch.arange(0, kv_len).unsqueeze(0) +>>>>>>> 10a7f09 (support test) if window_left >= 0: - row_idx = torch.arange(qo_len, dtype=torch.int32, - device=q.place.device)[ + row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ :, None ] - col_idx = torch.arange(kv_len, dtype=torch.int32, - device=q.place.device)[ + col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[ None, : ] mask &= row_idx - window_left <= col_idx @@ -345,10 +347,16 @@ def sink_attention_unified( mask &= abs_row_positions - window_left <= col_idx else: # Non-causal mask + q_device = q.place + print("+++",q_device) if mode == "incremental": +<<<<<<< HEAD mask = torch.ones(1, kv_len, dtype=torch.bool) +======= + mask = torch.ones(1, kv_len, device=q_device, dtype=torch.bool) +>>>>>>> 10a7f09 (support test) if window_left >= 0: - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.place) + col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device) mask = (kv_len - 1 - window_left) <= col_idx else: # prefill or chunk mask = torch.ones(qo_len, kv_len, dtype=torch.bool) @@ -356,19 +364,19 @@ def sink_attention_unified( if mode == "chunk": # For chunk mode, apply window relative to absolute positions current_chunk_start = kv_len - qo_len - row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.place)[ + row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ :, None ] - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.place)[ + col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[ None, : ] abs_row_positions = current_chunk_start + row_idx mask = abs_row_positions - window_left <= col_idx else: # prefill - row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.place)[ + row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ :, None ] - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.place)[ + col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[ None, : ] mask = row_idx - window_left <= col_idx From f8d129f7f7c662e9756fa8a2459dc570bccbd2cb Mon Sep 17 00:00:00 2001 From: binyang <1575938147@qq.com> Date: Tue, 6 Jan 2026 18:49:01 +0800 Subject: [PATCH 4/5] fix conflict issues --- tests/test_helpers/sink_attention_reference.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/test_helpers/sink_attention_reference.py b/tests/test_helpers/sink_attention_reference.py index aaa2367228..5527b45373 100644 --- a/tests/test_helpers/sink_attention_reference.py +++ b/tests/test_helpers/sink_attention_reference.py @@ -312,12 +312,7 @@ def sink_attention_unified( # 1 # ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) mask = torch.arange(kv_len - qo_len, kv_len).unsqueeze( -<<<<<<< HEAD 1) >= torch.arange(0, kv_len).unsqueeze(0) -======= - 1 - ) >= torch.arange(0, kv_len).unsqueeze(0) ->>>>>>> 10a7f09 (support test) if window_left >= 0: row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ :, None @@ -347,14 +342,8 @@ def sink_attention_unified( mask &= abs_row_positions - window_left <= col_idx else: # Non-causal mask - q_device = q.place - print("+++",q_device) if mode == "incremental": -<<<<<<< HEAD mask = torch.ones(1, kv_len, dtype=torch.bool) -======= - mask = torch.ones(1, kv_len, device=q_device, dtype=torch.bool) ->>>>>>> 10a7f09 (support test) if window_left >= 0: col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device) mask = (kv_len - 1 - window_left) <= col_idx From 9fb038ef83fa92b8c9295fcf153afd68469c9b14 Mon Sep 17 00:00:00 2001 From: binyang <1575938147@qq.com> Date: Tue, 6 Jan 2026 19:28:25 +0800 Subject: [PATCH 5/5] fix pytest issues --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1db8d85973..c1e050d0ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -145,7 +145,7 @@ def pytest_runtest_call(item): try: item.runtest() except: - # assert(False) + assert(False) # try: # item.runtest() # except (torch.cuda.OutOfMemoryError, RuntimeError) as e: