From 80f6c4b3605cf67033f96980aff6b78000669b65 Mon Sep 17 00:00:00 2001 From: Zzz9990 Date: Tue, 3 Feb 2026 15:08:25 +0800 Subject: [PATCH 01/11] reconstruction start --- flydsl/src/flydsl/kernels | 1 + kernels/mfma_preshuffle_pipeline.py | 463 +++++++++++++++++++++++++- kernels/mixed_preshuffle_gemm.py | 95 ------ kernels/preshuffle_gemm.py | 426 ++++++++++-------------- tests/kernels/test_preshuffle_gemm.py | 8 +- 5 files changed, 642 insertions(+), 351 deletions(-) create mode 120000 flydsl/src/flydsl/kernels diff --git a/flydsl/src/flydsl/kernels b/flydsl/src/flydsl/kernels new file mode 120000 index 00000000..d48390c1 --- /dev/null +++ b/flydsl/src/flydsl/kernels @@ -0,0 +1 @@ +../../../kernels \ No newline at end of file diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 1b8d02e6..0c35b5ce 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -12,9 +12,246 @@ from __future__ import annotations from dataclasses import dataclass +from flydsl.dialects.ext import arith, gpu, buffer_ops, vector, rocdl +from flydsl.lang.ir.types import T, memref + from _mlir import ir -@dataclass(frozen=True) +from enum import Enum + +class MfmaPipeline(Enum): + F8F4_MXFP4_PIPELINE = "F8F4_MXFP4_PIPELINE" + F8F8_MXFP4_PIPELINE = "F8F8_MXFP4_PIPELINE" + F16F16_16x16_PIPELINE = "F16F16_16x16_PIPELINE" + BF16BF16_16x16_PIPELINE = "BF16BF16_16x16_PIPELINE" + I8I8_16x16_PIPELINE = "I8I8_16x16_PIPELINE" + I8I4_16x16_PIPELINE = "I8I4_16x16_PIPELINE" + +class EpilogPipeline(Enum): + CSHUFFLE_F16 = "CSHUFFLE_F16" + CSHUFFLE_BF16 = "CSHUFFLE_BF16" + CSHUFFLE_F32 = "CSHUFFLE_F32" + DIRECT_F16 = "DIRECT_F16" + DIRECT_BF16 = "DIRECT_BF16" + DIRECT_F32 = "DIRECT_F32" + +a_elem_type_dict = { + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.f8, + MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.f8, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i8, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i8, +} + +b_elem_type_dict = { + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.ui8, + MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.f8, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i8, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.f8, +} + +scale_elem_type_dict = { + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.i32, + MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.i32, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.f32, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.f32, + # bf16 scale placeholder + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f32, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.f32, +} + +out_elem_type_dict = { + EpilogPipeline.CSHUFFLE_F16: lambda: T.f16, + EpilogPipeline.CSHUFFLE_BF16: lambda: T.bf16, + EpilogPipeline.CSHUFFLE_F32: lambda: T.f32, + EpilogPipeline.DIRECT_F16: lambda: T.f16, + EpilogPipeline.DIRECT_BF16: lambda: T.bf16, + EpilogPipeline.DIRECT_F32: lambda: T.f32, +} + +a_vec16_type_dict = { + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.f8x16, + MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda:T.f8x16, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda:T.f16x8, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda:T.bf16x8, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda:T.i8x16, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda:T.i8x16, +} + +b_vec16_type_dict = { + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.ui8x16, + MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.f8x16, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16x8, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16x8, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i8x16, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.f8x16, +} + +mfma_input_pack_ty_dict = { + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.i64, + MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.i64, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16x4, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.i16x4, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i32x4, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i32x4, +} + +mfma_output_pack_ty_dict = { + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.f32x4, + MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.f32x4, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f32x4, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.f32x4, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i32x4, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i32x4, +} + +def get_mfma_i32_k32(): + mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( + rocdl, "mfma_i32_16x16x32_i8", None + ) + if mfma_i32_k32 is None: + raise AttributeError( + "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " + "(or `rocdl.mfma_i32_16x16x32_i8`)." + ) + return mfma_i32_k32 + +class PreshufflePipelineManager: + def __init__( + self, + a_dtype: str, + b_dtype: str, + out_dtype: str, + use_cshuffle_epilog: bool = False, + a_packed: bool = False, + b_packed: bool = False, + block_size: int = 256, + ): + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.out_dtype = out_dtype + self.use_cshuffle_epilog = use_cshuffle_epilog + self.a_packed = self.a_dtype in ["fp4"] + self.b_packed = self.b_dtype in ["fp4", "int4"] + self.a_elem_pack = 2 if self.a_packed else 1 + self.b_elem_pack = 2 if self.b_packed else 1 + self.mfma_pipeline = self.get_mfma_pipeline() + self.epilog_pipeline = self.get_epilog_pipeline() + self.a_elem_bytes = self.get_a_elem_bytes() + self.b_elem_bytes = self.get_b_elem_bytes() + self.out_elem_bytes = self.get_out_elem_bytes() + self.block_size = block_size + + def refine_dtype(self): + + + def check_type_valid(self): + if self.a_dtype not in ["fp8", "int8", "int4", "fp16", "bf16"]: + raise ValueError(f"Invalid a_dtype: {self.a_dtype}") + if self.b_dtype not in ["fp8", "int8", "int4", "fp16", "bf16"]: + raise ValueError(f"Invalid b_dtype: {self.b_dtype}") + if self.out_dtype not in ["fp16", "bf16", "f32"]: + raise ValueError(f"Invalid out_dtype: {self.out_dtype}") + + def get_mfma_pipeline(self): + if self.a_dtype == "fp4" and self.b_dtype == "fp4": + return MfmaPipeline.F4F4_MXFP4_PIPELINE + elif self.a_dtype == "fp8" and self.b_dtype == "fp4": + return MfmaPipeline.F8F4_MXFP4_PIPELINE + elif self.a_dtype == "fp8" and self.b_dtype == "fp8": + return MfmaPipeline.F8F8_MXFP4_PIPELINE + elif self.a_dtype == "fp16" and self.b_dtype == "fp16": + return MfmaPipeline.F16F16_16x16_PIPELINE + elif self.a_dtype == "bf16" and self.b_dtype == "bf16": + return MfmaPipeline.BF16BF16_16x16_PIPELINE + elif self.a_dtype == "int8" and self.b_dtype == "int8": + return MfmaPipeline.I8I8_16x16_PIPELINE + elif self.a_dtype == "int8" and self.b_dtype == "int4": + return MfmaPipeline.I8I4_16x16_PIPELINE + else: + raise ValueError(f"Invalid preshuffle pipeline: {self.a_dtype}_{self.b_dtype}_{self.out_dtype}") + + def get_epilog_pipeline(self): + if self.use_cshuffle_epilog and self.out_dtype == "fp16": + return EpilogPipeline.CSHUFFLE_F16 + elif self.use_cshuffle_epilog and self.out_dtype == "bf16": + return EpilogPipeline.CSHUFFLE_BF16 + elif self.use_cshuffle_epilog and self.out_dtype == "f32": + return EpilogPipeline.CSHUFFLE_F32 + elif not self.use_cshuffle_epilog and self.out_dtype == "f32": + return EpilogPipeline.DIRECT_F32 + elif not self.use_cshuffle_epilog and self.out_dtype == "f16": + return EpilogPipeline.DIRECT_F16 + elif not self.use_cshuffle_epilog and self.out_dtype == "bf16": + return EpilogPipeline.DIRECT_BF16 + else: + raise ValueError(f"Invalid epilog pipeline: {self.out_dtype}") + + def get_b_elem_bytes(self): + if self.b_dtype in ["fp8", "int8", "int4"]: + return 1 + elif self.b_dtype in ["fp16", "bf16"]: + return 2 + else: + raise ValueError(f"Invalid b_dtype: {self.b_dtype}") + + def get_a_elem_bytes(self): + if self.a_dtype in ["fp8", "int8", "int4"]: + return 1 + elif self.a_dtype in ["fp16", "bf16"]: + return 2 + else: + raise ValueError(f"Invalid a_dtype: {self.a_dtype}") + + def get_out_elem_bytes(self): + if self.out_dtype in ["fp16", "bf16"]: + return 2 + elif self.out_dtype == "f32": + return 4 + else: + raise ValueError(f"Invalid out_dtype: {self.out_dtype}") + + def get_mfma_fn(self): + if self.mfma_pipeline == MfmaPipeline.F8F6F4_PIPELINE: + return rocdl.mfma_f32_16x16x16f16 + elif self.mfma_pipeline == MfmaPipeline.BF16BF16_16x16_PIPELINE: + return rocdl.mfma_f32_16x16x16bf16_1k + elif self.mfma_pipeline == MfmaPipeline.F16F16_16x16_PIPELINE: + return rocdl.mfma_f32_16x16x16f16 + elif self.mfma_pipeline == MfmaPipeline.I8I8_16x16_PIPELINE: + return get_mfma_i32_k32() + elif self.mfma_pipeline == MfmaPipeline.I8I4_16x16_PIPELINE: + return get_mfma_i32_k32() + else: + raise ValueError(f"Invalid mfma pipeline: {self.mfma_pipeline}") + + def get_a_bytes_per_thread( + self, + tile_m: int, + tile_k: int, + ): + a_bytes_per_tile = int(tile_m) * int(tile_k) * int(self.a_elem_bytes) + if a_bytes_per_tile % self.block_size != 0: + raise ValueError( + "tile_m*tile_k*elem_bytes must be divisible by " + f"{self.block_size}: tile_m={tile_m}, tile_k={tile_k}, a_elem_bytes={self.a_elem_bytes}" + ) + a_bytes_per_thread = a_bytes_per_tile // self.block_size + + # Assume A loads are always 16B-aligned and use fixed dwordx4 (16B) buffer loads. + a_load_bytes = 16 + if a_bytes_per_thread % a_load_bytes != 0: + raise ValueError( + f"a_bytes_per_thread ({a_bytes_per_thread}) must be divisible by {a_load_bytes}" + ) + + return a_bytes_per_thread + + + class PreshuffleBLayout: """Container returned by `make_preshuffle_b_layout`.""" @@ -509,9 +746,230 @@ def lds_load_pack_k32( a_vec64 = vector.bitcast(vec1_i64_ty, loaded_a8) return vector.extract(a_vec64, static_position=[0], dynamic_position=[]) +def block_mfma_block_scale_f8f6f4( + accs_in, + b_tile_in, + a_scale, + b_scale, + lds_base, + lds_load_packs_k64, + col_offset_base_bytes, + row_a_lds, + *, + mfma_fn, + mfma_res_ty, + cbsz, + blgp, + a_elem_vec_pack, + k_unroll, + m_repeat, + num_acc, + pack_K, + pack_M, + pack_N, + a0_prefetch=None, +): + current_accs_list = list(accs_in) + + k_unroll_packed = k_unroll // pack_K + m_repeat_packed = m_repeat // pack_M + num_acc_n_packed = num_acc // pack_N + + mfma_res_ty = T.f32x4 + vec4_i64 = T.vec(4, T.i64) + vec8_i32 = T.vec(8, T.i32) + c0_i64 = arith.constant(0, type=T.i64) + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + for ku128 in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + a_scale_i32 = a_scale[ku128 * m_repeat_packed + mi] + a_scale_val = vector.extract(a_scale_i32, static_position=[0], dynamic_position=[]) + for ni in range_constexpr(num_acc_n_packed): + b_scale_i32 = b_scale[ku128 * num_acc_n_packed + ni] + b_scale_val = vector.extract(b_scale_i32, static_position=[0], dynamic_position=[]) + for ikxdl in range_constexpr(pack_K): + k_idx = ku128 * pack_K + ikxdl + + b_packs0, b_packs1 = b_tile_in[k_idx] + + col_base = col_offset_base_bytes + (k_idx * 128) // a_elem_vec_pack + for imxdl in range_constexpr(pack_M): + col_base0 = col_base + mi_idx = mi * pack_M + imxdl + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) + + col_base1 = col_base + 64 + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + + for inxdl in range_constexpr(pack_N): + ni_idx = ni * pack_N + inxdl + + b0 = b_packs0[ni_idx] + b1 = b_packs1[ni_idx] + b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) + + acc_idx = mi_idx * num_acc_n + ni_idx + current_accs_list[acc_idx] = mfma_fn( + mfma_res_ty, + [ + a128, + b128, + current_accs_list[acc_idx], + # cbsz, abid, blgp + cbsz, + blgp, + # op_sel_a + scale_a (1.0f as i32 bits) + ikxdl * pack_M + imxdl, + a_scale_val, + # + # op_sel_b + scale_b (1.0f as i32 bits) + ikxdl * pack_N + inxdl, + b_scale_val, + ], + ) + return current_accs_list, None + +# ---------------- gfx95 fast path (K128 MFMA scale) ---------------- +# This is the key optimization from `zhimding/develop_0107` for FP8: +# use mfma.scale 16x16x128 to reduce instruction count in the hot loop. +# +# Notes: +# - Only valid for fp8 path (not int8/int4) and gfx95+ +# - Requires tile_k divisible by 128 +# - mfma.scale takes 9 operands: 3 vectors + 6 i32 flags/scales. +def block_mfma_PTPC_f8f6f4( + accs_in, + b_tile_in, + lds_base, + col_offset_base_bytes, + row_a_lds, + lds_load_packs_k64, + *, + mfma_res_ty, + mfma_fn, + k_unroll=16, + num_acc_n=16, + m_repeat=16, + a0_prefetch=None, +): + + vec4_i64 = T.vec(4, T.i64) + vec8_i32 = T.vec(8, T.i32) + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + for ku128 in range_constexpr(k_unroll // 2): + ku0 = ku128 * 2 + ku1 = ku0 + 1 + + b0_packs0, b0_packs1 = b_tile_in[ku0] + b1_packs0, b1_packs1 = b_tile_in[ku1] + + col_base0 = col_offset_base_bytes + (ku0 * 64) + col_base1 = col_offset_base_bytes + (ku1 * 64) + + for mi in range_constexpr(m_repeat): + mi_val = arith.constant(mi * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (ku0 == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + + for ni in range_constexpr(num_acc_n): + b0 = b0_packs0[ni] + b1 = b0_packs1[ni] + b2 = b1_packs0[ni] + b3 = b1_packs1[ni] + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + + acc_idx = mi * num_acc_n + ni + current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_accs_list[acc_idx], + # cbsz, abid, blgp: 0 + 0, + 0, + 0, + # op_sel_a + scale_a (1.0f as i32 bits) + 0x3F800000, + # op_sel_b + scale_b (1.0f as i32 bits) + 0, + 0x3F800000, + ], + ) + return current_accs_list, scales_pf + + +def block_mfma_16x16( + accs_in, + b_tile_in, + lds_base, + col_offset_base_bytes, + row_a_lds, + lds_load_packs_k64, + mfma_fn, + *, + mfma_res_ty, + k_unroll, + num_acc_n, + m_repeat, + a0_prefetch=None, +): + def mfma_step(acc_in, a, b): + return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0]) + + # "K64-byte wrapper": two back-to-back MFMA/WMMA ops using the two 8B halves. + def mfma_k64_bytes(acc_in, a0, a1, b0, b1): + acc_mid = mfma_step(acc_in, a0, b0) + return mfma_step(acc_mid, a1, b1) + + for ku in range_constexpr(k_unroll): + b_packs0, b_packs1 = b_tile_in[ku] + # Byte-addressed K stepping (64B per ku). + ki64 = ku * 64 + col_base = col_offset_base_bytes + ki64 + for mi in range_constexpr(m_repeat): + mi_val = arith.constant(mi * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + current_accs_list[acc_idx] = mfma_k64_bytes( + current_accs_list[acc_idx], + a0, + a1, + b_packs0[ni], + b_packs1[ni], + ) __all__ = [ "PreshuffleBLayout", + "MfmaPipeline", + "EpilogPipeline", + "PreshufflePipelineManager", "buffer_copy_gmem16_dwordx4", "lds_load_pack_k32", "lds_store_4b_xor16", @@ -521,5 +979,8 @@ def lds_load_pack_k32( "make_preshuffle_scale_layout", "load_b_pack_k32", "tile_chunk_coord_i32", + "block_mfma_block_scale_f8f6f4", + "block_mfma_PTPC_f8f6f4", + "block_mfma_16x16", ] diff --git a/kernels/mixed_preshuffle_gemm.py b/kernels/mixed_preshuffle_gemm.py index b28f0f75..74b45801 100644 --- a/kernels/mixed_preshuffle_gemm.py +++ b/kernels/mixed_preshuffle_gemm.py @@ -82,11 +82,6 @@ def compile_mxfp4_preshuffle_gemm( pack_N = 2 pack_K = 2 - quant_block_size_a = 32 - quant_block_size_b = 32 - - - cbsz = 0 if is_fp8_a else 4 blgp = 4 @@ -528,96 +523,6 @@ def prefetch_ab_tile(base_k): b_regs = load_b_tile(base_k // 2) return a_regs, b_regs - def compute_tile( - accs_in, - b_tile_in, - lds_base, - *, - a0_prefetch=None, - a_scale=None, - b_scale=None, - ): - current_accs_list = list(accs_in) - - # ---------------- gfx95 fast path (K128 MFMA scale) ---------------- - # This is the key optimization from `zhimding/develop_0107` for FP8: - # use mfma.scale 16x16x128 to reduce instruction count in the hot loop. - # - # Notes: - # - Only valid for fp8 path (not int8/int4) and gfx95+ - # - Requires tile_k divisible by 128 - # - mfma.scale takes 9 operands: 3 vectors + 6 i32 flags/scales. - if (int(tile_k) % 128) != 0: - raise ValueError( - f"tile_k must be divisible by 128 for mfma_scale_x128, got tile_k={tile_k}" - ) - - mfma_res_ty = T.f32x4 - vec4_i64 = T.vec(4, T.i64) - vec8_i32 = T.vec(8, T.i32) - c0_i64 = arith.constant(0, type=T.i64) - - def pack_i64x4_to_i32x8(x0, x1, x2, x3): - v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) - return vector.bitcast(vec8_i32, v4) - - for ku128 in range_constexpr(k_unroll_packed): - for mi in range_constexpr(m_repeat_packed): - - a_scale_i32 = a_scale[ku128 * m_repeat_packed + mi] - a_scale_val = vector.extract(a_scale_i32, static_position=[0], dynamic_position=[]) - for ni in range_constexpr(num_acc_n_packed): - b_scale_i32 = b_scale[ku128 * num_acc_n_packed + ni] - b_scale_val = vector.extract(b_scale_i32, static_position=[0], dynamic_position=[]) - for ikxdl in range_constexpr(pack_K): - k_idx = ku128 * pack_K + ikxdl - - b_packs0, b_packs1 = b_tile_in[k_idx] - - col_base = col_offset_base_bytes + (k_idx * 128) // a_elem_vec_pack - for imxdl in range_constexpr(pack_M): - col_base0 = col_base - mi_idx = mi * pack_M + imxdl - mi_val = arith.constant(mi_idx * 16, index=True) - curr_row_a_lds = row_a_lds + mi_val - - if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): - a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) - - col_base1 = col_base + 64 - a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) - a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - - for inxdl in range_constexpr(pack_N): - ni_idx = ni * pack_N + inxdl - - b0 = b_packs0[ni_idx] - b1 = b_packs1[ni_idx] - b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) - - acc_idx = mi_idx * num_acc_n + ni_idx - rocdl.sched_barrier(0) - current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - b128, - current_accs_list[acc_idx], - # cbsz, abid, blgp - cbsz, - blgp, - # op_sel_a + scale_a (1.0f as i32 bits) - ikxdl * pack_M + imxdl, - a_scale_val, - # - # op_sel_b + scale_b (1.0f as i32 bits) - ikxdl * pack_N + inxdl, - b_scale_val, - ], - ) - return current_accs_list, None vec1_f16 = ir.VectorType.get([1], ir.F16Type.get()) vec2_f16 = ir.VectorType.get([2], ir.F16Type.get()) diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index ea4cd4bb..a52c0dd9 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -31,11 +31,15 @@ make_preshuffle_b_layout, load_b_pack_k32, tile_chunk_coord_i32, + PreshufflePipelineManager, + block_mfma_block_scale_f8f6f4, + block_mfma_PTPC_f8f6f4, + block_mfma_16x16, ) from kernels.mfma_epilogues import mfma_epilog -def compile_preshuffle_gemm_a8( +def compile_preshuffle_gemm( *, M: int, N: int, @@ -43,7 +47,9 @@ def compile_preshuffle_gemm_a8( tile_m: int, tile_n: int, tile_k: int, - in_dtype: str = "fp8", + a_dtype: str = "fp8", + b_dtype: str = "fp8", + out_dtype: str = "f16", lds_stage: int = 2, # Epilogue options use_cshuffle_epilog: bool = False, @@ -53,111 +59,103 @@ def compile_preshuffle_gemm_a8( Args: M, N, K: GEMM sizes (A[M,K], B[N,K], C[M,N]). tile_m, tile_n, tile_k: block tile sizes. - in_dtype: + a_dtype: - "fp8": A/B are fp8 (1B/elem) - "int8": A/B are int8 (1B/elem) - - "int4": W4A8 path: A is int8, B is packed int4 (2 values per byte) and unpacked to int8 in-kernel. + - "int4": W4A8 path: A is int8 (1B/elem). + b_dtype: + - "fp8": A/B are fp8 (1B/elem) + - "int8": A/B are int8 (1B/elem) + - "int4": W4A8 path: B is packed int4 (2 values per byte) and unpacked to int8 in-kernel. + out_dtype: + - "fp16": Output is fp16 (2B/elem) + - "bf16": Output is bf16 (2B/elem) + - "f32": Output is f32 (4B/elem) lds_stage: - 2: ping-pong LDS for A (2 LDS buffers), tuned schedule (original). - 1: single LDS buffer for A . """ - if in_dtype not in ("fp8", "int8", "int4", "fp16", "bf16"): - raise ValueError( - "in_dtype must be one of ('fp8','int8','int4','fp16','bf16'), " - f"got {in_dtype!r}" - ) - is_int4 = in_dtype == "int4" - is_int8 = (in_dtype == "int8") or is_int4 - is_f16 = in_dtype == "fp16" - is_bf16 = in_dtype == "bf16" - is_f16_or_bf16 = is_f16 or is_bf16 - elem_bytes = 1 if (in_dtype in ("fp8", "int8", "int4")) else 2 + pipeline_manager = PreshufflePipelineManager(a_dtype, b_dtype, out_dtype) + pipeline_manager.check_type_valid() + + epilog_pipeline = pipeline_manager.get_epilog_pipeline() + mfma_pipeline = pipeline_manager.get_mfma_pipeline() + mfma_fn = pipeline_manager.get_mfma_fn() + + a_elem_bytes = pipeline_manager.get_a_elem_bytes() + b_elem_bytes = pipeline_manager.get_b_elem_bytes() + + a_elem_pack = pipeline_manager.a_elem_pack + b_elem_pack = pipeline_manager.b_elem_pack # Pipeline is byte-addressed along K (16B loads, XOR16 swizzle in bytes). - # For fp16/bf16 (2B/elem), user passes tile_k halved so tile_k_bytes stays constant. - tile_k_bytes = int(tile_k) * int(elem_bytes) + # For fp16/bf16 (2B/elem), user passes tile_k halved so b_tile_k_bytes stays constant. + b_tile_k_bytes = int(tile_k) * int(b_elem_bytes) # K64-byte micro-step wrapper uses 2x "half-pack" MFMA. # - fp8/i8: mfma K32, wrapper covers 64B (=64 elems) # - fp16/bf16: mfma K16, wrapper covers 64B (=32 elems) -> effective K halves - if (tile_k_bytes % 64) != 0: + if (b_tile_k_bytes % 64) != 0: raise ValueError( - f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " - f"(tile_k={tile_k}, elem_bytes={elem_bytes})" + f"b_tile_k_bytes must be divisible by 64, got b_tile_k_bytes={b_tile_k_bytes} " + f"(tile_k={tile_k}, b_elem_bytes={b_elem_bytes})" ) - # INT8 must use a K32 MFMA so the micro-step matches the FP8 path (strict alignment). - mfma_i32_k32 = None - if is_int8: - mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( - rocdl, "mfma_i32_16x16x32_i8", None - ) - if mfma_i32_k32 is None: - raise AttributeError( - "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " - "(or `rocdl.mfma_i32_16x16x32_i8`)." - ) + a_tile_k_bytes = int(tile_k) * int(a_elem_bytes) gpu_arch = get_hip_arch() allocator = SmemAllocator(None, arch=gpu_arch) _state = {} - # Default-on: cross-tile (tile_k) A0 LDS prefetch in the ping-pong pipeline (lds_stage=2). - # - # This issues the *first* A-pack LDS read for the next tile between barriers, to overlap - # with the VMEM prefetch of the following tile. - DYN = ir.ShapedType.get_dynamic_size() - # Vector width calc (assume full tiles / no tail guards). - total_threads = 256 - bytes_a_per_tile = int(tile_m) * int(tile_k) * int(elem_bytes) - if bytes_a_per_tile % total_threads != 0: - raise ValueError( - "tile_m*tile_k*elem_bytes must be divisible by " - f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={elem_bytes}" - ) - bytes_per_thread_a = bytes_a_per_tile // total_threads - - # Assume A loads are always 16B-aligned and use fixed dwordx4 (16B) buffer loads. - a_load_bytes = 16 - if bytes_per_thread_a % a_load_bytes != 0: - raise ValueError( - f"bytes_per_thread_a ({bytes_per_thread_a}) must be divisible by {a_load_bytes}" - ) + # a is copied by the whole block, so we need to calculate the bytes per thread. + a_bytes_per_thread = pipeline_manager.get_a_bytes_per_thread(tile_m, tile_k) # CK-style LDS128: stride is in BYTES along K (for XOR16 swizzle). - lds_stride_bytes = tile_k_bytes - - def _elem_type(): - if is_f16: - return T.f16 - if is_bf16: - return T.bf16 - return T.i8 if is_int8 else T.f8 - - def _vec16_type(): - if is_f16: - return T.f16x8 # 16B - if is_bf16: - return T.bf16x8 # 16B - return T.i8x16 if is_int8 else T.f8x16 - - def _mfma_pack_ty(): - # ROCDL MFMA intrinsics expect specific operand vector types: - # - fp8/int8 paths use i64 packs (8 bytes) - # - fp16 uses v4f16 (8 bytes) - # - bf16 uses v4i16 (8 bytes) for *_bf16_1k variants - if is_f16: - return T.f16x4 - if is_bf16: - return T.i16x4 - return T.i64 + lds_stride_bytes = a_tile_k_bytes // a_elem_pack + + def _a_elem_type(): + return a_elem_type_dict[mfma_pipeline] + def _b_elem_type(): + return b_elem_type_dict[mfma_pipeline] + def _scale_elem_type(): + return scale_elem_type_dict[mfma_pipeline] + def _out_elem_type(): + return out_elem_type_dict[epilog_pipeline] + def _a_vec16_type(): + return a_vec16_type_dict[mfma_pipeline] + def _b_vec16_type(): + return b_vec16_type_dict[mfma_pipeline] + def _mfma_input_pack_ty(): + return mfma_input_pack_ty_dict[mfma_pipeline] + def _mfma_output_pack_ty(): + return mfma_output_pack_ty_dict[mfma_pipeline] + + is_f16_or_bf16 = mfma_pipeline in [MfmaPipeline.F16F16_16x16_PIPELINE, + MfmaPipeline.BF16BF16_16x16_PIPELINE] + is_int4 = mfma_pipeline in [MfmaPipeline.I8I4_16x16_PIPELINE] + is_int8 = mfma_pipeline in [MfmaPipeline.I8I8_16x16_PIPELINE] + is_fp4 = mfma_pipeline in [MfmaPipeline.FP4FP4_16x16_PIPELINE, + MfmaPipeline.FP8FP4_16x16_PIPELINE] + + cbsz = 0 if mfma_pipeline == MfmaPipeline.FP4FP4_16x16_PIPELINE else 4 + blgp = 4 if mfma_pipeline in [MfmaPipeline.FP8FP4_16x16_PIPELINE, MfmaPipeline.FP8FP8_16x16_PIPELINE] else 0 + + pack_M = 2 + pack_N = 2 + pack_K = 2 + + use_mfma_scale_128 = ( + str(gpu_arch).startswith("gfx95") + and (not is_int8) + and (not is_int4) + and (not is_f16_or_bf16) + ) # GEMM epilogue toggle: optional LDS CShuffle + vectorized stores. # Default: off (current measured cases show no benefit). - epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" - module_name = f"mfma_preshuffle_{lds_stage}stages_{in_dtype}_{epilog_tag}".replace("-", "_") + module_name = f"mfma_preshuffle_{lds_stage}stages_{mfma_pipeline}_{epilog_pipeline}".replace("-", "_") class _GEMM(flir.MlirModule): GPU_MODULE_NAME = module_name @@ -184,11 +182,11 @@ def init_gpu_module(self): @flir.kernel def kernel_gemm( self: flir.T.i64, - arg_c: lambda: memref(DYN, T.f16), - arg_a: lambda: memref(DYN, _elem_type()), - arg_b: lambda: memref(DYN, _elem_type()), - arg_scale_a: lambda: memref(DYN, T.f32), - arg_scale_b: lambda: memref(DYN, T.f32), + arg_c: lambda: memref(DYN, _out_elem_type()), + arg_a: lambda: memref(DYN, _a_elem_type()), + arg_b: lambda: memref(DYN, _b_elem_type()), + arg_scale_a: lambda: memref(DYN, _scale_elem_type()), + arg_scale_b: lambda: memref(DYN, _scale_elem_type()), c_m: lambda: T.index, c_n: lambda: T.index, c_k: lambda: T.index, @@ -197,9 +195,7 @@ def kernel_gemm( # NOTE: Some environments have multiple `flydsl` builds on PYTHONPATH. # Use explicit MLIR Values (not Python ints / wrapper objects) for ROCDL ops. acc_init = arith.unwrap( - arith.constant_vector(0, T.i32x4) - if is_int8 - else arith.constant_vector(0.0, T.f32x4) + arith.constant_vector(0, _mfma_output_pack_ty()) ) # Layouts @@ -207,16 +203,13 @@ def kernel_gemm( # A uses dword indexing (buffer-load dwordx4). Convert element index -> dword index: # dword_index = (elem_index * elem_bytes) / 4 - if (int(elem_bytes) == 2): - c_k_div4bytes = (c_k * 2) / 4 - else: - c_k_div4bytes = c_k / 4 + c_k_div4bytes = c_k * a_elem_bytes / 4 / a_elem_pack layout_a_div4 = flir.make_layout((c_m, c_k_div4bytes), stride=(c_k_div4bytes, 1)) # B preshuffle layout (shared with MoE kernels). - kpack_bytes = 8 if is_int4 else 16 + b_kpack_bytes = 16 / b_elem_pack layout_b = make_preshuffle_b_layout( - flir, arith, c_n=c_n, c_k=c_k, kpack_bytes=kpack_bytes, elem_bytes=elem_bytes + flir, arith, c_n=c_n, c_k=c_k, kpack_bytes=b_kpack_bytes, elem_bytes=b_elem_bytes ).layout_b # LDS layout is element-indexed, but XOR16 swizzle is byte-based. @@ -226,7 +219,7 @@ def kernel_gemm( layout_lds = flir.make_layout(shape_lds, stride_lds) # CK-style XOR16 swizzle parameter (const). - k_blocks16 = arith.index(tile_k_bytes // 16) + a_k_blocks16 = arith.index(a_tile_k_bytes // 16) tx = gpu.thread_id("x") bx = gpu.block_id("x") @@ -236,7 +229,7 @@ def kernel_gemm( lds_a_ptr = _state["lds_a_decl"](base_ptr) lds_a = lds_a_ptr.get() lds_out = ( - SmemPtr(base_ptr, lds_a_ptr.byte_offset, T.f16, shape=(tile_m * tile_n,)).get() + SmemPtr(base_ptr, lds_a_ptr.byte_offset, _out_elem_type(), shape=(tile_m * tile_n,)).get() if use_cshuffle_epilog else None ) @@ -272,19 +265,18 @@ def kernel_gemm( # - fp16/bf16: 8 elems (2B) # # We express `col_offset_base` in *elements*. - kpack_elems = 16 if elem_bytes == 1 else 8 - col_offset_base = lane_div_16 * arith.constant(int(kpack_elems), index=True) + kpack_elems = 16 + a_kpack_elems = kpack_elems / a_elem_bytes + col_offset_base = lane_div_16 * arith.constant(int(a_kpack_elems), index=True) # `col_offset_base` is in element units (multiples of 16). We do LDS swizzle/math # in bytes, so scale by element size for fp16/bf16. col_offset_base_bytes = ( - col_offset_base - if elem_bytes == 1 - else (col_offset_base * arith.constant(int(elem_bytes), index=True)) + col_offset_base * arith.constant(int(a_elem_bytes), index=True) ) m_repeat = tile_m // 16 # K stepping is byte-addressed: one "micro-tile step" is 64 bytes. - k_unroll = tile_k_bytes // 64 + k_unroll = b_tile_k_bytes // 64 # --- Dynamic tiling along N (4 waves) --- num_waves = 4 @@ -311,7 +303,6 @@ def kernel_gemm( # Shared loader supports: # - FP8/INT8: explicit 16B load (one full KPack) + extract 8B for this micro-step # - INT4 (W4A8): 4B load + 7-op unpack to 8B (no v_perm) - def load_b_pack(base_k, ki_step, ni): return load_b_pack_k32( buffer_ops, @@ -327,7 +318,7 @@ def load_b_pack(base_k, ki_step, ni): n_intra=n_intra_list[ni], lane_div_16=lane_div_16, elem_type=_elem_type(), - kpack_bytes=kpack_bytes, + kpack_bytes=b_kpack_bytes, elem_bytes=elem_bytes, unpack_int4=is_int4, ) @@ -345,13 +336,14 @@ def load_b_packs_k64(base_k, ku: int, ni: int): return load_b_pack(base_k, ki0, ni), load_b_pack(base_k, ki1, ni) # FP8/INT8/FP16/BF16: load 16 bytes (one full KPack). - base_k_bytes = base_k * arith.constant(int(elem_bytes), index=True) + base_k_bytes = base_k * arith.constant(int(b_elem_bytes), index=True) k0_base = base_k_bytes / c64_b k0 = k0_base + ku k1 = lane_div_16 coord_pack = flir.make_coord(n_blk_list[ni], k0, k1, n_intra_list[ni], c0_idx) idx_pack = flir.crd2idx(coord_pack, layout_b) - vec_elems = 16 if elem_bytes == 1 else 8 + vec_elems = kpack_elems / b_elem_bytes + b_view = flir.TensorView( arg_b, (vec_elems,), @@ -368,6 +360,7 @@ def load_b_packs_k64(base_k, ku: int, ni: int): src_buffer_resource=(b_rsrc if elem_bytes == 1 else None), src_buffer_offset_in_bytes=(elem_bytes == 1), ) + # Split 16B pack into two 8B halves. b_i64x2 = vector.bitcast(T.i64x2, b16) b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) @@ -382,11 +375,7 @@ def load_b_packs_k64(base_k, ku: int, ni: int): vec1_i64_ty = ir.VectorType.get([1], ir.IntegerType.get_signless(64)) b0_v1 = vector.from_elements(vec1_i64_ty, [b0_i64]) b1_v1 = vector.from_elements(vec1_i64_ty, [b1_i64]) - if is_f16: - return vector.bitcast(T.f16x4, b0_v1), vector.bitcast(T.f16x4, b1_v1) - return vector.bitcast(T.i16x4, b0_v1), vector.bitcast(T.i16x4, b1_v1) - - # int8 path should not reach here (handled by the outer is_int8 branch). + return vector.bitcast(_mfma_input_pack_ty(), b0_v1), vector.bitcast(_mfma_input_pack_ty(), b1_v1) def load_b_tile(base_k): # b_tile[ku] = (packs_half0[ni], packs_half1[ni]) @@ -403,7 +392,7 @@ def load_b_tile(base_k): def lds_load_16b(curr_row_a_lds, col_base, lds_base): # Swizzle in bytes, then convert to element offset for memref indexing. - col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) + col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base, a_k_blocks16) col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / 2) coord_a16 = flir.make_coord(curr_row_a_lds, col_base_swz) idx_a16 = flir.crd2idx(coord_a16, layout_lds) @@ -423,9 +412,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): vec1_i64_ty = ir.VectorType.get([1], ir.IntegerType.get_signless(64)) a0_v1 = vector.from_elements(vec1_i64_ty, [a0_i64]) a1_v1 = vector.from_elements(vec1_i64_ty, [a1_i64]) - if is_f16: - return vector.bitcast(T.f16x4, a0_v1), vector.bitcast(T.f16x4, a1_v1) - return vector.bitcast(T.i16x4, a0_v1), vector.bitcast(T.i16x4, a1_v1) + return vector.bitcast(_mfma_input_pack_ty(), a0_v1), vector.bitcast(_mfma_input_pack_ty(), a1_v1) # --- A load/store (16B chunks), XOR16 swizzle --- num_a_loads = bytes_per_thread_a // a_load_bytes @@ -438,17 +425,17 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): layout_a_tile_div4 = flir.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) c4 = arith.constant(4, index=True) tx_i32_base = tx * c4 - atom_a_g2r16 = flir.make_copy_atom(_elem_type(), vector_size=(16 if elem_bytes == 1 else 8)) + atom_a_g2r16 = flir.make_copy_atom(_a_elem_type(), vector_size=a_kpack_elems) def load_a_16(idx_elem): return buffer_copy_gmem16_dwordx4( flir, arg=arg_a, - elem_type=_elem_type(), + elem_type=_a_elem_type(), idx_i32=idx_elem, atom_g2r16=atom_a_g2r16, rsrc=a_rsrc, - vec_elems=(16 if elem_bytes == 1 else 8), + vec_elems=a_kpack_elems, ) def a_tile_chunk_coord_i32(i: int): @@ -471,11 +458,8 @@ def load_a_tile(base_k_div4): # `idx_i32` is a dword offset. For 2B element types (fp16/bf16), # convert to element offset so the generic `vector.load` path reads # the right address (FLIR only specializes buffer_load_dwordx4 for 1B types). - idx_elem = ( - idx_i32 - if elem_bytes == 1 - else (idx_i32 * arith.constant(2, index=True)) - ) + idx_elem = idx_i32 * arith.constant(a_elem_bytes, index=True) + a_16B = load_a_16(idx_elem) parts.append(vector.bitcast(T.i32x4, a_16B)) return parts @@ -488,28 +472,32 @@ def store_a_tile_to_lds(vec_a_parts, lds_base): arith, vector, lds_memref=lds_a, - vec16_ty=_vec16_type(), - elem_type=_elem_type(), + vec16_ty=_a_vec16_type(), + elem_type=_a_elem_type(), atom_s16=atom_a_g2r16, layout_lds=layout_lds, row_local=row_a_local, col_local_i32=col_a_local_i32, tx_c4=c4, - k_blocks16=k_blocks16, + k_blocks16=a_k_blocks16, lds_base=lds_base, vec_part_i32x4=vec_a_parts[i], - elem_bytes=elem_bytes, + elem_bytes=a_elem_bytes, ) def prefetch_ab_tile(base_k): - base_k_bytes = base_k * arith.constant(int(elem_bytes), index=True) + base_k_bytes = base_k + + # div4 for Double word is 4 times of a byte base_k_div4 = base_k_bytes / 4 - a_regs = load_a_tile(base_k_div4) - b_regs = load_b_tile(base_k) + a_regs = load_a_tile(base_k_div4 * arith.constant(int(a_elem_bytes), index=True)) + b_regs = load_b_tile(base_k * arith.constant(int(b_elem_bytes), index=True)) return a_regs, b_regs def compute_tile(accs_in, b_tile_in, lds_base, *, is_last_tile=False, a0_prefetch=None): scales_pf = {} + + mfma_res_ty = _mfma_output_pack_ty() if is_last_tile and (not is_f16_or_bf16): # Prefetch scales (fp8/int8/int4 only). s_b_vals = [] @@ -532,125 +520,60 @@ def compute_tile(accs_in, b_tile_in, lds_base, *, is_last_tile=False, a0_prefetc scales_pf["s_a_vecs"].append(vector.bitcast(T.f32x4, s_a_vec)) current_accs_list = list(accs_in) + if is_fp4: + current_accs_list = block_mfma_block_scale_f8f6f4( + current_accs_list, + b_tile_in, + a_scale, + b_scale, + lds_base, + lds_load_packs_k64=lds_load_packs_k64, + col_offset_base_bytes=col_offset_base_bytes, + row_a_lds=row_a_lds, + mfma_fn=rocdl.mfma_scale_f32_16x16x128_f8f6f4, + mfma_res_ty=mfma_res_ty, + cbsz=0, + blgp=0, + k_unroll=k_unroll, + num_acc_n=num_acc_n, + m_repeat=m_repeat, + pack_K=2, + pack_M=2, + pack_N=2, + a0_prefetch=a0_prefetch, + ) - # ---------------- gfx95 fast path (K128 MFMA scale) ---------------- - # This is the key optimization from `zhimding/develop_0107` for FP8: - # use mfma.scale 16x16x128 to reduce instruction count in the hot loop. - # - # Notes: - # - Only valid for fp8 path (not int8/int4) and gfx95+ - # - Requires tile_k divisible by 128 - # - mfma.scale takes 9 operands: 3 vectors + 6 i32 flags/scales. - use_mfma_scale_128 = ( - str(gpu_arch).startswith("gfx95") - and (not is_int8) - and (not is_int4) - and (not is_f16_or_bf16) - ) if use_mfma_scale_128: - if (int(tile_k) % 128) != 0: - raise ValueError( - f"tile_k must be divisible by 128 for mfma_scale_x128, got tile_k={tile_k}" - ) - - mfma_res_ty = T.f32x4 - vec4_i64 = T.vec(4, T.i64) - vec8_i32 = T.vec(8, T.i32) - - def pack_i64x4_to_i32x8(x0, x1, x2, x3): - v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) - return vector.bitcast(vec8_i32, v4) - - for ku128 in range_constexpr(k_unroll // 2): - ku0 = ku128 * 2 - ku1 = ku0 + 1 - - b0_packs0, b0_packs1 = b_tile_in[ku0] - b1_packs0, b1_packs1 = b_tile_in[ku1] - - col_base0 = col_offset_base_bytes + (ku0 * 64) - col_base1 = col_offset_base_bytes + (ku1 * 64) - - for mi in range_constexpr(m_repeat): - mi_val = arith.constant(mi * 16, index=True) - curr_row_a_lds = row_a_lds + mi_val - - if (a0_prefetch is not None) and (ku0 == 0) and (mi == 0): - a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) - a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) - a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - - for ni in range_constexpr(num_acc_n): - b0 = b0_packs0[ni] - b1 = b0_packs1[ni] - b2 = b1_packs0[ni] - b3 = b1_packs1[ni] - b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) - - acc_idx = mi * num_acc_n + ni - current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - b128, - current_accs_list[acc_idx], - # cbsz, abid, blgp: 0 - 0, - 0, - 0, - # op_sel_a + scale_a (1.0f as i32 bits) - 0x3F800000, - # op_sel_b + scale_b (1.0f as i32 bits) - 0, - 0x3F800000, - ], - ) + current_accs_list = block_mfma_PTPC_f8f6f4( + current_accs_list, + b_tile_in, + lds_base, + col_offset_base_bytes=col_offset_base_bytes, + row_a_lds=row_a_lds, + lds_load_packs_k64=lds_load_packs_k64, + mfma_fn=rocdl.mfma_scale_f32_16x16x128_f8f6f4, + mfma_res_ty=mfma_res_ty, + k_unroll=k_unroll, + num_acc_n=num_acc_n, + m_repeat=m_repeat, + a0_prefetch=a0_prefetch, + ) return current_accs_list, scales_pf - mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 - - if is_int8: - mfma_fn = mfma_i32_k32 - elif is_f16: - # gfx942 fp16 MFMA: 16x16x16 f16 (operands are v4f16, 8B packs) - mfma_fn = rocdl.mfma_f32_16x16x16f16 - elif is_bf16: - # bf16 MFMA K16 variant uses i16 bit-pattern packs (v4i16) - mfma_fn = rocdl.mfma_f32_16x16x16bf16_1k - else: - mfma_fn = rocdl.mfma_f32_16x16x32_fp8_fp8 - - def mfma_step(acc_in, a, b): - return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0]) - - # "K64-byte wrapper": two back-to-back MFMA/WMMA ops using the two 8B halves. - def mfma_k64_bytes(acc_in, a0, a1, b0, b1): - acc_mid = mfma_step(acc_in, a0, b0) - return mfma_step(acc_mid, a1, b1) - - for ku in range_constexpr(k_unroll): - b_packs0, b_packs1 = b_tile_in[ku] - # Byte-addressed K stepping (64B per ku). - ki64 = ku * 64 - col_base = col_offset_base_bytes + ki64 - for mi in range_constexpr(m_repeat): - mi_val = arith.constant(mi * 16, index=True) - curr_row_a_lds = row_a_lds + mi_val - if (a0_prefetch is not None) and (ku == 0) and (mi == 0): - a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) - for ni in range_constexpr(num_acc_n): - acc_idx = mi * num_acc_n + ni - current_accs_list[acc_idx] = mfma_k64_bytes( - current_accs_list[acc_idx], - a0, - a1, - b_packs0[ni], - b_packs1[ni], - ) + current_accs_list = block_mfma_16x16( + current_accs_list, + b_tile_in, + lds_base, + col_offset_base_bytes=col_offset_base_bytes, + row_a_lds=row_a_lds, + lds_load_packs_k64=lds_load_packs_k64, + mfma_fn=mfma_fn, + mfma_res_ty=mfma_res_ty, + k_unroll=k_unroll, + num_acc_n=num_acc_n, + m_repeat=m_repeat, + a0_prefetch=a0_prefetch, + ) return current_accs_list, scales_pf vec1_f16 = ir.VectorType.get([1], ir.F16Type.get()) @@ -1062,11 +985,11 @@ def prefetch_a0_pack(lds_base): @flir.jit def __call__( self: flir.T.i64, - arg_c: lambda: memref(DYN, T.f16), - arg_a: lambda: memref(DYN, _elem_type()), - arg_b: lambda: memref(DYN, _elem_type()), - arg_scale_a: lambda: memref(DYN, T.f32), - arg_scale_b: lambda: memref(DYN, T.f32), + arg_c: lambda: memref(DYN, _out_elem_type()), + arg_a: lambda: memref(DYN, _a_elem_type()), + arg_b: lambda: memref(DYN, _b_elem_type()), + arg_scale_a: lambda: memref(DYN, _scale_elem_type()), + arg_scale_b: lambda: memref(DYN, _scale_elem_type()), c_m: lambda: T.index, c_n: lambda: T.index, c_k: lambda: T.index, @@ -1103,5 +1026,4 @@ def __call__( ) -__all__ = ["compile_preshuffle_gemm_a8"] - +__all__ = ["compile_preshuffle_gemm"] \ No newline at end of file diff --git a/tests/kernels/test_preshuffle_gemm.py b/tests/kernels/test_preshuffle_gemm.py index 5e0587d0..256d7d1b 100644 --- a/tests/kernels/test_preshuffle_gemm.py +++ b/tests/kernels/test_preshuffle_gemm.py @@ -28,7 +28,7 @@ if _PYFLIR_SRC not in sys.path: sys.path.insert(0, _PYFLIR_SRC) -from kernels.preshuffle_gemm import compile_preshuffle_gemm_a8 +from kernels.preshuffle_gemm import compile_preshuffle_gemm from kernels.mixed_preshuffle_gemm import compile_mxfp4_preshuffle_gemm from tests.test_common import run_perftest, verify_output from tests.utils import pertoken_quant, shuffle_weight @@ -130,14 +130,16 @@ def test_mfma_a8_flir_preshuffle( raise ValueError( f"lds_stage must be 1 or 2, got {lds_stage!r}" ) - exe = compile_preshuffle_gemm_a8( + exe = compile_preshuffle_gemm( M=M, N=N, K=K, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, - in_dtype=in_dtype, + a_dtype=in_dtype, + b_dtype=in_dtype, + out_dtype="f16", lds_stage=lds_stage, use_cshuffle_epilog=bool(use_cshuffle_epilog), ) From 0e5891225010f9b954674454ad917d91f943ddd7 Mon Sep 17 00:00:00 2001 From: zanzhang Date: Wed, 4 Feb 2026 01:26:24 +0800 Subject: [PATCH 02/11] 308 int8 a8w4 bf16 fp8 ready --- kernels/mfma_preshuffle_pipeline.py | 105 +++++++--- kernels/moe_gemm_2stage.py | 2 + kernels/preshuffle_gemm.py | 287 +++++++++++++++++--------- tests/kernels/test_preshuffle_gemm.py | 45 ++-- 4 files changed, 298 insertions(+), 141 deletions(-) diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 0c35b5ce..d11b767d 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -12,6 +12,8 @@ from __future__ import annotations from dataclasses import dataclass +import re +from flydsl.dialects.ext.python_control_flow import range_constexpr from flydsl.dialects.ext import arith, gpu, buffer_ops, vector, rocdl from flydsl.lang.ir.types import T, memref @@ -20,8 +22,9 @@ from enum import Enum class MfmaPipeline(Enum): + F4F4_MXFP4_PIPELINE = "F4F4_MXFP4_PIPELINE" F8F4_MXFP4_PIPELINE = "F8F4_MXFP4_PIPELINE" - F8F8_MXFP4_PIPELINE = "F8F8_MXFP4_PIPELINE" + F8F8_16x16_PIPELINE = "F8F8_16x16_PIPELINE" F16F16_16x16_PIPELINE = "F16F16_16x16_PIPELINE" BF16BF16_16x16_PIPELINE = "BF16BF16_16x16_PIPELINE" I8I8_16x16_PIPELINE = "I8I8_16x16_PIPELINE" @@ -36,8 +39,9 @@ class EpilogPipeline(Enum): DIRECT_F32 = "DIRECT_F32" a_elem_type_dict = { - MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.f8, - MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.f8, + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.ui8, + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.ui8, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f8, MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16, MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16, MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i8, @@ -45,17 +49,19 @@ class EpilogPipeline(Enum): } b_elem_type_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.ui8, MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.ui8, - MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.f8, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f8, MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16, MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16, MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i8, - MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.f8, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i8, } scale_elem_type_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.i32, MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.i32, - MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.i32, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f32, MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.f32, MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.f32, # bf16 scale placeholder @@ -73,26 +79,29 @@ class EpilogPipeline(Enum): } a_vec16_type_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.ui8x16, MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.f8x16, - MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda:T.f8x16, - MfmaPipeline.F16F16_16x16_PIPELINE: lambda:T.f16x8, - MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda:T.bf16x8, - MfmaPipeline.I8I8_16x16_PIPELINE: lambda:T.i8x16, - MfmaPipeline.I8I4_16x16_PIPELINE: lambda:T.i8x16, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f8x16, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16x8, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16x8, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i8x16, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i8x16, } b_vec16_type_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.ui8x16, MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.ui8x16, - MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.f8x16, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f8x16, MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16x8, MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16x8, MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i8x16, - MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.f8x16, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i8x16, } mfma_input_pack_ty_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.i64, MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.i64, - MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.i64, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.i64, MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16x4, MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.i16x4, MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i32x4, @@ -100,14 +109,26 @@ class EpilogPipeline(Enum): } mfma_output_pack_ty_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.f32x4, MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.f32x4, - MfmaPipeline.F8F8_MXFP4_PIPELINE: lambda: T.f32x4, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f32x4, MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f32x4, MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.f32x4, MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i32x4, MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i32x4, } +mfma_pipeline_dicts = { + "a_elem_type": a_elem_type_dict, + "b_elem_type": b_elem_type_dict, + "scale_elem_type": scale_elem_type_dict, + "out_elem_type": out_elem_type_dict, + "a_vec16_type": a_vec16_type_dict, + "b_vec16_type": b_vec16_type_dict, + "mfma_input_pack_ty": mfma_input_pack_ty_dict, + "mfma_output_pack_ty": mfma_output_pack_ty_dict, +} + def get_mfma_i32_k32(): mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( rocdl, "mfma_i32_16x16x32_i8", None @@ -133,6 +154,9 @@ def __init__( self.a_dtype = a_dtype self.b_dtype = b_dtype self.out_dtype = out_dtype + self.refine_dtype() + self.check_type_valid() + self.use_cshuffle_epilog = use_cshuffle_epilog self.a_packed = self.a_dtype in ["fp4"] self.b_packed = self.b_dtype in ["fp4", "int4"] @@ -146,12 +170,30 @@ def __init__( self.block_size = block_size def refine_dtype(self): - + def _normalize_dtype(value: str) -> str: + s = str(value).strip().lower() + s = re.sub(r"^(f16|float16|half)$", "fp16", s) + s = re.sub(r"^(bf16|bfloat16)$", "bf16", s) + s = re.sub(r"^(f32|fp32|float|float32)$", "f32", s) + s = re.sub(r"^(fp8|f8)$", "fp8", s) + s = re.sub(r"^(fp4|f4)$", "fp4", s) + s = re.sub(r"^(int8|i8)$", "int8", s) + s = re.sub(r"^(int4|i4)$", "int4", s) + return s + + self.a_dtype = _normalize_dtype(self.a_dtype) + self.b_dtype = _normalize_dtype(self.b_dtype) + self.out_dtype = _normalize_dtype(self.out_dtype) + + if self.out_dtype not in ("fp16", "bf16", "f32"): + raise ValueError( + f"out_dtype must be 'f16', 'bf16', or 'f32', got {self.out_dtype!r}" + ) def check_type_valid(self): - if self.a_dtype not in ["fp8", "int8", "int4", "fp16", "bf16"]: + if self.a_dtype not in ["fp8", "fp4", "int8", "fp16", "bf16"]: raise ValueError(f"Invalid a_dtype: {self.a_dtype}") - if self.b_dtype not in ["fp8", "int8", "int4", "fp16", "bf16"]: + if self.b_dtype not in ["fp8", "fp4", "int8", "int4", "fp16", "bf16"]: raise ValueError(f"Invalid b_dtype: {self.b_dtype}") if self.out_dtype not in ["fp16", "bf16", "f32"]: raise ValueError(f"Invalid out_dtype: {self.out_dtype}") @@ -162,7 +204,7 @@ def get_mfma_pipeline(self): elif self.a_dtype == "fp8" and self.b_dtype == "fp4": return MfmaPipeline.F8F4_MXFP4_PIPELINE elif self.a_dtype == "fp8" and self.b_dtype == "fp8": - return MfmaPipeline.F8F8_MXFP4_PIPELINE + return MfmaPipeline.F8F8_16x16_PIPELINE elif self.a_dtype == "fp16" and self.b_dtype == "fp16": return MfmaPipeline.F16F16_16x16_PIPELINE elif self.a_dtype == "bf16" and self.b_dtype == "bf16": @@ -183,7 +225,7 @@ def get_epilog_pipeline(self): return EpilogPipeline.CSHUFFLE_F32 elif not self.use_cshuffle_epilog and self.out_dtype == "f32": return EpilogPipeline.DIRECT_F32 - elif not self.use_cshuffle_epilog and self.out_dtype == "f16": + elif not self.use_cshuffle_epilog and self.out_dtype == "fp16": return EpilogPipeline.DIRECT_F16 elif not self.use_cshuffle_epilog and self.out_dtype == "bf16": return EpilogPipeline.DIRECT_BF16 @@ -191,7 +233,7 @@ def get_epilog_pipeline(self): raise ValueError(f"Invalid epilog pipeline: {self.out_dtype}") def get_b_elem_bytes(self): - if self.b_dtype in ["fp8", "int8", "int4"]: + if self.b_dtype in ["fp8", "int8", "int4", "fp4"]: return 1 elif self.b_dtype in ["fp16", "bf16"]: return 2 @@ -199,7 +241,7 @@ def get_b_elem_bytes(self): raise ValueError(f"Invalid b_dtype: {self.b_dtype}") def get_a_elem_bytes(self): - if self.a_dtype in ["fp8", "int8", "int4"]: + if self.a_dtype in ["fp8", "int8", "int4", "fp4"]: return 1 elif self.a_dtype in ["fp16", "bf16"]: return 2 @@ -215,8 +257,8 @@ def get_out_elem_bytes(self): raise ValueError(f"Invalid out_dtype: {self.out_dtype}") def get_mfma_fn(self): - if self.mfma_pipeline == MfmaPipeline.F8F6F4_PIPELINE: - return rocdl.mfma_f32_16x16x16f16 + if self.mfma_pipeline == MfmaPipeline.F8F8_16x16_PIPELINE: + return rocdl.mfma_f32_16x16x32_fp8_fp8 elif self.mfma_pipeline == MfmaPipeline.BF16BF16_16x16_PIPELINE: return rocdl.mfma_f32_16x16x16bf16_1k elif self.mfma_pipeline == MfmaPipeline.F16F16_16x16_PIPELINE: @@ -252,6 +294,7 @@ def get_a_bytes_per_thread( +@dataclass class PreshuffleBLayout: """Container returned by `make_preshuffle_b_layout`.""" @@ -900,12 +943,12 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) acc_idx = mi * num_acc_n + ni - current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + accs_in[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, [ a128, b128, - current_accs_list[acc_idx], + accs_in[acc_idx], # cbsz, abid, blgp: 0 0, 0, @@ -917,7 +960,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): 0x3F800000, ], ) - return current_accs_list, scales_pf + return accs_in def block_mfma_16x16( @@ -927,8 +970,8 @@ def block_mfma_16x16( col_offset_base_bytes, row_a_lds, lds_load_packs_k64, - mfma_fn, *, + mfma_fn, mfma_res_ty, k_unroll, num_acc_n, @@ -957,18 +1000,20 @@ def mfma_k64_bytes(acc_in, a0, a1, b0, b1): a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) for ni in range_constexpr(num_acc_n): acc_idx = mi * num_acc_n + ni - current_accs_list[acc_idx] = mfma_k64_bytes( - current_accs_list[acc_idx], + accs_in[acc_idx] = mfma_k64_bytes( + accs_in[acc_idx], a0, a1, b_packs0[ni], b_packs1[ni], ) + return accs_in __all__ = [ "PreshuffleBLayout", "MfmaPipeline", "EpilogPipeline", + "mfma_pipeline_dicts", "PreshufflePipelineManager", "buffer_copy_gmem16_dwordx4", "lds_load_pack_k32", diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 25e16448..9f7f0ba0 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -25,6 +25,8 @@ from flydsl.dialects.ext import arith, gpu, buffer_ops, llvm, vector, rocdl, scf, memref from kernels.mfma_preshuffle_pipeline import ( + MfmaPipeline, + EpilogPipeline, buffer_copy_gmem16_dwordx4, lds_load_pack_k32, lds_store_4b_xor16, diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index a52c0dd9..6280dbca 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -11,8 +11,6 @@ - `ck_v1_single_lds`: CK-like Intrawave + bpreshuffle v1 spirit (single LDS buffer for A) """ -import os - import flydsl from flydsl.dialects.ext import flir from flydsl.dialects.ext.python_control_flow import range_constexpr @@ -26,11 +24,14 @@ from kernels.mfma_preshuffle_pipeline import ( buffer_copy_gmem16_dwordx4, - lds_load_pack_k32, lds_store_16b_xor16, make_preshuffle_b_layout, + make_preshuffle_scale_layout, load_b_pack_k32, tile_chunk_coord_i32, + MfmaPipeline, + EpilogPipeline, + mfma_pipeline_dicts, PreshufflePipelineManager, block_mfma_block_scale_f8f6f4, block_mfma_PTPC_f8f6f4, @@ -49,7 +50,7 @@ def compile_preshuffle_gemm( tile_k: int, a_dtype: str = "fp8", b_dtype: str = "fp8", - out_dtype: str = "f16", + out_dtype: str = "fp16", lds_stage: int = 2, # Epilogue options use_cshuffle_epilog: bool = False, @@ -62,7 +63,7 @@ def compile_preshuffle_gemm( a_dtype: - "fp8": A/B are fp8 (1B/elem) - "int8": A/B are int8 (1B/elem) - - "int4": W4A8 path: A is int8 (1B/elem). + - "a8w4": W4A8 path: A is int8 (1B/elem). b_dtype: - "fp8": A/B are fp8 (1B/elem) - "int8": A/B are int8 (1B/elem) @@ -75,6 +76,9 @@ def compile_preshuffle_gemm( - 2: ping-pong LDS for A (2 LDS buffers), tuned schedule (original). - 1: single LDS buffer for A . """ + + total_threads = 256 + pipeline_manager = PreshufflePipelineManager(a_dtype, b_dtype, out_dtype) pipeline_manager.check_type_valid() @@ -115,32 +119,37 @@ def compile_preshuffle_gemm( # CK-style LDS128: stride is in BYTES along K (for XOR16 swizzle). lds_stride_bytes = a_tile_k_bytes // a_elem_pack + def _get_mfma_dict_value(key, pipeline): + value = mfma_pipeline_dicts[key][pipeline] + return value() if callable(value) else value + def _a_elem_type(): - return a_elem_type_dict[mfma_pipeline] + return _get_mfma_dict_value("a_elem_type", mfma_pipeline) def _b_elem_type(): - return b_elem_type_dict[mfma_pipeline] + return _get_mfma_dict_value("b_elem_type", mfma_pipeline) def _scale_elem_type(): - return scale_elem_type_dict[mfma_pipeline] + return _get_mfma_dict_value("scale_elem_type", mfma_pipeline) def _out_elem_type(): - return out_elem_type_dict[epilog_pipeline] + return _get_mfma_dict_value("out_elem_type", epilog_pipeline) def _a_vec16_type(): - return a_vec16_type_dict[mfma_pipeline] + return _get_mfma_dict_value("a_vec16_type", mfma_pipeline) def _b_vec16_type(): - return b_vec16_type_dict[mfma_pipeline] + return _get_mfma_dict_value("b_vec16_type", mfma_pipeline) def _mfma_input_pack_ty(): - return mfma_input_pack_ty_dict[mfma_pipeline] + return _get_mfma_dict_value("mfma_input_pack_ty", mfma_pipeline) def _mfma_output_pack_ty(): - return mfma_output_pack_ty_dict[mfma_pipeline] + return _get_mfma_dict_value("mfma_output_pack_ty", mfma_pipeline) is_f16_or_bf16 = mfma_pipeline in [MfmaPipeline.F16F16_16x16_PIPELINE, MfmaPipeline.BF16BF16_16x16_PIPELINE] is_int4 = mfma_pipeline in [MfmaPipeline.I8I4_16x16_PIPELINE] is_int8 = mfma_pipeline in [MfmaPipeline.I8I8_16x16_PIPELINE] - is_fp4 = mfma_pipeline in [MfmaPipeline.FP4FP4_16x16_PIPELINE, - MfmaPipeline.FP8FP4_16x16_PIPELINE] + is_fp4 = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, + MfmaPipeline.F8F4_MXFP4_PIPELINE] - cbsz = 0 if mfma_pipeline == MfmaPipeline.FP4FP4_16x16_PIPELINE else 4 - blgp = 4 if mfma_pipeline in [MfmaPipeline.FP8FP4_16x16_PIPELINE, MfmaPipeline.FP8FP8_16x16_PIPELINE] else 0 + # 350 16x16x128 adtype(cbsz) & bdtype(blgp) + cbsz = 4 if mfma_pipeline == MfmaPipeline.F4F4_MXFP4_PIPELINE else 0 + blgp = 4 if mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, MfmaPipeline.F8F4_MXFP4_PIPELINE] else 0 pack_M = 2 pack_N = 2 @@ -175,8 +184,8 @@ def init_gpu_module(self): lds_total_bytes = max(lds_a_bytes, lds_out_bytes) # Keep LDS allocation sized in bytes: element_size * num_elems. # Allocate element type == _elem_type() and scale element count accordingly. - lds_total_elems = lds_total_bytes if elem_bytes == 1 else (lds_total_bytes // 2) - _state["lds_a_decl"] = allocator.allocate_array(_elem_type(), lds_total_elems) + a_lds_total_elems = lds_total_bytes // a_elem_bytes + _state["lds_a_decl"] = allocator.allocate_array(_a_elem_type(), a_lds_total_elems) allocator.finalize() @flir.kernel @@ -194,9 +203,7 @@ def kernel_gemm( # ---- Types ---- # NOTE: Some environments have multiple `flydsl` builds on PYTHONPATH. # Use explicit MLIR Values (not Python ints / wrapper objects) for ROCDL ops. - acc_init = arith.unwrap( - arith.constant_vector(0, _mfma_output_pack_ty()) - ) + acc_init = arith.unwrap(arith.constant_vector(0, _mfma_output_pack_ty())) # Layouts layout_c = flir.make_layout((c_m, c_n), stride=(c_n, 1)) @@ -207,11 +214,15 @@ def kernel_gemm( layout_a_div4 = flir.make_layout((c_m, c_k_div4bytes), stride=(c_k_div4bytes, 1)) # B preshuffle layout (shared with MoE kernels). - b_kpack_bytes = 16 / b_elem_pack + b_kpack_bytes = 16 // b_elem_pack layout_b = make_preshuffle_b_layout( flir, arith, c_n=c_n, c_k=c_k, kpack_bytes=b_kpack_bytes, elem_bytes=b_elem_bytes ).layout_b + # Scale layouts for FP4/MXFP4 (block-scale MFMA). + layout_a_scale = make_preshuffle_scale_layout(flir, arith, c_mn=c_m, c_k=c_k) if is_fp4 else None + layout_b_scale = make_preshuffle_scale_layout(flir, arith, c_mn=c_n, c_k=c_k) if is_fp4 else None + # LDS layout is element-indexed, but XOR16 swizzle is byte-based. # Represent LDS as (tile_m, tile_k) in elements and scale swizzle math by elem_bytes. shape_lds = flir.make_shape(tile_m, tile_k) @@ -299,6 +310,72 @@ def kernel_gemm( n_blk_list.append(flir.get(coord_n, 0)) n_intra_list.append(flir.get(coord_n, 1)) + # FP4/MXFP4 pack parameters for block-scale MFMA. + k_unroll_packed = k_unroll // pack_K if is_fp4 else 0 + m_repeat_packed = m_repeat // pack_M if is_fp4 else 0 + num_acc_n_packed = num_acc_n // pack_N if is_fp4 else 0 + + # --- Scale load logic for FP4/MXFP4 --- + def load_scale(arg_scale, rsrc, layout, ku, mni): + """Load a single scale value for FP4/MXFP4 block-scale MFMA.""" + k_lane = lane_div_16 + n_lane = lane_mod_16 + coord_pack = flir.make_coord(mni, ku, k_lane, n_lane) + idx_pack = flir.crd2idx(coord_pack, layout) + scale_view = flir.TensorView( + arg_scale, + (1,), + strides=(1,), + base_indices=(idx_pack,), + element_type=_scale_elem_type(), + ) + scale = flir.copy( + flir.make_copy_atom(_scale_elem_type(), vector_size=1), + scale_view, + None, + alignment=8, + return_vector=True, + src_buffer_resource=rsrc, + src_buffer_offset_in_bytes=False, + ) + return scale + + def load_b_scale_tile(base_k): + """Load B scale tile for FP4/MXFP4.""" + b_scale_tile = [] + for ku in range_constexpr(k_unroll_packed): + for ni in range_constexpr(num_acc_n_packed): + scale = load_scale( + arg_scale_b, + scale_b_rsrc, + layout_b_scale, + ku + base_k, + ni + (by_n + n_tile_base) // pack_N // 16, + ) + b_scale_tile.append(scale) + return b_scale_tile + + def load_a_scale_tile(base_k): + """Load A scale tile for FP4/MXFP4.""" + a_scale_tile = [] + for ku in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + scale = load_scale( + arg_scale_a, + scale_a_rsrc, + layout_a_scale, + ku + base_k, + mi + bx_m // pack_M // 16, + ) + a_scale_tile.append(scale) + return a_scale_tile + + def prefetch_ab_scale_tile(base_k): + """Prefetch A and B scale tiles for FP4/MXFP4.""" + if not is_fp4: + return None, None + return load_a_scale_tile(base_k), load_b_scale_tile(base_k) + # --- B load logic --- # Shared loader supports: # - FP8/INT8: explicit 16B load (one full KPack) + extract 8B for this micro-step @@ -317,15 +394,15 @@ def load_b_pack(base_k, ki_step, ni): n_blk=n_blk_list[ni], n_intra=n_intra_list[ni], lane_div_16=lane_div_16, - elem_type=_elem_type(), + elem_type=_b_elem_type(), kpack_bytes=b_kpack_bytes, - elem_bytes=elem_bytes, + elem_bytes=b_elem_bytes, unpack_int4=is_int4, ) # For FP8/INT8 we can load one 16B pack and extract both 8B halves (K64 bytes). # For INT4 (packed), reuse the existing K32 loader twice (2x4B loads + unpack). - atom_b_g2r16 = flir.make_copy_atom(_elem_type(), vector_size=16) + atom_b_g2r16 = flir.make_copy_atom(_b_elem_type(), vector_size=16) c64_b = 64 c0_idx = 0 @@ -349,16 +426,16 @@ def load_b_packs_k64(base_k, ku: int, ni: int): (vec_elems,), strides=(1,), base_indices=(idx_pack,), - element_type=_elem_type(), + element_type=_b_elem_type(), ) b16 = flir.copy( - flir.make_copy_atom(_elem_type(), vector_size=vec_elems), + flir.make_copy_atom(_b_elem_type(), vector_size=vec_elems), b_view, None, alignment=8, return_vector=True, - src_buffer_resource=(b_rsrc if elem_bytes == 1 else None), - src_buffer_offset_in_bytes=(elem_bytes == 1), + src_buffer_resource=(b_rsrc if b_elem_bytes == 1 else None), + src_buffer_offset_in_bytes=(b_elem_bytes == 1), ) # Split 16B pack into two 8B halves. @@ -393,11 +470,11 @@ def load_b_tile(base_k): def lds_load_16b(curr_row_a_lds, col_base, lds_base): # Swizzle in bytes, then convert to element offset for memref indexing. col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base, a_k_blocks16) - col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / 2) + col_base_swz = col_base_swz_bytes if a_elem_bytes == 1 else (col_base_swz_bytes / 2) coord_a16 = flir.make_coord(curr_row_a_lds, col_base_swz) idx_a16 = flir.crd2idx(coord_a16, layout_lds) idx_a16 = idx_a16 + lds_base - return vector.load_op(_vec16_type(), lds_a, [idx_a16]) + return vector.load_op(_a_vec16_type(), lds_a, [idx_a16]) # --- A LDS load helper for K64-bytes (load 16B once, extract 2x i64 halves) --- def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): @@ -415,14 +492,12 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): return vector.bitcast(_mfma_input_pack_ty(), a0_v1), vector.bitcast(_mfma_input_pack_ty(), a1_v1) # --- A load/store (16B chunks), XOR16 swizzle --- - num_a_loads = bytes_per_thread_a // a_load_bytes + a_load_bytes = 16 + num_a_loads = a_bytes_per_thread // a_load_bytes # A tile mapping in dwords along K: # tile_k_dwords = (tile_k * elem_bytes) / 4 - if elem_bytes == 2: - tile_k_dwords = (tile_k * 2) // 4 - else: - tile_k_dwords = tile_k // 4 - layout_a_tile_div4 = flir.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) + a_tile_k_dwords = tile_k * a_elem_bytes // 4 + layout_a_tile_div4 = flir.make_layout((tile_m, a_tile_k_dwords), stride=(a_tile_k_dwords, 1)) c4 = arith.constant(4, index=True) tx_i32_base = tx * c4 atom_a_g2r16 = flir.make_copy_atom(_a_elem_type(), vector_size=a_kpack_elems) @@ -486,20 +561,19 @@ def store_a_tile_to_lds(vec_a_parts, lds_base): ) def prefetch_ab_tile(base_k): - base_k_bytes = base_k - - # div4 for Double word is 4 times of a byte + # Convert element index to byte index, then to dword index. + base_k_bytes = base_k * arith.constant(int(a_elem_bytes), index=True) base_k_div4 = base_k_bytes / 4 - a_regs = load_a_tile(base_k_div4 * arith.constant(int(a_elem_bytes), index=True)) - b_regs = load_b_tile(base_k * arith.constant(int(b_elem_bytes), index=True)) + a_regs = load_a_tile(base_k_div4) + b_regs = load_b_tile(base_k) return a_regs, b_regs - def compute_tile(accs_in, b_tile_in, lds_base, *, is_last_tile=False, a0_prefetch=None): + def compute_tile(accs_in, b_tile_in, lds_base, *, is_last_tile=False, a0_prefetch=None, a_scale=None, b_scale=None): scales_pf = {} mfma_res_ty = _mfma_output_pack_ty() - if is_last_tile and (not is_f16_or_bf16): - # Prefetch scales (fp8/int8/int4 only). + if is_last_tile and (not is_f16_or_bf16) and (not is_fp4): + # Prefetch scales for non-FP4 scaled paths (fp8/int8/int4 with per-tensor scale). s_b_vals = [] for ni in range_constexpr(num_acc_n): offset = ni * 16 @@ -521,6 +595,7 @@ def compute_tile(accs_in, b_tile_in, lds_base, *, is_last_tile=False, a0_prefetc current_accs_list = list(accs_in) if is_fp4: + # FP4/MXFP4 path: use block-scale MFMA with per-block scales. current_accs_list = block_mfma_block_scale_f8f6f4( current_accs_list, b_tile_in, @@ -532,16 +607,18 @@ def compute_tile(accs_in, b_tile_in, lds_base, *, is_last_tile=False, a0_prefetc row_a_lds=row_a_lds, mfma_fn=rocdl.mfma_scale_f32_16x16x128_f8f6f4, mfma_res_ty=mfma_res_ty, - cbsz=0, - blgp=0, + cbsz=cbsz, + blgp=blgp, + a_elem_vec_pack=a_elem_pack, k_unroll=k_unroll, - num_acc_n=num_acc_n, m_repeat=m_repeat, - pack_K=2, - pack_M=2, - pack_N=2, + num_acc=num_acc_n, + pack_K=pack_K, + pack_M=pack_M, + pack_N=pack_N, a0_prefetch=a0_prefetch, ) + return current_accs_list, scales_pf if use_mfma_scale_128: current_accs_list = block_mfma_PTPC_f8f6f4( @@ -636,13 +713,14 @@ def write_row_to_lds( acc_idx = mi * num_acc_n + ni acc = final_accs[acc_idx] val = vector.extract(acc, static_position=[ii], dynamic_position=[]) - if is_int8: + # if is_int8: + if is_int8 or is_int4: val = arith.sitofp(T.f32, val) if is_f16_or_bf16: val_s = val else: val_s = (val * s_a) * s_b_vals[ni] - v16 = arith.trunc_f(T.f16, val_s) + v16 = arith.trunc_f(_out_elem_type(), val_s) # v16 (f16) -> bits in i32 low16 v1_f16 = vector.from_elements(vec1_f16, [v16]) @@ -754,15 +832,16 @@ def body_row(*, mi: int, ii: int, row_in_tile, row): acc_idx = mi * num_acc_n + ni acc = final_accs[acc_idx] val = vector.extract(acc, static_position=[ii], dynamic_position=[]) - if is_int8: + if is_int8 or is_int4: + # INT8/INT4 paths use i32 accumulators; convert to f32 for scaled epilogue. val = arith.sitofp(T.f32, val) if is_f16_or_bf16: val_s = val else: val_s = (val * s_a) * s_b_vals[ni] - val_f16 = arith.trunc_f(T.f16, val_s) + val_out = arith.trunc_f(_out_elem_type(), val_s) idx_out = idx_base + arith.constant(ni * 16, index=True) - buffer_ops.buffer_store(val_f16, c_rsrc, idx_out) + buffer_ops.buffer_store(val_out, c_rsrc, idx_out) mfma_epilog( use_cshuffle=False, @@ -843,6 +922,9 @@ def prefetch_a0_pack(lds_base): # Prologue: tile-0 k0 = arith.constant(0, index=True) a_regs0, b_tile0 = prefetch_ab_tile(k0) + # Prefetch scales for FP4/MXFP4 at tile-0. + a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(k0 // 2) if is_fp4 else (k0, k0) + store_a_tile_to_lds(a_regs0, lds_base0) gpu.barrier() accs = [acc_init] * (num_acc_n * m_repeat) @@ -860,9 +942,12 @@ def prefetch_a0_pack(lds_base): for k_iv in range(0, c_k_main, tile_k * 2): next_k1 = k_iv + tile_k a_regs_ping, b_tile_ping = prefetch_ab_tile(next_k1) + a_scale_ping, b_scale_ping = prefetch_ab_scale_tile(next_k1 // 256) if is_fp4 else (k0, k0) accs, _ = compute_tile( - accs, b_tile_pong, lds_base_pong, a0_prefetch=a0_prefetch_pong + accs, b_tile_pong, lds_base_pong, + a0_prefetch=a0_prefetch_pong, + a_scale=a_scale_pong, b_scale=b_scale_pong, ) a0_prefetch_pong = None @@ -875,9 +960,12 @@ def prefetch_a0_pack(lds_base): next_k2 = k_iv + tile_k * 2 a_regs_pong, b_tile_pong = prefetch_ab_tile(next_k2) + a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(next_k2 // 256) if is_fp4 else (k0, k0) accs, _ = compute_tile( - accs, b_tile_ping, lds_base_ping, a0_prefetch=a0_prefetch_ping + accs, b_tile_ping, lds_base_ping, + a0_prefetch=a0_prefetch_ping, + a_scale=a_scale_ping, b_scale=b_scale_ping, ) a0_prefetch_ping = None @@ -894,15 +982,19 @@ def prefetch_a0_pack(lds_base): lds_base_pong, is_last_tile=True, a0_prefetch=a0_prefetch_pong, + a_scale=a_scale_pong, b_scale=b_scale_pong, ) else: c_k_stop = c_k - (tile_k * 3) for k_iv in range(0, c_k_stop, tile_k * 2): next_k1 = k_iv + tile_k a_regs_ping, b_tile_ping = prefetch_ab_tile(next_k1) + a_scale_ping, b_scale_ping = prefetch_ab_scale_tile(next_k1 // 256) if is_fp4 else (k0, k0) accs, _ = compute_tile( - accs, b_tile_pong, lds_base_pong, a0_prefetch=a0_prefetch_pong + accs, b_tile_pong, lds_base_pong, + a0_prefetch=a0_prefetch_pong, + a_scale=a_scale_pong, b_scale=b_scale_pong, ) a0_prefetch_pong = None @@ -914,9 +1006,12 @@ def prefetch_a0_pack(lds_base): next_k2 = k_iv + tile_k * 2 a_regs_pong, b_tile_pong = prefetch_ab_tile(next_k2) + a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(next_k2 // 256) if is_fp4 else (k0, k0) accs, _ = compute_tile( - accs, b_tile_ping, lds_base_ping, a0_prefetch=a0_prefetch_ping + accs, b_tile_ping, lds_base_ping, + a0_prefetch=a0_prefetch_ping, + a_scale=a_scale_ping, b_scale=b_scale_ping, ) a0_prefetch_ping = None @@ -928,9 +1023,12 @@ def prefetch_a0_pack(lds_base): last_k = c_k - tile_k a_regs_ping, b_tile_ping = prefetch_ab_tile(last_k) + a_scale_ping, b_scale_ping = prefetch_ab_scale_tile(last_k // 256) if is_fp4 else (k0, k0) accs, _ = compute_tile( - accs, b_tile_pong, lds_base_pong, a0_prefetch=a0_prefetch_pong + accs, b_tile_pong, lds_base_pong, + a0_prefetch=a0_prefetch_pong, + a_scale=a_scale_pong, b_scale=b_scale_pong, ) a0_prefetch_pong = None @@ -946,41 +1044,42 @@ def prefetch_a0_pack(lds_base): lds_base_ping, is_last_tile=True, a0_prefetch=a0_prefetch_ping, + a_scale=a_scale_ping, b_scale=b_scale_ping, ) store_output(final_accs, scales) - else: - # CK-like bpreshuffle v1 spirit: - # - Intrawave schedule - # - Global prefetch 2 (regs double-buffer) - # - Local shared memory buffer 1 (single LDS tile for A) - # Prologue: tile-0 - k0 = arith.constant(0, index=True) - a_regs0, b_tile0 = prefetch_ab_tile(k0) - store_a_tile_to_lds(a_regs0, lds_base0) - gpu.barrier() - accs = [acc_init] * (num_acc_n * m_repeat) - - lds_base = lds_base0 - b_tile_cur = b_tile0 - - # For each tile except last: prefetch next tile, compute current, then overwrite LDS. - for k_base in range(0, c_k - tile_k, tile_k): - next_k = k_base + tile_k - a_next, b_next = prefetch_ab_tile(next_k) - accs, _ = compute_tile(accs, b_tile_cur, lds_base) - # Single LDS buffer: ensure *all* waves are done reading A from LDS - # before any wave overwrites it with the next tile. - gpu.barrier() - store_a_tile_to_lds(a_next, lds_base) - hot_loop_scheduler() - gpu.barrier() - b_tile_cur = b_next - - final_accs, scales = compute_tile( - accs, b_tile_cur, lds_base, is_last_tile=True - ) - store_output(final_accs, scales) + # else: + # # CK-like bpreshuffle v1 spirit: + # # - Intrawave schedule + # # - Global prefetch 2 (regs double-buffer) + # # - Local shared memory buffer 1 (single LDS tile for A) + # # Prologue: tile-0 + # k0 = arith.constant(0, index=True) + # a_regs0, b_tile0 = prefetch_ab_tile(k0) + # store_a_tile_to_lds(a_regs0, lds_base0) + # gpu.barrier() + # accs = [acc_init] * (num_acc_n * m_repeat) + + # lds_base = lds_base0 + # b_tile_cur = b_tile0 + + # # For each tile except last: prefetch next tile, compute current, then overwrite LDS. + # for k_base in range(0, c_k - tile_k, tile_k): + # next_k = k_base + tile_k + # a_next, b_next = prefetch_ab_tile(next_k) + # accs, _ = compute_tile(accs, b_tile_cur, lds_base) + # # Single LDS buffer: ensure *all* waves are done reading A from LDS + # # before any wave overwrites it with the next tile. + # gpu.barrier() + # store_a_tile_to_lds(a_next, lds_base) + # hot_loop_scheduler() + # gpu.barrier() + # b_tile_cur = b_next + + # final_accs, scales = compute_tile( + # accs, b_tile_cur, lds_base, is_last_tile=True + # ) + # store_output(final_accs, scales) @flir.jit def __call__( diff --git a/tests/kernels/test_preshuffle_gemm.py b/tests/kernels/test_preshuffle_gemm.py index 256d7d1b..33976829 100644 --- a/tests/kernels/test_preshuffle_gemm.py +++ b/tests/kernels/test_preshuffle_gemm.py @@ -29,7 +29,6 @@ sys.path.insert(0, _PYFLIR_SRC) from kernels.preshuffle_gemm import compile_preshuffle_gemm -from kernels.mixed_preshuffle_gemm import compile_mxfp4_preshuffle_gemm from tests.test_common import run_perftest, verify_output from tests.utils import pertoken_quant, shuffle_weight from flydsl.runtime.device import get_rocm_arch @@ -94,7 +93,8 @@ def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): return out.to(dtype) -@pytest.mark.parametrize("in_dtype", ["fp8", "int8", "bf16"]) +@pytest.mark.parametrize("a_dtype", ["fp8", "int8", "bf16"]) +@pytest.mark.parametrize("b_dtype", ["fp8", "int8", "bf16"]) @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k", [ @@ -105,7 +105,8 @@ def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): ] ) def test_mfma_a8_flir_preshuffle( - in_dtype, + a_dtype, + b_dtype, M, N, K, @@ -121,7 +122,7 @@ def test_mfma_a8_flir_preshuffle( ): print("=" * 80) print( - f"MFMA {in_dtype.upper()} GEMM Test (Tile: {tile_m}x{tile_n}x{tile_k}) [Torch Optimized]" + f"MFMA {a_dtype.upper()}/{b_dtype.upper()} GEMM Test (Tile: {tile_m}x{tile_n}x{tile_k}) [Torch Optimized]" ) print("=" * 80) @@ -137,8 +138,8 @@ def test_mfma_a8_flir_preshuffle( tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, - a_dtype=in_dtype, - b_dtype=in_dtype, + a_dtype=a_dtype, + b_dtype=b_dtype, out_dtype="f16", lds_stage=lds_stage, use_cshuffle_epilog=bool(use_cshuffle_epilog), @@ -148,10 +149,10 @@ def test_mfma_a8_flir_preshuffle( size_c = M * N size_a = M * K # B is packed int4 for W4A8: 2 values per byte. - if in_dtype == "int4": + if b_dtype == "int4": size_b = (N * K) // 2 elem_bytes = 1 - elif in_dtype in ("fp16", "bf16"): + elif b_dtype in ("fp16", "bf16"): size_b = (N * K) * 2 elem_bytes = 2 else: @@ -164,12 +165,12 @@ def test_mfma_a8_flir_preshuffle( a_fp32 = torch.rand(M, K, device=device, dtype=torch.float32) b_fp32_t = torch.rand(N, K, device=device, dtype=torch.float32) # (N, K) - is_int4 = in_dtype == "int4" + is_int4 = b_dtype == "int4" # INT4 here means W4A8: A is INT8, B is packed INT4 and unpacked to INT8 in-kernel. - is_int8 = (in_dtype == "int8") or is_int4 + is_int8 = (a_dtype == "int8") or is_int4 - if in_dtype in ("fp16", "bf16"): - torch_dtype = torch.float16 if in_dtype == "fp16" else torch.bfloat16 + if a_dtype in ("fp16", "bf16") or b_dtype in ("fp16", "bf16"): + torch_dtype = torch.float16 if a_dtype == "fp16" else torch.bfloat16 a_q = a_fp32.to(torch_dtype) b_q = b_fp32_t.to(torch_dtype) # Scale is semantically optional for fp16/bf16 (no dequant). Let callers pass None; @@ -253,7 +254,7 @@ def launch_kernel(c, a, b, sa, sb): tbps = bytes_moved / 1e12 / (us / 1e6) print(f"Throughput: {us:.1f} us, {tflops:.2f} TFLOPS, BW: {tbps:.3f} TB/s") - if HAS_AITER and bool(run_aiter_bench) and (not is_int4) and (in_dtype in ("fp8", "int8")): + if HAS_AITER and bool(run_aiter_bench) and (not is_int4) and (a_dtype in ("fp8", "int8") and a_dtype == b_dtype): print("-" * 40) print("Running Aiter Benchmark...") try: @@ -321,7 +322,7 @@ def test_mfma_w4_flir_preshuffle( raise ValueError( f"lds_stage must be 1 or 2, got {lds_stage!r}" ) - exe = compile_mxfp4_preshuffle_gemm( + exe = compile_preshuffle_gemm( M=M, N=N, K=K, @@ -330,6 +331,7 @@ def test_mfma_w4_flir_preshuffle( tile_k=tile_k, a_dtype=a_dtype, b_dtype=b_dtype, + out_dtype="f16", lds_stage=lds_stage, use_cshuffle_epilog=bool(use_cshuffle_epilog), ) @@ -428,11 +430,19 @@ def launch_kernel(c, a, b, sa, sb): description="Preshuffle GEMM benchmark" ) parser.add_argument( - "--in_dtype", + "--a_dtype", type=str, default="fp8", choices=["fp8", "int8", "int4", "fp16", "bf16", "fp4"], - help="Input dtype") + help="Input dtype" + ) + parser.add_argument( + "--b_dtype", + type=str, + default="fp8", + choices=["fp8", "int8", "int4", "fp16", "bf16", "fp4"], + help="Input dtype" + ) parser.add_argument("-M", type=int, default=16, help="M dimension") parser.add_argument("-N", type=int, default=10240, help="N dimension") parser.add_argument("-K", type=int, default=8192, help="K dimension") @@ -489,7 +499,8 @@ def launch_kernel(c, a, b, sa, sb): torch.set_default_device("cuda") if not args.wfp4: test_mfma_a8_flir_preshuffle( - args.in_dtype, + args.a_dtype, + args.b_dtype, M=args.M, N=args.N, K=args.K, From edac33c8290b38490d8160f063f06b0a4ca8a3d1 Mon Sep 17 00:00:00 2001 From: Zzz9990 Date: Tue, 3 Feb 2026 19:28:33 -0600 Subject: [PATCH 03/11] preshuffle_gemm.py ready --- kernels/mfma_preshuffle_pipeline.py | 32 +- kernels/mixed_preshuffle_gemm.py | 993 -------------------------- kernels/preshuffle_gemm.py | 68 +- tests/kernels/test_preshuffle_gemm.py | 2 +- 4 files changed, 65 insertions(+), 1030 deletions(-) delete mode 100644 kernels/mixed_preshuffle_gemm.py diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index d11b767d..a4b9cb22 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -40,7 +40,7 @@ class EpilogPipeline(Enum): a_elem_type_dict = { MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.ui8, - MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.ui8, + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.f8, MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f8, MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16, MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16, @@ -267,6 +267,8 @@ def get_mfma_fn(self): return get_mfma_i32_k32() elif self.mfma_pipeline == MfmaPipeline.I8I4_16x16_PIPELINE: return get_mfma_i32_k32() + elif self.mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, MfmaPipeline.F8F4_MXFP4_PIPELINE]: + return rocdl.mfma_scale_f32_16x16x128_f8f6f4 else: raise ValueError(f"Invalid mfma pipeline: {self.mfma_pipeline}") @@ -275,11 +277,11 @@ def get_a_bytes_per_thread( tile_m: int, tile_k: int, ): - a_bytes_per_tile = int(tile_m) * int(tile_k) * int(self.a_elem_bytes) + a_bytes_per_tile = int(tile_m) * int(tile_k) * int(self.a_elem_bytes) // self.a_elem_pack if a_bytes_per_tile % self.block_size != 0: raise ValueError( - "tile_m*tile_k*elem_bytes must be divisible by " - f"{self.block_size}: tile_m={tile_m}, tile_k={tile_k}, a_elem_bytes={self.a_elem_bytes}" + "tile_m*tile_k*elem_bytes/a_elem_pack must be divisible by " + f"{self.block_size}: tile_m={tile_m}, tile_k={tile_k}, a_elem_bytes={self.a_elem_bytes}, a_elem_pack={self.a_elem_pack}" ) a_bytes_per_thread = a_bytes_per_tile // self.block_size @@ -806,17 +808,15 @@ def block_mfma_block_scale_f8f6f4( a_elem_vec_pack, k_unroll, m_repeat, - num_acc, + num_acc_n, pack_K, pack_M, pack_N, a0_prefetch=None, ): - current_accs_list = list(accs_in) - k_unroll_packed = k_unroll // pack_K m_repeat_packed = m_repeat // pack_M - num_acc_n_packed = num_acc // pack_N + num_acc_n_packed = num_acc_n // pack_N mfma_res_ty = T.f32x4 vec4_i64 = T.vec(4, T.i64) @@ -851,24 +851,26 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): else: a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) - col_base1 = col_base + 64 - a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) - a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + if cbsz == 0: + col_base1 = col_base + 64 + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + else: + a128 = pack_i64x4_to_i32x8(a0, a1, c0_i64, c0_i64) for inxdl in range_constexpr(pack_N): ni_idx = ni * pack_N + inxdl - b0 = b_packs0[ni_idx] b1 = b_packs1[ni_idx] b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) acc_idx = mi_idx * num_acc_n + ni_idx - current_accs_list[acc_idx] = mfma_fn( + accs_in[acc_idx] = mfma_fn( mfma_res_ty, [ a128, b128, - current_accs_list[acc_idx], + accs_in[acc_idx], # cbsz, abid, blgp cbsz, blgp, @@ -881,7 +883,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): b_scale_val, ], ) - return current_accs_list, None + return accs_in # ---------------- gfx95 fast path (K128 MFMA scale) ---------------- # This is the key optimization from `zhimding/develop_0107` for FP8: diff --git a/kernels/mixed_preshuffle_gemm.py b/kernels/mixed_preshuffle_gemm.py deleted file mode 100644 index 74b45801..00000000 --- a/kernels/mixed_preshuffle_gemm.py +++ /dev/null @@ -1,993 +0,0 @@ -"""Preshuffle GEMM kernel implementations (FLIR MFMA FP8/INT8). - -This module intentionally contains the **kernel builder code** for the preshuffle GEMM, -extracted from `tests/kernels/test_preshuffle_gemm.py` in the same style as -`kernels/moe_gemm_2stage.py`: -- `kernels/` holds the implementation (compile functions) -- `tests/` holds correctness/perf harnesses - -Pipelines: -- `pingpong`: tuned 2-stage pipeline with ping-pong LDS for A (2 LDS buffers) -- `ck_v1_single_lds`: CK-like Intrawave + bpreshuffle v1 spirit (single LDS buffer for A) -""" - -import os - -import flydsl -from flydsl.dialects.ext import flir -from flydsl.dialects.ext.python_control_flow import range_constexpr -from flydsl.runtime.device import get_rocm_arch as get_hip_arch -from flydsl.utils import SmemAllocator, SmemPtr - -from _mlir import ir - -from flydsl.dialects.ext import arith, gpu, buffer_ops, vector, rocdl -from flydsl.lang.ir.types import T, memref - -from kernels.mfma_preshuffle_pipeline import ( - buffer_copy_gmem16_dwordx4, - lds_load_pack_k32, - lds_store_16b_xor16, - make_preshuffle_b_layout, - make_preshuffle_scale_layout, - load_b_pack_k32, - tile_chunk_coord_i32, -) -from kernels.mfma_epilogues import mfma_epilog - - -def compile_mxfp4_preshuffle_gemm( - *, - M: int, - N: int, - K: int, - tile_m: int, - tile_n: int, - tile_k: int, - a_dtype: str = "fp8", - b_dtype: str = "fp4", - lds_stage: int = 2, - # Epilogue options - use_cshuffle_epilog: bool = True, -): - """Compile the preshuffle GEMM kernel and return the compiled executable. - - Args: - M, N, K: GEMM sizes (A[M,K], B[N,K], C[M,N]). - tile_m, tile_n, tile_k: block tile sizes. - in_dtype: - - "fp8": A/B are fp8 (1B/elem) - - "int8": A/B are int8 (1B/elem) - - "int4": W4A8 path: A is int8, B is packed int4 (2 values per byte) and unpacked to int8 in-kernel. - lds_stage: - - 2: ping-pong LDS for A (2 LDS buffers), tuned schedule (original). - - 1: single LDS buffer for A . - """ - # if in_dtype not in ("fp8", "int8", "int4", "fp16", "bf16"): - # raise ValueError( - # "in_dtype must be one of ('fp8','int8','int4','fp16','bf16'), " - # f"got {in_dtype!r}" - # ) - - is_fp4_a = a_dtype == "fp4" - is_fp8_a = a_dtype == "fp8" - is_fp4_b = b_dtype == "fp4" - - a_elem_vec_pack = 2 if is_fp4_a else 1 - b_elem_vec_pack = 2 - - elem_bytes = 1 - - pack_M = 2 - pack_N = 2 - pack_K = 2 - - cbsz = 0 if is_fp8_a else 4 - blgp = 4 - - # Pipeline is byte-addressed along K (16B loads, XOR16 swizzle in bytes). - # For fp16/bf16 (2B/elem), user passes tile_k halved so tile_k_bytes stays constant. - tile_k_bytes = int(tile_k) * int(elem_bytes) - - # K64-byte micro-step wrapper uses 2x "half-pack" MFMA. - # - fp8/i8: mfma K32, wrapper covers 64B (=64 elems) - # - fp16/bf16: mfma K16, wrapper covers 64B (=32 elems) -> effective K halves - if (tile_k_bytes % 64) != 0: - raise ValueError( - f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " - f"(tile_k={tile_k}, elem_bytes={elem_bytes})" - ) - - gpu_arch = get_hip_arch() - allocator = SmemAllocator(None, arch=gpu_arch) - _state = {} - - # Default-on: cross-tile (tile_k) A0 LDS prefetch in the ping-pong pipeline (lds_stage=2). - # - # This issues the *first* A-pack LDS read for the next tile between barriers, to overlap - # with the VMEM prefetch of the following tile. - - DYN = ir.ShapedType.get_dynamic_size() - - # Vector width calc (assume full tiles / no tail guards). - total_threads = 256 - bytes_a_per_tile = int(tile_m) * int(tile_k) * int(elem_bytes) // a_elem_vec_pack - if bytes_a_per_tile % total_threads != 0: - raise ValueError( - "tile_m*tile_k*elem_bytes must be divisible by " - f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={elem_bytes}" - ) - bytes_per_thread_a = bytes_a_per_tile // total_threads - - # Assume A loads are always 16B-aligned and use fixed dwordx4 (16B) buffer loads. - a_load_bytes = 16 - if bytes_per_thread_a % a_load_bytes != 0: - raise ValueError( - f"bytes_per_thread_a ({bytes_per_thread_a}) must be divisible by {a_load_bytes}" - ) - - # CK-style LDS128: stride is in BYTES along K (for XOR16 swizzle). - lds_stride_bytes = tile_k_bytes - - #TODO: use f4? - def _a_elem_type(): - # return T.f8 if is_fp8_a else T.ui8 - return T.f8 if is_fp8_a else T.ui8 - - def _b_elem_type(): - return T.ui8 - - def _scale_elem_type(): - return T.i32 - - #TODO: use f4 pack? - def _a_vec16_type(): - return T.f8x16 if is_fp8_a else T.ui8x16 - - def _b_vec16_type(): - return T.ui8x16 - - def _mfma_pack_ty(): - # ROCDL MFMA intrinsics expect specific operand vector types: - # - fp8/int8 paths use i64 packs (8 bytes) - # - fp16 uses v4f16 (8 bytes) - # - bf16 uses v4i16 (8 bytes) for *_bf16_1k variants - if is_f16: - return T.f16x4 - if is_bf16: - return T.i16x4 - return T.i64 - - # GEMM epilogue toggle: optional LDS CShuffle + vectorized stores. - # Default: off (current measured cases show no benefit). - epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" - module_name = f"mfma_preshuffle_{lds_stage}stages_a{a_dtype}_b{b_dtype}_{epilog_tag}".replace("-", "_") - - class _GEMM(flir.MlirModule): - GPU_MODULE_NAME = module_name - GPU_MODULE_TARGETS = [ - f'#rocdl.target' - ] - - def init_gpu_module(self): - # LDS scratch: - # - A tiles (fp8/int8): lds_stage * tile_m * lds_stride bytes - # - optional CShuffle output tile (fp16): tile_m * tile_n * 2 bytes - # - # When CShuffle is enabled, we reuse the same LDS allocation by aliasing it as fp16 - # in the epilogue (the A LDS is dead after the mainloop). - lds_a_bytes = int(lds_stage) * int(tile_m) * int(lds_stride_bytes) // int(a_elem_vec_pack) - lds_out_bytes = 2 * int(tile_m) * int(tile_n) if use_cshuffle_epilog else 0 - lds_total_bytes = max(lds_a_bytes, lds_out_bytes) - # Keep LDS allocation sized in bytes: element_size * num_elems. - # Allocate element type == _elem_type() and scale element count accordingly. - lds_total_elems = lds_total_bytes if elem_bytes == 1 else (lds_total_bytes // 2) - _state["lds_a_decl"] = allocator.allocate_array(_a_elem_type(), lds_total_elems) - allocator.finalize() - - @flir.kernel - def kernel_gemm( - self: flir.T.i64, - arg_c: lambda: memref(DYN, T.f16), - arg_a: lambda: memref(DYN, _a_elem_type()), - arg_b: lambda: memref(DYN, _b_elem_type()), - arg_scale_a: lambda: memref(DYN, _scale_elem_type()), - arg_scale_b: lambda: memref(DYN, _scale_elem_type()), - c_m: lambda: T.index, - c_n: lambda: T.index, - c_k: lambda: T.index, - ): - # ---- Types ---- - # NOTE: Some environments have multiple `flydsl` builds on PYTHONPATH. - # Use explicit MLIR Values (not Python ints / wrapper objects) for ROCDL ops. - acc_init = arith.unwrap(arith.constant_vector(0.0, T.f32x4)) - - # Layouts - layout_c = flir.make_layout((c_m, c_n), stride=(c_n, 1)) - - # A uses dword indexing (buffer-load dwordx4). Convert element index -> dword index: - # dword_index = (elem_index * elem_bytes) / 4 - if (int(elem_bytes) == 2): - c_k_div4bytes = (c_k * 2) / 4 - else: - c_k_div4bytes = c_k / 4 / a_elem_vec_pack - layout_a_div4 = flir.make_layout((c_m, c_k_div4bytes), stride=(c_k_div4bytes, 1)) - - c_k_b = c_k // b_elem_vec_pack - # B preshuffle layout (shared with MoE kernels). - kpack_bytes = 16 - layout_b = make_preshuffle_b_layout( - flir, arith, c_n=c_n, c_k=c_k_b, kpack_bytes=kpack_bytes, elem_bytes=elem_bytes - ).layout_b - - layout_a_scale = make_preshuffle_scale_layout( - flir, arith, c_mn=c_m, c_k=c_k, - ) - layout_b_scale = make_preshuffle_scale_layout( - flir, arith, c_mn=c_n, c_k=c_k, - ) - - # LDS layout is element-indexed, but XOR16 swizzle is byte-based. - # Represent LDS as (tile_m, tile_k) in elements and scale swizzle math by elem_bytes. - shape_lds = flir.make_shape(tile_m, tile_k // a_elem_vec_pack) - stride_lds = flir.make_stride(tile_k // a_elem_vec_pack, 1) - layout_lds = flir.make_layout(shape_lds, stride_lds) - - # CK-style XOR16 swizzle parameter (const). - k_blocks16 = arith.index(tile_k_bytes // 16) - - tx = gpu.thread_id("x") - bx = gpu.block_id("x") - by = gpu.block_id("y") - - base_ptr = allocator.get_base() - lds_a_ptr = _state["lds_a_decl"](base_ptr) - lds_a = lds_a_ptr.get() - lds_out = ( - SmemPtr(base_ptr, lds_a_ptr.byte_offset, T.f16, shape=(tile_m * tile_n,)).get() - if use_cshuffle_epilog - else None - ) - - # Note: We assume N is aligned (no N-tail support in this kernel). - a_rsrc = buffer_ops.create_buffer_resource(arg_a, max_size=False) - c_rsrc = buffer_ops.create_buffer_resource(arg_c, max_size=False) - scale_a_rsrc = buffer_ops.create_buffer_resource(arg_scale_a, max_size=False) - - b_rsrc = buffer_ops.create_buffer_resource(arg_b, max_size=True) - scale_b_rsrc = buffer_ops.create_buffer_resource(arg_scale_b, max_size=True) - - bx_m = bx * tile_m - by_n = by * tile_n - - # (thread_id.x) -> (wave_id, lane_id) via FLIR. - layout_wave_lane = flir.make_layout((4, 64), stride=(64, 1)) - coord_wave_lane = flir.idx2crd(tx, layout_wave_lane) - wave_id = flir.get(coord_wave_lane, 0) - lane_id = flir.get(coord_wave_lane, 1) - - # lane_id -> (lane_div_16, lane_mod_16) via FLIR. - layout_lane16 = flir.make_layout((4, 16), stride=(16, 1)) - coord_lane16 = flir.idx2crd(lane_id, layout_lane16) - lane_div_16 = flir.get(coord_lane16, 0) - lane_mod_16 = flir.get(coord_lane16, 1) - - row_a_lds = lane_mod_16 - # Per-`k1` (KLane) base offset along K inside a 64B K0 block. - # - # CK preshuffle uses KPackBytes=16 across dtypes, but KPackElems differs: - # - fp8/int8: 16 elems (1B) - # - fp16/bf16: 8 elems (2B) - # - # We express `col_offset_base` in *elements*. - kpack_elems = 16 if elem_bytes == 1 else 8 - col_offset_base = lane_div_16 * arith.constant(int(kpack_elems), index=True) - # `col_offset_base` is in element units (multiples of 16). We do LDS swizzle/math - # in bytes, so scale by element size for fp16/bf16. - col_offset_base_bytes = ( - col_offset_base - if elem_bytes == 1 - else (col_offset_base * arith.constant(int(elem_bytes), index=True)) - ) - - m_repeat = tile_m // 16 - # K stepping is byte-addressed: one "micro-tile step" is 64 bytes. - k_unroll = tile_k_bytes // 128 - - # --- Dynamic tiling along N (4 waves) --- - num_waves = 4 - n_per_wave = tile_n // num_waves - num_acc_n = n_per_wave // 16 - - c_n_per_wave = arith.constant(n_per_wave, index=True) - n_tile_base = wave_id * c_n_per_wave - - # fp4 pack - k_unroll_packed = k_unroll // pack_K - m_repeat_packed = m_repeat // pack_M - num_acc_n_packed = num_acc_n // pack_N - - # Decompose global_n -> (n_blk, n_intra) once per ni. - c_n0 = c_n / 16 - layout_n_blk_intra = flir.make_layout((c_n0, 16), stride=(16, 1)) - n_intra_list = [] - n_blk_list = [] - for i in range_constexpr(num_acc_n): - offset = i * 16 - c_offset = arith.constant(offset, index=True) - global_n = by_n + n_tile_base + c_offset + lane_mod_16 - coord_n = flir.idx2crd(global_n, layout_n_blk_intra) - n_blk_list.append(flir.get(coord_n, 0)) - n_intra_list.append(flir.get(coord_n, 1)) - - # For FP8/INT8 we can load one 16B pack and extract both 8B halves (K64 bytes). - # For INT4 (packed), reuse the existing K32 loader twice (2x4B loads + unpack). - atom_b_g2r16 = flir.make_copy_atom(_b_elem_type(), vector_size=16) - c64_b = 64 - c0_idx = 0 - - def load_b_packs_k64(base_k, ku: int, ni: int): - base_k_bytes = base_k * arith.constant(int(elem_bytes), index=True) - k0_base = base_k_bytes / c64_b - k0 = k0_base + ku - k1 = lane_div_16 - coord_pack = flir.make_coord(n_blk_list[ni], k0, k1, n_intra_list[ni], c0_idx) - idx_pack = flir.crd2idx(coord_pack, layout_b) - vec_elems = 16 - b_view = flir.TensorView( - arg_b, - (vec_elems,), - strides=(1,), - base_indices=(idx_pack,), - element_type=_b_elem_type(), - ) - b16 = flir.copy( - flir.make_copy_atom(_b_elem_type(), vector_size=vec_elems), - b_view, - None, - alignment=8, - return_vector=True, - src_buffer_resource=(b_rsrc if elem_bytes == 1 else None), - src_buffer_offset_in_bytes=(elem_bytes == 1), - ) - # Split 16B pack into two 8B halves. - b_i64x2 = vector.bitcast(T.i64x2, b16) - b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) - b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) - return b0_i64, b1_i64 - - def load_b_tile(base_k): - b_tile = [] - for ku in range_constexpr(k_unroll): - packs0 = [] - packs1 = [] - for ni in range_constexpr(num_acc_n): - b0, b1 = load_b_packs_k64(base_k, ku, ni) - packs0.append(b0) - packs1.append(b1) - b_tile.append((packs0, packs1)) - return b_tile - - def load_scale(arg_scale, rsrc, layout, ku, mni): - k_lane = lane_div_16 - n_lane = lane_mod_16 - coord_pack = flir.make_coord(mni, ku, k_lane, n_lane) - idx_pack = flir.crd2idx(coord_pack, layout) - vec_elems = 1 - scale_view = flir.TensorView( - arg_scale, - (1,), - strides=(1,), - base_indices=(idx_pack,), - element_type=_scale_elem_type(), - ) - scale = flir.copy( - flir.make_copy_atom(_scale_elem_type(), vector_size=1), - scale_view, - None, - alignment=8, - return_vector=True, - src_buffer_resource=rsrc, - src_buffer_offset_in_bytes=False, - ) - # Split 16B pack into two 8B halves. - return scale - - def load_b_scale_tile(base_k): - b_scale_tile = [] - for ku in range_constexpr(k_unroll_packed): - for ni in range_constexpr(num_acc_n_packed): - scale = load_scale( - arg_scale_b, - scale_b_rsrc, - layout_b_scale, - ku + base_k, - ni + (by_n + n_tile_base) // pack_N // 16, - ) - b_scale_tile.append(scale) - return b_scale_tile - - def load_a_scale_tile(base_k): - a_scale_tile = [] - for ku in range_constexpr(k_unroll_packed): - for mi in range_constexpr(m_repeat_packed): - scale = load_scale( - arg_scale_a, - scale_a_rsrc, - layout_a_scale, - ku + base_k, - mi + bx_m // pack_M // 16, - ) - a_scale_tile.append(scale) - return a_scale_tile - - def prefetch_ab_scale_tile(base_k): - return [load_a_scale_tile(base_k), load_b_scale_tile(base_k)] - - def lds_load_16b(curr_row_a_lds, col_base, lds_base): - # Swizzle in bytes, then convert to element offset for memref indexing. - col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) - col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / 2) - coord_a16 = flir.make_coord(curr_row_a_lds, col_base_swz) - idx_a16 = flir.crd2idx(coord_a16, layout_lds) - idx_a16 = idx_a16 + lds_base - return vector.load_op(_a_vec16_type(), lds_a, [idx_a16]) - - # --- A LDS load helper for K64-bytes (load 16B once, extract 2x i64 halves) --- - def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): - loaded_a16 = lds_load_16b(curr_row_a_lds, col_base, lds_base) - a_i64x2 = vector.bitcast(T.i64x2, loaded_a16) - a0_i64 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) - a1_i64 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) - - return a0_i64, a1_i64 - - # --- A load/store (16B chunks), XOR16 swizzle --- - num_a_loads = bytes_per_thread_a // a_load_bytes - # A tile mapping in dwords along K: - if elem_bytes == 2: - tile_k_dwords = (tile_k * 2) // 4 - else: - tile_k_dwords = tile_k // 4 // a_elem_vec_pack - layout_a_tile_div4 = flir.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) - c4 = arith.constant(4, index=True) - tx_i32_base = tx * c4 - atom_a_g2r16 = flir.make_copy_atom(_a_elem_type(), vector_size=(16 if elem_bytes == 1 else 8)) - - def load_a_16(idx_elem): - return buffer_copy_gmem16_dwordx4( - flir, - arg=arg_a, - elem_type=_a_elem_type(), - idx_i32=idx_elem, - atom_g2r16=atom_a_g2r16, - rsrc=a_rsrc, - vec_elems=(16 if elem_bytes == 1 else 8), - ) - - def a_tile_chunk_coord_i32(i: int): - return tile_chunk_coord_i32( - flir, - arith, - tx_i32_base=tx_i32_base, - i=i, - total_threads=total_threads, - layout_tile_div4=layout_a_tile_div4, - ) - - def load_a_tile(base_k_div4): - parts = [] - for i in range_constexpr(num_a_loads): - row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i) - row_a_global = bx_m + row_a_local - coord_a_g = flir.make_coord(row_a_global, base_k_div4 + col_a_local_i32) - idx_i32 = flir.crd2idx(coord_a_g, layout_a_div4) - # `idx_i32` is a dword offset. For 2B element types (fp16/bf16), - # convert to element offset so the generic `vector.load` path reads - # the right address (FLIR only specializes buffer_load_dwordx4 for 1B types). - idx_elem = ( - idx_i32 - if elem_bytes == 1 - else (idx_i32 * arith.constant(2, index=True)) - ) - a_16B = load_a_16(idx_elem) - - parts.append(vector.bitcast(T.i32x4, a_16B)) - return parts - - def store_a_tile_to_lds(vec_a_parts, lds_base): - for i in range_constexpr(num_a_loads): - row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i) - lds_store_16b_xor16( - flir, - arith, - vector, - lds_memref=lds_a, - vec16_ty=_a_vec16_type(), - elem_type=_a_elem_type(), - atom_s16=atom_a_g2r16, - layout_lds=layout_lds, - row_local=row_a_local, - col_local_i32=col_a_local_i32, - tx_c4=c4, - k_blocks16=k_blocks16, - lds_base=lds_base, - vec_part_i32x4=vec_a_parts[i], - elem_bytes=elem_bytes, - ) - - def prefetch_ab_tile(base_k): - base_k_bytes = base_k * arith.constant(int(elem_bytes), index=True) - base_k_div4 = base_k_bytes / 4 - a_regs = load_a_tile(base_k_div4 // a_elem_vec_pack) - b_regs = load_b_tile(base_k // 2) - return a_regs, b_regs - - - vec1_f16 = ir.VectorType.get([1], ir.F16Type.get()) - vec2_f16 = ir.VectorType.get([2], ir.F16Type.get()) - vec1_i16 = ir.VectorType.get([1], ir.IntegerType.get_signless(16)) - vec2_i16 = ir.VectorType.get([2], ir.IntegerType.get_signless(16)) - vec1_i32 = ir.VectorType.get([1], ir.IntegerType.get_signless(32)) - vec4_i32 = ir.VectorType.get([4], ir.IntegerType.get_signless(32)) - - def store_output(final_accs): - if use_cshuffle_epilog: - if lds_out is None: - raise RuntimeError( - "use_cshuffle_epilog=True but lds_out is not allocated/aliased." - ) - # We reuse the A LDS allocation as `lds_out` for the cshuffle epilogue. - # Add a block-wide barrier before starting to write into LDS to avoid - # racing with the tail of the mainloop's LDS reads (different waves can - # reach the epilogue at slightly different times). - gpu.barrier() - - def write_row_to_lds( - *, - mi: int, - ii: int, - row_in_tile, - row, - row_base_lds, - col_base_local, - num_acc_n: int, - lds_out, - ): - # Store packed half2 to LDS as i32: - # - Each lane computes one f16 (for its lane_mod_16 column) - # - Use ds_bpermute to grab the neighbor lane's f16 bits and pack (even, odd) - # - Store to the even column address / 2 in the i32 alias view - c0_i32 = arith.constant(0, type=T.i32) - c1_i32 = arith.constant(1, type=T.i32) - cFE_i32 = arith.constant(0xFFFFFFFE, type=T.i32) - c2_i32 = arith.constant(2, type=T.i32) - - lane_id_i32 = arith.index_cast(T.i32, lane_id) - lane_lsb = arith.andi(lane_id_i32, c1_i32) - is_odd = lane_lsb != c0_i32 - nbr_lane = arith.xori(lane_id_i32, c1_i32) - nbr_lane_bytes = arith.shli(nbr_lane, c2_i32) # lane_id * 4 (bytes) - - for ni in range_constexpr(num_acc_n): - col_local = col_base_local + (ni * 16) - acc_idx = mi * num_acc_n + ni - acc = final_accs[acc_idx] - val = vector.extract(acc, static_position=[ii], dynamic_position=[]) - v16 = arith.trunc_f(T.f16, val) - - # v16 (f16) -> bits in i32 low16 - v1_f16 = vector.from_elements(vec1_f16, [v16]) - v1_i16 = vector.bitcast(vec1_i16, v1_f16) - v16_i16 = vector.extract( - v1_i16, static_position=[0], dynamic_position=[] - ) - # Zero-extend i16 bits to i32: - # Build a 2xi16 vector (low16=v16 bits, high16=0) then bitcast to 1xi32. - z16 = arith.constant(0, type=T.i16) - v2_i16 = vector.from_elements(vec2_i16, [v16_i16, z16]) - v16_i32 = vector.extract( - vector.bitcast(vec1_i32, v2_i16), - static_position=[0], - dynamic_position=[], - ) - - # Neighbor's bits (per-lane): ds_bpermute uses a byte index. - nbr_i32 = rocdl.ds_bpermute( - T.i32, - arith.unwrap(nbr_lane_bytes), - arith.unwrap(v16_i32), - ) - - # Convert neighbor bits back to f16 so we can store vec2. - nbr_v1_i32 = vector.from_elements(vec1_i32, [nbr_i32]) - nbr_v2_i16 = vector.bitcast(vec2_i16, nbr_v1_i32) - nbr_i16 = vector.extract( - nbr_v2_i16, static_position=[0], dynamic_position=[] - ) - nbr_v1_i16 = vector.from_elements(vec1_i16, [nbr_i16]) - nbr_v1_f16 = vector.bitcast(vec1_f16, nbr_v1_i16) - nbr_f16 = vector.extract( - nbr_v1_f16, static_position=[0], dynamic_position=[] - ) - - even_f16 = arith.select(is_odd, nbr_f16, v16) - odd_f16 = arith.select(is_odd, v16, nbr_f16) - - # Store [even, odd] as a single 32-bit LDS write (2xf16). - col_local_i32 = arith.index_cast(T.i32, col_local) - col_even_i32 = arith.andi(col_local_i32, cFE_i32) - col_even = arith.index_cast(T.index, col_even_i32) - - lds_idx = row_base_lds + col_even - v2 = vector.from_elements(vec2_f16, [even_f16, odd_f16]) - - def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): - # Store vector to C at (row, col_g0). - # - # IMPORTANT: - # RawPtrBufferStoreOp offsets are in BYTES. `buffer_ops.buffer_store()` - # will scale by element bytes based on the *data type*. For f16 vectors, - # some backends/paths can be fragile. We explicitly bitcast to i32 - # and pass a byte offset to keep the store well-defined. - idx_out = flir.crd2idx(flir.make_coord(row, col_g0), layout_c) # f16 element offset - byte_off = idx_out * arith.constant(2, index=True) # bytes - - if e_vec == 4: - frag_i32x2 = vector.bitcast(T.vec(2, T.i32), frag) - buffer_ops.buffer_store( - frag_i32x2, c_rsrc, byte_off, offset_is_bytes=True - ) - else: - # e_vec == 2: pack 2xf16 -> 1xi32 - frag_i32x1 = vector.bitcast(T.vec(1, T.i32), frag) - frag_i32 = vector.extract( - frag_i32x1, static_position=[0], dynamic_position=[] - ) - buffer_ops.buffer_store( - frag_i32, c_rsrc, byte_off, offset_is_bytes=True - ) - - # Prefer 16B stores when possible: - # - EVec=4 => 4xf16 (8B) per store (and matches tile_n multiples of 128) - # - EVec=2 => 2xf16 (4B) per store (tile_n multiples of 64) - e_vec = 4 if (int(tile_n) % (32 * 4)) == 0 else 2 - mfma_epilog( - use_cshuffle=True, - arith=arith, - vector=vector, - gpu=gpu, - range_constexpr=range_constexpr, - tile_m=tile_m, - tile_n=tile_n, - e_vec=e_vec, - m_repeat=m_repeat, - num_acc_n=num_acc_n, - tx=tx, - lane_div_16=lane_div_16, - lane_mod_16=lane_mod_16, - bx_m=bx_m, - by_n=by_n, - n_tile_base=n_tile_base, - lds_out=lds_out, - write_row_to_lds=write_row_to_lds, - store_pair=store_pair, - ) - return - - def body_row(*, mi: int, ii: int, row_in_tile, row): - col_base = by_n + n_tile_base + lane_mod_16 - idx_base = flir.crd2idx(flir.make_coord(row, col_base), layout_c) - for ni in range_constexpr(num_acc_n): - acc_idx = mi * num_acc_n + ni - acc = final_accs[acc_idx] - val = vector.extract(acc, static_position=[ii], dynamic_position=[]) - # if is_int8: - # val = arith.sitofp(T.f32, val) - - # val_s = (val * s_a) * s_b_vals[ni] - val_f16 = arith.trunc_f(T.f16, val) - idx_out = idx_base + arith.constant(ni * 16, index=True) - buffer_ops.buffer_store(val_f16, c_rsrc, idx_out) - - mfma_epilog( - use_cshuffle=False, - arith=arith, - range_constexpr=range_constexpr, - m_repeat=m_repeat, - lane_div_16=lane_div_16, - bx_m=bx_m, - body_row=body_row, - ) - - # ---------------- Scheduling hints (match CK-style) ---------------- - # These sched_group_barrier hints help the backend interleave VMEM/DS/MFMA - # similarly to CK's tuned pipelines. - rocdl.sched_barrier(0) - - # def hot_loop_scheduler(): - # # - MFMA group size per "slot": num_acc_n - # # - Total MFMA per tile: (2*K32 per K64) * k_unroll * m_repeat * num_acc_n - # # - We emit (mfma_group + dsrd + mfma_group) per scheduler iteration. - # mfma_group = num_acc_n - # mfma_total = (k_unroll * 2) * m_repeat * mfma_group - # mfma_per_iter = 2 * mfma_group - # sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) - # - # # DS-read preload (CK default is 2). - # rocdl.sched_dsrd(2) - # rocdl.sched_mfma(1) - # if tile_m == 16: - # rocdl.sched_vmem(1) - # rocdl.sched_mfma(1) - # if tile_m == 16: - # rocdl.sched_vmem(1) - # if num_acc_n < 4: - # rocdl.sched_dsrd(1) - # rocdl.sched_mfma(1) - # if tile_m == 16: - # rocdl.sched_vmem(1) - # rocdl.sched_dsrd(1) - # rocdl.sched_mfma(1) - # if tile_m == 16: - # rocdl.sched_vmem(1) - # rocdl.sched_mfma(1) - # - # # DS-write hints near the end: match total A LDS-store micro-ops per thread. - # dswr_tail = num_a_loads - # if dswr_tail > sche_iters: - # dswr_tail = sche_iters - # dswr_start = sche_iters - dswr_tail - # - # for sche_i in range_constexpr(sche_iters): - # rocdl.sched_vmem(1) - # rocdl.sched_mfma(mfma_group) - # rocdl.sched_dsrd(1) - # rocdl.sched_mfma(mfma_group) - # if sche_i >= dswr_start - 1: - # rocdl.sched_dswr(1) - # - # rocdl.sched_barrier(0) - - # ---------------- Pipeline ---------------- - # LDS base offsets are in *elements* of `_elem_type()`. - # We keep LDS laid out as (tile_m, tile_k) in element units. - lds_tile_elems = arith.constant(tile_m * tile_k // a_elem_vec_pack, index=True) - lds_base0 = arith.constant(0, index=True) - lds_base1 = lds_tile_elems - - if lds_stage == 2: - # ---------------- Ping-pong pipeline (2 LDS buffers) ---------------- - # Cross-tile A0 LDS prefetch (default-on): - # issue the first A-pack DS read for the next tile *between* barriers, - # so it can overlap with the VMEM prefetch of the following tile. - - def prefetch_a0_pack(lds_base): - # (mi=0, ku=0): prefetch both K32 halves (K64) for the first A-pack. - return lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base) - - # Prologue: tile-0 - k0 = arith.constant(0, index=True) - a_regs0, b_tile0 = prefetch_ab_tile(k0) - a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(k0 // 2) - - store_a_tile_to_lds(a_regs0, lds_base0) - gpu.barrier() - accs = [acc_init] * (num_acc_n * m_repeat) - - lds_base_pong = lds_base0 - lds_base_ping = lds_base1 - b_tile_pong = b_tile0 - c_k_main = c_k - tile_k - - # Prefetch A0 for the first compute tile (overlap with the next VMEM prefetch). - a0_prefetch_pong = prefetch_a0_pack(lds_base_pong) - - - num_tiles = K // tile_k - if (num_tiles % 2) == 1: - for k_iv in range(0, c_k_main, tile_k * 2): - next_k1 = k_iv + tile_k - a_regs_ping, b_tile_ping = prefetch_ab_tile(next_k1) - a_scale_ping, b_scale_ping = prefetch_ab_scale_tile(next_k1 // 256) - - accs, _ = compute_tile( - accs, - b_tile_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - b_scale=b_scale_pong, - ) - a0_prefetch_pong = None - - store_a_tile_to_lds(a_regs_ping, lds_base_ping) - # hot_loop_scheduler() - gpu.barrier() - - # Cross-tile prefetch for the ping tile we are about to compute. - a0_prefetch_ping = prefetch_a0_pack(lds_base_ping) - - next_k2 = k_iv + tile_k * 2 - a_regs_pong, b_tile_pong = prefetch_ab_tile(next_k2) - a_scale_pong, b_scale_pong= prefetch_ab_scale_tile(next_k2 // 256) - - accs, _ = compute_tile( - accs, - b_tile_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - a_scale=a_scale_ping, - b_scale=b_scale_ping, - ) - a0_prefetch_ping = None - - store_a_tile_to_lds(a_regs_pong, lds_base_pong) - # hot_loop_scheduler() - gpu.barrier() - - # Cross-tile prefetch for the next pong tile. - a0_prefetch_pong = prefetch_a0_pack(lds_base_pong) - - final_accs, _ = compute_tile( - accs, - b_tile_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - b_scale=b_scale_pong, - ) - else: - c_k_stop = c_k - (tile_k * 3) - for k_iv in range(0, c_k_stop, tile_k * 2): - next_k1 = k_iv + tile_k - a_regs_ping, b_tile_ping = prefetch_ab_tile(next_k1) - a_scale_ping, b_scale_ping= prefetch_ab_scale_tile(next_k1 // 256) - - accs, _ = compute_tile( - accs, - b_tile_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - b_scale=b_scale_pong, - ) - a0_prefetch_pong = None - - store_a_tile_to_lds(a_regs_ping, lds_base_ping) - # hot_loop_scheduler() - gpu.barrier() - - a0_prefetch_ping = prefetch_a0_pack(lds_base_ping) - - next_k2 = k_iv + tile_k * 2 - a_regs_pong, b_tile_pong = prefetch_ab_tile(next_k2) - a_scale_pong, b_scale_pong= prefetch_ab_scale_tile(next_k2 // 256) - - accs, _ = compute_tile( - accs, - b_tile_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - a_scale=a_scale_ping, - b_scale=b_scale_ping, - ) - a0_prefetch_ping = None - - store_a_tile_to_lds(a_regs_pong, lds_base_pong) - # hot_loop_scheduler() - gpu.barrier() - - a0_prefetch_pong = prefetch_a0_pack(lds_base_pong) - - last_k = c_k - tile_k - a_regs_ping, b_tile_ping = prefetch_ab_tile(last_k) - a_scale_ping, b_scale_ping= prefetch_ab_scale_tile(last_k // 256) - - accs, _ = compute_tile( - accs, - b_tile_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - b_scale=b_scale_pong, - ) - - a0_prefetch_pong = None - - store_a_tile_to_lds(a_regs_ping, lds_base_ping) - # hot_loop_scheduler() - gpu.barrier() - - a0_prefetch_ping = prefetch_a0_pack(lds_base_ping) - - final_accs, _ = compute_tile( - accs, - b_tile_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - a_scale=a_scale_ping, - b_scale=b_scale_ping, - ) - - store_output(final_accs) - # else: - # # CK-like bpreshuffle v1 spirit: - # # - Intrawave schedule - # # - Global prefetch 2 (regs double-buffer) - # # - Local shared memory buffer 1 (single LDS tile for A) - # # Prologue: tile-0 - # k0 = arith.constant(0, index=True) - # a_regs0, b_tile0 = prefetch_ab_tile(k0) - # store_a_tile_to_lds(a_regs0, lds_base0) - # gpu.barrier() - # accs = [acc_init] * (num_acc_n * m_repeat) - # - # lds_base = lds_base0 - # b_tile_cur = b_tile0 - # - # # For each tile except last: prefetch next tile, compute current, then overwrite LDS. - # for k_base in range(0, c_k - tile_k, tile_k): - # next_k = k_base + tile_k - # a_next, b_next = prefetch_ab_tile(next_k) - # accs, _ = compute_tile(accs, b_tile_cur, lds_base) - # # Single LDS buffer: ensure *all* waves are done reading A from LDS - # # before any wave overwrites it with the next tile. - # gpu.barrier() - # store_a_tile_to_lds(a_next, lds_base) - # # hot_loop_scheduler() - # gpu.barrier() - # b_tile_cur = b_next - # - # final_accs, scales = compute_tile( - # accs, b_tile_cur, lds_base, is_last_tile=True - # ) - # store_output(final_accs, scales) - - @flir.jit - def __call__( - self: flir.T.i64, - arg_c: lambda: memref(DYN, T.f16), - arg_a: lambda: memref(DYN, _a_elem_type()), - arg_b: lambda: memref(DYN, _b_elem_type()), - arg_scale_a: lambda: memref(DYN, _scale_elem_type()), - arg_scale_b: lambda: memref(DYN, _scale_elem_type()), - c_m: lambda: T.index, - c_n: lambda: T.index, - c_k: lambda: T.index, - ): - c1 = arith.constant(1, index=True) - bdx = arith.constant(256, index=True) - tm = arith.constant(tile_m, index=True) - tn = arith.constant(tile_n, index=True) - one = arith.constant(1, index=True) - gx = (c_m + tm - one) / tm - gy = c_n / tn - - flir.gpu_ext.LaunchFuncOp( - [module_name, "kernel_gemm"], - grid_size=(gx, gy, c1), - block_size=(bdx, c1, c1), - kernel_operands=[ - arg_c, - arg_a, - arg_b, - arg_scale_a, - arg_scale_b, - c_m, - c_n, - c_k, - ], - ) - - m = _GEMM() - return flydsl.compile( - m, - use_bare_ptr_memref_call_conv=False, - use_bare_pointers_for_host=False, - use_bare_pointers_for_kernels=False, - ) - - -__all__ = ["compile_mxfp4_preshuffle_gemm"] - diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 6280dbca..df715c5a 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -105,7 +105,8 @@ def compile_preshuffle_gemm( f"(tile_k={tile_k}, b_elem_bytes={b_elem_bytes})" ) - a_tile_k_bytes = int(tile_k) * int(a_elem_bytes) + # For FP4: A is packed (2 elems/byte), so actual bytes = tile_k * a_elem_bytes / a_elem_pack + a_tile_k_bytes = int(tile_k) * int(a_elem_bytes) // a_elem_pack gpu_arch = get_hip_arch() allocator = SmemAllocator(None, arch=gpu_arch) @@ -117,7 +118,8 @@ def compile_preshuffle_gemm( a_bytes_per_thread = pipeline_manager.get_a_bytes_per_thread(tile_m, tile_k) # CK-style LDS128: stride is in BYTES along K (for XOR16 swizzle). - lds_stride_bytes = a_tile_k_bytes // a_elem_pack + # a_tile_k_bytes already accounts for packing, so no further division needed + lds_stride_bytes = a_tile_k_bytes def _get_mfma_dict_value(key, pipeline): value = mfma_pipeline_dicts[key][pipeline] @@ -146,6 +148,9 @@ def _mfma_output_pack_ty(): is_int8 = mfma_pipeline in [MfmaPipeline.I8I8_16x16_PIPELINE] is_fp4 = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, MfmaPipeline.F8F4_MXFP4_PIPELINE] + print(is_fp4) + + no_epilogue_dequant = is_fp4 or is_f16_or_bf16 # 350 16x16x128 adtype(cbsz) & bdtype(blgp) cbsz = 4 if mfma_pipeline == MfmaPipeline.F4F4_MXFP4_PIPELINE else 0 @@ -214,9 +219,11 @@ def kernel_gemm( layout_a_div4 = flir.make_layout((c_m, c_k_div4bytes), stride=(c_k_div4bytes, 1)) # B preshuffle layout (shared with MoE kernels). - b_kpack_bytes = 16 // b_elem_pack + # For FP4: B is packed (2 elems/byte), so adjust c_k accordingly + b_kpack_bytes = 16 + c_k_b = c_k // b_elem_pack layout_b = make_preshuffle_b_layout( - flir, arith, c_n=c_n, c_k=c_k, kpack_bytes=b_kpack_bytes, elem_bytes=b_elem_bytes + flir, arith, c_n=c_n, c_k=c_k_b, kpack_bytes=b_kpack_bytes, elem_bytes=b_elem_bytes ).layout_b # Scale layouts for FP4/MXFP4 (block-scale MFMA). @@ -225,8 +232,10 @@ def kernel_gemm( # LDS layout is element-indexed, but XOR16 swizzle is byte-based. # Represent LDS as (tile_m, tile_k) in elements and scale swizzle math by elem_bytes. - shape_lds = flir.make_shape(tile_m, tile_k) - stride_lds = flir.make_stride(tile_k, 1) + # For FP4: A is packed (2 elems/byte), so LDS K dimension is tile_k / a_elem_pack + lds_tile_k = tile_k // a_elem_pack if is_fp4 else tile_k + shape_lds = flir.make_shape(tile_m, lds_tile_k) + stride_lds = flir.make_stride(lds_tile_k, 1) layout_lds = flir.make_layout(shape_lds, stride_lds) # CK-style XOR16 swizzle parameter (const). @@ -286,8 +295,13 @@ def kernel_gemm( ) m_repeat = tile_m // 16 - # K stepping is byte-addressed: one "micro-tile step" is 64 bytes. - k_unroll = b_tile_k_bytes // 64 + # K stepping is byte-addressed: + # - For FP4/MXFP4 (mfma 16x16x128): one micro-step is 128 bytes + # - For FP8/INT8 (mfma 16x16x32): one micro-step is 64 bytes (K32 x 2) + if is_fp4: + k_unroll = b_tile_k_bytes // 64 // b_elem_pack + else: + k_unroll = b_tile_k_bytes // 64 # --- Dynamic tiling along N (4 waves) --- num_waves = 4 @@ -374,7 +388,8 @@ def prefetch_ab_scale_tile(base_k): """Prefetch A and B scale tiles for FP4/MXFP4.""" if not is_fp4: return None, None - return load_a_scale_tile(base_k), load_b_scale_tile(base_k) + # return load_a_scale_tile(base_k), load_b_scale_tile(base_k) + return load_a_scale_tile(0), load_b_scale_tile(0) # --- B load logic --- # Shared loader supports: @@ -496,7 +511,11 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): num_a_loads = a_bytes_per_thread // a_load_bytes # A tile mapping in dwords along K: # tile_k_dwords = (tile_k * elem_bytes) / 4 - a_tile_k_dwords = tile_k * a_elem_bytes // 4 + # For FP4: A is packed (2 elems/byte), so divide by a_elem_pack + if is_fp4: + a_tile_k_dwords = tile_k * a_elem_bytes // 4 // a_elem_pack + else: + a_tile_k_dwords = tile_k * a_elem_bytes // 4 layout_a_tile_div4 = flir.make_layout((tile_m, a_tile_k_dwords), stride=(a_tile_k_dwords, 1)) c4 = arith.constant(4, index=True) tx_i32_base = tx * c4 @@ -564,15 +583,20 @@ def prefetch_ab_tile(base_k): # Convert element index to byte index, then to dword index. base_k_bytes = base_k * arith.constant(int(a_elem_bytes), index=True) base_k_div4 = base_k_bytes / 4 - a_regs = load_a_tile(base_k_div4) - b_regs = load_b_tile(base_k) + if is_fp4: + # For FP4/MXFP4: A and B are packed (2 elems/byte), need to adjust offsets + a_regs = load_a_tile(base_k_div4 // a_elem_pack) + b_regs = load_b_tile(base_k // b_elem_pack) + else: + a_regs = load_a_tile(base_k_div4) + b_regs = load_b_tile(base_k) return a_regs, b_regs def compute_tile(accs_in, b_tile_in, lds_base, *, is_last_tile=False, a0_prefetch=None, a_scale=None, b_scale=None): scales_pf = {} mfma_res_ty = _mfma_output_pack_ty() - if is_last_tile and (not is_f16_or_bf16) and (not is_fp4): + if is_last_tile and (not no_epilogue_dequant): # Prefetch scales for non-FP4 scaled paths (fp8/int8/int4 with per-tensor scale). s_b_vals = [] for ni in range_constexpr(num_acc_n): @@ -612,7 +636,7 @@ def compute_tile(accs_in, b_tile_in, lds_base, *, is_last_tile=False, a0_prefetc a_elem_vec_pack=a_elem_pack, k_unroll=k_unroll, m_repeat=m_repeat, - num_acc=num_acc_n, + num_acc_n=num_acc_n, pack_K=pack_K, pack_M=pack_M, pack_N=pack_N, @@ -661,7 +685,7 @@ def compute_tile(accs_in, b_tile_in, lds_base, *, is_last_tile=False, a0_prefetc def store_output(final_accs, scales): # fp16/bf16: no scale fetch, no scale multiply in epilogue. - if is_f16_or_bf16: + if no_epilogue_dequant: s_b_vals = None s_a_vecs = None else: @@ -716,7 +740,7 @@ def write_row_to_lds( # if is_int8: if is_int8 or is_int4: val = arith.sitofp(T.f32, val) - if is_f16_or_bf16: + if no_epilogue_dequant: val_s = val else: val_s = (val * s_a) * s_b_vals[ni] @@ -823,7 +847,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): return def body_row(*, mi: int, ii: int, row_in_tile, row): - if not is_f16_or_bf16: + if not no_epilogue_dequant: s_a_vec4 = s_a_vecs[mi] s_a = vector.extract(s_a_vec4, static_position=[ii], dynamic_position=[]) col_base = by_n + n_tile_base + lane_mod_16 @@ -835,7 +859,7 @@ def body_row(*, mi: int, ii: int, row_in_tile, row): if is_int8 or is_int4: # INT8/INT4 paths use i32 accumulators; convert to f32 for scaled epilogue. val = arith.sitofp(T.f32, val) - if is_f16_or_bf16: + if no_epilogue_dequant: val_s = val else: val_s = (val * s_a) * s_b_vals[ni] @@ -905,7 +929,9 @@ def hot_loop_scheduler(): # ---------------- Pipeline ---------------- # LDS base offsets are in *elements* of `_elem_type()`. # We keep LDS laid out as (tile_m, tile_k) in element units. - lds_tile_elems = arith.constant(tile_m * tile_k, index=True) + # For FP4: A is packed (2 elems/byte), so divide by a_elem_pack + lds_tile_elems_val = tile_m * tile_k // a_elem_pack if is_fp4 else tile_m * tile_k + lds_tile_elems = arith.constant(lds_tile_elems_val, index=True) lds_base0 = arith.constant(0, index=True) lds_base1 = lds_tile_elems @@ -923,7 +949,7 @@ def prefetch_a0_pack(lds_base): k0 = arith.constant(0, index=True) a_regs0, b_tile0 = prefetch_ab_tile(k0) # Prefetch scales for FP4/MXFP4 at tile-0. - a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(k0 // 2) if is_fp4 else (k0, k0) + a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(k0 // 256) if is_fp4 else (k0, k0) store_a_tile_to_lds(a_regs0, lds_base0) gpu.barrier() @@ -1125,4 +1151,4 @@ def __call__( ) -__all__ = ["compile_preshuffle_gemm"] \ No newline at end of file +__all__ = ["compile_preshuffle_gemm"] diff --git a/tests/kernels/test_preshuffle_gemm.py b/tests/kernels/test_preshuffle_gemm.py index 33976829..f6d31606 100644 --- a/tests/kernels/test_preshuffle_gemm.py +++ b/tests/kernels/test_preshuffle_gemm.py @@ -516,7 +516,7 @@ def launch_kernel(c, a, b, sa, sb): else: pack_M = 2 test_mfma_w4_flir_preshuffle( - args.in_dtype if args.in_dtype == "fp8" else "fp4", + args.a_dtype if args.a_dtype == "fp8" else "fp4", "fp4", M=args.M, N=args.N, From d5241c3ee565ca56c6a46795c9432a899ed07a22 Mon Sep 17 00:00:00 2001 From: zanzhang Date: Wed, 4 Feb 2026 15:54:56 +0800 Subject: [PATCH 04/11] update --- kernels/moe_gemm_2stage.py | 522 +++++++++++++++++++-------------- kernels/preshuffle_gemm.py | 1 - tests/kernels/test_moe_gemm.py | 183 +++++++----- 3 files changed, 412 insertions(+), 294 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 9f7f0ba0..21921e8a 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -35,6 +35,9 @@ make_preshuffle_b_layout, load_b_pack_k32, tile_chunk_coord_i32, + PreshufflePipelineManager, + mfma_pipeline_dicts, + make_preshuffle_scale_layout, ) from kernels.mfma_epilogues import c_shuffle_epilog, default_epilog, mfma_epilog @@ -51,7 +54,8 @@ def compile_moe_gemm1( tile_k: int, # NOTE: aiter swap passes these for API symmetry; stage1 uses dynamic memrefs so they are ignored. doweight_stage1: bool, - in_dtype: str = "fp8", + x_dtype: str = "fp8", + w_dtype: str = "fp8", out_dtype: str = "f16", use_cshuffle_epilog: bool | None = None, ): @@ -67,51 +71,84 @@ def compile_moe_gemm1( allocator = SmemAllocator(None, arch=gpu_arch) _state = {} - if in_dtype not in ("fp8", "fp16", "int8", "int4"): - raise ValueError(f"in_dtype must be one of ('fp8','fp16','int8','int4'), got {in_dtype!r}") - is_f16 = in_dtype == "fp16" - elem_bytes = 2 if is_f16 else 1 - if out_dtype not in ("f16", "bf16"): - raise ValueError(f"out_dtype must be 'f16' or 'bf16', got {out_dtype!r}") - # NOTE: don't materialize MLIR types outside an active MLIR Context. - out_mlir = lambda: (T.f16() if out_dtype == "f16" else T.bf16()) - tile_k_bytes = int(tile_k) * int(elem_bytes) + total_threads = 256 + + pipeline_manager = PreshufflePipelineManager(x_dtype, w_dtype, out_dtype) + pipeline_manager.check_type_valid() + + epilog_pipeline = pipeline_manager.get_epilog_pipeline() + mfma_pipeline = pipeline_manager.get_mfma_pipeline() + mfma_fn = pipeline_manager.get_mfma_fn() + + x_elem_bytes = pipeline_manager.get_a_elem_bytes() + w_elem_bytes = pipeline_manager.get_b_elem_bytes() + out_elem_bytes = pipeline_manager.get_out_elem_bytes() + + x_elem_pack = pipeline_manager.a_elem_pack + w_elem_pack = pipeline_manager.b_elem_pack + + # pack_K is only used for FP4 modes which need special packing + is_fp4_mode = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, + MfmaPipeline.F8F4_MXFP4_PIPELINE] + pack_M = 2 if is_fp4_mode else 1 + pack_N = 2 if is_fp4_mode else 1 + pack_K = 2 if is_fp4_mode else 1 + + tile_k_bytes = int(tile_k) * int(x_elem_bytes) # K64-byte micro-step: always 64 bytes per `ku`. For fp16 this is 32 elements. if (tile_k_bytes % 64) != 0: raise ValueError( f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " - f"(tile_k={tile_k}, elem_bytes={elem_bytes})" + f"(tile_k={tile_k}, elem_bytes={x_elem_bytes})" ) - is_int4 = in_dtype == "int4" - # INT4 here means W4A8: X is int8, W is packed int4 and unpacked to int8 in-kernel. - is_int8 = (in_dtype == "int8") or is_int4 - - mfma_i32_k32 = None - if is_int8: - mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( - rocdl, "mfma_i32_16x16x32_i8", None - ) - if mfma_i32_k32 is None: - raise AttributeError( - "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " - "(or `rocdl.mfma_i32_16x16x32_i8`)." - ) + + def _get_mfma_dict_value(key, pipeline): + value = mfma_pipeline_dicts[key][pipeline] + return value() if callable(value) else value + + def _x_elem_type(): + return _get_mfma_dict_value("a_elem_type", mfma_pipeline) + def _w_elem_type(): + return _get_mfma_dict_value("b_elem_type", mfma_pipeline) + def _scale_elem_type(): + return _get_mfma_dict_value("scale_elem_type", mfma_pipeline) + def _out_elem_type(): + return _get_mfma_dict_value("out_elem_type", epilog_pipeline) + def _x_vec16_type(): + return _get_mfma_dict_value("x_vec16_type", mfma_pipeline) + def _w_vec16_type(): + return _get_mfma_dict_value("w_vec16_type", mfma_pipeline) + def _mfma_input_pack_ty(): + return _get_mfma_dict_value("mfma_input_pack_ty", mfma_pipeline) + def _mfma_output_pack_ty(): + return _get_mfma_dict_value("mfma_output_pack_ty", mfma_pipeline) + + is_f16_or_bf16 = mfma_pipeline in [MfmaPipeline.F16F16_16x16_PIPELINE, + MfmaPipeline.BF16BF16_16x16_PIPELINE] + is_int4 = mfma_pipeline in [MfmaPipeline.I8I4_16x16_PIPELINE] + is_int8 = mfma_pipeline in [MfmaPipeline.I8I8_16x16_PIPELINE] + is_fp4 = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, + MfmaPipeline.F8F4_MXFP4_PIPELINE] + + no_epilogue_dequant = is_fp4 or is_f16_or_bf16 DYN = ir.ShapedType.get_dynamic_size() size_out = DYN size_x = DYN + # W is packed int4 for W4A8: 2 values per byte. - size_w = (experts * (2 * inter_dim) * model_dim) // 2 if is_int4 else (experts * (2 * inter_dim) * model_dim) + size_w = (experts * (2 * inter_dim) * model_dim) // w_elem_pack size_sorted = DYN size_expert_ids = DYN total_threads = 256 - bytes_x_per_tile = int(tile_m) * int(tile_k) * int(elem_bytes) + bytes_x_per_tile = int(tile_m) * int(tile_k) * int(x_elem_bytes) // x_elem_pack if bytes_x_per_tile % total_threads != 0: raise ValueError( "tile_m*tile_k*elem_bytes must be divisible by " - f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={elem_bytes}" + f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={x_elem_bytes}" ) + bytes_per_thread_x = bytes_x_per_tile // total_threads # Keep MoE stage1 X gmem->LDS pipeline consistent with the optimized GEMM kernel: # split into <=16B pieces and use `flir.copy(load-only)` for buffer_load_dwordx4. @@ -125,6 +162,7 @@ def compile_moe_gemm1( lds_stride = tile_k + pad_k if use_cshuffle_epilog is None: use_cshuffle_epilog = os.environ.get("FLIR_MOE_STAGE1_CSHUFFLE", "1") in ("1", "true", "True", "YES", "yes") + use_cshuffle_epilog = bool(use_cshuffle_epilog) if out_dtype != "f16" and use_cshuffle_epilog: raise ValueError("stage1 cshuffle epilog currently supports only f16 output (out_dtype='f16')") @@ -133,7 +171,7 @@ def compile_moe_gemm1( # IMPORTANT: module name participates in FlyDSL's compile cache key. # Keep an explicit ABI tag so signature changes can't accidentally reuse an old binary. module_name = ( - f"mfma_moe1_{in_dtype}_{out_dtype}_{epilog_tag}" + f"mfma_moe1_{x_dtype}_{w_dtype}_{out_dtype}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" ).replace("-", "_") @@ -149,22 +187,23 @@ def init_gpu_module(self): # - ping-pong X tiles (2 * tile_m * lds_stride bytes; fp8/int8) # - epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes) _use_cshuffle_epilog = bool(use_cshuffle_epilog) - lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(elem_bytes) + lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(x_elem_bytes) lds_out_bytes = 2 * tile_m * tile_n if _use_cshuffle_epilog else 0 lds_total_bytes = max(lds_x_bytes, lds_out_bytes) - lds_total_elems = lds_total_bytes if elem_bytes == 1 else (lds_total_bytes // 2) - x_lds_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + lds_total_elems = lds_total_bytes if x_elem_bytes == 1 else (lds_total_bytes // 2) + # x_lds_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + x_lds_elem = I.f16 if is_f16_or_bf16 else (I.i8 if is_int8 else I.f8) _state["lds_x_decl"] = allocator.allocate_array(x_lds_elem, lds_total_elems) allocator.finalize() @flir.kernel def moe_gemm1( self: flir.T.i64, - arg_out: lambda: T.memref(DYN, out_mlir()), - arg_x: lambda: T.memref(DYN, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_w: lambda: T.memref(DYN, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_scale_x: lambda: T.memref(DYN, T.f32()), - arg_scale_w: lambda: T.memref(experts * (2 * inter_dim), T.f32()), + arg_out: lambda: T.memref(DYN, _out_elem_type()), + arg_x: lambda: T.memref(DYN, _x_elem_type()), + arg_w: lambda: T.memref(DYN, _w_elem_type()), + arg_scale_x: lambda: T.memref(DYN, _scale_elem_type()), + arg_scale_w: lambda: T.memref(DYN, _scale_elem_type()), arg_sorted_token_ids: lambda: T.memref(DYN, T.i32()), arg_expert_ids: lambda: T.memref(DYN, T.i32()), arg_sorted_weights: lambda: T.memref(DYN, T.f32()), @@ -174,9 +213,9 @@ def moe_gemm1( k_in: lambda: T.index(), size_expert_ids_in: lambda: T.index(), ): - x_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + x_elem = _x_elem_type() # For int4, weights are stored as packed bytes (i8) and unpacked to i8 packs. - w_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + w_elem = _w_elem_type() f16 = I.f16 f32 = I.f32 i32 = I.i32 @@ -185,11 +224,11 @@ def moe_gemm1( vec4_i32 = I.vec(4, i32) vec1_f16 = I.vec(1, f16) vec4_f16 = I.vec(4, f16) - vec16_elems = 16 if elem_bytes == 1 else 8 - vec8_elems = 8 if elem_bytes == 1 else 4 - vec4_elems = 4 if elem_bytes == 1 else 2 - vec8_x = I.vec(vec8_elems, x_elem) - vec16_x = I.vec(vec16_elems, x_elem) + vec16_x_elems = 16 // x_elem_bytes # if x_elem_bytes == 1 else 8 + vec8_x_elems = 8 // x_elem_bytes # if x_elem_bytes == 1 else 4 + vec4_x_elems = 4 // x_elem_bytes # if x_elem_bytes == 1 else 2 + vec8_x = I.vec(vec8_x_elems, x_elem) + vec16_x = I.vec(vec16_x_elems, x_elem) vec1_i64 = I.vec(1, i64) vec2_i64 = I.vec(2, i64) @@ -206,12 +245,23 @@ def silu(x): den = 1.0 + emu sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) return x * sig + + def swiglu(gate, up, alpha=1.702, limit=7.0): + # Align with CK's device fast path + # + # Using llvm.amdgcn intrinsics prevents lowering to the div_scale/div_fixup + # sequences that introduce extra compares/cndmasks. + gate = arith.minimum(gate, limit) + up = arith.minimum(up, limit) + up = arith.maximum(up, -limit) - acc_init = ( - arith.constant_vector(0, vec4_i32) - if is_int8 - else arith.constant_vector(0.0, vec4_f32) - ) + t = gate * (alpha) * (-1.4426950408889634) # -log2(e) + emu = llvm.call_intrinsic(f32, "llvm.amdgcn.exp2.f32", [t], [], []) + den = 1.0 + emu + sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) + return gate * sig * (up + 1) + + acc_init = (arith.constant_vector(0, _mfma_output_pack_ty())) # Layouts layout_x = flir.make_layout((tokens_in, k_in), stride=(k_in, 1)) @@ -220,15 +270,22 @@ def silu(x): c_n_total = arith.constant(experts * (2 * inter_dim), index=True) kpack_bytes = 8 if is_int4 else 16 b_layout = make_preshuffle_b_layout( - flir, arith, c_n=c_n_total, c_k=k_in, kpack_bytes=kpack_bytes, elem_bytes=elem_bytes + flir, arith, c_n=c_n_total, c_k=k_in // pack_K, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes ) layout_b = b_layout.layout_b - c_k0 = (k_in * arith.constant(int(elem_bytes), index=True)) / arith.index(64) shape_lds = flir.make_shape(tile_m, tile_k) stride_lds = flir.make_stride(lds_stride, 1) layout_lds = flir.make_layout(shape_lds, stride_lds) + # A&B's scale preshuffle layout + layout_a_scale = make_preshuffle_scale_layout( + flir, arith, c_mn=tokens_in, c_k=k_in, + ) + layout_b_scale = make_preshuffle_scale_layout( + flir, arith, c_mn=c_n_total, c_k=k_in, + ) + tx = gpu.thread_id("x") # Align with Aiter launch mapping (NSwizzle==false): # - blockIdx.x -> N dimension (tile along inter_dim) @@ -239,23 +296,25 @@ def silu(x): # Block validity: compute as early as possible so invalid blocks skip all buffer-resource # setup, LDS pointer math, and gmem prefetch work. bx_m = bx * arith.constant(tile_m, index=True) + maxids_rsrc = buffer_ops.create_buffer_resource( arg_max_token_ids, max_size=False, num_records_bytes=arith.i32(4) ) max_token_id_i32 = buffer_ops.buffer_load( maxids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=i32 ) + bx_m_i32 = arith.index_cast(i32, bx_m) blk_valid = arith.cmpu(bx_m_i32, max_token_id_i32, "ult") # Common constants/atoms (hoisted): keep IR small like GEMM. # XOR16 swizzle parameter (in bytes; constant, power-of-two in our configs). k_blocks16 = arith.constant(tile_k_bytes // 16, index=True) - atom_x_s16 = flir.make_copy_atom(x_elem, vector_size=vec16_elems) - atom_x_s8 = flir.make_copy_atom(x_elem, vector_size=8) - atom_x_s4 = flir.make_copy_atom(x_elem, vector_size=4) - atom_x_g2r16 = flir.make_copy_atom(x_elem, vector_size=vec16_elems) - atom_x_g2r8 = flir.make_copy_atom(x_elem, vector_size=vec8_elems) - atom_x_g2r4 = flir.make_copy_atom(x_elem, vector_size=vec4_elems) + atom_x_s16 = flir.make_copy_atom(x_elem, vector_size=vec16_x_elems) + atom_x_s8 = flir.make_copy_atom(x_elem, vector_size=vec8_x_elems) + atom_x_s4 = flir.make_copy_atom(x_elem, vector_size=vec4_x_elems) + atom_x_g2r16 = flir.make_copy_atom(x_elem, vector_size=vec16_x_elems) + atom_x_g2r8 = flir.make_copy_atom(x_elem, vector_size=vec8_x_elems) + atom_x_g2r4 = flir.make_copy_atom(x_elem, vector_size=vec4_x_elems) layout_tx_wave_lane = flir.make_layout((4, 64), stride=(64, 1)) layout_lane16 = flir.make_layout((4, 16), stride=(16, 1)) @@ -279,7 +338,7 @@ def silu(x): c_topk = arith.constant(topk, index=True) # X: [tokens, k] bytes = tokens*k*elem_bytes - x_nbytes_idx = tokens_in * k_in * arith.constant(int(elem_bytes), index=True) + x_nbytes_idx = tokens_in * k_in * arith.constant(int(x_elem_bytes), index=True) x_nbytes_i32 = arith.index_cast(i32, x_nbytes_idx) x_rsrc = buffer_ops.create_buffer_resource( arg_x, max_size=False, num_records_bytes=x_nbytes_i32 @@ -288,7 +347,6 @@ def silu(x): w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) # OUT: [tokens, topk, inter] f16/bf16 -> bytes = tokens*topk*inter*out_elem_bytes - out_elem_bytes = 2 # f16/bf16 out_nbytes_idx = tokens_in * c_topk * inter_in * arith.constant(out_elem_bytes, index=True) out_nbytes_i32 = arith.index_cast(i32, out_nbytes_idx) out_rsrc = buffer_ops.create_buffer_resource( @@ -296,7 +354,7 @@ def silu(x): ) # fp16 path ignores scales completely (implicit scale=1.0). - if is_f16: + if no_epilogue_dequant: sx_rsrc = None sw_rsrc = None else: @@ -326,7 +384,7 @@ def silu(x): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. - if is_f16: + if is_f16_or_bf16: if bytes_per_thread_x % 16 != 0: raise ValueError( f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" @@ -346,9 +404,9 @@ def silu(x): num_x_loads = bytes_per_thread_x // x_load_bytes chunk_i32 = x_load_bytes // 4 # dwords per chunk (1/2/4) - c_k_div4 = (k_in * arith.constant(int(elem_bytes), index=True)) / arith.index(4) + c_k_div4 = (k_in * arith.constant(int(x_elem_bytes), index=True)) / arith.index(4) layout_x_div4 = flir.make_layout((tokens_in, c_k_div4), stride=(c_k_div4, 1)) - tile_k_dwords = (int(tile_k) * int(elem_bytes)) // 4 + tile_k_dwords = (int(tile_k) * int(x_elem_bytes)) // 4 layout_x_tile_div4 = flir.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) c_chunk_i32 = arith.constant(chunk_i32, index=True) tx_i32_base = tx * c_chunk_i32 @@ -396,7 +454,7 @@ def load_x(idx_i32): For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. """ if x_load_bytes == 16: - idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * arith.index(2)) + idx_elem = idx_i32 if x_elem_bytes == 1 else (idx_i32 * arith.index(2)) return buffer_copy_gmem16_dwordx4( flir, arg=arg_x, @@ -404,7 +462,7 @@ def load_x(idx_i32): idx_i32=idx_elem, atom_g2r16=atom_x_g2r16, rsrc=x_rsrc, - vec_elems=vec16_elems, + vec_elems=vec16_x_elems, ) idx_bytes = idx_i32 * arith.index(4) atom = atom_x_g2r8 if x_load_bytes == 8 else atom_x_g2r4 @@ -427,7 +485,7 @@ def load_x(idx_i32): def load_x_tile(base_k): """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" - base_k_div4 = (base_k * arith.constant(int(elem_bytes), index=True)) / arith.index(4) + base_k_div4 = (base_k * arith.constant(int(x_elem_bytes), index=True)) / arith.index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] @@ -453,8 +511,8 @@ def load_x_tile(base_k): col_offset_base = flir.crd2idx(flir.make_coord(lane_div_16, 0), layout_lane16) col_offset_base_bytes = ( col_offset_base - if elem_bytes == 1 - else (col_offset_base * arith.constant(int(elem_bytes), index=True)) + if x_elem_bytes == 1 + else (col_offset_base * arith.constant(int(x_elem_bytes), index=True)) ) # Dynamic N tiling within block (same as existing kernels) @@ -496,7 +554,7 @@ def load_x_tile(base_k): m_repeat = tile_m // 16 - k_unroll = tile_k_bytes // 64 # K64-byte micro-step (2x MFMA) + k_unroll = tile_k_bytes // 64 // pack_K # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- def load_b_pack(base_k, ki_step, ni, blk_list, intra_list): @@ -515,7 +573,7 @@ def load_b_pack(base_k, ki_step, ni, blk_list, intra_list): lane_div_16=lane_div_16, # 0..3 elem_type=w_elem, kpack_bytes=kpack_bytes, - elem_bytes=elem_bytes, + elem_bytes=w_elem_bytes, unpack_int4=is_int4, ) @@ -563,7 +621,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): k_blocks16=k_blocks16, lds_base=lds_base, vec_part_i32x4=vec_x_in_parts[i], - elem_bytes=elem_bytes, + elem_bytes=x_elem_bytes, ) elif x_load_bytes == 8: lds_store_8b_xor16( @@ -605,8 +663,8 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base_bytes, k_blocks16) col_base_swz = ( col_base_swz_bytes - if elem_bytes == 1 - else (col_base_swz_bytes / arith.constant(int(elem_bytes), index=True)) + if x_elem_bytes == 1 + else (col_base_swz_bytes / arith.constant(int(x_elem_bytes), index=True)) ) coord_a16 = flir.make_coord(curr_row_a_lds, col_base_swz) idx_a16 = flir.crd2idx(coord_a16, layout_lds) @@ -629,17 +687,12 @@ def compute_tile( ): gate_list = list(acc_gate_in) up_list = list(acc_up_in) - mfma_res_ty = vec4_i32 if is_int8 else vec4_f32 - mfma_fn = ( - mfma_i32_k32 - if is_int8 - else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) - ) + mfma_res_ty = _mfma_output_pack_ty() # Optional: prefetch epilogue scales while we are about to run the last MFMA tile, # matching the preshuffle GEMM pattern of overlapping scale loads with MFMA. epilogue_pf = None - if prefetch_epilogue: + if prefetch_epilogue and not no_epilogue_dequant: expert_off_pf = expert_off_idx sw_gate_pf = [] sw_up_pf = [] @@ -648,14 +701,10 @@ def compute_tile( row_gate_idx = expert_off_pf + col_g row_up_idx = row_gate_idx + inter_idx sw_gate_pf.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_gate_idx, vec_width=1, dtype=f32) + buffer_ops.buffer_load(sw_rsrc, row_gate_idx, vec_width=1, dtype=f32) ) sw_up_pf.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_up_idx, vec_width=1, dtype=f32) + buffer_ops.buffer_load(sw_rsrc, row_up_idx, vec_width=1, dtype=f32) ) epilogue_pf = (sw_gate_pf, sw_up_pf) @@ -664,7 +713,7 @@ def _i64_to_v4f16(x_i64): return vector.bitcast(vec4_f16, v1) def mfma_k64(acc_in, a0, a1, b0, b1): - if is_f16: + if is_f16_or_bf16: a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -855,7 +904,7 @@ def hot_loop_scheduler(): if epilogue_pf is not None: sw_gate_vals, sw_up_vals = epilogue_pf - else: + elif not no_epilogue_dequant: sw_gate_vals = [] sw_up_vals = [] for ni in range_constexpr(num_acc_n): @@ -863,14 +912,10 @@ def hot_loop_scheduler(): row_gate_idx = expert_off + col_g row_up_idx = row_gate_idx + inter_idx sw_gate_vals.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_gate_idx, vec_width=1, dtype=f32) + buffer_ops.buffer_load(sw_rsrc, row_gate_idx, vec_width=1, dtype=f32) ) sw_up_vals.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_up_idx, vec_width=1, dtype=f32) + buffer_ops.buffer_load(sw_rsrc, row_up_idx, vec_width=1, dtype=f32) ) # Epilogue hoists to keep IR + Python build time small: @@ -904,7 +949,7 @@ def write_row_to_lds( t2 = fused2 & mask24_i32 # No explicit mask: rely on buffer descriptor OOB to zero-fill when t2 is the # sentinel (t2 == tokens) or otherwise out-of-range. - sx = arith.f32(1.0) if is_f16 else buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32) + sx = arith.f32(1.0) if no_epilogue_dequant else buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32) # Sorted weight aligned with `row` (matches aiter moe_sorting output). if doweight_stage1: @@ -912,8 +957,6 @@ def write_row_to_lds( for ni in range_constexpr(num_acc_n): col_local = col_base_local + (ni * 16) - sw_gate = sw_gate_vals[ni] - sw_up = sw_up_vals[ni] acc_idx = mi * num_acc_n + ni vg = vector.extract( @@ -923,11 +966,15 @@ def write_row_to_lds( acc_up[acc_idx], static_position=[ii], dynamic_position=[] ) - if is_int8: - vg = arith.sitofp(f32, vg) - vu = arith.sitofp(f32, vu) - vg = vg * sx * sw_gate - vu = vu * sx * sw_up + if is_int8 or is_int4: + vg = arith.sitofp(T.f32, vg) + vu = arith.sitofp(T.f32, vu) + + if not no_epilogue_dequant: + sw_gate = sw_gate_vals[ni] + sw_up = sw_up_vals[ni] + vg = vg * sx * sw_gate + vu = vu * sx * sw_up y = silu(vg) * vu if doweight_stage1: @@ -950,6 +997,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): idx_out = idx0 + col_i32 # Vectorized fp16 store (EVec=4). buffer_ops.buffer_store(frag, out_rsrc, idx_out) + mfma_epilog( use_cshuffle=True, arith=arith, @@ -985,9 +1033,9 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): t2 = t2_raw s2 = s2_raw - sx0 = arith.f32(1.0) if is_f16 else buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32) + sx0 = arith.f32(1.0) if no_epilogue_dequant else buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32) sx = sx0 - zero_out = arith.constant(0.0, type=out_mlir()) + zero_out = arith.constant(0.0, type=_out_elem_type()) # out linear index base = ((t*topk + s)*inter_dim) (invariant across ni) idx0 = (t2 * topk_i32_v + s2) * inter_i32_local @@ -998,8 +1046,6 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): for ni in range_constexpr(num_acc_n): col_i32 = col_i32_list[ni] - sw_gate = sw_gate_vals[ni] - sw_up = sw_up_vals[ni] acc_idx = mi * num_acc_n + ni vg = vector.extract( @@ -1009,21 +1055,24 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): acc_up[acc_idx], static_position=[ii], dynamic_position=[] ) - if is_int8: + if is_int8 or is_int4: vg = arith.sitofp(f32, vg) vu = arith.sitofp(f32, vu) - vg = vg * sx * sw_gate - vu = vu * sx * sw_up + + if not no_epilogue_dequant: + sw_gate = sw_gate_vals[ni] + sw_up = sw_up_vals[ni] + vg = vg * sx * sw_gate + vu = vu * sx * sw_up y = silu(vg) * vu if doweight_stage1: y = y * tw - y = arith.trunc_f(out_mlir(), y) + y = arith.trunc_f(_out_elem_type(), y) idx_out0 = idx0 + col_i32 buffer_ops.buffer_store(y, out_rsrc, idx_out0) mfma_epilog( - use_cshuffle=False, arith=arith, range_constexpr=range_constexpr, @@ -1036,11 +1085,11 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): @flir.jit def __call__( self: flir.T.i64, - arg_out: lambda: T.memref(DYN, out_mlir()), - arg_x: lambda: T.memref(DYN, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_w: lambda: T.memref(DYN, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_scale_x: lambda: T.memref(DYN, T.f32()), - arg_scale_w: lambda: T.memref(experts * (2 * inter_dim), T.f32()), + arg_out: lambda: T.memref(DYN, _out_elem_type()), + arg_x: lambda: T.memref(DYN, _x_elem_type()), + arg_w: lambda: T.memref(DYN, _w_elem_type()), + arg_scale_x: lambda: T.memref(DYN, _scale_elem_type()), + arg_scale_w: lambda: T.memref(DYN, _scale_elem_type()), arg_sorted_token_ids: lambda: T.memref(DYN, T.i32()), arg_expert_ids: lambda: T.memref(DYN, T.i32()), arg_sorted_weights: lambda: T.memref(DYN, T.f32()), @@ -1092,7 +1141,8 @@ def compile_moe_gemm2( tile_n: int, tile_k: int, doweight_stage2: bool, - in_dtype: str = "fp8", + x_dtype: str = "fp8", + w_dtype: str = "fp8", out_dtype: str = "f16", use_cshuffle_epilog: bool | None = None, # Optional experiment: write per-(token,slot) output (no atomics) into an output shaped @@ -1102,11 +1152,16 @@ def compile_moe_gemm2( ): """Compile stage2 kernel (`moe_gemm2`) and return the compiled executable. - in_dtype: - - "fp8": A2/W are fp8 - - "fp16": A2/W are fp16 - - "int8": A2/W are int8 - - "int4": W4A8 path: A2 is int8, W is packed int4 unpacked to int8 in-kernel + x_dtype: + - "fp8": X/W are fp8 + - "fp16": X/W are fp16 + - "int8": X/W are int8 + - "int4": W4A8 path: X is int8, W is packed int4 unpacked to int8 in-kernel + w_dtype: + - "fp8": W is fp8 + - "fp16": W is fp16 + - "int8": W is int8 + - "int4": W4A8 path: W is packed int4 unpacked to int8 in-kernel Stage2 output supports: - out_dtype="f16": fp16 half2 atomics (fast, can overflow to +/-inf for bf16 workloads) @@ -1119,31 +1174,66 @@ def compile_moe_gemm2( allocator = SmemAllocator(None, arch=gpu_arch) _state = {} - if in_dtype not in ("fp8", "fp16", "int8", "int4"): - raise ValueError(f"in_dtype must be one of ('fp8','fp16','int8','int4'), got {in_dtype!r}") - is_f16 = in_dtype == "fp16" - elem_bytes = 2 if is_f16 else 1 - out_s = str(out_dtype).strip().lower() + + pipeline_manager = PreshufflePipelineManager(x_dtype, w_dtype, out_dtype) + pipeline_manager.check_type_valid() + + epilog_pipeline = pipeline_manager.get_epilog_pipeline() + mfma_pipeline = pipeline_manager.get_mfma_pipeline() + mfma_fn = pipeline_manager.get_mfma_fn() + + x_elem_bytes = pipeline_manager.get_a_elem_bytes() + w_elem_bytes = pipeline_manager.get_b_elem_bytes() + out_elem_bytes = pipeline_manager.get_out_elem_bytes() + + x_elem_pack = pipeline_manager.a_elem_pack + w_elem_pack = pipeline_manager.b_elem_pack + + tile_k_bytes = int(tile_k) * int(x_elem_bytes) + if (tile_k_bytes % 64) != 0: + raise ValueError( + f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " + f"(tile_k={tile_k}, x_elem_bytes={x_elem_bytes})" + ) + + def _get_mfma_dict_value(key, pipeline): + value = mfma_pipeline_dicts[key][pipeline] + return value() if callable(value) else value + + def _x_elem_type(): + return _get_mfma_dict_value("a_elem_type", mfma_pipeline) + def _w_elem_type(): + return _get_mfma_dict_value("b_elem_type", mfma_pipeline) + def _scale_elem_type(): + return _get_mfma_dict_value("scale_elem_type", mfma_pipeline) + def _out_elem_type(): + return _get_mfma_dict_value("out_elem_type", epilog_pipeline) + def _x_vec16_type(): + return _get_mfma_dict_value("x_vec16_type", mfma_pipeline) + def _w_vec16_type(): + return _get_mfma_dict_value("w_vec16_type", mfma_pipeline) + def _mfma_input_pack_ty(): + return _get_mfma_dict_value("mfma_input_pack_ty", mfma_pipeline) + def _mfma_output_pack_ty(): + return _get_mfma_dict_value("mfma_output_pack_ty", mfma_pipeline) + + is_f16_or_bf16 = mfma_pipeline in [MfmaPipeline.F16F16_16x16_PIPELINE, + MfmaPipeline.BF16BF16_16x16_PIPELINE] + is_int4 = mfma_pipeline in [MfmaPipeline.I8I4_16x16_PIPELINE] + is_int8 = mfma_pipeline in [MfmaPipeline.I8I8_16x16_PIPELINE] + is_fp4 = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, + MfmaPipeline.F8F4_MXFP4_PIPELINE] + + no_epilogue_dequant = is_fp4 or is_f16_or_bf16 + + out_s = pipeline_manager.out_dtype if out_s not in ("f16", "fp16", "half", "bf16", "bfloat16", "f32", "fp32", "float"): raise ValueError(f"out_dtype must be 'f16', 'bf16', or 'f32', got {out_dtype!r}") out_is_f32 = out_s in ("f32", "fp32", "float") out_is_bf16 = out_s in ("bf16", "bfloat16") + if (not bool(accumulate)) and out_is_f32: raise ValueError("compile_moe_gemm2(accumulate=False) only supports out_dtype in {'f16','bf16'}") - is_int4 = in_dtype == "int4" - # INT4 here means W4A8: A2 is int8, W is packed int4 and unpacked to int8 in-kernel. - is_int8 = (in_dtype == "int8") or is_int4 - - mfma_i32_k32 = None - if is_int8: - mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( - rocdl, "mfma_i32_16x16x32_i8", None - ) - if mfma_i32_k32 is None: - raise AttributeError( - "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " - "(or `rocdl.mfma_i32_16x16x32_i8`)." - ) DYN = ir.ShapedType.get_dynamic_size() size_out = DYN @@ -1152,20 +1242,21 @@ def compile_moe_gemm2( size_expert_ids_shape = DYN size_scale_x = DYN # W is packed int4 for W4A8: 2 values per byte. - size_w = (experts * model_dim * inter_dim) // 2 if is_int4 else (experts * model_dim * inter_dim) + size_w = (experts * model_dim * inter_dim) // w_elem_pack total_threads = 256 - tile_k_bytes = int(tile_k) * int(elem_bytes) + tile_k_bytes = int(tile_k) * int(x_elem_bytes) // x_elem_pack if (tile_k_bytes % 64) != 0: raise ValueError( f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " - f"(tile_k={tile_k}, elem_bytes={elem_bytes})" + f"(tile_k={tile_k}, elem_bytes={x_elem_bytes})" ) - bytes_x_per_tile = int(tile_m) * int(tile_k) * int(elem_bytes) + + bytes_x_per_tile = int(tile_m) * int(tile_k) * int(x_elem_bytes) // x_elem_pack if bytes_x_per_tile % total_threads != 0: raise ValueError( - "tile_m*tile_k*elem_bytes must be divisible by " - f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={elem_bytes}" + "tile_m*tile_k*x_elem_bytes must be divisible by " + f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, x_elem_bytes={x_elem_bytes}" ) bytes_per_thread_x = bytes_x_per_tile // total_threads @@ -1198,7 +1289,6 @@ def compile_moe_gemm2( ) # NOTE: Keep this as a callable so we don't require an MLIR Context at Python-time. - out_elem = (T.f32 if out_is_f32 else (T.bf16 if out_is_bf16 else T.f16)) epilog_tag = "cshuffle" # IMPORTANT: include tiling in the module name to avoid accidentally reusing a compiled # binary for a different (tile_m, tile_n, tile_k) configuration. @@ -1207,7 +1297,7 @@ def compile_moe_gemm2( # Dynamic-shape variant: safe to reuse across (tokens/sorted_size/size_expert_ids) at runtime. # Keep a distinct ABI tag so the compile cache never mixes with historical signatures. module_name = ( - f"mfma_moe2_{in_dtype}_{out_s}_{epilog_tag}" + f"mfma_moe2_{x_dtype}_{w_dtype}_{out_s}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" ).replace("-", "_") @@ -1223,22 +1313,22 @@ def init_gpu_module(self): # - epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes) # # This reduces LDS usage from sum(...) to max(...). - lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(elem_bytes) + lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(x_elem_bytes) lds_out_bytes = 2 * tile_m * tile_n if _use_cshuffle_epilog else 0 # f16 bytes lds_total_bytes = max(lds_x_bytes, lds_out_bytes) - lds_total_elems = lds_total_bytes if elem_bytes == 1 else (lds_total_bytes // 2) - x_lds_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + lds_total_elems = lds_total_bytes if x_elem_bytes == 1 else (lds_total_bytes // 2) + x_lds_elem = I.f16 if is_f16_or_bf16 else (I.i8 if is_int8 else I.f8) _state["lds_x_decl"] = allocator.allocate_array(x_lds_elem, lds_total_elems) allocator.finalize() @flir.kernel def moe_gemm2( self: flir.T.i64, - arg_out: lambda: T.memref(size_out, out_elem()), - arg_x: lambda: T.memref(size_x, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_w: lambda: T.memref(size_w, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_scale_x: lambda: T.memref(size_scale_x, T.f32()), - arg_scale_w: lambda: T.memref(experts * model_dim, T.f32()), + arg_out: lambda: T.memref(size_out, _out_elem_type()), + arg_x: lambda: T.memref(size_x, _x_elem_type()), + arg_w: lambda: T.memref(size_w, _w_elem_type()), + arg_scale_x: lambda: T.memref(size_scale_x, _scale_elem_type()), + arg_scale_w: lambda: T.memref(DYN, _scale_elem_type()), arg_sorted_token_ids: lambda: T.memref(size_sorted, T.i32()), arg_expert_ids: lambda: T.memref(size_expert_ids_shape, T.i32()), arg_sorted_weights: lambda: T.memref(size_sorted, T.f32()), @@ -1248,9 +1338,9 @@ def moe_gemm2( k_in: lambda: T.index(), size_expert_ids_in: lambda: T.index(), ): - x_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + x_elem = _x_elem_type() # For int4, weights are stored as packed bytes (i8) and unpacked to i8 packs. - w_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + w_elem = _w_elem_type() f16 = I.f16 f32 = I.f32 i32 = I.i32 @@ -1260,11 +1350,11 @@ def moe_gemm2( vec1_f16 = I.vec(1, f16) vec2_f16 = I.vec(2, f16) vec4_f16 = I.vec(4, f16) - vec16_elems = 16 if elem_bytes == 1 else 8 - vec8_elems = 8 if elem_bytes == 1 else 4 - vec4_elems = 4 if elem_bytes == 1 else 2 - vec8_x = I.vec(vec8_elems, x_elem) - vec16_x = I.vec(vec16_elems, x_elem) + vec16_x_elems = 16 if x_elem_bytes == 1 else 8 + vec8_x_elems = 8 if x_elem_bytes == 1 else 4 + vec4_x_elems = 4 if x_elem_bytes == 1 else 2 + vec8_x = I.vec(vec8_x_elems, x_elem) + vec16_x = I.vec(vec16_x_elems, x_elem) vec1_i64 = I.vec(1, i64) vec2_i64 = I.vec(2, i64) @@ -1283,10 +1373,9 @@ def moe_gemm2( c_n_total = arith.constant(experts * model_dim, index=True) kpack_bytes = 8 if is_int4 else 16 b_layout = make_preshuffle_b_layout( - flir, arith, c_n=c_n_total, c_k=k_in, kpack_bytes=kpack_bytes, elem_bytes=elem_bytes + flir, arith, c_n=c_n_total, c_k=k_in, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes ) layout_b = b_layout.layout_b - c_k0 = (k_in * arith.constant(int(elem_bytes), index=True)) / arith.index(64) shape_lds = flir.make_shape(tile_m, tile_k) stride_lds = flir.make_stride(lds_stride, 1) @@ -1301,12 +1390,12 @@ def moe_gemm2( # XOR16 swizzle parameter (in bytes; constant, power-of-two in our configs). k_blocks16 = arith.constant(tile_k_bytes // 16, index=True) - atom_x_s16 = flir.make_copy_atom(x_elem, vector_size=vec16_elems) + atom_x_s16 = flir.make_copy_atom(x_elem, vector_size=vec16_x_elems) atom_x_s8 = flir.make_copy_atom(x_elem, vector_size=8) atom_x_s4 = flir.make_copy_atom(x_elem, vector_size=4) - atom_x_g2r16 = flir.make_copy_atom(x_elem, vector_size=vec16_elems) - atom_x_g2r8 = flir.make_copy_atom(x_elem, vector_size=vec8_elems) - atom_x_g2r4 = flir.make_copy_atom(x_elem, vector_size=vec4_elems) + atom_x_g2r16 = flir.make_copy_atom(x_elem, vector_size=vec16_x_elems) + atom_x_g2r8 = flir.make_copy_atom(x_elem, vector_size=vec8_x_elems) + atom_x_g2r4 = flir.make_copy_atom(x_elem, vector_size=vec4_x_elems) layout_tx_wave_lane = flir.make_layout((4, 64), stride=(64, 1)) layout_lane16 = flir.make_layout((4, 16), stride=(16, 1)) layout_lin_rowcol = flir.make_layout((tile_m, tile_k), stride=(tile_k, 1)) @@ -1332,7 +1421,7 @@ def moe_gemm2( c_topk = arith.constant(topk, index=True) # X(A2): [tokens*topk, inter_dim] bytes = tokens*topk*k*elem_bytes - x_nbytes_idx = (tokens_in * c_topk) * k_in * arith.constant(int(elem_bytes), index=True) + x_nbytes_idx = (tokens_in * c_topk) * k_in * arith.constant(int(x_elem_bytes), index=True) x_nbytes_i32 = arith.index_cast(i32, x_nbytes_idx) x_rsrc = buffer_ops.create_buffer_resource( arg_x, max_size=False, num_records_bytes=x_nbytes_i32 @@ -1355,7 +1444,7 @@ def moe_gemm2( arg_out, max_size=False, num_records_bytes=out_nbytes_i32 ) # fp16 path ignores scales completely (implicit scale=1.0). - if is_f16: + if is_f16_or_bf16: sx_rsrc = None sw_rsrc = None else: @@ -1411,7 +1500,7 @@ def _moe_gemm2_then_body(): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. - if is_f16: + if is_f16_or_bf16: if bytes_per_thread_x % 16 != 0: raise ValueError( f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" @@ -1432,9 +1521,9 @@ def _moe_gemm2_then_body(): chunk_i32 = x_load_bytes // 4 # dwords per chunk (1/2/4) vec4_i32 = I.vec(4, i32) - c_k_div4 = (k_in * arith.constant(int(elem_bytes), index=True)) / arith.index(4) + c_k_div4 = (k_in * arith.constant(int(x_elem_bytes), index=True)) / arith.index(4) layout_x_div4 = flir.make_layout((m_in, c_k_div4), stride=(c_k_div4, 1)) - tile_k_dwords = (int(tile_k) * int(elem_bytes)) // 4 + tile_k_dwords = (int(tile_k) * int(x_elem_bytes)) // 4 layout_x_tile_div4 = flir.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) c_chunk_i32 = arith.constant(chunk_i32, index=True) tx_i32_base = tx * c_chunk_i32 @@ -1461,7 +1550,7 @@ def x_tile_chunk_coord_i32(i: int): def load_x(idx_i32): if x_load_bytes == 16: - idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * arith.index(2)) + idx_elem = idx_i32 if x_elem_bytes == 1 else (idx_i32 * arith.index(2)) return buffer_copy_gmem16_dwordx4( flir, arg=arg_x, @@ -1469,7 +1558,7 @@ def load_x(idx_i32): idx_i32=idx_elem, atom_g2r16=atom_x_g2r16, rsrc=x_rsrc, - vec_elems=vec16_elems, + vec_elems=vec16_x_elems, ) idx_bytes = idx_i32 * arith.index(4) atom = atom_x_g2r8 if x_load_bytes == 8 else atom_x_g2r4 @@ -1510,7 +1599,7 @@ def load_x(idx_i32): x_row_base_div4.append(row_ts_idx * c_k_div4) def load_x_tile(base_k): - base_k_div4 = (base_k * arith.constant(int(elem_bytes), index=True)) / arith.index(4) + base_k_div4 = (base_k * arith.constant(int(x_elem_bytes), index=True)) / arith.index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] @@ -1535,8 +1624,8 @@ def load_x_tile(base_k): col_offset_base = flir.crd2idx(flir.make_coord(lane_div_16, 0), layout_lane16) col_offset_base_bytes = ( col_offset_base - if elem_bytes == 1 - else (col_offset_base * arith.constant(int(elem_bytes), index=True)) + if x_elem_bytes == 1 + else (col_offset_base * arith.constant(int(x_elem_bytes), index=True)) ) # Dynamic N tiling within block. @@ -1584,7 +1673,7 @@ def load_b_pack(base_k, ki_step, ni): lane_div_16=lane_div_16, # 0..3 elem_type=w_elem, kpack_bytes=kpack_bytes, - elem_bytes=elem_bytes, + elem_bytes=w_elem_bytes, unpack_int4=is_int4, ) @@ -1629,7 +1718,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): k_blocks16=k_blocks16, lds_base=lds_base, vec_part_i32x4=vec_x_in_parts[i], - elem_bytes=elem_bytes, + elem_bytes=x_elem_bytes, ) elif x_load_bytes == 8: lds_store_8b_xor16( @@ -1671,8 +1760,8 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base_bytes, k_blocks16) col_base_swz = ( col_base_swz_bytes - if elem_bytes == 1 - else (col_base_swz_bytes / arith.constant(int(elem_bytes), index=True)) + if x_elem_bytes == 1 + else (col_base_swz_bytes / arith.constant(int(x_elem_bytes), index=True)) ) coord_a16 = flir.make_coord(curr_row_a_lds, col_base_swz) idx_a16 = flir.crd2idx(coord_a16, layout_lds) @@ -1685,25 +1774,17 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False, a0_prefetch=None): acc_list = list(acc_in) - mfma_res_ty = vec4_i32 if is_int8 else vec4_f32 - mfma_fn = ( - mfma_i32_k32 - if is_int8 - else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) - ) - + mfma_res_ty = _mfma_output_pack_ty() + epilogue_pf = None if prefetch_epilogue: expert_off_pf = expert_off_idx sw_pf = [] - for ni in range_constexpr(num_acc_n): - col_g = col_g_list[ni] - row_w_idx = expert_off_pf + col_g - sw_pf.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_w_idx, vec_width=1, dtype=f32) - ) + if not no_epilogue_dequant: + for ni in range_constexpr(num_acc_n): + col_g = col_g_list[ni] + row_w_idx = expert_off_pf + col_g + sw_pf.append(buffer_ops.buffer_load(sw_rsrc, row_w_idx, vec_width=1, dtype=f32)) # Also prefetch per-row routed/topk weights (sorted_weights) when enabled. tw_pf = None if doweight_stage2: @@ -1728,7 +1809,7 @@ def _i64_to_v4f16(x_i64): return vector.bitcast(vec4_f16, v1) def mfma_k64(acc0, a0, a1, b0, b1): - if is_f16: + if is_f16_or_bf16: a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -1970,16 +2051,12 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): # Weight scales for the N tile (col_g depends on lane/wave/by but not on (t,s)). if sw_pf is not None: sw_vals = sw_pf - else: + elif not no_epilogue_dequant: sw_vals = [] for ni in range_constexpr(num_acc_n): col_g = col_g_list[ni] row_w_idx = expert_off + col_g - sw_vals.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_w_idx, vec_width=1, dtype=f32) - ) + sw_vals.append(buffer_ops.buffer_load(sw_rsrc, row_w_idx, vec_width=1, dtype=f32)) if out_is_f32: # origin/dev_a16w4: f32 output uses scalar f32 atomics and skips CShuffle/LDS. @@ -2000,7 +2077,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): s2 = fused2 >> 24 ts2 = t2 * topk_i32_v + s2 - sx = arith.f32(1.0) if is_f16 else buffer_ops.buffer_load(sx_rsrc, ts2, vec_width=1, dtype=f32) + sx = arith.f32(1.0) if no_epilogue_dequant else buffer_ops.buffer_load(sx_rsrc, ts2, vec_width=1, dtype=f32) if doweight_stage2: tw_idx = (mi * 4) + ii @@ -2013,12 +2090,14 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): for ni in range_constexpr(num_acc_n): col_g = col_g_list[ni] - sw = sw_vals[ni] acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if is_int8: + if is_int8 or is_int4: v = arith.sitofp(f32, v) - v = v * sx * sw + if not no_epilogue_dequant: + sw = sw_vals[ni] + v = v * sx * sw + if doweight_stage2: v = v * tw col_i32 = arith.index_cast(i32, col_g) @@ -2063,7 +2142,7 @@ def write_row_to_lds( t2 = fused2 & mask24_i32 s2 = fused2 >> 24 ts2 = t2 * topk_i32_v + s2 - sx = arith.f32(1.0) if is_f16 else buffer_ops.buffer_load(sx_rsrc, ts2, vec_width=1, dtype=f32) + sx = arith.f32(1.0) if no_epilogue_dequant else buffer_ops.buffer_load(sx_rsrc, ts2, vec_width=1, dtype=f32) if doweight_stage2: tw_idx = (mi * 4) + ii @@ -2076,18 +2155,19 @@ def write_row_to_lds( for ni in range_constexpr(num_acc_n): col_local = col_base_local + (ni * 16) - sw = sw_vals[ni] acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) if is_int8: v = arith.sitofp(f32, v) - v = v * sx * sw + if not no_epilogue_dequant: + sw = sw_vals[ni] + v = v * sx * sw if doweight_stage2: v = v * tw - v_out = arith.trunc_f(out_elem(), v) + v_out = arith.trunc_f(_out_elem_type(), v) lds_idx = row_base_lds + col_local - vec1_out = I.vec(1, out_elem()) + vec1_out = I.vec(1, _out_elem_type()) v1 = vector.from_elements(vec1_out, [v_out]) vector.store(v1, lds_out, [lds_idx], alignment=2) @@ -2157,7 +2237,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): by_n=by_n, n_tile_base=n_tile_base, lds_out=lds_out, - frag_elem_type=(ir.BF16Type.get() if out_is_bf16 else ir.F16Type.get()), + frag_elem_type=_out_elem_type(), write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, @@ -2169,11 +2249,11 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): @flir.jit def __call__( self: flir.T.i64, - arg_out: lambda: T.memref(size_out, out_elem()), - arg_x: lambda: T.memref(size_x, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_w: lambda: T.memref(size_w, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_scale_x: lambda: T.memref(size_scale_x, T.f32()), - arg_scale_w: lambda: T.memref(experts * model_dim, T.f32()), + arg_out: lambda: T.memref(size_out, _out_elem_type()), + arg_x: lambda: T.memref(size_x, _x_elem_type()), + arg_w: lambda: T.memref(size_w, _w_elem_type()), + arg_scale_x: lambda: T.memref(size_scale_x, _scale_elem_type()), + arg_scale_w: lambda: T.memref(DYN, _scale_elem_type()), arg_sorted_token_ids: lambda: T.memref(size_sorted, T.i32()), arg_expert_ids: lambda: T.memref(size_expert_ids_shape, T.i32()), arg_sorted_weights: lambda: T.memref(size_sorted, T.f32()), diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index df715c5a..182c6694 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -148,7 +148,6 @@ def _mfma_output_pack_ty(): is_int8 = mfma_pipeline in [MfmaPipeline.I8I8_16x16_PIPELINE] is_fp4 = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, MfmaPipeline.F8F4_MXFP4_PIPELINE] - print(is_fp4) no_epilogue_dequant = is_fp4 or is_f16_or_bf16 diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index 0b3f6dfd..b0bafe12 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -285,7 +285,9 @@ def run_moe_stage1( tile_k: int, doweight_stage1: bool, *, - in_dtype: str = "fp8", + x_dtype: str = "fp8", + w_dtype: str = "fp8", + out_dtype: str = "f16", seed: int = 0, num_iters: int = 5, num_warmup: int = 2, @@ -357,36 +359,40 @@ def run_moe_stage1( blocks, ) = routing - if in_dtype not in ("fp8", "fp16", "int8", "int4"): - raise ValueError(f"in_dtype must be one of ('fp8','fp16','int8','int4'), got {in_dtype!r}") - is_int4 = in_dtype == "int4" - is_int8 = in_dtype in ("int8", "int4") + if x_dtype not in ("fp8", "fp16", "int8"): + raise ValueError(f"x_dtype must be one of ('fp8','fp16','int8','int4'), got {x_dtype!r}") + if w_dtype not in ("fp8", "fp16", "int8", "int4"): + raise ValueError(f"w_dtype must be one of ('fp8','fp16','int8','int4'), got {w_dtype!r}") + is_int4 = w_dtype == "int4" + is_int8 = x_dtype in ("int8", "int4") # Quantize inputs / weights. - if in_dtype == "fp8": + if x_dtype == "fp8" and w_dtype == "fp8": x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) # [tokens,K], [tokens,1] w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=DTYPE_FP8) # [E,2*inter,K], [E,2*inter,1] # w2 is not used by our kernel, but required by CK stage1 API w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=DTYPE_FP8) - elif in_dtype == "fp16": + elif x_dtype == "fp16" and w_dtype == "fp16": x_q = x_fp32.to(torch.float16) w1_q = w1_fp32.to(torch.float16) w2_q = w2_fp32.to(torch.float16) scale_x = None scale_w1 = None - elif in_dtype == "int8": + elif x_dtype == "int8" and w_dtype == "int8": x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=torch.int8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8) w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=torch.int8) - else: + elif x_dtype == "int8" and w_dtype == "int4": # W4A8: X is int8, W is int4 packed (host packs from int8 values in [-8,7]). x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=torch.int8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) + else: + raise ValueError(f"Invalid combination of x_dtype and w_dtype: {x_dtype!r}, {w_dtype!r}") # Preshuffle weights (aiter/CK layout) on the *unpacked* tensor. w1_shuffled = shuffle_weight(w1_q) - w2_shuffled = shuffle_weight(w2_q) if in_dtype == "fp8" else None + w2_shuffled = shuffle_weight(w2_q) if w_dtype == "fp8" else None # Flatten W1 for our flir kernel (treat expert dim as part of N). w1_shuffled_flat = w1_shuffled.view(experts * (2 * inter_dim), model_dim) @@ -420,7 +426,9 @@ def run_moe_stage1( inter_dim=inter_dim, experts=experts, topk=topk, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, @@ -493,7 +501,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): tbps = bytes_moved / 1e12 / (us / 1e6) print( - f"FLIR MoE stage1[{in_dtype}]: " + f"FLIR MoE stage1[{x_dtype} | {w_dtype} -> {out_dtype}]: " f"{us:.1f} us, " f"{tflops:.2f} TFLOPS(logical, M={tokens*topk}), " f"{tbps:.3f} TB/s (doweight_stage1={doweight_stage1})" @@ -504,7 +512,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): else: compare_ck = bool(compare_aiter_ck) # aiter CK paths are fp8-only in our setup. - compare_ck = compare_ck and (in_dtype == "fp8") + compare_ck = compare_ck and (x_dtype == "fp8" and w_dtype == "fp8") if compare_ck: if not HAS_AITER: pytest.skip("aiter not available; cannot compare to CK moe stage1.", allow_module_level=False) @@ -586,7 +594,8 @@ def run_moe_stage2( tile_k: int, doweight_stage1: bool, *, - in_dtype: str = "fp8", + x_dtype: str = "fp8", + w_dtype: str = "fp8", # Stage2 output is fp16 (half2 atomics + CShuffle). The legacy f32-atomic path was removed. out_dtype: str = "f16", seed: int = 0, @@ -704,39 +713,41 @@ def run_moe_stage2( # NOTE: routing uses `moe_sorting` output directly (no host trim/pad). Extra launched blocks # are gated by `num_valid_ids` inside the kernels. - if in_dtype not in ("fp8", "fp16", "int8", "int4"): - raise ValueError(f"in_dtype must be one of ('fp8','fp16','int8','int4'), got {in_dtype!r}") - is_int4 = in_dtype == "int4" - is_int8 = in_dtype in ("int8", "int4") + if x_dtype not in ("fp8", "fp16", "int8", "int4"): + raise ValueError(f"x_dtype must be one of ('fp8','fp16','int8','int4'), got {x_dtype!r}") + is_int4 = w_dtype == "int4" + is_int8 = x_dtype in ("int8", "int4") # Quantize inputs / weights. - if in_dtype == "fp8": + if x_dtype == "fp8" and w_dtype == "fp8": x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=DTYPE_FP8) w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=DTYPE_FP8) - elif in_dtype == "fp16": + elif x_dtype == "fp16" and w_dtype == "fp16": x_q = x_fp32.to(torch.float16) w1_q = w1_fp32.to(torch.float16) w2_q = w2_fp32.to(torch.float16) scale_x = None scale_w1 = None scale_w2 = None - elif in_dtype == "int8": + elif x_dtype == "int8" and w_dtype == "int8": x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=torch.int8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8) w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=torch.int8) - else: + elif x_dtype == "int8" and w_dtype == "int4": # W4A8: A2 is int8, W2 is int4 packed (host packs from int8 values in [-8,7]). x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=torch.int8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) + else: + raise ValueError(f"Invalid combination of x_dtype and w_dtype: {x_dtype!r}, {w_dtype!r}") # Preshuffle weights (aiter/CK layout) on the *unpacked* tensor. w1_shuffled = shuffle_weight(w1_q) w2_shuffled = shuffle_weight(w2_q) # Stage2 input (A2): either provided (gemm1->quantize chaining) or built from stage1 reference. - if a2_fp8_in is not None and (a2_scale_in is not None or in_dtype == "fp16"): + if a2_fp8_in is not None and (a2_scale_in is not None or (x_dtype == "fp16" and w_dtype == "fp16")): a2_q = a2_fp8_in a2_scale = a2_scale_in else: @@ -758,9 +769,9 @@ def run_moe_stage2( inter_dim=inter_dim, doweight_stage1=bool(doweight_stage1), ) # [tokens, topk, inter] fp32 - if in_dtype == "fp8": + if x_dtype == "fp8" and w_dtype == "fp8": a2_q, a2_scale = pertoken_quant(out1_ref, quant_dtype=DTYPE_FP8) - elif in_dtype == "fp16": + elif x_dtype == "fp16" and w_dtype == "fp16": a2_q = out1_ref.to(torch.float16) a2_scale = None else: @@ -805,7 +816,9 @@ def run_moe_stage2( inter_dim=inter_dim, experts=experts, topk=topk, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, @@ -889,7 +902,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): bytes_moved += int(sorted_expert_ids.numel()) * 4 tbps = bytes_moved / 1e12 / (us / 1e6) print( - f"FLIR MoE stage2 [{kernel_name}] {in_dtype} | " + f"FLIR MoE stage2 [{kernel_name}] {x_dtype} | {w_dtype} -> {out_dtype} | " f"{model_dim}x{inter_dim}, E={experts}, K={topk}, M_eff={tokens*topk} | " f"{us:.1f} us, {tflops:.2f} TFLOPS, {tbps:.3f} TB/s" ) @@ -899,7 +912,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): else: compare_ck = bool(compare_aiter_ck) # aiter CK paths are fp8-only in our setup. - compare_ck = compare_ck and (in_dtype == "fp8") + compare_ck = compare_ck and (x_dtype == "fp8" and w_dtype == "fp8") if compare_ck: if not HAS_AITER: pytest.skip("aiter not available; cannot compare to CK moe stage2.", allow_module_level=False) @@ -987,15 +1000,16 @@ def launch_ck(o, a2_, w1_, w2_, sorted_ids_, sorted_eids_, num_valid_, w2_scale_ @pytest.mark.parametrize( "tokens, model_dim, inter_dim, experts, topk, tile_m, tile_n1, tile_k1, tile_n2, tile_k2, doweight_stage1", [ - # Small smoke (fast compile + run) for all in_dtype. + # Small smoke (fast compile + run) for all x_dtype and w_dtype. pytest.param(64, 256, 128, 4, 2, 32, 64, 128, 64, 128, False, id="S"), - # Medium (more realistic) for all in_dtype (skip_ref will auto-enable). + # Medium (more realistic) for all x_dtype and w_dtype (skip_ref will auto-enable). pytest.param(128, 1024, 256, 8, 2, 64, 128, 128, 128, 128, False, id="M"), # Large (aiter-style) mainly for perf smoke; reference is too expensive here. pytest.param(256, 4096, 2048, 17, 9, 64, 128, 128, 256, 128, False, id="L"), ], ) -@pytest.mark.parametrize("in_dtype", ["fp8", "fp16", "int8", "int4"]) +@pytest.mark.parametrize("x_dtype", ["fp8", "fp16", "int8"]) +@pytest.mark.parametrize("w_dtype", ["fp8", "fp16", "int8", "int4"]) @pytest.mark.parametrize("use_reduce", [False, True], ids=["atomic", "reduce"]) def test_moe_gemm_2stage( tokens: int, @@ -1009,7 +1023,9 @@ def test_moe_gemm_2stage( tile_n2: int, tile_k2: int, doweight_stage1: bool, - in_dtype: str, + x_dtype: str, + w_dtype: str, + out_dtype: str, use_reduce: bool, *, seed: int = 0, @@ -1062,7 +1078,9 @@ def test_moe_gemm_2stage( inter_dim=inter_dim, experts=experts, topk=topk, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, tile_m=tile_m, tile_n=tile_n1, tile_k=tile_k1, @@ -1070,7 +1088,7 @@ def test_moe_gemm_2stage( seed=seed, num_iters=num_iters, num_warmup=num_warmup, - compare_aiter_ck=bool(compare_aiter_ck) and (in_dtype == "fp8"), + compare_aiter_ck=bool(compare_aiter_ck) and (x_dtype == "fp8" and w_dtype == "fp8"), moe_sort_mode=moe_sort_mode, x_fp32_in=x_fp32, w1_fp32_in=w1_fp32, @@ -1082,10 +1100,10 @@ def test_moe_gemm_2stage( skip_ref=bool(skip_ref), ) - if in_dtype == "fp8": + if x_dtype == "fp8" and w_dtype == "fp8": out1_fp32 = out1_fp16.to(torch.float32) a2_q, a2_scale = pertoken_quant(out1_fp32, quant_dtype=DTYPE_FP8) - elif in_dtype == "fp16": + elif x_dtype == "fp16" and w_dtype == "fp16": a2_q = out1_fp16 a2_scale = None else: @@ -1098,7 +1116,9 @@ def test_moe_gemm_2stage( inter_dim=inter_dim, experts=experts, topk=topk, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, tile_m=tile_m, tile_n=tile_n2, tile_k=tile_k2, @@ -1141,7 +1161,8 @@ def _compile( tile_n: int, tile_k: int, doweight_stage2: bool, - in_dtype: str = "fp8", + x_dtype: str, + w_dtype: str, out_dtype: str = "f16", ): if use_flydsl_reduce: @@ -1155,7 +1176,8 @@ def _compile( tile_n=tile_n, tile_k=tile_k, doweight_stage2=doweight_stage2, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, out_dtype=out_dtype, mode=MoeGemm2Mode.REDUCE, ) @@ -1170,7 +1192,8 @@ def _compile( tile_n=tile_n, tile_k=tile_k, doweight_stage2=doweight_stage2, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, out_dtype=out_dtype, accumulate=False, ) @@ -1367,14 +1390,16 @@ def test_moe_reduce_kernel(tokens: int, topk: int, model_dim: int): pytest.param(256, 5120, 1536, 64, 6, id="EP-K6-decode-L"), ], ) -@pytest.mark.parametrize("in_dtype", ["fp8"]) +@pytest.mark.parametrize("x_dtype", ["fp8"]) +@pytest.mark.parametrize("w_dtype", ["fp8"]) def test_moe_stage2_standalone( tokens: int, model_dim: int, inter_dim: int, experts: int, topk: int, - in_dtype: str, + x_dtype: str, + w_dtype: str, *, tile_m: int = 64, # Common block size for M tile_n: int = 256, # Common block size for N2 @@ -1401,7 +1426,8 @@ def test_moe_stage2_standalone( tile_n=tile_n, tile_k=tile_k, doweight_stage1=False, # Apply weight in stage2 - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, seed=seed, num_iters=num_iters, num_warmup=num_warmup, @@ -1456,12 +1482,25 @@ def _str2tuple_dim(v: str) -> Tuple[int, int]: description="MoE 2-stage (FLIR MFMA FP8) test/benchmark (argparse subset aligned with aiter test_moe_2stage.py)", ) parser.add_argument( - "--in_dtype", + "--x_dtype", + type=str, + default="fp8", + choices=["fp8", "fp16", "int8", "int4"], + help="Kernel input dtype: fp8 / fp16 / int8 / int4.", + ) + parser.add_argument( + "--w_dtype", type=str, default="fp8", - choices=["fp8", "fp16", "int8", "int4", "all"], - help="Kernel input dtype: fp8 / int8 / int4 / all (default: all). " - "int4 means W4A8: A int8, W packed int4.", + choices=["fp8", "fp16", "int8", "int4"], + help="Kernel weight dtype: fp8 / fp16 / int8 / int4.", + ) + parser.add_argument( + "--out_dtype", + type=str, + default="f16", + choices=["f16", "bf16", "f32"], + help="Stage2 output dtype: f16 / f32.", ) parser.add_argument("-d", "--dtype", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], help="Input init dtype (currently data is quantized to FP8 per-token; init dtype mainly affects RNG range).") parser.add_argument("-dim", type=_str2tuple_dim, default=(6144, 4096), help="Model dimension: model_dim,inter_dim (e.g. -dim 6144,4096)") @@ -1496,28 +1535,28 @@ def _str2tuple_dim(v: str) -> Tuple[int, int]: tile_k2 = int(args.tile_k2) if args.tile_k2 is not None else args.tile_k # Run 2-stage (gemm1 -> quantize -> gemm2) aiter-style test/benchmark. - for dt in args.in_dtype.split(","): - test_moe_gemm_2stage( - tokens=int(args.tokenNum), - model_dim=int(model_dim), - inter_dim=int(inter_dim), - experts=int(args.expert), - topk=int(args.topk), - tile_m=int(args.tile_m), - tile_n1=int(args.tile_n), - tile_k1=int(args.tile_k), - tile_n2=tile_n2, - tile_k2=tile_k2, - doweight_stage1=bool(args.doweight_stage1), - in_dtype=dt, - seed=int(args.seed), - num_iters=int(args.num_iters), - num_warmup=int(args.num_warmup), - moe_sort_mode=args.moe_sort_mode, - compare_aiter_ck=args.compare_aiter_ck, - skip_ref=bool(args.skip_ref), - use_reduce=bool(args.reduce), - ) - - - + for x_dtype in args.x_dtype.split(","): + for w_dtype in args.w_dtype.split(","): + test_moe_gemm_2stage( + tokens=int(args.tokenNum), + model_dim=int(model_dim), + inter_dim=int(inter_dim), + experts=int(args.expert), + topk=int(args.topk), + tile_m=int(args.tile_m), + tile_n1=int(args.tile_n), + tile_k1=int(args.tile_k), + tile_n2=tile_n2, + tile_k2=tile_k2, + doweight_stage1=bool(args.doweight_stage1), + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=args.out_dtype, + seed=int(args.seed), + num_iters=int(args.num_iters), + num_warmup=int(args.num_warmup), + moe_sort_mode=args.moe_sort_mode, + compare_aiter_ck=args.compare_aiter_ck, + skip_ref=bool(args.skip_ref), + use_reduce=bool(args.reduce), + ) From 8486680aa53581adee091a1f7dba018ec3c8863b Mon Sep 17 00:00:00 2001 From: zanzhang Date: Wed, 4 Feb 2026 19:36:15 +0800 Subject: [PATCH 05/11] update --- kernels/moe_gemm_2stage.py | 1452 +++++++++++++++++++++++++++++++- tests/kernels/test_moe_gemm.py | 22 +- 2 files changed, 1418 insertions(+), 56 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 21921e8a..103b59d9 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -42,6 +42,8 @@ from kernels.mfma_epilogues import c_shuffle_epilog, default_epilog, mfma_epilog +####==================== gemm1 pipeline start =====================### + @functools.lru_cache(maxsize=1024) def compile_moe_gemm1( *, @@ -87,13 +89,6 @@ def compile_moe_gemm1( x_elem_pack = pipeline_manager.a_elem_pack w_elem_pack = pipeline_manager.b_elem_pack - # pack_K is only used for FP4 modes which need special packing - is_fp4_mode = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, - MfmaPipeline.F8F4_MXFP4_PIPELINE] - pack_M = 2 if is_fp4_mode else 1 - pack_N = 2 if is_fp4_mode else 1 - pack_K = 2 if is_fp4_mode else 1 - tile_k_bytes = int(tile_k) * int(x_elem_bytes) # K64-byte micro-step: always 64 bytes per `ku`. For fp16 this is 32 elements. if (tile_k_bytes % 64) != 0: @@ -130,8 +125,14 @@ def _mfma_output_pack_ty(): is_fp4 = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, MfmaPipeline.F8F4_MXFP4_PIPELINE] + is_int_mode = is_int8 or is_int4 + no_epilogue_dequant = is_fp4 or is_f16_or_bf16 + pack_M = 2 if is_fp4 else 1 + pack_N = 2 if is_fp4 else 1 + pack_K = 2 if is_fp4 else 1 + DYN = ir.ShapedType.get_dynamic_size() size_out = DYN size_x = DYN @@ -192,7 +193,7 @@ def init_gpu_module(self): lds_total_bytes = max(lds_x_bytes, lds_out_bytes) lds_total_elems = lds_total_bytes if x_elem_bytes == 1 else (lds_total_bytes // 2) # x_lds_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) - x_lds_elem = I.f16 if is_f16_or_bf16 else (I.i8 if is_int8 else I.f8) + x_lds_elem = I.f16 if is_f16_or_bf16 else (I.i8 if is_int_mode else I.f8) _state["lds_x_decl"] = allocator.allocate_array(x_lds_elem, lds_total_elems) allocator.finalize() @@ -490,6 +491,7 @@ def load_x_tile(base_k): for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32) + # x_vec = load_x(0) if x_load_bytes == 16: parts.append(vector.bitcast(vec4_i32, x_vec)) elif x_load_bytes == 8: @@ -682,6 +684,7 @@ def compute_tile( b_up_tile_in, lds_base, *, + prefetch_epilogue: bool = False, a0_prefetch=None, ): @@ -966,7 +969,7 @@ def write_row_to_lds( acc_up[acc_idx], static_position=[ii], dynamic_position=[] ) - if is_int8 or is_int4: + if is_int_mode: vg = arith.sitofp(T.f32, vg) vu = arith.sitofp(T.f32, vu) @@ -1055,7 +1058,7 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): acc_up[acc_idx], static_position=[ii], dynamic_position=[] ) - if is_int8 or is_int4: + if is_int_mode: vg = arith.sitofp(f32, vg) vu = arith.sitofp(f32, vu) @@ -1130,6 +1133,1212 @@ def __call__( return exe +# This gemm1 pipeline used interleaved scale shuffle +@functools.lru_cache(maxsize=None) +def compile_gate_up_moe_gemm1( + *, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + # NOTE: aiter swap passes these for API symmetry; stage1 uses dynamic memrefs so they are ignored. + doweight_stage1: bool, + a_dtype: str = "fp8", + b_dtype: str = "fp4", + out_dtype: str = "f16", + act: str = "swiglu", + use_cshuffle_epilog: bool | None = None, + enable_bias: bool = False, + model_dim_pad: int = 0, + inter_dim_pad: int = 0, +): + """Compile stage1 kernel (`moe_gemm1`) and return the compiled executable. + + a_dtype: + - "fp8": X is fp8 + - "fp16": X is fp16 (caller uses tile_k halved vs fp8 to match MFMA K halving) + - "int8": X is int8 + - "fp4": X is fp4 + + b_dtype: + - "fp8": W is fp8 + - "fp16": W is fp16 (caller uses tile_k halved vs fp8 to match MFMA K halving) + - "int8": W is int8 + - "int4": W4A8 path: X is int8, W is packed int4 (2 values per byte) unpacked to int8 in-kernel + - "fp4": W is fp4 + """ + gpu_arch = get_hip_arch() + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + if a_dtype not in ("fp8", "fp16", "int8", "fp4"): + raise ValueError(f"a_dtype must be one of ('fp8','fp16','int8','fp4'), got {in_dtype!r}") + if b_dtype not in ("fp8", "fp16", "int8", "int4", "fp4"): + raise ValueError(f"in_dtype must be one of ('fp8','fp16','int8','int4', 'fp4'), got {in_dtype!r}") + + is_f16_a = a_dtype == "fp16" + is_f16_b = b_dtype == "fp16" + is_f16 = is_f16_a or is_f16_b + + is_f8_a = a_dtype == "fp8" + is_f4_a = a_dtype == "fp4" + is_f4_b = b_dtype == "fp4" + + pack_M = 2 + pack_N = 2 + pack_K = 2 + + elem_bytes = 1 + + a_elem_bytes = 2 if is_f16_a else 1 + b_elem_bytes = 1 + tile_k_bytes = int(tile_k) * int(a_elem_bytes) + + a_elem_vec_pack = 2 if is_f4_a else 1 + cbsz = 0 if is_f8_a else 4 + blgp = 4 + + # enable_bias = False + + # K64-byte micro-step: always 64 bytes per `ku`. For fp16, this is 32 elements (2xK16 MFMA). + if (tile_k_bytes % 64) != 0: + raise ValueError( + f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " + f"(tile_k={tile_k}, elem_bytes={a_elem_bytes})" + ) + is_int4 = b_dtype == "int4" + # INT4 here means W4A8: X is int8, W is packed int4 and unpacked to int8 in-kernel. + # is_int8 = (in_dtype == "int8") or is_int4 + is_int8 = False + + mfma_i32_k32 = None + if is_int8: + mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( + rocdl, "mfma_i32_16x16x32_i8", None + ) + if mfma_i32_k32 is None: + raise AttributeError( + "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " + "(or `rocdl.mfma_i32_16x16x32_i8`)." + ) + + def _x_elem_type(): + if is_f4_b: + return I.f8 if is_f8_a else I.ui8 + return I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + + def _w_elem_type(): + if is_f4_b: + return I.ui8 + return I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + + def _scale_elem_type(): + return I.i32 + + def _out_elem_type(): + return I.bf16 if out_dtype == "bf16" else I.i32 + + def _out_lds_elem_type(): + return I.f32 + + def _out_vec_type(): + return I.bf16x1 if out_dtype == "bf16" else I.f8x1 + + # size_out = tokens * topk * inter_dim + # size_x = tokens * model_dim + # # W is packed int4 for W4A8: 2 values per byte. + # size_w = (experts * (2 * inter_dim) * model_dim) // 2 if is_int4 else (experts * (2 * inter_dim) * model_dim) + + DYN = ir.ShapedType.get_dynamic_size() + size_out = DYN + size_x = DYN + # W is packed int4 for W4A8: 2 values per byte. + size_w = (experts * (2 * inter_dim) * model_dim) // 2 if is_int4 else (experts * (2 * inter_dim) * model_dim) + size_sorted = DYN + size_expert_ids = DYN + + total_threads = 256 + bytes_x_per_tile = int(tile_m) * int(tile_k) * int(a_elem_bytes) + if bytes_x_per_tile % total_threads != 0: + raise ValueError( + "tile_m*tile_k*elem_bytes must be divisible by " + f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={a_elem_bytes}" + ) + bytes_per_thread_x = bytes_x_per_tile // total_threads + # Keep MoE stage1 X gmem->LDS pipeline consistent with the optimized GEMM kernel: + # split into <=16B pieces and use `flir.copy(load-only)` for buffer_load_dwordx4. + # (Compute the split lens inside the kernel so the code matches GEMM structure.) + + # CK-style LDS128 mode (same idea as test_preshuffle_gemm.py): + # - LDS stride == tile_k (no extra padding) + XOR16 swizzle + # - Use ds_{read,write}_b128 (16B) and extract 8B halves for MFMA steps + _ck_lds128 = os.environ.get("FLIR_CK_LDS128", "1") in ("1", "true", "True", "YES", "yes") + pad_k = 0 if _ck_lds128 else 8 + lds_stride = tile_k + pad_k + if use_cshuffle_epilog is None: + use_cshuffle_epilog = os.environ.get("FLIR_MOE_STAGE1_CSHUFFLE", "1") in ("1", "true", "True", "YES", "yes") + use_cshuffle_epilog = bool(use_cshuffle_epilog) + + epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" + module_name = f"mfma_moe1_a{a_dtype}_w{b_dtype}_{epilog_tag}".replace("-", "_") + + class _MOE1(flir.MlirModule): + GPU_MODULE_NAME = module_name + GPU_MODULE_TARGETS = [ + f'#rocdl.target' + ] + + def init_gpu_module(self): + # Optional epilogue CShuffle (LDS + vectorized buffer stores). + # Reuse the same LDS bytes for both: + # - ping-pong X tiles (2 * tile_m * lds_stride bytes; fp8/int8) + # - epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes) + _use_cshuffle_epilog = bool(use_cshuffle_epilog) + lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(a_elem_bytes) + lds_out_bytes = 4 * tile_m * (tile_n // 2) if _use_cshuffle_epilog else 0 + lds_total_bytes = max(lds_x_bytes, lds_out_bytes) + lds_total_elems = lds_total_bytes if a_elem_bytes == 1 else (lds_total_bytes // 2) + x_lds_elem = I.f16 if is_f16_a else (I.i8 if is_int8 else I.f8) + _state["lds_x_decl"] = allocator.allocate_array(x_lds_elem, lds_total_elems) + allocator.finalize() + + @flir.kernel + def moe_gemm1( + self: flir.T.i64, + arg_out: lambda: T.memref(DYN, _out_elem_type()), + arg_x: lambda: T.memref(DYN, _x_elem_type()), + arg_w: lambda: T.memref(DYN, _w_elem_type()), + arg_scale_x: lambda: T.memref(DYN, _scale_elem_type()), + arg_scale_w: lambda: T.memref(experts * (2 * inter_dim), _scale_elem_type()), + arg_sorted_token_ids: lambda: T.memref(size_sorted, T.i32()), + arg_expert_ids: lambda: T.memref(DYN, T.i32()), + arg_sorted_weights: lambda: T.memref(DYN, T.f32()), + arg_max_token_ids: lambda: T.memref(DYN, T.i32()), + arg_bias: lambda: T.memref(DYN, T.f32()), + tokens_in: lambda: T.index(), + inter_in: lambda: T.index(), + k_in: lambda: T.index(), + size_expert_ids_in: lambda: T.index(), + ): + x_elem = I.f16 if is_f16_a else (I.i8 if is_int8 else I.f8) + # For int4, weights are stored as packed bytes (i8) and unpacked to i8 packs. + w_elem = I.f16 if is_f16_b else (I.i8 if is_int8 else I.f8) + f16 = I.f16 + f32 = I.f32 + i32 = I.i32 + i64 = I.i64 + vec4_f32 = I.vec(4, f32) + vec4_i32 = I.vec(4, i32) + vec4_f16 = I.vec(4, f16) + vec4_f8 = I.vec(4, I.f8) + vec1_f16 = I.vec(1, f16) + vec1_f32 = I.vec(1, f32) + vec16_elems = 16 if a_elem_bytes == 1 else 8 + vec8_elems = 8 if a_elem_bytes == 1 else 4 + vec4_elems = 4 if a_elem_bytes == 1 else 2 + vec8_x = I.vec(vec8_elems, x_elem) + vec16_x = I.vec(vec16_elems, x_elem) + vec1_i64 = I.vec(1, i64) + vec2_i64 = I.vec(2, i64) + + def silu(x): + # Align with CK's device fast path: + # emu = exp(-x) ~= exp2(log2e * (-x)) -> v_exp_f32 + # sig = rcp(1 + emu) -> v_rcp_f32 + # y = x * sig + # + # Using llvm.amdgcn intrinsics prevents lowering to the div_scale/div_fixup + # sequences that introduce extra compares/cndmasks. + t = x * (-1.4426950408889634) # -log2(e) + emu = llvm.call_intrinsic(f32, "llvm.amdgcn.exp2.f32", [t], [], []) + den = 1.0 + emu + sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) + return x * sig + + def swiglu(gate, up, alpha=1.702, limit=7.0): + # Align with CK's device fast path + # + # Using llvm.amdgcn intrinsics prevents lowering to the div_scale/div_fixup + # sequences that introduce extra compares/cndmasks. + gate = arith.minimum(gate, limit) + up = arith.minimum(up, limit) + up = arith.maximum(up, -limit) + + t = gate * (alpha) * (-1.4426950408889634) # -log2(e) + emu = llvm.call_intrinsic(f32, "llvm.amdgcn.exp2.f32", [t], [], []) + den = 1.0 + emu + sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) + return gate * sig * (up + 1) + + acc_init = ( + arith.constant_vector(0, vec4_i32) + if is_int8 + else arith.constant_vector(0.0, vec4_f32) + ) + + # Lccouts + layout_x = flir.make_layout((tokens_in, k_in), stride=(k_in, 1)) + + # B preshuffle layout: match GEMM test helper exactly. + c_n_total = arith.constant(experts * (2 * inter_dim), index=True) + kpack_bytes = 8 if is_int4 else 16 + b_layout = make_preshuffle_b_layout( + flir, arith, c_n=c_n_total, c_k=k_in // pack_K, kpack_bytes=kpack_bytes, elem_bytes=b_elem_bytes + ) + layout_b = b_layout.layout_b + + m_repeat = tile_m // 16 + k_unroll = tile_k_bytes // 128 # K64-byte micro-step + + # A&B's scale preshuffle layout + layout_a_scale = make_preshuffle_scale_layout( + flir, arith, c_mn=tokens_in, c_k=k_in, + ) + layout_b_scale = make_preshuffle_scale_layout( + flir, arith, c_mn=c_n_total, c_k=k_in, + ) + + # Only used by fp8/int8 path (16B gmem -> regs). Kept for backwards compat. + atom_w_g2r16 = flir.make_copy_atom(w_elem, vector_size=16) + + shape_lds = flir.make_shape(tile_m, tile_k) + stride_lds = flir.make_stride(lds_stride, 1) + layout_lds = flir.make_layout(shape_lds, stride_lds) + + tx = gpu.thread_id("x") + # Align with Aiter launch mapping (NSwizzle==false): + # - blockIdx.x -> N dimension (tile along inter_dim) + # - blockIdx.y -> expert-block id / M dimension (tile along sorted M) + by = gpu.block_id("x") # tile along inter_dim + bx = gpu.block_id("y") # tile along sorted M + + # Block validity: compute as early as possible so invalid blocks skip all buffer-resource + # setup, LDS pointer math, and gmem prefetch work. + bx_m = bx * arith.constant(tile_m, index=True) + by_n = by * arith.constant(tile_n, index=True) + + maxids_rsrc = buffer_ops.create_buffer_resource( + arg_max_token_ids, max_size=False, num_records_bytes=arith.i32(4) + ) + max_token_id_i32 = buffer_ops.buffer_load( + maxids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=i32 + ) + + bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None + + bx_m_i32 = arith.index_cast(i32, bx_m) + by_n_i32 = arith.index_cast(i32, by_n) + blk_valid = arith.cmpu(bx_m_i32, max_token_id_i32, "ult") + # Common constants/atoms (hoisted): keep IR small like GEMM. + # CK-style XOR16 swizzle parameter (constant, power-of-two in our configs). + k_blocks16 = arith.constant(tile_k_bytes // 16, index=True) + atom_x_s16 = flir.make_copy_atom(x_elem, vector_size=16) + atom_x_s8 = flir.make_copy_atom(x_elem, vector_size=8) + atom_x_s4 = flir.make_copy_atom(x_elem, vector_size=4) + atom_x_g2r16 = flir.make_copy_atom(x_elem, vector_size=vec16_elems) + atom_x_g2r8 = flir.make_copy_atom(x_elem, vector_size=vec8_elems) + atom_x_g2r4 = flir.make_copy_atom(x_elem, vector_size=vec4_elems) + layout_tx_wave_lane = flir.make_layout((4, 64), stride=(64, 1)) + layout_lane16 = flir.make_layout((4, 16), stride=(16, 1)) + + _if_blk = scf.IfOp(blk_valid) + with _if_blk.then(): + base_ptr = allocator.get_base() + lds_x_ptr = _state["lds_x_decl"](base_ptr) + lds_x = lds_x_ptr.get() + # Alias LDS bytes as fp16 for optional CShuffle epilogue. + _use_cshuffle_epilog = bool(use_cshuffle_epilog) + + lds_out = ( + SmemPtr(base_ptr, lds_x_ptr.byte_offset, _out_lds_elem_type(), shape=(tile_m * tile_n,)).get() + if _use_cshuffle_epilog + else None + ) + + # Use logical buffer sizes (descriptor num_records) so hardware OOB checking can be + # used directly (CK-style). This allows us to avoid `select`-based masking for + # invalid lanes and rely on the buffer instruction's built-in bounds behavior. + x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=tokens_in*model_dim) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) + out_rsrc = buffer_ops.create_buffer_resource(arg_out, max_size=False) + + # fp16 path ignores scales completely (implicit scale=1.0). + sx_rsrc = None if is_f16_a else buffer_ops.create_buffer_resource(arg_scale_x, max_size=False) + sw_rsrc = None if is_f16_b else buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) + sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False) + sorted_w_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=False) + + # expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4 + eid_nbytes_i32 = arith.index_cast(i32, size_expert_ids_in * arith.constant(4, index=True)) + expert_rsrc = buffer_ops.create_buffer_resource( + arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 + ) + + # Expert id for this M tile (keep address math in `index`) + expert_i32 = buffer_ops.buffer_load(expert_rsrc, bx, vec_width=1, dtype=i32) + exp_valid = arith.cmpu(expert_i32, experts, "ult") # todo fix + _ifexpert_of = scf.IfOp(exp_valid) + with _ifexpert_of.then(): + expert_idx = arith.index_cast(ir.IndexType.get(), expert_i32) + inter2_idx = arith.constant(2 * inter_dim, index=True) + expert_off_idx = expert_idx * inter2_idx # index + + bx_m = bx * arith.constant(tile_m, index=True) + + # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- + # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by + # 16, fall back to 8B (dwordx2) or 4B (dword) loads. This broadens supported tilings + # (e.g. tile_m=16, tile_k=192 -> 12B/thread) at some performance cost. + if is_f16_a: + # fp16 path keeps the same fixed 16B gmem->reg schedule. + if bytes_per_thread_x % 16 != 0: + raise ValueError( + f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" + ) + x_load_bytes = 16 + else: + if bytes_per_thread_x % 16 == 0: + x_load_bytes = 16 + elif bytes_per_thread_x % 8 == 0: + x_load_bytes = 8 + elif bytes_per_thread_x % 4 == 0: + x_load_bytes = 4 + else: + raise ValueError( + f"bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 4 to use the dword-indexed load mapping." + ) + num_x_loads = bytes_per_thread_x // x_load_bytes + chunk_i32 = x_load_bytes // 4 # dwords per chunk (1/2/4) + + # Work in dword units along K: K_dwords = (K_bytes)/4. + c_k_div4 = (k_in * arith.constant(int(elem_bytes), index=True)) / arith.index(4) + layout_x_div4 = flir.make_layout((tokens_in, c_k_div4), stride=(c_k_div4, 1)) + tile_k_dwords = (int(tile_k) * int(elem_bytes)) // 4 + layout_x_tile_div4 = flir.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) + c_chunk_i32 = arith.constant(chunk_i32, index=True) + tx_i32_base = tx * c_chunk_i32 + mask24 = arith.i32(0xFFFFFF) + # Keep i32 constants available for epilogue index math. + topk_i32 = arith.i32(topk) + + tokens_i32 = arith.index_cast(i32, tokens_in) + + def x_tile_chunk_coord_i32(i: int): + return tile_chunk_coord_i32( + flir, + arith, + tx_i32_base=tx_i32_base, + i=i, + total_threads=total_threads, + layout_tile_div4=layout_x_tile_div4, + chunk_i32=chunk_i32, + ) + + # CK-aligned: decode token once (per thread's M-slice) and build a base row offset. + x_row_base_div4 = [] + x_col_local_i32 = [] + x_row_local = [] + for i in range_constexpr(num_x_loads): + row_local, col_local_i32 = x_tile_chunk_coord_i32(i) + x_row_local.append(row_local) + x_col_local_i32.append(col_local_i32) + + sorted_row_i = bx_m + row_local + fused_i = buffer_ops.buffer_load(sorted_rsrc, sorted_row_i, vec_width=1, dtype=i32) + t_i32 = arith.andi(fused_i, mask24) + t_idx = arith.index_cast(ir.IndexType.get(), t_i32) + x_row_base_div4.append(t_idx * c_k_div4) + + vec1_i32 = I.vec(1, i32) + vec2_i32 = I.vec(2, i32) + vec4_i32 = I.vec(4, i32) + vec4_x = I.vec(4, x_elem) + + def load_x(idx_i32): + """Load `x_load_bytes` bytes from X (gmem) into regs. + + For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. + """ + if x_load_bytes == 16: + idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * arith.index(2)) + return buffer_copy_gmem16_dwordx4( + flir, + arg=arg_x, + elem_type=x_elem, + # idx_i32=idx_elem + 0x80000000, + idx_i32=idx_elem, + atom_g2r16=atom_x_g2r16, + rsrc=x_rsrc, + vec_elems=vec16_elems, + ) + idx_bytes = idx_i32 * arith.index(4) + atom = atom_x_g2r8 if x_load_bytes == 8 else atom_x_g2r4 + view = flir.TensorView( + arg_x, + (x_load_bytes,), + strides=(1,), + base_indices=(idx_bytes,), + element_type=x_elem, + ) + return flir.copy( + atom, + view, + None, + alignment=x_load_bytes, + return_vector=True, + src_buffer_resource=x_rsrc, + src_buffer_offset_in_bytes=True, + ) + + def load_x_tile(base_k): + """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" + base_k_div4 = (base_k * arith.constant(int(elem_bytes), index=True)) / arith.index(4) + parts = [] + for i in range_constexpr(num_x_loads): + idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] + x_vec = load_x(idx_i32) + parts.append(vector.bitcast(vec4_i32, x_vec)) + return parts + + # tx -> wave/lane (GEMM-style decomposition). + coord_wl = flir.idx2crd(tx, layout_tx_wave_lane) + wave_id = flir.get(coord_wl, 0) + lane_id = flir.get(coord_wl, 1) + coord_l16 = flir.idx2crd(lane_id, layout_lane16) + lane_div_16 = flir.get(coord_l16, 0) + lane_mod_16 = flir.get(coord_l16, 1) + + # Match GEMM naming/pattern: row in LDS is lane_mod_16, and col base is lane_div_16*16B (KPackBytes=16). + row_a_lds = lane_mod_16 + # col_offset_base = lane_div_16 * arith.constant(32, index=True) + col_offset_base = lane_div_16 * arith.constant(16, index=True) + + # Dynamic N tiling within block (same as existing kernels) + num_waves = 4 + n_per_wave = tile_n // num_waves + num_acc_n = n_per_wave // 16 + c_n_per_wave = arith.constant(n_per_wave, index=True) + wave_mod_4 = wave_id % arith.index(4) + n_tile_base = wave_mod_4 * c_n_per_wave + + # fp4 pack + k_unroll_packed = k_unroll // pack_K + m_repeat_packed = m_repeat // pack_M + num_acc_n_packed = num_acc_n // pack_N + + # Precompute n_blk/n_intra for gate and up rows (GEMM-style: idx2crd/get) + col_g_list = [] + valid_col_list = [] + inter_idx = arith.constant(inter_dim, index=True) + # layout for (row -> (blk,intra)) where intra is 0..15 + c_n0 = c_n_total / arith.index(16) + layout_n_blk_intra = flir.make_layout((c_n0, 16), stride=(16, 1)) + n_intra_list = [] + n_blk_list = [] + for i in range_constexpr(num_acc_n): + offset = i * 16 + + col_g = by_n + n_tile_base + col_g = col_g // 2 + offset + col_g = col_g + lane_mod_16 + col_g_list.append(col_g) + + c_offset = arith.constant(offset, index=True) + global_n = by_n + n_tile_base + c_offset + lane_mod_16 + row_w = expert_off_idx + global_n + coord_n = flir.idx2crd(row_w, layout_n_blk_intra) + n_blk_list.append(flir.get(coord_n, 0)) + n_intra_list.append(flir.get(coord_n, 1)) + + # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- + def load_b_packs_k64(base_k, ku: int, ni: int): + base_k_bytes = base_k * arith.constant(int(elem_bytes), index=True) + k0_base = base_k_bytes / 64 + k0 = k0_base + ku + k1 = lane_div_16 + coord_pack = flir.make_coord(n_blk_list[ni], k0, k1, n_intra_list[ni], 0) + idx_pack = flir.crd2idx(coord_pack, layout_b) + + # Calculate mask for boundary check + c_offset = arith.constant(ni * 16, index=True) + global_n = by_n + n_tile_base + c_offset + lane_mod_16 + + vec_elems = 16 + b_view = flir.TensorView( + arg_w, + (vec_elems,), + strides=(1,), + base_indices=(idx_pack,), + element_type=_w_elem_type(), + ) + b16 = flir.copy( + flir.make_copy_atom(_w_elem_type(), vector_size=vec_elems), + b_view, + None, + alignment=8, + return_vector=True, + src_buffer_resource=(w_rsrc if elem_bytes == 1 else None), + src_buffer_offset_in_bytes=(elem_bytes == 1), + ) + # Split 16B pack into two 8B halves. + b_i64x2 = vector.bitcast(I.i64x2, b16) + b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) + b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) + return b0_i64, b1_i64 + + def load_b_tile(base_k): + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n): + b0, b1 = load_b_packs_k64(base_k, ku, ni) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile + + def load_scale(arg_scale, rsrc, layout, ku, mni): + k_lane = lane_div_16 + n_lane = lane_mod_16 + coord_pack = flir.make_coord(mni, ku, k_lane, n_lane) + idx_pack = flir.crd2idx(coord_pack, layout) + scale_view = flir.TensorView( + arg_scale, + (1,), + strides=(1,), + base_indices=(idx_pack,), + element_type=_scale_elem_type(), + ) + scale = flir.copy( + flir.make_copy_atom(_scale_elem_type(), vector_size=1), + scale_view, + None, + alignment=8, + return_vector=True, + src_buffer_resource=rsrc, + src_buffer_offset_in_bytes=False, + ) + # Split 16B pack into two 8B halves. + return scale + + def load_b_scale_tile(base_k): + b_scale_tile = [] + for ku in range_constexpr(k_unroll_packed): + for ni in range_constexpr(num_acc_n_packed): + scale = load_scale( + arg_scale_w, + sw_rsrc, + layout_b_scale, + ku + base_k, + ni + (expert_off_idx + by_n + n_tile_base) // pack_N // 16, + ) + b_scale_tile.append(scale) + return b_scale_tile + + def load_a_scale_tile(base_k): + a_scale_tile = [] + for ku in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + scale = load_scale( + arg_scale_x, + sx_rsrc, + layout_a_scale, + ku + base_k, + mi + bx_m // pack_M // 16, + ) + a_scale_tile.append(scale) + return a_scale_tile + + def prefetch_ab_scale_tile(base_k): + return [None, load_b_scale_tile(base_k)] + + acc_gate = [acc_init] * (num_acc_n // 2 * m_repeat) + acc_up = [acc_init] * (num_acc_n // 2 * m_repeat) + + # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- + def store_x_tile_to_lds(vec_x_in_parts, lds_base): + for i in range_constexpr(num_x_loads): + row_local = x_row_local[i] + col_local_i32 = x_col_local_i32[i] + if x_load_bytes == 16: + lds_store_16b_xor16( + flir, + arith, + vector, + lds_memref=lds_x, + vec16_ty=vec16_x, + elem_type=x_elem, + atom_s16=atom_x_s16, + layout_lds=layout_lds, + row_local=row_local, + col_local_i32=col_local_i32, + tx_c4=arith.index(4), + k_blocks16=k_blocks16, + lds_base=lds_base, + vec_part_i32x4=vec_x_in_parts[i], + elem_bytes=elem_bytes, + ) + + # --- A LDS load helper for K64 (load 16B once, extract 2x i64 halves) --- + def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): + # Swizzle in bytes, then convert to element offset for memref indexing. + col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) + col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2)) + coord_a16 = flir.make_coord(curr_row_a_lds, col_base_swz) + idx_a16 = flir.crd2idx(coord_a16, layout_lds) + idx_a16 = idx_a16 + lds_base + loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) + a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) + a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) + a1 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) + return a0, a1 + + def compute_f8f6f4_tile( + acc_gate_in, + acc_up_in, + b_tile_in, + lds_base, + *, + a0_prefetch=None, + a_scale=None, + b_scale=None, + prefetch_epilogue: bool = False + ): + gate_list = list(acc_gate_in) + up_list = list(acc_up_in) + + epilogue_pf = None + if enable_bias and prefetch_epilogue: + expert_off_idx + by_n + n_tile_base + + gate_bias = [] + up_bias = [] + for ni in range_constexpr(num_acc_n_packed): + global_n = (by_n + n_tile_base) // 2 + ni * 16 + lane_mod_16 + gate_offset = expert_off_idx + global_n + up_offset = expert_off_idx + global_n + inter_dim + gate_bias.append( + buffer_ops.buffer_load(bias_rsrc, gate_offset, vec_width=1, dtype=f32) + ) + up_bias.append( + buffer_ops.buffer_load(bias_rsrc, up_offset, vec_width=1, dtype=f32) + ) + epilogue_pf = (gate_bias, up_bias) + + # ---------------- gfx95 fast path (K128 MFMA scale) ---------------- + # This is the key optimization from `zhimding/develop_0107` for FP8: + # use mfma.scale 16x16x128 to reduce instruction count in the hot loop. + # + # Notes: + # - Only valid for fp8 path (not int8/int4) and gfx95+ + # - Requires tile_k divisible by 128 + # - mfma.scale takes 9 operands: 3 vectors + 6 i32 flags/scales. + if (int(tile_k) % 128) != 0: + raise ValueError( + f"tile_k must be divisible by 128 for mfma_scale_x128, got tile_k={tile_k}" + ) + + mfma_res_ty = I.f32x4 + vec4_i64 = I.vec(4, I.i64) + vec8_i32 = I.vec(8, I.i32) + c0_i64 = arith.constant(0, type=I.i64) + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + for ku128 in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + # a_scale_i32 = a_scale[ku128 * m_repeat_packed + mi] + # a_scale_val = vector.extract(a_scale_i32, static_position=[0], dynamic_position=[]) + for ni in range_constexpr(num_acc_n_packed): + b_scale_i32 = b_scale[ku128 * num_acc_n_packed + ni] + b_scale_val = vector.extract(b_scale_i32, static_position=[0], dynamic_position=[]) + for ikxdl in range_constexpr(pack_K): + k_idx = ku128 * pack_K + ikxdl + b_packs0, b_packs1 = b_tile_in[k_idx] + col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack + for imxdl in range_constexpr(pack_M): + col_base0 = col_base + mi_idx = mi * pack_M + imxdl + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) + + if is_f8_a: + col_base1 = col_base + 64 + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + else: + a128 = pack_i64x4_to_i32x8(a0, a1, c0_i64, c0_i64) + + for inxdl in range_constexpr(pack_N): + if inxdl % 2 == 0: + current_accs_list = gate_list + else: + current_accs_list = up_list + ni_idx = ni * pack_N + inxdl + + b0 = b_packs0[ni_idx] + b1 = b_packs1[ni_idx] + b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) + + acc_idx = mi_idx * num_acc_n_packed + ni + rocdl.sched_barrier(0) + current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_accs_list[acc_idx], + cbsz, + blgp, + # use per tensor quant a1 for now, + 0, + 0x3F800000, + ikxdl * pack_N + inxdl, + b_scale_val, + ], + ) + + return gate_list, up_list, epilogue_pf + + # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- + lds_tile_elems = arith.constant(tile_m * lds_stride, index=True) + lds_base_cur = arith.index(0) + lds_base_nxt = lds_tile_elems + + # Optional scheduler hints (copied from tuned GEMM); can be disabled via env. + rocdl.sched_barrier(0) + + def hot_loop_scheduler(): + mfma_group = num_acc_n * 2 + # K64 micro-step: 2x K32 MFMA per gemm. + mfma_total = (k_unroll * 2) * m_repeat * mfma_group + mfma_per_iter = 2 * mfma_group + sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) + + # DS-read preload (CK default is 2); clamp to non-negative. + rocdl.sched_dsrd(2) + rocdl.sched_mfma(2) + rocdl.sched_dsrd(1) + rocdl.sched_mfma(1) + rocdl.sched_dsrd(1) + rocdl.sched_mfma(1) + + # DS-write hints near the end: match total X LDS-store micro-ops per thread. + dswr_tail = num_x_loads + if dswr_tail > sche_iters: + dswr_tail = sche_iters + dswr_start = sche_iters - dswr_tail + for sche_i in range_constexpr(sche_iters): + rocdl.sched_vmem(1) + rocdl.sched_mfma(mfma_group) + rocdl.sched_dsrd(1) + rocdl.sched_mfma(mfma_group) + if sche_i >= dswr_start - 1: + rocdl.sched_dswr(1) + rocdl.sched_barrier(0) + + # Prologue: prefetch tile0, store to LDS(cur), sync. + k0 = arith.index(0) + x_regs0 = load_x_tile(k0) + w_regs0 = load_b_tile(k0) + + a_scale_pong = None + a_scale_ping = None + # a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(k0 // 2) + _, b_scale_pong = prefetch_ab_scale_tile(k0 // 2) + store_x_tile_to_lds(x_regs0, lds_base_cur) + gpu.barrier() + + # Loop-carried ping/pong state. + lds_base_pong = lds_base_cur # current/compute + lds_base_ping = lds_base_nxt # next/load+store + w_regs_pong = w_regs0 + + # Cross-tile A0 LDS prefetch (default-on): prefetch the first A-pack (K64) for the + # tile we are about to compute from LDS, to overlap with upcoming VMEM. + a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base, lds_base_pong) + + # Unrolled ping-pong main loop (2 tiles per iteration), leaving 2 tail tiles. + c2_tile_k = arith.constant(tile_k * 2, index=True) + c_k_main2 = k_in - c2_tile_k + + for k_iv in range(arith.index(0), c_k_main2, c2_tile_k): + # ---- stage 0: prefetch+store ping, compute pong ---- + next_k1 = k_iv + tile_k + x_regs_ping = load_x_tile(next_k1) + w_regs_ping = load_b_tile(next_k1 // 2) + _, b_scale_ping = prefetch_ab_scale_tile(next_k1 // pack_K // 128) + + acc_gate, acc_up, _ = compute_f8f6f4_tile( + acc_gate, + acc_up, + w_regs_pong, + lds_base_pong, + a0_prefetch=a0_prefetch_pong, + a_scale=a_scale_pong, + b_scale=b_scale_pong, + ) + a0_prefetch_pong = None + store_x_tile_to_lds(x_regs_ping, lds_base_ping) + # hot_loop_scheduler() + gpu.barrier() + + # Cross-tile prefetch for the ping tile we are about to compute. + a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base, lds_base_ping) + + # ---- stage 1: prefetch+store pong, compute ping ---- + next_k2 = k_iv + c2_tile_k + x_regs_pong = load_x_tile(next_k2) + w_regs_pong = load_b_tile(next_k2 // 2) + _, b_scale_pong = prefetch_ab_scale_tile(next_k2 // pack_K // 128) + + acc_gate, acc_up, _ = compute_f8f6f4_tile( + acc_gate, + acc_up, + w_regs_ping, + lds_base_ping, + a0_prefetch=a0_prefetch_ping, + a_scale=a_scale_ping, + b_scale=b_scale_ping, + ) + a0_prefetch_ping = None + store_x_tile_to_lds(x_regs_pong, lds_base_pong) + # hot_loop_scheduler() + gpu.barrier() + + # Cross-tile prefetch for the next pong tile. + a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base, lds_base_pong) + + # Tail: 2 remaining tiles at (k_in - 2*tile_k) and (k_in - tile_k). + k_tail1 = k_in - tile_k + x_regs_ping = load_x_tile(k_tail1) + w_regs_ping = load_b_tile(k_tail1 // 2) + _, b_scale_ping = prefetch_ab_scale_tile(k_tail1 // pack_K // 128) + + acc_gate, acc_up, _ = compute_f8f6f4_tile( + acc_gate, + acc_up, + w_regs_pong, + lds_base_pong, + a0_prefetch=a0_prefetch_pong, + a_scale=a_scale_pong, + b_scale=b_scale_pong, + ) + a0_prefetch_pong = None + store_x_tile_to_lds(x_regs_ping, lds_base_ping) + # hot_loop_scheduler() + gpu.barrier() + + # Cross-tile prefetch for the final ping tile. + a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base, lds_base_ping) + + # Epilogue: compute last tile with epilogue scale prefetch to overlap loads with MFMA. + acc_gate, acc_up, epilogue_pf = compute_f8f6f4_tile( + acc_gate, + acc_up, + w_regs_ping, + lds_base_ping, + a0_prefetch=a0_prefetch_ping, + a_scale=a_scale_ping, + b_scale=b_scale_ping, + prefetch_epilogue=True, + ) + + # Store epilogue to out[t, slot, inter] + expert_off = expert_off_idx + bx_m0 = bx_m + topk_i32_v = topk_i32 + inter_i32_v = arith.i32(inter_dim) + mask24_i32 = arith.i32(0xFFFFFF) + + # Epilogue hoists to keep IR + Python build time small: + col_i32_list = [] + for ni in range_constexpr(num_acc_n): + col_i32_list.append(arith.index_cast(i32, col_g_list[ni])) + + lane_div_16_mul4 = lane_div_16 * arith.index(4) + inter_i32_local = inter_i32_v + + # Optional: CK-style CShuffle epilogue for better global store coalescing. + # Uses EVec=4 (buffer store "x4" of fp16 elements). + _use_cshuffle_epilog = (out_dtype == "fp8") or bool(use_cshuffle_epilog) + + mask_even_i32 = arith.i32(0xFFFFFFFE) + + if _use_cshuffle_epilog: + if lds_out is None: + raise RuntimeError("CShuffle epilogue enabled but lds_out is not allocated/aliased.") + + def write_row_to_lds( + *, + mi: int, + ii: int, + row_in_tile, + row, + row_base_lds, + col_base_local, + num_acc_n: int, + lds_out, + ): + # `row` is the sorted-row index (bx_m + row_in_tile). + fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) + t2 = fused2 & mask24_i32 + + # Sorted weight aligned with `row` (matches aiter moe_sorting output). + if doweight_stage1: + tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=f32) + + for ni in range_constexpr(num_acc_n_packed): + col_local = col_base_local + (ni * 16) + + acc_idx = mi * num_acc_n + ni + vg = vector.extract( + acc_gate[acc_idx], static_position=[ii], dynamic_position=[] + ) + vu = vector.extract( + acc_up[acc_idx], static_position=[ii], dynamic_position=[] + ) + + if enable_bias: + gate_bias_list, up_bias_list = epilogue_pf + vg = vg + gate_bias_list[ni] + vu = vu + up_bias_list[ni] + + if act == "swiglu": + y = swiglu(vg, vu) + else: + y = silu(vg) * vu + + if doweight_stage1: + y = y * tw + + lds_idx = row_base_lds + col_local + v1 = vector.from_elements(vec1_f32, [y]) + vector.store(v1, lds_out, [lds_idx], alignment=1) + + def precompute_row(*, row_local, row): + fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) + t2 = fused2 & mask24_i32 + s2 = fused2 >> 24 + return (t2 * topk_i32_v + s2) * inter_i32_local + + def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): + # Guard against sentinel token ids (t == tokens) produced by aiter moe_sorting padding. + # OOB buffer stores are not guaranteed to be safe on all paths, so predicate explicitly. + fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) + t2 = fused2 & mask24_i32 + t_valid = arith.cmpu(t2, tokens_i32, "ult") + _if_valid = scf.IfOp(t_valid) + with _if_valid.then(): + frag = vector.bitcast(vec4_f32, frag) + frag0 = vector.extract(frag, static_position=[0], dynamic_position=[]) + frag1 = vector.extract(frag, static_position=[1], dynamic_position=[]) + frag2 = vector.extract(frag, static_position=[2], dynamic_position=[]) + frag3 = vector.extract(frag, static_position=[3], dynamic_position=[]) + + out_fp8 = arith.i32(0) + out_fp8 = rocdl.cvt_pk_fp8_f32(src_a=arith._unwrap_value(frag0), src_b=arith._unwrap_value(frag1), old=arith._unwrap_value(out_fp8), word_sel=0, res=I.i32) + out_fp8 = rocdl.cvt_pk_fp8_f32(src_a=arith._unwrap_value(frag2), src_b=arith._unwrap_value(frag3), old=arith._unwrap_value(out_fp8), word_sel=1, res=I.i32) + + idx0 = row_ctx + col_i32 = arith.index_cast(i32, col_g0) + idx_out = idx0 + col_i32 + buffer_ops.buffer_store(out_fp8, out_rsrc, idx_out // 4) + + mfma_epilog( + use_cshuffle=True, + arith=arith, + vector=vector, + gpu=gpu, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n // 2, + e_vec=4, + m_repeat=m_repeat, + num_acc_n=num_acc_n_packed, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n // 2, + n_tile_base=n_tile_base // 2, + lds_out=lds_out, + frag_elem_type=I.f32, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + ) + return + + def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): + # `row` is the sorted-row index (bx_m + row_in_tile). + fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) + t2 = fused2 & mask24_i32 + s2 = fused2 >> 24 + + t_valid = arith.cmpu(t2, tokens_i32, "ult") + # No explicit mask: rely on buffer descriptor OOB to zero-fill when t2 is the + # sentinel (t2 == tokens) or otherwise out-of-range. + + # out linear index base = ((t*topk + s)*inter_dim) (invariant across ni) + idx0 = (t2 * topk_i32_v + s2) * inter_i32_local + + # Sorted weight aligned with `row` (matches aiter moe_sorting output). + if doweight_stage1: + tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=f32) + + _if_valid = scf.IfOp(t_valid) + with _if_valid.then(): + for ni in range_constexpr(num_acc_n_packed): + col_i32 = col_i32_list[ni] + acc_idx = mi * num_acc_n_packed + ni + vg = vector.extract( + acc_gate[acc_idx], static_position=[ii], dynamic_position=[] + ) + vu = vector.extract( + acc_up[acc_idx], static_position=[ii], dynamic_position=[] + ) + if enable_bias: + gate_bias_list, up_bias_list = epilogue_pf + vg = vg + gate_bias_list[ni] + vu = vu + up_bias_list[ni] + + if act == "swiglu:": + y = swiglu(vg, vu) + else: + y = silu(vg) * vu + + if doweight_stage1: + y = y * tw + + y = arith.trunc_f(_out_elem_type(), y) + idx_out = idx0 + col_i32 + buffer_ops.buffer_store(y, out_rsrc, idx_out) + + mfma_epilog( + use_cshuffle=False, + arith=arith, + range_constexpr=range_constexpr, + m_repeat=m_repeat, + lane_div_16=lane_div_16, + bx_m=bx_m, + body_row=_stage1_store_row, + ) + + @flir.jit + def __call__( + self: flir.T.i64, + arg_out: lambda: T.memref(DYN, _out_elem_type()), + arg_x: lambda: T.memref(DYN, _x_elem_type()), + arg_w: lambda: T.memref(DYN, _w_elem_type()), + arg_scale_x: lambda: T.memref(DYN, _scale_elem_type()), + arg_scale_w: lambda: T.memref(experts * (2 * inter_dim), _scale_elem_type()), + arg_sorted_token_ids: lambda: T.memref(size_sorted, T.i32()), + arg_expert_ids: lambda: T.memref(DYN, T.i32()), + arg_sorted_weights: lambda: T.memref(DYN, T.f32()), + arg_max_token_ids: lambda: T.memref(DYN, T.i32()), + arg_bias: lambda: T.memref(DYN, T.f32()), + tokens_in: lambda: T.index(), + inter_in: lambda: T.index(), + k_in: lambda: T.index(), + size_expert_ids_in: lambda: T.index(), + ): + bdx = 256 + gx = 2 * inter_in / arith.index(tile_n) + # Use host-provided upper bound for M blocks (same as aiter moe_sorting allocation). + # This avoids device->host sync on num_valid_ids. + gy = size_expert_ids_in + flir.gpu_ext.LaunchFuncOp( + [module_name, "moe_gemm1"], + grid_size=(gx, gy, 1), + block_size=(bdx, 1, 1), + kernel_operands=[ + arg_out, + arg_x, + arg_w, + arg_scale_x, + arg_scale_w, + arg_sorted_token_ids, + arg_expert_ids, + arg_sorted_weights, + arg_max_token_ids, + arg_bias, + tokens_in, + inter_in, + k_in, + size_expert_ids_in, + ], + ) + + m = _MOE1() + exe = flydsl.compile(m) + return exe + + +def compile_moe_gemm1_dispatch( + *, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight_stage2: bool, + x_dtype: str = "fp8", + w_dtype: str = "fp8", + out_dtype: str = "f16", + use_cshuffle_epilog: bool | None = None, + gate_up_interleave: bool = False, +): + # Compile based on mode + if not gate_up_interleave: + return compile_moe_gemm1( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=doweight_stage2, + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + use_cshuffle_epilog=use_cshuffle_epilog, + ) + else: + return compile_gate_up_moe_gemm1( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=doweight_stage2, + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + use_cshuffle_epilog=use_cshuffle_epilog, + ) + +####==================== gemm1 pipeline end =====================### + + +####==================== gemm2 pipeline start =====================### + @functools.lru_cache(maxsize=1024) def compile_moe_gemm2( *, @@ -1224,8 +2433,17 @@ def _mfma_output_pack_ty(): is_fp4 = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, MfmaPipeline.F8F4_MXFP4_PIPELINE] + is_int_mode = is_int8 or is_int4 + no_epilogue_dequant = is_fp4 or is_f16_or_bf16 + # FP4 specific parameters for mfma_scale_f32_16x16x128_f8f6f4 + pack_M = 2 if is_fp4 else 1 + pack_N = 2 if is_fp4 else 1 + pack_K = 2 if is_fp4 else 1 + cbsz = 0 if mfma_pipeline == MfmaPipeline.F8F4_MXFP4_PIPELINE else 4 # fp8 a: cbsz=0, fp4 a: cbsz=4 + blgp = 4 + out_s = pipeline_manager.out_dtype if out_s not in ("f16", "fp16", "half", "bf16", "bfloat16", "f32", "fp32", "float"): raise ValueError(f"out_dtype must be 'f16', 'bf16', or 'f32', got {out_dtype!r}") @@ -1316,9 +2534,8 @@ def init_gpu_module(self): lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(x_elem_bytes) lds_out_bytes = 2 * tile_m * tile_n if _use_cshuffle_epilog else 0 # f16 bytes lds_total_bytes = max(lds_x_bytes, lds_out_bytes) - lds_total_elems = lds_total_bytes if x_elem_bytes == 1 else (lds_total_bytes // 2) - x_lds_elem = I.f16 if is_f16_or_bf16 else (I.i8 if is_int8 else I.f8) - _state["lds_x_decl"] = allocator.allocate_array(x_lds_elem, lds_total_elems) + lds_total_elems = lds_total_bytes // x_elem_bytes + _state["lds_x_decl"] = allocator.allocate_array(_x_elem_type(), lds_total_elems) allocator.finalize() @flir.kernel @@ -1358,10 +2575,10 @@ def moe_gemm2( vec1_i64 = I.vec(1, i64) vec2_i64 = I.vec(2, i64) + placeholder = arith.constant(0, type=I.i32) + acc_init = ( - arith.constant_vector(0, vec4_i32) - if is_int8 - else arith.constant_vector(0.0, vec4_f32) + arith.constant_vector(0, _mfma_output_pack_ty()) ) # A2 layout (flatten token-slot -> M). @@ -1655,6 +2872,11 @@ def load_x_tile(base_k): m_repeat = tile_m // 16 k_unroll = tile_k_bytes // 64 # K64-byte micro-step (2x MFMA) + + # FP4 packed parameters for mfma_scale_f32_16x16x128_f8f6f4 + k_unroll_packed = k_unroll // pack_K + m_repeat_packed = m_repeat // pack_M + num_acc_n_packed = num_acc_n // pack_N # --- B Load Logic (K64) --- def load_b_pack(base_k, ki_step, ni): @@ -1696,7 +2918,66 @@ def load_b_tile(base_k): packs1.append(b1) b_tile.append((packs0, packs1)) return b_tile - + + # --- FP4 Scale Load Functions for gemm2 --- + def load_scale(arg_scale, rsrc, layout, ku, mni): + """Load a single scale value for FP4 mfma_scale instruction (gemm2).""" + k_lane = lane_div_16 + n_lane = lane_mod_16 + coord_pack = flir.make_coord(mni, ku, k_lane, n_lane) + idx_pack = flir.crd2idx(coord_pack, layout) + scale_view = flir.TensorView( + arg_scale, + (1,), + strides=(1,), + base_indices=(idx_pack,), + element_type=_scale_elem_type(), + ) + scale = flir.copy( + flir.make_copy_atom(_scale_elem_type(), vector_size=1), + scale_view, + None, + alignment=8, + return_vector=True, + src_buffer_resource=rsrc, + src_buffer_offset_in_bytes=False, + ) + return scale + + def load_b_scale_tile(base_k): + """Load B scale tile for FP4 pipeline (gemm2).""" + b_scale_tile = [] + for ku in range_constexpr(k_unroll_packed_g2): + for ni in range_constexpr(num_acc_n_packed_g2): + scale = load_scale( + arg_scale_w, + sw_rsrc, + layout_b_scale, + ku + base_k, + ni + (expert_off_idx + by_n + n_tile_base) // pack_N_gemm2 // 16, + ) + b_scale_tile.append(scale) + return b_scale_tile + + def load_a_scale_tile(base_k): + """Load A scale tile for FP4 pipeline (gemm2).""" + a_scale_tile = [] + for ku in range_constexpr(k_unroll_packed_g2): + for mi in range_constexpr(m_repeat_packed_g2): + scale = load_scale_fp4_g2( + arg_scale_x, + sx_rsrc, + layout_a_scale, + ku + base_k, + mi + bx_m // pack_M_gemm2 // 16, + ) + a_scale_tile.append(scale) + return a_scale_tile + + def prefetch_ab_scale_tile(base_k): + """Prefetch A and B scale tiles for FP4 pipeline (gemm2).""" + return [None, load_b_scale_tile(base_k)] + # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- def store_x_tile_to_lds(vec_x_in_parts, lds_base): for i in range_constexpr(num_x_loads): @@ -1772,15 +3053,16 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): a1 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) return a0, a1 - def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False, a0_prefetch=None): + def compute_tile(acc_in, b_tile_in, lds_base, *, a_scale=None, b_scale=None, prefetch_epilogue: bool = False, a0_prefetch=None): acc_list = list(acc_in) mfma_res_ty = _mfma_output_pack_ty() epilogue_pf = None if prefetch_epilogue: expert_off_pf = expert_off_idx - sw_pf = [] + sw_pf = None if not no_epilogue_dequant: + sw_pf = [] for ni in range_constexpr(num_acc_n): col_g = col_g_list[ni] row_w_idx = expert_off_pf + col_g @@ -1803,6 +3085,70 @@ def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False ) ) epilogue_pf = (sw_pf, tw_pf) + + if is_fp4: + c0_i64 = arith.constant(0, type=I.i64) + vec4_i64 = I.vec(4, I.i64) + vec8_i32 = I.vec(8, I.i32) + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + # FP4 path using mfma_scale_f32_16x16x128_f8f6f4 + for ku128 in range_constexpr(k_unroll_packed_g2): + for mi in range_constexpr(m_repeat_packed_g2): + for ni in range_constexpr(num_acc_n_packed_g2): + b_scale_i32 = b_scale[ku128 * num_acc_n_packed_g2 + ni] + b_scale_val = vector.extract(b_scale_i32, static_position=[0], dynamic_position=[]) + for ikxdl in range_constexpr(pack_K_gemm2): + k_idx = ku128 * pack_K_gemm2 + ikxdl + b_packs0, b_packs1 = b_tile_in[k_idx] + col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack_gemm2 + + for imxdl in range_constexpr(pack_M_gemm2): + col_base0 = col_base + mi_idx = mi * pack_M_gemm2 + imxdl + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) + + if is_f8_a_gemm2: + col_base1 = col_base + 64 + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + else: + a128 = pack_i64x4_to_i32x8(a0, a1, c0_i64, c0_i64) + + for inxdl in range_constexpr(pack_N_gemm2): + ni_idx = ni * pack_N_gemm2 + inxdl + + b0 = b_packs0[ni_idx] + b1 = b_packs1[ni_idx] + b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) + + acc_idx = mi_idx * num_acc_n + ni_idx + rocdl.sched_barrier(0) + acc_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + acc_list[acc_idx], + cbsz_gemm2, + blgp_gemm2, + # use per tensor quant a1 for now, + 0, + 0x3F800000, + ikxdl * pack_N_gemm2 + inxdl, + b_scale_val, + ], + ) + return acc_list, epilogue_pf def _i64_to_v4f16(x_i64): v1 = vector.from_elements(vec1_i64, [x_i64]) @@ -1843,7 +3189,7 @@ def mfma_k64(acc0, a0, a1, b0, b1): b_packs1[ni], ) return acc_list, epilogue_pf - + # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- lds_tile_elems = arith.constant(tile_m * lds_stride, index=True) lds_base_cur = arith.index(0) @@ -1929,21 +3275,31 @@ def hot_loop_scheduler(): rocdl.sched_dswr(1) rocdl.sched_barrier(0) + + # ================ Unified Pipeline (FP4 / Standard) ================ + # Select pipeline functions based on is_fp4 + def _prefetch_scale(k_val): + if is_fp4: + return prefetch_ab_scale_tile_fp4_g2(k_val // pack_K // 128) + return placeholder, placeholder + # Prologue. k0 = arith.index(0) x_regs0 = load_x_tile(k0) b_cur = load_b_tile(k0) + a_scale_ping, a_scale_pong = placeholder, placeholder + _, b_scale_pong = _prefetch_scale(k0) store_x_tile_to_lds(x_regs0, lds_base_cur) gpu.barrier() - + acc = [acc_init] * (num_acc_n * m_repeat) lds_base_pong = lds_base_cur lds_base_ping = lds_base_nxt - + # Cross-tile A0 LDS prefetch (default-on): prefetch the first A-pack (K64) for the # tile we are about to compute from LDS, to overlap with upcoming VMEM. a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) - + # Main loop: process K tiles in 2-tile ping-pong steps. # # IMPORTANT: for odd number of K tiles, leave **1** tail tile; for even, leave **2**. @@ -1955,63 +3311,65 @@ def hot_loop_scheduler(): k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) if k_main2_py < 0: k_main2_py = 0 - + c2_tile_k = arith.constant(tile_k * 2, index=True) c_k_main2 = arith.index(k_main2_py) for k_iv in range(arith.index(0), c_k_main2, arith.index(tile_k * 2)): next_k1 = k_iv + tile_k x_regs_ping = load_x_tile(next_k1) - b_ping = load_b_tile(next_k1) - - acc, _ = compute_tile(acc, b_cur, lds_base_pong, a0_prefetch=a0_prefetch_pong) - a0_prefetch_pong = None + b_ping = load_b_tile(next_k1 // pack_K) + _, b_scale_ping = _prefetch_scale(next_k1) + + acc, _ = compute_tile(acc, b_cur, lds_base_pong, a_scale=a_scale_pong, b_scale=b_scale_pong, a0_prefetch=a0_prefetch_pong) store_x_tile_to_lds(x_regs_ping, lds_base_ping) hot_loop_scheduler() gpu.barrier() - + # Cross-tile prefetch for the ping tile we are about to compute. a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) - + next_k2 = k_iv + c2_tile_k x_regs_pong = load_x_tile(next_k2) - b_next = load_b_tile(next_k2) - - acc, _ = compute_tile(acc, b_ping, lds_base_ping, a0_prefetch=a0_prefetch_ping) - a0_prefetch_ping = None + b_next = load_b_tile(next_k2 // pack_K) + _, b_scale_pong = _prefetch_scale(next_k2) + + acc, _ = compute_tile(acc, b_ping, lds_base_ping, a_scale=a_scale_ping, b_scale=b_scale_ping, a0_prefetch=a0_prefetch_ping) store_x_tile_to_lds(x_regs_pong, lds_base_pong) hot_loop_scheduler() gpu.barrier() - + # Cross-tile prefetch for the next pong tile. a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) - + b_cur = b_next - + if odd_k_tiles: # Tail: single remaining tile (already in `b_cur` / `lds_base_pong`). acc, epilogue_pf = compute_tile( acc, b_cur, lds_base_pong, + a_scale=a_scale_pong, + b_scale=b_scale_pong, prefetch_epilogue=True, a0_prefetch=a0_prefetch_pong, ) else: # Tail: 2 remaining tiles. - k_tail1 = k_in - tile_k + k_tail1 = (k_in + tile_k - 1) // tile_k * tile_k - tile_k if is_fp4 else k_in - tile_k x_regs_ping = load_x_tile(k_tail1) - b_ping = load_b_tile(k_tail1) - - acc, _ = compute_tile(acc, b_cur, lds_base_pong, a0_prefetch=a0_prefetch_pong) - a0_prefetch_pong = None + b_ping = load_b_tile(k_tail1 // pack_K) + _, b_scale_ping = _prefetch_scale(k_tail1) + + acc, _ = compute_tile(acc, b_cur, lds_base_pong, a_scale=a_scale_pong, b_scale=b_scale_pong, a0_prefetch=a0_prefetch_pong) store_x_tile_to_lds(x_regs_ping, lds_base_ping) hot_loop_scheduler() gpu.barrier() - + # Epilogue tile with sw prefetch. a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) acc, epilogue_pf = compute_tile( - acc, b_ping, lds_base_ping, prefetch_epilogue=True, a0_prefetch=a0_prefetch_ping + acc, b_ping, lds_base_ping, a_scale=a_scale_ping, b_scale=b_scale_ping, prefetch_epilogue=True, a0_prefetch=a0_prefetch_ping ) # ---------------- Epilogue: LDS CShuffle + atomic half2 (x2) ---------------- @@ -2092,7 +3450,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): col_g = col_g_list[ni] acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if is_int8 or is_int4: + if is_int_mode: v = arith.sitofp(f32, v) if not no_epilogue_dequant: sw = sw_vals[ni] @@ -2157,7 +3515,7 @@ def write_row_to_lds( col_local = col_base_local + (ni * 16) acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if is_int8: + if is_int_mode: v = arith.sitofp(f32, v) if not no_epilogue_dequant: sw = sw_vals[ni] @@ -2709,3 +4067,5 @@ def compile_moe_gemm2_ex( use_cshuffle_epilog=use_cshuffle_epilog, accumulate=True, ) + +####==================== gemm2 pipeline end =====================### \ No newline at end of file diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index b0bafe12..555e5be9 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -482,7 +482,9 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): rtol = 0.5 if is_int4 else 0.25 atol = 0.5 if is_int4 else 0.25 - assert verify_output(out.to(torch.float32), ref, rtol=rtol, atol=atol) + print(out.to(torch.float32)) + print(ref) + assert verify_output(out.to(torch.float32), ref, rtol=1e-4, atol=1e-4) # Note: kernel launches full expert-block range; effective work is gated by num_valid_ids. flops = 2 * tokens * topk * (2 * inter_dim) * model_dim @@ -668,22 +670,22 @@ def run_moe_stage2( x_fp32 = ( x_fp32_in if x_fp32_in is not None - else torch.rand((tokens, model_dim), device=device, dtype=torch.float32) * s + else torch.randn((tokens, model_dim), device=device, dtype=torch.float32) * s ) w1_fp32 = ( w1_fp32_in if w1_fp32_in is not None - else torch.rand((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * (s / math.sqrt(model_dim)) + else torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * (s / math.sqrt(model_dim)) ) w2_fp32 = ( w2_fp32_in if w2_fp32_in is not None - else torch.rand((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) + else torch.randn((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) ) # Routing: deterministic torch topk + softmax. if topk_ids_in is None or topk_weights_in is None: - score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + score = torch.randn((tokens, experts), device=device, dtype=torch.float32) topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) else: @@ -885,7 +887,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): model_dim=model_dim, doweight_stage2=doweight_stage2, ) - assert verify_output(out.to(torch.float32), ref2, rtol=0.5, atol=0.5) + assert verify_output(out.to(torch.float32), ref2, rtol=0.000005, atol=0.000005) # Launches full expert-block range; effective work is gated by num_valid_ids. flops = 2 * tokens * topk * model_dim * inter_dim @@ -1045,13 +1047,13 @@ def test_moe_gemm_2stage( if init_scale == 1.0: init_scale = 0.2 s = float(init_scale) - x_fp32 = torch.rand((tokens, model_dim), device=device, dtype=torch.float32) * s + x_fp32 = torch.randn((tokens, model_dim), device=device, dtype=torch.float32) * s # fan_in = model_dim for W1: [E, 2*inter, model] - w1_fp32 = torch.rand((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * (s / math.sqrt(model_dim)) + w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * (s / math.sqrt(model_dim)) # fan_in = inter_dim for W2: [E, model, inter] - w2_fp32 = torch.rand((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) + w2_fp32 = torch.randn((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) - score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + score = torch.randn((tokens, experts), device=device, dtype=torch.float32) topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) From 037c6d198e2750b3430eed7feda54a149f074eab Mon Sep 17 00:00:00 2001 From: zanzhang Date: Thu, 5 Feb 2026 16:16:59 +0800 Subject: [PATCH 06/11] fused update --- kernels/moe_gemm_2stage.py | 1601 +++++++------------------------- tests/kernels/test_moe_gemm.py | 114 ++- 2 files changed, 425 insertions(+), 1290 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 103b59d9..573cc116 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -60,6 +60,8 @@ def compile_moe_gemm1( w_dtype: str = "fp8", out_dtype: str = "f16", use_cshuffle_epilog: bool | None = None, + act: str = "silu", + enable_bias: bool = False, ): """Compile stage1 kernel (`moe_gemm1`) and return the compiled executable. @@ -133,6 +135,16 @@ def _mfma_output_pack_ty(): pack_N = 2 if is_fp4 else 1 pack_K = 2 if is_fp4 else 1 + # FP4 specific parameters for mfma_scale_f32_16x16x128_f8f6f4 (gemm1) + cbsz_g1 = 0 if mfma_pipeline == MfmaPipeline.F4F4_MXFP4_PIPELINE else 4 # fp8 a: cbsz=0, fp4 a: cbsz=4 + blgp_g1 = 4 + + is_gate_up_inter = is_fp4 + + is_cast_out = out_dtype == "f8" + if is_cast_out and not use_cshuffle_epilog: + raise ValueError("out_dtype='f' requires CShuffle epilogue (set use_cshuffle_epilog=True).") + DYN = ir.ShapedType.get_dynamic_size() size_out = DYN size_x = DYN @@ -223,6 +235,7 @@ def moe_gemm1( i64 = I.i64 vec4_f32 = I.vec(4, f32) vec4_i32 = I.vec(4, i32) + vec1_f32 = I.vec(1, f32) vec1_f16 = I.vec(1, f16) vec4_f16 = I.vec(4, f16) vec16_x_elems = 16 // x_elem_bytes # if x_elem_bytes == 1 else 8 @@ -233,6 +246,8 @@ def moe_gemm1( vec1_i64 = I.vec(1, i64) vec2_i64 = I.vec(2, i64) + placeholder = arith.constant(0, index=True) + def silu(x): # device fast path: # emu = exp(-x) ~= exp2(log2e * (-x)) -> v_exp_f32 @@ -319,6 +334,8 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): layout_tx_wave_lane = flir.make_layout((4, 64), stride=(64, 1)) layout_lane16 = flir.make_layout((4, 16), stride=(16, 1)) + # bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None + # Everything below is gated by `blk_valid` to avoid doing buffer-resource setup and # gmem work for padding blocks. _if_blk = scf.IfOp(blk_valid) @@ -557,6 +574,11 @@ def load_x_tile(base_k): m_repeat = tile_m // 16 k_unroll = tile_k_bytes // 64 // pack_K + + # FP4 packed parameters for mfma_scale (gemm1) + k_unroll_packed_g1 = k_unroll // pack_K + m_repeat_packed_g1 = m_repeat // pack_M + num_acc_n_packed_g1 = num_acc_n // pack_N # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- def load_b_pack(base_k, ki_step, ni, blk_list, intra_list): @@ -598,9 +620,100 @@ def load_b_tile(base_k, blk_list, intra_list): packs1.append(b1) b_tile.append((packs0, packs1)) return b_tile + + # --- FP4 Scale Load Functions for gemm1 --- + def load_scale_inter(arg_scale, rsrc, layout, ku, mni): + """Load a single scale value for FP4 mfma_scale instruction (gemm1).""" + k_lane = lane_div_16 + n_lane = lane_mod_16 + coord_pack = flir.make_coord(mni, ku, k_lane, n_lane) + idx_pack = flir.crd2idx(coord_pack, layout) + scale_view = flir.TensorView( + arg_scale_w, + (1,), + strides=(1,), + base_indices=(idx_pack,), + element_type=_scale_elem_type(), + ) + scale = flir.copy( + flir.make_copy_atom(_scale_elem_type(), vector_size=1), + scale_view, + None, + alignment=8, + return_vector=True, + src_buffer_resource=rsrc, + src_buffer_offset_in_bytes=False, + ) + return scale + + def load_b_scale_tile_inter(base_k): + """Load B scale tile for FP4 pipeline (gemm1).""" + b_scale_tile = [] + for ku in range_constexpr(k_unroll_packed_g1): + for ni in range_constexpr(num_acc_n_packed_g1): + scale = load_scale_fp4_g1( + arg_scale_w, + sw_rsrc, + layout_b_scale, + ku + base_k, + ni + (expert_off_idx + by_n + n_tile_base) // pack_N // 16, + ) + b_scale_tile.append(scale) + return b_scale_tile + + def prefetch_ab_scale_tile_inter(base_k): + """Prefetch A and B scale tiles for FP4 pipeline (gemm1).""" + return placeholder, load_b_scale_tile_fp4_g1(base_k) + + # --- FP4 B Load Logic for gemm1 (interleaved gate+up) --- + def load_b_packs_k64_inter(base_k, ku: int, ni: int): + """Load B pack for FP4 pipeline with interleaved gate+up (gemm1).""" + base_k_bytes = base_k * arith.constant(int(w_elem_bytes), index=True) + k0_base = base_k_bytes / 64 + k0 = k0_base + ku + k1 = lane_div_16 + # For interleaved gate+up, use col_g_list which contains both gate and up columns + coord_pack = flir.make_coord(n_blk_gate[ni // 2] if ni % 2 == 0 else n_blk_up[ni // 2], k0, k1, n_intra_gate[ni // 2] if ni % 2 == 0 else n_intra_up[ni // 2], 0) + idx_pack = flir.crd2idx(coord_pack, layout_b) + + vec_elems = 16 + b_view = flir.TensorView( + arg_w, + (vec_elems,), + strides=(1,), + base_indices=(idx_pack,), + element_type=_w_elem_type(), + ) + b16 = flir.copy( + flir.make_copy_atom(_w_elem_type(), vector_size=vec_elems), + b_view, + None, + alignment=8, + return_vector=True, + src_buffer_resource=(w_rsrc if w_elem_bytes == 1 else None), + src_buffer_offset_in_bytes=(w_elem_bytes == 1), + ) + # Split 16B pack into two 8B halves. + b_i64x2 = vector.bitcast(I.i64x2, b16) + b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) + b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) + return b0_i64, b1_i64 + + def load_b_tile_inter(base_k): + """Load B tile for FP4 pipeline with interleaved gate+up (gemm1).""" + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n * 2): # gate+up interleaved + b0, b1 = load_b_packs_k64_inter(base_k, ku, ni) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile - acc_gate = [acc_init] * (num_acc_n * m_repeat) - acc_up = [acc_init] * (num_acc_n * m_repeat) + acc_gate = [acc_init] * (num_acc_n * m_repeat) if not is_gate_up_inter else [acc_init] * (num_acc_n // 2 * m_repeat) + acc_up = [acc_init] * (num_acc_n * m_repeat) if not is_gate_up_inter else [acc_init] * (num_acc_n // 2 * m_repeat) # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- def store_x_tile_to_lds(vec_x_in_parts, lds_base): @@ -684,7 +797,6 @@ def compute_tile( b_up_tile_in, lds_base, *, - prefetch_epilogue: bool = False, a0_prefetch=None, ): @@ -758,6 +870,111 @@ def mfma_k64(acc_in, a0, a1, b0, b1): b_up_packs1[ni], ) return gate_list, up_list, epilogue_pf + + # --- FP4 Compute Tile using mfma_scale_f32_16x16x128_f8f6f4 (gemm1) --- + def compute_tile_inter( + acc_gate_in, + acc_up_in, + b_tile_in, + lds_base, + *, + a_scale=None, + b_scale=None, + prefetch_epilogue: bool = False, + a0_prefetch=None, + ): + """FP4 compute tile using mfma_scale for interleaved gate+up (gemm1).""" + gate_list = list(acc_gate_in) + up_list = list(acc_up_in) + mfma_res_ty = _mfma_output_pack_ty() + + epilogue_pf = None + # if enable_bias and prefetch_epilogue: + # expert_off_idx + by_n + n_tile_base + # gate_bias = [] + # up_bias = [] + # for ni in range_constexpr(num_acc_n_packed): + # global_n = (by_n + n_tile_base) // 2 + ni * 16 + lane_mod_16 + # gate_offset = expert_off_idx + global_n + # up_offset = expert_off_idx + global_n + inter_dim + # gate_bias.append( + # buffer_ops.buffer_load(bias_rsrc, gate_offset, vec_width=1, dtype=f32) + # ) + # up_bias.append( + # buffer_ops.buffer_load(bias_rsrc, up_offset, vec_width=1, dtype=f32) + # ) + # epilogue_pf = (gate_bias, up_bias) + + c0_i64 = arith.constant(0, type=I.i64) + vec4_i64 = I.vec(4, I.i64) + vec8_i32 = I.vec(8, I.i32) + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + col_offset_base_fp4 = flir.crd2idx(flir.make_coord(lane_div_16, 0), layout_lane16) + + # FP4 path using mfma_scale_f32_16x16x128_f8f6f4 + for ku128 in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + # a_scale_i32 = a_scale[ku128 * m_repeat_packed + mi] + # a_scale_val = vector.extract(a_scale_i32, static_position=[0], dynamic_position=[]) + for ni in range_constexpr(num_acc_n_packed): + b_scale_i32 = b_scale[ku128 * num_acc_n_packed + ni] if b_scale else None + b_scale_val = vector.extract(b_scale_i32, static_position=[0], dynamic_position=[]) if b_scale_i32 else arith.i32(0x3F800000) + for ikxdl in range_constexpr(pack_K): + k_idx = ku128 * pack_K + ikxdl + b_packs0, b_packs1 = b_tile_in[k_idx] + col_base = col_offset_base_fp4 + (k_idx * 128) // a_elem_vec_pack_g1 + + for imxdl in range_constexpr(pack_M): + mi_idx = mi * pack_M + imxdl + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + + if is_f8_a: + col_base1 = col_base + 64 + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + else: + a128 = pack_i64x4_to_i32x8(a0, a1, c0_i64, c0_i64) + + for inxdl in range_constexpr(pack_N): + # Interleaved gate+up: even indices are gate, odd are up + if inxdl % 2 == 0: + current_list = gate_list + else: + current_list = up_list + ni_idx = ni * pack_N + inxdl + + b0 = b_packs0[ni_idx] + b1 = b_packs1[ni_idx] + b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) + + acc_idx = mi_idx * num_acc_n_packed + (ni_idx // 2) + rocdl.sched_barrier(0) + current_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_list[acc_idx], + cbsz_g1, + blgp_g1, + 0, + 0x3F800000, + ikxdl * pack_N + inxdl, + b_scale_val, + ], + ) + + return gate_list, up_list, epilogue_pf # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- lds_tile_elems = arith.constant(tile_m * lds_stride, index=True) @@ -795,11 +1012,31 @@ def hot_loop_scheduler(): rocdl.sched_dswr(1) rocdl.sched_barrier(0) + # ================ Unified Pipeline (FP4 / Standard) for gemm1 ================ + # Select pipeline functions based on is_fp4 + def _load_b_tiles(k_val): + if is_gate_up_inter: + return load_b_tile_inter(k_val // 2) + else: + return (load_b_tile(k_val, n_blk_gate, n_intra_gate), load_b_tile(k_val, n_blk_up, n_intra_up)) + + def _compute_tile(acc_g, acc_u, b_cur, lds_base, a_scale, b_scale, *, prefetch_epilogue=False, a0_prefetch=None): + if is_gate_up_inter: + return compute_tile_inter(acc_g, acc_u, b_cur, lds_base, a_scale=a_scale, b_scale=b_scale, prefetch_epilogue=prefetch_epilogue, a0_prefetch=a0_prefetch) + else: + b_g, b_u = b_cur + return compute_tile(acc_g, acc_u, b_g, b_u, lds_base, prefetch_epilogue=prefetch_epilogue, a0_prefetch=a0_prefetch) + + def _prefetch_scale(k_val): + if is_gate_up_inter: + return prefetch_ab_scale_tile_inter(k_val // pack_K // 128) + return placeholder, placeholder + # Prologue: prefetch tile0, store to LDS(cur), sync. k0 = arith.index(0) x_regs0 = load_x_tile(k0) - b_gate_cur = load_b_tile(k0, n_blk_gate, n_intra_gate) - b_up_cur = load_b_tile(k0, n_blk_up, n_intra_up) + b_cur = _load_b_tiles(k0) + a_scale_pong, b_scale_pong = _prefetch_scale(k0) store_x_tile_to_lds(x_regs0, lds_base_cur) gpu.barrier() @@ -819,18 +1056,19 @@ def hot_loop_scheduler(): # ---- stage 0: prefetch+store ping, compute pong ---- next_k1 = k_iv + tile_k x_regs_ping = load_x_tile(next_k1) - b_gate_ping = load_b_tile(next_k1, n_blk_gate, n_intra_gate) - b_up_ping = load_b_tile(next_k1, n_blk_up, n_intra_up) + b_ping = _load_b_tiles(next_k1) + a_scale_ping, b_scale_ping = _prefetch_scale(next_k1) - acc_gate, acc_up, _ = compute_tile( + acc_gate, acc_up, _ = _compute_tile( acc_gate, acc_up, - b_gate_cur, - b_up_cur, + b_cur, lds_base_pong, + a_scale_pong, + b_scale_pong, a0_prefetch=a0_prefetch_pong, ) - a0_prefetch_pong = None + store_x_tile_to_lds(x_regs_ping, lds_base_ping) hot_loop_scheduler() gpu.barrier() @@ -841,18 +1079,18 @@ def hot_loop_scheduler(): # ---- stage 1: prefetch+store pong, compute ping ---- next_k2 = k_iv + c2_tile_k x_regs_pong = load_x_tile(next_k2) - b_gate_next = load_b_tile(next_k2, n_blk_gate, n_intra_gate) - b_up_next = load_b_tile(next_k2, n_blk_up, n_intra_up) + b_next = _load_b_tiles(next_k2) + a_scale_pong, b_scale_pong = _prefetch_scale(next_k2) - acc_gate, acc_up, _ = compute_tile( + acc_gate, acc_up, _ = _compute_tile( acc_gate, acc_up, - b_gate_ping, - b_up_ping, + b_ping, lds_base_ping, + a_scale_ping, + b_scale_ping, a0_prefetch=a0_prefetch_ping, ) - a0_prefetch_ping = None store_x_tile_to_lds(x_regs_pong, lds_base_pong) hot_loop_scheduler() gpu.barrier() @@ -861,24 +1099,23 @@ def hot_loop_scheduler(): a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) # Advance pong state to next_k2 for next iteration. - b_gate_cur = b_gate_next - b_up_cur = b_up_next + b_cur = b_next # Tail: 2 remaining tiles at (k_in - 2*tile_k) and (k_in - tile_k). k_tail1 = k_in - tile_k x_regs_ping = load_x_tile(k_tail1) - b_gate_ping = load_b_tile(k_tail1, n_blk_gate, n_intra_gate) - b_up_ping = load_b_tile(k_tail1, n_blk_up, n_intra_up) + b_ping = _load_b_tiles(k_tail1) + a_scale_ping, b_scale_ping = _prefetch_scale(k_tail1) - acc_gate, acc_up, _ = compute_tile( + acc_gate, acc_up, _ = _compute_tile( acc_gate, acc_up, - b_gate_cur, - b_up_cur, + b_cur, lds_base_pong, + a_scale_pong, + b_scale_pong, a0_prefetch=a0_prefetch_pong, ) - a0_prefetch_pong = None store_x_tile_to_lds(x_regs_ping, lds_base_ping) hot_loop_scheduler() gpu.barrier() @@ -887,12 +1124,13 @@ def hot_loop_scheduler(): a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) # Epilogue: compute last tile with epilogue scale prefetch to overlap loads with MFMA. - acc_gate, acc_up, epilogue_pf = compute_tile( + acc_gate, acc_up, epilogue_pf = _compute_tile( acc_gate, acc_up, - b_gate_ping, - b_up_ping, + b_ping, lds_base_ping, + a_scale_ping, + b_scale_ping, prefetch_epilogue=True, a0_prefetch=a0_prefetch_ping, ) @@ -958,7 +1196,7 @@ def write_row_to_lds( if doweight_stage1: tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=f32) - for ni in range_constexpr(num_acc_n): + for ni in range_constexpr(num_acc_n // 2 if is_gate_up_inter else num_acc_n): col_local = col_base_local + (ni * 16) acc_idx = mi * num_acc_n + ni @@ -978,15 +1216,31 @@ def write_row_to_lds( sw_up = sw_up_vals[ni] vg = vg * sx * sw_gate vu = vu * sx * sw_up + + # if enable_bias: + # gate_bias_list, up_bias_list = epilogue_pf + # vg = vg + gate_bias_list[ni] + # vu = vu + up_bias_list[ni] - y = silu(vg) * vu + if act == "swiglu": + y = swiglu(vg, vu) + else: + y = silu(vg) * vu + if doweight_stage1: y = y * tw - y16 = arith.trunc_f(T.f16(), y) - + + if not is_cast_out: + y = arith.trunc_f(T.f16(), y) + vec1_out_elem = vec1_f16 + alignment = 2 + else: + vec1_out_elem = vec1_f32 + alignment = 1 + lds_idx = row_base_lds + col_local - v1 = vector.from_elements(vec1_f16, [y16]) - vector.store(v1, lds_out, [lds_idx], alignment=2) + v1 = vector.from_elements(vec1_out_elem, [y]) + vector.store(v1, lds_out, [lds_idx], alignment=alignment) def precompute_row(*, row_local, row): fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) @@ -995,11 +1249,32 @@ def precompute_row(*, row_local, row): return (t2 * topk_i32_v + s2) * inter_i32_local def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): - idx0 = row_ctx - col_i32 = arith.index_cast(i32, col_g0) - idx_out = idx0 + col_i32 - # Vectorized fp16 store (EVec=4). - buffer_ops.buffer_store(frag, out_rsrc, idx_out) + if not is_cast_out: + idx0 = row_ctx + col_i32 = arith.index_cast(i32, col_g0) + idx_out = idx0 + col_i32 + # Vectorized fp16 store (EVec=4). + buffer_ops.buffer_store(frag, out_rsrc, idx_out) + else: + fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) + t2 = fused2 & mask24_i32 + t_valid = arith.cmpu(t2, tokens_i32, "ult") + _if_valid = scf.IfOp(t_valid) + with _if_valid.then(): + frag = vector.bitcast(vec4_f32, frag) + frag0 = vector.extract(frag, static_position=[0], dynamic_position=[]) + frag1 = vector.extract(frag, static_position=[1], dynamic_position=[]) + frag2 = vector.extract(frag, static_position=[2], dynamic_position=[]) + frag3 = vector.extract(frag, static_position=[3], dynamic_position=[]) + + out_fp8 = arith.i32(0) + out_fp8 = rocdl.cvt_pk_fp8_f32(src_a=arith._unwrap_value(frag0), src_b=arith._unwrap_value(frag1), old=arith._unwrap_value(out_fp8), word_sel=0, res=I.i32) + out_fp8 = rocdl.cvt_pk_fp8_f32(src_a=arith._unwrap_value(frag2), src_b=arith._unwrap_value(frag3), old=arith._unwrap_value(out_fp8), word_sel=1, res=I.i32) + + idx0 = row_ctx + col_i32 = arith.index_cast(i32, col_g0) + idx_out = idx0 + col_i32 + buffer_ops.buffer_store(out_fp8, out_rsrc, idx_out // 4) mfma_epilog( use_cshuffle=True, @@ -1009,7 +1284,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): scf=scf, range_constexpr=range_constexpr, tile_m=tile_m, - tile_n=tile_n, + tile_n=tile_n // 2 if is_gate_up_inter else tile_n, e_vec=4, m_repeat=m_repeat, num_acc_n=num_acc_n, @@ -1017,8 +1292,8 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): lane_div_16=lane_div_16, lane_mod_16=lane_mod_16, bx_m=bx_m, - by_n=by_n, - n_tile_base=n_tile_base, + by_n=by_n // 2 if is_gate_up_inter else by_n, + n_tile_base=n_tile_base // 2 if is_gate_up_inter else n_tile_base, lds_out=lds_out, write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, @@ -1047,7 +1322,7 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): if doweight_stage1: tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=f32) - for ni in range_constexpr(num_acc_n): + for ni in range_constexpr(num_acc_n // 2 if is_gate_up_inter else num_acc_n): col_i32 = col_i32_list[ni] acc_idx = mi * num_acc_n + ni @@ -1067,8 +1342,17 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): sw_up = sw_up_vals[ni] vg = vg * sx * sw_gate vu = vu * sx * sw_up + + # if enable_bias: + # gate_bias_list, up_bias_list = epilogue_pf + # vg = vg + gate_bias_list[ni] + # vu = vu + up_bias_list[ni] - y = silu(vg) * vu + if act == "swiglu": + y = swiglu(vg, vu) + else: + y = silu(vg) * vu + if doweight_stage1: y = y * tw y = arith.trunc_f(_out_elem_type(), y) @@ -1106,7 +1390,7 @@ def __call__( gx = inter_in / arith.index(tile_n) # Use host-provided upper bound for M blocks (same as aiter moe_sorting allocation). # This avoids device->host sync on num_valid_ids. - gy = size_expert_ids_in + gy = size_expert_ids_in * 2 if is_gate_up_inter else size_expert_ids_in flir.gpu_ext.LaunchFuncOp( [module_name, "moe_gemm1"], grid_size=(gx, gy, 1), @@ -1132,10 +1416,13 @@ def __call__( exe = flydsl.compile(m) return exe +####==================== gemm1 pipeline end =====================### + + +####==================== gemm2 pipeline start =====================### -# This gemm1 pipeline used interleaved scale shuffle -@functools.lru_cache(maxsize=None) -def compile_gate_up_moe_gemm1( +@functools.lru_cache(maxsize=1024) +def compile_moe_gemm2( *, model_dim: int, inter_dim: int, @@ -1144,1220 +1431,16 @@ def compile_gate_up_moe_gemm1( tile_m: int, tile_n: int, tile_k: int, - # NOTE: aiter swap passes these for API symmetry; stage1 uses dynamic memrefs so they are ignored. - doweight_stage1: bool, - a_dtype: str = "fp8", - b_dtype: str = "fp4", + doweight_stage2: bool, + x_dtype: str = "fp8", + w_dtype: str = "fp8", out_dtype: str = "f16", - act: str = "swiglu", use_cshuffle_epilog: bool | None = None, + # Optional experiment: write per-(token,slot) output (no atomics) into an output shaped + # [tokens*topk, model_dim] (or [tokens, topk, model_dim] flattened), then reduce over topk outside. + # This can reduce atomic contention for small tokens at the cost of extra bandwidth / reduction. + accumulate: bool = True, enable_bias: bool = False, - model_dim_pad: int = 0, - inter_dim_pad: int = 0, -): - """Compile stage1 kernel (`moe_gemm1`) and return the compiled executable. - - a_dtype: - - "fp8": X is fp8 - - "fp16": X is fp16 (caller uses tile_k halved vs fp8 to match MFMA K halving) - - "int8": X is int8 - - "fp4": X is fp4 - - b_dtype: - - "fp8": W is fp8 - - "fp16": W is fp16 (caller uses tile_k halved vs fp8 to match MFMA K halving) - - "int8": W is int8 - - "int4": W4A8 path: X is int8, W is packed int4 (2 values per byte) unpacked to int8 in-kernel - - "fp4": W is fp4 - """ - gpu_arch = get_hip_arch() - allocator = SmemAllocator(None, arch=gpu_arch) - _state = {} - - if a_dtype not in ("fp8", "fp16", "int8", "fp4"): - raise ValueError(f"a_dtype must be one of ('fp8','fp16','int8','fp4'), got {in_dtype!r}") - if b_dtype not in ("fp8", "fp16", "int8", "int4", "fp4"): - raise ValueError(f"in_dtype must be one of ('fp8','fp16','int8','int4', 'fp4'), got {in_dtype!r}") - - is_f16_a = a_dtype == "fp16" - is_f16_b = b_dtype == "fp16" - is_f16 = is_f16_a or is_f16_b - - is_f8_a = a_dtype == "fp8" - is_f4_a = a_dtype == "fp4" - is_f4_b = b_dtype == "fp4" - - pack_M = 2 - pack_N = 2 - pack_K = 2 - - elem_bytes = 1 - - a_elem_bytes = 2 if is_f16_a else 1 - b_elem_bytes = 1 - tile_k_bytes = int(tile_k) * int(a_elem_bytes) - - a_elem_vec_pack = 2 if is_f4_a else 1 - cbsz = 0 if is_f8_a else 4 - blgp = 4 - - # enable_bias = False - - # K64-byte micro-step: always 64 bytes per `ku`. For fp16, this is 32 elements (2xK16 MFMA). - if (tile_k_bytes % 64) != 0: - raise ValueError( - f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " - f"(tile_k={tile_k}, elem_bytes={a_elem_bytes})" - ) - is_int4 = b_dtype == "int4" - # INT4 here means W4A8: X is int8, W is packed int4 and unpacked to int8 in-kernel. - # is_int8 = (in_dtype == "int8") or is_int4 - is_int8 = False - - mfma_i32_k32 = None - if is_int8: - mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( - rocdl, "mfma_i32_16x16x32_i8", None - ) - if mfma_i32_k32 is None: - raise AttributeError( - "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " - "(or `rocdl.mfma_i32_16x16x32_i8`)." - ) - - def _x_elem_type(): - if is_f4_b: - return I.f8 if is_f8_a else I.ui8 - return I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) - - def _w_elem_type(): - if is_f4_b: - return I.ui8 - return I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) - - def _scale_elem_type(): - return I.i32 - - def _out_elem_type(): - return I.bf16 if out_dtype == "bf16" else I.i32 - - def _out_lds_elem_type(): - return I.f32 - - def _out_vec_type(): - return I.bf16x1 if out_dtype == "bf16" else I.f8x1 - - # size_out = tokens * topk * inter_dim - # size_x = tokens * model_dim - # # W is packed int4 for W4A8: 2 values per byte. - # size_w = (experts * (2 * inter_dim) * model_dim) // 2 if is_int4 else (experts * (2 * inter_dim) * model_dim) - - DYN = ir.ShapedType.get_dynamic_size() - size_out = DYN - size_x = DYN - # W is packed int4 for W4A8: 2 values per byte. - size_w = (experts * (2 * inter_dim) * model_dim) // 2 if is_int4 else (experts * (2 * inter_dim) * model_dim) - size_sorted = DYN - size_expert_ids = DYN - - total_threads = 256 - bytes_x_per_tile = int(tile_m) * int(tile_k) * int(a_elem_bytes) - if bytes_x_per_tile % total_threads != 0: - raise ValueError( - "tile_m*tile_k*elem_bytes must be divisible by " - f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={a_elem_bytes}" - ) - bytes_per_thread_x = bytes_x_per_tile // total_threads - # Keep MoE stage1 X gmem->LDS pipeline consistent with the optimized GEMM kernel: - # split into <=16B pieces and use `flir.copy(load-only)` for buffer_load_dwordx4. - # (Compute the split lens inside the kernel so the code matches GEMM structure.) - - # CK-style LDS128 mode (same idea as test_preshuffle_gemm.py): - # - LDS stride == tile_k (no extra padding) + XOR16 swizzle - # - Use ds_{read,write}_b128 (16B) and extract 8B halves for MFMA steps - _ck_lds128 = os.environ.get("FLIR_CK_LDS128", "1") in ("1", "true", "True", "YES", "yes") - pad_k = 0 if _ck_lds128 else 8 - lds_stride = tile_k + pad_k - if use_cshuffle_epilog is None: - use_cshuffle_epilog = os.environ.get("FLIR_MOE_STAGE1_CSHUFFLE", "1") in ("1", "true", "True", "YES", "yes") - use_cshuffle_epilog = bool(use_cshuffle_epilog) - - epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" - module_name = f"mfma_moe1_a{a_dtype}_w{b_dtype}_{epilog_tag}".replace("-", "_") - - class _MOE1(flir.MlirModule): - GPU_MODULE_NAME = module_name - GPU_MODULE_TARGETS = [ - f'#rocdl.target' - ] - - def init_gpu_module(self): - # Optional epilogue CShuffle (LDS + vectorized buffer stores). - # Reuse the same LDS bytes for both: - # - ping-pong X tiles (2 * tile_m * lds_stride bytes; fp8/int8) - # - epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes) - _use_cshuffle_epilog = bool(use_cshuffle_epilog) - lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(a_elem_bytes) - lds_out_bytes = 4 * tile_m * (tile_n // 2) if _use_cshuffle_epilog else 0 - lds_total_bytes = max(lds_x_bytes, lds_out_bytes) - lds_total_elems = lds_total_bytes if a_elem_bytes == 1 else (lds_total_bytes // 2) - x_lds_elem = I.f16 if is_f16_a else (I.i8 if is_int8 else I.f8) - _state["lds_x_decl"] = allocator.allocate_array(x_lds_elem, lds_total_elems) - allocator.finalize() - - @flir.kernel - def moe_gemm1( - self: flir.T.i64, - arg_out: lambda: T.memref(DYN, _out_elem_type()), - arg_x: lambda: T.memref(DYN, _x_elem_type()), - arg_w: lambda: T.memref(DYN, _w_elem_type()), - arg_scale_x: lambda: T.memref(DYN, _scale_elem_type()), - arg_scale_w: lambda: T.memref(experts * (2 * inter_dim), _scale_elem_type()), - arg_sorted_token_ids: lambda: T.memref(size_sorted, T.i32()), - arg_expert_ids: lambda: T.memref(DYN, T.i32()), - arg_sorted_weights: lambda: T.memref(DYN, T.f32()), - arg_max_token_ids: lambda: T.memref(DYN, T.i32()), - arg_bias: lambda: T.memref(DYN, T.f32()), - tokens_in: lambda: T.index(), - inter_in: lambda: T.index(), - k_in: lambda: T.index(), - size_expert_ids_in: lambda: T.index(), - ): - x_elem = I.f16 if is_f16_a else (I.i8 if is_int8 else I.f8) - # For int4, weights are stored as packed bytes (i8) and unpacked to i8 packs. - w_elem = I.f16 if is_f16_b else (I.i8 if is_int8 else I.f8) - f16 = I.f16 - f32 = I.f32 - i32 = I.i32 - i64 = I.i64 - vec4_f32 = I.vec(4, f32) - vec4_i32 = I.vec(4, i32) - vec4_f16 = I.vec(4, f16) - vec4_f8 = I.vec(4, I.f8) - vec1_f16 = I.vec(1, f16) - vec1_f32 = I.vec(1, f32) - vec16_elems = 16 if a_elem_bytes == 1 else 8 - vec8_elems = 8 if a_elem_bytes == 1 else 4 - vec4_elems = 4 if a_elem_bytes == 1 else 2 - vec8_x = I.vec(vec8_elems, x_elem) - vec16_x = I.vec(vec16_elems, x_elem) - vec1_i64 = I.vec(1, i64) - vec2_i64 = I.vec(2, i64) - - def silu(x): - # Align with CK's device fast path: - # emu = exp(-x) ~= exp2(log2e * (-x)) -> v_exp_f32 - # sig = rcp(1 + emu) -> v_rcp_f32 - # y = x * sig - # - # Using llvm.amdgcn intrinsics prevents lowering to the div_scale/div_fixup - # sequences that introduce extra compares/cndmasks. - t = x * (-1.4426950408889634) # -log2(e) - emu = llvm.call_intrinsic(f32, "llvm.amdgcn.exp2.f32", [t], [], []) - den = 1.0 + emu - sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) - return x * sig - - def swiglu(gate, up, alpha=1.702, limit=7.0): - # Align with CK's device fast path - # - # Using llvm.amdgcn intrinsics prevents lowering to the div_scale/div_fixup - # sequences that introduce extra compares/cndmasks. - gate = arith.minimum(gate, limit) - up = arith.minimum(up, limit) - up = arith.maximum(up, -limit) - - t = gate * (alpha) * (-1.4426950408889634) # -log2(e) - emu = llvm.call_intrinsic(f32, "llvm.amdgcn.exp2.f32", [t], [], []) - den = 1.0 + emu - sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) - return gate * sig * (up + 1) - - acc_init = ( - arith.constant_vector(0, vec4_i32) - if is_int8 - else arith.constant_vector(0.0, vec4_f32) - ) - - # Lccouts - layout_x = flir.make_layout((tokens_in, k_in), stride=(k_in, 1)) - - # B preshuffle layout: match GEMM test helper exactly. - c_n_total = arith.constant(experts * (2 * inter_dim), index=True) - kpack_bytes = 8 if is_int4 else 16 - b_layout = make_preshuffle_b_layout( - flir, arith, c_n=c_n_total, c_k=k_in // pack_K, kpack_bytes=kpack_bytes, elem_bytes=b_elem_bytes - ) - layout_b = b_layout.layout_b - - m_repeat = tile_m // 16 - k_unroll = tile_k_bytes // 128 # K64-byte micro-step - - # A&B's scale preshuffle layout - layout_a_scale = make_preshuffle_scale_layout( - flir, arith, c_mn=tokens_in, c_k=k_in, - ) - layout_b_scale = make_preshuffle_scale_layout( - flir, arith, c_mn=c_n_total, c_k=k_in, - ) - - # Only used by fp8/int8 path (16B gmem -> regs). Kept for backwards compat. - atom_w_g2r16 = flir.make_copy_atom(w_elem, vector_size=16) - - shape_lds = flir.make_shape(tile_m, tile_k) - stride_lds = flir.make_stride(lds_stride, 1) - layout_lds = flir.make_layout(shape_lds, stride_lds) - - tx = gpu.thread_id("x") - # Align with Aiter launch mapping (NSwizzle==false): - # - blockIdx.x -> N dimension (tile along inter_dim) - # - blockIdx.y -> expert-block id / M dimension (tile along sorted M) - by = gpu.block_id("x") # tile along inter_dim - bx = gpu.block_id("y") # tile along sorted M - - # Block validity: compute as early as possible so invalid blocks skip all buffer-resource - # setup, LDS pointer math, and gmem prefetch work. - bx_m = bx * arith.constant(tile_m, index=True) - by_n = by * arith.constant(tile_n, index=True) - - maxids_rsrc = buffer_ops.create_buffer_resource( - arg_max_token_ids, max_size=False, num_records_bytes=arith.i32(4) - ) - max_token_id_i32 = buffer_ops.buffer_load( - maxids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=i32 - ) - - bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None - - bx_m_i32 = arith.index_cast(i32, bx_m) - by_n_i32 = arith.index_cast(i32, by_n) - blk_valid = arith.cmpu(bx_m_i32, max_token_id_i32, "ult") - # Common constants/atoms (hoisted): keep IR small like GEMM. - # CK-style XOR16 swizzle parameter (constant, power-of-two in our configs). - k_blocks16 = arith.constant(tile_k_bytes // 16, index=True) - atom_x_s16 = flir.make_copy_atom(x_elem, vector_size=16) - atom_x_s8 = flir.make_copy_atom(x_elem, vector_size=8) - atom_x_s4 = flir.make_copy_atom(x_elem, vector_size=4) - atom_x_g2r16 = flir.make_copy_atom(x_elem, vector_size=vec16_elems) - atom_x_g2r8 = flir.make_copy_atom(x_elem, vector_size=vec8_elems) - atom_x_g2r4 = flir.make_copy_atom(x_elem, vector_size=vec4_elems) - layout_tx_wave_lane = flir.make_layout((4, 64), stride=(64, 1)) - layout_lane16 = flir.make_layout((4, 16), stride=(16, 1)) - - _if_blk = scf.IfOp(blk_valid) - with _if_blk.then(): - base_ptr = allocator.get_base() - lds_x_ptr = _state["lds_x_decl"](base_ptr) - lds_x = lds_x_ptr.get() - # Alias LDS bytes as fp16 for optional CShuffle epilogue. - _use_cshuffle_epilog = bool(use_cshuffle_epilog) - - lds_out = ( - SmemPtr(base_ptr, lds_x_ptr.byte_offset, _out_lds_elem_type(), shape=(tile_m * tile_n,)).get() - if _use_cshuffle_epilog - else None - ) - - # Use logical buffer sizes (descriptor num_records) so hardware OOB checking can be - # used directly (CK-style). This allows us to avoid `select`-based masking for - # invalid lanes and rely on the buffer instruction's built-in bounds behavior. - x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=tokens_in*model_dim) - w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) - out_rsrc = buffer_ops.create_buffer_resource(arg_out, max_size=False) - - # fp16 path ignores scales completely (implicit scale=1.0). - sx_rsrc = None if is_f16_a else buffer_ops.create_buffer_resource(arg_scale_x, max_size=False) - sw_rsrc = None if is_f16_b else buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) - sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False) - sorted_w_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=False) - - # expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4 - eid_nbytes_i32 = arith.index_cast(i32, size_expert_ids_in * arith.constant(4, index=True)) - expert_rsrc = buffer_ops.create_buffer_resource( - arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 - ) - - # Expert id for this M tile (keep address math in `index`) - expert_i32 = buffer_ops.buffer_load(expert_rsrc, bx, vec_width=1, dtype=i32) - exp_valid = arith.cmpu(expert_i32, experts, "ult") # todo fix - _ifexpert_of = scf.IfOp(exp_valid) - with _ifexpert_of.then(): - expert_idx = arith.index_cast(ir.IndexType.get(), expert_i32) - inter2_idx = arith.constant(2 * inter_dim, index=True) - expert_off_idx = expert_idx * inter2_idx # index - - bx_m = bx * arith.constant(tile_m, index=True) - - # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- - # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by - # 16, fall back to 8B (dwordx2) or 4B (dword) loads. This broadens supported tilings - # (e.g. tile_m=16, tile_k=192 -> 12B/thread) at some performance cost. - if is_f16_a: - # fp16 path keeps the same fixed 16B gmem->reg schedule. - if bytes_per_thread_x % 16 != 0: - raise ValueError( - f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" - ) - x_load_bytes = 16 - else: - if bytes_per_thread_x % 16 == 0: - x_load_bytes = 16 - elif bytes_per_thread_x % 8 == 0: - x_load_bytes = 8 - elif bytes_per_thread_x % 4 == 0: - x_load_bytes = 4 - else: - raise ValueError( - f"bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 4 to use the dword-indexed load mapping." - ) - num_x_loads = bytes_per_thread_x // x_load_bytes - chunk_i32 = x_load_bytes // 4 # dwords per chunk (1/2/4) - - # Work in dword units along K: K_dwords = (K_bytes)/4. - c_k_div4 = (k_in * arith.constant(int(elem_bytes), index=True)) / arith.index(4) - layout_x_div4 = flir.make_layout((tokens_in, c_k_div4), stride=(c_k_div4, 1)) - tile_k_dwords = (int(tile_k) * int(elem_bytes)) // 4 - layout_x_tile_div4 = flir.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) - c_chunk_i32 = arith.constant(chunk_i32, index=True) - tx_i32_base = tx * c_chunk_i32 - mask24 = arith.i32(0xFFFFFF) - # Keep i32 constants available for epilogue index math. - topk_i32 = arith.i32(topk) - - tokens_i32 = arith.index_cast(i32, tokens_in) - - def x_tile_chunk_coord_i32(i: int): - return tile_chunk_coord_i32( - flir, - arith, - tx_i32_base=tx_i32_base, - i=i, - total_threads=total_threads, - layout_tile_div4=layout_x_tile_div4, - chunk_i32=chunk_i32, - ) - - # CK-aligned: decode token once (per thread's M-slice) and build a base row offset. - x_row_base_div4 = [] - x_col_local_i32 = [] - x_row_local = [] - for i in range_constexpr(num_x_loads): - row_local, col_local_i32 = x_tile_chunk_coord_i32(i) - x_row_local.append(row_local) - x_col_local_i32.append(col_local_i32) - - sorted_row_i = bx_m + row_local - fused_i = buffer_ops.buffer_load(sorted_rsrc, sorted_row_i, vec_width=1, dtype=i32) - t_i32 = arith.andi(fused_i, mask24) - t_idx = arith.index_cast(ir.IndexType.get(), t_i32) - x_row_base_div4.append(t_idx * c_k_div4) - - vec1_i32 = I.vec(1, i32) - vec2_i32 = I.vec(2, i32) - vec4_i32 = I.vec(4, i32) - vec4_x = I.vec(4, x_elem) - - def load_x(idx_i32): - """Load `x_load_bytes` bytes from X (gmem) into regs. - - For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. - """ - if x_load_bytes == 16: - idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * arith.index(2)) - return buffer_copy_gmem16_dwordx4( - flir, - arg=arg_x, - elem_type=x_elem, - # idx_i32=idx_elem + 0x80000000, - idx_i32=idx_elem, - atom_g2r16=atom_x_g2r16, - rsrc=x_rsrc, - vec_elems=vec16_elems, - ) - idx_bytes = idx_i32 * arith.index(4) - atom = atom_x_g2r8 if x_load_bytes == 8 else atom_x_g2r4 - view = flir.TensorView( - arg_x, - (x_load_bytes,), - strides=(1,), - base_indices=(idx_bytes,), - element_type=x_elem, - ) - return flir.copy( - atom, - view, - None, - alignment=x_load_bytes, - return_vector=True, - src_buffer_resource=x_rsrc, - src_buffer_offset_in_bytes=True, - ) - - def load_x_tile(base_k): - """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" - base_k_div4 = (base_k * arith.constant(int(elem_bytes), index=True)) / arith.index(4) - parts = [] - for i in range_constexpr(num_x_loads): - idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] - x_vec = load_x(idx_i32) - parts.append(vector.bitcast(vec4_i32, x_vec)) - return parts - - # tx -> wave/lane (GEMM-style decomposition). - coord_wl = flir.idx2crd(tx, layout_tx_wave_lane) - wave_id = flir.get(coord_wl, 0) - lane_id = flir.get(coord_wl, 1) - coord_l16 = flir.idx2crd(lane_id, layout_lane16) - lane_div_16 = flir.get(coord_l16, 0) - lane_mod_16 = flir.get(coord_l16, 1) - - # Match GEMM naming/pattern: row in LDS is lane_mod_16, and col base is lane_div_16*16B (KPackBytes=16). - row_a_lds = lane_mod_16 - # col_offset_base = lane_div_16 * arith.constant(32, index=True) - col_offset_base = lane_div_16 * arith.constant(16, index=True) - - # Dynamic N tiling within block (same as existing kernels) - num_waves = 4 - n_per_wave = tile_n // num_waves - num_acc_n = n_per_wave // 16 - c_n_per_wave = arith.constant(n_per_wave, index=True) - wave_mod_4 = wave_id % arith.index(4) - n_tile_base = wave_mod_4 * c_n_per_wave - - # fp4 pack - k_unroll_packed = k_unroll // pack_K - m_repeat_packed = m_repeat // pack_M - num_acc_n_packed = num_acc_n // pack_N - - # Precompute n_blk/n_intra for gate and up rows (GEMM-style: idx2crd/get) - col_g_list = [] - valid_col_list = [] - inter_idx = arith.constant(inter_dim, index=True) - # layout for (row -> (blk,intra)) where intra is 0..15 - c_n0 = c_n_total / arith.index(16) - layout_n_blk_intra = flir.make_layout((c_n0, 16), stride=(16, 1)) - n_intra_list = [] - n_blk_list = [] - for i in range_constexpr(num_acc_n): - offset = i * 16 - - col_g = by_n + n_tile_base - col_g = col_g // 2 + offset - col_g = col_g + lane_mod_16 - col_g_list.append(col_g) - - c_offset = arith.constant(offset, index=True) - global_n = by_n + n_tile_base + c_offset + lane_mod_16 - row_w = expert_off_idx + global_n - coord_n = flir.idx2crd(row_w, layout_n_blk_intra) - n_blk_list.append(flir.get(coord_n, 0)) - n_intra_list.append(flir.get(coord_n, 1)) - - # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- - def load_b_packs_k64(base_k, ku: int, ni: int): - base_k_bytes = base_k * arith.constant(int(elem_bytes), index=True) - k0_base = base_k_bytes / 64 - k0 = k0_base + ku - k1 = lane_div_16 - coord_pack = flir.make_coord(n_blk_list[ni], k0, k1, n_intra_list[ni], 0) - idx_pack = flir.crd2idx(coord_pack, layout_b) - - # Calculate mask for boundary check - c_offset = arith.constant(ni * 16, index=True) - global_n = by_n + n_tile_base + c_offset + lane_mod_16 - - vec_elems = 16 - b_view = flir.TensorView( - arg_w, - (vec_elems,), - strides=(1,), - base_indices=(idx_pack,), - element_type=_w_elem_type(), - ) - b16 = flir.copy( - flir.make_copy_atom(_w_elem_type(), vector_size=vec_elems), - b_view, - None, - alignment=8, - return_vector=True, - src_buffer_resource=(w_rsrc if elem_bytes == 1 else None), - src_buffer_offset_in_bytes=(elem_bytes == 1), - ) - # Split 16B pack into two 8B halves. - b_i64x2 = vector.bitcast(I.i64x2, b16) - b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) - b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) - return b0_i64, b1_i64 - - def load_b_tile(base_k): - b_tile = [] - for ku in range_constexpr(k_unroll): - packs0 = [] - packs1 = [] - for ni in range_constexpr(num_acc_n): - b0, b1 = load_b_packs_k64(base_k, ku, ni) - packs0.append(b0) - packs1.append(b1) - b_tile.append((packs0, packs1)) - return b_tile - - def load_scale(arg_scale, rsrc, layout, ku, mni): - k_lane = lane_div_16 - n_lane = lane_mod_16 - coord_pack = flir.make_coord(mni, ku, k_lane, n_lane) - idx_pack = flir.crd2idx(coord_pack, layout) - scale_view = flir.TensorView( - arg_scale, - (1,), - strides=(1,), - base_indices=(idx_pack,), - element_type=_scale_elem_type(), - ) - scale = flir.copy( - flir.make_copy_atom(_scale_elem_type(), vector_size=1), - scale_view, - None, - alignment=8, - return_vector=True, - src_buffer_resource=rsrc, - src_buffer_offset_in_bytes=False, - ) - # Split 16B pack into two 8B halves. - return scale - - def load_b_scale_tile(base_k): - b_scale_tile = [] - for ku in range_constexpr(k_unroll_packed): - for ni in range_constexpr(num_acc_n_packed): - scale = load_scale( - arg_scale_w, - sw_rsrc, - layout_b_scale, - ku + base_k, - ni + (expert_off_idx + by_n + n_tile_base) // pack_N // 16, - ) - b_scale_tile.append(scale) - return b_scale_tile - - def load_a_scale_tile(base_k): - a_scale_tile = [] - for ku in range_constexpr(k_unroll_packed): - for mi in range_constexpr(m_repeat_packed): - scale = load_scale( - arg_scale_x, - sx_rsrc, - layout_a_scale, - ku + base_k, - mi + bx_m // pack_M // 16, - ) - a_scale_tile.append(scale) - return a_scale_tile - - def prefetch_ab_scale_tile(base_k): - return [None, load_b_scale_tile(base_k)] - - acc_gate = [acc_init] * (num_acc_n // 2 * m_repeat) - acc_up = [acc_init] * (num_acc_n // 2 * m_repeat) - - # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): - for i in range_constexpr(num_x_loads): - row_local = x_row_local[i] - col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: - lds_store_16b_xor16( - flir, - arith, - vector, - lds_memref=lds_x, - vec16_ty=vec16_x, - elem_type=x_elem, - atom_s16=atom_x_s16, - layout_lds=layout_lds, - row_local=row_local, - col_local_i32=col_local_i32, - tx_c4=arith.index(4), - k_blocks16=k_blocks16, - lds_base=lds_base, - vec_part_i32x4=vec_x_in_parts[i], - elem_bytes=elem_bytes, - ) - - # --- A LDS load helper for K64 (load 16B once, extract 2x i64 halves) --- - def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): - # Swizzle in bytes, then convert to element offset for memref indexing. - col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) - col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2)) - coord_a16 = flir.make_coord(curr_row_a_lds, col_base_swz) - idx_a16 = flir.crd2idx(coord_a16, layout_lds) - idx_a16 = idx_a16 + lds_base - loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) - a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) - a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) - a1 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) - return a0, a1 - - def compute_f8f6f4_tile( - acc_gate_in, - acc_up_in, - b_tile_in, - lds_base, - *, - a0_prefetch=None, - a_scale=None, - b_scale=None, - prefetch_epilogue: bool = False - ): - gate_list = list(acc_gate_in) - up_list = list(acc_up_in) - - epilogue_pf = None - if enable_bias and prefetch_epilogue: - expert_off_idx + by_n + n_tile_base - - gate_bias = [] - up_bias = [] - for ni in range_constexpr(num_acc_n_packed): - global_n = (by_n + n_tile_base) // 2 + ni * 16 + lane_mod_16 - gate_offset = expert_off_idx + global_n - up_offset = expert_off_idx + global_n + inter_dim - gate_bias.append( - buffer_ops.buffer_load(bias_rsrc, gate_offset, vec_width=1, dtype=f32) - ) - up_bias.append( - buffer_ops.buffer_load(bias_rsrc, up_offset, vec_width=1, dtype=f32) - ) - epilogue_pf = (gate_bias, up_bias) - - # ---------------- gfx95 fast path (K128 MFMA scale) ---------------- - # This is the key optimization from `zhimding/develop_0107` for FP8: - # use mfma.scale 16x16x128 to reduce instruction count in the hot loop. - # - # Notes: - # - Only valid for fp8 path (not int8/int4) and gfx95+ - # - Requires tile_k divisible by 128 - # - mfma.scale takes 9 operands: 3 vectors + 6 i32 flags/scales. - if (int(tile_k) % 128) != 0: - raise ValueError( - f"tile_k must be divisible by 128 for mfma_scale_x128, got tile_k={tile_k}" - ) - - mfma_res_ty = I.f32x4 - vec4_i64 = I.vec(4, I.i64) - vec8_i32 = I.vec(8, I.i32) - c0_i64 = arith.constant(0, type=I.i64) - - def pack_i64x4_to_i32x8(x0, x1, x2, x3): - v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) - return vector.bitcast(vec8_i32, v4) - - for ku128 in range_constexpr(k_unroll_packed): - for mi in range_constexpr(m_repeat_packed): - # a_scale_i32 = a_scale[ku128 * m_repeat_packed + mi] - # a_scale_val = vector.extract(a_scale_i32, static_position=[0], dynamic_position=[]) - for ni in range_constexpr(num_acc_n_packed): - b_scale_i32 = b_scale[ku128 * num_acc_n_packed + ni] - b_scale_val = vector.extract(b_scale_i32, static_position=[0], dynamic_position=[]) - for ikxdl in range_constexpr(pack_K): - k_idx = ku128 * pack_K + ikxdl - b_packs0, b_packs1 = b_tile_in[k_idx] - col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack - for imxdl in range_constexpr(pack_M): - col_base0 = col_base - mi_idx = mi * pack_M + imxdl - mi_val = arith.constant(mi_idx * 16, index=True) - curr_row_a_lds = row_a_lds + mi_val - - if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): - a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) - - if is_f8_a: - col_base1 = col_base + 64 - a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) - a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - else: - a128 = pack_i64x4_to_i32x8(a0, a1, c0_i64, c0_i64) - - for inxdl in range_constexpr(pack_N): - if inxdl % 2 == 0: - current_accs_list = gate_list - else: - current_accs_list = up_list - ni_idx = ni * pack_N + inxdl - - b0 = b_packs0[ni_idx] - b1 = b_packs1[ni_idx] - b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) - - acc_idx = mi_idx * num_acc_n_packed + ni - rocdl.sched_barrier(0) - current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - b128, - current_accs_list[acc_idx], - cbsz, - blgp, - # use per tensor quant a1 for now, - 0, - 0x3F800000, - ikxdl * pack_N + inxdl, - b_scale_val, - ], - ) - - return gate_list, up_list, epilogue_pf - - # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- - lds_tile_elems = arith.constant(tile_m * lds_stride, index=True) - lds_base_cur = arith.index(0) - lds_base_nxt = lds_tile_elems - - # Optional scheduler hints (copied from tuned GEMM); can be disabled via env. - rocdl.sched_barrier(0) - - def hot_loop_scheduler(): - mfma_group = num_acc_n * 2 - # K64 micro-step: 2x K32 MFMA per gemm. - mfma_total = (k_unroll * 2) * m_repeat * mfma_group - mfma_per_iter = 2 * mfma_group - sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) - - # DS-read preload (CK default is 2); clamp to non-negative. - rocdl.sched_dsrd(2) - rocdl.sched_mfma(2) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - - # DS-write hints near the end: match total X LDS-store micro-ops per thread. - dswr_tail = num_x_loads - if dswr_tail > sche_iters: - dswr_tail = sche_iters - dswr_start = sche_iters - dswr_tail - for sche_i in range_constexpr(sche_iters): - rocdl.sched_vmem(1) - rocdl.sched_mfma(mfma_group) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(mfma_group) - if sche_i >= dswr_start - 1: - rocdl.sched_dswr(1) - rocdl.sched_barrier(0) - - # Prologue: prefetch tile0, store to LDS(cur), sync. - k0 = arith.index(0) - x_regs0 = load_x_tile(k0) - w_regs0 = load_b_tile(k0) - - a_scale_pong = None - a_scale_ping = None - # a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(k0 // 2) - _, b_scale_pong = prefetch_ab_scale_tile(k0 // 2) - store_x_tile_to_lds(x_regs0, lds_base_cur) - gpu.barrier() - - # Loop-carried ping/pong state. - lds_base_pong = lds_base_cur # current/compute - lds_base_ping = lds_base_nxt # next/load+store - w_regs_pong = w_regs0 - - # Cross-tile A0 LDS prefetch (default-on): prefetch the first A-pack (K64) for the - # tile we are about to compute from LDS, to overlap with upcoming VMEM. - a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base, lds_base_pong) - - # Unrolled ping-pong main loop (2 tiles per iteration), leaving 2 tail tiles. - c2_tile_k = arith.constant(tile_k * 2, index=True) - c_k_main2 = k_in - c2_tile_k - - for k_iv in range(arith.index(0), c_k_main2, c2_tile_k): - # ---- stage 0: prefetch+store ping, compute pong ---- - next_k1 = k_iv + tile_k - x_regs_ping = load_x_tile(next_k1) - w_regs_ping = load_b_tile(next_k1 // 2) - _, b_scale_ping = prefetch_ab_scale_tile(next_k1 // pack_K // 128) - - acc_gate, acc_up, _ = compute_f8f6f4_tile( - acc_gate, - acc_up, - w_regs_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - b_scale=b_scale_pong, - ) - a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) - # hot_loop_scheduler() - gpu.barrier() - - # Cross-tile prefetch for the ping tile we are about to compute. - a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base, lds_base_ping) - - # ---- stage 1: prefetch+store pong, compute ping ---- - next_k2 = k_iv + c2_tile_k - x_regs_pong = load_x_tile(next_k2) - w_regs_pong = load_b_tile(next_k2 // 2) - _, b_scale_pong = prefetch_ab_scale_tile(next_k2 // pack_K // 128) - - acc_gate, acc_up, _ = compute_f8f6f4_tile( - acc_gate, - acc_up, - w_regs_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - a_scale=a_scale_ping, - b_scale=b_scale_ping, - ) - a0_prefetch_ping = None - store_x_tile_to_lds(x_regs_pong, lds_base_pong) - # hot_loop_scheduler() - gpu.barrier() - - # Cross-tile prefetch for the next pong tile. - a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base, lds_base_pong) - - # Tail: 2 remaining tiles at (k_in - 2*tile_k) and (k_in - tile_k). - k_tail1 = k_in - tile_k - x_regs_ping = load_x_tile(k_tail1) - w_regs_ping = load_b_tile(k_tail1 // 2) - _, b_scale_ping = prefetch_ab_scale_tile(k_tail1 // pack_K // 128) - - acc_gate, acc_up, _ = compute_f8f6f4_tile( - acc_gate, - acc_up, - w_regs_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - b_scale=b_scale_pong, - ) - a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) - # hot_loop_scheduler() - gpu.barrier() - - # Cross-tile prefetch for the final ping tile. - a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base, lds_base_ping) - - # Epilogue: compute last tile with epilogue scale prefetch to overlap loads with MFMA. - acc_gate, acc_up, epilogue_pf = compute_f8f6f4_tile( - acc_gate, - acc_up, - w_regs_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - a_scale=a_scale_ping, - b_scale=b_scale_ping, - prefetch_epilogue=True, - ) - - # Store epilogue to out[t, slot, inter] - expert_off = expert_off_idx - bx_m0 = bx_m - topk_i32_v = topk_i32 - inter_i32_v = arith.i32(inter_dim) - mask24_i32 = arith.i32(0xFFFFFF) - - # Epilogue hoists to keep IR + Python build time small: - col_i32_list = [] - for ni in range_constexpr(num_acc_n): - col_i32_list.append(arith.index_cast(i32, col_g_list[ni])) - - lane_div_16_mul4 = lane_div_16 * arith.index(4) - inter_i32_local = inter_i32_v - - # Optional: CK-style CShuffle epilogue for better global store coalescing. - # Uses EVec=4 (buffer store "x4" of fp16 elements). - _use_cshuffle_epilog = (out_dtype == "fp8") or bool(use_cshuffle_epilog) - - mask_even_i32 = arith.i32(0xFFFFFFFE) - - if _use_cshuffle_epilog: - if lds_out is None: - raise RuntimeError("CShuffle epilogue enabled but lds_out is not allocated/aliased.") - - def write_row_to_lds( - *, - mi: int, - ii: int, - row_in_tile, - row, - row_base_lds, - col_base_local, - num_acc_n: int, - lds_out, - ): - # `row` is the sorted-row index (bx_m + row_in_tile). - fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) - t2 = fused2 & mask24_i32 - - # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: - tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=f32) - - for ni in range_constexpr(num_acc_n_packed): - col_local = col_base_local + (ni * 16) - - acc_idx = mi * num_acc_n + ni - vg = vector.extract( - acc_gate[acc_idx], static_position=[ii], dynamic_position=[] - ) - vu = vector.extract( - acc_up[acc_idx], static_position=[ii], dynamic_position=[] - ) - - if enable_bias: - gate_bias_list, up_bias_list = epilogue_pf - vg = vg + gate_bias_list[ni] - vu = vu + up_bias_list[ni] - - if act == "swiglu": - y = swiglu(vg, vu) - else: - y = silu(vg) * vu - - if doweight_stage1: - y = y * tw - - lds_idx = row_base_lds + col_local - v1 = vector.from_elements(vec1_f32, [y]) - vector.store(v1, lds_out, [lds_idx], alignment=1) - - def precompute_row(*, row_local, row): - fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) - t2 = fused2 & mask24_i32 - s2 = fused2 >> 24 - return (t2 * topk_i32_v + s2) * inter_i32_local - - def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): - # Guard against sentinel token ids (t == tokens) produced by aiter moe_sorting padding. - # OOB buffer stores are not guaranteed to be safe on all paths, so predicate explicitly. - fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) - t2 = fused2 & mask24_i32 - t_valid = arith.cmpu(t2, tokens_i32, "ult") - _if_valid = scf.IfOp(t_valid) - with _if_valid.then(): - frag = vector.bitcast(vec4_f32, frag) - frag0 = vector.extract(frag, static_position=[0], dynamic_position=[]) - frag1 = vector.extract(frag, static_position=[1], dynamic_position=[]) - frag2 = vector.extract(frag, static_position=[2], dynamic_position=[]) - frag3 = vector.extract(frag, static_position=[3], dynamic_position=[]) - - out_fp8 = arith.i32(0) - out_fp8 = rocdl.cvt_pk_fp8_f32(src_a=arith._unwrap_value(frag0), src_b=arith._unwrap_value(frag1), old=arith._unwrap_value(out_fp8), word_sel=0, res=I.i32) - out_fp8 = rocdl.cvt_pk_fp8_f32(src_a=arith._unwrap_value(frag2), src_b=arith._unwrap_value(frag3), old=arith._unwrap_value(out_fp8), word_sel=1, res=I.i32) - - idx0 = row_ctx - col_i32 = arith.index_cast(i32, col_g0) - idx_out = idx0 + col_i32 - buffer_ops.buffer_store(out_fp8, out_rsrc, idx_out // 4) - - mfma_epilog( - use_cshuffle=True, - arith=arith, - vector=vector, - gpu=gpu, - range_constexpr=range_constexpr, - tile_m=tile_m, - tile_n=tile_n // 2, - e_vec=4, - m_repeat=m_repeat, - num_acc_n=num_acc_n_packed, - tx=tx, - lane_div_16=lane_div_16, - lane_mod_16=lane_mod_16, - bx_m=bx_m, - by_n=by_n // 2, - n_tile_base=n_tile_base // 2, - lds_out=lds_out, - frag_elem_type=I.f32, - write_row_to_lds=write_row_to_lds, - precompute_row=precompute_row, - store_pair=store_pair, - ) - return - - def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): - # `row` is the sorted-row index (bx_m + row_in_tile). - fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) - t2 = fused2 & mask24_i32 - s2 = fused2 >> 24 - - t_valid = arith.cmpu(t2, tokens_i32, "ult") - # No explicit mask: rely on buffer descriptor OOB to zero-fill when t2 is the - # sentinel (t2 == tokens) or otherwise out-of-range. - - # out linear index base = ((t*topk + s)*inter_dim) (invariant across ni) - idx0 = (t2 * topk_i32_v + s2) * inter_i32_local - - # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: - tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=f32) - - _if_valid = scf.IfOp(t_valid) - with _if_valid.then(): - for ni in range_constexpr(num_acc_n_packed): - col_i32 = col_i32_list[ni] - acc_idx = mi * num_acc_n_packed + ni - vg = vector.extract( - acc_gate[acc_idx], static_position=[ii], dynamic_position=[] - ) - vu = vector.extract( - acc_up[acc_idx], static_position=[ii], dynamic_position=[] - ) - if enable_bias: - gate_bias_list, up_bias_list = epilogue_pf - vg = vg + gate_bias_list[ni] - vu = vu + up_bias_list[ni] - - if act == "swiglu:": - y = swiglu(vg, vu) - else: - y = silu(vg) * vu - - if doweight_stage1: - y = y * tw - - y = arith.trunc_f(_out_elem_type(), y) - idx_out = idx0 + col_i32 - buffer_ops.buffer_store(y, out_rsrc, idx_out) - - mfma_epilog( - use_cshuffle=False, - arith=arith, - range_constexpr=range_constexpr, - m_repeat=m_repeat, - lane_div_16=lane_div_16, - bx_m=bx_m, - body_row=_stage1_store_row, - ) - - @flir.jit - def __call__( - self: flir.T.i64, - arg_out: lambda: T.memref(DYN, _out_elem_type()), - arg_x: lambda: T.memref(DYN, _x_elem_type()), - arg_w: lambda: T.memref(DYN, _w_elem_type()), - arg_scale_x: lambda: T.memref(DYN, _scale_elem_type()), - arg_scale_w: lambda: T.memref(experts * (2 * inter_dim), _scale_elem_type()), - arg_sorted_token_ids: lambda: T.memref(size_sorted, T.i32()), - arg_expert_ids: lambda: T.memref(DYN, T.i32()), - arg_sorted_weights: lambda: T.memref(DYN, T.f32()), - arg_max_token_ids: lambda: T.memref(DYN, T.i32()), - arg_bias: lambda: T.memref(DYN, T.f32()), - tokens_in: lambda: T.index(), - inter_in: lambda: T.index(), - k_in: lambda: T.index(), - size_expert_ids_in: lambda: T.index(), - ): - bdx = 256 - gx = 2 * inter_in / arith.index(tile_n) - # Use host-provided upper bound for M blocks (same as aiter moe_sorting allocation). - # This avoids device->host sync on num_valid_ids. - gy = size_expert_ids_in - flir.gpu_ext.LaunchFuncOp( - [module_name, "moe_gemm1"], - grid_size=(gx, gy, 1), - block_size=(bdx, 1, 1), - kernel_operands=[ - arg_out, - arg_x, - arg_w, - arg_scale_x, - arg_scale_w, - arg_sorted_token_ids, - arg_expert_ids, - arg_sorted_weights, - arg_max_token_ids, - arg_bias, - tokens_in, - inter_in, - k_in, - size_expert_ids_in, - ], - ) - - m = _MOE1() - exe = flydsl.compile(m) - return exe - - -def compile_moe_gemm1_dispatch( - *, - model_dim: int, - inter_dim: int, - experts: int, - topk: int, - tile_m: int, - tile_n: int, - tile_k: int, - doweight_stage2: bool, - x_dtype: str = "fp8", - w_dtype: str = "fp8", - out_dtype: str = "f16", - use_cshuffle_epilog: bool | None = None, - gate_up_interleave: bool = False, -): - # Compile based on mode - if not gate_up_interleave: - return compile_moe_gemm1( - model_dim=model_dim, - inter_dim=inter_dim, - experts=experts, - topk=topk, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - doweight_stage2=doweight_stage2, - x_dtype=x_dtype, - w_dtype=w_dtype, - out_dtype=out_dtype, - use_cshuffle_epilog=use_cshuffle_epilog, - ) - else: - return compile_gate_up_moe_gemm1( - model_dim=model_dim, - inter_dim=inter_dim, - experts=experts, - topk=topk, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - doweight_stage2=doweight_stage2, - x_dtype=x_dtype, - w_dtype=w_dtype, - out_dtype=out_dtype, - use_cshuffle_epilog=use_cshuffle_epilog, - ) - -####==================== gemm1 pipeline end =====================### - - -####==================== gemm2 pipeline start =====================### - -@functools.lru_cache(maxsize=1024) -def compile_moe_gemm2( - *, - model_dim: int, - inter_dim: int, - experts: int, - topk: int, - tile_m: int, - tile_n: int, - tile_k: int, - doweight_stage2: bool, - x_dtype: str = "fp8", - w_dtype: str = "fp8", - out_dtype: str = "f16", - use_cshuffle_epilog: bool | None = None, - # Optional experiment: write per-(token,slot) output (no atomics) into an output shaped - # [tokens*topk, model_dim] (or [tokens, topk, model_dim] flattened), then reduce over topk outside. - # This can reduce atomic contention for small tokens at the cost of extra bandwidth / reduction. - accumulate: bool = True, ): """Compile stage2 kernel (`moe_gemm2`) and return the compiled executable. diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index 555e5be9..36e27474 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -23,6 +23,7 @@ from tests.utils import pertoken_quant, shuffle_weight from tests.test_common import verify_output, run_perftest from flydsl.runtime.device import get_rocm_arch +from tests.kernels.utils import fp4_utils ARCH = get_rocm_arch() # GFX950 (MI350) and newer typically use OCP standard float8_e4m3fn @@ -387,39 +388,53 @@ def run_moe_stage1( x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=torch.int8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) + elif x_dtype == "fp8" and w_dtype == "fp4": + x_q = x_fp32.to(DTYPE_FP8) + scale_x = torch.ones([tokens, model_dim // 32], dtype=fp4_utils.fp8_e8m0, device=device) + w1_q, scale_w1, w1_convert = fp4_utils.per_1x32_f4_quant(w1_fp32) # (E, 2*inter, K) + w2_q, scale_w2, w2_convert = fp4_utils.per_1x32_f4_quant(w2_fp32) # (E, model_dim, inter) else: raise ValueError(f"Invalid combination of x_dtype and w_dtype: {x_dtype!r}, {w_dtype!r}") # Preshuffle weights (aiter/CK layout) on the *unpacked* tensor. - w1_shuffled = shuffle_weight(w1_q) - w2_shuffled = shuffle_weight(w2_q) if w_dtype == "fp8" else None + if w_dtype != "fp4": + w1_shuffled = shuffle_weight(w1_q) + w2_shuffled = shuffle_weight(w2_q) if w_dtype == "fp8" else None + # Flatten W1 for our flir kernel (treat expert dim as part of N). + w1_shuffled_flat = w1_shuffled.view(experts * (2 * inter_dim), model_dim) + w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim) + scale_w1_flat = None if scale_w1 is None else scale_w1.view(experts * (2 * inter_dim), 1) + w_kernel = ( + _pack_shuffled_int8_to_packed_int4_no_perm(w1_shuffled_flat) if is_int4 else w1_shuffled_flat + ).contiguous() + if not is_int4: + w_kernel = w_kernel.view(experts * (2 * inter_dim), model_dim) + if scale_w1_flat is None: + scale_w1_1d = torch.empty((0,), device=device, dtype=torch.float32) + else: + scale_w1_1d = scale_w1_flat.view(-1).contiguous() # [rows] + + else: + w1_shuffled = fp4_utils.shuffle_weight_w4(w1_q, 16, True, True) + w2_shuffled = fp4_utils.shuffle_weight_w4(w2_q, 16, True, True) + scale_w1_shuffled = fp4_utils.shuffle_scale_w4(scale_w1, experts, True) + scale_w1_1d = scale_w1_shuffled + w_kernel = w1_shuffled - # Flatten W1 for our flir kernel (treat expert dim as part of N). - w1_shuffled_flat = w1_shuffled.view(experts * (2 * inter_dim), model_dim) - w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim) - scale_w1_flat = None if scale_w1 is None else scale_w1.view(experts * (2 * inter_dim), 1) # No host-side padding: keep tensors contiguous and rely on kernel-side resource sizes / early-exit. x_q = x_q.contiguous().view(tokens, model_dim) - w_kernel = ( - _pack_shuffled_int8_to_packed_int4_no_perm(w1_shuffled_flat) if is_int4 else w1_shuffled_flat - ).contiguous() - if not is_int4: - w_kernel = w_kernel.view(experts * (2 * inter_dim), model_dim) - # Flatten scales to 1D memrefs (fp16 path uses 0-sized scale tensors; kernel ignores them). if scale_x is None: scale_x_1d = torch.empty((0,), device=device, dtype=torch.float32) else: scale_x_1d = scale_x.view(-1).contiguous() # [tokens] - if scale_w1_flat is None: - scale_w1_1d = torch.empty((0,), device=device, dtype=torch.float32) - else: - scale_w1_1d = scale_w1_flat.view(-1).contiguous() # [rows] + sorted_weights_1d = sorted_weights.contiguous().view(-1) # [sorted_size] # Output: [tokens, topk, inter_dim] fp16 - out = torch.empty((tokens, topk, inter_dim), device=device, dtype=torch.float16) + out_torch_dtype = DTYPE_FP8 if out_dtype == "fp8" else torch.float16 + out = torch.empty((tokens, topk, inter_dim), device=device, dtype=out_torch_dtype) exe = compile_moe_gemm1( model_dim=model_dim, @@ -469,11 +484,21 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): torch.cuda.synchronize() if not bool(skip_ref): + if w_dtype != "fp4": + x_ref = x_q + w1_flat_ref = w1_q_flat + scale_x_ref = scale_x + scale_w1_flat_ref = scale_w1_flat + else: + x_ref = x_f32 + w1_flat_ref = w1_f32 + scale_x_ref = None + scale_w1_flat_ref = None ref = torch_moe_gemm1( - x_q, - w1_q_flat, - scale_x, - scale_w1_flat, + x_ref, + w1_flat_ref, + scale_x_ref, + scale_w1_flat_ref, topk_ids.to(torch.int64), topk_weights, inter_dim=inter_dim, @@ -741,12 +766,23 @@ def run_moe_stage2( x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=torch.int8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) + elif x_dtype == "fp8" and w_dtype == "fp4": + x_q = x_fp32.to(DTYPE_FP8) + scale_x = torch.ones([tokens, model_dim // 32], dtype=fp4_utils.fp8_e8m0, device=device) + w1_q, scale_w1, w1_convert = fp4_utils.per_1x32_f4_quant(w1_fp32) # (E, 2*inter, K) + w2_q, scale_w2, w2_convert = fp4_utils.per_1x32_f4_quant(w2_fp32) # (E, model_dim, inter) else: raise ValueError(f"Invalid combination of x_dtype and w_dtype: {x_dtype!r}, {w_dtype!r}") # Preshuffle weights (aiter/CK layout) on the *unpacked* tensor. - w1_shuffled = shuffle_weight(w1_q) - w2_shuffled = shuffle_weight(w2_q) + if w_dtype == "fp4": + w1_shuffled = fp4_utils.shuffle_weight_w4(w1_q, 16, True, True) + w2_shuffled = fp4_utils.shuffle_weight_w4(w2_q, 16, True, True) + scale_w1_shuffled = fp4_utils.shuffle_scale_w4(scale_w1, experts, True) + scale_w2_shuffled = fp4_utils.shuffle_scale_w4(scale_w2, experts, True) + else: + w1_shuffled = shuffle_weight(w1_q) + w2_shuffled = shuffle_weight(w2_q) # Stage2 input (A2): either provided (gemm1->quantize chaining) or built from stage1 reference. if a2_fp8_in is not None and (a2_scale_in is not None or (x_dtype == "fp16" and w_dtype == "fp16")): @@ -762,10 +798,10 @@ def run_moe_stage2( "(so we don't have to run the huge torch reference stage1)." ) out1_ref = torch_moe_gemm1( - x_q, - w1_q_flat, - scale_x, - scale_w1_flat, + x_fp32_in, + w1_fp32_in, + None, + None, topk_ids.to(torch.int64), topk_weights, inter_dim=inter_dim, @@ -877,11 +913,21 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): torch.cuda.synchronize() if not bool(skip_ref): + if w_dtype != "fp4": + a2_ref = a2_q + w2_ref = w2_q + a2_scale_ref = a2_scale + scale_w2_ref = scale_w2 + else: + a2_ref = a2_q.to(torch.float32) + w2_ref = w2_fp32_in + a2_scale_ref = None + scale_w2_ref = None ref2 = torch_moe_gemm2( - a2_q, - w2_q, - a2_scale, - scale_w2, + a2_ref, + w2_ref, + a2_scale_ref, + scale_w2_ref, topk_ids.to(torch.int64), topk_weights, model_dim=model_dim, @@ -1108,6 +1154,12 @@ def test_moe_gemm_2stage( elif x_dtype == "fp16" and w_dtype == "fp16": a2_q = out1_fp16 a2_scale = None + elif x_dtype == "fp8" and w_dtype == "fp4": + if out1_fp16.dtype == torch.float16: + a2_q = out1_fp16.to(DTYPE_FP8) + else: + a2_q = out1_fp16 + a2_scale = torch.ones([tokens, topk, inter_dim // 32], dtype=fp4_utils.fp8_e8m0, device=device) else: out1_fp32 = out1_fp16.to(torch.float32) a2_q, a2_scale = pertoken_quant(out1_fp32, quant_dtype=torch.int8) From 4f7ae348aee3d7231156ae9ace0951c3d02378dc Mon Sep 17 00:00:00 2001 From: Zzz9990 Date: Thu, 5 Feb 2026 20:36:47 -0600 Subject: [PATCH 07/11] combine finished --- kernels/mfma_preshuffle_pipeline.py | 10 +- kernels/moe_gemm_2stage.py | 228 +++++++++++++++++----------- tests/kernels/test_moe_gemm.py | 105 +++++++++---- tests/kernels/test_ref.py | 19 ++- 4 files changed, 243 insertions(+), 119 deletions(-) diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index a4b9cb22..59a05a1b 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -34,6 +34,7 @@ class EpilogPipeline(Enum): CSHUFFLE_F16 = "CSHUFFLE_F16" CSHUFFLE_BF16 = "CSHUFFLE_BF16" CSHUFFLE_F32 = "CSHUFFLE_F32" + CSHUFFLE_F8 = "CSHUFFLE_F8" DIRECT_F16 = "DIRECT_F16" DIRECT_BF16 = "DIRECT_BF16" DIRECT_F32 = "DIRECT_F32" @@ -73,6 +74,7 @@ class EpilogPipeline(Enum): EpilogPipeline.CSHUFFLE_F16: lambda: T.f16, EpilogPipeline.CSHUFFLE_BF16: lambda: T.bf16, EpilogPipeline.CSHUFFLE_F32: lambda: T.f32, + EpilogPipeline.CSHUFFLE_F8: lambda: T.f8, EpilogPipeline.DIRECT_F16: lambda: T.f16, EpilogPipeline.DIRECT_BF16: lambda: T.bf16, EpilogPipeline.DIRECT_F32: lambda: T.f32, @@ -185,7 +187,7 @@ def _normalize_dtype(value: str) -> str: self.b_dtype = _normalize_dtype(self.b_dtype) self.out_dtype = _normalize_dtype(self.out_dtype) - if self.out_dtype not in ("fp16", "bf16", "f32"): + if self.out_dtype not in ("fp16", "bf16", "f32", "fp8"): raise ValueError( f"out_dtype must be 'f16', 'bf16', or 'f32', got {self.out_dtype!r}" ) @@ -195,7 +197,7 @@ def check_type_valid(self): raise ValueError(f"Invalid a_dtype: {self.a_dtype}") if self.b_dtype not in ["fp8", "fp4", "int8", "int4", "fp16", "bf16"]: raise ValueError(f"Invalid b_dtype: {self.b_dtype}") - if self.out_dtype not in ["fp16", "bf16", "f32"]: + if self.out_dtype not in ["fp16", "bf16", "f32", "fp8"]: raise ValueError(f"Invalid out_dtype: {self.out_dtype}") def get_mfma_pipeline(self): @@ -223,6 +225,8 @@ def get_epilog_pipeline(self): return EpilogPipeline.CSHUFFLE_BF16 elif self.use_cshuffle_epilog and self.out_dtype == "f32": return EpilogPipeline.CSHUFFLE_F32 + elif self.use_cshuffle_epilog and self.out_dtype == "fp8": + return EpilogPipeline.CSHUFFLE_F8 elif not self.use_cshuffle_epilog and self.out_dtype == "f32": return EpilogPipeline.DIRECT_F32 elif not self.use_cshuffle_epilog and self.out_dtype == "fp16": @@ -249,6 +253,8 @@ def get_a_elem_bytes(self): raise ValueError(f"Invalid a_dtype: {self.a_dtype}") def get_out_elem_bytes(self): + if self.out_dtype in ["fp8", "int8", "int4", "fp4"]: + return 1 if self.out_dtype in ["fp16", "bf16"]: return 2 elif self.out_dtype == "f32": diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 573cc116..98d71ca0 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -77,7 +77,7 @@ def compile_moe_gemm1( total_threads = 256 - pipeline_manager = PreshufflePipelineManager(x_dtype, w_dtype, out_dtype) + pipeline_manager = PreshufflePipelineManager(x_dtype, w_dtype, out_dtype, use_cshuffle_epilog) pipeline_manager.check_type_valid() epilog_pipeline = pipeline_manager.get_epilog_pipeline() @@ -126,6 +126,7 @@ def _mfma_output_pack_ty(): is_int8 = mfma_pipeline in [MfmaPipeline.I8I8_16x16_PIPELINE] is_fp4 = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, MfmaPipeline.F8F4_MXFP4_PIPELINE] + is_fp8_a = mfma_pipeline == MfmaPipeline.F8F4_MXFP4_PIPELINE is_int_mode = is_int8 or is_int4 @@ -136,14 +137,15 @@ def _mfma_output_pack_ty(): pack_K = 2 if is_fp4 else 1 # FP4 specific parameters for mfma_scale_f32_16x16x128_f8f6f4 (gemm1) - cbsz_g1 = 0 if mfma_pipeline == MfmaPipeline.F4F4_MXFP4_PIPELINE else 4 # fp8 a: cbsz=0, fp4 a: cbsz=4 + # cbsz encodes A matrix type: 0 for fp8, 4 for fp4 + cbsz_g1 = 0 if mfma_pipeline == MfmaPipeline.F8F4_MXFP4_PIPELINE else 4 blgp_g1 = 4 is_gate_up_inter = is_fp4 - is_cast_out = out_dtype == "f8" + is_cast_out = out_dtype == "fp8" if is_cast_out and not use_cshuffle_epilog: - raise ValueError("out_dtype='f' requires CShuffle epilogue (set use_cshuffle_epilog=True).") + raise ValueError("out_dtype='fp8' requires CShuffle epilogue (set use_cshuffle_epilog=True).") DYN = ir.ShapedType.get_dynamic_size() size_out = DYN @@ -177,8 +179,8 @@ def _mfma_output_pack_ty(): use_cshuffle_epilog = os.environ.get("FLIR_MOE_STAGE1_CSHUFFLE", "1") in ("1", "true", "True", "YES", "yes") use_cshuffle_epilog = bool(use_cshuffle_epilog) - if out_dtype != "f16" and use_cshuffle_epilog: - raise ValueError("stage1 cshuffle epilog currently supports only f16 output (out_dtype='f16')") + # if out_dtype != "f16" and use_cshuffle_epilog: + # raise ValueError("stage1 cshuffle epilog currently supports only f16 output (out_dtype='f16')") epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" # IMPORTANT: module name participates in FlyDSL's compile cache key. @@ -201,7 +203,8 @@ def init_gpu_module(self): # - epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes) _use_cshuffle_epilog = bool(use_cshuffle_epilog) lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(x_elem_bytes) - lds_out_bytes = 2 * tile_m * tile_n if _use_cshuffle_epilog else 0 + lds_out_elem_bytes = 2 if out_dtype in ["fp16", "bf16"] else 4 + lds_out_bytes = lds_out_elem_bytes * tile_m * tile_n if _use_cshuffle_epilog else 0 lds_total_bytes = max(lds_x_bytes, lds_out_bytes) lds_total_elems = lds_total_bytes if x_elem_bytes == 1 else (lds_total_bytes // 2) # x_lds_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) @@ -221,6 +224,7 @@ def moe_gemm1( arg_expert_ids: lambda: T.memref(DYN, T.i32()), arg_sorted_weights: lambda: T.memref(DYN, T.f32()), arg_max_token_ids: lambda: T.memref(DYN, T.i32()), + arg_bias: lambda: T.memref(DYN, T.f32()), tokens_in: lambda: T.index(), inter_in: lambda: T.index(), k_in: lambda: T.index(), @@ -334,7 +338,7 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): layout_tx_wave_lane = flir.make_layout((4, 64), stride=(64, 1)) layout_lane16 = flir.make_layout((4, 16), stride=(16, 1)) - # bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None + bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None # Everything below is gated by `blk_valid` to avoid doing buffer-resource setup and # gmem work for padding blocks. @@ -346,7 +350,7 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): # Alias LDS bytes as fp16 for optional CShuffle epilogue. _use_cshuffle_epilog = bool(use_cshuffle_epilog) lds_out = ( - SmemPtr(base_ptr, lds_x_ptr.byte_offset, I.f16, shape=(tile_m * tile_n,)).get() + SmemPtr(base_ptr, lds_x_ptr.byte_offset, I.f16 if not is_cast_out else I.f32, shape=(tile_m * tile_n,)).get() if _use_cshuffle_epilog else None ) @@ -371,12 +375,17 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): arg_out, max_size=False, num_records_bytes=out_nbytes_i32 ) - # fp16 path ignores scales completely (implicit scale=1.0). - if no_epilogue_dequant: + # fp16/bf16 path ignores scales completely (implicit scale=1.0). + # FP4 path uses scales in mfma_scale instruction, so we still need scale resources. + if is_f16_or_bf16: sx_rsrc = None sw_rsrc = None + elif is_fp4: + # FP4: scale is used in mfma_scale_f32_16x16x128_f8f6f4 instruction + sx_rsrc = buffer_ops.create_buffer_resource(arg_scale_x, max_size=False) + sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) else: - # scale_x: [tokens] f32 -> bytes = tokens*4 + # fp8/int8 path: scale_x: [tokens] f32 -> bytes = tokens*4 sx_nbytes_idx = tokens_in * arith.constant(4, index=True) sx_nbytes_i32 = arith.index_cast(i32, sx_nbytes_idx) sx_rsrc = buffer_ops.create_buffer_resource( @@ -548,6 +557,9 @@ def load_x_tile(base_k): n_blk_gate = [] n_intra_up = [] n_blk_up = [] + # Single n_blk_list for interleaved mode (matching mixed_moe_gemm_2stage.py) + n_blk_list = [] + n_intra_list = [] col_g_list = [] inter_idx = arith.constant(inter_dim, index=True) # layout for (row -> (blk,intra)) where intra is 0..15 @@ -556,10 +568,23 @@ def load_x_tile(base_k): for ni in range_constexpr(num_acc_n): offset = arith.constant(ni * 16, index=True) col_g = by_n + n_tile_base - col_g = col_g + offset + col_g = col_g // 2 + offset if is_gate_up_inter else col_g + offset col_g = col_g + lane_mod_16 col_g_list.append(col_g) + # For both interleaved and non-interleaved layout, global_n maps to physical layout + global_n = by_n + n_tile_base + offset + lane_mod_16 + row_w = expert_off_idx + global_n + + coord_n = flir.idx2crd(row_w, layout_n_blk_intra) + n_blk = flir.get(coord_n, 0) + n_intra = flir.get(coord_n, 1) + + # For interleaved mode, use single list + n_blk_list.append(n_blk) + n_intra_list.append(n_intra) + + # For non-interleaved mode, keep separate gate/up lists row_gate = expert_off_idx + col_g row_up = row_gate + inter_idx @@ -576,9 +601,9 @@ def load_x_tile(base_k): k_unroll = tile_k_bytes // 64 // pack_K # FP4 packed parameters for mfma_scale (gemm1) - k_unroll_packed_g1 = k_unroll // pack_K - m_repeat_packed_g1 = m_repeat // pack_M - num_acc_n_packed_g1 = num_acc_n // pack_N + k_unroll_packed = k_unroll // pack_K + m_repeat_packed = m_repeat // pack_M + num_acc_n_packed = num_acc_n // pack_N # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- def load_b_pack(base_k, ki_step, ni, blk_list, intra_list): @@ -629,7 +654,7 @@ def load_scale_inter(arg_scale, rsrc, layout, ku, mni): coord_pack = flir.make_coord(mni, ku, k_lane, n_lane) idx_pack = flir.crd2idx(coord_pack, layout) scale_view = flir.TensorView( - arg_scale_w, + arg_scale, (1,), strides=(1,), base_indices=(idx_pack,), @@ -649,9 +674,9 @@ def load_scale_inter(arg_scale, rsrc, layout, ku, mni): def load_b_scale_tile_inter(base_k): """Load B scale tile for FP4 pipeline (gemm1).""" b_scale_tile = [] - for ku in range_constexpr(k_unroll_packed_g1): - for ni in range_constexpr(num_acc_n_packed_g1): - scale = load_scale_fp4_g1( + for ku in range_constexpr(k_unroll_packed): + for ni in range_constexpr(num_acc_n_packed): + scale = load_scale_inter( arg_scale_w, sw_rsrc, layout_b_scale, @@ -663,7 +688,7 @@ def load_b_scale_tile_inter(base_k): def prefetch_ab_scale_tile_inter(base_k): """Prefetch A and B scale tiles for FP4 pipeline (gemm1).""" - return placeholder, load_b_scale_tile_fp4_g1(base_k) + return placeholder, load_b_scale_tile_inter(base_k) # --- FP4 B Load Logic for gemm1 (interleaved gate+up) --- def load_b_packs_k64_inter(base_k, ku: int, ni: int): @@ -672,8 +697,8 @@ def load_b_packs_k64_inter(base_k, ku: int, ni: int): k0_base = base_k_bytes / 64 k0 = k0_base + ku k1 = lane_div_16 - # For interleaved gate+up, use col_g_list which contains both gate and up columns - coord_pack = flir.make_coord(n_blk_gate[ni // 2] if ni % 2 == 0 else n_blk_up[ni // 2], k0, k1, n_intra_gate[ni // 2] if ni % 2 == 0 else n_intra_up[ni // 2], 0) + # Use n_blk_list directly for interleaved layout + coord_pack = flir.make_coord(n_blk_list[ni], k0, k1, n_intra_list[ni], 0) idx_pack = flir.crd2idx(coord_pack, layout_b) vec_elems = 16 @@ -705,7 +730,7 @@ def load_b_tile_inter(base_k): for ku in range_constexpr(k_unroll): packs0 = [] packs1 = [] - for ni in range_constexpr(num_acc_n * 2): # gate+up interleaved + for ni in range_constexpr(num_acc_n): # Match mixed_moe_gemm_2stage.py b0, b1 = load_b_packs_k64_inter(base_k, ku, ni) packs0.append(b0) packs1.append(b1) @@ -889,21 +914,20 @@ def compute_tile_inter( mfma_res_ty = _mfma_output_pack_ty() epilogue_pf = None - # if enable_bias and prefetch_epilogue: - # expert_off_idx + by_n + n_tile_base - # gate_bias = [] - # up_bias = [] - # for ni in range_constexpr(num_acc_n_packed): - # global_n = (by_n + n_tile_base) // 2 + ni * 16 + lane_mod_16 - # gate_offset = expert_off_idx + global_n - # up_offset = expert_off_idx + global_n + inter_dim - # gate_bias.append( - # buffer_ops.buffer_load(bias_rsrc, gate_offset, vec_width=1, dtype=f32) - # ) - # up_bias.append( - # buffer_ops.buffer_load(bias_rsrc, up_offset, vec_width=1, dtype=f32) - # ) - # epilogue_pf = (gate_bias, up_bias) + if enable_bias and prefetch_epilogue: + gate_bias = [] + up_bias = [] + for ni in range_constexpr(num_acc_n_packed): + global_n = (by_n + n_tile_base) // 2 + ni * 16 + lane_mod_16 + gate_offset = expert_off_idx + global_n + up_offset = expert_off_idx + global_n + inter_dim + gate_bias.append( + buffer_ops.buffer_load(bias_rsrc, gate_offset, vec_width=1, dtype=f32) + ) + up_bias.append( + buffer_ops.buffer_load(bias_rsrc, up_offset, vec_width=1, dtype=f32) + ) + epilogue_pf = (gate_bias, up_bias) c0_i64 = arith.constant(0, type=I.i64) vec4_i64 = I.vec(4, I.i64) @@ -913,8 +937,6 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) return vector.bitcast(vec8_i32, v4) - col_offset_base_fp4 = flir.crd2idx(flir.make_coord(lane_div_16, 0), layout_lane16) - # FP4 path using mfma_scale_f32_16x16x128_f8f6f4 for ku128 in range_constexpr(k_unroll_packed): for mi in range_constexpr(m_repeat_packed): @@ -926,7 +948,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for ikxdl in range_constexpr(pack_K): k_idx = ku128 * pack_K + ikxdl b_packs0, b_packs1 = b_tile_in[k_idx] - col_base = col_offset_base_fp4 + (k_idx * 128) // a_elem_vec_pack_g1 + col_base = col_offset_base_bytes + (k_idx * 128) // x_elem_pack for imxdl in range_constexpr(pack_M): mi_idx = mi * pack_M + imxdl @@ -938,7 +960,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): else: a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) - if is_f8_a: + if is_fp8_a: col_base1 = col_base + 64 a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) @@ -1169,6 +1191,8 @@ def _prefetch_scale(k_val): # Uses EVec=4 (buffer store "x4" of fp16 elements). _use_cshuffle_epilog = bool(use_cshuffle_epilog) + + num_acc_n_final = num_acc_n // 2 if is_gate_up_inter else num_acc_n if _use_cshuffle_epilog: if lds_out is None: @@ -1196,10 +1220,10 @@ def write_row_to_lds( if doweight_stage1: tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=f32) - for ni in range_constexpr(num_acc_n // 2 if is_gate_up_inter else num_acc_n): + for ni in range_constexpr(num_acc_n_final): col_local = col_base_local + (ni * 16) - acc_idx = mi * num_acc_n + ni + acc_idx = mi * num_acc_n_final + ni vg = vector.extract( acc_gate[acc_idx], static_position=[ii], dynamic_position=[] ) @@ -1217,10 +1241,10 @@ def write_row_to_lds( vg = vg * sx * sw_gate vu = vu * sx * sw_up - # if enable_bias: - # gate_bias_list, up_bias_list = epilogue_pf - # vg = vg + gate_bias_list[ni] - # vu = vu + up_bias_list[ni] + if enable_bias: + gate_bias_list, up_bias_list = epilogue_pf + vg = vg + gate_bias_list[ni] + vu = vu + up_bias_list[ni] if act == "swiglu": y = swiglu(vg, vu) @@ -1287,7 +1311,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): tile_n=tile_n // 2 if is_gate_up_inter else tile_n, e_vec=4, m_repeat=m_repeat, - num_acc_n=num_acc_n, + num_acc_n=num_acc_n_packed if is_gate_up_inter else num_acc_n, tx=tx, lane_div_16=lane_div_16, lane_mod_16=lane_mod_16, @@ -1295,6 +1319,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): by_n=by_n // 2 if is_gate_up_inter else by_n, n_tile_base=n_tile_base // 2 if is_gate_up_inter else n_tile_base, lds_out=lds_out, + frag_elem_type=I.f32 if is_cast_out else I.f16, write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, @@ -1322,10 +1347,10 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): if doweight_stage1: tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=f32) - for ni in range_constexpr(num_acc_n // 2 if is_gate_up_inter else num_acc_n): + for ni in range_constexpr(num_acc_n_final): col_i32 = col_i32_list[ni] - acc_idx = mi * num_acc_n + ni + acc_idx = mi * num_acc_n_final + ni vg = vector.extract( acc_gate[acc_idx], static_position=[ii], dynamic_position=[] ) @@ -1343,10 +1368,10 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): vg = vg * sx * sw_gate vu = vu * sx * sw_up - # if enable_bias: - # gate_bias_list, up_bias_list = epilogue_pf - # vg = vg + gate_bias_list[ni] - # vu = vu + up_bias_list[ni] + if enable_bias: + gate_bias_list, up_bias_list = epilogue_pf + vg = vg + gate_bias_list[ni] + vu = vu + up_bias_list[ni] if act == "swiglu": y = swiglu(vg, vu) @@ -1381,16 +1406,17 @@ def __call__( arg_expert_ids: lambda: T.memref(DYN, T.i32()), arg_sorted_weights: lambda: T.memref(DYN, T.f32()), arg_max_token_ids: lambda: T.memref(DYN, T.i32()), + arg_bias: lambda: T.memref(DYN, T.f32()), tokens_in: lambda: T.index(), inter_in: lambda: T.index(), k_in: lambda: T.index(), size_expert_ids_in: lambda: T.index(), ): bdx = 256 - gx = inter_in / arith.index(tile_n) + gx = (arith.index(2) * inter_in / arith.index(tile_n)) if is_gate_up_inter else (inter_in / arith.index(tile_n)) # Use host-provided upper bound for M blocks (same as aiter moe_sorting allocation). # This avoids device->host sync on num_valid_ids. - gy = size_expert_ids_in * 2 if is_gate_up_inter else size_expert_ids_in + gy = size_expert_ids_in flir.gpu_ext.LaunchFuncOp( [module_name, "moe_gemm1"], grid_size=(gx, gy, 1), @@ -1405,6 +1431,7 @@ def __call__( arg_expert_ids, arg_sorted_weights, arg_max_token_ids, + arg_bias, tokens_in, inter_in, k_in, @@ -1467,7 +1494,7 @@ def compile_moe_gemm2( _state = {} - pipeline_manager = PreshufflePipelineManager(x_dtype, w_dtype, out_dtype) + pipeline_manager = PreshufflePipelineManager(x_dtype, w_dtype, out_dtype, use_cshuffle_epilog) pipeline_manager.check_type_valid() epilog_pipeline = pipeline_manager.get_epilog_pipeline() @@ -1524,7 +1551,8 @@ def _mfma_output_pack_ty(): pack_M = 2 if is_fp4 else 1 pack_N = 2 if is_fp4 else 1 pack_K = 2 if is_fp4 else 1 - cbsz = 0 if mfma_pipeline == MfmaPipeline.F8F4_MXFP4_PIPELINE else 4 # fp8 a: cbsz=0, fp4 a: cbsz=4 + is_fp8_a = mfma_pipeline == MfmaPipeline.F8F4_MXFP4_PIPELINE + cbsz = 0 if is_fp8_a else 4 # fp8 a: cbsz=0, fp4 a: cbsz=4 blgp = 4 out_s = pipeline_manager.out_dtype @@ -1633,6 +1661,7 @@ def moe_gemm2( arg_expert_ids: lambda: T.memref(size_expert_ids_shape, T.i32()), arg_sorted_weights: lambda: T.memref(size_sorted, T.f32()), arg_num_valid_ids: lambda: T.memref(DYN, T.i32()), + arg_bias: lambda: T.memref(DYN, T.f32()), tokens_in: lambda: T.index(), n_in: lambda: T.index(), k_in: lambda: T.index(), @@ -1673,9 +1702,12 @@ def moe_gemm2( c_n_total = arith.constant(experts * model_dim, index=True) kpack_bytes = 8 if is_int4 else 16 b_layout = make_preshuffle_b_layout( - flir, arith, c_n=c_n_total, c_k=k_in, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes + flir, arith, c_n=c_n_total, c_k=k_in // pack_K, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes ) layout_b = b_layout.layout_b + layout_b_scale = make_preshuffle_scale_layout( + flir, arith, c_mn=c_n_total, c_k=k_in, + ) shape_lds = flir.make_shape(tile_m, tile_k) stride_lds = flir.make_stride(lds_stride, 1) @@ -1771,6 +1803,9 @@ def moe_gemm2( arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_i32 ) + # bias resource for optional bias addition + bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None + # expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4 eid_nbytes_idx = size_expert_ids_in * arith.constant(4, index=True) eid_nbytes_i32 = arith.index_cast(i32, eid_nbytes_idx) @@ -1954,7 +1989,7 @@ def load_x_tile(base_k): n_intra_list.append(flir.get(coord_w, 1)) m_repeat = tile_m // 16 - k_unroll = tile_k_bytes // 64 # K64-byte micro-step (2x MFMA) + k_unroll = tile_k_bytes // 64 if not is_fp4 else tile_k_bytes // 128 # K64-byte micro-step (2x MFMA) # FP4 packed parameters for mfma_scale_f32_16x16x128_f8f6f4 k_unroll_packed = k_unroll // pack_K @@ -2030,14 +2065,14 @@ def load_scale(arg_scale, rsrc, layout, ku, mni): def load_b_scale_tile(base_k): """Load B scale tile for FP4 pipeline (gemm2).""" b_scale_tile = [] - for ku in range_constexpr(k_unroll_packed_g2): - for ni in range_constexpr(num_acc_n_packed_g2): + for ku in range_constexpr(k_unroll_packed): + for ni in range_constexpr(num_acc_n_packed): scale = load_scale( arg_scale_w, sw_rsrc, layout_b_scale, ku + base_k, - ni + (expert_off_idx + by_n + n_tile_base) // pack_N_gemm2 // 16, + ni + (expert_off_idx + by_n + n_tile_base) // pack_N // 16, ) b_scale_tile.append(scale) return b_scale_tile @@ -2045,14 +2080,14 @@ def load_b_scale_tile(base_k): def load_a_scale_tile(base_k): """Load A scale tile for FP4 pipeline (gemm2).""" a_scale_tile = [] - for ku in range_constexpr(k_unroll_packed_g2): - for mi in range_constexpr(m_repeat_packed_g2): - scale = load_scale_fp4_g2( + for ku in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + scale = load_scale( arg_scale_x, sx_rsrc, layout_a_scale, ku + base_k, - mi + bx_m // pack_M_gemm2 // 16, + mi + bx_m // pack_M // 16, ) a_scale_tile.append(scale) return a_scale_tile @@ -2167,7 +2202,17 @@ def compute_tile(acc_in, b_tile_in, lds_base, *, a_scale=None, b_scale=None, pre sorted_w_rsrc, sorted_row_pf, vec_width=1, dtype=f32 ) ) - epilogue_pf = (sw_pf, tw_pf) + # Prefetch bias values when enabled + bias_pf = None + if enable_bias: + bias_pf = [] + for ni in range_constexpr(num_acc_n): + global_n = by_n + n_tile_base + ni * 16 + lane_mod_16 + bias_offset = expert_off_pf + global_n + bias_pf.append( + buffer_ops.buffer_load(bias_rsrc, bias_offset, vec_width=1, dtype=f32) + ) + epilogue_pf = (sw_pf, tw_pf, bias_pf) if is_fp4: c0_i64 = arith.constant(0, type=I.i64) @@ -2179,19 +2224,19 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): return vector.bitcast(vec8_i32, v4) # FP4 path using mfma_scale_f32_16x16x128_f8f6f4 - for ku128 in range_constexpr(k_unroll_packed_g2): - for mi in range_constexpr(m_repeat_packed_g2): - for ni in range_constexpr(num_acc_n_packed_g2): - b_scale_i32 = b_scale[ku128 * num_acc_n_packed_g2 + ni] + for ku128 in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + for ni in range_constexpr(num_acc_n_packed): + b_scale_i32 = b_scale[ku128 * num_acc_n_packed + ni] b_scale_val = vector.extract(b_scale_i32, static_position=[0], dynamic_position=[]) - for ikxdl in range_constexpr(pack_K_gemm2): - k_idx = ku128 * pack_K_gemm2 + ikxdl + for ikxdl in range_constexpr(pack_K): + k_idx = ku128 * pack_K + ikxdl b_packs0, b_packs1 = b_tile_in[k_idx] - col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack_gemm2 + col_base = col_offset_base + (k_idx * 128) // x_elem_pack - for imxdl in range_constexpr(pack_M_gemm2): + for imxdl in range_constexpr(pack_M): col_base0 = col_base - mi_idx = mi * pack_M_gemm2 + imxdl + mi_idx = mi * pack_M + imxdl mi_val = arith.constant(mi_idx * 16, index=True) curr_row_a_lds = row_a_lds + mi_val @@ -2200,15 +2245,15 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): else: a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) - if is_f8_a_gemm2: + if is_fp8_a: col_base1 = col_base + 64 a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) else: a128 = pack_i64x4_to_i32x8(a0, a1, c0_i64, c0_i64) - for inxdl in range_constexpr(pack_N_gemm2): - ni_idx = ni * pack_N_gemm2 + inxdl + for inxdl in range_constexpr(pack_N): + ni_idx = ni * pack_N + inxdl b0 = b_packs0[ni_idx] b1 = b_packs1[ni_idx] @@ -2222,12 +2267,12 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): a128, b128, acc_list[acc_idx], - cbsz_gemm2, - blgp_gemm2, + cbsz, + blgp, # use per tensor quant a1 for now, 0, 0x3F800000, - ikxdl * pack_N_gemm2 + inxdl, + ikxdl * pack_N + inxdl, b_scale_val, ], ) @@ -2363,7 +2408,7 @@ def hot_loop_scheduler(): # Select pipeline functions based on is_fp4 def _prefetch_scale(k_val): if is_fp4: - return prefetch_ab_scale_tile_fp4_g2(k_val // pack_K // 128) + return prefetch_ab_scale_tile(k_val // pack_K // 128) return placeholder, placeholder # Prologue. @@ -2486,8 +2531,9 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): sw_pf = None tw_pf = None + bias_pf = None if epilogue_pf is not None: - sw_pf, tw_pf = epilogue_pf + sw_pf, tw_pf, bias_pf = epilogue_pf # Weight scales for the N tile (col_g depends on lane/wave/by but not on (t,s)). if sw_pf is not None: @@ -2538,6 +2584,8 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): if not no_epilogue_dequant: sw = sw_vals[ni] v = v * sx * sw + if enable_bias: + v = v + bias_pf[ni] if doweight_stage2: v = v * tw @@ -2603,6 +2651,8 @@ def write_row_to_lds( if not no_epilogue_dequant: sw = sw_vals[ni] v = v * sx * sw + if enable_bias: + v = v + bias_pf[ni] if doweight_stage2: v = v * tw v_out = arith.trunc_f(_out_elem_type(), v) @@ -2699,6 +2749,7 @@ def __call__( arg_expert_ids: lambda: T.memref(size_expert_ids_shape, T.i32()), arg_sorted_weights: lambda: T.memref(size_sorted, T.f32()), arg_num_valid_ids: lambda: T.memref(DYN, T.i32()), + arg_bias: lambda: T.memref(DYN, T.f32()), tokens_in: lambda: T.index(), n_in: lambda: T.index(), k_in: lambda: T.index(), @@ -2721,6 +2772,7 @@ def __call__( arg_expert_ids, arg_sorted_weights, arg_num_valid_ids, + arg_bias, tokens_in, n_in, k_in, @@ -3151,4 +3203,4 @@ def compile_moe_gemm2_ex( accumulate=True, ) -####==================== gemm2 pipeline end =====================### \ No newline at end of file +####==================== gemm2 pipeline end =====================### diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index 36e27474..cd2ccd3a 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -303,6 +303,8 @@ def run_moe_stage1( routing_in: Optional[RoutingBuffers] = None, return_outputs: bool = False, skip_ref: bool = False, + enable_bias: bool = False, + bias_in: Optional[torch.Tensor] = None, ): assert model_dim % 64 == 0 assert model_dim % tile_k == 0 @@ -361,9 +363,9 @@ def run_moe_stage1( ) = routing if x_dtype not in ("fp8", "fp16", "int8"): - raise ValueError(f"x_dtype must be one of ('fp8','fp16','int8','int4'), got {x_dtype!r}") - if w_dtype not in ("fp8", "fp16", "int8", "int4"): - raise ValueError(f"w_dtype must be one of ('fp8','fp16','int8','int4'), got {w_dtype!r}") + raise ValueError(f"x_dtype must be one of ('fp8','fp16','int8'), got {x_dtype!r}") + if w_dtype not in ("fp8", "fp16", "int8", "int4", "fp4"): + raise ValueError(f"w_dtype must be one of ('fp8','fp16','int8','int4', 'fp4'), got {w_dtype!r}") is_int4 = w_dtype == "int4" is_int8 = x_dtype in ("int8", "int4") @@ -436,6 +438,19 @@ def run_moe_stage1( out_torch_dtype = DTYPE_FP8 if out_dtype == "fp8" else torch.float16 out = torch.empty((tokens, topk, inter_dim), device=device, dtype=out_torch_dtype) + # Bias: [experts, 2 * inter_dim] f32 (gate_bias, up_bias concatenated) + if enable_bias: + bias = ( + bias_in + if bias_in is not None + else torch.randn((experts, 2 * inter_dim), device=device, dtype=torch.float32) * 0.1 + ) + # Flatten bias for kernel: [experts * 2 * inter_dim] + bias_1d = bias.view(-1).contiguous() + else: + bias = None + bias_1d = torch.empty((0,), device=device, dtype=torch.float32) + exe = compile_moe_gemm1( model_dim=model_dim, inter_dim=inter_dim, @@ -449,9 +464,10 @@ def run_moe_stage1( tile_k=tile_k, doweight_stage1=bool(doweight_stage1), use_cshuffle_epilog=False, + enable_bias=enable_bias, ) - def launch(o, x, w, sx, sw, st, eids, sw_sorted): + def launch(o, x, w, sx, sw, st, eids, sw_sorted, bias_arg): exe( o, x, @@ -462,6 +478,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): eids, sw_sorted, num_valid_ids, + bias_arg, tokens, inter_dim, model_dim, @@ -478,6 +495,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): sorted_token_ids, sorted_expert_ids, sorted_weights_1d, + bias_1d, num_iters=int(num_iters), num_warmup=int(num_warmup), ) @@ -490,8 +508,8 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): scale_x_ref = scale_x scale_w1_flat_ref = scale_w1_flat else: - x_ref = x_f32 - w1_flat_ref = w1_f32 + x_ref = x_fp32_in + w1_flat_ref = w1_fp32_in scale_x_ref = None scale_w1_flat_ref = None ref = torch_moe_gemm1( @@ -503,13 +521,14 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): topk_weights, inter_dim=inter_dim, doweight_stage1=doweight_stage1, + bias=bias, ) rtol = 0.5 if is_int4 else 0.25 atol = 0.5 if is_int4 else 0.25 print(out.to(torch.float32)) print(ref) - assert verify_output(out.to(torch.float32), ref, rtol=1e-4, atol=1e-4) + assert verify_output(out.to(torch.float32), ref, rtol=1e-3, atol=1e-3) # Note: kernel launches full expert-block range; effective work is gated by num_valid_ids. flops = 2 * tokens * topk * (2 * inter_dim) * model_dim @@ -648,6 +667,8 @@ def run_moe_stage2( kernel_name: str = "moe_gemm2", # Use reduce mode (accumulate=False) instead of atomic mode. use_reduce: bool = False, + enable_bias: bool = False, + bias_in: Optional[torch.Tensor] = None, ): """MoE stage2 (gemm2): out2[t] = sum_{slot} ( out1[t,slot] @ W2[expert]^T ) with optional routed weight.""" @@ -816,28 +837,34 @@ def run_moe_stage2( a2_q, a2_scale = pertoken_quant(out1_ref, quant_dtype=torch.int8) # Flatten weights/scales for the kernel. - w2_shuffled_flat = w2_shuffled.view(experts * model_dim, inter_dim) - scale_w2_flat = None if scale_w2 is None else scale_w2.view(experts * model_dim, 1) + if w_dtype == "fp4": + # FP4 path: use shuffled weights directly (already packed) + w2_kernel = w2_shuffled + w2_scale_1d = scale_w2_shuffled + else: + w2_shuffled_flat = w2_shuffled.view(experts * model_dim, inter_dim) + scale_w2_flat = None if scale_w2 is None else scale_w2.view(experts * model_dim, 1) + + # For W4A8, pack preshuffled int8 weights into packed int4 bytes. + w2_kernel = w2_shuffled_flat + if is_int4: + w2_kernel = _pack_shuffled_int8_to_packed_int4_no_perm(w2_shuffled_flat) - # For W4A8, pack preshuffled int8 weights into packed int4 bytes. - w2_kernel = w2_shuffled_flat - if is_int4: - w2_kernel = _pack_shuffled_int8_to_packed_int4_no_perm(w2_shuffled_flat) + w2_flat = w2_kernel.contiguous().view(-1) + w2_kernel = w2_flat + if not is_int4: + w2_kernel = w2_kernel.view(experts * model_dim, inter_dim) - w2_flat = w2_kernel.contiguous().view(-1) - w2_kernel = w2_flat - if not is_int4: - w2_kernel = w2_kernel.view(experts * model_dim, inter_dim) + if scale_w2_flat is None: + w2_scale_1d = torch.empty((0,), device=device, dtype=torch.float32) + else: + w2_scale_1d = scale_w2_flat.view(-1).contiguous() # [experts*model_dim] # Flatten scales to 1D memrefs (fp16 path uses 0-sized scale tensors; kernel ignores them). if a2_scale is None: a2_scale_1d = torch.empty((0,), device=device, dtype=torch.float32) else: a2_scale_1d = a2_scale.view(-1).contiguous() # [tokens*topk] - if scale_w2_flat is None: - w2_scale_1d = torch.empty((0,), device=device, dtype=torch.float32) - else: - w2_scale_1d = scale_w2_flat.view(-1).contiguous() # [experts*model_dim] sorted_weights_1d = sorted_weights.contiguous().view(-1) # [sorted_size] out_s = str(out_dtype).strip().lower() @@ -848,6 +875,19 @@ def run_moe_stage2( out = torch.zeros((tokens, model_dim), device=device, dtype=out_torch_dtype) out_perf = torch.zeros_like(out) + # Bias: [experts, model_dim] f32 + if enable_bias: + bias = ( + bias_in + if bias_in is not None + else torch.randn((experts, model_dim), device=device, dtype=torch.float32) * 0.1 + ) + # Flatten bias for kernel: [experts * model_dim] + bias_1d = bias.view(-1).contiguous() + else: + bias = None + bias_1d = torch.empty((0,), device=device, dtype=torch.float32) + doweight_stage2 = not bool(doweight_stage1) exe = compile_fn( model_dim=model_dim, @@ -861,9 +901,10 @@ def run_moe_stage2( tile_n=tile_n, tile_k=tile_k, doweight_stage2=bool(doweight_stage2), + enable_bias=enable_bias, ) - def launch(o, x, w, sx, sw, st, eids, sw_sorted): + def launch(o, x, w, sx, sw, st, eids, sw_sorted, bias_arg): exe( o, x, @@ -874,6 +915,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): eids, sw_sorted, num_valid_ids, + bias_arg, tokens, model_dim, inter_dim, @@ -893,6 +935,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): sorted_token_ids, sorted_expert_ids, sorted_weights_1d, + bias_1d, num_iters=int(num_iters), num_warmup=int(num_warmup), ) @@ -909,6 +952,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): sorted_token_ids, sorted_expert_ids, sorted_weights_1d, + bias_1d, ) torch.cuda.synchronize() @@ -932,8 +976,9 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): topk_weights, model_dim=model_dim, doweight_stage2=doweight_stage2, + bias=bias, ) - assert verify_output(out.to(torch.float32), ref2, rtol=0.000005, atol=0.000005) + assert verify_output(out.to(torch.float32), ref2, rtol=1e-3, atol=1e-3) # Launches full expert-block range; effective work is gated by num_valid_ids. flops = 2 * tokens * topk * model_dim * inter_dim @@ -1073,9 +1118,9 @@ def test_moe_gemm_2stage( doweight_stage1: bool, x_dtype: str, w_dtype: str, - out_dtype: str, use_reduce: bool, *, + out_dtype: str = "f16", seed: int = 0, num_iters: int = 5, num_warmup: int = 2, @@ -1218,9 +1263,12 @@ def _compile( x_dtype: str, w_dtype: str, out_dtype: str = "f16", + enable_bias: bool = False, ): if use_flydsl_reduce: # Use unified implementation with FlyDSL reduce kernel + # Note: compile_moe_gemm2_ex uses in_dtype instead of x_dtype/w_dtype + in_dtype = x_dtype # Assume x_dtype and w_dtype are the same for reduce mode return compile_moe_gemm2_ex( model_dim=model_dim, inter_dim=inter_dim, @@ -1230,8 +1278,7 @@ def _compile( tile_n=tile_n, tile_k=tile_k, doweight_stage2=doweight_stage2, - x_dtype=x_dtype, - w_dtype=w_dtype, + in_dtype=in_dtype, out_dtype=out_dtype, mode=MoeGemm2Mode.REDUCE, ) @@ -1250,6 +1297,7 @@ def _compile( w_dtype=w_dtype, out_dtype=out_dtype, accumulate=False, + enable_bias=enable_bias, ) return _TorchReduceWrapper(gemm2_exe, topk, model_dim) return _compile @@ -1278,6 +1326,7 @@ def __call__( arg_expert_ids, arg_sorted_weights, arg_num_valid_ids, + arg_bias, tokens_in, n_in, k_in, @@ -1296,7 +1345,7 @@ def __call__( intermediate.view(-1), arg_x, arg_w, arg_scale_x, arg_scale_w, arg_sorted_token_ids, arg_expert_ids, arg_sorted_weights, - arg_num_valid_ids, tokens_in, n_in, k_in, size_expert_ids_in, + arg_num_valid_ids, arg_bias, tokens_in, n_in, k_in, size_expert_ids_in, ) torch.sum(intermediate.view(tokens_in, self._topk, self._model_dim), dim=1, out=arg_out) @@ -1546,7 +1595,7 @@ def _str2tuple_dim(v: str) -> Tuple[int, int]: "--w_dtype", type=str, default="fp8", - choices=["fp8", "fp16", "int8", "int4"], + choices=["fp8", "fp16", "int8", "int4", "fp4"], help="Kernel weight dtype: fp8 / fp16 / int8 / int4.", ) parser.add_argument( diff --git a/tests/kernels/test_ref.py b/tests/kernels/test_ref.py index a97b69eb..ed084202 100644 --- a/tests/kernels/test_ref.py +++ b/tests/kernels/test_ref.py @@ -10,8 +10,13 @@ def torch_moe_gemm1( topk_weights: torch.Tensor, inter_dim: int, doweight_stage1: bool, + bias: torch.Tensor | None = None, ) -> torch.Tensor: - """Return [tokens, topk, inter_dim] fp32.""" + """Return [tokens, topk, inter_dim] fp32. + + Args: + bias: Optional bias tensor of shape [experts, 2 * inter_dim] (gate_bias, up_bias concatenated). + """ tokens, model_dim = x_fp8.shape topk = topk_ids.shape[1] # Derive experts from weight shapes (topk_ids may not cover all experts when tokens are tiny). @@ -40,6 +45,10 @@ def torch_moe_gemm1( y2 = F.linear(x[t_idx, :], w1[e, :, :]) # [num, 2*inter_dim] gate = y2[:, :inter_dim] up = y2[:, inter_dim:] + # Apply bias if provided + if bias is not None: + gate = gate + bias[e, :inter_dim] + up = up + bias[e, inter_dim:] y = F.silu(gate) * up if doweight_stage1: y = y * topk_weights[t_idx, s_idx].unsqueeze(-1) @@ -56,14 +65,19 @@ def torch_moe_gemm2( topk_weights: torch.Tensor, model_dim: int, doweight_stage2: bool, + bias: torch.Tensor | None = None, ) -> torch.Tensor: """Return [tokens, model_dim] fp32. Semantics align with aiter `torch_moe_stage2`: - Dequantize `a2_fp8` and `w2_fp8` with per-token/row scales. - For each routed (token, slot) -> expert, compute y = a2 @ W2[expert]^T. + - Optionally add bias (if provided). - Optionally multiply routed weight in stage2 (when stage1 did *not*). - Reduce across topk by summing into [tokens, model_dim]. + + Args: + bias: Optional bias tensor of shape [experts, model_dim]. """ assert a2_fp8.is_cuda and w2_fp8.is_cuda tokens, topk, inter_dim = a2_fp8.shape @@ -87,6 +101,9 @@ def torch_moe_gemm2( t_idx = idx[:, 0] s_idx = idx[:, 1] y = F.linear(a2[t_idx, s_idx, :], w2[e, :, :]) # [num, model_dim] + # Apply bias if provided + if bias is not None: + y = y + bias[e, :] if doweight_stage2: y = y * topk_weights[t_idx, s_idx].unsqueeze(-1) out.index_add_(0, t_idx, y) From a808e9d6046adbb4e99c90dac2022d32aab0952b Mon Sep 17 00:00:00 2001 From: Zzz9990 Date: Tue, 10 Feb 2026 20:52:36 -0600 Subject: [PATCH 08/11] update --- kernels/moe_gemm_2stage.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 051a3248..8f1c5794 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -3133,6 +3133,7 @@ def __call__( arg_expert_ids, arg_sorted_weights, arg_num_valid_ids, + arg_bias, tokens_in, n_in, k_in, @@ -3150,7 +3151,7 @@ def __call__( intermediate.view(-1), arg_x, arg_w, arg_scale_x, arg_scale_w, arg_sorted_token_ids, arg_expert_ids, arg_sorted_weights, - arg_num_valid_ids, tokens_in, n_in, k_in, size_expert_ids_in, + arg_num_valid_ids, arg_bias, tokens_in, n_in, k_in, size_expert_ids_in, stream_ptr, ) # Phase 2: Reduce over topk -> [tokens, model_dim] @@ -3174,9 +3175,11 @@ def compile_moe_gemm2_ex( tile_n: int, tile_k: int, doweight_stage2: bool, - in_dtype: str = "fp8", + x_dtype: str = "fp8", + w_dtype: str = "fp8", out_dtype: str = "f16", use_cshuffle_epilog: bool | None = None, + enable_bias: bool = False, # Extended parameters for mode control mode: str = MoeGemm2Mode.AUTO, tokens_hint: int | None = None, @@ -3229,9 +3232,11 @@ def compile_moe_gemm2_ex( tile_n=tile_n, tile_k=tile_k, doweight_stage2=doweight_stage2, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, out_dtype=out_dtype, use_cshuffle_epilog=use_cshuffle_epilog, + enable_bias=enable_bias, accumulate=False, ) # Compile reduction kernel @@ -3268,6 +3273,7 @@ def compile_moe_gemm2_ex( in_dtype=in_dtype, out_dtype=out_dtype, use_cshuffle_epilog=use_cshuffle_epilog, + enable_bias=enable_bias, accumulate=True, ) From c65d87ae8049ab803ff423468eba29054df6cc6a Mon Sep 17 00:00:00 2001 From: Zzz9990 Date: Wed, 11 Feb 2026 03:42:31 -0600 Subject: [PATCH 09/11] update for prefill --- kernels/moe_gemm_2stage.py | 137 ++++++++++++++----------------------- 1 file changed, 53 insertions(+), 84 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 8f1c5794..5c870b60 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -44,7 +44,6 @@ ####==================== gemm1 pipeline start =====================### - @functools.lru_cache(maxsize=1024) def compile_moe_gemm1( *, @@ -72,7 +71,6 @@ def compile_moe_gemm1( - "int8": X/W are int8 - "int4": W4A8 path: X is int8, W is packed int4 (2 values per byte) unpacked to int8 in-kernel """ - gpu_arch = get_hip_arch() allocator = SmemAllocator(None, arch=gpu_arch) _state = {} @@ -190,7 +188,6 @@ def _mfma_output_pack_ty(): module_name = ( f"mfma_moe1_{x_dtype}_{w_dtype}_{out_dtype}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" - f"_abi3" # also mask sentinel token ids on loads (X/scale_x) to avoid illegal address faults ).replace("-", "_") class _MOE1(flir.MlirModule): @@ -471,11 +468,7 @@ def x_tile_chunk_coord_i32(i: int): fused_i = buffer_ops.buffer_load(sorted_rsrc, sorted_row_i, vec_width=1, dtype=i32) t_raw = arith.andi(fused_i, mask24) t_idx = arith.index_cast(ir.IndexType.get(), t_raw) - # NOTE: aiter moe_sorting uses sentinel token_id == tokens for padding. - # Do NOT rely on buffer OOB semantics for X loads; explicitly mask to a safe row. - t_valid_i32 = arith.cmpu(t_raw, tokens_i32, "ult") - t_safe = arith.select(t_valid_i32, t_idx, arith.index(0)) - x_row_base_div4.append(t_safe * c_k_div4) + x_row_base_div4.append(t_idx * c_k_div4) vec1_i32 = I.vec(1, i32) vec2_i32 = I.vec(2, i32) @@ -1219,18 +1212,9 @@ def write_row_to_lds( # `row` is the sorted-row index (bx_m + row_in_tile). fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) t2 = fused2 & mask24_i32 - # aiter moe_sorting uses sentinel token_id == tokens for padding. - # Do NOT rely on buffer OOB semantics for scale loads; explicitly mask. - t_valid = arith.cmpu(t2, tokens_i32_v, "ult") - sx = ( - arith.f32(1.0) - if no_epilogue_dequant - else arith.select( - t_valid, - buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32), - arith.f32(0.0), - ) - ) + # No explicit mask: rely on buffer descriptor OOB to zero-fill when t2 is the + # sentinel (t2 == tokens) or otherwise out-of-range. + sx = arith.f32(1.0) if no_epilogue_dequant else buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32) # Sorted weight aligned with `row` (matches aiter moe_sorting output). if doweight_stage1: @@ -1289,20 +1273,18 @@ def precompute_row(*, row_local, row): return (t2 * topk_i32_v + s2) * inter_i32_local def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): - # Guard against sentinel token ids (t == tokens) produced by aiter moe_sorting padding. - # OOB buffer stores are not guaranteed to be safe on all paths, so predicate explicitly. - fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) - t2 = fused2 & mask24_i32 - t_valid = arith.cmpu(t2, tokens_i32_v, "ult") - _if_valid = scf.IfOp(t_valid) - with _if_valid.then(): - if not is_cast_out: - idx0 = row_ctx - col_i32 = arith.index_cast(i32, col_g0) - idx_out = idx0 + col_i32 - # Vectorized fp16 store (EVec=4). - buffer_ops.buffer_store(frag, out_rsrc, idx_out) - else: + if not is_cast_out: + idx0 = row_ctx + col_i32 = arith.index_cast(i32, col_g0) + idx_out = idx0 + col_i32 + # Vectorized fp16 store (EVec=4). + buffer_ops.buffer_store(frag, out_rsrc, idx_out) + else: + fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) + t2 = fused2 & mask24_i32 + t_valid = arith.cmpu(t2, tokens_i32, "ult") + _if_valid = scf.IfOp(t_valid) + with _if_valid.then(): frag = vector.bitcast(vec4_f32, frag) frag0 = vector.extract(frag, static_position=[0], dynamic_position=[]) frag1 = vector.extract(frag, static_position=[1], dynamic_position=[]) @@ -1353,18 +1335,8 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): s2_raw = fused2 >> 24 t2 = t2_raw s2 = s2_raw - t_valid = arith.cmpu(t2, tokens_i32_v, "ult") - - # Do NOT rely on buffer OOB semantics for scale loads; explicitly mask. - sx0 = ( - arith.f32(1.0) - if no_epilogue_dequant - else arith.select( - t_valid, - buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32), - arith.f32(0.0), - ) - ) + + sx0 = arith.f32(1.0) if no_epilogue_dequant else buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32) sx = sx0 zero_out = arith.constant(0.0, type=_out_elem_type()) @@ -1375,44 +1347,42 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): if doweight_stage1: tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=f32) - _if_valid = scf.IfOp(t_valid) - with _if_valid.then(): - for ni in range_constexpr(num_acc_n_final): - col_i32 = col_i32_list[ni] - - acc_idx = mi * num_acc_n_final + ni - vg = vector.extract( - acc_gate[acc_idx], static_position=[ii], dynamic_position=[] - ) - vu = vector.extract( - acc_up[acc_idx], static_position=[ii], dynamic_position=[] - ) - - if is_int_mode: - vg = arith.sitofp(f32, vg) - vu = arith.sitofp(f32, vu) - - if not no_epilogue_dequant: - sw_gate = sw_gate_vals[ni] - sw_up = sw_up_vals[ni] - vg = vg * sx * sw_gate - vu = vu * sx * sw_up + for ni in range_constexpr(num_acc_n_final): + col_i32 = col_i32_list[ni] + + acc_idx = mi * num_acc_n_final + ni + vg = vector.extract( + acc_gate[acc_idx], static_position=[ii], dynamic_position=[] + ) + vu = vector.extract( + acc_up[acc_idx], static_position=[ii], dynamic_position=[] + ) + + if is_int_mode: + vg = arith.sitofp(f32, vg) + vu = arith.sitofp(f32, vu) + + if not no_epilogue_dequant: + sw_gate = sw_gate_vals[ni] + sw_up = sw_up_vals[ni] + vg = vg * sx * sw_gate + vu = vu * sx * sw_up - if enable_bias: - gate_bias_list, up_bias_list = epilogue_pf - vg = vg + gate_bias_list[ni] - vu = vu + up_bias_list[ni] - - if act == "swiglu": - y = swiglu(vg, vu) - else: - y = silu(vg) * vu + if enable_bias: + gate_bias_list, up_bias_list = epilogue_pf + vg = vg + gate_bias_list[ni] + vu = vu + up_bias_list[ni] + + if act == "swiglu": + y = swiglu(vg, vu) + else: + y = silu(vg) * vu - if doweight_stage1: - y = y * tw - y = arith.trunc_f(_out_elem_type(), y) - idx_out0 = idx0 + col_i32 - buffer_ops.buffer_store(y, out_rsrc, idx_out0) + if doweight_stage1: + y = y * tw + y = arith.trunc_f(_out_elem_type(), y) + idx_out0 = idx0 + col_i32 + buffer_ops.buffer_store(y, out_rsrc, idx_out0) mfma_epilog( use_cshuffle=False, @@ -1448,7 +1418,6 @@ def __call__( # Use host-provided upper bound for M blocks (same as aiter moe_sorting allocation). # This avoids device->host sync on num_valid_ids. gy = size_expert_ids_in - stream_token = stream_ptr_to_async_token(stream_ptr) flir.gpu_ext.LaunchFuncOp( [module_name, "moe_gemm1"], @@ -1470,7 +1439,7 @@ def __call__( k_in, size_expert_ids_in, ], - async_dependencies=[stream_token], + # async_dependencies=[stream_token], ) m = _MOE1() From a1ce89311eb0a5215f3417325f0f6bdfd9e3a654 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 11 Feb 2026 12:20:31 +0000 Subject: [PATCH 10/11] update --- kernels/moe_gemm_2stage.py | 57 ++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 5c870b60..670c50c0 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -1007,12 +1007,33 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): rocdl.sched_barrier(0) def hot_loop_scheduler(): - mfma_group = num_acc_n * 2 - # K64 micro-step: 2x K32 MFMA per gemm. - mfma_total = (k_unroll * 2) * m_repeat * mfma_group - mfma_per_iter = 2 * mfma_group - sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) - + # mfma_group = num_acc_n * 2 + # # K64 micro-step: 2x K32 MFMA per gemm. + # mfma_total = (k_unroll * 2) * m_repeat * mfma_group + # mfma_per_iter = 2 * mfma_group + # sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) + + # rocdl.sched_dsrd(2) + # rocdl.sched_mfma(2) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(1) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(1) + + # # DS-write hints near the end: match total X LDS-store micro-ops per thread. + # dswr_tail = num_x_loads + # if dswr_tail > sche_iters: + # dswr_tail = sche_iters + # dswr_start = sche_iters - dswr_tail + # for sche_i in range_constexpr(sche_iters): + # rocdl.sched_vmem(1) + # rocdl.sched_mfma(mfma_group) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(mfma_group) + # if sche_i >= dswr_start - 1: + # rocdl.sched_dswr(1) + # rocdl.sched_barrier(0) + rocdl.sched_dsrd(2) rocdl.sched_mfma(2) rocdl.sched_dsrd(1) @@ -1021,17 +1042,17 @@ def hot_loop_scheduler(): rocdl.sched_mfma(1) # DS-write hints near the end: match total X LDS-store micro-ops per thread. - dswr_tail = num_x_loads - if dswr_tail > sche_iters: - dswr_tail = sche_iters - dswr_start = sche_iters - dswr_tail - for sche_i in range_constexpr(sche_iters): - rocdl.sched_vmem(1) - rocdl.sched_mfma(mfma_group) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(mfma_group) + dswr_start = 12 - 2 + for sche_i in range_constexpr(12): + if sche_i < 10: + rocdl.sched_vmem(1) + if sche_i <= 10: + rocdl.sched_dsrd(1) + + rocdl.sched_mfma(1) if sche_i >= dswr_start - 1: rocdl.sched_dswr(1) + rocdl.sched_barrier(0) # ================ Unified Pipeline (FP4 / Standard) for gemm1 ================ @@ -1073,6 +1094,7 @@ def _prefetch_scale(k_val): # Unrolled ping-pong main loop (2 tiles per iteration), leaving 2 tail tiles. c2_tile_k = arith.constant(tile_k * 2, index=True) c_k_main2 = k_in - c2_tile_k + rocdl.sched_barrier(0) for k_iv in range(arith.index(0), c_k_main2, c2_tile_k): # ---- stage 0: prefetch+store ping, compute pong ---- @@ -1439,7 +1461,7 @@ def __call__( k_in, size_expert_ids_in, ], - # async_dependencies=[stream_token], + async_dependencies=[stream_token], ) m = _MOE1() @@ -3239,7 +3261,8 @@ def compile_moe_gemm2_ex( tile_n=tile_n, tile_k=tile_k, doweight_stage2=doweight_stage2, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, out_dtype=out_dtype, use_cshuffle_epilog=use_cshuffle_epilog, enable_bias=enable_bias, From 9bcf6789c0b63e2b6d2473df28834fe06d3752f4 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 11 Feb 2026 14:48:34 +0000 Subject: [PATCH 11/11] update --- kernels/moe_gemm_2stage.py | 114 ++++++++++++++++++++++++++----------- 1 file changed, 80 insertions(+), 34 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 670c50c0..682c6c6d 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -980,7 +980,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) acc_idx = mi_idx * num_acc_n_packed + (ni_idx // 2) - rocdl.sched_barrier(0) + current_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, [ @@ -1037,8 +1037,10 @@ def hot_loop_scheduler(): rocdl.sched_dsrd(2) rocdl.sched_mfma(2) rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) rocdl.sched_mfma(1) rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) rocdl.sched_mfma(1) # DS-write hints near the end: match total X LDS-store micro-ops per thread. @@ -1938,6 +1940,7 @@ def load_x(idx_i32): return_vector=True, src_buffer_resource=x_rsrc, src_buffer_offset_in_bytes=True, + nontemporal=True, ) # decode routed token once (per thread's M-slice) and build a base offset. @@ -2292,7 +2295,6 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) acc_idx = mi_idx * num_acc_n + ni_idx - rocdl.sched_barrier(0) acc_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, [ @@ -2397,43 +2399,87 @@ def hot_loop_scheduler(): # - MFMA group size per "slot": num_acc_n # - Total MFMA per tile: (2*K32 per K64) * k_unroll * m_repeat * num_acc_n # - We emit (mfma_group + dsrd + mfma_group) per scheduler iteration. - mfma_group = num_acc_n - mfma_total = (k_unroll * 2) * m_repeat * mfma_group - mfma_per_iter = 2 * mfma_group - sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) + # mfma_group = num_acc_n + # mfma_total = (k_unroll * 2) * m_repeat * mfma_group + # mfma_per_iter = 2 * mfma_group + # sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) + + # rocdl.sched_dsrd(2) + # rocdl.sched_mfma(1) + # if tile_m == 16: + # rocdl.sched_vmem(1) + # rocdl.sched_mfma(1) + # if tile_m == 16: + # rocdl.sched_vmem(1) + # if num_acc_n < 4: + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(1) + # if tile_m == 16: + # rocdl.sched_vmem(1) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(1) + # if tile_m == 16: + # rocdl.sched_vmem(1) + # rocdl.sched_mfma(1) + + # # DS-write hints near the end: match total A LDS-store micro-ops per thread. + # dswr_tail = num_x_loads + # if dswr_tail > sche_iters: + # dswr_tail = sche_iters + # dswr_start = sche_iters - dswr_tail + + # for sche_i in range_constexpr(sche_iters): + # rocdl.sched_vmem(1) + # rocdl.sched_mfma(mfma_group) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(mfma_group) + # if sche_i >= dswr_start - 1: + # rocdl.sched_dswr(1) + # rocdl.sched_barrier(0) rocdl.sched_dsrd(2) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + + rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) + + rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + + rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + + rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + + rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + rocdl.sched_dswr(1) + rocdl.sched_vmem(1) + + rocdl.sched_mfma(1) + rocdl.sched_dswr(1) + rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) - if num_acc_n < 4: - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) - rocdl.sched_mfma(1) - - # DS-write hints near the end: match total A LDS-store micro-ops per thread. - dswr_tail = num_x_loads - if dswr_tail > sche_iters: - dswr_tail = sche_iters - dswr_start = sche_iters - dswr_tail - - for sche_i in range_constexpr(sche_iters): - rocdl.sched_vmem(1) - rocdl.sched_mfma(mfma_group) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(mfma_group) - if sche_i >= dswr_start - 1: - rocdl.sched_dswr(1) + # DS-write hints near the end: match total X LDS-store micro-ops per thread. + # dswr_start = 12 - 2 + # for sche_i in range_constexpr(12): + # if sche_i < 10: + # rocdl.sched_vmem(1) + # if sche_i <= 10: + # rocdl.sched_dsrd(1) + + # rocdl.sched_mfma(1) + # if sche_i >= dswr_start - 1: + # rocdl.sched_dswr(1) + rocdl.sched_barrier(0) # ================ Unified Pipeline (FP4 / Standard) ================