Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions tests/gluon/fused_sigmoid_mul_add_gluon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
Ref from python/sglang/jit_kernel/fused_sigmoid_mul_add_gluon.py
"""

import torch
import triton
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from functools import cache
# CDNA wave=64 fixed. Layout config: BLOCK_SIZE = SIZE_PER_THREAD * WARP_SIZE * WARPS_PER_CTA.
WARP_SIZE = 64
SIZE_PER_THREAD = 2
WARPS_PER_CTA = 2
BLOCK_SIZE = SIZE_PER_THREAD * WARP_SIZE * WARPS_PER_CTA # 4096

# WARP_SIZE = 64
# SIZE_PER_THREAD = 4
# WARPS_PER_CTA = 1
# BLOCK_SIZE = SIZE_PER_THREAD * WARP_SIZE * WARPS_PER_CTA # 4096

_layout = gl.BlockedLayout(
size_per_thread=[SIZE_PER_THREAD],
threads_per_warp=[WARP_SIZE],
warps_per_cta=[WARPS_PER_CTA],
order=[0],
)

"""
@gluon.jit
def _fused_sigmoid_mul_add_gluon_kernel_cdna3(
gate_ptr,
shared_ptr,
out_ptr,
hidden_size,
shared_stride_row,
out_stride_row,
BLOCK_SIZE: gl.constexpr,
):
row = gl.program_id(0)
col_block = gl.program_id(1)

col_offsets = col_block * BLOCK_SIZE + gl.arange(0, BLOCK_SIZE, layout=_col_layout)
mask = col_offsets < hidden_size

zeros_offsets = gl.zeros([BLOCK_SIZE], dtype=gl.int32, layout=_col_layout)
gate_val = gl.amd.cdna3.buffer_load(gate_ptr + row, zeros_offsets).to(gl.float32)
sig = 1.0 / (1.0 + gl.exp(-gate_val))

shared_offsets = row * shared_stride_row + col_offsets
out_offsets = row * out_stride_row + col_offsets

shared_val = gl.amd.cdna3.buffer_load(shared_ptr, shared_offsets, mask=mask).to(gl.float32)
out_val = gl.amd.cdna3.buffer_load(out_ptr, out_offsets, mask=mask).to(gl.float32)

result = (out_val + sig * shared_val).to(gl.bfloat16)
gl.amd.cdna3.buffer_store(result, out_ptr, out_offsets, mask=mask)
"""

@gluon.jit
def _fused_sigmoid_mul_add_gluon_kernel(
gate_ptr,
shared_ptr,
out_ptr,
hidden_size,
shared_stride_row,
out_stride_row,
BLOCK_SIZE: gl.constexpr,
):
row = gl.program_id(0)
col_block = gl.program_id(1)

col_offsets = col_block * BLOCK_SIZE + gl.arange(0, BLOCK_SIZE, layout=_layout)
mask = col_offsets < hidden_size

#gate_val = gl.load(gate_ptr + row).to(gl.float32)
#zeros_offsets = gl.zeros([BLOCK_SIZE], dtype=gl.int32, layout=layout) cdan3 version
temp = gl.arange(0, BLOCK_SIZE, layout=_layout)
zeros_offsets = (temp * 0).to(gl.int32)
gate_val = gl.amd.cdna4.buffer_load(gate_ptr+ row,zeros_offsets).to(gl.float32)
sig = 1.0 / (1.0 + gl.exp(-gate_val))

shared_offsets = row * shared_stride_row + col_offsets
out_offsets = row * out_stride_row + col_offsets

shared_val = gl.amd.cdna4.buffer_load(shared_ptr, shared_offsets, mask=mask).to(gl.float32)
out_val = gl.amd.cdna4.buffer_load(out_ptr, out_offsets, mask=mask).to(gl.float32)

result = (out_val + sig * shared_val).to(gl.bfloat16)
#gl.store(out_ptr + out_offsets, result, mask=mask)
gl.amd.cdna4.buffer_store(result, out_ptr, out_offsets, mask=mask)

def fused_sigmoid_mul_add_gluon(
gate: torch.Tensor,
shared_output: torch.Tensor,
final_hidden_states: torch.Tensor,
) -> None:
"""Fused sigmoid-mul-add: final_hidden_states += sigmoid(gate) * shared_output.

Args:
gate: [num_tokens, 1] (or flattenable to 1D).
shared_output: [num_tokens, hidden_size].
final_hidden_states: [num_tokens, hidden_size], modified in-place.
"""
num_tokens, hidden_size = shared_output.shape
#num_tokens,hidden_size = 4, 1024
gate_flat = gate.view(-1)

num_col_blocks = triton.cdiv(hidden_size, BLOCK_SIZE)
grid = (num_tokens, num_col_blocks)

_fused_sigmoid_mul_add_gluon_kernel[grid](
gate_flat,
shared_output,
final_hidden_states,
hidden_size,
shared_output.stride(0),
final_hidden_states.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
num_warps=WARPS_PER_CTA,
)
6 changes: 4 additions & 2 deletions tests/gluon/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@ def test():
idx = torch.where(torch.abs(ref[i] - output[i]) > 0.05)
print(f'{idx=}, {ref[i][idx]=}, {output[i][idx]=}')
assert 0


# bf16: 2 bytes; read input [num_tokens_total, TOPK, OC], write output [num_tokens_total, OC]
rw_bytes = num_tokens_total * TOPK * OC * 2 + num_tokens_total * OC * 2
for _ in range(10):
with pyhip.cuPerf(name="moe_gemm_final_reduce_bf16"):
with pyhip.cudaPerf(0, rw_bytes, name="moe_gemm_final_reduce_bf16"):
moe_gemm_final_reduce_bf16[(num_WG,)](
input,
output,
Expand Down
94 changes: 94 additions & 0 deletions tests/gluon/test_fused_sigmoid_mul_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Kernel: final_hidden_states += sigmoid(gate) * shared_output
Shapes: gate [num_tokens, 1], shared_output / final_hidden_states [num_tokens, hidden_size]
bench perf
python -m pytest test_fused_sigmoid_mul_add.py -v -k benchmark -s
"""

import os
import sys

import pytest
import torch

_tests_dir = os.path.dirname(os.path.abspath(__file__))
if _tests_dir not in sys.path:
sys.path.insert(0, _tests_dir)

from fused_sigmoid_mul_add_gluon import fused_sigmoid_mul_add_gluon

# Only these shapes: accuracy + benchmark
CASES = [
pytest.param(1, 4096, torch.bfloat16, id="1x4096_bf16"),
pytest.param(2, 4096, torch.bfloat16, id="2x4096_bf16"),
pytest.param(8000, 4096, torch.bfloat16, id="8000x4096_bf16"),
]


def _ref(gate: torch.Tensor, shared: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
return out + torch.sigmoid(gate) * shared


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
@pytest.mark.parametrize("num_tokens, hidden_size, dtype", CASES)
def test_accuracy(num_tokens, hidden_size, dtype):
"""Kernel output vs PyTorch reference."""
device = torch.device("cuda", 0)
gate = torch.randn(num_tokens, 1, device=device, dtype=dtype)
shared = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype)
out = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype)

ref = _ref(gate, shared, out)
out_kernel = out.clone()
fused_sigmoid_mul_add_gluon(gate, shared, out_kernel)

torch.testing.assert_close(out_kernel, ref, rtol=2e-2, atol=2e-2)


def _bench(run_fn, num_tokens: int, hidden_size: int, dtype: torch.dtype, warmup: int = 10, repeat: int = 100):
for _ in range(warmup):
run_fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(repeat):
run_fn()
end.record()
torch.cuda.synchronize()
mean_ms = start.elapsed_time(end) / repeat
es = torch.empty(0, dtype=dtype).element_size()
bytes_run = (num_tokens + num_tokens * hidden_size + 2 * num_tokens * hidden_size) * es
gbps = (bytes_run / 1e9) / (mean_ms / 1000) if mean_ms > 0 else 0
return mean_ms, gbps


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
@pytest.mark.parametrize("num_tokens, hidden_size, dtype", CASES)
def test_benchmark(num_tokens, hidden_size, dtype):
"""Fused kernel vs PyTorch ref: time and GB/s."""
device = torch.device("cuda", 0)
gate = torch.randn(num_tokens, 1, device=device, dtype=dtype)
shared = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype)
z = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype)

def run_kernel():
out = z.clone()
fused_sigmoid_mul_add_gluon(gate, shared, out)

def run_ref():
_ref(gate, shared, z)

fused_ms, fused_gbps = _bench(run_kernel, num_tokens, hidden_size, dtype)
ref_ms, ref_gbps = _bench(run_ref, num_tokens, hidden_size, dtype)
speedup = ref_ms / fused_ms if fused_ms > 0 else 0.0

print(
f"\n[benchmark] {num_tokens}x{hidden_size} {dtype}\n"
f" fused: {fused_ms*1000:.2f} us/iter, {fused_gbps:.2f} GB/s\n"
f" ref: {ref_ms*1000:.2f} us/iter, {ref_gbps:.2f} GB/s speedup {speedup:.2f}x"
)


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])