diff --git a/flydsl/src/flydsl/kernels b/flydsl/src/flydsl/kernels new file mode 120000 index 00000000..d48390c1 --- /dev/null +++ b/flydsl/src/flydsl/kernels @@ -0,0 +1 @@ +../../../kernels \ No newline at end of file diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 1b8d02e6..59a05a1b 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -12,9 +12,297 @@ from __future__ import annotations from dataclasses import dataclass +import re +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext import arith, gpu, buffer_ops, vector, rocdl +from flydsl.lang.ir.types import T, memref + from _mlir import ir -@dataclass(frozen=True) +from enum import Enum + +class MfmaPipeline(Enum): + F4F4_MXFP4_PIPELINE = "F4F4_MXFP4_PIPELINE" + F8F4_MXFP4_PIPELINE = "F8F4_MXFP4_PIPELINE" + F8F8_16x16_PIPELINE = "F8F8_16x16_PIPELINE" + F16F16_16x16_PIPELINE = "F16F16_16x16_PIPELINE" + BF16BF16_16x16_PIPELINE = "BF16BF16_16x16_PIPELINE" + I8I8_16x16_PIPELINE = "I8I8_16x16_PIPELINE" + I8I4_16x16_PIPELINE = "I8I4_16x16_PIPELINE" + +class EpilogPipeline(Enum): + CSHUFFLE_F16 = "CSHUFFLE_F16" + CSHUFFLE_BF16 = "CSHUFFLE_BF16" + CSHUFFLE_F32 = "CSHUFFLE_F32" + CSHUFFLE_F8 = "CSHUFFLE_F8" + DIRECT_F16 = "DIRECT_F16" + DIRECT_BF16 = "DIRECT_BF16" + DIRECT_F32 = "DIRECT_F32" + +a_elem_type_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.ui8, + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.f8, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f8, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i8, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i8, +} + +b_elem_type_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.ui8, + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.ui8, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f8, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i8, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i8, +} + +scale_elem_type_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.i32, + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.i32, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f32, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.f32, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.f32, + # bf16 scale placeholder + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f32, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.f32, +} + +out_elem_type_dict = { + EpilogPipeline.CSHUFFLE_F16: lambda: T.f16, + EpilogPipeline.CSHUFFLE_BF16: lambda: T.bf16, + EpilogPipeline.CSHUFFLE_F32: lambda: T.f32, + EpilogPipeline.CSHUFFLE_F8: lambda: T.f8, + EpilogPipeline.DIRECT_F16: lambda: T.f16, + EpilogPipeline.DIRECT_BF16: lambda: T.bf16, + EpilogPipeline.DIRECT_F32: lambda: T.f32, +} + +a_vec16_type_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.ui8x16, + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.f8x16, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f8x16, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16x8, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16x8, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i8x16, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i8x16, +} + +b_vec16_type_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.ui8x16, + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.ui8x16, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f8x16, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16x8, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.bf16x8, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i8x16, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i8x16, +} + +mfma_input_pack_ty_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.i64, + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.i64, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.i64, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f16x4, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.i16x4, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i32x4, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i32x4, +} + +mfma_output_pack_ty_dict = { + MfmaPipeline.F4F4_MXFP4_PIPELINE: lambda: T.f32x4, + MfmaPipeline.F8F4_MXFP4_PIPELINE: lambda: T.f32x4, + MfmaPipeline.F8F8_16x16_PIPELINE: lambda: T.f32x4, + MfmaPipeline.F16F16_16x16_PIPELINE: lambda: T.f32x4, + MfmaPipeline.BF16BF16_16x16_PIPELINE: lambda: T.f32x4, + MfmaPipeline.I8I8_16x16_PIPELINE: lambda: T.i32x4, + MfmaPipeline.I8I4_16x16_PIPELINE: lambda: T.i32x4, +} + +mfma_pipeline_dicts = { + "a_elem_type": a_elem_type_dict, + "b_elem_type": b_elem_type_dict, + "scale_elem_type": scale_elem_type_dict, + "out_elem_type": out_elem_type_dict, + "a_vec16_type": a_vec16_type_dict, + "b_vec16_type": b_vec16_type_dict, + "mfma_input_pack_ty": mfma_input_pack_ty_dict, + "mfma_output_pack_ty": mfma_output_pack_ty_dict, +} + +def get_mfma_i32_k32(): + mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( + rocdl, "mfma_i32_16x16x32_i8", None + ) + if mfma_i32_k32 is None: + raise AttributeError( + "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " + "(or `rocdl.mfma_i32_16x16x32_i8`)." + ) + return mfma_i32_k32 + +class PreshufflePipelineManager: + def __init__( + self, + a_dtype: str, + b_dtype: str, + out_dtype: str, + use_cshuffle_epilog: bool = False, + a_packed: bool = False, + b_packed: bool = False, + block_size: int = 256, + ): + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.out_dtype = out_dtype + self.refine_dtype() + self.check_type_valid() + + self.use_cshuffle_epilog = use_cshuffle_epilog + self.a_packed = self.a_dtype in ["fp4"] + self.b_packed = self.b_dtype in ["fp4", "int4"] + self.a_elem_pack = 2 if self.a_packed else 1 + self.b_elem_pack = 2 if self.b_packed else 1 + self.mfma_pipeline = self.get_mfma_pipeline() + self.epilog_pipeline = self.get_epilog_pipeline() + self.a_elem_bytes = self.get_a_elem_bytes() + self.b_elem_bytes = self.get_b_elem_bytes() + self.out_elem_bytes = self.get_out_elem_bytes() + self.block_size = block_size + + def refine_dtype(self): + def _normalize_dtype(value: str) -> str: + s = str(value).strip().lower() + s = re.sub(r"^(f16|float16|half)$", "fp16", s) + s = re.sub(r"^(bf16|bfloat16)$", "bf16", s) + s = re.sub(r"^(f32|fp32|float|float32)$", "f32", s) + s = re.sub(r"^(fp8|f8)$", "fp8", s) + s = re.sub(r"^(fp4|f4)$", "fp4", s) + s = re.sub(r"^(int8|i8)$", "int8", s) + s = re.sub(r"^(int4|i4)$", "int4", s) + return s + + self.a_dtype = _normalize_dtype(self.a_dtype) + self.b_dtype = _normalize_dtype(self.b_dtype) + self.out_dtype = _normalize_dtype(self.out_dtype) + + if self.out_dtype not in ("fp16", "bf16", "f32", "fp8"): + raise ValueError( + f"out_dtype must be 'f16', 'bf16', or 'f32', got {self.out_dtype!r}" + ) + + def check_type_valid(self): + if self.a_dtype not in ["fp8", "fp4", "int8", "fp16", "bf16"]: + raise ValueError(f"Invalid a_dtype: {self.a_dtype}") + if self.b_dtype not in ["fp8", "fp4", "int8", "int4", "fp16", "bf16"]: + raise ValueError(f"Invalid b_dtype: {self.b_dtype}") + if self.out_dtype not in ["fp16", "bf16", "f32", "fp8"]: + raise ValueError(f"Invalid out_dtype: {self.out_dtype}") + + def get_mfma_pipeline(self): + if self.a_dtype == "fp4" and self.b_dtype == "fp4": + return MfmaPipeline.F4F4_MXFP4_PIPELINE + elif self.a_dtype == "fp8" and self.b_dtype == "fp4": + return MfmaPipeline.F8F4_MXFP4_PIPELINE + elif self.a_dtype == "fp8" and self.b_dtype == "fp8": + return MfmaPipeline.F8F8_16x16_PIPELINE + elif self.a_dtype == "fp16" and self.b_dtype == "fp16": + return MfmaPipeline.F16F16_16x16_PIPELINE + elif self.a_dtype == "bf16" and self.b_dtype == "bf16": + return MfmaPipeline.BF16BF16_16x16_PIPELINE + elif self.a_dtype == "int8" and self.b_dtype == "int8": + return MfmaPipeline.I8I8_16x16_PIPELINE + elif self.a_dtype == "int8" and self.b_dtype == "int4": + return MfmaPipeline.I8I4_16x16_PIPELINE + else: + raise ValueError(f"Invalid preshuffle pipeline: {self.a_dtype}_{self.b_dtype}_{self.out_dtype}") + + def get_epilog_pipeline(self): + if self.use_cshuffle_epilog and self.out_dtype == "fp16": + return EpilogPipeline.CSHUFFLE_F16 + elif self.use_cshuffle_epilog and self.out_dtype == "bf16": + return EpilogPipeline.CSHUFFLE_BF16 + elif self.use_cshuffle_epilog and self.out_dtype == "f32": + return EpilogPipeline.CSHUFFLE_F32 + elif self.use_cshuffle_epilog and self.out_dtype == "fp8": + return EpilogPipeline.CSHUFFLE_F8 + elif not self.use_cshuffle_epilog and self.out_dtype == "f32": + return EpilogPipeline.DIRECT_F32 + elif not self.use_cshuffle_epilog and self.out_dtype == "fp16": + return EpilogPipeline.DIRECT_F16 + elif not self.use_cshuffle_epilog and self.out_dtype == "bf16": + return EpilogPipeline.DIRECT_BF16 + else: + raise ValueError(f"Invalid epilog pipeline: {self.out_dtype}") + + def get_b_elem_bytes(self): + if self.b_dtype in ["fp8", "int8", "int4", "fp4"]: + return 1 + elif self.b_dtype in ["fp16", "bf16"]: + return 2 + else: + raise ValueError(f"Invalid b_dtype: {self.b_dtype}") + + def get_a_elem_bytes(self): + if self.a_dtype in ["fp8", "int8", "int4", "fp4"]: + return 1 + elif self.a_dtype in ["fp16", "bf16"]: + return 2 + else: + raise ValueError(f"Invalid a_dtype: {self.a_dtype}") + + def get_out_elem_bytes(self): + if self.out_dtype in ["fp8", "int8", "int4", "fp4"]: + return 1 + if self.out_dtype in ["fp16", "bf16"]: + return 2 + elif self.out_dtype == "f32": + return 4 + else: + raise ValueError(f"Invalid out_dtype: {self.out_dtype}") + + def get_mfma_fn(self): + if self.mfma_pipeline == MfmaPipeline.F8F8_16x16_PIPELINE: + return rocdl.mfma_f32_16x16x32_fp8_fp8 + elif self.mfma_pipeline == MfmaPipeline.BF16BF16_16x16_PIPELINE: + return rocdl.mfma_f32_16x16x16bf16_1k + elif self.mfma_pipeline == MfmaPipeline.F16F16_16x16_PIPELINE: + return rocdl.mfma_f32_16x16x16f16 + elif self.mfma_pipeline == MfmaPipeline.I8I8_16x16_PIPELINE: + return get_mfma_i32_k32() + elif self.mfma_pipeline == MfmaPipeline.I8I4_16x16_PIPELINE: + return get_mfma_i32_k32() + elif self.mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, MfmaPipeline.F8F4_MXFP4_PIPELINE]: + return rocdl.mfma_scale_f32_16x16x128_f8f6f4 + else: + raise ValueError(f"Invalid mfma pipeline: {self.mfma_pipeline}") + + def get_a_bytes_per_thread( + self, + tile_m: int, + tile_k: int, + ): + a_bytes_per_tile = int(tile_m) * int(tile_k) * int(self.a_elem_bytes) // self.a_elem_pack + if a_bytes_per_tile % self.block_size != 0: + raise ValueError( + "tile_m*tile_k*elem_bytes/a_elem_pack must be divisible by " + f"{self.block_size}: tile_m={tile_m}, tile_k={tile_k}, a_elem_bytes={self.a_elem_bytes}, a_elem_pack={self.a_elem_pack}" + ) + a_bytes_per_thread = a_bytes_per_tile // self.block_size + + # Assume A loads are always 16B-aligned and use fixed dwordx4 (16B) buffer loads. + a_load_bytes = 16 + if a_bytes_per_thread % a_load_bytes != 0: + raise ValueError( + f"a_bytes_per_thread ({a_bytes_per_thread}) must be divisible by {a_load_bytes}" + ) + + return a_bytes_per_thread + + + +@dataclass class PreshuffleBLayout: """Container returned by `make_preshuffle_b_layout`.""" @@ -509,9 +797,232 @@ def lds_load_pack_k32( a_vec64 = vector.bitcast(vec1_i64_ty, loaded_a8) return vector.extract(a_vec64, static_position=[0], dynamic_position=[]) +def block_mfma_block_scale_f8f6f4( + accs_in, + b_tile_in, + a_scale, + b_scale, + lds_base, + lds_load_packs_k64, + col_offset_base_bytes, + row_a_lds, + *, + mfma_fn, + mfma_res_ty, + cbsz, + blgp, + a_elem_vec_pack, + k_unroll, + m_repeat, + num_acc_n, + pack_K, + pack_M, + pack_N, + a0_prefetch=None, +): + k_unroll_packed = k_unroll // pack_K + m_repeat_packed = m_repeat // pack_M + num_acc_n_packed = num_acc_n // pack_N + + mfma_res_ty = T.f32x4 + vec4_i64 = T.vec(4, T.i64) + vec8_i32 = T.vec(8, T.i32) + c0_i64 = arith.constant(0, type=T.i64) + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + for ku128 in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + a_scale_i32 = a_scale[ku128 * m_repeat_packed + mi] + a_scale_val = vector.extract(a_scale_i32, static_position=[0], dynamic_position=[]) + for ni in range_constexpr(num_acc_n_packed): + b_scale_i32 = b_scale[ku128 * num_acc_n_packed + ni] + b_scale_val = vector.extract(b_scale_i32, static_position=[0], dynamic_position=[]) + for ikxdl in range_constexpr(pack_K): + k_idx = ku128 * pack_K + ikxdl + + b_packs0, b_packs1 = b_tile_in[k_idx] + + col_base = col_offset_base_bytes + (k_idx * 128) // a_elem_vec_pack + for imxdl in range_constexpr(pack_M): + col_base0 = col_base + mi_idx = mi * pack_M + imxdl + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) + + if cbsz == 0: + col_base1 = col_base + 64 + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + else: + a128 = pack_i64x4_to_i32x8(a0, a1, c0_i64, c0_i64) + + for inxdl in range_constexpr(pack_N): + ni_idx = ni * pack_N + inxdl + b0 = b_packs0[ni_idx] + b1 = b_packs1[ni_idx] + b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) + + acc_idx = mi_idx * num_acc_n + ni_idx + accs_in[acc_idx] = mfma_fn( + mfma_res_ty, + [ + a128, + b128, + accs_in[acc_idx], + # cbsz, abid, blgp + cbsz, + blgp, + # op_sel_a + scale_a (1.0f as i32 bits) + ikxdl * pack_M + imxdl, + a_scale_val, + # + # op_sel_b + scale_b (1.0f as i32 bits) + ikxdl * pack_N + inxdl, + b_scale_val, + ], + ) + return accs_in + +# ---------------- gfx95 fast path (K128 MFMA scale) ---------------- +# This is the key optimization from `zhimding/develop_0107` for FP8: +# use mfma.scale 16x16x128 to reduce instruction count in the hot loop. +# +# Notes: +# - Only valid for fp8 path (not int8/int4) and gfx95+ +# - Requires tile_k divisible by 128 +# - mfma.scale takes 9 operands: 3 vectors + 6 i32 flags/scales. +def block_mfma_PTPC_f8f6f4( + accs_in, + b_tile_in, + lds_base, + col_offset_base_bytes, + row_a_lds, + lds_load_packs_k64, + *, + mfma_res_ty, + mfma_fn, + k_unroll=16, + num_acc_n=16, + m_repeat=16, + a0_prefetch=None, +): + + vec4_i64 = T.vec(4, T.i64) + vec8_i32 = T.vec(8, T.i32) + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + for ku128 in range_constexpr(k_unroll // 2): + ku0 = ku128 * 2 + ku1 = ku0 + 1 + + b0_packs0, b0_packs1 = b_tile_in[ku0] + b1_packs0, b1_packs1 = b_tile_in[ku1] + + col_base0 = col_offset_base_bytes + (ku0 * 64) + col_base1 = col_offset_base_bytes + (ku1 * 64) + + for mi in range_constexpr(m_repeat): + mi_val = arith.constant(mi * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (ku0 == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + + for ni in range_constexpr(num_acc_n): + b0 = b0_packs0[ni] + b1 = b0_packs1[ni] + b2 = b1_packs0[ni] + b3 = b1_packs1[ni] + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + + acc_idx = mi * num_acc_n + ni + accs_in[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + accs_in[acc_idx], + # cbsz, abid, blgp: 0 + 0, + 0, + 0, + # op_sel_a + scale_a (1.0f as i32 bits) + 0x3F800000, + # op_sel_b + scale_b (1.0f as i32 bits) + 0, + 0x3F800000, + ], + ) + return accs_in + + +def block_mfma_16x16( + accs_in, + b_tile_in, + lds_base, + col_offset_base_bytes, + row_a_lds, + lds_load_packs_k64, + *, + mfma_fn, + mfma_res_ty, + k_unroll, + num_acc_n, + m_repeat, + a0_prefetch=None, +): + def mfma_step(acc_in, a, b): + return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0]) + + # "K64-byte wrapper": two back-to-back MFMA/WMMA ops using the two 8B halves. + def mfma_k64_bytes(acc_in, a0, a1, b0, b1): + acc_mid = mfma_step(acc_in, a0, b0) + return mfma_step(acc_mid, a1, b1) + + for ku in range_constexpr(k_unroll): + b_packs0, b_packs1 = b_tile_in[ku] + # Byte-addressed K stepping (64B per ku). + ki64 = ku * 64 + col_base = col_offset_base_bytes + ki64 + for mi in range_constexpr(m_repeat): + mi_val = arith.constant(mi * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + accs_in[acc_idx] = mfma_k64_bytes( + accs_in[acc_idx], + a0, + a1, + b_packs0[ni], + b_packs1[ni], + ) + return accs_in __all__ = [ "PreshuffleBLayout", + "MfmaPipeline", + "EpilogPipeline", + "mfma_pipeline_dicts", + "PreshufflePipelineManager", "buffer_copy_gmem16_dwordx4", "lds_load_pack_k32", "lds_store_4b_xor16", @@ -521,5 +1032,8 @@ def lds_load_pack_k32( "make_preshuffle_scale_layout", "load_b_pack_k32", "tile_chunk_coord_i32", + "block_mfma_block_scale_f8f6f4", + "block_mfma_PTPC_f8f6f4", + "block_mfma_16x16", ] diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 43703400..682c6c6d 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -25,6 +25,8 @@ from flydsl.dialects.ext import arith, gpu, buffer_ops, llvm, vector, rocdl, scf, memref from kernels.mfma_preshuffle_pipeline import ( + MfmaPipeline, + EpilogPipeline, buffer_copy_gmem16_dwordx4, lds_load_pack_k32, lds_store_4b_xor16, @@ -33,11 +35,15 @@ make_preshuffle_b_layout, load_b_pack_k32, tile_chunk_coord_i32, + PreshufflePipelineManager, + mfma_pipeline_dicts, + make_preshuffle_scale_layout, ) from kernels.mfma_epilogues import c_shuffle_epilog, default_epilog, mfma_epilog from kernels.kernels_common import stream_ptr_to_async_token +####==================== gemm1 pipeline start =====================### @functools.lru_cache(maxsize=1024) def compile_moe_gemm1( *, @@ -50,9 +56,12 @@ def compile_moe_gemm1( tile_k: int, # NOTE: aiter swap passes these for API symmetry; stage1 uses dynamic memrefs so they are ignored. doweight_stage1: bool, - in_dtype: str = "fp8", + x_dtype: str = "fp8", + w_dtype: str = "fp8", out_dtype: str = "f16", use_cshuffle_epilog: bool | None = None, + act: str = "silu", + enable_bias: bool = False, ): """Compile stage1 kernel (`moe_gemm1`) and return the compiled executable. @@ -62,56 +71,99 @@ def compile_moe_gemm1( - "int8": X/W are int8 - "int4": W4A8 path: X is int8, W is packed int4 (2 values per byte) unpacked to int8 in-kernel """ - gpu_arch = get_hip_arch() allocator = SmemAllocator(None, arch=gpu_arch) _state = {} - if in_dtype not in ("fp8", "fp16", "int8", "int4"): - raise ValueError(f"in_dtype must be one of ('fp8','fp16','int8','int4'), got {in_dtype!r}") - is_f16 = in_dtype == "fp16" - elem_bytes = 2 if is_f16 else 1 - if out_dtype not in ("f16", "bf16"): - raise ValueError(f"out_dtype must be 'f16' or 'bf16', got {out_dtype!r}") - # NOTE: don't materialize MLIR types outside an active MLIR Context. - out_mlir = lambda: (T.f16() if out_dtype == "f16" else T.bf16()) - tile_k_bytes = int(tile_k) * int(elem_bytes) + total_threads = 256 + + pipeline_manager = PreshufflePipelineManager(x_dtype, w_dtype, out_dtype, use_cshuffle_epilog) + pipeline_manager.check_type_valid() + + epilog_pipeline = pipeline_manager.get_epilog_pipeline() + mfma_pipeline = pipeline_manager.get_mfma_pipeline() + mfma_fn = pipeline_manager.get_mfma_fn() + + x_elem_bytes = pipeline_manager.get_a_elem_bytes() + w_elem_bytes = pipeline_manager.get_b_elem_bytes() + out_elem_bytes = pipeline_manager.get_out_elem_bytes() + + x_elem_pack = pipeline_manager.a_elem_pack + w_elem_pack = pipeline_manager.b_elem_pack + + tile_k_bytes = int(tile_k) * int(x_elem_bytes) # K64-byte micro-step: always 64 bytes per `ku`. For fp16 this is 32 elements. if (tile_k_bytes % 64) != 0: raise ValueError( f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " - f"(tile_k={tile_k}, elem_bytes={elem_bytes})" + f"(tile_k={tile_k}, elem_bytes={x_elem_bytes})" ) - is_int4 = in_dtype == "int4" - # INT4 here means W4A8: X is int8, W is packed int4 and unpacked to int8 in-kernel. - is_int8 = (in_dtype == "int8") or is_int4 - - mfma_i32_k32 = None - if is_int8: - mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( - rocdl, "mfma_i32_16x16x32_i8", None - ) - if mfma_i32_k32 is None: - raise AttributeError( - "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " - "(or `rocdl.mfma_i32_16x16x32_i8`)." - ) + + def _get_mfma_dict_value(key, pipeline): + value = mfma_pipeline_dicts[key][pipeline] + return value() if callable(value) else value + + def _x_elem_type(): + return _get_mfma_dict_value("a_elem_type", mfma_pipeline) + def _w_elem_type(): + return _get_mfma_dict_value("b_elem_type", mfma_pipeline) + def _scale_elem_type(): + return _get_mfma_dict_value("scale_elem_type", mfma_pipeline) + def _out_elem_type(): + return _get_mfma_dict_value("out_elem_type", epilog_pipeline) + def _x_vec16_type(): + return _get_mfma_dict_value("x_vec16_type", mfma_pipeline) + def _w_vec16_type(): + return _get_mfma_dict_value("w_vec16_type", mfma_pipeline) + def _mfma_input_pack_ty(): + return _get_mfma_dict_value("mfma_input_pack_ty", mfma_pipeline) + def _mfma_output_pack_ty(): + return _get_mfma_dict_value("mfma_output_pack_ty", mfma_pipeline) + + is_f16_or_bf16 = mfma_pipeline in [MfmaPipeline.F16F16_16x16_PIPELINE, + MfmaPipeline.BF16BF16_16x16_PIPELINE] + is_int4 = mfma_pipeline in [MfmaPipeline.I8I4_16x16_PIPELINE] + is_int8 = mfma_pipeline in [MfmaPipeline.I8I8_16x16_PIPELINE] + is_fp4 = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, + MfmaPipeline.F8F4_MXFP4_PIPELINE] + is_fp8_a = mfma_pipeline == MfmaPipeline.F8F4_MXFP4_PIPELINE + + is_int_mode = is_int8 or is_int4 + + no_epilogue_dequant = is_fp4 or is_f16_or_bf16 + + pack_M = 2 if is_fp4 else 1 + pack_N = 2 if is_fp4 else 1 + pack_K = 2 if is_fp4 else 1 + + # FP4 specific parameters for mfma_scale_f32_16x16x128_f8f6f4 (gemm1) + # cbsz encodes A matrix type: 0 for fp8, 4 for fp4 + cbsz_g1 = 0 if mfma_pipeline == MfmaPipeline.F8F4_MXFP4_PIPELINE else 4 + blgp_g1 = 4 + + is_gate_up_inter = is_fp4 + + is_cast_out = out_dtype == "fp8" + if is_cast_out and not use_cshuffle_epilog: + raise ValueError("out_dtype='fp8' requires CShuffle epilogue (set use_cshuffle_epilog=True).") DYN = ir.ShapedType.get_dynamic_size() size_out = DYN size_x = DYN + # W is packed int4 for W4A8: 2 values per byte. - size_w = (experts * (2 * inter_dim) * model_dim) // 2 if is_int4 else (experts * (2 * inter_dim) * model_dim) + size_w = (experts * (2 * inter_dim) * model_dim) // w_elem_pack size_sorted = DYN size_expert_ids = DYN total_threads = 256 - bytes_x_per_tile = int(tile_m) * int(tile_k) * int(elem_bytes) + bytes_x_per_tile = int(tile_m) * int(tile_k) * int(x_elem_bytes) // x_elem_pack if bytes_x_per_tile % total_threads != 0: raise ValueError( "tile_m*tile_k*elem_bytes must be divisible by " - f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={elem_bytes}" + f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={x_elem_bytes}" ) + bytes_per_thread_x = bytes_x_per_tile // total_threads # Keep MoE stage1 X gmem->LDS pipeline consistent with the optimized GEMM kernel: # split into <=16B pieces and use `flir.copy(load-only)` for buffer_load_dwordx4. @@ -125,17 +177,17 @@ def compile_moe_gemm1( lds_stride = tile_k + pad_k if use_cshuffle_epilog is None: use_cshuffle_epilog = os.environ.get("FLIR_MOE_STAGE1_CSHUFFLE", "1") in ("1", "true", "True", "YES", "yes") + use_cshuffle_epilog = bool(use_cshuffle_epilog) - if out_dtype != "f16" and use_cshuffle_epilog: - raise ValueError("stage1 cshuffle epilog currently supports only f16 output (out_dtype='f16')") + # if out_dtype != "f16" and use_cshuffle_epilog: + # raise ValueError("stage1 cshuffle epilog currently supports only f16 output (out_dtype='f16')") epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" # IMPORTANT: module name participates in FlyDSL's compile cache key. # Keep an explicit ABI tag so signature changes can't accidentally reuse an old binary. module_name = ( - f"mfma_moe1_{in_dtype}_{out_dtype}_{epilog_tag}" + f"mfma_moe1_{x_dtype}_{w_dtype}_{out_dtype}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" - f"_abi3" # also mask sentinel token ids on loads (X/scale_x) to avoid illegal address faults ).replace("-", "_") class _MOE1(flir.MlirModule): @@ -150,50 +202,56 @@ def init_gpu_module(self): # - ping-pong X tiles (2 * tile_m * lds_stride bytes; fp8/int8) # - epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes) _use_cshuffle_epilog = bool(use_cshuffle_epilog) - lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(elem_bytes) - lds_out_bytes = 2 * tile_m * tile_n if _use_cshuffle_epilog else 0 + lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(x_elem_bytes) + lds_out_elem_bytes = 2 if out_dtype in ["fp16", "bf16"] else 4 + lds_out_bytes = lds_out_elem_bytes * tile_m * tile_n if _use_cshuffle_epilog else 0 lds_total_bytes = max(lds_x_bytes, lds_out_bytes) - lds_total_elems = lds_total_bytes if elem_bytes == 1 else (lds_total_bytes // 2) - x_lds_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + lds_total_elems = lds_total_bytes if x_elem_bytes == 1 else (lds_total_bytes // 2) + # x_lds_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + x_lds_elem = I.f16 if is_f16_or_bf16 else (I.i8 if is_int_mode else I.f8) _state["lds_x_decl"] = allocator.allocate_array(x_lds_elem, lds_total_elems) allocator.finalize() @flir.kernel def moe_gemm1( self: flir.T.i64, - arg_out: lambda: T.memref(DYN, out_mlir()), - arg_x: lambda: T.memref(DYN, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_w: lambda: T.memref(DYN, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_scale_x: lambda: T.memref(DYN, T.f32()), - arg_scale_w: lambda: T.memref(experts * (2 * inter_dim), T.f32()), + arg_out: lambda: T.memref(DYN, _out_elem_type()), + arg_x: lambda: T.memref(DYN, _x_elem_type()), + arg_w: lambda: T.memref(DYN, _w_elem_type()), + arg_scale_x: lambda: T.memref(DYN, _scale_elem_type()), + arg_scale_w: lambda: T.memref(DYN, _scale_elem_type()), arg_sorted_token_ids: lambda: T.memref(DYN, T.i32()), arg_expert_ids: lambda: T.memref(DYN, T.i32()), arg_sorted_weights: lambda: T.memref(DYN, T.f32()), arg_max_token_ids: lambda: T.memref(DYN, T.i32()), + arg_bias: lambda: T.memref(DYN, T.f32()), tokens_in: lambda: T.index(), inter_in: lambda: T.index(), k_in: lambda: T.index(), size_expert_ids_in: lambda: T.index(), ): - x_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + x_elem = _x_elem_type() # For int4, weights are stored as packed bytes (i8) and unpacked to i8 packs. - w_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + w_elem = _w_elem_type() f16 = I.f16 f32 = I.f32 i32 = I.i32 i64 = I.i64 vec4_f32 = I.vec(4, f32) vec4_i32 = I.vec(4, i32) + vec1_f32 = I.vec(1, f32) vec1_f16 = I.vec(1, f16) vec4_f16 = I.vec(4, f16) - vec16_elems = 16 if elem_bytes == 1 else 8 - vec8_elems = 8 if elem_bytes == 1 else 4 - vec4_elems = 4 if elem_bytes == 1 else 2 - vec8_x = I.vec(vec8_elems, x_elem) - vec16_x = I.vec(vec16_elems, x_elem) + vec16_x_elems = 16 // x_elem_bytes # if x_elem_bytes == 1 else 8 + vec8_x_elems = 8 // x_elem_bytes # if x_elem_bytes == 1 else 4 + vec4_x_elems = 4 // x_elem_bytes # if x_elem_bytes == 1 else 2 + vec8_x = I.vec(vec8_x_elems, x_elem) + vec16_x = I.vec(vec16_x_elems, x_elem) vec1_i64 = I.vec(1, i64) vec2_i64 = I.vec(2, i64) + placeholder = arith.constant(0, index=True) + def silu(x): # device fast path: # emu = exp(-x) ~= exp2(log2e * (-x)) -> v_exp_f32 @@ -207,12 +265,23 @@ def silu(x): den = 1.0 + emu sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) return x * sig + + def swiglu(gate, up, alpha=1.702, limit=7.0): + # Align with CK's device fast path + # + # Using llvm.amdgcn intrinsics prevents lowering to the div_scale/div_fixup + # sequences that introduce extra compares/cndmasks. + gate = arith.minimum(gate, limit) + up = arith.minimum(up, limit) + up = arith.maximum(up, -limit) - acc_init = ( - arith.constant_vector(0, vec4_i32) - if is_int8 - else arith.constant_vector(0.0, vec4_f32) - ) + t = gate * (alpha) * (-1.4426950408889634) # -log2(e) + emu = llvm.call_intrinsic(f32, "llvm.amdgcn.exp2.f32", [t], [], []) + den = 1.0 + emu + sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) + return gate * sig * (up + 1) + + acc_init = (arith.constant_vector(0, _mfma_output_pack_ty())) # Layouts layout_x = flir.make_layout((tokens_in, k_in), stride=(k_in, 1)) @@ -221,15 +290,22 @@ def silu(x): c_n_total = arith.constant(experts * (2 * inter_dim), index=True) kpack_bytes = 8 if is_int4 else 16 b_layout = make_preshuffle_b_layout( - flir, arith, c_n=c_n_total, c_k=k_in, kpack_bytes=kpack_bytes, elem_bytes=elem_bytes + flir, arith, c_n=c_n_total, c_k=k_in // pack_K, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes ) layout_b = b_layout.layout_b - c_k0 = (k_in * arith.constant(int(elem_bytes), index=True)) / arith.index(64) shape_lds = flir.make_shape(tile_m, tile_k) stride_lds = flir.make_stride(lds_stride, 1) layout_lds = flir.make_layout(shape_lds, stride_lds) + # A&B's scale preshuffle layout + layout_a_scale = make_preshuffle_scale_layout( + flir, arith, c_mn=tokens_in, c_k=k_in, + ) + layout_b_scale = make_preshuffle_scale_layout( + flir, arith, c_mn=c_n_total, c_k=k_in, + ) + tx = gpu.thread_id("x") # Align with Aiter launch mapping (NSwizzle==false): # - blockIdx.x -> N dimension (tile along inter_dim) @@ -240,26 +316,30 @@ def silu(x): # Block validity: compute as early as possible so invalid blocks skip all buffer-resource # setup, LDS pointer math, and gmem prefetch work. bx_m = bx * arith.constant(tile_m, index=True) + maxids_rsrc = buffer_ops.create_buffer_resource( arg_max_token_ids, max_size=False, num_records_bytes=arith.i32(4) ) max_token_id_i32 = buffer_ops.buffer_load( maxids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=i32 ) + bx_m_i32 = arith.index_cast(i32, bx_m) blk_valid = arith.cmpu(bx_m_i32, max_token_id_i32, "ult") # Common constants/atoms (hoisted): keep IR small like GEMM. # XOR16 swizzle parameter (in bytes; constant, power-of-two in our configs). k_blocks16 = arith.constant(tile_k_bytes // 16, index=True) - atom_x_s16 = flir.make_copy_atom(x_elem, vector_size=vec16_elems) - atom_x_s8 = flir.make_copy_atom(x_elem, vector_size=8) - atom_x_s4 = flir.make_copy_atom(x_elem, vector_size=4) - atom_x_g2r16 = flir.make_copy_atom(x_elem, vector_size=vec16_elems) - atom_x_g2r8 = flir.make_copy_atom(x_elem, vector_size=vec8_elems) - atom_x_g2r4 = flir.make_copy_atom(x_elem, vector_size=vec4_elems) + atom_x_s16 = flir.make_copy_atom(x_elem, vector_size=vec16_x_elems) + atom_x_s8 = flir.make_copy_atom(x_elem, vector_size=vec8_x_elems) + atom_x_s4 = flir.make_copy_atom(x_elem, vector_size=vec4_x_elems) + atom_x_g2r16 = flir.make_copy_atom(x_elem, vector_size=vec16_x_elems) + atom_x_g2r8 = flir.make_copy_atom(x_elem, vector_size=vec8_x_elems) + atom_x_g2r4 = flir.make_copy_atom(x_elem, vector_size=vec4_x_elems) layout_tx_wave_lane = flir.make_layout((4, 64), stride=(64, 1)) layout_lane16 = flir.make_layout((4, 16), stride=(16, 1)) + bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None + # Everything below is gated by `blk_valid` to avoid doing buffer-resource setup and # gmem work for padding blocks. _if_blk = scf.IfOp(blk_valid) @@ -270,7 +350,7 @@ def silu(x): # Alias LDS bytes as fp16 for optional CShuffle epilogue. _use_cshuffle_epilog = bool(use_cshuffle_epilog) lds_out = ( - SmemPtr(base_ptr, lds_x_ptr.byte_offset, I.f16, shape=(tile_m * tile_n,)).get() + SmemPtr(base_ptr, lds_x_ptr.byte_offset, I.f16 if not is_cast_out else I.f32, shape=(tile_m * tile_n,)).get() if _use_cshuffle_epilog else None ) @@ -280,7 +360,7 @@ def silu(x): c_topk = arith.constant(topk, index=True) # X: [tokens, k] bytes = tokens*k*elem_bytes - x_nbytes_idx = tokens_in * k_in * arith.constant(int(elem_bytes), index=True) + x_nbytes_idx = tokens_in * k_in * arith.constant(int(x_elem_bytes), index=True) x_nbytes_i32 = arith.index_cast(i32, x_nbytes_idx) x_rsrc = buffer_ops.create_buffer_resource( arg_x, max_size=False, num_records_bytes=x_nbytes_i32 @@ -289,19 +369,23 @@ def silu(x): w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) # OUT: [tokens, topk, inter] f16/bf16 -> bytes = tokens*topk*inter*out_elem_bytes - out_elem_bytes = 2 # f16/bf16 out_nbytes_idx = tokens_in * c_topk * inter_in * arith.constant(out_elem_bytes, index=True) out_nbytes_i32 = arith.index_cast(i32, out_nbytes_idx) out_rsrc = buffer_ops.create_buffer_resource( arg_out, max_size=False, num_records_bytes=out_nbytes_i32 ) - # fp16 path ignores scales completely (implicit scale=1.0). - if is_f16: + # fp16/bf16 path ignores scales completely (implicit scale=1.0). + # FP4 path uses scales in mfma_scale instruction, so we still need scale resources. + if is_f16_or_bf16: sx_rsrc = None sw_rsrc = None + elif is_fp4: + # FP4: scale is used in mfma_scale_f32_16x16x128_f8f6f4 instruction + sx_rsrc = buffer_ops.create_buffer_resource(arg_scale_x, max_size=False) + sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) else: - # scale_x: [tokens] f32 -> bytes = tokens*4 + # fp8/int8 path: scale_x: [tokens] f32 -> bytes = tokens*4 sx_nbytes_idx = tokens_in * arith.constant(4, index=True) sx_nbytes_i32 = arith.index_cast(i32, sx_nbytes_idx) sx_rsrc = buffer_ops.create_buffer_resource( @@ -327,7 +411,7 @@ def silu(x): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. - if is_f16: + if is_f16_or_bf16: if bytes_per_thread_x % 16 != 0: raise ValueError( f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" @@ -347,9 +431,9 @@ def silu(x): num_x_loads = bytes_per_thread_x // x_load_bytes chunk_i32 = x_load_bytes // 4 # dwords per chunk (1/2/4) - c_k_div4 = (k_in * arith.constant(int(elem_bytes), index=True)) / arith.index(4) + c_k_div4 = (k_in * arith.constant(int(x_elem_bytes), index=True)) / arith.index(4) layout_x_div4 = flir.make_layout((tokens_in, c_k_div4), stride=(c_k_div4, 1)) - tile_k_dwords = (int(tile_k) * int(elem_bytes)) // 4 + tile_k_dwords = (int(tile_k) * int(x_elem_bytes)) // 4 layout_x_tile_div4 = flir.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) c_chunk_i32 = arith.constant(chunk_i32, index=True) tx_i32_base = tx * c_chunk_i32 @@ -384,11 +468,7 @@ def x_tile_chunk_coord_i32(i: int): fused_i = buffer_ops.buffer_load(sorted_rsrc, sorted_row_i, vec_width=1, dtype=i32) t_raw = arith.andi(fused_i, mask24) t_idx = arith.index_cast(ir.IndexType.get(), t_raw) - # NOTE: aiter moe_sorting uses sentinel token_id == tokens for padding. - # Do NOT rely on buffer OOB semantics for X loads; explicitly mask to a safe row. - t_valid_i32 = arith.cmpu(t_raw, tokens_i32, "ult") - t_safe = arith.select(t_valid_i32, t_idx, arith.index(0)) - x_row_base_div4.append(t_safe * c_k_div4) + x_row_base_div4.append(t_idx * c_k_div4) vec1_i32 = I.vec(1, i32) vec2_i32 = I.vec(2, i32) @@ -401,7 +481,7 @@ def load_x(idx_i32): For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. """ if x_load_bytes == 16: - idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * arith.index(2)) + idx_elem = idx_i32 if x_elem_bytes == 1 else (idx_i32 * arith.index(2)) return buffer_copy_gmem16_dwordx4( flir, arg=arg_x, @@ -409,7 +489,7 @@ def load_x(idx_i32): idx_i32=idx_elem, atom_g2r16=atom_x_g2r16, rsrc=x_rsrc, - vec_elems=vec16_elems, + vec_elems=vec16_x_elems, ) idx_bytes = idx_i32 * arith.index(4) atom = atom_x_g2r8 if x_load_bytes == 8 else atom_x_g2r4 @@ -432,11 +512,12 @@ def load_x(idx_i32): def load_x_tile(base_k): """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" - base_k_div4 = (base_k * arith.constant(int(elem_bytes), index=True)) / arith.index(4) + base_k_div4 = (base_k * arith.constant(int(x_elem_bytes), index=True)) / arith.index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32) + # x_vec = load_x(0) if x_load_bytes == 16: parts.append(vector.bitcast(vec4_i32, x_vec)) elif x_load_bytes == 8: @@ -458,8 +539,8 @@ def load_x_tile(base_k): col_offset_base = flir.crd2idx(flir.make_coord(lane_div_16, 0), layout_lane16) col_offset_base_bytes = ( col_offset_base - if elem_bytes == 1 - else (col_offset_base * arith.constant(int(elem_bytes), index=True)) + if x_elem_bytes == 1 + else (col_offset_base * arith.constant(int(x_elem_bytes), index=True)) ) # Dynamic N tiling within block (same as existing kernels) @@ -476,6 +557,9 @@ def load_x_tile(base_k): n_blk_gate = [] n_intra_up = [] n_blk_up = [] + # Single n_blk_list for interleaved mode (matching mixed_moe_gemm_2stage.py) + n_blk_list = [] + n_intra_list = [] col_g_list = [] inter_idx = arith.constant(inter_dim, index=True) # layout for (row -> (blk,intra)) where intra is 0..15 @@ -484,10 +568,23 @@ def load_x_tile(base_k): for ni in range_constexpr(num_acc_n): offset = arith.constant(ni * 16, index=True) col_g = by_n + n_tile_base - col_g = col_g + offset + col_g = col_g // 2 + offset if is_gate_up_inter else col_g + offset col_g = col_g + lane_mod_16 col_g_list.append(col_g) + # For both interleaved and non-interleaved layout, global_n maps to physical layout + global_n = by_n + n_tile_base + offset + lane_mod_16 + row_w = expert_off_idx + global_n + + coord_n = flir.idx2crd(row_w, layout_n_blk_intra) + n_blk = flir.get(coord_n, 0) + n_intra = flir.get(coord_n, 1) + + # For interleaved mode, use single list + n_blk_list.append(n_blk) + n_intra_list.append(n_intra) + + # For non-interleaved mode, keep separate gate/up lists row_gate = expert_off_idx + col_g row_up = row_gate + inter_idx @@ -501,7 +598,12 @@ def load_x_tile(base_k): m_repeat = tile_m // 16 - k_unroll = tile_k_bytes // 64 # K64-byte micro-step (2x MFMA) + k_unroll = tile_k_bytes // 64 // pack_K + + # FP4 packed parameters for mfma_scale (gemm1) + k_unroll_packed = k_unroll // pack_K + m_repeat_packed = m_repeat // pack_M + num_acc_n_packed = num_acc_n // pack_N # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- def load_b_pack(base_k, ki_step, ni, blk_list, intra_list): @@ -520,7 +622,7 @@ def load_b_pack(base_k, ki_step, ni, blk_list, intra_list): lane_div_16=lane_div_16, # 0..3 elem_type=w_elem, kpack_bytes=kpack_bytes, - elem_bytes=elem_bytes, + elem_bytes=w_elem_bytes, unpack_int4=is_int4, ) @@ -543,9 +645,100 @@ def load_b_tile(base_k, blk_list, intra_list): packs1.append(b1) b_tile.append((packs0, packs1)) return b_tile + + # --- FP4 Scale Load Functions for gemm1 --- + def load_scale_inter(arg_scale, rsrc, layout, ku, mni): + """Load a single scale value for FP4 mfma_scale instruction (gemm1).""" + k_lane = lane_div_16 + n_lane = lane_mod_16 + coord_pack = flir.make_coord(mni, ku, k_lane, n_lane) + idx_pack = flir.crd2idx(coord_pack, layout) + scale_view = flir.TensorView( + arg_scale, + (1,), + strides=(1,), + base_indices=(idx_pack,), + element_type=_scale_elem_type(), + ) + scale = flir.copy( + flir.make_copy_atom(_scale_elem_type(), vector_size=1), + scale_view, + None, + alignment=8, + return_vector=True, + src_buffer_resource=rsrc, + src_buffer_offset_in_bytes=False, + ) + return scale + + def load_b_scale_tile_inter(base_k): + """Load B scale tile for FP4 pipeline (gemm1).""" + b_scale_tile = [] + for ku in range_constexpr(k_unroll_packed): + for ni in range_constexpr(num_acc_n_packed): + scale = load_scale_inter( + arg_scale_w, + sw_rsrc, + layout_b_scale, + ku + base_k, + ni + (expert_off_idx + by_n + n_tile_base) // pack_N // 16, + ) + b_scale_tile.append(scale) + return b_scale_tile + + def prefetch_ab_scale_tile_inter(base_k): + """Prefetch A and B scale tiles for FP4 pipeline (gemm1).""" + return placeholder, load_b_scale_tile_inter(base_k) + + # --- FP4 B Load Logic for gemm1 (interleaved gate+up) --- + def load_b_packs_k64_inter(base_k, ku: int, ni: int): + """Load B pack for FP4 pipeline with interleaved gate+up (gemm1).""" + base_k_bytes = base_k * arith.constant(int(w_elem_bytes), index=True) + k0_base = base_k_bytes / 64 + k0 = k0_base + ku + k1 = lane_div_16 + # Use n_blk_list directly for interleaved layout + coord_pack = flir.make_coord(n_blk_list[ni], k0, k1, n_intra_list[ni], 0) + idx_pack = flir.crd2idx(coord_pack, layout_b) + + vec_elems = 16 + b_view = flir.TensorView( + arg_w, + (vec_elems,), + strides=(1,), + base_indices=(idx_pack,), + element_type=_w_elem_type(), + ) + b16 = flir.copy( + flir.make_copy_atom(_w_elem_type(), vector_size=vec_elems), + b_view, + None, + alignment=8, + return_vector=True, + src_buffer_resource=(w_rsrc if w_elem_bytes == 1 else None), + src_buffer_offset_in_bytes=(w_elem_bytes == 1), + ) + # Split 16B pack into two 8B halves. + b_i64x2 = vector.bitcast(I.i64x2, b16) + b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) + b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) + return b0_i64, b1_i64 + + def load_b_tile_inter(base_k): + """Load B tile for FP4 pipeline with interleaved gate+up (gemm1).""" + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n): # Match mixed_moe_gemm_2stage.py + b0, b1 = load_b_packs_k64_inter(base_k, ku, ni) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile - acc_gate = [acc_init] * (num_acc_n * m_repeat) - acc_up = [acc_init] * (num_acc_n * m_repeat) + acc_gate = [acc_init] * (num_acc_n * m_repeat) if not is_gate_up_inter else [acc_init] * (num_acc_n // 2 * m_repeat) + acc_up = [acc_init] * (num_acc_n * m_repeat) if not is_gate_up_inter else [acc_init] * (num_acc_n // 2 * m_repeat) # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- def store_x_tile_to_lds(vec_x_in_parts, lds_base): @@ -568,7 +761,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): k_blocks16=k_blocks16, lds_base=lds_base, vec_part_i32x4=vec_x_in_parts[i], - elem_bytes=elem_bytes, + elem_bytes=x_elem_bytes, ) elif x_load_bytes == 8: lds_store_8b_xor16( @@ -610,8 +803,8 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base_bytes, k_blocks16) col_base_swz = ( col_base_swz_bytes - if elem_bytes == 1 - else (col_base_swz_bytes / arith.constant(int(elem_bytes), index=True)) + if x_elem_bytes == 1 + else (col_base_swz_bytes / arith.constant(int(x_elem_bytes), index=True)) ) coord_a16 = flir.make_coord(curr_row_a_lds, col_base_swz) idx_a16 = flir.crd2idx(coord_a16, layout_lds) @@ -634,17 +827,12 @@ def compute_tile( ): gate_list = list(acc_gate_in) up_list = list(acc_up_in) - mfma_res_ty = vec4_i32 if is_int8 else vec4_f32 - mfma_fn = ( - mfma_i32_k32 - if is_int8 - else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) - ) + mfma_res_ty = _mfma_output_pack_ty() # Optional: prefetch epilogue scales while we are about to run the last MFMA tile, # matching the preshuffle GEMM pattern of overlapping scale loads with MFMA. epilogue_pf = None - if prefetch_epilogue: + if prefetch_epilogue and not no_epilogue_dequant: expert_off_pf = expert_off_idx sw_gate_pf = [] sw_up_pf = [] @@ -653,14 +841,10 @@ def compute_tile( row_gate_idx = expert_off_pf + col_g row_up_idx = row_gate_idx + inter_idx sw_gate_pf.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_gate_idx, vec_width=1, dtype=f32) + buffer_ops.buffer_load(sw_rsrc, row_gate_idx, vec_width=1, dtype=f32) ) sw_up_pf.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_up_idx, vec_width=1, dtype=f32) + buffer_ops.buffer_load(sw_rsrc, row_up_idx, vec_width=1, dtype=f32) ) epilogue_pf = (sw_gate_pf, sw_up_pf) @@ -669,7 +853,7 @@ def _i64_to_v4f16(x_i64): return vector.bitcast(vec4_f16, v1) def mfma_k64(acc_in, a0, a1, b0, b1): - if is_f16: + if is_f16_or_bf16: a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -711,6 +895,108 @@ def mfma_k64(acc_in, a0, a1, b0, b1): b_up_packs1[ni], ) return gate_list, up_list, epilogue_pf + + # --- FP4 Compute Tile using mfma_scale_f32_16x16x128_f8f6f4 (gemm1) --- + def compute_tile_inter( + acc_gate_in, + acc_up_in, + b_tile_in, + lds_base, + *, + a_scale=None, + b_scale=None, + prefetch_epilogue: bool = False, + a0_prefetch=None, + ): + """FP4 compute tile using mfma_scale for interleaved gate+up (gemm1).""" + gate_list = list(acc_gate_in) + up_list = list(acc_up_in) + mfma_res_ty = _mfma_output_pack_ty() + + epilogue_pf = None + if enable_bias and prefetch_epilogue: + gate_bias = [] + up_bias = [] + for ni in range_constexpr(num_acc_n_packed): + global_n = (by_n + n_tile_base) // 2 + ni * 16 + lane_mod_16 + gate_offset = expert_off_idx + global_n + up_offset = expert_off_idx + global_n + inter_dim + gate_bias.append( + buffer_ops.buffer_load(bias_rsrc, gate_offset, vec_width=1, dtype=f32) + ) + up_bias.append( + buffer_ops.buffer_load(bias_rsrc, up_offset, vec_width=1, dtype=f32) + ) + epilogue_pf = (gate_bias, up_bias) + + c0_i64 = arith.constant(0, type=I.i64) + vec4_i64 = I.vec(4, I.i64) + vec8_i32 = I.vec(8, I.i32) + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + # FP4 path using mfma_scale_f32_16x16x128_f8f6f4 + for ku128 in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + # a_scale_i32 = a_scale[ku128 * m_repeat_packed + mi] + # a_scale_val = vector.extract(a_scale_i32, static_position=[0], dynamic_position=[]) + for ni in range_constexpr(num_acc_n_packed): + b_scale_i32 = b_scale[ku128 * num_acc_n_packed + ni] if b_scale else None + b_scale_val = vector.extract(b_scale_i32, static_position=[0], dynamic_position=[]) if b_scale_i32 else arith.i32(0x3F800000) + for ikxdl in range_constexpr(pack_K): + k_idx = ku128 * pack_K + ikxdl + b_packs0, b_packs1 = b_tile_in[k_idx] + col_base = col_offset_base_bytes + (k_idx * 128) // x_elem_pack + + for imxdl in range_constexpr(pack_M): + mi_idx = mi * pack_M + imxdl + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + + if is_fp8_a: + col_base1 = col_base + 64 + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + else: + a128 = pack_i64x4_to_i32x8(a0, a1, c0_i64, c0_i64) + + for inxdl in range_constexpr(pack_N): + # Interleaved gate+up: even indices are gate, odd are up + if inxdl % 2 == 0: + current_list = gate_list + else: + current_list = up_list + ni_idx = ni * pack_N + inxdl + + b0 = b_packs0[ni_idx] + b1 = b_packs1[ni_idx] + b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) + + acc_idx = mi_idx * num_acc_n_packed + (ni_idx // 2) + + current_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_list[acc_idx], + cbsz_g1, + blgp_g1, + 0, + 0x3F800000, + ikxdl * pack_N + inxdl, + b_scale_val, + ], + ) + + return gate_list, up_list, epilogue_pf # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- lds_tile_elems = arith.constant(tile_m * lds_stride, index=True) @@ -721,38 +1007,81 @@ def mfma_k64(acc_in, a0, a1, b0, b1): rocdl.sched_barrier(0) def hot_loop_scheduler(): - mfma_group = num_acc_n * 2 - # K64 micro-step: 2x K32 MFMA per gemm. - mfma_total = (k_unroll * 2) * m_repeat * mfma_group - mfma_per_iter = 2 * mfma_group - sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) - + # mfma_group = num_acc_n * 2 + # # K64 micro-step: 2x K32 MFMA per gemm. + # mfma_total = (k_unroll * 2) * m_repeat * mfma_group + # mfma_per_iter = 2 * mfma_group + # sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) + + # rocdl.sched_dsrd(2) + # rocdl.sched_mfma(2) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(1) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(1) + + # # DS-write hints near the end: match total X LDS-store micro-ops per thread. + # dswr_tail = num_x_loads + # if dswr_tail > sche_iters: + # dswr_tail = sche_iters + # dswr_start = sche_iters - dswr_tail + # for sche_i in range_constexpr(sche_iters): + # rocdl.sched_vmem(1) + # rocdl.sched_mfma(mfma_group) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(mfma_group) + # if sche_i >= dswr_start - 1: + # rocdl.sched_dswr(1) + # rocdl.sched_barrier(0) + rocdl.sched_dsrd(2) rocdl.sched_mfma(2) rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) rocdl.sched_mfma(1) rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) rocdl.sched_mfma(1) # DS-write hints near the end: match total X LDS-store micro-ops per thread. - dswr_tail = num_x_loads - if dswr_tail > sche_iters: - dswr_tail = sche_iters - dswr_start = sche_iters - dswr_tail - for sche_i in range_constexpr(sche_iters): - rocdl.sched_vmem(1) - rocdl.sched_mfma(mfma_group) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(mfma_group) + dswr_start = 12 - 2 + for sche_i in range_constexpr(12): + if sche_i < 10: + rocdl.sched_vmem(1) + if sche_i <= 10: + rocdl.sched_dsrd(1) + + rocdl.sched_mfma(1) if sche_i >= dswr_start - 1: rocdl.sched_dswr(1) + rocdl.sched_barrier(0) + # ================ Unified Pipeline (FP4 / Standard) for gemm1 ================ + # Select pipeline functions based on is_fp4 + def _load_b_tiles(k_val): + if is_gate_up_inter: + return load_b_tile_inter(k_val // 2) + else: + return (load_b_tile(k_val, n_blk_gate, n_intra_gate), load_b_tile(k_val, n_blk_up, n_intra_up)) + + def _compute_tile(acc_g, acc_u, b_cur, lds_base, a_scale, b_scale, *, prefetch_epilogue=False, a0_prefetch=None): + if is_gate_up_inter: + return compute_tile_inter(acc_g, acc_u, b_cur, lds_base, a_scale=a_scale, b_scale=b_scale, prefetch_epilogue=prefetch_epilogue, a0_prefetch=a0_prefetch) + else: + b_g, b_u = b_cur + return compute_tile(acc_g, acc_u, b_g, b_u, lds_base, prefetch_epilogue=prefetch_epilogue, a0_prefetch=a0_prefetch) + + def _prefetch_scale(k_val): + if is_gate_up_inter: + return prefetch_ab_scale_tile_inter(k_val // pack_K // 128) + return placeholder, placeholder + # Prologue: prefetch tile0, store to LDS(cur), sync. k0 = arith.index(0) x_regs0 = load_x_tile(k0) - b_gate_cur = load_b_tile(k0, n_blk_gate, n_intra_gate) - b_up_cur = load_b_tile(k0, n_blk_up, n_intra_up) + b_cur = _load_b_tiles(k0) + a_scale_pong, b_scale_pong = _prefetch_scale(k0) store_x_tile_to_lds(x_regs0, lds_base_cur) gpu.barrier() @@ -767,23 +1096,25 @@ def hot_loop_scheduler(): # Unrolled ping-pong main loop (2 tiles per iteration), leaving 2 tail tiles. c2_tile_k = arith.constant(tile_k * 2, index=True) c_k_main2 = k_in - c2_tile_k + rocdl.sched_barrier(0) for k_iv in range(arith.index(0), c_k_main2, c2_tile_k): # ---- stage 0: prefetch+store ping, compute pong ---- next_k1 = k_iv + tile_k x_regs_ping = load_x_tile(next_k1) - b_gate_ping = load_b_tile(next_k1, n_blk_gate, n_intra_gate) - b_up_ping = load_b_tile(next_k1, n_blk_up, n_intra_up) + b_ping = _load_b_tiles(next_k1) + a_scale_ping, b_scale_ping = _prefetch_scale(next_k1) - acc_gate, acc_up, _ = compute_tile( + acc_gate, acc_up, _ = _compute_tile( acc_gate, acc_up, - b_gate_cur, - b_up_cur, + b_cur, lds_base_pong, + a_scale_pong, + b_scale_pong, a0_prefetch=a0_prefetch_pong, ) - a0_prefetch_pong = None + store_x_tile_to_lds(x_regs_ping, lds_base_ping) hot_loop_scheduler() gpu.barrier() @@ -794,18 +1125,18 @@ def hot_loop_scheduler(): # ---- stage 1: prefetch+store pong, compute ping ---- next_k2 = k_iv + c2_tile_k x_regs_pong = load_x_tile(next_k2) - b_gate_next = load_b_tile(next_k2, n_blk_gate, n_intra_gate) - b_up_next = load_b_tile(next_k2, n_blk_up, n_intra_up) + b_next = _load_b_tiles(next_k2) + a_scale_pong, b_scale_pong = _prefetch_scale(next_k2) - acc_gate, acc_up, _ = compute_tile( + acc_gate, acc_up, _ = _compute_tile( acc_gate, acc_up, - b_gate_ping, - b_up_ping, + b_ping, lds_base_ping, + a_scale_ping, + b_scale_ping, a0_prefetch=a0_prefetch_ping, ) - a0_prefetch_ping = None store_x_tile_to_lds(x_regs_pong, lds_base_pong) hot_loop_scheduler() gpu.barrier() @@ -814,24 +1145,23 @@ def hot_loop_scheduler(): a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) # Advance pong state to next_k2 for next iteration. - b_gate_cur = b_gate_next - b_up_cur = b_up_next + b_cur = b_next # Tail: 2 remaining tiles at (k_in - 2*tile_k) and (k_in - tile_k). k_tail1 = k_in - tile_k x_regs_ping = load_x_tile(k_tail1) - b_gate_ping = load_b_tile(k_tail1, n_blk_gate, n_intra_gate) - b_up_ping = load_b_tile(k_tail1, n_blk_up, n_intra_up) + b_ping = _load_b_tiles(k_tail1) + a_scale_ping, b_scale_ping = _prefetch_scale(k_tail1) - acc_gate, acc_up, _ = compute_tile( + acc_gate, acc_up, _ = _compute_tile( acc_gate, acc_up, - b_gate_cur, - b_up_cur, + b_cur, lds_base_pong, + a_scale_pong, + b_scale_pong, a0_prefetch=a0_prefetch_pong, ) - a0_prefetch_pong = None store_x_tile_to_lds(x_regs_ping, lds_base_ping) hot_loop_scheduler() gpu.barrier() @@ -840,12 +1170,13 @@ def hot_loop_scheduler(): a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) # Epilogue: compute last tile with epilogue scale prefetch to overlap loads with MFMA. - acc_gate, acc_up, epilogue_pf = compute_tile( + acc_gate, acc_up, epilogue_pf = _compute_tile( acc_gate, acc_up, - b_gate_ping, - b_up_ping, + b_ping, lds_base_ping, + a_scale_ping, + b_scale_ping, prefetch_epilogue=True, a0_prefetch=a0_prefetch_ping, ) @@ -860,7 +1191,7 @@ def hot_loop_scheduler(): if epilogue_pf is not None: sw_gate_vals, sw_up_vals = epilogue_pf - else: + elif not no_epilogue_dequant: sw_gate_vals = [] sw_up_vals = [] for ni in range_constexpr(num_acc_n): @@ -868,14 +1199,10 @@ def hot_loop_scheduler(): row_gate_idx = expert_off + col_g row_up_idx = row_gate_idx + inter_idx sw_gate_vals.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_gate_idx, vec_width=1, dtype=f32) + buffer_ops.buffer_load(sw_rsrc, row_gate_idx, vec_width=1, dtype=f32) ) sw_up_vals.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_up_idx, vec_width=1, dtype=f32) + buffer_ops.buffer_load(sw_rsrc, row_up_idx, vec_width=1, dtype=f32) ) # Epilogue hoists to keep IR + Python build time small: @@ -888,6 +1215,8 @@ def hot_loop_scheduler(): # Uses EVec=4 (buffer store "x4" of fp16 elements). _use_cshuffle_epilog = bool(use_cshuffle_epilog) + + num_acc_n_final = num_acc_n // 2 if is_gate_up_inter else num_acc_n if _use_cshuffle_epilog: if lds_out is None: @@ -907,50 +1236,59 @@ def write_row_to_lds( # `row` is the sorted-row index (bx_m + row_in_tile). fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) t2 = fused2 & mask24_i32 - # aiter moe_sorting uses sentinel token_id == tokens for padding. - # Do NOT rely on buffer OOB semantics for scale loads; explicitly mask. - t_valid = arith.cmpu(t2, tokens_i32_v, "ult") - sx = ( - arith.f32(1.0) - if is_f16 - else arith.select( - t_valid, - buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32), - arith.f32(0.0), - ) - ) + # No explicit mask: rely on buffer descriptor OOB to zero-fill when t2 is the + # sentinel (t2 == tokens) or otherwise out-of-range. + sx = arith.f32(1.0) if no_epilogue_dequant else buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32) # Sorted weight aligned with `row` (matches aiter moe_sorting output). if doweight_stage1: tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=f32) - for ni in range_constexpr(num_acc_n): + for ni in range_constexpr(num_acc_n_final): col_local = col_base_local + (ni * 16) - sw_gate = sw_gate_vals[ni] - sw_up = sw_up_vals[ni] - - acc_idx = mi * num_acc_n + ni + + acc_idx = mi * num_acc_n_final + ni vg = vector.extract( acc_gate[acc_idx], static_position=[ii], dynamic_position=[] ) vu = vector.extract( acc_up[acc_idx], static_position=[ii], dynamic_position=[] ) + + if is_int_mode: + vg = arith.sitofp(T.f32, vg) + vu = arith.sitofp(T.f32, vu) - if is_int8: - vg = arith.sitofp(f32, vg) - vu = arith.sitofp(f32, vu) - vg = vg * sx * sw_gate - vu = vu * sx * sw_up + if not no_epilogue_dequant: + sw_gate = sw_gate_vals[ni] + sw_up = sw_up_vals[ni] + vg = vg * sx * sw_gate + vu = vu * sx * sw_up + + if enable_bias: + gate_bias_list, up_bias_list = epilogue_pf + vg = vg + gate_bias_list[ni] + vu = vu + up_bias_list[ni] + + if act == "swiglu": + y = swiglu(vg, vu) + else: + y = silu(vg) * vu - y = silu(vg) * vu if doweight_stage1: y = y * tw - y16 = arith.trunc_f(T.f16(), y) - + + if not is_cast_out: + y = arith.trunc_f(T.f16(), y) + vec1_out_elem = vec1_f16 + alignment = 2 + else: + vec1_out_elem = vec1_f32 + alignment = 1 + lds_idx = row_base_lds + col_local - v1 = vector.from_elements(vec1_f16, [y16]) - vector.store(v1, lds_out, [lds_idx], alignment=2) + v1 = vector.from_elements(vec1_out_elem, [y]) + vector.store(v1, lds_out, [lds_idx], alignment=alignment) def precompute_row(*, row_local, row): fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) @@ -959,18 +1297,33 @@ def precompute_row(*, row_local, row): return (t2 * topk_i32_v + s2) * inter_i32_local def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): - # Guard against sentinel token ids (t == tokens) produced by aiter moe_sorting padding. - # OOB buffer stores are not guaranteed to be safe on all paths, so predicate explicitly. - fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) - t2 = fused2 & mask24_i32 - t_valid = arith.cmpu(t2, tokens_i32_v, "ult") - _if_valid = scf.IfOp(t_valid) - with _if_valid.then(): + if not is_cast_out: idx0 = row_ctx col_i32 = arith.index_cast(i32, col_g0) idx_out = idx0 + col_i32 # Vectorized fp16 store (EVec=4). buffer_ops.buffer_store(frag, out_rsrc, idx_out) + else: + fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=i32) + t2 = fused2 & mask24_i32 + t_valid = arith.cmpu(t2, tokens_i32, "ult") + _if_valid = scf.IfOp(t_valid) + with _if_valid.then(): + frag = vector.bitcast(vec4_f32, frag) + frag0 = vector.extract(frag, static_position=[0], dynamic_position=[]) + frag1 = vector.extract(frag, static_position=[1], dynamic_position=[]) + frag2 = vector.extract(frag, static_position=[2], dynamic_position=[]) + frag3 = vector.extract(frag, static_position=[3], dynamic_position=[]) + + out_fp8 = arith.i32(0) + out_fp8 = rocdl.cvt_pk_fp8_f32(src_a=arith._unwrap_value(frag0), src_b=arith._unwrap_value(frag1), old=arith._unwrap_value(out_fp8), word_sel=0, res=I.i32) + out_fp8 = rocdl.cvt_pk_fp8_f32(src_a=arith._unwrap_value(frag2), src_b=arith._unwrap_value(frag3), old=arith._unwrap_value(out_fp8), word_sel=1, res=I.i32) + + idx0 = row_ctx + col_i32 = arith.index_cast(i32, col_g0) + idx_out = idx0 + col_i32 + buffer_ops.buffer_store(out_fp8, out_rsrc, idx_out // 4) + mfma_epilog( use_cshuffle=True, arith=arith, @@ -979,17 +1332,18 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): scf=scf, range_constexpr=range_constexpr, tile_m=tile_m, - tile_n=tile_n, + tile_n=tile_n // 2 if is_gate_up_inter else tile_n, e_vec=4, m_repeat=m_repeat, - num_acc_n=num_acc_n, + num_acc_n=num_acc_n_packed if is_gate_up_inter else num_acc_n, tx=tx, lane_div_16=lane_div_16, lane_mod_16=lane_mod_16, bx_m=bx_m, - by_n=by_n, - n_tile_base=n_tile_base, + by_n=by_n // 2 if is_gate_up_inter else by_n, + n_tile_base=n_tile_base // 2 if is_gate_up_inter else n_tile_base, lds_out=lds_out, + frag_elem_type=I.f32 if is_cast_out else I.f16, write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, @@ -1005,20 +1359,10 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): s2_raw = fused2 >> 24 t2 = t2_raw s2 = s2_raw - t_valid = arith.cmpu(t2, tokens_i32_v, "ult") - - # Do NOT rely on buffer OOB semantics for scale loads; explicitly mask. - sx0 = ( - arith.f32(1.0) - if is_f16 - else arith.select( - t_valid, - buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32), - arith.f32(0.0), - ) - ) + + sx0 = arith.f32(1.0) if no_epilogue_dequant else buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=f32) sx = sx0 - zero_out = arith.constant(0.0, type=out_mlir()) + zero_out = arith.constant(0.0, type=_out_elem_type()) # out linear index base = ((t*topk + s)*inter_dim) (invariant across ni) idx0 = (t2 * topk_i32_v + s2) * inter_i32_local @@ -1027,36 +1371,44 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): if doweight_stage1: tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=f32) - _if_valid = scf.IfOp(t_valid) - with _if_valid.then(): - for ni in range_constexpr(num_acc_n): - col_i32 = col_i32_list[ni] + for ni in range_constexpr(num_acc_n_final): + col_i32 = col_i32_list[ni] + + acc_idx = mi * num_acc_n_final + ni + vg = vector.extract( + acc_gate[acc_idx], static_position=[ii], dynamic_position=[] + ) + vu = vector.extract( + acc_up[acc_idx], static_position=[ii], dynamic_position=[] + ) + + if is_int_mode: + vg = arith.sitofp(f32, vg) + vu = arith.sitofp(f32, vu) + + if not no_epilogue_dequant: sw_gate = sw_gate_vals[ni] sw_up = sw_up_vals[ni] - - acc_idx = mi * num_acc_n + ni - vg = vector.extract( - acc_gate[acc_idx], static_position=[ii], dynamic_position=[] - ) - vu = vector.extract( - acc_up[acc_idx], static_position=[ii], dynamic_position=[] - ) - - if is_int8: - vg = arith.sitofp(f32, vg) - vu = arith.sitofp(f32, vu) vg = vg * sx * sw_gate vu = vu * sx * sw_up - + + if enable_bias: + gate_bias_list, up_bias_list = epilogue_pf + vg = vg + gate_bias_list[ni] + vu = vu + up_bias_list[ni] + + if act == "swiglu": + y = swiglu(vg, vu) + else: y = silu(vg) * vu - if doweight_stage1: - y = y * tw - y = arith.trunc_f(out_mlir(), y) - idx_out0 = idx0 + col_i32 - buffer_ops.buffer_store(y, out_rsrc, idx_out0) + + if doweight_stage1: + y = y * tw + y = arith.trunc_f(_out_elem_type(), y) + idx_out0 = idx0 + col_i32 + buffer_ops.buffer_store(y, out_rsrc, idx_out0) mfma_epilog( - use_cshuffle=False, arith=arith, range_constexpr=range_constexpr, @@ -1069,15 +1421,16 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): @flir.jit def __call__( self: flir.T.i64, - arg_out: lambda: T.memref(DYN, out_mlir()), - arg_x: lambda: T.memref(DYN, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_w: lambda: T.memref(DYN, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_scale_x: lambda: T.memref(DYN, T.f32()), - arg_scale_w: lambda: T.memref(experts * (2 * inter_dim), T.f32()), + arg_out: lambda: T.memref(DYN, _out_elem_type()), + arg_x: lambda: T.memref(DYN, _x_elem_type()), + arg_w: lambda: T.memref(DYN, _w_elem_type()), + arg_scale_x: lambda: T.memref(DYN, _scale_elem_type()), + arg_scale_w: lambda: T.memref(DYN, _scale_elem_type()), arg_sorted_token_ids: lambda: T.memref(DYN, T.i32()), arg_expert_ids: lambda: T.memref(DYN, T.i32()), arg_sorted_weights: lambda: T.memref(DYN, T.f32()), arg_max_token_ids: lambda: T.memref(DYN, T.i32()), + arg_bias: lambda: T.memref(DYN, T.f32()), tokens_in: lambda: T.index(), inter_in: lambda: T.index(), k_in: lambda: T.index(), @@ -1085,11 +1438,10 @@ def __call__( stream_ptr: lambda: T.i64(), # PyTorch stream pointer ): bdx = 256 - gx = inter_in / arith.index(tile_n) + gx = (arith.index(2) * inter_in / arith.index(tile_n)) if is_gate_up_inter else (inter_in / arith.index(tile_n)) # Use host-provided upper bound for M blocks (same as aiter moe_sorting allocation). # This avoids device->host sync on num_valid_ids. gy = size_expert_ids_in - stream_token = stream_ptr_to_async_token(stream_ptr) flir.gpu_ext.LaunchFuncOp( [module_name, "moe_gemm1"], @@ -1105,6 +1457,7 @@ def __call__( arg_expert_ids, arg_sorted_weights, arg_max_token_ids, + arg_bias, tokens_in, inter_in, k_in, @@ -1117,6 +1470,10 @@ def __call__( exe = flydsl.compile(m) return exe +####==================== gemm1 pipeline end =====================### + + +####==================== gemm2 pipeline start =====================### @functools.lru_cache(maxsize=1024) def compile_moe_gemm2( @@ -1129,21 +1486,28 @@ def compile_moe_gemm2( tile_n: int, tile_k: int, doweight_stage2: bool, - in_dtype: str = "fp8", + x_dtype: str = "fp8", + w_dtype: str = "fp8", out_dtype: str = "f16", use_cshuffle_epilog: bool | None = None, # Optional experiment: write per-(token,slot) output (no atomics) into an output shaped # [tokens*topk, model_dim] (or [tokens, topk, model_dim] flattened), then reduce over topk outside. # This can reduce atomic contention for small tokens at the cost of extra bandwidth / reduction. accumulate: bool = True, + enable_bias: bool = False, ): """Compile stage2 kernel (`moe_gemm2`) and return the compiled executable. - in_dtype: - - "fp8": A2/W are fp8 - - "fp16": A2/W are fp16 - - "int8": A2/W are int8 - - "int4": W4A8 path: A2 is int8, W is packed int4 unpacked to int8 in-kernel + x_dtype: + - "fp8": X/W are fp8 + - "fp16": X/W are fp16 + - "int8": X/W are int8 + - "int4": W4A8 path: X is int8, W is packed int4 unpacked to int8 in-kernel + w_dtype: + - "fp8": W is fp8 + - "fp16": W is fp16 + - "int8": W is int8 + - "int4": W4A8 path: W is packed int4 unpacked to int8 in-kernel Stage2 output supports: - out_dtype="f16": fp16 half2 atomics (fast, can overflow to +/-inf for bf16 workloads) @@ -1156,31 +1520,76 @@ def compile_moe_gemm2( allocator = SmemAllocator(None, arch=gpu_arch) _state = {} - if in_dtype not in ("fp8", "fp16", "int8", "int4"): - raise ValueError(f"in_dtype must be one of ('fp8','fp16','int8','int4'), got {in_dtype!r}") - is_f16 = in_dtype == "fp16" - elem_bytes = 2 if is_f16 else 1 - out_s = str(out_dtype).strip().lower() + + pipeline_manager = PreshufflePipelineManager(x_dtype, w_dtype, out_dtype, use_cshuffle_epilog) + pipeline_manager.check_type_valid() + + epilog_pipeline = pipeline_manager.get_epilog_pipeline() + mfma_pipeline = pipeline_manager.get_mfma_pipeline() + mfma_fn = pipeline_manager.get_mfma_fn() + + x_elem_bytes = pipeline_manager.get_a_elem_bytes() + w_elem_bytes = pipeline_manager.get_b_elem_bytes() + out_elem_bytes = pipeline_manager.get_out_elem_bytes() + + x_elem_pack = pipeline_manager.a_elem_pack + w_elem_pack = pipeline_manager.b_elem_pack + + tile_k_bytes = int(tile_k) * int(x_elem_bytes) + if (tile_k_bytes % 64) != 0: + raise ValueError( + f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " + f"(tile_k={tile_k}, x_elem_bytes={x_elem_bytes})" + ) + + def _get_mfma_dict_value(key, pipeline): + value = mfma_pipeline_dicts[key][pipeline] + return value() if callable(value) else value + + def _x_elem_type(): + return _get_mfma_dict_value("a_elem_type", mfma_pipeline) + def _w_elem_type(): + return _get_mfma_dict_value("b_elem_type", mfma_pipeline) + def _scale_elem_type(): + return _get_mfma_dict_value("scale_elem_type", mfma_pipeline) + def _out_elem_type(): + return _get_mfma_dict_value("out_elem_type", epilog_pipeline) + def _x_vec16_type(): + return _get_mfma_dict_value("x_vec16_type", mfma_pipeline) + def _w_vec16_type(): + return _get_mfma_dict_value("w_vec16_type", mfma_pipeline) + def _mfma_input_pack_ty(): + return _get_mfma_dict_value("mfma_input_pack_ty", mfma_pipeline) + def _mfma_output_pack_ty(): + return _get_mfma_dict_value("mfma_output_pack_ty", mfma_pipeline) + + is_f16_or_bf16 = mfma_pipeline in [MfmaPipeline.F16F16_16x16_PIPELINE, + MfmaPipeline.BF16BF16_16x16_PIPELINE] + is_int4 = mfma_pipeline in [MfmaPipeline.I8I4_16x16_PIPELINE] + is_int8 = mfma_pipeline in [MfmaPipeline.I8I8_16x16_PIPELINE] + is_fp4 = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, + MfmaPipeline.F8F4_MXFP4_PIPELINE] + + is_int_mode = is_int8 or is_int4 + + no_epilogue_dequant = is_fp4 or is_f16_or_bf16 + + # FP4 specific parameters for mfma_scale_f32_16x16x128_f8f6f4 + pack_M = 2 if is_fp4 else 1 + pack_N = 2 if is_fp4 else 1 + pack_K = 2 if is_fp4 else 1 + is_fp8_a = mfma_pipeline == MfmaPipeline.F8F4_MXFP4_PIPELINE + cbsz = 0 if is_fp8_a else 4 # fp8 a: cbsz=0, fp4 a: cbsz=4 + blgp = 4 + + out_s = pipeline_manager.out_dtype if out_s not in ("f16", "fp16", "half", "bf16", "bfloat16", "f32", "fp32", "float"): raise ValueError(f"out_dtype must be 'f16', 'bf16', or 'f32', got {out_dtype!r}") out_is_f32 = out_s in ("f32", "fp32", "float") out_is_bf16 = out_s in ("bf16", "bfloat16") + if (not bool(accumulate)) and out_is_f32: raise ValueError("compile_moe_gemm2(accumulate=False) only supports out_dtype in {'f16','bf16'}") - is_int4 = in_dtype == "int4" - # INT4 here means W4A8: A2 is int8, W is packed int4 and unpacked to int8 in-kernel. - is_int8 = (in_dtype == "int8") or is_int4 - - mfma_i32_k32 = None - if is_int8: - mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( - rocdl, "mfma_i32_16x16x32_i8", None - ) - if mfma_i32_k32 is None: - raise AttributeError( - "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " - "(or `rocdl.mfma_i32_16x16x32_i8`)." - ) DYN = ir.ShapedType.get_dynamic_size() size_out = DYN @@ -1189,20 +1598,21 @@ def compile_moe_gemm2( size_expert_ids_shape = DYN size_scale_x = DYN # W is packed int4 for W4A8: 2 values per byte. - size_w = (experts * model_dim * inter_dim) // 2 if is_int4 else (experts * model_dim * inter_dim) + size_w = (experts * model_dim * inter_dim) // w_elem_pack total_threads = 256 - tile_k_bytes = int(tile_k) * int(elem_bytes) + tile_k_bytes = int(tile_k) * int(x_elem_bytes) // x_elem_pack if (tile_k_bytes % 64) != 0: raise ValueError( f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " - f"(tile_k={tile_k}, elem_bytes={elem_bytes})" + f"(tile_k={tile_k}, elem_bytes={x_elem_bytes})" ) - bytes_x_per_tile = int(tile_m) * int(tile_k) * int(elem_bytes) + + bytes_x_per_tile = int(tile_m) * int(tile_k) * int(x_elem_bytes) // x_elem_pack if bytes_x_per_tile % total_threads != 0: raise ValueError( - "tile_m*tile_k*elem_bytes must be divisible by " - f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={elem_bytes}" + "tile_m*tile_k*x_elem_bytes must be divisible by " + f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, x_elem_bytes={x_elem_bytes}" ) bytes_per_thread_x = bytes_x_per_tile // total_threads @@ -1235,7 +1645,6 @@ def compile_moe_gemm2( ) # NOTE: Keep this as a callable so we don't require an MLIR Context at Python-time. - out_elem = (T.f32 if out_is_f32 else (T.bf16 if out_is_bf16 else T.f16)) epilog_tag = "cshuffle" # IMPORTANT: include tiling in the module name to avoid accidentally reusing a compiled # binary for a different (tile_m, tile_n, tile_k) configuration. @@ -1244,7 +1653,7 @@ def compile_moe_gemm2( # Dynamic-shape variant: safe to reuse across (tokens/sorted_size/size_expert_ids) at runtime. # Keep a distinct ABI tag so the compile cache never mixes with historical signatures. module_name = ( - f"mfma_moe2_{in_dtype}_{out_s}_{epilog_tag}" + f"mfma_moe2_{x_dtype}_{w_dtype}_{out_s}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" f"_abi2" # mask sentinel token ids on loads/stores to avoid illegal address faults ).replace("-", "_") @@ -1261,34 +1670,34 @@ def init_gpu_module(self): # - epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes) # # This reduces LDS usage from sum(...) to max(...). - lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(elem_bytes) + lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(x_elem_bytes) lds_out_bytes = 2 * tile_m * tile_n if _use_cshuffle_epilog else 0 # f16 bytes lds_total_bytes = max(lds_x_bytes, lds_out_bytes) - lds_total_elems = lds_total_bytes if elem_bytes == 1 else (lds_total_bytes // 2) - x_lds_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) - _state["lds_x_decl"] = allocator.allocate_array(x_lds_elem, lds_total_elems) + lds_total_elems = lds_total_bytes // x_elem_bytes + _state["lds_x_decl"] = allocator.allocate_array(_x_elem_type(), lds_total_elems) allocator.finalize() @flir.kernel def moe_gemm2( self: flir.T.i64, - arg_out: lambda: T.memref(size_out, out_elem()), - arg_x: lambda: T.memref(size_x, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_w: lambda: T.memref(size_w, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_scale_x: lambda: T.memref(size_scale_x, T.f32()), - arg_scale_w: lambda: T.memref(experts * model_dim, T.f32()), + arg_out: lambda: T.memref(size_out, _out_elem_type()), + arg_x: lambda: T.memref(size_x, _x_elem_type()), + arg_w: lambda: T.memref(size_w, _w_elem_type()), + arg_scale_x: lambda: T.memref(size_scale_x, _scale_elem_type()), + arg_scale_w: lambda: T.memref(DYN, _scale_elem_type()), arg_sorted_token_ids: lambda: T.memref(size_sorted, T.i32()), arg_expert_ids: lambda: T.memref(size_expert_ids_shape, T.i32()), arg_sorted_weights: lambda: T.memref(size_sorted, T.f32()), arg_num_valid_ids: lambda: T.memref(DYN, T.i32()), + arg_bias: lambda: T.memref(DYN, T.f32()), tokens_in: lambda: T.index(), n_in: lambda: T.index(), k_in: lambda: T.index(), size_expert_ids_in: lambda: T.index(), ): - x_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + x_elem = _x_elem_type() # For int4, weights are stored as packed bytes (i8) and unpacked to i8 packs. - w_elem = I.f16 if is_f16 else (I.i8 if is_int8 else I.f8) + w_elem = _w_elem_type() f16 = I.f16 f32 = I.f32 i32 = I.i32 @@ -1298,18 +1707,18 @@ def moe_gemm2( vec1_f16 = I.vec(1, f16) vec2_f16 = I.vec(2, f16) vec4_f16 = I.vec(4, f16) - vec16_elems = 16 if elem_bytes == 1 else 8 - vec8_elems = 8 if elem_bytes == 1 else 4 - vec4_elems = 4 if elem_bytes == 1 else 2 - vec8_x = I.vec(vec8_elems, x_elem) - vec16_x = I.vec(vec16_elems, x_elem) + vec16_x_elems = 16 if x_elem_bytes == 1 else 8 + vec8_x_elems = 8 if x_elem_bytes == 1 else 4 + vec4_x_elems = 4 if x_elem_bytes == 1 else 2 + vec8_x = I.vec(vec8_x_elems, x_elem) + vec16_x = I.vec(vec16_x_elems, x_elem) vec1_i64 = I.vec(1, i64) vec2_i64 = I.vec(2, i64) + placeholder = arith.constant(0, type=I.i32) + acc_init = ( - arith.constant_vector(0, vec4_i32) - if is_int8 - else arith.constant_vector(0.0, vec4_f32) + arith.constant_vector(0, _mfma_output_pack_ty()) ) # A2 layout (flatten token-slot -> M). @@ -1321,10 +1730,12 @@ def moe_gemm2( c_n_total = arith.constant(experts * model_dim, index=True) kpack_bytes = 8 if is_int4 else 16 b_layout = make_preshuffle_b_layout( - flir, arith, c_n=c_n_total, c_k=k_in, kpack_bytes=kpack_bytes, elem_bytes=elem_bytes + flir, arith, c_n=c_n_total, c_k=k_in // pack_K, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes ) layout_b = b_layout.layout_b - c_k0 = (k_in * arith.constant(int(elem_bytes), index=True)) / arith.index(64) + layout_b_scale = make_preshuffle_scale_layout( + flir, arith, c_mn=c_n_total, c_k=k_in, + ) shape_lds = flir.make_shape(tile_m, tile_k) stride_lds = flir.make_stride(lds_stride, 1) @@ -1339,12 +1750,12 @@ def moe_gemm2( # XOR16 swizzle parameter (in bytes; constant, power-of-two in our configs). k_blocks16 = arith.constant(tile_k_bytes // 16, index=True) - atom_x_s16 = flir.make_copy_atom(x_elem, vector_size=vec16_elems) + atom_x_s16 = flir.make_copy_atom(x_elem, vector_size=vec16_x_elems) atom_x_s8 = flir.make_copy_atom(x_elem, vector_size=8) atom_x_s4 = flir.make_copy_atom(x_elem, vector_size=4) - atom_x_g2r16 = flir.make_copy_atom(x_elem, vector_size=vec16_elems) - atom_x_g2r8 = flir.make_copy_atom(x_elem, vector_size=vec8_elems) - atom_x_g2r4 = flir.make_copy_atom(x_elem, vector_size=vec4_elems) + atom_x_g2r16 = flir.make_copy_atom(x_elem, vector_size=vec16_x_elems) + atom_x_g2r8 = flir.make_copy_atom(x_elem, vector_size=vec8_x_elems) + atom_x_g2r4 = flir.make_copy_atom(x_elem, vector_size=vec4_x_elems) layout_tx_wave_lane = flir.make_layout((4, 64), stride=(64, 1)) layout_lane16 = flir.make_layout((4, 16), stride=(16, 1)) layout_lin_rowcol = flir.make_layout((tile_m, tile_k), stride=(tile_k, 1)) @@ -1370,7 +1781,7 @@ def moe_gemm2( c_topk = arith.constant(topk, index=True) # X(A2): [tokens*topk, inter_dim] bytes = tokens*topk*k*elem_bytes - x_nbytes_idx = (tokens_in * c_topk) * k_in * arith.constant(int(elem_bytes), index=True) + x_nbytes_idx = (tokens_in * c_topk) * k_in * arith.constant(int(x_elem_bytes), index=True) x_nbytes_i32 = arith.index_cast(i32, x_nbytes_idx) x_rsrc = buffer_ops.create_buffer_resource( arg_x, max_size=False, num_records_bytes=x_nbytes_i32 @@ -1393,7 +1804,7 @@ def moe_gemm2( arg_out, max_size=False, num_records_bytes=out_nbytes_i32 ) # fp16 path ignores scales completely (implicit scale=1.0). - if is_f16: + if is_f16_or_bf16: sx_rsrc = None sw_rsrc = None else: @@ -1420,6 +1831,9 @@ def moe_gemm2( arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_i32 ) + # bias resource for optional bias addition + bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None + # expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4 eid_nbytes_idx = size_expert_ids_in * arith.constant(4, index=True) eid_nbytes_i32 = arith.index_cast(i32, eid_nbytes_idx) @@ -1449,7 +1863,7 @@ def _moe_gemm2_then_body(): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. - if is_f16: + if is_f16_or_bf16: if bytes_per_thread_x % 16 != 0: raise ValueError( f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" @@ -1470,9 +1884,9 @@ def _moe_gemm2_then_body(): chunk_i32 = x_load_bytes // 4 # dwords per chunk (1/2/4) vec4_i32 = I.vec(4, i32) - c_k_div4 = (k_in * arith.constant(int(elem_bytes), index=True)) / arith.index(4) + c_k_div4 = (k_in * arith.constant(int(x_elem_bytes), index=True)) / arith.index(4) layout_x_div4 = flir.make_layout((m_in, c_k_div4), stride=(c_k_div4, 1)) - tile_k_dwords = (int(tile_k) * int(elem_bytes)) // 4 + tile_k_dwords = (int(tile_k) * int(x_elem_bytes)) // 4 layout_x_tile_div4 = flir.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) c_chunk_i32 = arith.constant(chunk_i32, index=True) tx_i32_base = tx * c_chunk_i32 @@ -1499,7 +1913,7 @@ def x_tile_chunk_coord_i32(i: int): def load_x(idx_i32): if x_load_bytes == 16: - idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * arith.index(2)) + idx_elem = idx_i32 if x_elem_bytes == 1 else (idx_i32 * arith.index(2)) return buffer_copy_gmem16_dwordx4( flir, arg=arg_x, @@ -1507,7 +1921,7 @@ def load_x(idx_i32): idx_i32=idx_elem, atom_g2r16=atom_x_g2r16, rsrc=x_rsrc, - vec_elems=vec16_elems, + vec_elems=vec16_x_elems, ) idx_bytes = idx_i32 * arith.index(4) atom = atom_x_g2r8 if x_load_bytes == 8 else atom_x_g2r4 @@ -1526,6 +1940,7 @@ def load_x(idx_i32): return_vector=True, src_buffer_resource=x_rsrc, src_buffer_offset_in_bytes=True, + nontemporal=True, ) # decode routed token once (per thread's M-slice) and build a base offset. @@ -1554,7 +1969,7 @@ def load_x(idx_i32): x_row_base_div4.append(row_ts_idx * c_k_div4) def load_x_tile(base_k): - base_k_div4 = (base_k * arith.constant(int(elem_bytes), index=True)) / arith.index(4) + base_k_div4 = (base_k * arith.constant(int(x_elem_bytes), index=True)) / arith.index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] @@ -1579,8 +1994,8 @@ def load_x_tile(base_k): col_offset_base = flir.crd2idx(flir.make_coord(lane_div_16, 0), layout_lane16) col_offset_base_bytes = ( col_offset_base - if elem_bytes == 1 - else (col_offset_base * arith.constant(int(elem_bytes), index=True)) + if x_elem_bytes == 1 + else (col_offset_base * arith.constant(int(x_elem_bytes), index=True)) ) # Dynamic N tiling within block. @@ -1609,7 +2024,12 @@ def load_x_tile(base_k): n_intra_list.append(flir.get(coord_w, 1)) m_repeat = tile_m // 16 - k_unroll = tile_k_bytes // 64 # K64-byte micro-step (2x MFMA) + k_unroll = tile_k_bytes // 64 if not is_fp4 else tile_k_bytes // 128 # K64-byte micro-step (2x MFMA) + + # FP4 packed parameters for mfma_scale_f32_16x16x128_f8f6f4 + k_unroll_packed = k_unroll // pack_K + m_repeat_packed = m_repeat // pack_M + num_acc_n_packed = num_acc_n // pack_N # --- B Load Logic (K64) --- def load_b_pack(base_k, ki_step, ni): @@ -1628,7 +2048,7 @@ def load_b_pack(base_k, ki_step, ni): lane_div_16=lane_div_16, # 0..3 elem_type=w_elem, kpack_bytes=kpack_bytes, - elem_bytes=elem_bytes, + elem_bytes=w_elem_bytes, unpack_int4=is_int4, ) @@ -1651,7 +2071,66 @@ def load_b_tile(base_k): packs1.append(b1) b_tile.append((packs0, packs1)) return b_tile - + + # --- FP4 Scale Load Functions for gemm2 --- + def load_scale(arg_scale, rsrc, layout, ku, mni): + """Load a single scale value for FP4 mfma_scale instruction (gemm2).""" + k_lane = lane_div_16 + n_lane = lane_mod_16 + coord_pack = flir.make_coord(mni, ku, k_lane, n_lane) + idx_pack = flir.crd2idx(coord_pack, layout) + scale_view = flir.TensorView( + arg_scale, + (1,), + strides=(1,), + base_indices=(idx_pack,), + element_type=_scale_elem_type(), + ) + scale = flir.copy( + flir.make_copy_atom(_scale_elem_type(), vector_size=1), + scale_view, + None, + alignment=8, + return_vector=True, + src_buffer_resource=rsrc, + src_buffer_offset_in_bytes=False, + ) + return scale + + def load_b_scale_tile(base_k): + """Load B scale tile for FP4 pipeline (gemm2).""" + b_scale_tile = [] + for ku in range_constexpr(k_unroll_packed): + for ni in range_constexpr(num_acc_n_packed): + scale = load_scale( + arg_scale_w, + sw_rsrc, + layout_b_scale, + ku + base_k, + ni + (expert_off_idx + by_n + n_tile_base) // pack_N // 16, + ) + b_scale_tile.append(scale) + return b_scale_tile + + def load_a_scale_tile(base_k): + """Load A scale tile for FP4 pipeline (gemm2).""" + a_scale_tile = [] + for ku in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + scale = load_scale( + arg_scale_x, + sx_rsrc, + layout_a_scale, + ku + base_k, + mi + bx_m // pack_M // 16, + ) + a_scale_tile.append(scale) + return a_scale_tile + + def prefetch_ab_scale_tile(base_k): + """Prefetch A and B scale tiles for FP4 pipeline (gemm2).""" + return [None, load_b_scale_tile(base_k)] + # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- def store_x_tile_to_lds(vec_x_in_parts, lds_base): for i in range_constexpr(num_x_loads): @@ -1673,7 +2152,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): k_blocks16=k_blocks16, lds_base=lds_base, vec_part_i32x4=vec_x_in_parts[i], - elem_bytes=elem_bytes, + elem_bytes=x_elem_bytes, ) elif x_load_bytes == 8: lds_store_8b_xor16( @@ -1715,8 +2194,8 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base_bytes, k_blocks16) col_base_swz = ( col_base_swz_bytes - if elem_bytes == 1 - else (col_base_swz_bytes / arith.constant(int(elem_bytes), index=True)) + if x_elem_bytes == 1 + else (col_base_swz_bytes / arith.constant(int(x_elem_bytes), index=True)) ) coord_a16 = flir.make_coord(curr_row_a_lds, col_base_swz) idx_a16 = flir.crd2idx(coord_a16, layout_lds) @@ -1727,27 +2206,20 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): a1 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) return a0, a1 - def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False, a0_prefetch=None): + def compute_tile(acc_in, b_tile_in, lds_base, *, a_scale=None, b_scale=None, prefetch_epilogue: bool = False, a0_prefetch=None): acc_list = list(acc_in) - mfma_res_ty = vec4_i32 if is_int8 else vec4_f32 - mfma_fn = ( - mfma_i32_k32 - if is_int8 - else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) - ) - + mfma_res_ty = _mfma_output_pack_ty() + epilogue_pf = None if prefetch_epilogue: expert_off_pf = expert_off_idx - sw_pf = [] - for ni in range_constexpr(num_acc_n): - col_g = col_g_list[ni] - row_w_idx = expert_off_pf + col_g - sw_pf.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_w_idx, vec_width=1, dtype=f32) - ) + sw_pf = None + if not no_epilogue_dequant: + sw_pf = [] + for ni in range_constexpr(num_acc_n): + col_g = col_g_list[ni] + row_w_idx = expert_off_pf + col_g + sw_pf.append(buffer_ops.buffer_load(sw_rsrc, row_w_idx, vec_width=1, dtype=f32)) # Also prefetch per-row routed/topk weights (sorted_weights) when enabled. tw_pf = None if doweight_stage2: @@ -1765,14 +2237,87 @@ def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False sorted_w_rsrc, sorted_row_pf, vec_width=1, dtype=f32 ) ) - epilogue_pf = (sw_pf, tw_pf) + # Prefetch bias values when enabled + bias_pf = None + if enable_bias: + bias_pf = [] + for ni in range_constexpr(num_acc_n): + global_n = by_n + n_tile_base + ni * 16 + lane_mod_16 + bias_offset = expert_off_pf + global_n + bias_pf.append( + buffer_ops.buffer_load(bias_rsrc, bias_offset, vec_width=1, dtype=f32) + ) + epilogue_pf = (sw_pf, tw_pf, bias_pf) + + if is_fp4: + c0_i64 = arith.constant(0, type=I.i64) + vec4_i64 = I.vec(4, I.i64) + vec8_i32 = I.vec(8, I.i32) + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + # FP4 path using mfma_scale_f32_16x16x128_f8f6f4 + for ku128 in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + for ni in range_constexpr(num_acc_n_packed): + b_scale_i32 = b_scale[ku128 * num_acc_n_packed + ni] + b_scale_val = vector.extract(b_scale_i32, static_position=[0], dynamic_position=[]) + for ikxdl in range_constexpr(pack_K): + k_idx = ku128 * pack_K + ikxdl + b_packs0, b_packs1 = b_tile_in[k_idx] + col_base = col_offset_base + (k_idx * 128) // x_elem_pack + + for imxdl in range_constexpr(pack_M): + col_base0 = col_base + mi_idx = mi * pack_M + imxdl + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) + + if is_fp8_a: + col_base1 = col_base + 64 + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + else: + a128 = pack_i64x4_to_i32x8(a0, a1, c0_i64, c0_i64) + + for inxdl in range_constexpr(pack_N): + ni_idx = ni * pack_N + inxdl + + b0 = b_packs0[ni_idx] + b1 = b_packs1[ni_idx] + b128 = pack_i64x4_to_i32x8(b0, b1, c0_i64, c0_i64) + + acc_idx = mi_idx * num_acc_n + ni_idx + acc_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + acc_list[acc_idx], + cbsz, + blgp, + # use per tensor quant a1 for now, + 0, + 0x3F800000, + ikxdl * pack_N + inxdl, + b_scale_val, + ], + ) + return acc_list, epilogue_pf def _i64_to_v4f16(x_i64): v1 = vector.from_elements(vec1_i64, [x_i64]) return vector.bitcast(vec4_f16, v1) def mfma_k64(acc0, a0, a1, b0, b1): - if is_f16: + if is_f16_or_bf16: a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -1806,7 +2351,7 @@ def mfma_k64(acc0, a0, a1, b0, b1): b_packs1[ni], ) return acc_list, epilogue_pf - + # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- lds_tile_elems = arith.constant(tile_m * lds_stride, index=True) lds_base_cur = arith.index(0) @@ -1854,59 +2399,113 @@ def hot_loop_scheduler(): # - MFMA group size per "slot": num_acc_n # - Total MFMA per tile: (2*K32 per K64) * k_unroll * m_repeat * num_acc_n # - We emit (mfma_group + dsrd + mfma_group) per scheduler iteration. - mfma_group = num_acc_n - mfma_total = (k_unroll * 2) * m_repeat * mfma_group - mfma_per_iter = 2 * mfma_group - sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) - + # mfma_group = num_acc_n + # mfma_total = (k_unroll * 2) * m_repeat * mfma_group + # mfma_per_iter = 2 * mfma_group + # sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) + + # rocdl.sched_dsrd(2) + # rocdl.sched_mfma(1) + # if tile_m == 16: + # rocdl.sched_vmem(1) + # rocdl.sched_mfma(1) + # if tile_m == 16: + # rocdl.sched_vmem(1) + # if num_acc_n < 4: + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(1) + # if tile_m == 16: + # rocdl.sched_vmem(1) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(1) + # if tile_m == 16: + # rocdl.sched_vmem(1) + # rocdl.sched_mfma(1) + + # # DS-write hints near the end: match total A LDS-store micro-ops per thread. + # dswr_tail = num_x_loads + # if dswr_tail > sche_iters: + # dswr_tail = sche_iters + # dswr_start = sche_iters - dswr_tail + + # for sche_i in range_constexpr(sche_iters): + # rocdl.sched_vmem(1) + # rocdl.sched_mfma(mfma_group) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(mfma_group) + # if sche_i >= dswr_start - 1: + # rocdl.sched_dswr(1) + + # rocdl.sched_barrier(0) rocdl.sched_dsrd(2) + rocdl.sched_vmem(1) rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) + + rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + + rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + + rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + + rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + + rocdl.sched_dsrd(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + rocdl.sched_dswr(1) + rocdl.sched_vmem(1) + + rocdl.sched_mfma(1) + rocdl.sched_dswr(1) + rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) - if num_acc_n < 4: - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) - rocdl.sched_mfma(1) - - # DS-write hints near the end: match total A LDS-store micro-ops per thread. - dswr_tail = num_x_loads - if dswr_tail > sche_iters: - dswr_tail = sche_iters - dswr_start = sche_iters - dswr_tail - - for sche_i in range_constexpr(sche_iters): - rocdl.sched_vmem(1) - rocdl.sched_mfma(mfma_group) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(mfma_group) - if sche_i >= dswr_start - 1: - rocdl.sched_dswr(1) + # DS-write hints near the end: match total X LDS-store micro-ops per thread. + # dswr_start = 12 - 2 + # for sche_i in range_constexpr(12): + # if sche_i < 10: + # rocdl.sched_vmem(1) + # if sche_i <= 10: + # rocdl.sched_dsrd(1) + + # rocdl.sched_mfma(1) + # if sche_i >= dswr_start - 1: + # rocdl.sched_dswr(1) + rocdl.sched_barrier(0) + + # ================ Unified Pipeline (FP4 / Standard) ================ + # Select pipeline functions based on is_fp4 + def _prefetch_scale(k_val): + if is_fp4: + return prefetch_ab_scale_tile(k_val // pack_K // 128) + return placeholder, placeholder + # Prologue. k0 = arith.index(0) x_regs0 = load_x_tile(k0) b_cur = load_b_tile(k0) + a_scale_ping, a_scale_pong = placeholder, placeholder + _, b_scale_pong = _prefetch_scale(k0) store_x_tile_to_lds(x_regs0, lds_base_cur) gpu.barrier() - + acc = [acc_init] * (num_acc_n * m_repeat) lds_base_pong = lds_base_cur lds_base_ping = lds_base_nxt - + # Cross-tile A0 LDS prefetch (default-on): prefetch the first A-pack (K64) for the # tile we are about to compute from LDS, to overlap with upcoming VMEM. a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) - + # Main loop: process K tiles in 2-tile ping-pong steps. # # IMPORTANT: for odd number of K tiles, leave **1** tail tile; for even, leave **2**. @@ -1918,63 +2517,65 @@ def hot_loop_scheduler(): k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) if k_main2_py < 0: k_main2_py = 0 - + c2_tile_k = arith.constant(tile_k * 2, index=True) c_k_main2 = arith.index(k_main2_py) for k_iv in range(arith.index(0), c_k_main2, arith.index(tile_k * 2)): next_k1 = k_iv + tile_k x_regs_ping = load_x_tile(next_k1) - b_ping = load_b_tile(next_k1) - - acc, _ = compute_tile(acc, b_cur, lds_base_pong, a0_prefetch=a0_prefetch_pong) - a0_prefetch_pong = None + b_ping = load_b_tile(next_k1 // pack_K) + _, b_scale_ping = _prefetch_scale(next_k1) + + acc, _ = compute_tile(acc, b_cur, lds_base_pong, a_scale=a_scale_pong, b_scale=b_scale_pong, a0_prefetch=a0_prefetch_pong) store_x_tile_to_lds(x_regs_ping, lds_base_ping) hot_loop_scheduler() gpu.barrier() - + # Cross-tile prefetch for the ping tile we are about to compute. a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) - + next_k2 = k_iv + c2_tile_k x_regs_pong = load_x_tile(next_k2) - b_next = load_b_tile(next_k2) - - acc, _ = compute_tile(acc, b_ping, lds_base_ping, a0_prefetch=a0_prefetch_ping) - a0_prefetch_ping = None + b_next = load_b_tile(next_k2 // pack_K) + _, b_scale_pong = _prefetch_scale(next_k2) + + acc, _ = compute_tile(acc, b_ping, lds_base_ping, a_scale=a_scale_ping, b_scale=b_scale_ping, a0_prefetch=a0_prefetch_ping) store_x_tile_to_lds(x_regs_pong, lds_base_pong) hot_loop_scheduler() gpu.barrier() - + # Cross-tile prefetch for the next pong tile. a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) - + b_cur = b_next - + if odd_k_tiles: # Tail: single remaining tile (already in `b_cur` / `lds_base_pong`). acc, epilogue_pf = compute_tile( acc, b_cur, lds_base_pong, + a_scale=a_scale_pong, + b_scale=b_scale_pong, prefetch_epilogue=True, a0_prefetch=a0_prefetch_pong, ) else: # Tail: 2 remaining tiles. - k_tail1 = k_in - tile_k + k_tail1 = (k_in + tile_k - 1) // tile_k * tile_k - tile_k if is_fp4 else k_in - tile_k x_regs_ping = load_x_tile(k_tail1) - b_ping = load_b_tile(k_tail1) - - acc, _ = compute_tile(acc, b_cur, lds_base_pong, a0_prefetch=a0_prefetch_pong) - a0_prefetch_pong = None + b_ping = load_b_tile(k_tail1 // pack_K) + _, b_scale_ping = _prefetch_scale(k_tail1) + + acc, _ = compute_tile(acc, b_cur, lds_base_pong, a_scale=a_scale_pong, b_scale=b_scale_pong, a0_prefetch=a0_prefetch_pong) store_x_tile_to_lds(x_regs_ping, lds_base_ping) hot_loop_scheduler() gpu.barrier() - + # Epilogue tile with sw prefetch. a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) acc, epilogue_pf = compute_tile( - acc, b_ping, lds_base_ping, prefetch_epilogue=True, a0_prefetch=a0_prefetch_ping + acc, b_ping, lds_base_ping, a_scale=a_scale_ping, b_scale=b_scale_ping, prefetch_epilogue=True, a0_prefetch=a0_prefetch_ping ) # ---------------- Epilogue: LDS CShuffle + atomic half2 (x2) ---------------- @@ -2008,22 +2609,19 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): sw_pf = None tw_pf = None + bias_pf = None if epilogue_pf is not None: - sw_pf, tw_pf = epilogue_pf + sw_pf, tw_pf, bias_pf = epilogue_pf # Weight scales for the N tile (col_g depends on lane/wave/by but not on (t,s)). if sw_pf is not None: sw_vals = sw_pf - else: + elif not no_epilogue_dequant: sw_vals = [] for ni in range_constexpr(num_acc_n): col_g = col_g_list[ni] row_w_idx = expert_off + col_g - sw_vals.append( - arith.f32(1.0) - if is_f16 - else buffer_ops.buffer_load(sw_rsrc, row_w_idx, vec_width=1, dtype=f32) - ) + sw_vals.append(buffer_ops.buffer_load(sw_rsrc, row_w_idx, vec_width=1, dtype=f32)) if out_is_f32: # origin/dev_a16w4: f32 output uses scalar f32 atomics and skips CShuffle/LDS. @@ -2044,7 +2642,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): s2 = fused2 >> 24 ts2 = t2 * topk_i32_v + s2 - sx = arith.f32(1.0) if is_f16 else buffer_ops.buffer_load(sx_rsrc, ts2, vec_width=1, dtype=f32) + sx = arith.f32(1.0) if no_epilogue_dequant else buffer_ops.buffer_load(sx_rsrc, ts2, vec_width=1, dtype=f32) if doweight_stage2: tw_idx = (mi * 4) + ii @@ -2057,12 +2655,16 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): for ni in range_constexpr(num_acc_n): col_g = col_g_list[ni] - sw = sw_vals[ni] acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if is_int8: + if is_int_mode: v = arith.sitofp(f32, v) - v = v * sx * sw + if not no_epilogue_dequant: + sw = sw_vals[ni] + v = v * sx * sw + if enable_bias: + v = v + bias_pf[ni] + if doweight_stage2: v = v * tw col_i32 = arith.index_cast(i32, col_g) @@ -2114,7 +2716,7 @@ def write_row_to_lds( ts2 = t2_safe * topk_i32_v + s2_safe sx = ( arith.f32(1.0) - if is_f16 + if no_epilogue_dequant else arith.select( ts_ok, buffer_ops.buffer_load(sx_rsrc, ts2, vec_width=1, dtype=f32), @@ -2133,18 +2735,21 @@ def write_row_to_lds( for ni in range_constexpr(num_acc_n): col_local = col_base_local + (ni * 16) - sw = sw_vals[ni] acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if is_int8: + if is_int_mode: v = arith.sitofp(f32, v) - v = v * sx * sw + if not no_epilogue_dequant: + sw = sw_vals[ni] + v = v * sx * sw + if enable_bias: + v = v + bias_pf[ni] if doweight_stage2: v = v * tw - v_out = arith.trunc_f(out_elem(), v) + v_out = arith.trunc_f(_out_elem_type(), v) lds_idx = row_base_lds + col_local - vec1_out = I.vec(1, out_elem()) + vec1_out = I.vec(1, _out_elem_type()) v1 = vector.from_elements(vec1_out, [v_out]) vector.store(v1, lds_out, [lds_idx], alignment=2) @@ -2219,7 +2824,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): by_n=by_n, n_tile_base=n_tile_base, lds_out=lds_out, - frag_elem_type=(ir.BF16Type.get() if out_is_bf16 else ir.F16Type.get()), + frag_elem_type=_out_elem_type(), write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, @@ -2231,15 +2836,16 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): @flir.jit def __call__( self: flir.T.i64, - arg_out: lambda: T.memref(size_out, out_elem()), - arg_x: lambda: T.memref(size_x, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_w: lambda: T.memref(size_w, I.f16 if is_f16 else (I.i8 if is_int8 else I.f8)), - arg_scale_x: lambda: T.memref(size_scale_x, T.f32()), - arg_scale_w: lambda: T.memref(experts * model_dim, T.f32()), + arg_out: lambda: T.memref(size_out, _out_elem_type()), + arg_x: lambda: T.memref(size_x, _x_elem_type()), + arg_w: lambda: T.memref(size_w, _w_elem_type()), + arg_scale_x: lambda: T.memref(size_scale_x, _scale_elem_type()), + arg_scale_w: lambda: T.memref(DYN, _scale_elem_type()), arg_sorted_token_ids: lambda: T.memref(size_sorted, T.i32()), arg_expert_ids: lambda: T.memref(size_expert_ids_shape, T.i32()), arg_sorted_weights: lambda: T.memref(size_sorted, T.f32()), arg_num_valid_ids: lambda: T.memref(DYN, T.i32()), + arg_bias: lambda: T.memref(DYN, T.f32()), tokens_in: lambda: T.index(), n_in: lambda: T.index(), k_in: lambda: T.index(), @@ -2265,6 +2871,7 @@ def __call__( arg_expert_ids, arg_sorted_weights, arg_num_valid_ids, + arg_bias, tokens_in, n_in, k_in, @@ -2563,6 +3170,7 @@ def __call__( arg_expert_ids, arg_sorted_weights, arg_num_valid_ids, + arg_bias, tokens_in, n_in, k_in, @@ -2580,7 +3188,7 @@ def __call__( intermediate.view(-1), arg_x, arg_w, arg_scale_x, arg_scale_w, arg_sorted_token_ids, arg_expert_ids, arg_sorted_weights, - arg_num_valid_ids, tokens_in, n_in, k_in, size_expert_ids_in, + arg_num_valid_ids, arg_bias, tokens_in, n_in, k_in, size_expert_ids_in, stream_ptr, ) # Phase 2: Reduce over topk -> [tokens, model_dim] @@ -2604,9 +3212,11 @@ def compile_moe_gemm2_ex( tile_n: int, tile_k: int, doweight_stage2: bool, - in_dtype: str = "fp8", + x_dtype: str = "fp8", + w_dtype: str = "fp8", out_dtype: str = "f16", use_cshuffle_epilog: bool | None = None, + enable_bias: bool = False, # Extended parameters for mode control mode: str = MoeGemm2Mode.AUTO, tokens_hint: int | None = None, @@ -2659,9 +3269,11 @@ def compile_moe_gemm2_ex( tile_n=tile_n, tile_k=tile_k, doweight_stage2=doweight_stage2, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, out_dtype=out_dtype, use_cshuffle_epilog=use_cshuffle_epilog, + enable_bias=enable_bias, accumulate=False, ) # Compile reduction kernel @@ -2695,8 +3307,12 @@ def compile_moe_gemm2_ex( tile_n=tile_n, tile_k=tile_k, doweight_stage2=doweight_stage2, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, out_dtype=out_dtype, use_cshuffle_epilog=use_cshuffle_epilog, + enable_bias=enable_bias, accumulate=True, ) + +####==================== gemm2 pipeline end =====================### diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 7cefacd0..17ccec87 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -11,8 +11,6 @@ - `ck_v1_single_lds`: CK-like Intrawave + bpreshuffle v1 spirit (single LDS buffer for A) """ -import os - import flydsl from flydsl.dialects.ext import flir from flydsl.dialects.ext.python_control_flow import range_constexpr @@ -28,16 +26,23 @@ from kernels.mfma_preshuffle_pipeline import ( buffer_copy_gmem16_dwordx4, - lds_load_pack_k32, lds_store_16b_xor16, make_preshuffle_b_layout, + make_preshuffle_scale_layout, load_b_pack_k32, tile_chunk_coord_i32, + MfmaPipeline, + EpilogPipeline, + mfma_pipeline_dicts, + PreshufflePipelineManager, + block_mfma_block_scale_f8f6f4, + block_mfma_PTPC_f8f6f4, + block_mfma_16x16, ) from kernels.mfma_epilogues import mfma_epilog -def compile_preshuffle_gemm_a8( +def compile_preshuffle_gemm( *, M: int, N: int, @@ -45,7 +50,9 @@ def compile_preshuffle_gemm_a8( tile_m: int, tile_n: int, tile_k: int, - in_dtype: str = "fp8", + a_dtype: str = "fp8", + b_dtype: str = "fp8", + out_dtype: str = "fp16", lds_stage: int = 2, # Epilogue options use_cshuffle_epilog: bool = False, @@ -59,10 +66,18 @@ def compile_preshuffle_gemm_a8( Args: M, N, K: GEMM sizes (A[M,K], B[N,K], C[M,N]). tile_m, tile_n, tile_k: block tile sizes. - in_dtype: + a_dtype: + - "fp8": A/B are fp8 (1B/elem) + - "int8": A/B are int8 (1B/elem) + - "a8w4": W4A8 path: A is int8 (1B/elem). + b_dtype: - "fp8": A/B are fp8 (1B/elem) - "int8": A/B are int8 (1B/elem) - - "int4": W4A8 path: A is int8, B is packed int4 (2 values per byte) and unpacked to int8 in-kernel. + - "int4": W4A8 path: B is packed int4 (2 values per byte) and unpacked to int8 in-kernel. + out_dtype: + - "fp16": Output is fp16 (2B/elem) + - "bf16": Output is bf16 (2B/elem) + - "f32": Output is f32 (4B/elem) lds_stage: - 2: ping-pong LDS for A (2 LDS buffers), tuned schedule (original). - 1: single LDS buffer for A . @@ -74,42 +89,37 @@ def compile_preshuffle_gemm_a8( Common values: 1 (max resources), 2 (balanced), 4 (max occupancy). use_async_copy: Use async DMA for A tile global-to-LDS transfer. """ - if in_dtype not in ("fp8", "int8", "int4", "fp16", "bf16"): - raise ValueError( - "in_dtype must be one of ('fp8','int8','int4','fp16','bf16'), " - f"got {in_dtype!r}" - ) - is_int4 = in_dtype == "int4" - is_int8 = (in_dtype == "int8") or is_int4 - is_f16 = in_dtype == "fp16" - is_bf16 = in_dtype == "bf16" - is_f16_or_bf16 = is_f16 or is_bf16 - elem_bytes = 1 if (in_dtype in ("fp8", "int8", "int4")) else 2 + + total_threads = 256 + + pipeline_manager = PreshufflePipelineManager(a_dtype, b_dtype, out_dtype) + pipeline_manager.check_type_valid() + + epilog_pipeline = pipeline_manager.get_epilog_pipeline() + mfma_pipeline = pipeline_manager.get_mfma_pipeline() + mfma_fn = pipeline_manager.get_mfma_fn() + + a_elem_bytes = pipeline_manager.get_a_elem_bytes() + b_elem_bytes = pipeline_manager.get_b_elem_bytes() + + a_elem_pack = pipeline_manager.a_elem_pack + b_elem_pack = pipeline_manager.b_elem_pack # Pipeline is byte-addressed along K (16B loads, XOR16 swizzle in bytes). - # For fp16/bf16 (2B/elem), user passes tile_k halved so tile_k_bytes stays constant. - tile_k_bytes = int(tile_k) * int(elem_bytes) + # For fp16/bf16 (2B/elem), user passes tile_k halved so b_tile_k_bytes stays constant. + b_tile_k_bytes = int(tile_k) * int(b_elem_bytes) # K64-byte micro-step wrapper uses 2x "half-pack" MFMA. # - fp8/i8: mfma K32, wrapper covers 64B (=64 elems) # - fp16/bf16: mfma K16, wrapper covers 64B (=32 elems) -> effective K halves - if (tile_k_bytes % 64) != 0: + if (b_tile_k_bytes % 64) != 0: raise ValueError( - f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " - f"(tile_k={tile_k}, elem_bytes={elem_bytes})" + f"b_tile_k_bytes must be divisible by 64, got b_tile_k_bytes={b_tile_k_bytes} " + f"(tile_k={tile_k}, b_elem_bytes={b_elem_bytes})" ) - # INT8 must use a K32 MFMA so the micro-step matches the FP8 path (strict alignment). - mfma_i32_k32 = None - if is_int8: - mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( - rocdl, "mfma_i32_16x16x32_i8", None - ) - if mfma_i32_k32 is None: - raise AttributeError( - "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " - "(or `rocdl.mfma_i32_16x16x32_i8`)." - ) + # For FP4: A is packed (2 elems/byte), so actual bytes = tile_k * a_elem_bytes / a_elem_pack + a_tile_k_bytes = int(tile_k) * int(a_elem_bytes) // a_elem_pack gpu_arch = get_hip_arch() if use_async_copy and gpu_arch != "gfx950": @@ -119,62 +129,63 @@ def compile_preshuffle_gemm_a8( allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name = "smem_ping") _state = {} - # Default-on: cross-tile (tile_k) A0 LDS prefetch in the ping-pong pipeline (lds_stage=2). - # - # This issues the *first* A-pack LDS read for the next tile between barriers, to overlap - # with the VMEM prefetch of the following tile. - DYN = ir.ShapedType.get_dynamic_size() - # Vector width calc (assume full tiles / no tail guards). - total_threads = 256 - bytes_a_per_tile = int(tile_m) * int(tile_k) * int(elem_bytes) - if bytes_a_per_tile % total_threads != 0: - raise ValueError( - "tile_m*tile_k*elem_bytes must be divisible by " - f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={elem_bytes}" - ) - bytes_per_thread_a = bytes_a_per_tile // total_threads - - # Assume A loads are always 16B-aligned and use fixed dwordx4 (16B) buffer loads. - a_load_bytes = 16 - if bytes_per_thread_a % a_load_bytes != 0: - raise ValueError( - f"bytes_per_thread_a ({bytes_per_thread_a}) must be divisible by {a_load_bytes}" - ) + # a is copied by the whole block, so we need to calculate the bytes per thread. + a_bytes_per_thread = pipeline_manager.get_a_bytes_per_thread(tile_m, tile_k) # CK-style LDS128: stride is in BYTES along K (for XOR16 swizzle). - lds_stride_bytes = tile_k_bytes - - def _elem_type(): - if is_f16: - return T.f16 - if is_bf16: - return T.bf16 - return T.i8 if is_int8 else T.f8 - - def _vec16_type(): - if is_f16: - return T.f16x8 # 16B - if is_bf16: - return T.bf16x8 # 16B - return T.i8x16 if is_int8 else T.f8x16 - - def _mfma_pack_ty(): - # ROCDL MFMA intrinsics expect specific operand vector types: - # - fp8/int8 paths use i64 packs (8 bytes) - # - fp16 uses v4f16 (8 bytes) - # - bf16 uses v4i16 (8 bytes) for *_bf16_1k variants - if is_f16: - return T.f16x4 - if is_bf16: - return T.i16x4 - return T.i64 + # a_tile_k_bytes already accounts for packing, so no further division needed + lds_stride_bytes = a_tile_k_bytes + + def _get_mfma_dict_value(key, pipeline): + value = mfma_pipeline_dicts[key][pipeline] + return value() if callable(value) else value + + def _a_elem_type(): + return _get_mfma_dict_value("a_elem_type", mfma_pipeline) + def _b_elem_type(): + return _get_mfma_dict_value("b_elem_type", mfma_pipeline) + def _scale_elem_type(): + return _get_mfma_dict_value("scale_elem_type", mfma_pipeline) + def _out_elem_type(): + return _get_mfma_dict_value("out_elem_type", epilog_pipeline) + def _a_vec16_type(): + return _get_mfma_dict_value("a_vec16_type", mfma_pipeline) + def _b_vec16_type(): + return _get_mfma_dict_value("b_vec16_type", mfma_pipeline) + def _mfma_input_pack_ty(): + return _get_mfma_dict_value("mfma_input_pack_ty", mfma_pipeline) + def _mfma_output_pack_ty(): + return _get_mfma_dict_value("mfma_output_pack_ty", mfma_pipeline) + + is_f16_or_bf16 = mfma_pipeline in [MfmaPipeline.F16F16_16x16_PIPELINE, + MfmaPipeline.BF16BF16_16x16_PIPELINE] + is_int4 = mfma_pipeline in [MfmaPipeline.I8I4_16x16_PIPELINE] + is_int8 = mfma_pipeline in [MfmaPipeline.I8I8_16x16_PIPELINE] + is_fp4 = mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, + MfmaPipeline.F8F4_MXFP4_PIPELINE] + + no_epilogue_dequant = is_fp4 or is_f16_or_bf16 + + # 350 16x16x128 adtype(cbsz) & bdtype(blgp) + cbsz = 4 if mfma_pipeline == MfmaPipeline.F4F4_MXFP4_PIPELINE else 0 + blgp = 4 if mfma_pipeline in [MfmaPipeline.F4F4_MXFP4_PIPELINE, MfmaPipeline.F8F4_MXFP4_PIPELINE] else 0 + + pack_M = 2 + pack_N = 2 + pack_K = 2 + + use_mfma_scale_128 = ( + str(gpu_arch).startswith("gfx95") + and (not is_int8) + and (not is_int4) + and (not is_f16_or_bf16) + ) # GEMM epilogue toggle: optional LDS CShuffle + vectorized stores. # Default: off (current measured cases show no benefit). - epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" - module_name = f"mfma_preshuffle_{lds_stage}stages_{in_dtype}_{epilog_tag}".replace("-", "_") + module_name = f"mfma_preshuffle_{lds_stage}stages_{mfma_pipeline}_{epilog_pipeline}".replace("-", "_") class _GEMM(flir.MlirModule): GPU_MODULE_NAME = module_name @@ -195,6 +206,14 @@ def init_gpu_module(self): lds_tile_bytes = int(tile_m) * int(lds_stride_bytes) lds_out_bytes = 2 * int(tile_m) * int(tile_n) if use_cshuffle_epilog else 0 +<<<<<<< HEAD + lds_total_bytes = max(lds_a_bytes, lds_out_bytes) + # Keep LDS allocation sized in bytes: element_size * num_elems. + # Allocate element type == _elem_type() and scale element count accordingly. + a_lds_total_elems = lds_total_bytes // a_elem_bytes + _state["lds_a_decl"] = allocator.allocate_array(_a_elem_type(), a_lds_total_elems) + allocator.finalize() +======= if int(lds_stage) == 2: # Separate ping/pong buffers for no-alias guarantee @@ -213,15 +232,16 @@ def init_gpu_module(self): allocator_pong.finalize() allocator_ping.finalize() +>>>>>>> main @flir.kernel def kernel_gemm( self: flir.T.i64, - arg_c: lambda: memref(DYN, T.f16), - arg_a: lambda: memref(DYN, _elem_type()), - arg_b: lambda: memref(DYN, _elem_type()), - arg_scale_a: lambda: memref(DYN, T.f32), - arg_scale_b: lambda: memref(DYN, T.f32), + arg_c: lambda: memref(DYN, _out_elem_type()), + arg_a: lambda: memref(DYN, _a_elem_type()), + arg_b: lambda: memref(DYN, _b_elem_type()), + arg_scale_a: lambda: memref(DYN, _scale_elem_type()), + arg_scale_b: lambda: memref(DYN, _scale_elem_type()), c_m: lambda: T.index, c_n: lambda: T.index, c_k: lambda: T.index, @@ -229,17 +249,16 @@ def kernel_gemm( # ---- Types ---- # NOTE: Some environments have multiple `flydsl` builds on PYTHONPATH. # Use explicit MLIR Values (not Python ints / wrapper objects) for ROCDL ops. - acc_init = arith.unwrap( - arith.constant_vector(0, T.i32x4) - if is_int8 - else arith.constant_vector(0.0, T.f32x4) - ) + acc_init = arith.unwrap(arith.constant_vector(0, _mfma_output_pack_ty())) # Layouts layout_c = flir.make_layout((c_m, c_n), stride=(c_n, 1)) # A uses dword indexing (buffer-load dwordx4). Convert element index -> dword index: # dword_index = (elem_index * elem_bytes) / 4 +<<<<<<< HEAD + c_k_div4bytes = c_k * a_elem_bytes / 4 / a_elem_pack +======= # Also create byte-indexed layout for DMA path. if (int(elem_bytes) == 2): c_k_div4bytes = (c_k * 2) / 4 @@ -248,27 +267,46 @@ def kernel_gemm( c_k_div4bytes = c_k / 4 c_k_bytes = c_k layout_a = flir.make_layout((c_m, c_k_bytes), stride=(c_k_bytes, 1)) +>>>>>>> main layout_a_div4 = flir.make_layout((c_m, c_k_div4bytes), stride=(c_k_div4bytes, 1)) # B preshuffle layout (shared with MoE kernels). - kpack_bytes = 8 if is_int4 else 16 + # For FP4: B is packed (2 elems/byte), so adjust c_k accordingly + b_kpack_bytes = 16 + c_k_b = c_k // b_elem_pack layout_b = make_preshuffle_b_layout( - flir, arith, c_n=c_n, c_k=c_k, kpack_bytes=kpack_bytes, elem_bytes=elem_bytes + flir, arith, c_n=c_n, c_k=c_k_b, kpack_bytes=b_kpack_bytes, elem_bytes=b_elem_bytes ).layout_b + # Scale layouts for FP4/MXFP4 (block-scale MFMA). + layout_a_scale = make_preshuffle_scale_layout(flir, arith, c_mn=c_m, c_k=c_k) if is_fp4 else None + layout_b_scale = make_preshuffle_scale_layout(flir, arith, c_mn=c_n, c_k=c_k) if is_fp4 else None + # LDS layout is element-indexed, but XOR16 swizzle is byte-based. # Represent LDS as (tile_m, tile_k) in elements and scale swizzle math by elem_bytes. - shape_lds = flir.make_shape(tile_m, tile_k) - stride_lds = flir.make_stride(tile_k, 1) + # For FP4: A is packed (2 elems/byte), so LDS K dimension is tile_k / a_elem_pack + lds_tile_k = tile_k // a_elem_pack if is_fp4 else tile_k + shape_lds = flir.make_shape(tile_m, lds_tile_k) + stride_lds = flir.make_stride(lds_tile_k, 1) layout_lds = flir.make_layout(shape_lds, stride_lds) # CK-style XOR16 swizzle parameter (const). - k_blocks16 = arith.index(tile_k_bytes // 16) + a_k_blocks16 = arith.index(a_tile_k_bytes // 16) tx = gpu.thread_id("x") bx = gpu.block_id("x") by = gpu.block_id("y") +<<<<<<< HEAD + base_ptr = allocator.get_base() + lds_a_ptr = _state["lds_a_decl"](base_ptr) + lds_a = lds_a_ptr.get() + lds_out = ( + SmemPtr(base_ptr, lds_a_ptr.byte_offset, _out_elem_type(), shape=(tile_m * tile_n,)).get() + if use_cshuffle_epilog + else None + ) +======= base_ptr, base_ptr1 = allocator_pong.get_base(), allocator_ping.get_base() # Get LDS pointers based on pipeline stage @@ -303,6 +341,7 @@ def kernel_gemm( if use_cshuffle_epilog else None ) +>>>>>>> main # Note: We assume N is aligned (no N-tail support in this kernel). a_rsrc = buffer_ops.create_buffer_resource(arg_a, max_size=False) @@ -335,19 +374,23 @@ def kernel_gemm( # - fp16/bf16: 8 elems (2B) # # We express `col_offset_base` in *elements*. - kpack_elems = 16 if elem_bytes == 1 else 8 - col_offset_base = lane_div_16 * arith.constant(int(kpack_elems), index=True) + kpack_elems = 16 + a_kpack_elems = kpack_elems / a_elem_bytes + col_offset_base = lane_div_16 * arith.constant(int(a_kpack_elems), index=True) # `col_offset_base` is in element units (multiples of 16). We do LDS swizzle/math # in bytes, so scale by element size for fp16/bf16. col_offset_base_bytes = ( - col_offset_base - if elem_bytes == 1 - else (col_offset_base * arith.constant(int(elem_bytes), index=True)) + col_offset_base * arith.constant(int(a_elem_bytes), index=True) ) m_repeat = tile_m // 16 - # K stepping is byte-addressed: one "micro-tile step" is 64 bytes. - k_unroll = tile_k_bytes // 64 + # K stepping is byte-addressed: + # - For FP4/MXFP4 (mfma 16x16x128): one micro-step is 128 bytes + # - For FP8/INT8 (mfma 16x16x32): one micro-step is 64 bytes (K32 x 2) + if is_fp4: + k_unroll = b_tile_k_bytes // 64 // b_elem_pack + else: + k_unroll = b_tile_k_bytes // 64 # --- Dynamic tiling along N (4 waves) --- num_waves = 4 @@ -370,11 +413,77 @@ def kernel_gemm( n_blk_list.append(flir.get(coord_n, 0)) n_intra_list.append(flir.get(coord_n, 1)) + # FP4/MXFP4 pack parameters for block-scale MFMA. + k_unroll_packed = k_unroll // pack_K if is_fp4 else 0 + m_repeat_packed = m_repeat // pack_M if is_fp4 else 0 + num_acc_n_packed = num_acc_n // pack_N if is_fp4 else 0 + + # --- Scale load logic for FP4/MXFP4 --- + def load_scale(arg_scale, rsrc, layout, ku, mni): + """Load a single scale value for FP4/MXFP4 block-scale MFMA.""" + k_lane = lane_div_16 + n_lane = lane_mod_16 + coord_pack = flir.make_coord(mni, ku, k_lane, n_lane) + idx_pack = flir.crd2idx(coord_pack, layout) + scale_view = flir.TensorView( + arg_scale, + (1,), + strides=(1,), + base_indices=(idx_pack,), + element_type=_scale_elem_type(), + ) + scale = flir.copy( + flir.make_copy_atom(_scale_elem_type(), vector_size=1), + scale_view, + None, + alignment=8, + return_vector=True, + src_buffer_resource=rsrc, + src_buffer_offset_in_bytes=False, + ) + return scale + + def load_b_scale_tile(base_k): + """Load B scale tile for FP4/MXFP4.""" + b_scale_tile = [] + for ku in range_constexpr(k_unroll_packed): + for ni in range_constexpr(num_acc_n_packed): + scale = load_scale( + arg_scale_b, + scale_b_rsrc, + layout_b_scale, + ku + base_k, + ni + (by_n + n_tile_base) // pack_N // 16, + ) + b_scale_tile.append(scale) + return b_scale_tile + + def load_a_scale_tile(base_k): + """Load A scale tile for FP4/MXFP4.""" + a_scale_tile = [] + for ku in range_constexpr(k_unroll_packed): + for mi in range_constexpr(m_repeat_packed): + scale = load_scale( + arg_scale_a, + scale_a_rsrc, + layout_a_scale, + ku + base_k, + mi + bx_m // pack_M // 16, + ) + a_scale_tile.append(scale) + return a_scale_tile + + def prefetch_ab_scale_tile(base_k): + """Prefetch A and B scale tiles for FP4/MXFP4.""" + if not is_fp4: + return None, None + # return load_a_scale_tile(base_k), load_b_scale_tile(base_k) + return load_a_scale_tile(0), load_b_scale_tile(0) + # --- B load logic --- # Shared loader supports: # - FP8/INT8: explicit 16B load (one full KPack) + extract 8B for this micro-step # - INT4 (W4A8): 4B load + 7-op unpack to 8B (no v_perm) - def load_b_pack(base_k, ki_step, ni): return load_b_pack_k32( buffer_ops, @@ -389,15 +498,15 @@ def load_b_pack(base_k, ki_step, ni): n_blk=n_blk_list[ni], n_intra=n_intra_list[ni], lane_div_16=lane_div_16, - elem_type=_elem_type(), - kpack_bytes=kpack_bytes, - elem_bytes=elem_bytes, + elem_type=_b_elem_type(), + kpack_bytes=b_kpack_bytes, + elem_bytes=b_elem_bytes, unpack_int4=is_int4, ) # For FP8/INT8 we can load one 16B pack and extract both 8B halves (K64 bytes). # For INT4 (packed), reuse the existing K32 loader twice (2x4B loads + unpack). - atom_b_g2r16 = flir.make_copy_atom(_elem_type(), vector_size=16) + atom_b_g2r16 = flir.make_copy_atom(_b_elem_type(), vector_size=16) c64_b = 64 c0_idx = 0 @@ -408,29 +517,31 @@ def load_b_packs_k64(base_k, ku: int, ni: int): return load_b_pack(base_k, ki0, ni), load_b_pack(base_k, ki1, ni) # FP8/INT8/FP16/BF16: load 16 bytes (one full KPack). - base_k_bytes = base_k * arith.constant(int(elem_bytes), index=True) + base_k_bytes = base_k * arith.constant(int(b_elem_bytes), index=True) k0_base = base_k_bytes / c64_b k0 = k0_base + ku k1 = lane_div_16 coord_pack = flir.make_coord(n_blk_list[ni], k0, k1, n_intra_list[ni], c0_idx) idx_pack = flir.crd2idx(coord_pack, layout_b) - vec_elems = 16 if elem_bytes == 1 else 8 + vec_elems = kpack_elems / b_elem_bytes + b_view = flir.TensorView( arg_b, (vec_elems,), strides=(1,), base_indices=(idx_pack,), - element_type=_elem_type(), + element_type=_b_elem_type(), ) b16 = flir.copy( - flir.make_copy_atom(_elem_type(), vector_size=vec_elems), + flir.make_copy_atom(_b_elem_type(), vector_size=vec_elems), b_view, None, alignment=8, return_vector=True, - src_buffer_resource=(b_rsrc if elem_bytes == 1 else None), - src_buffer_offset_in_bytes=(elem_bytes == 1), + src_buffer_resource=(b_rsrc if b_elem_bytes == 1 else None), + src_buffer_offset_in_bytes=(b_elem_bytes == 1), ) + # Split 16B pack into two 8B halves. b_i64x2 = vector.bitcast(T.i64x2, b16) b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) @@ -445,11 +556,7 @@ def load_b_packs_k64(base_k, ku: int, ni: int): vec1_i64_ty = ir.VectorType.get([1], ir.IntegerType.get_signless(64)) b0_v1 = vector.from_elements(vec1_i64_ty, [b0_i64]) b1_v1 = vector.from_elements(vec1_i64_ty, [b1_i64]) - if is_f16: - return vector.bitcast(T.f16x4, b0_v1), vector.bitcast(T.f16x4, b1_v1) - return vector.bitcast(T.i16x4, b0_v1), vector.bitcast(T.i16x4, b1_v1) - - # int8 path should not reach here (handled by the outer is_int8 branch). + return vector.bitcast(_mfma_input_pack_ty(), b0_v1), vector.bitcast(_mfma_input_pack_ty(), b1_v1) def load_b_tile(base_k): # b_tile[ku] = (packs_half0[ni], packs_half1[ni]) @@ -466,11 +573,16 @@ def load_b_tile(base_k): def lds_load_16b(curr_row_a_lds, col_base, lds_buffer): # Swizzle in bytes, then convert to element offset for memref indexing. - col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) - col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / 2) + col_base_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base, a_k_blocks16) + col_base_swz = col_base_swz_bytes if a_elem_bytes == 1 else (col_base_swz_bytes / 2) coord_a16 = flir.make_coord(curr_row_a_lds, col_base_swz) idx_a16 = flir.crd2idx(coord_a16, layout_lds) +<<<<<<< HEAD + idx_a16 = idx_a16 + lds_base + return vector.load_op(_a_vec16_type(), lds_a, [idx_a16]) +======= return vector.load_op(_vec16_type(), lds_buffer, [idx_a16]) +>>>>>>> main # --- A LDS load helper for K64-bytes (load 16B once, extract 2x i64 halves) --- def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): @@ -485,33 +597,37 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): vec1_i64_ty = ir.VectorType.get([1], ir.IntegerType.get_signless(64)) a0_v1 = vector.from_elements(vec1_i64_ty, [a0_i64]) a1_v1 = vector.from_elements(vec1_i64_ty, [a1_i64]) - if is_f16: - return vector.bitcast(T.f16x4, a0_v1), vector.bitcast(T.f16x4, a1_v1) - return vector.bitcast(T.i16x4, a0_v1), vector.bitcast(T.i16x4, a1_v1) + return vector.bitcast(_mfma_input_pack_ty(), a0_v1), vector.bitcast(_mfma_input_pack_ty(), a1_v1) # --- A load/store (16B chunks), XOR16 swizzle --- +<<<<<<< HEAD + a_load_bytes = 16 + num_a_loads = a_bytes_per_thread // a_load_bytes +======= # Original register-based approach (commented out, kept for reference) num_a_loads = bytes_per_thread_a // a_load_bytes +>>>>>>> main # A tile mapping in dwords along K: # tile_k_dwords = (tile_k * elem_bytes) / 4 - if elem_bytes == 2: - tile_k_dwords = (tile_k * 2) // 4 + # For FP4: A is packed (2 elems/byte), so divide by a_elem_pack + if is_fp4: + a_tile_k_dwords = tile_k * a_elem_bytes // 4 // a_elem_pack else: - tile_k_dwords = tile_k // 4 - layout_a_tile_div4 = flir.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) + a_tile_k_dwords = tile_k * a_elem_bytes // 4 + layout_a_tile_div4 = flir.make_layout((tile_m, a_tile_k_dwords), stride=(a_tile_k_dwords, 1)) c4 = arith.constant(4, index=True) tx_i32_base = tx * c4 - atom_a_g2r16 = flir.make_copy_atom(_elem_type(), vector_size=(16 if elem_bytes == 1 else 8)) + atom_a_g2r16 = flir.make_copy_atom(_a_elem_type(), vector_size=a_kpack_elems) def load_a_16(idx_elem): return buffer_copy_gmem16_dwordx4( flir, arg=arg_a, - elem_type=_elem_type(), + elem_type=_a_elem_type(), idx_i32=idx_elem, atom_g2r16=atom_a_g2r16, rsrc=a_rsrc, - vec_elems=(16 if elem_bytes == 1 else 8), + vec_elems=a_kpack_elems, ) def a_tile_chunk_coord_i32(i: int): @@ -535,11 +651,8 @@ def load_a_tile(base_k_div4): # `idx_i32` is a dword offset. For 2B element types (fp16/bf16), # convert to element offset so the generic `vector.load` path reads # the right address (FLIR only specializes buffer_load_dwordx4 for 1B types). - idx_elem = ( - idx_i32 - if elem_bytes == 1 - else (idx_i32 * arith.constant(2, index=True)) - ) + idx_elem = idx_i32 * arith.constant(a_elem_bytes, index=True) + a_16B = load_a_16(idx_elem) parts.append(vector.bitcast(T.i32x4, a_16B)) return parts @@ -551,21 +664,47 @@ def store_a_tile_to_lds(vec_a_parts, lds_buffer): flir, arith, vector, +<<<<<<< HEAD + lds_memref=lds_a, + vec16_ty=_a_vec16_type(), + elem_type=_a_elem_type(), +======= lds_memref=lds_buffer, vec16_ty=_vec16_type(), elem_type=_elem_type(), +>>>>>>> main atom_s16=atom_a_g2r16, layout_lds=layout_lds, row_local=row_a_local, col_local_i32=col_a_local_i32, tx_c4=c4, +<<<<<<< HEAD + k_blocks16=a_k_blocks16, + lds_base=lds_base, +======= k_blocks16=k_blocks16, lds_base=arith.constant(0, index=True), +>>>>>>> main vec_part_i32x4=vec_a_parts[i], - elem_bytes=elem_bytes, + elem_bytes=a_elem_bytes, ) def prefetch_ab_tile(base_k): +<<<<<<< HEAD + # Convert element index to byte index, then to dword index. + base_k_bytes = base_k * arith.constant(int(a_elem_bytes), index=True) + base_k_div4 = base_k_bytes / 4 + if is_fp4: + # For FP4/MXFP4: A and B are packed (2 elems/byte), need to adjust offsets + a_regs = load_a_tile(base_k_div4 // a_elem_pack) + b_regs = load_b_tile(base_k // b_elem_pack) + else: + a_regs = load_a_tile(base_k_div4) + b_regs = load_b_tile(base_k) + return a_regs, b_regs + + def compute_tile(accs_in, b_tile_in, lds_base, *, is_last_tile=False, a0_prefetch=None, a_scale=None, b_scale=None): +======= a_regs = prefetch_a_tile(base_k) b_regs = prefetch_b_tile(base_k) return a_regs, b_regs @@ -633,9 +772,12 @@ def prefetch_a_to_lds(base_k, lds_buffer): dma_a_tile_to_lds(base_k_div4, lds_buffer) def compute_tile(accs_in, b_tile_in, lds_buffer, *, is_last_tile=False, a0_prefetch=None): +>>>>>>> main scales_pf = {} - if is_last_tile and (not is_f16_or_bf16): - # Prefetch scales (fp8/int8/int4 only). + + mfma_res_ty = _mfma_output_pack_ty() + if is_last_tile and (not no_epilogue_dequant): + # Prefetch scales for non-FP4 scaled paths (fp8/int8/int4 with per-tensor scale). s_b_vals = [] for ni in range_constexpr(num_acc_n): offset = ni * 16 @@ -656,6 +798,32 @@ def compute_tile(accs_in, b_tile_in, lds_buffer, *, is_last_tile=False, a0_prefe scales_pf["s_a_vecs"].append(vector.bitcast(T.f32x4, s_a_vec)) current_accs_list = list(accs_in) +<<<<<<< HEAD + if is_fp4: + # FP4/MXFP4 path: use block-scale MFMA with per-block scales. + current_accs_list = block_mfma_block_scale_f8f6f4( + current_accs_list, + b_tile_in, + a_scale, + b_scale, + lds_base, + lds_load_packs_k64=lds_load_packs_k64, + col_offset_base_bytes=col_offset_base_bytes, + row_a_lds=row_a_lds, + mfma_fn=rocdl.mfma_scale_f32_16x16x128_f8f6f4, + mfma_res_ty=mfma_res_ty, + cbsz=cbsz, + blgp=blgp, + a_elem_vec_pack=a_elem_pack, + k_unroll=k_unroll, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + pack_K=pack_K, + pack_M=pack_M, + pack_N=pack_N, + a0_prefetch=a0_prefetch, + ) +======= # ---------------- gfx95 fast path (K128 MFMA scale) ---------------- # This is the key optimization from `zhimding/develop_0107` for FP8: @@ -734,10 +902,42 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): 0x3F800000, ], ) +>>>>>>> main return current_accs_list, scales_pf - mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 + if use_mfma_scale_128: + current_accs_list = block_mfma_PTPC_f8f6f4( + current_accs_list, + b_tile_in, + lds_base, + col_offset_base_bytes=col_offset_base_bytes, + row_a_lds=row_a_lds, + lds_load_packs_k64=lds_load_packs_k64, + mfma_fn=rocdl.mfma_scale_f32_16x16x128_f8f6f4, + mfma_res_ty=mfma_res_ty, + k_unroll=k_unroll, + num_acc_n=num_acc_n, + m_repeat=m_repeat, + a0_prefetch=a0_prefetch, + ) + return current_accs_list, scales_pf +<<<<<<< HEAD + current_accs_list = block_mfma_16x16( + current_accs_list, + b_tile_in, + lds_base, + col_offset_base_bytes=col_offset_base_bytes, + row_a_lds=row_a_lds, + lds_load_packs_k64=lds_load_packs_k64, + mfma_fn=mfma_fn, + mfma_res_ty=mfma_res_ty, + k_unroll=k_unroll, + num_acc_n=num_acc_n, + m_repeat=m_repeat, + a0_prefetch=a0_prefetch, + ) +======= if is_int8: mfma_fn = mfma_i32_k32 elif is_f16: @@ -778,13 +978,14 @@ def mfma_k64_bytes(acc_in, a0, a1, b0, b1): b_packs0[ni], b_packs1[ni], ) +>>>>>>> main return current_accs_list, scales_pf vec1_f16 = ir.VectorType.get([1], ir.F16Type.get()) def store_output(final_accs, scales): # fp16/bf16: no scale fetch, no scale multiply in epilogue. - if is_f16_or_bf16: + if no_epilogue_dequant: s_b_vals = None s_a_vecs = None else: @@ -827,13 +1028,14 @@ def write_row_to_lds( acc_idx = mi * num_acc_n + ni acc = final_accs[acc_idx] val = vector.extract(acc, static_position=[ii], dynamic_position=[]) - if is_int8: + # if is_int8: + if is_int8 or is_int4: val = arith.sitofp(T.f32, val) - if is_f16_or_bf16: + if no_epilogue_dequant: val_s = val else: val_s = (val * s_a) * s_b_vals[ni] - v16 = arith.trunc_f(T.f16, val_s) + v16 = arith.trunc_f(_out_elem_type(), val_s) lds_idx = row_base_lds + col_local v1 = vector.from_elements(vec1_f16, [v16]) @@ -893,7 +1095,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): return def body_row(*, mi: int, ii: int, row_in_tile, row): - if not is_f16_or_bf16: + if not no_epilogue_dequant: s_a_vec4 = s_a_vecs[mi] s_a = vector.extract(s_a_vec4, static_position=[ii], dynamic_position=[]) col_base = by_n + n_tile_base + lane_mod_16 @@ -902,15 +1104,16 @@ def body_row(*, mi: int, ii: int, row_in_tile, row): acc_idx = mi * num_acc_n + ni acc = final_accs[acc_idx] val = vector.extract(acc, static_position=[ii], dynamic_position=[]) - if is_int8: + if is_int8 or is_int4: + # INT8/INT4 paths use i32 accumulators; convert to f32 for scaled epilogue. val = arith.sitofp(T.f32, val) - if is_f16_or_bf16: + if no_epilogue_dequant: val_s = val else: val_s = (val * s_a) * s_b_vals[ni] - val_f16 = arith.trunc_f(T.f16, val_s) + val_out = arith.trunc_f(_out_elem_type(), val_s) idx_out = idx_base + arith.constant(ni * 16, index=True) - buffer_ops.buffer_store(val_f16, c_rsrc, idx_out) + buffer_ops.buffer_store(val_out, c_rsrc, idx_out) mfma_epilog( use_cshuffle=False, @@ -971,6 +1174,16 @@ def hot_loop_scheduler(): rocdl.sched_barrier(0) # ---------------- Pipeline ---------------- +<<<<<<< HEAD + # LDS base offsets are in *elements* of `_elem_type()`. + # We keep LDS laid out as (tile_m, tile_k) in element units. + # For FP4: A is packed (2 elems/byte), so divide by a_elem_pack + lds_tile_elems_val = tile_m * tile_k // a_elem_pack if is_fp4 else tile_m * tile_k + lds_tile_elems = arith.constant(lds_tile_elems_val, index=True) + lds_base0 = arith.constant(0, index=True) + lds_base1 = lds_tile_elems + +======= # LDS tile size in elements (not bytes) for offset calculation lds_tile_elems = arith.constant(tile_m * tile_k, index=True) # Note: For 2-stage pipeline, we now use separate LDS buffers (lds_a_ping/pong) @@ -978,6 +1191,7 @@ def hot_loop_scheduler(): # no-alias guarantees to the compiler. b_tile_ping = None b_tile_pong = None +>>>>>>> main if lds_stage == 2: # ---------------- Ping-pong pipeline (2 separate LDS buffers) ---------------- # Cross-tile A0 LDS prefetch (default-on): @@ -990,11 +1204,19 @@ def prefetch_a0_pack(lds_buffer): # Prologue: tile-0 k0 = arith.constant(0, index=True) +<<<<<<< HEAD + a_regs0, b_tile0 = prefetch_ab_tile(k0) + # Prefetch scales for FP4/MXFP4 at tile-0. + a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(k0 // 256) if is_fp4 else (k0, k0) + + store_a_tile_to_lds(a_regs0, lds_base0) +======= b_tile_pong = prefetch_b_tile(k0) if use_async_copy: prefetch_a_to_lds(k0, lds_a_pong) # Load into pong buffer else: store_a_tile_to_lds(prefetch_a_tile(k0), lds_a_pong) +>>>>>>> main gpu.barrier() accs = [acc_init] * (num_acc_n * m_repeat) @@ -1007,6 +1229,15 @@ def prefetch_a0_pack(lds_buffer): if (num_tiles % 2) == 1: for k_iv in range(0, c_k_main, tile_k * 2): next_k1 = k_iv + tile_k +<<<<<<< HEAD + a_regs_ping, b_tile_ping = prefetch_ab_tile(next_k1) + a_scale_ping, b_scale_ping = prefetch_ab_scale_tile(next_k1 // 256) if is_fp4 else (k0, k0) + + accs, _ = compute_tile( + accs, b_tile_pong, lds_base_pong, + a0_prefetch=a0_prefetch_pong, + a_scale=a_scale_pong, b_scale=b_scale_pong, +======= b_tile_ping = prefetch_b_tile(next_k1) if use_async_copy: prefetch_a_to_lds(next_k1, lds_a_ping) @@ -1015,6 +1246,7 @@ def prefetch_a0_pack(lds_buffer): accs, _ = compute_tile( accs, b_tile_pong, lds_a_pong, a0_prefetch=a0_prefetch_pong +>>>>>>> main ) a0_prefetch_pong = None @@ -1025,6 +1257,15 @@ def prefetch_a0_pack(lds_buffer): a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) next_k2 = k_iv + tile_k * 2 +<<<<<<< HEAD + a_regs_pong, b_tile_pong = prefetch_ab_tile(next_k2) + a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(next_k2 // 256) if is_fp4 else (k0, k0) + + accs, _ = compute_tile( + accs, b_tile_ping, lds_base_ping, + a0_prefetch=a0_prefetch_ping, + a_scale=a_scale_ping, b_scale=b_scale_ping, +======= b_tile_pong = prefetch_b_tile(next_k2) if use_async_copy: prefetch_a_to_lds(next_k2, lds_a_pong) @@ -1033,6 +1274,7 @@ def prefetch_a0_pack(lds_buffer): accs, _ = compute_tile( accs, b_tile_ping, lds_a_ping, a0_prefetch=a0_prefetch_ping +>>>>>>> main ) a0_prefetch_ping = None @@ -1048,11 +1290,21 @@ def prefetch_a0_pack(lds_buffer): lds_a_pong, is_last_tile=True, a0_prefetch=a0_prefetch_pong, + a_scale=a_scale_pong, b_scale=b_scale_pong, ) else: c_k_stop = c_k - (tile_k * 3) for k_iv in range(0, c_k_stop, tile_k * 2): next_k1 = k_iv + tile_k +<<<<<<< HEAD + a_regs_ping, b_tile_ping = prefetch_ab_tile(next_k1) + a_scale_ping, b_scale_ping = prefetch_ab_scale_tile(next_k1 // 256) if is_fp4 else (k0, k0) + + accs, _ = compute_tile( + accs, b_tile_pong, lds_base_pong, + a0_prefetch=a0_prefetch_pong, + a_scale=a_scale_pong, b_scale=b_scale_pong, +======= b_tile_ping = prefetch_b_tile(next_k1) if use_async_copy: prefetch_a_to_lds(next_k1, lds_a_ping) @@ -1060,6 +1312,7 @@ def prefetch_a0_pack(lds_buffer): store_a_tile_to_lds(prefetch_a_tile(next_k1), lds_a_ping) accs, _ = compute_tile( accs, b_tile_pong, lds_a_pong, a0_prefetch=a0_prefetch_pong +>>>>>>> main ) a0_prefetch_pong = None hot_loop_scheduler() @@ -1068,6 +1321,15 @@ def prefetch_a0_pack(lds_buffer): a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) next_k2 = k_iv + tile_k * 2 +<<<<<<< HEAD + a_regs_pong, b_tile_pong = prefetch_ab_tile(next_k2) + a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(next_k2 // 256) if is_fp4 else (k0, k0) + + accs, _ = compute_tile( + accs, b_tile_ping, lds_base_ping, + a0_prefetch=a0_prefetch_ping, + a_scale=a_scale_ping, b_scale=b_scale_ping, +======= b_tile_pong = prefetch_b_tile(next_k2) if use_async_copy: prefetch_a_to_lds(next_k2, lds_a_pong) @@ -1075,6 +1337,7 @@ def prefetch_a0_pack(lds_buffer): store_a_tile_to_lds(prefetch_a_tile(next_k2), lds_a_pong) accs, _ = compute_tile( accs, b_tile_ping, lds_a_ping, a0_prefetch=a0_prefetch_ping +>>>>>>> main ) a0_prefetch_ping = None @@ -1083,6 +1346,15 @@ def prefetch_a0_pack(lds_buffer): a0_prefetch_pong = prefetch_a0_pack(lds_a_pong) last_k = c_k - tile_k +<<<<<<< HEAD + a_regs_ping, b_tile_ping = prefetch_ab_tile(last_k) + a_scale_ping, b_scale_ping = prefetch_ab_scale_tile(last_k // 256) if is_fp4 else (k0, k0) + + accs, _ = compute_tile( + accs, b_tile_pong, lds_base_pong, + a0_prefetch=a0_prefetch_pong, + a_scale=a_scale_pong, b_scale=b_scale_pong, +======= b_tile_ping = prefetch_b_tile(last_k) if use_async_copy: prefetch_a_to_lds(last_k, lds_a_ping) @@ -1091,6 +1363,7 @@ def prefetch_a0_pack(lds_buffer): accs, _ = compute_tile( accs, b_tile_pong, lds_a_pong, a0_prefetch=a0_prefetch_pong +>>>>>>> main ) a0_prefetch_pong = None @@ -1105,9 +1378,44 @@ def prefetch_a0_pack(lds_buffer): lds_a_ping, is_last_tile=True, a0_prefetch=a0_prefetch_ping, + a_scale=a_scale_ping, b_scale=b_scale_ping, ) store_output(final_accs, scales) +<<<<<<< HEAD + # else: + # # CK-like bpreshuffle v1 spirit: + # # - Intrawave schedule + # # - Global prefetch 2 (regs double-buffer) + # # - Local shared memory buffer 1 (single LDS tile for A) + # # Prologue: tile-0 + # k0 = arith.constant(0, index=True) + # a_regs0, b_tile0 = prefetch_ab_tile(k0) + # store_a_tile_to_lds(a_regs0, lds_base0) + # gpu.barrier() + # accs = [acc_init] * (num_acc_n * m_repeat) + + # lds_base = lds_base0 + # b_tile_cur = b_tile0 + + # # For each tile except last: prefetch next tile, compute current, then overwrite LDS. + # for k_base in range(0, c_k - tile_k, tile_k): + # next_k = k_base + tile_k + # a_next, b_next = prefetch_ab_tile(next_k) + # accs, _ = compute_tile(accs, b_tile_cur, lds_base) + # # Single LDS buffer: ensure *all* waves are done reading A from LDS + # # before any wave overwrites it with the next tile. + # gpu.barrier() + # store_a_tile_to_lds(a_next, lds_base) + # hot_loop_scheduler() + # gpu.barrier() + # b_tile_cur = b_next + + # final_accs, scales = compute_tile( + # accs, b_tile_cur, lds_base, is_last_tile=True + # ) + # store_output(final_accs, scales) +======= else: # CK-like bpreshuffle v1 spirit: # - Intrawave schedule @@ -1139,15 +1447,16 @@ def prefetch_a0_pack(lds_buffer): accs, b_tile_cur, lds_a_pong, is_last_tile=True ) store_output(final_accs, scales) +>>>>>>> main @flir.jit def __call__( self: flir.T.i64, - arg_c: lambda: memref(DYN, T.f16), - arg_a: lambda: memref(DYN, _elem_type()), - arg_b: lambda: memref(DYN, _elem_type()), - arg_scale_a: lambda: memref(DYN, T.f32), - arg_scale_b: lambda: memref(DYN, T.f32), + arg_c: lambda: memref(DYN, _out_elem_type()), + arg_a: lambda: memref(DYN, _a_elem_type()), + arg_b: lambda: memref(DYN, _b_elem_type()), + arg_scale_a: lambda: memref(DYN, _scale_elem_type()), + arg_scale_b: lambda: memref(DYN, _scale_elem_type()), c_m: lambda: T.index, c_n: lambda: T.index, c_k: lambda: T.index, @@ -1193,5 +1502,4 @@ def __call__( ) -__all__ = ["compile_preshuffle_gemm_a8"] - +__all__ = ["compile_preshuffle_gemm"] diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index 0d6a7c1d..7243293b 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -286,7 +286,9 @@ def run_moe_stage1( tile_k: int, doweight_stage1: bool, *, - in_dtype: str = "fp8", + x_dtype: str = "fp8", + w_dtype: str = "fp8", + out_dtype: str = "f16", seed: int = 0, num_iters: int = 5, num_warmup: int = 2, @@ -301,7 +303,8 @@ def run_moe_stage1( routing_in: Optional[RoutingBuffers] = None, return_outputs: bool = False, skip_ref: bool = False, - w_fp4_kernel: bool = False, + enable_bias: bool = False, + bias_in: Optional[torch.Tensor] = None, test_graph: bool = False, ): assert model_dim % 64 == 0 @@ -360,63 +363,94 @@ def run_moe_stage1( blocks, ) = routing - if in_dtype not in ("fp8", "fp16", "int8", "int4"): - raise ValueError(f"in_dtype must be one of ('fp8','fp16','int8','int4'), got {in_dtype!r}") - is_int4 = in_dtype == "int4" - is_int8 = in_dtype in ("int8", "int4") + if x_dtype not in ("fp8", "fp16", "int8"): + raise ValueError(f"x_dtype must be one of ('fp8','fp16','int8'), got {x_dtype!r}") + if w_dtype not in ("fp8", "fp16", "int8", "int4", "fp4"): + raise ValueError(f"w_dtype must be one of ('fp8','fp16','int8','int4', 'fp4'), got {w_dtype!r}") + is_int4 = w_dtype == "int4" + is_int8 = x_dtype in ("int8", "int4") # Quantize inputs / weights. - if in_dtype == "fp8": + if x_dtype == "fp8" and w_dtype == "fp8": x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) # [tokens,K], [tokens,1] w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=DTYPE_FP8) # [E,2*inter,K], [E,2*inter,1] # w2 is not used by our kernel, but required by CK stage1 API w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=DTYPE_FP8) - elif in_dtype == "fp16": + elif x_dtype == "fp16" and w_dtype == "fp16": x_q = x_fp32.to(torch.float16) w1_q = w1_fp32.to(torch.float16) w2_q = w2_fp32.to(torch.float16) scale_x = None scale_w1 = None - elif in_dtype == "int8": + elif x_dtype == "int8" and w_dtype == "int8": x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=torch.int8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8) w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=torch.int8) - else: + elif x_dtype == "int8" and w_dtype == "int4": # W4A8: X is int8, W is int4 packed (host packs from int8 values in [-8,7]). x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=torch.int8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) + elif x_dtype == "fp8" and w_dtype == "fp4": + x_q = x_fp32.to(DTYPE_FP8) + scale_x = torch.ones([tokens, model_dim // 32], dtype=fp4_utils.fp8_e8m0, device=device) + w1_q, scale_w1, w1_convert = fp4_utils.per_1x32_f4_quant(w1_fp32) # (E, 2*inter, K) + w2_q, scale_w2, w2_convert = fp4_utils.per_1x32_f4_quant(w2_fp32) # (E, model_dim, inter) + else: + raise ValueError(f"Invalid combination of x_dtype and w_dtype: {x_dtype!r}, {w_dtype!r}") # Preshuffle weights (aiter/CK layout) on the *unpacked* tensor. - w1_shuffled = shuffle_weight(w1_q) - w2_shuffled = shuffle_weight(w2_q) if in_dtype == "fp8" else None + if w_dtype != "fp4": + w1_shuffled = shuffle_weight(w1_q) + w2_shuffled = shuffle_weight(w2_q) if w_dtype == "fp8" else None + # Flatten W1 for our flir kernel (treat expert dim as part of N). + w1_shuffled_flat = w1_shuffled.view(experts * (2 * inter_dim), model_dim) + w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim) + scale_w1_flat = None if scale_w1 is None else scale_w1.view(experts * (2 * inter_dim), 1) + w_kernel = ( + _pack_shuffled_int8_to_packed_int4_no_perm(w1_shuffled_flat) if is_int4 else w1_shuffled_flat + ).contiguous() + if not is_int4: + w_kernel = w_kernel.view(experts * (2 * inter_dim), model_dim) + if scale_w1_flat is None: + scale_w1_1d = torch.empty((0,), device=device, dtype=torch.float32) + else: + scale_w1_1d = scale_w1_flat.view(-1).contiguous() # [rows] + + else: + w1_shuffled = fp4_utils.shuffle_weight_w4(w1_q, 16, True, True) + w2_shuffled = fp4_utils.shuffle_weight_w4(w2_q, 16, True, True) + scale_w1_shuffled = fp4_utils.shuffle_scale_w4(scale_w1, experts, True) + scale_w1_1d = scale_w1_shuffled + w_kernel = w1_shuffled - # Flatten W1 for our flir kernel (treat expert dim as part of N). - w1_shuffled_flat = w1_shuffled.view(experts * (2 * inter_dim), model_dim) - w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim) - scale_w1_flat = None if scale_w1 is None else scale_w1.view(experts * (2 * inter_dim), 1) # No host-side padding: keep tensors contiguous and rely on kernel-side resource sizes / early-exit. x_q = x_q.contiguous().view(tokens, model_dim) - w_kernel = ( - _pack_shuffled_int8_to_packed_int4_no_perm(w1_shuffled_flat) if is_int4 else w1_shuffled_flat - ).contiguous() - if not is_int4: - w_kernel = w_kernel.view(experts * (2 * inter_dim), model_dim) - # Flatten scales to 1D memrefs (fp16 path uses 0-sized scale tensors; kernel ignores them). if scale_x is None: scale_x_1d = torch.empty((0,), device=device, dtype=torch.float32) else: scale_x_1d = scale_x.view(-1).contiguous() # [tokens] - if scale_w1_flat is None: - scale_w1_1d = torch.empty((0,), device=device, dtype=torch.float32) - else: - scale_w1_1d = scale_w1_flat.view(-1).contiguous() # [rows] + sorted_weights_1d = sorted_weights.contiguous().view(-1) # [sorted_size] # Output: [tokens, topk, inter_dim] fp16 - out = torch.empty((tokens, topk, inter_dim), device=device, dtype=torch.float16) + out_torch_dtype = DTYPE_FP8 if out_dtype == "fp8" else torch.float16 + out = torch.empty((tokens, topk, inter_dim), device=device, dtype=out_torch_dtype) + + # Bias: [experts, 2 * inter_dim] f32 (gate_bias, up_bias concatenated) + if enable_bias: + bias = ( + bias_in + if bias_in is not None + else torch.randn((experts, 2 * inter_dim), device=device, dtype=torch.float32) * 0.1 + ) + # Flatten bias for kernel: [experts * 2 * inter_dim] + bias_1d = bias.view(-1).contiguous() + else: + bias = None + bias_1d = torch.empty((0,), device=device, dtype=torch.float32) from kernels.moe_gemm_2stage import compile_moe_gemm1 exe = compile_moe_gemm1( @@ -424,15 +458,18 @@ def run_moe_stage1( inter_dim=inter_dim, experts=experts, topk=topk, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, doweight_stage1=bool(doweight_stage1), use_cshuffle_epilog=False, + enable_bias=enable_bias, ) - def launch(o, x, w, sx, sw, st, eids, sw_sorted): + def launch(o, x, w, sx, sw, st, eids, sw_sorted, bias_arg): stream_ptr = torch.cuda.current_stream().cuda_stream exe( o, @@ -444,6 +481,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): eids, sw_sorted, num_valid_ids, + bias_arg, tokens, inter_dim, model_dim, @@ -461,6 +499,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): sorted_token_ids, sorted_expert_ids, sorted_weights_1d, + bias_1d, num_iters=int(num_iters), num_warmup=int(num_warmup), testGraph=test_graph, @@ -468,20 +507,33 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): torch.cuda.synchronize() if not bool(skip_ref): + if w_dtype != "fp4": + x_ref = x_q + w1_flat_ref = w1_q_flat + scale_x_ref = scale_x + scale_w1_flat_ref = scale_w1_flat + else: + x_ref = x_fp32_in + w1_flat_ref = w1_fp32_in + scale_x_ref = None + scale_w1_flat_ref = None ref = torch_moe_gemm1( - x_q, - w1_q_flat, - scale_x, - scale_w1_flat, + x_ref, + w1_flat_ref, + scale_x_ref, + scale_w1_flat_ref, topk_ids.to(torch.int64), topk_weights, inter_dim=inter_dim, doweight_stage1=doweight_stage1, + bias=bias, ) rtol = 0.5 if is_int4 else 0.25 atol = 0.5 if is_int4 else 0.25 - assert verify_output(out.to(torch.float32), ref, rtol=rtol, atol=atol) + print(out.to(torch.float32)) + print(ref) + assert verify_output(out.to(torch.float32), ref, rtol=1e-3, atol=1e-3) # Note: kernel launches full expert-block range; effective work is gated by num_valid_ids. flops = 2 * tokens * topk * (2 * inter_dim) * model_dim @@ -500,7 +552,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): tbps = bytes_moved / 1e12 / (us / 1e6) print( - f"FLIR MoE stage1[{in_dtype}]: " + f"FLIR MoE stage1[{x_dtype} | {w_dtype} -> {out_dtype}]: " f"{us:.1f} us, " f"{tflops:.2f} TFLOPS(logical, M={tokens*topk}), " f"{tbps:.3f} TB/s (doweight_stage1={doweight_stage1})" @@ -511,7 +563,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): else: compare_ck = bool(compare_aiter_ck) # aiter CK paths are fp8-only in our setup. - compare_ck = compare_ck and (in_dtype == "fp8") + compare_ck = compare_ck and (x_dtype == "fp8" and w_dtype == "fp8") if compare_ck: if not HAS_AITER: pytest.skip("aiter not available; cannot compare to CK moe stage1.", allow_module_level=False) @@ -594,7 +646,8 @@ def run_moe_stage2( tile_k: int, doweight_stage1: bool, *, - in_dtype: str = "fp8", + x_dtype: str = "fp8", + w_dtype: str = "fp8", # Stage2 output is fp16 (half2 atomics + CShuffle). The legacy f32-atomic path was removed. out_dtype: str = "f16", seed: int = 0, @@ -620,6 +673,8 @@ def run_moe_stage2( kernel_name: str = "moe_gemm2", # Use reduce mode (accumulate=False) instead of atomic mode. use_reduce: bool = False, + enable_bias: bool = False, + bias_in: Optional[torch.Tensor] = None, test_graph: bool = False, ): """MoE stage2 (gemm2): out2[t] = sum_{slot} ( out1[t,slot] @ W2[expert]^T ) with optional routed weight.""" @@ -668,22 +723,22 @@ def run_moe_stage2( x_fp32 = ( x_fp32_in if x_fp32_in is not None - else torch.rand((tokens, model_dim), device=device, dtype=torch.float32) * s + else torch.randn((tokens, model_dim), device=device, dtype=torch.float32) * s ) w1_fp32 = ( w1_fp32_in if w1_fp32_in is not None - else torch.rand((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * (s / math.sqrt(model_dim)) + else torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * (s / math.sqrt(model_dim)) ) w2_fp32 = ( w2_fp32_in if w2_fp32_in is not None - else torch.rand((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) + else torch.randn((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) ) # Routing: deterministic torch topk + softmax. if topk_ids_in is None or topk_weights_in is None: - score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + score = torch.randn((tokens, experts), device=device, dtype=torch.float32) topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) else: @@ -713,39 +768,52 @@ def run_moe_stage2( # NOTE: routing uses `moe_sorting` output directly (no host trim/pad). Extra launched blocks # are gated by `num_valid_ids` inside the kernels. - if in_dtype not in ("fp8", "fp16", "int8", "int4"): - raise ValueError(f"in_dtype must be one of ('fp8','fp16','int8','int4'), got {in_dtype!r}") - is_int4 = in_dtype == "int4" - is_int8 = in_dtype in ("int8", "int4") + if x_dtype not in ("fp8", "fp16", "int8", "int4"): + raise ValueError(f"x_dtype must be one of ('fp8','fp16','int8','int4'), got {x_dtype!r}") + is_int4 = w_dtype == "int4" + is_int8 = x_dtype in ("int8", "int4") # Quantize inputs / weights. - if in_dtype == "fp8": + if x_dtype == "fp8" and w_dtype == "fp8": x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=DTYPE_FP8) w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=DTYPE_FP8) - elif in_dtype == "fp16": + elif x_dtype == "fp16" and w_dtype == "fp16": x_q = x_fp32.to(torch.float16) w1_q = w1_fp32.to(torch.float16) w2_q = w2_fp32.to(torch.float16) scale_x = None scale_w1 = None scale_w2 = None - elif in_dtype == "int8": + elif x_dtype == "int8" and w_dtype == "int8": x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=torch.int8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8) w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=torch.int8) - else: + elif x_dtype == "int8" and w_dtype == "int4": # W4A8: A2 is int8, W2 is int4 packed (host packs from int8 values in [-8,7]). x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=torch.int8) w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) + elif x_dtype == "fp8" and w_dtype == "fp4": + x_q = x_fp32.to(DTYPE_FP8) + scale_x = torch.ones([tokens, model_dim // 32], dtype=fp4_utils.fp8_e8m0, device=device) + w1_q, scale_w1, w1_convert = fp4_utils.per_1x32_f4_quant(w1_fp32) # (E, 2*inter, K) + w2_q, scale_w2, w2_convert = fp4_utils.per_1x32_f4_quant(w2_fp32) # (E, model_dim, inter) + else: + raise ValueError(f"Invalid combination of x_dtype and w_dtype: {x_dtype!r}, {w_dtype!r}") # Preshuffle weights (aiter/CK layout) on the *unpacked* tensor. - w1_shuffled = shuffle_weight(w1_q) - w2_shuffled = shuffle_weight(w2_q) + if w_dtype == "fp4": + w1_shuffled = fp4_utils.shuffle_weight_w4(w1_q, 16, True, True) + w2_shuffled = fp4_utils.shuffle_weight_w4(w2_q, 16, True, True) + scale_w1_shuffled = fp4_utils.shuffle_scale_w4(scale_w1, experts, True) + scale_w2_shuffled = fp4_utils.shuffle_scale_w4(scale_w2, experts, True) + else: + w1_shuffled = shuffle_weight(w1_q) + w2_shuffled = shuffle_weight(w2_q) # Stage2 input (A2): either provided (gemm1->quantize chaining) or built from stage1 reference. - if a2_fp8_in is not None and (a2_scale_in is not None or in_dtype == "fp16"): + if a2_fp8_in is not None and (a2_scale_in is not None or (x_dtype == "fp16" and w_dtype == "fp16")): a2_q = a2_fp8_in a2_scale = a2_scale_in else: @@ -758,46 +826,52 @@ def run_moe_stage2( "(so we don't have to run the huge torch reference stage1)." ) out1_ref = torch_moe_gemm1( - x_q, - w1_q_flat, - scale_x, - scale_w1_flat, + x_fp32_in, + w1_fp32_in, + None, + None, topk_ids.to(torch.int64), topk_weights, inter_dim=inter_dim, doweight_stage1=bool(doweight_stage1), ) # [tokens, topk, inter] fp32 - if in_dtype == "fp8": + if x_dtype == "fp8" and w_dtype == "fp8": a2_q, a2_scale = pertoken_quant(out1_ref, quant_dtype=DTYPE_FP8) - elif in_dtype == "fp16": + elif x_dtype == "fp16" and w_dtype == "fp16": a2_q = out1_ref.to(torch.float16) a2_scale = None else: a2_q, a2_scale = pertoken_quant(out1_ref, quant_dtype=torch.int8) # Flatten weights/scales for the kernel. - w2_shuffled_flat = w2_shuffled.view(experts * model_dim, inter_dim) - scale_w2_flat = None if scale_w2 is None else scale_w2.view(experts * model_dim, 1) + if w_dtype == "fp4": + # FP4 path: use shuffled weights directly (already packed) + w2_kernel = w2_shuffled + w2_scale_1d = scale_w2_shuffled + else: + w2_shuffled_flat = w2_shuffled.view(experts * model_dim, inter_dim) + scale_w2_flat = None if scale_w2 is None else scale_w2.view(experts * model_dim, 1) - # For W4A8, pack preshuffled int8 weights into packed int4 bytes. - w2_kernel = w2_shuffled_flat - if is_int4: - w2_kernel = _pack_shuffled_int8_to_packed_int4_no_perm(w2_shuffled_flat) + # For W4A8, pack preshuffled int8 weights into packed int4 bytes. + w2_kernel = w2_shuffled_flat + if is_int4: + w2_kernel = _pack_shuffled_int8_to_packed_int4_no_perm(w2_shuffled_flat) - w2_flat = w2_kernel.contiguous().view(-1) - w2_kernel = w2_flat - if not is_int4: - w2_kernel = w2_kernel.view(experts * model_dim, inter_dim) + w2_flat = w2_kernel.contiguous().view(-1) + w2_kernel = w2_flat + if not is_int4: + w2_kernel = w2_kernel.view(experts * model_dim, inter_dim) + + if scale_w2_flat is None: + w2_scale_1d = torch.empty((0,), device=device, dtype=torch.float32) + else: + w2_scale_1d = scale_w2_flat.view(-1).contiguous() # [experts*model_dim] # Flatten scales to 1D memrefs (fp16 path uses 0-sized scale tensors; kernel ignores them). if a2_scale is None: a2_scale_1d = torch.empty((0,), device=device, dtype=torch.float32) else: a2_scale_1d = a2_scale.view(-1).contiguous() # [tokens*topk] - if scale_w2_flat is None: - w2_scale_1d = torch.empty((0,), device=device, dtype=torch.float32) - else: - w2_scale_1d = scale_w2_flat.view(-1).contiguous() # [experts*model_dim] sorted_weights_1d = sorted_weights.contiguous().view(-1) # [sorted_size] out_s = str(out_dtype).strip().lower() @@ -808,20 +882,36 @@ def run_moe_stage2( out = torch.zeros((tokens, model_dim), device=device, dtype=out_torch_dtype) out_perf = torch.zeros_like(out) + # Bias: [experts, model_dim] f32 + if enable_bias: + bias = ( + bias_in + if bias_in is not None + else torch.randn((experts, model_dim), device=device, dtype=torch.float32) * 0.1 + ) + # Flatten bias for kernel: [experts * model_dim] + bias_1d = bias.view(-1).contiguous() + else: + bias = None + bias_1d = torch.empty((0,), device=device, dtype=torch.float32) + doweight_stage2 = not bool(doweight_stage1) exe = compile_fn( model_dim=model_dim, inter_dim=inter_dim, experts=experts, topk=topk, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, doweight_stage2=bool(doweight_stage2), + enable_bias=enable_bias, ) - def launch(o, x, w, sx, sw, st, eids, sw_sorted): + def launch(o, x, w, sx, sw, st, eids, sw_sorted, bias_arg): stream_ptr = torch.cuda.current_stream().cuda_stream exe( o, @@ -833,6 +923,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): eids, sw_sorted, num_valid_ids, + bias_arg, tokens, model_dim, inter_dim, @@ -853,6 +944,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): sorted_token_ids, sorted_expert_ids, sorted_weights_1d, + bias_1d, num_iters=int(num_iters), num_warmup=int(num_warmup), testGraph=test_graph, @@ -870,21 +962,33 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): sorted_token_ids, sorted_expert_ids, sorted_weights_1d, + bias_1d, ) torch.cuda.synchronize() if not bool(skip_ref): + if w_dtype != "fp4": + a2_ref = a2_q + w2_ref = w2_q + a2_scale_ref = a2_scale + scale_w2_ref = scale_w2 + else: + a2_ref = a2_q.to(torch.float32) + w2_ref = w2_fp32_in + a2_scale_ref = None + scale_w2_ref = None ref2 = torch_moe_gemm2( - a2_q, - w2_q, - a2_scale, - scale_w2, + a2_ref, + w2_ref, + a2_scale_ref, + scale_w2_ref, topk_ids.to(torch.int64), topk_weights, model_dim=model_dim, doweight_stage2=doweight_stage2, + bias=bias, ) - assert verify_output(out.to(torch.float32), ref2, rtol=0.5, atol=0.5) + assert verify_output(out.to(torch.float32), ref2, rtol=1e-3, atol=1e-3) # Launches full expert-block range; effective work is gated by num_valid_ids. flops = 2 * tokens * topk * model_dim * inter_dim @@ -901,7 +1005,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): bytes_moved += int(sorted_expert_ids.numel()) * 4 tbps = bytes_moved / 1e12 / (us / 1e6) print( - f"FLIR MoE stage2 [{kernel_name}] {in_dtype} | " + f"FLIR MoE stage2 [{kernel_name}] {x_dtype} | {w_dtype} -> {out_dtype} | " f"{model_dim}x{inter_dim}, E={experts}, K={topk}, M_eff={tokens*topk} | " f"{us:.1f} us, {tflops:.2f} TFLOPS, {tbps:.3f} TB/s" ) @@ -911,7 +1015,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): else: compare_ck = bool(compare_aiter_ck) # aiter CK paths are fp8-only in our setup. - compare_ck = compare_ck and (in_dtype == "fp8") + compare_ck = compare_ck and (x_dtype == "fp8" and w_dtype == "fp8") if compare_ck: if not HAS_AITER: pytest.skip("aiter not available; cannot compare to CK moe stage2.", allow_module_level=False) @@ -1000,15 +1104,16 @@ def launch_ck(o, a2_, w1_, w2_, sorted_ids_, sorted_eids_, num_valid_, w2_scale_ @pytest.mark.parametrize( "tokens, model_dim, inter_dim, experts, topk, tile_m, tile_n1, tile_k1, tile_n2, tile_k2, doweight_stage1", [ - # Small smoke (fast compile + run) for all in_dtype. + # Small smoke (fast compile + run) for all x_dtype and w_dtype. pytest.param(64, 256, 128, 4, 2, 32, 64, 128, 64, 128, False, id="S"), - # Medium (more realistic) for all in_dtype (skip_ref will auto-enable). + # Medium (more realistic) for all x_dtype and w_dtype (skip_ref will auto-enable). pytest.param(128, 1024, 256, 8, 2, 64, 128, 128, 128, 128, False, id="M"), # Large (aiter-style) mainly for perf smoke; reference is too expensive here. pytest.param(256, 4096, 2048, 17, 9, 64, 128, 128, 256, 128, False, id="L"), ], ) -@pytest.mark.parametrize("in_dtype", ["fp8", "fp16", "int8", "int4"]) +@pytest.mark.parametrize("x_dtype", ["fp8", "fp16", "int8"]) +@pytest.mark.parametrize("w_dtype", ["fp8", "fp16", "int8", "int4"]) @pytest.mark.parametrize("use_reduce", [False, True], ids=["atomic", "reduce"]) def test_moe_gemm_2stage( tokens: int, @@ -1022,9 +1127,11 @@ def test_moe_gemm_2stage( tile_n2: int, tile_k2: int, doweight_stage1: bool, - in_dtype: str, + x_dtype: str, + w_dtype: str, use_reduce: bool, *, + out_dtype: str = "f16", seed: int = 0, num_iters: int = 5, num_warmup: int = 2, @@ -1032,7 +1139,6 @@ def test_moe_gemm_2stage( compare_aiter_ck: Optional[bool] = None, init_scale: float = 1.0, skip_ref: bool = False, - w_fp4_kernel: bool = False, test_graph: bool = False, ): """Single 2-stage test: gemm1 -> quantize -> gemm2, with routing built once.""" @@ -1045,15 +1151,12 @@ def test_moe_gemm_2stage( init_scale = 0.2 s = float(init_scale) x_fp32 = torch.randn((tokens, model_dim), device=device, dtype=torch.float32) * s - # x_fp32 = torch.ones((tokens, model_dim), device=device, dtype=torch.float32) * s # fan_in = model_dim for W1: [E, 2*inter, model] - w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * s #* (s / math.sqrt(model_dim)) - # w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * 0.2 - # w1_fp32 = torch.ones((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * s + w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * (s / math.sqrt(model_dim)) # fan_in = inter_dim for W2: [E, model, inter] w2_fp32 = torch.randn((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) - score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + score = torch.randn((tokens, experts), device=device, dtype=torch.float32) topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) @@ -1080,7 +1183,9 @@ def test_moe_gemm_2stage( inter_dim=inter_dim, experts=experts, topk=topk, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, tile_m=tile_m, tile_n=tile_n1, tile_k=tile_k1, @@ -1088,7 +1193,7 @@ def test_moe_gemm_2stage( seed=seed, num_iters=num_iters, num_warmup=num_warmup, - compare_aiter_ck=bool(compare_aiter_ck) and (in_dtype == "fp8"), + compare_aiter_ck=bool(compare_aiter_ck) and (x_dtype == "fp8" and w_dtype == "fp8"), moe_sort_mode=moe_sort_mode, x_fp32_in=x_fp32, w1_fp32_in=w1_fp32, @@ -1098,20 +1203,20 @@ def test_moe_gemm_2stage( routing_in=routing, return_outputs=True, skip_ref=bool(skip_ref), - w_fp4_kernel=w_fp4_kernel, ) - if w_fp4_kernel: - a2_q = out1_fp16.to(torch.float32) - # a2_q = torch.ones_like(out1_fp16, dtype=torch.float32) / 5 - # w2_fp32 = torch.ones_like(w2_fp32, dtype=torch.float32) / 10 - a2_scale = None - elif in_dtype == "fp8": + if x_dtype == "fp8" and w_dtype == "fp8": out1_fp32 = out1_fp16.to(torch.float32) a2_q, a2_scale = pertoken_quant(out1_fp32, quant_dtype=DTYPE_FP8) - elif in_dtype == "fp16": + elif x_dtype == "fp16" and w_dtype == "fp16": a2_q = out1_fp16 a2_scale = None + elif x_dtype == "fp8" and w_dtype == "fp4": + if out1_fp16.dtype == torch.float16: + a2_q = out1_fp16.to(DTYPE_FP8) + else: + a2_q = out1_fp16 + a2_scale = torch.ones([tokens, topk, inter_dim // 32], dtype=fp4_utils.fp8_e8m0, device=device) else: out1_fp32 = out1_fp16.to(torch.float32) a2_q, a2_scale = pertoken_quant(out1_fp32, quant_dtype=torch.int8) @@ -1122,7 +1227,9 @@ def test_moe_gemm_2stage( inter_dim=inter_dim, experts=experts, topk=topk, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, tile_m=tile_m, tile_n=tile_n2, tile_k=tile_k2, @@ -1164,11 +1271,15 @@ def _compile( tile_n: int, tile_k: int, doweight_stage2: bool, - in_dtype: str = "fp8", + x_dtype: str, + w_dtype: str, out_dtype: str = "f16", + enable_bias: bool = False, ): if use_flydsl_reduce: # Use unified implementation with FlyDSL reduce kernel + # Note: compile_moe_gemm2_ex uses in_dtype instead of x_dtype/w_dtype + in_dtype = x_dtype # Assume x_dtype and w_dtype are the same for reduce mode return compile_moe_gemm2_ex( model_dim=model_dim, inter_dim=inter_dim, @@ -1193,9 +1304,11 @@ def _compile( tile_n=tile_n, tile_k=tile_k, doweight_stage2=doweight_stage2, - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, out_dtype=out_dtype, accumulate=False, + enable_bias=enable_bias, ) return _TorchReduceWrapper(gemm2_exe, topk, model_dim) return _compile @@ -1224,6 +1337,7 @@ def __call__( arg_expert_ids, arg_sorted_weights, arg_num_valid_ids, + arg_bias, tokens_in, n_in, k_in, @@ -1243,7 +1357,7 @@ def __call__( intermediate.view(-1), arg_x, arg_w, arg_scale_x, arg_scale_w, arg_sorted_token_ids, arg_expert_ids, arg_sorted_weights, - arg_num_valid_ids, tokens_in, n_in, k_in, size_expert_ids_in, + arg_num_valid_ids, arg_bias, tokens_in, n_in, k_in, size_expert_ids_in, stream_ptr, ) torch.sum(intermediate.view(tokens_in, self._topk, self._model_dim), dim=1, out=arg_out) @@ -1394,14 +1508,16 @@ def test_moe_reduce_kernel(tokens: int, topk: int, model_dim: int): pytest.param(8, 5120, 1536, 64, 6, id="EP-K6-decode-bs8"), ], ) -@pytest.mark.parametrize("in_dtype", ["fp8"]) +@pytest.mark.parametrize("x_dtype", ["fp8"]) +@pytest.mark.parametrize("w_dtype", ["fp8"]) def test_moe_stage2_standalone( tokens: int, model_dim: int, inter_dim: int, experts: int, topk: int, - in_dtype: str, + x_dtype: str, + w_dtype: str, *, tile_m: int = 64, # Common block size for M tile_n: int = 256, # Common block size for N2 @@ -1428,7 +1544,8 @@ def test_moe_stage2_standalone( tile_n=tile_n, tile_k=tile_k, doweight_stage1=False, # Apply weight in stage2 - in_dtype=in_dtype, + x_dtype=x_dtype, + w_dtype=w_dtype, seed=seed, num_iters=num_iters, num_warmup=num_warmup, @@ -1483,12 +1600,25 @@ def _str2tuple_dim(v: str) -> Tuple[int, int]: description="MoE 2-stage (FLIR MFMA FP8) test/benchmark (argparse subset aligned with aiter test_moe_2stage.py)", ) parser.add_argument( - "--in_dtype", + "--x_dtype", type=str, default="fp8", - choices=["fp8", "fp16", "int8", "int4", "all"], - help="Kernel input dtype: fp8 / int8 / int4 / all (default: all). " - "int4 means W4A8: A int8, W packed int4.", + choices=["fp8", "fp16", "int8", "int4"], + help="Kernel input dtype: fp8 / fp16 / int8 / int4.", + ) + parser.add_argument( + "--w_dtype", + type=str, + default="fp8", + choices=["fp8", "fp16", "int8", "int4", "fp4"], + help="Kernel weight dtype: fp8 / fp16 / int8 / int4.", + ) + parser.add_argument( + "--out_dtype", + type=str, + default="f16", + choices=["f16", "bf16", "f32"], + help="Stage2 output dtype: f16 / f32.", ) parser.add_argument("-d", "--dtype", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], help="Input init dtype (currently data is quantized to FP8 per-token; init dtype mainly affects RNG range).") parser.add_argument("-dim", type=_str2tuple_dim, default=(6144, 4096), help="Model dimension: model_dim,inter_dim (e.g. -dim 6144,4096)") @@ -1540,26 +1670,28 @@ def _str2tuple_dim(v: str) -> Tuple[int, int]: tile_k2 = int(args.tile_k2) if args.tile_k2 is not None else args.tile_k # Run 2-stage (gemm1 -> quantize -> gemm2) aiter-style test/benchmark. - for dt in args.in_dtype.split(","): - test_moe_gemm_2stage( - tokens=int(args.tokenNum), - model_dim=int(model_dim), - inter_dim=int(inter_dim), - experts=int(args.expert), - topk=int(args.topk), - tile_m=int(args.tile_m), - tile_n1=int(args.tile_n), - tile_k1=int(args.tile_k), - tile_n2=tile_n2, - tile_k2=tile_k2, - doweight_stage1=bool(args.doweight_stage1), - in_dtype=dt, - seed=int(args.seed), - num_iters=int(args.num_iters), - num_warmup=int(args.num_warmup), - moe_sort_mode=args.moe_sort_mode, - compare_aiter_ck=args.compare_aiter_ck, - skip_ref=bool(args.skip_ref), - use_reduce=bool(args.reduce), - test_graph=bool(args.test_graph), - ) + for x_dtype in args.x_dtype.split(","): + for w_dtype in args.w_dtype.split(","): + test_moe_gemm_2stage( + tokens=int(args.tokenNum), + model_dim=int(model_dim), + inter_dim=int(inter_dim), + experts=int(args.expert), + topk=int(args.topk), + tile_m=int(args.tile_m), + tile_n1=int(args.tile_n), + tile_k1=int(args.tile_k), + tile_n2=tile_n2, + tile_k2=tile_k2, + doweight_stage1=bool(args.doweight_stage1), + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=args.out_dtype, + seed=int(args.seed), + num_iters=int(args.num_iters), + num_warmup=int(args.num_warmup), + moe_sort_mode=args.moe_sort_mode, + compare_aiter_ck=args.compare_aiter_ck, + skip_ref=bool(args.skip_ref), + use_reduce=bool(args.reduce), + ) diff --git a/tests/kernels/test_preshuffle_gemm.py b/tests/kernels/test_preshuffle_gemm.py index 28d922b1..2f53a3b7 100644 --- a/tests/kernels/test_preshuffle_gemm.py +++ b/tests/kernels/test_preshuffle_gemm.py @@ -29,8 +29,7 @@ if _PYFLIR_SRC not in sys.path: sys.path.insert(0, _PYFLIR_SRC) -from kernels.preshuffle_gemm import compile_preshuffle_gemm_a8 -from kernels.mixed_preshuffle_gemm import compile_mxfp4_preshuffle_gemm +from kernels.preshuffle_gemm import compile_preshuffle_gemm from tests.test_common import run_perftest, verify_output from tests.utils import pertoken_quant, shuffle_weight from flydsl.runtime.device import get_rocm_arch @@ -94,7 +93,9 @@ def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): out = out.to(bias.dtype) + bias return out.to(dtype) -@pytest.mark.parametrize("in_dtype", ["fp8", "int8", "bf16"]) + +@pytest.mark.parametrize("a_dtype", ["fp8", "int8", "bf16"]) +@pytest.mark.parametrize("b_dtype", ["fp8", "int8", "bf16"]) @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k", [ @@ -106,7 +107,8 @@ def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): ) @pytest.mark.parametrize("use_async_copy", [False, True], ids=["sync_copy", "async_copy"]) def test_mfma_a8_flir_preshuffle( - in_dtype, + a_dtype, + b_dtype, M, N, K, @@ -126,7 +128,7 @@ def test_mfma_a8_flir_preshuffle( pytest.skip("async copy is only supported in gfx950") print("=" * 80) print( - f"MFMA {in_dtype.upper()} GEMM Test (Tile: {tile_m}x{tile_n}x{tile_k}) [Torch Optimized]" + f"MFMA {a_dtype.upper()}/{b_dtype.upper()} GEMM Test (Tile: {tile_m}x{tile_n}x{tile_k}) [Torch Optimized]" ) print("=" * 80) @@ -135,14 +137,16 @@ def test_mfma_a8_flir_preshuffle( raise ValueError( f"lds_stage must be 1 or 2, got {lds_stage!r}" ) - exe = compile_preshuffle_gemm_a8( + exe = compile_preshuffle_gemm( M=M, N=N, K=K, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, - in_dtype=in_dtype, + a_dtype=a_dtype, + b_dtype=b_dtype, + out_dtype="f16", lds_stage=lds_stage, use_cshuffle_epilog=bool(use_cshuffle_epilog), use_async_copy=bool(use_async_copy), @@ -152,10 +156,10 @@ def test_mfma_a8_flir_preshuffle( size_c = M * N size_a = M * K # B is packed int4 for W4A8: 2 values per byte. - if in_dtype == "int4": + if b_dtype == "int4": size_b = (N * K) // 2 elem_bytes = 1 - elif in_dtype in ("fp16", "bf16"): + elif b_dtype in ("fp16", "bf16"): size_b = (N * K) * 2 elem_bytes = 2 else: @@ -168,12 +172,12 @@ def test_mfma_a8_flir_preshuffle( a_fp32 = torch.rand(M, K, device=device, dtype=torch.float32) b_fp32_t = torch.rand(N, K, device=device, dtype=torch.float32) # (N, K) - is_int4 = in_dtype == "int4" + is_int4 = b_dtype == "int4" # INT4 here means W4A8: A is INT8, B is packed INT4 and unpacked to INT8 in-kernel. - is_int8 = (in_dtype == "int8") or is_int4 + is_int8 = (a_dtype == "int8") or is_int4 - if in_dtype in ("fp16", "bf16"): - torch_dtype = torch.float16 if in_dtype == "fp16" else torch.bfloat16 + if a_dtype in ("fp16", "bf16") or b_dtype in ("fp16", "bf16"): + torch_dtype = torch.float16 if a_dtype == "fp16" else torch.bfloat16 a_q = a_fp32.to(torch_dtype) b_q = b_fp32_t.to(torch_dtype) # Scale is semantically optional for fp16/bf16 (no dequant). Let callers pass None; @@ -260,7 +264,7 @@ def launch_kernel(c, a, b, sa, sb): tbps = bytes_moved / 1e12 / (us / 1e6) print(f"Throughput: {us:.1f} us, {tflops:.2f} TFLOPS, BW: {tbps:.3f} TB/s") - if HAS_AITER and bool(run_aiter_bench) and (not is_int4) and (in_dtype in ("fp8", "int8")): + if HAS_AITER and bool(run_aiter_bench) and (not is_int4) and (a_dtype in ("fp8", "int8") and a_dtype == b_dtype): print("-" * 40) print("Running Aiter Benchmark...") try: @@ -329,7 +333,7 @@ def test_mfma_w4_flir_preshuffle( raise ValueError( f"lds_stage must be 1 or 2, got {lds_stage!r}" ) - exe = compile_mxfp4_preshuffle_gemm( + exe = compile_preshuffle_gemm( M=M, N=N, K=K, @@ -338,6 +342,7 @@ def test_mfma_w4_flir_preshuffle( tile_k=tile_k, a_dtype=a_dtype, b_dtype=b_dtype, + out_dtype="f16", lds_stage=lds_stage, use_cshuffle_epilog=bool(use_cshuffle_epilog), ) @@ -439,11 +444,18 @@ def launch_kernel(c, a, b, sa, sb): description="Preshuffle GEMM benchmark" ) parser.add_argument( - "--in_dtype", + "--a_dtype", + type=str, + default="fp8", + choices=["fp8", "int8", "int4", "fp16", "bf16", "fp4"], + help="Input dtype" + ) + parser.add_argument( + "--b_dtype", type=str, default="fp8", choices=["fp8", "int8", "int4", "fp16", "bf16", "fp4"], - help="Input dtype", + help="Input dtype" ) parser.add_argument("-M", type=int, default=16, help="M dimension") parser.add_argument("-N", type=int, default=10240, help="N dimension") @@ -516,12 +528,13 @@ def launch_kernel(c, a, b, sa, sb): if args.in_dtype == "fp4": raise ValueError("--in_dtype fp4 requires --wfp4") test_mfma_a8_flir_preshuffle( - args.in_dtype, - M=args.M, - N=args.N, - K=args.K, - tile_m=args.tile_m, - tile_n=args.tile_n, + args.a_dtype, + args.b_dtype, + M=args.M, + N=args.N, + K=args.K, + tile_m=args.tile_m, + tile_n=args.tile_n, tile_k=args.tile_k, lds_stage=args.lds_stage, bench_iters=args.num_iters, @@ -534,13 +547,13 @@ def launch_kernel(c, a, b, sa, sb): else: pack_M = 2 test_mfma_w4_flir_preshuffle( - args.in_dtype if args.in_dtype == "fp8" else "fp4", - "fp4", - M=args.M, - N=args.N, - K=args.K, - tile_m=args.tile_m * pack_M, - tile_n=args.tile_n, + args.a_dtype if args.a_dtype == "fp8" else "fp4", + "fp4", + M=args.M, + N=args.N, + K=args.K, + tile_m=args.tile_m * pack_M, + tile_n=args.tile_n, tile_k=args.tile_k, lds_stage=args.lds_stage, bench_iters=args.num_iters, diff --git a/tests/kernels/test_ref.py b/tests/kernels/test_ref.py index a97b69eb..ed084202 100644 --- a/tests/kernels/test_ref.py +++ b/tests/kernels/test_ref.py @@ -10,8 +10,13 @@ def torch_moe_gemm1( topk_weights: torch.Tensor, inter_dim: int, doweight_stage1: bool, + bias: torch.Tensor | None = None, ) -> torch.Tensor: - """Return [tokens, topk, inter_dim] fp32.""" + """Return [tokens, topk, inter_dim] fp32. + + Args: + bias: Optional bias tensor of shape [experts, 2 * inter_dim] (gate_bias, up_bias concatenated). + """ tokens, model_dim = x_fp8.shape topk = topk_ids.shape[1] # Derive experts from weight shapes (topk_ids may not cover all experts when tokens are tiny). @@ -40,6 +45,10 @@ def torch_moe_gemm1( y2 = F.linear(x[t_idx, :], w1[e, :, :]) # [num, 2*inter_dim] gate = y2[:, :inter_dim] up = y2[:, inter_dim:] + # Apply bias if provided + if bias is not None: + gate = gate + bias[e, :inter_dim] + up = up + bias[e, inter_dim:] y = F.silu(gate) * up if doweight_stage1: y = y * topk_weights[t_idx, s_idx].unsqueeze(-1) @@ -56,14 +65,19 @@ def torch_moe_gemm2( topk_weights: torch.Tensor, model_dim: int, doweight_stage2: bool, + bias: torch.Tensor | None = None, ) -> torch.Tensor: """Return [tokens, model_dim] fp32. Semantics align with aiter `torch_moe_stage2`: - Dequantize `a2_fp8` and `w2_fp8` with per-token/row scales. - For each routed (token, slot) -> expert, compute y = a2 @ W2[expert]^T. + - Optionally add bias (if provided). - Optionally multiply routed weight in stage2 (when stage1 did *not*). - Reduce across topk by summing into [tokens, model_dim]. + + Args: + bias: Optional bias tensor of shape [experts, model_dim]. """ assert a2_fp8.is_cuda and w2_fp8.is_cuda tokens, topk, inter_dim = a2_fp8.shape @@ -87,6 +101,9 @@ def torch_moe_gemm2( t_idx = idx[:, 0] s_idx = idx[:, 1] y = F.linear(a2[t_idx, s_idx, :], w2[e, :, :]) # [num, model_dim] + # Apply bias if provided + if bias is not None: + y = y + bias[e, :] if doweight_stage2: y = y * topk_weights[t_idx, s_idx].unsqueeze(-1) out.index_add_(0, t_idx, y)