From 134bf28f95e0e12b62911f11f0ec0d1083dfb65e Mon Sep 17 00:00:00 2001 From: apinge Date: Tue, 3 Mar 2026 18:32:40 +0800 Subject: [PATCH 1/2] fix reduce and add sigmoid_mul_add_gluon Signed-off-by: apinge --- tests/gluon/fused_sigmoid_mul_add_gluon.py | 120 +++++++++++++++++++++ tests/gluon/reduce.py | 6 +- tests/gluon/test_fused_sigmoid_mul_add.py | 94 ++++++++++++++++ 3 files changed, 218 insertions(+), 2 deletions(-) create mode 100644 tests/gluon/fused_sigmoid_mul_add_gluon.py create mode 100644 tests/gluon/test_fused_sigmoid_mul_add.py diff --git a/tests/gluon/fused_sigmoid_mul_add_gluon.py b/tests/gluon/fused_sigmoid_mul_add_gluon.py new file mode 100644 index 0000000..50c1ed2 --- /dev/null +++ b/tests/gluon/fused_sigmoid_mul_add_gluon.py @@ -0,0 +1,120 @@ +""" +Ref from python/sglang/jit_kernel/fused_sigmoid_mul_add_gluon.py +""" + +import torch +import triton +from triton.experimental import gluon +from triton.experimental.gluon import language as gl +from functools import cache +# CDNA wave=64 fixed. Layout config: BLOCK_SIZE = SIZE_PER_THREAD * WARP_SIZE * WARPS_PER_CTA. +WARP_SIZE = 64 +SIZE_PER_THREAD = 2 +WARPS_PER_CTA = 2 +BLOCK_SIZE = SIZE_PER_THREAD * WARP_SIZE * WARPS_PER_CTA # 4096 + +# WARP_SIZE = 64 +# SIZE_PER_THREAD = 4 +# WARPS_PER_CTA = 1 +# BLOCK_SIZE = SIZE_PER_THREAD * WARP_SIZE * WARPS_PER_CTA # 4096 + +_layout = gl.BlockedLayout( + size_per_thread=[SIZE_PER_THREAD], + threads_per_warp=[WARP_SIZE], + warps_per_cta=[WARPS_PER_CTA], + order=[0], +) + +""" +@gluon.jit +def _fused_sigmoid_mul_add_gluon_kernel_cdna3( + gate_ptr, + shared_ptr, + out_ptr, + hidden_size, + shared_stride_row, + out_stride_row, + BLOCK_SIZE: gl.constexpr, +): + row = gl.program_id(0) + col_block = gl.program_id(1) + + col_offsets = col_block * BLOCK_SIZE + gl.arange(0, BLOCK_SIZE, layout=_col_layout) + mask = col_offsets < hidden_size + + zeros_offsets = gl.zeros([BLOCK_SIZE], dtype=gl.int32, layout=_col_layout) + gate_val = gl.amd.cdna3.buffer_load(gate_ptr + row, zeros_offsets).to(gl.float32) + sig = 1.0 / (1.0 + gl.exp(-gate_val)) + + shared_offsets = row * shared_stride_row + col_offsets + out_offsets = row * out_stride_row + col_offsets + + shared_val = gl.amd.cdna3.buffer_load(shared_ptr, shared_offsets, mask=mask).to(gl.float32) + out_val = gl.amd.cdna3.buffer_load(out_ptr, out_offsets, mask=mask).to(gl.float32) + + result = (out_val + sig * shared_val).to(gl.bfloat16) + gl.amd.cdna3.buffer_store(result, out_ptr, out_offsets, mask=mask) +""" + +@gluon.jit +def _fused_sigmoid_mul_add_gluon_kernel( + gate_ptr, + shared_ptr, + out_ptr, + hidden_size, + shared_stride_row, + out_stride_row, + BLOCK_SIZE: gl.constexpr, +): + row = gl.program_id(0) + col_block = gl.program_id(1) + + col_offsets = col_block * BLOCK_SIZE + gl.arange(0, BLOCK_SIZE, layout=_layout) + mask = col_offsets < hidden_size + + #gate_val = gl.load(gate_ptr + row).to(gl.float32) + #zeros_offsets = gl.zeros([BLOCK_SIZE], dtype=gl.int32, layout=layout) cdan3 version + temp = gl.arange(0, BLOCK_SIZE, layout=_layout) + zeros_offsets = (temp * 0).to(gl.int32) + gate_val = gl.amd.cdna4.buffer_load(gate_ptr+ row,zeros_offsets).to(gl.float32) + sig = 1.0 / (1.0 + gl.exp(-gate_val)) + + shared_offsets = row * shared_stride_row + col_offsets + out_offsets = row * out_stride_row + col_offsets + + shared_val = gl.amd.cdna4.buffer_load(shared_ptr, shared_offsets, mask=mask).to(gl.float32) + out_val = gl.amd.cdna4.buffer_load(out_ptr, out_offsets, mask=mask).to(gl.float32) + + result = (out_val + sig * shared_val).to(gl.bfloat16) + #gl.store(out_ptr + out_offsets, result, mask=mask) + gl.amd.cdna4.buffer_store(result, out_ptr, out_offsets, mask=mask) + +def fused_sigmoid_mul_add_gluon( + gate: torch.Tensor, + shared_output: torch.Tensor, + final_hidden_states: torch.Tensor, +) -> None: + """Fused sigmoid-mul-add: final_hidden_states += sigmoid(gate) * shared_output. + + Args: + gate: [num_tokens, 1] (or flattenable to 1D). + shared_output: [num_tokens, hidden_size]. + final_hidden_states: [num_tokens, hidden_size], modified in-place. + """ + # num_tokens, hidden_size = shared_output.shape + num_tokens,hidden_size = 4, 1024 + gate_flat = gate.view(-1) + + num_col_blocks = triton.cdiv(hidden_size, BLOCK_SIZE) + grid = (num_tokens, num_col_blocks) + + _fused_sigmoid_mul_add_gluon_kernel[grid]( + gate_flat, + shared_output, + final_hidden_states, + hidden_size, + shared_output.stride(0), + final_hidden_states.stride(0), + BLOCK_SIZE=BLOCK_SIZE, + num_warps=WARPS_PER_CTA, + ) diff --git a/tests/gluon/reduce.py b/tests/gluon/reduce.py index ff33aaf..42d4b4e 100644 --- a/tests/gluon/reduce.py +++ b/tests/gluon/reduce.py @@ -142,9 +142,11 @@ def test(): idx = torch.where(torch.abs(ref[i] - output[i]) > 0.05) print(f'{idx=}, {ref[i][idx]=}, {output[i][idx]=}') assert 0 - + + # bf16: 2 bytes; read input [num_tokens_total, TOPK, OC], write output [num_tokens_total, OC] + rw_bytes = num_tokens_total * TOPK * OC * 2 + num_tokens_total * OC * 2 for _ in range(10): - with pyhip.cuPerf(name="moe_gemm_final_reduce_bf16"): + with pyhip.cudaPerf(0, rw_bytes, name="moe_gemm_final_reduce_bf16"): moe_gemm_final_reduce_bf16[(num_WG,)]( input, output, diff --git a/tests/gluon/test_fused_sigmoid_mul_add.py b/tests/gluon/test_fused_sigmoid_mul_add.py new file mode 100644 index 0000000..908c427 --- /dev/null +++ b/tests/gluon/test_fused_sigmoid_mul_add.py @@ -0,0 +1,94 @@ +""" +Kernel: final_hidden_states += sigmoid(gate) * shared_output +Shapes: gate [num_tokens, 1], shared_output / final_hidden_states [num_tokens, hidden_size] +bench perf +python -m pytest test_fused_sigmoid_mul_add.py -v -k benchmark -s +""" + +import os +import sys + +import pytest +import torch + +_tests_dir = os.path.dirname(os.path.abspath(__file__)) +if _tests_dir not in sys.path: + sys.path.insert(0, _tests_dir) + +from fused_sigmoid_mul_add_gluon import fused_sigmoid_mul_add_gluon + +# Only these shapes: accuracy + benchmark +CASES = [ + pytest.param(1, 4096, torch.bfloat16, id="1x4096_bf16"), + pytest.param(2, 4096, torch.bfloat16, id="2x4096_bf16"), + pytest.param(8000, 4096, torch.bfloat16, id="8000x4096_bf16"), +] + + +def _ref(gate: torch.Tensor, shared: torch.Tensor, out: torch.Tensor) -> torch.Tensor: + return out + torch.sigmoid(gate) * shared + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("num_tokens, hidden_size, dtype", CASES) +def test_accuracy(num_tokens, hidden_size, dtype): + """Kernel output vs PyTorch reference.""" + device = torch.device("cuda", 0) + gate = torch.randn(num_tokens, 1, device=device, dtype=dtype) + shared = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype) + out = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype) + + ref = _ref(gate, shared, out) + out_kernel = out.clone() + fused_sigmoid_mul_add_gluon(gate, shared, out_kernel) + + torch.testing.assert_close(out_kernel, ref, rtol=2e-2, atol=2e-2) + + +def _bench(run_fn, num_tokens: int, hidden_size: int, dtype: torch.dtype, warmup: int = 10, repeat: int = 100): + for _ in range(warmup): + run_fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(repeat): + run_fn() + end.record() + torch.cuda.synchronize() + mean_ms = start.elapsed_time(end) / repeat + es = torch.empty(0, dtype=dtype).element_size() + bytes_run = (num_tokens + num_tokens * hidden_size + 2 * num_tokens * hidden_size) * es + gbps = (bytes_run / 1e9) / (mean_ms / 1000) if mean_ms > 0 else 0 + return mean_ms, gbps + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("num_tokens, hidden_size, dtype", CASES) +def test_benchmark(num_tokens, hidden_size, dtype): + """Fused kernel vs PyTorch ref: time and GB/s.""" + device = torch.device("cuda", 0) + gate = torch.randn(num_tokens, 1, device=device, dtype=dtype) + shared = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype) + z = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype) + + def run_kernel(): + out = z.clone() + fused_sigmoid_mul_add_gluon(gate, shared, out) + + def run_ref(): + _ref(gate, shared, z) + + fused_ms, fused_gbps = _bench(run_kernel, num_tokens, hidden_size, dtype) + ref_ms, ref_gbps = _bench(run_ref, num_tokens, hidden_size, dtype) + speedup = ref_ms / fused_ms if fused_ms > 0 else 0.0 + + print( + f"\n[benchmark] {num_tokens}x{hidden_size} {dtype}\n" + f" fused: {fused_ms*1000:.2f} us/iter, {fused_gbps:.2f} GB/s\n" + f" ref: {ref_ms*1000:.2f} us/iter, {ref_gbps:.2f} GB/s speedup {speedup:.2f}x" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From 5bd14212ca3535312ade967ae367cfa69297eb73 Mon Sep 17 00:00:00 2001 From: apinge Date: Tue, 3 Mar 2026 19:08:52 +0800 Subject: [PATCH 2/2] correct grid Signed-off-by: apinge --- tests/gluon/fused_sigmoid_mul_add_gluon.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/gluon/fused_sigmoid_mul_add_gluon.py b/tests/gluon/fused_sigmoid_mul_add_gluon.py index 50c1ed2..f6b1016 100644 --- a/tests/gluon/fused_sigmoid_mul_add_gluon.py +++ b/tests/gluon/fused_sigmoid_mul_add_gluon.py @@ -101,8 +101,8 @@ def fused_sigmoid_mul_add_gluon( shared_output: [num_tokens, hidden_size]. final_hidden_states: [num_tokens, hidden_size], modified in-place. """ - # num_tokens, hidden_size = shared_output.shape - num_tokens,hidden_size = 4, 1024 + num_tokens, hidden_size = shared_output.shape + #num_tokens,hidden_size = 4, 1024 gate_flat = gate.view(-1) num_col_blocks = triton.cdiv(hidden_size, BLOCK_SIZE)