diff --git a/.gitignore b/.gitignore index e8947221..2a78a909 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,8 @@ flydsl/src/*.egg-info/ # Wheels / packaging outputs dist/ + +# Profiling artifacts (rocprofv3) +.rocprofv3/ +R9700-Workstation-SH/ +*_results.db diff --git a/flir/include/flir/FlirRocmDialect.td b/flir/include/flir/FlirRocmDialect.td index 4c98078b..11f4d699 100644 --- a/flir/include/flir/FlirRocmDialect.td +++ b/flir/include/flir/FlirRocmDialect.td @@ -93,7 +93,8 @@ def GfxArchEnum : I32EnumAttr<"GfxArch", I32EnumAttrCase<"GFX940", 2, "gfx940">, I32EnumAttrCase<"GFX941", 3, "gfx941">, I32EnumAttrCase<"GFX942", 4, "gfx942">, - I32EnumAttrCase<"GFX950", 5, "gfx950"> + I32EnumAttrCase<"GFX950", 5, "gfx950">, + I32EnumAttrCase<"GFX1201", 6, "gfx1201"> ]> { let cppNamespace = "::mlir::flir::rocm"; } diff --git a/flir/python_bindings/runtime/FlirRocmRuntimeWrappers.cpp b/flir/python_bindings/runtime/FlirRocmRuntimeWrappers.cpp index 59ed39f3..ee85a0b5 100644 --- a/flir/python_bindings/runtime/FlirRocmRuntimeWrappers.cpp +++ b/flir/python_bindings/runtime/FlirRocmRuntimeWrappers.cpp @@ -15,6 +15,8 @@ #include #include +#include +#include #include "hip/hip_runtime.h" @@ -36,9 +38,65 @@ fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ }(expr) -extern "C" FLIR_EXPORT hipModule_t mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) { +// --------------------------------------------------------------------------- +// Caching layer: avoids calling hipModuleLoadData / hipModuleGetFunction / +// hipStreamCreate on every kernel launch. The MLIR gpu-to-llvm lowering +// emits calls to mgpuModuleLoad + mgpuModuleGetFunction + mgpuStreamCreate +// on *every* invocation of the host wrapper, which adds ~0.8 ms of overhead +// per call. +// +// Cache key: we use a content hash of the blob data. The JIT may reuse +// the same address for different blobs (after freeing a previous module), +// so pointer-based keys are unsafe. We compute a fast hash of the blob +// content and use (hash, size) as the cache key. For repeated calls from +// the same JIT'd function, the blob content is identical → cache hit. +// --------------------------------------------------------------------------- + +static std::mutex g_cache_mutex; + +// Simple FNV-1a hash of blob content +static uint64_t hash_blob(const void *data, size_t size) { + const uint8_t *p = static_cast(data); + uint64_t hash = 14695981039346656037ULL; // FNV offset basis + for (size_t i = 0; i < size; i++) { + hash ^= p[i]; + hash *= 1099511628211ULL; // FNV prime + } + return hash; +} + +struct BlobKey { + uint64_t content_hash; + size_t size; + bool operator==(const BlobKey &o) const { + return content_hash == o.content_hash && size == o.size; + } +}; + +struct BlobKeyHash { + size_t operator()(const BlobKey &k) const { + size_t h = std::hash()(k.content_hash); + h ^= std::hash()(k.size) + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; + } +}; + +static std::unordered_map g_module_cache; +static std::unordered_map> g_func_cache; +static hipStream_t g_cached_stream = nullptr; +static bool g_stream_initialized = false; + +extern "C" FLIR_EXPORT hipModule_t mgpuModuleLoad(void *data, size_t gpuBlobSize) { + std::lock_guard lock(g_cache_mutex); + uint64_t h = hash_blob(data, gpuBlobSize); + BlobKey key{h, gpuBlobSize}; + auto it = g_module_cache.find(key); + if (it != g_module_cache.end()) { + return it->second; + } hipModule_t module = nullptr; HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data)); + g_module_cache[key] = module; return module; } @@ -50,13 +108,26 @@ extern "C" FLIR_EXPORT hipModule_t mgpuModuleLoadJIT(void *data, int optLevel) { } extern "C" FLIR_EXPORT void mgpuModuleUnload(hipModule_t module) { - HIP_REPORT_IF_ERROR(hipModuleUnload(module)); + // Don't unload cached modules — they're reused across calls. + (void)module; } extern "C" FLIR_EXPORT hipFunction_t mgpuModuleGetFunction(hipModule_t module, - const char *name) { + const char *name) { + std::lock_guard lock(g_cache_mutex); + auto &func_map = g_func_cache[module]; + std::string key(name); + auto it = func_map.find(key); + if (it != func_map.end()) { + return it->second; + } hipFunction_t function = nullptr; - HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name)); + hipError_t err = hipModuleGetFunction(&function, module, name); + if (err != hipSuccess) { + fprintf(stderr, "mgpuModuleGetFunction: failed for name='%s' module=%p err=%d\n", + name, (void*)module, (int)err); + } + func_map[key] = function; return function; } @@ -74,13 +145,27 @@ extern "C" FLIR_EXPORT void mgpuLaunchKernel(hipFunction_t function, intptr_t gr } extern "C" FLIR_EXPORT hipStream_t mgpuStreamCreate() { - hipStream_t stream = nullptr; - HIP_REPORT_IF_ERROR(hipStreamCreate(&stream)); - return stream; + std::lock_guard lock(g_cache_mutex); + if (g_stream_initialized) { + return g_cached_stream; + } + g_stream_initialized = true; + HIP_REPORT_IF_ERROR(hipStreamCreate(&g_cached_stream)); + return g_cached_stream; +} + +// Allow Python to set the stream we use for kernel launches. +// Call this with torch.cuda.current_stream().cuda_stream to share PyTorch's stream. +extern "C" FLIR_EXPORT void mgpuSetStream(hipStream_t stream) { + std::lock_guard lock(g_cache_mutex); + g_cached_stream = stream; + g_stream_initialized = true; } extern "C" FLIR_EXPORT void mgpuStreamDestroy(hipStream_t stream) { - HIP_REPORT_IF_ERROR(hipStreamDestroy(stream)); + // Don't destroy the cached stream — it's reused across calls. + // The MLIR lowering emits a destroy after every launch_func. + (void)stream; } extern "C" FLIR_EXPORT void mgpuStreamSynchronize(hipStream_t stream) { diff --git a/flydsl/src/flydsl/compiler/compiler.py b/flydsl/src/flydsl/compiler/compiler.py index 8ca13290..a0ef558e 100644 --- a/flydsl/src/flydsl/compiler/compiler.py +++ b/flydsl/src/flydsl/compiler/compiler.py @@ -16,7 +16,13 @@ from flydsl.runtime.device import get_rocm_arch from .executor import default_shared_libs -from .cache import FileCache, cache_enabled, cache_rebuild_requested, default_key_payload, make_cache_key +from .cache import ( + FileCache, + cache_enabled, + cache_rebuild_requested, + default_key_payload, + make_cache_key, +) if TYPE_CHECKING: from .executor import ExecutionEngineExecutor as Executor @@ -31,6 +37,15 @@ class CompileOptions: backend: Literal["execution_engine"] = "execution_engine" +def _is_wave32_arch(chip: str) -> bool: + """Return True if the GPU architecture uses wavefront size 32 (RDNA).""" + # RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use wave32 natively. + # CDNA architectures (gfx9xx) use wave64. + return ( + chip.startswith("gfx10") or chip.startswith("gfx11") or chip.startswith("gfx12") + ) + + def _pipeline_fragments( *, chip: str, @@ -50,6 +65,7 @@ def _pipeline_fragments( rocdl_bare_ptr_opt = b2s(use_bare_ptr_memref_call_conv) llvm_bare_host_opt = b2s(use_bare_pointers_for_host) llvm_bare_kern_opt = b2s(use_bare_pointers_for_kernels) + wave64_opt = b2s(not _is_wave32_arch(chip)) return [ "flir-to-standard", "trivial-dce", @@ -63,13 +79,13 @@ def _pipeline_fragments( "gpu.module(reconcile-unrealized-casts)", # Keep this as a formatted string so the chip is visible in dumps and matches # the non-dump compilation pipeline. - f"rocdl-attach-target{{O=2 abi=600 chip={chip} correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}}", + f"rocdl-attach-target{{O=2 abi=600 chip={chip} correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64={wave64_opt}}}", "gpu-to-llvm{intersperse-sizes-for-kernels=false " + f"use-bare-pointers-for-host={llvm_bare_host_opt} " + f"use-bare-pointers-for-kernels={llvm_bare_kern_opt}" + "}", "reconcile-unrealized-casts", - "gpu-module-to-binary{format=fatbin opts= section= toolkit=}", + f"gpu-module-to-binary{{format=fatbin opts={os.environ.get('FLYDSL_LLC_OPTS', '')} section= toolkit=}}", ] @@ -137,7 +153,9 @@ def _dump_ir(stage: str, *, dump_dir: Path, asm: str) -> Path: return out -def _dump_isa_from_rocdl_module_asm(*, dump_dir: Path, ctx: ir.Context, asm: str, verify: bool) -> Optional[Path]: +def _dump_isa_from_rocdl_module_asm( + *, dump_dir: Path, ctx: ir.Context, asm: str, verify: bool +) -> Optional[Path]: """Best-effort dump final ISA/assembly (.s) for the current GPU module. This is only used for debug dumps. It intentionally does not affect the main @@ -151,7 +169,10 @@ def _dump_isa_from_rocdl_module_asm(*, dump_dir: Path, ctx: ir.Context, asm: str try: # Parse a fresh clone so we don't mutate the main compilation module. mod = ir.Module.parse(asm, context=ctx) - pm = PassManager.parse("builtin.module(gpu-module-to-binary{format=isa opts= section= toolkit=})", context=ctx) + pm = PassManager.parse( + "builtin.module(gpu-module-to-binary{format=isa opts= section= toolkit=})", # ISA dump always without custom opts + context=ctx, + ) pm.enable_verifier(bool(verify)) pm.run(mod.operation) isa_bytes = get_compile_object_bytes(mod) @@ -239,8 +260,10 @@ def _apply_waves_per_eu_hint(mlir_module, waves_per_eu: int): # Best-effort: if attribute injection fails, log and continue # This prevents breaking existing functionality import warnings + warnings.warn(f"Failed to apply waves_per_eu hint: {e}", RuntimeWarning) + def compile( flir_module_or_ir: Union[object, ir.Module], *, @@ -272,12 +295,16 @@ def compile( if mlir_module is None: mlir_module = flir_module_or_ir if not isinstance(mlir_module, ir.Module): - raise TypeError(f"Expected an MLIR module or flir.lang.MlirModule; got {type(flir_module_or_ir)}") + raise TypeError( + f"Expected an MLIR module or flir.lang.MlirModule; got {type(flir_module_or_ir)}" + ) ctx = mlir_module.context ensure_flir_python_extensions(ctx) - compile_only = _env_truthy("FLYDSL_COMPILE_ONLY", "0") or _env_truthy("COMPILE_ONLY", "0") + compile_only = _env_truthy("FLYDSL_COMPILE_ONLY", "0") or _env_truthy( + "COMPILE_ONLY", "0" + ) dump_enabled = _env_truthy("FLIR_DUMP_IR", "0") dump_root_dir = Path(os.environ.get("FLIR_DUMP_DIR", "my_ir_dumps")).resolve() dump_prefix_base = ( @@ -311,7 +338,9 @@ def compile( module = mlir_module # Allow overriding target arch via env var (useful for cross-compilation or FLYDSL_COMPILE_ONLY mode) - chip = (os.environ.get("FLYDSL_TARGET_ARCH") or os.environ.get("ARCH") or "").strip() or get_rocm_arch() + chip = ( + os.environ.get("FLYDSL_TARGET_ARCH") or os.environ.get("ARCH") or "" + ).strip() or get_rocm_arch() pipeline = _build_pipeline_str( chip=chip, @@ -355,12 +384,17 @@ def compile( print(f"[flir.compile] cache hit key={cache_key}") if compile_only: if dump_enabled or print_final_module: - print(f"[flir.compile] FLYDSL_COMPILE_ONLY=1, skipping executor creation (arch={chip})") + print( + f"[flir.compile] FLYDSL_COMPILE_ONLY=1, skipping executor creation (arch={chip})" + ) return None from .executor import ExecutionEngineExecutor as Executor + if shared_libs is None: shared_libs = default_shared_libs().as_list() - return Executor(cached_mod, opt_level=opt_level, shared_libs=shared_libs) + return Executor( + cached_mod, opt_level=opt_level, shared_libs=shared_libs + ) except Exception: # Treat cache parse failures as misses. pass @@ -433,7 +467,9 @@ def compile( # In compile-only mode, skip executor creation and return None if compile_only: if dump_enabled or print_final_module: - print(f"[flir.compile] FLYDSL_COMPILE_ONLY=1, skipping executor creation (arch={chip})") + print( + f"[flir.compile] FLYDSL_COMPILE_ONLY=1, skipping executor creation (arch={chip})" + ) return None from .executor import ExecutionEngineExecutor as Executor @@ -441,5 +477,3 @@ def compile( if shared_libs is None: shared_libs = default_shared_libs().as_list() return Executor(module, opt_level=opt_level, shared_libs=shared_libs) - - diff --git a/flydsl/src/flydsl/dialects/ext/arith.py b/flydsl/src/flydsl/dialects/ext/arith.py index 7c6105ab..7250a559 100644 --- a/flydsl/src/flydsl/dialects/ext/arith.py +++ b/flydsl/src/flydsl/dialects/ext/arith.py @@ -7,36 +7,58 @@ import numpy as np from _mlir.ir import ( - Type, Value, IntegerType, IndexType, F32Type, F64Type, F16Type, - DenseElementsAttr, Location, InsertionPoint, ShapedType, VectorType + Type, + Value, + IntegerType, + IndexType, + F32Type, + F64Type, + F16Type, + BF16Type, + DenseElementsAttr, + Location, + InsertionPoint, + ShapedType, + VectorType, ) from _mlir.dialects import arith as _arith from _mlir.dialects._ods_common import get_op_result_or_op_results from ._loc import maybe_default_loc + def _is_integer_like_type(t: Type) -> bool: """Check if type is integer-like (including index).""" return IntegerType.isinstance(t) or IndexType.isinstance(t) + def _is_floating_point_type(t: Type) -> bool: """Check if type is floating point.""" - return F32Type.isinstance(t) or F64Type.isinstance(t) or F16Type.isinstance(t) + return ( + F32Type.isinstance(t) + or F64Type.isinstance(t) + or F16Type.isinstance(t) + or BF16Type.isinstance(t) + ) + def _is_index_type(t: Type) -> bool: """Check if type is index.""" return IndexType.isinstance(t) + def _is_vector_type(t: Type) -> bool: """Check if type is a vector.""" return VectorType.isinstance(t) + def _get_element_type(t: Type) -> Type: """Get element type from vector or return the type itself for scalars.""" if _is_vector_type(t): return VectorType(t).element_type return t + def _infer_mlir_type(value, vector=False): """Infer MLIR type from Python value.""" if isinstance(value, bool): @@ -52,6 +74,7 @@ def _infer_mlir_type(value, vector=False): else: raise ValueError(f"Cannot infer MLIR type from {type(value)}") + def constant( value: Union[int, float, bool], *, @@ -61,14 +84,14 @@ def constant( ip: InsertionPoint = None, ) -> "ArithValue": """Create a constant with type inference. - + Args: value: Python value (int, float, bool) type: Optional explicit MLIR type index: If True, create index type constant loc: Location for the operation ip: Insertion point - + Returns: ArithValue wrapping the constant """ @@ -78,7 +101,7 @@ def constant( mlir_type = type else: mlir_type = _infer_mlir_type(value) - + if _is_floating_point_type(mlir_type) and not isinstance(value, float): value = float(value) @@ -126,45 +149,68 @@ def constant( result = _arith.ConstantOp(mlir_type, value, loc=loc, ip=ip).result return ArithValue(result) -def index(value: int, *, loc: Location = None, ip: InsertionPoint = None) -> "ArithValue": + +def index( + value: int, *, loc: Location = None, ip: InsertionPoint = None +) -> "ArithValue": """Create an index constant.""" return constant(value, index=True, loc=loc, ip=ip) + def i32(value: int, *, loc: Location = None, ip: InsertionPoint = None) -> "ArithValue": """Create an i32 constant.""" return constant(value, type=IntegerType.get_signless(32), loc=loc, ip=ip) + def i64(value: int, *, loc: Location = None, ip: InsertionPoint = None) -> "ArithValue": """Create an i64 constant.""" return constant(value, type=IntegerType.get_signless(64), loc=loc, ip=ip) -def f16(value: float, *, loc: Location = None, ip: InsertionPoint = None) -> "ArithValue": + +def f16( + value: float, *, loc: Location = None, ip: InsertionPoint = None +) -> "ArithValue": """Create an f16 constant.""" return constant(value, type=F16Type.get(), loc=loc, ip=ip) -def f16(value: float, *, loc: Location = None, ip: InsertionPoint = None) -> "ArithValue": + +def f16( + value: float, *, loc: Location = None, ip: InsertionPoint = None +) -> "ArithValue": """Create an f16 constant.""" return constant(value, type=F16Type.get(), loc=loc, ip=ip) -def f32(value: float, *, loc: Location = None, ip: InsertionPoint = None) -> "ArithValue": + +def f32( + value: float, *, loc: Location = None, ip: InsertionPoint = None +) -> "ArithValue": """Create an f32 constant.""" return constant(value, type=F32Type.get(), loc=loc, ip=ip) -def f64(value: float, *, loc: Location = None, ip: InsertionPoint = None) -> "ArithValue": + +def f64( + value: float, *, loc: Location = None, ip: InsertionPoint = None +) -> "ArithValue": """Create an f64 constant.""" return constant(value, type=F64Type.get(), loc=loc, ip=ip) -def maximum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": + +def maximum( + lhs: Union["ArithValue", Value], + rhs: Union["ArithValue", Value], + *, + loc: Location = None, +) -> "ArithValue": """Compute maximum of two values (automatically handles float/int types). - + Args: lhs: Left operand (ArithValue, Value, or Python number) rhs: Right operand (ArithValue, Value, or Python number) loc: Optional source location - + Returns: ArithValue wrapping the maximum result - + Example: >>> a = arith.f32(1.5) >>> b = arith.f32(2.3) @@ -173,17 +219,23 @@ def maximum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, """ return _minmax_op(lhs, rhs, op_type="max", loc=loc) -def minimum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": + +def minimum( + lhs: Union["ArithValue", Value], + rhs: Union["ArithValue", Value], + *, + loc: Location = None, +) -> "ArithValue": """Compute minimum of two values (automatically handles float/int types). - + Args: lhs: Left operand (ArithValue, Value, or Python number) rhs: Right operand (ArithValue, Value, or Python number) loc: Optional source location - + Returns: ArithValue wrapping the minimum result - + Example: >>> a = arith.f32(1.5) >>> b = arith.f32(2.3) @@ -192,41 +244,58 @@ def minimum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, """ return _minmax_op(lhs, rhs, op_type="min", loc=loc) -def select(condition: Union["ArithValue", Value], true_value: Union["ArithValue", Value], - false_value: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": + +def select( + condition: Union["ArithValue", Value], + true_value: Union["ArithValue", Value], + false_value: Union["ArithValue", Value], + *, + loc: Location = None, +) -> "ArithValue": """Select between two values based on a condition (ternary operator). - + Args: condition: Boolean condition (i1 type) true_value: Value to return if condition is true false_value: Value to return if condition is false loc: Optional source location - + Returns: ArithValue wrapping the selected value - + Example: >>> cond = a < b >>> result = arith.select(cond, a, b) # Equivalent to: a if cond else b """ - cond_val = _unwrap_value(condition) if isinstance(condition, ArithValue) else condition - true_val = _unwrap_value(true_value) if isinstance(true_value, ArithValue) else true_value - false_val = _unwrap_value(false_value) if isinstance(false_value, ArithValue) else false_value - + cond_val = ( + _unwrap_value(condition) if isinstance(condition, ArithValue) else condition + ) + true_val = ( + _unwrap_value(true_value) if isinstance(true_value, ArithValue) else true_value + ) + false_val = ( + _unwrap_value(false_value) + if isinstance(false_value, ArithValue) + else false_value + ) + result = _arith.SelectOp(cond_val, true_val, false_val, loc=loc).result return ArithValue(result) -def extf(result_type: Type, value: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": + +def extf( + result_type: Type, value: Union["ArithValue", Value], *, loc: Location = None +) -> "ArithValue": """Extend floating point value to a wider type (e.g., f16 -> f32). - + Args: result_type: Target floating point type value: Value to extend loc: Optional source location - + Returns: ArithValue wrapping the extended value - + Example: >>> f16_val = ... # some f16 value >>> f32_val = arith.extf(T.vector(32, T.f32()), f16_val) @@ -235,17 +304,20 @@ def extf(result_type: Type, value: Union["ArithValue", Value], *, loc: Location result = _arith.ExtFOp(result_type, val, loc=loc).result return ArithValue(result) -def fptosi(result_type: Type, value: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": + +def fptosi( + result_type: Type, value: Union["ArithValue", Value], *, loc: Location = None +) -> "ArithValue": """Convert floating point value to signed integer. - + Args: result_type: Target integer type value: Floating point value to convert loc: Optional source location - + Returns: ArithValue wrapping the integer result - + Example: >>> f32_val = arith.f32(3.7) >>> i32_val = arith.fptosi(T.i32(), f32_val) # Result: 3 @@ -255,7 +327,9 @@ def fptosi(result_type: Type, value: Union["ArithValue", Value], *, loc: Locatio return ArithValue(result) -def sitofp(result_type: Type, value: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": +def sitofp( + result_type: Type, value: Union["ArithValue", Value], *, loc: Location = None +) -> "ArithValue": """Convert signed integer value to floating point. Args: @@ -274,63 +348,76 @@ def sitofp(result_type: Type, value: Union["ArithValue", Value], *, loc: Locatio result = _arith.SIToFPOp(result_type, val, loc=loc).result return ArithValue(result) -def constant_vector(element_value: Union[int, float], vector_type: Type, *, loc: Location = None) -> "ArithValue": + +def constant_vector( + element_value: Union[int, float], vector_type: Type, *, loc: Location = None +) -> "ArithValue": """Create a constant vector with all elements set to the same value. - + Args: element_value: Scalar value to splat across the vector vector_type: Vector type (e.g., T.vector(32, T.f32())) loc: Optional source location - + Returns: ArithValue wrapping the constant vector - + Example: >>> vec_zero = arith.constant_vector(0.0, T.vector(32, T.f16())) >>> vec_ones = arith.constant_vector(1.0, T.vector(16, T.f32())) """ from _mlir.ir import FloatAttr, IntegerAttr, DenseElementsAttr - + # Get element type from vector type element_type = VectorType(vector_type).element_type - + # Create attribute for the element value if _is_floating_point_type(element_type): elem_attr = FloatAttr.get(element_type, float(element_value)) elif _is_integer_like_type(element_type): elem_attr = IntegerAttr.get(element_type, int(element_value)) else: - raise ValueError(f"Unsupported element type for constant vector: {element_type}") - + raise ValueError( + f"Unsupported element type for constant vector: {element_type}" + ) + # Create dense elements attribute (splat) dense_attr = DenseElementsAttr.get_splat(vector_type, elem_attr) - + result = _arith.ConstantOp(vector_type, dense_attr, loc=loc).result return ArithValue(result) + def absf(value: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": """Calculate absolute value (floating point). - + Args: value: Input value (float or vector of floats) loc: Optional source location - + Returns: Absolute value result wrapped in ArithValue """ from _mlir.dialects import math as _math + val = _unwrap_value(value) result = _math.AbsFOp(val, loc=loc).result return ArithValue(result) -def andi(lhs: Union["ArithValue", Value, int], rhs: Union["ArithValue", Value, int], *, loc: Location = None) -> "ArithValue": + +def andi( + lhs: Union["ArithValue", Value, int], + rhs: Union["ArithValue", Value, int], + *, + loc: Location = None, +) -> "ArithValue": """Bitwise AND operation on integers. - + Args: lhs: Left operand rhs: Right operand loc: Optional source location - + Returns: ArithValue wrapping the AND result """ @@ -344,14 +431,20 @@ def andi(lhs: Union["ArithValue", Value, int], rhs: Union["ArithValue", Value, i result = _arith.AndIOp(lhs_val, rhs_val, loc=loc).result return ArithValue(result) -def ori(lhs: Union["ArithValue", Value, int], rhs: Union["ArithValue", Value, int], *, loc: Location = None) -> "ArithValue": + +def ori( + lhs: Union["ArithValue", Value, int], + rhs: Union["ArithValue", Value, int], + *, + loc: Location = None, +) -> "ArithValue": """Bitwise OR operation on integers. - + Args: lhs: Left operand rhs: Right operand loc: Optional source location - + Returns: ArithValue wrapping the OR result """ @@ -365,14 +458,20 @@ def ori(lhs: Union["ArithValue", Value, int], rhs: Union["ArithValue", Value, in result = _arith.OrIOp(lhs_val, rhs_val, loc=loc).result return ArithValue(result) -def xori(lhs: Union["ArithValue", Value, int], rhs: Union["ArithValue", Value, int], *, loc: Location = None) -> "ArithValue": + +def xori( + lhs: Union["ArithValue", Value, int], + rhs: Union["ArithValue", Value, int], + *, + loc: Location = None, +) -> "ArithValue": """Bitwise XOR operation on integers. - + Args: lhs: Left operand rhs: Right operand loc: Optional source location - + Returns: ArithValue wrapping the XOR result """ @@ -386,14 +485,20 @@ def xori(lhs: Union["ArithValue", Value, int], rhs: Union["ArithValue", Value, i result = _arith.XOrIOp(lhs_val, rhs_val, loc=loc).result return ArithValue(result) -def shrui(lhs: Union["ArithValue", Value, int], rhs: Union["ArithValue", Value, int], *, loc: Location = None) -> "ArithValue": + +def shrui( + lhs: Union["ArithValue", Value, int], + rhs: Union["ArithValue", Value, int], + *, + loc: Location = None, +) -> "ArithValue": """Logical (unsigned) right shift operation on integers. - + Args: lhs: Value to shift rhs: Number of bits to shift loc: Optional source location - + Returns: ArithValue wrapping the shift result """ @@ -407,14 +512,20 @@ def shrui(lhs: Union["ArithValue", Value, int], rhs: Union["ArithValue", Value, result = _arith.ShRUIOp(lhs_val, rhs_val, loc=loc).result return ArithValue(result) -def shli(lhs: Union["ArithValue", Value, int], rhs: Union["ArithValue", Value, int], *, loc: Location = None) -> "ArithValue": + +def shli( + lhs: Union["ArithValue", Value, int], + rhs: Union["ArithValue", Value, int], + *, + loc: Location = None, +) -> "ArithValue": """Left shift operation on integers. - + Args: lhs: Value to shift rhs: Number of bits to shift loc: Optional source location - + Returns: ArithValue wrapping the shift result """ @@ -428,14 +539,17 @@ def shli(lhs: Union["ArithValue", Value, int], rhs: Union["ArithValue", Value, i result = _arith.ShLIOp(lhs_val, rhs_val, loc=loc).result return ArithValue(result) -def index_cast(target_type: Type, value: Union["ArithValue", Value, int], *, loc: Location = None) -> "ArithValue": + +def index_cast( + target_type: Type, value: Union["ArithValue", Value, int], *, loc: Location = None +) -> "ArithValue": """Cast between index and integer types. - + Args: target_type: Target type (index or integer type) value: Value to cast loc: Optional source location - + Returns: ArithValue wrapping the cast result """ @@ -446,7 +560,10 @@ def index_cast(target_type: Type, value: Union["ArithValue", Value, int], *, loc result = _arith.IndexCastOp(target_type, val, loc=loc).result return ArithValue(result) -def index_cast_ui(target_type: Type, value: Union["ArithValue", Value, int], *, loc: Location = None) -> "ArithValue": + +def index_cast_ui( + target_type: Type, value: Union["ArithValue", Value, int], *, loc: Location = None +) -> "ArithValue": """Cast between index and unsigned integer types. Args: @@ -465,21 +582,26 @@ def index_cast_ui(target_type: Type, value: Union["ArithValue", Value, int], *, return ArithValue(result) -def bitcast(result_type: Type, value: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": +def bitcast( + result_type: Type, value: Union["ArithValue", Value], *, loc: Location = None +) -> "ArithValue": """Reinterpret-cast bits between types of the same width (e.g. f32 <-> i32).""" loc = maybe_default_loc(loc) val = _unwrap_value(value) result = _arith.BitcastOp(result_type, val, loc=loc).result return ArithValue(result) -def trunc_f(target_type: Type, value: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": + +def trunc_f( + target_type: Type, value: Union["ArithValue", Value], *, loc: Location = None +) -> "ArithValue": """Truncate floating point value to narrower type (e.g., f32 -> f16). - + Args: target_type: Target floating point type value: Value to truncate loc: Optional source location - + Returns: ArithValue wrapping the truncated result """ @@ -488,29 +610,36 @@ def trunc_f(target_type: Type, value: Union["ArithValue", Value], *, loc: Locati result = _arith.TruncFOp(target_type, val, loc=loc).result return ArithValue(result) -def reduce(value: Union["ArithValue", Value], kind: str = "add", *, acc: Optional[Value] = None, loc: Location = None) -> "ArithValue": + +def reduce( + value: Union["ArithValue", Value], + kind: str = "add", + *, + acc: Optional[Value] = None, + loc: Location = None, +) -> "ArithValue": """Perform vector reduction. - + Args: value: Input vector value kind: Reduction kind ("add", "mul", "min", "max", "and", "or", "xor") acc: Optional accumulator loc: Optional source location - + Returns: Reduced scalar value wrapped in ArithValue """ from _mlir.dialects import vector as _vector - + val = _unwrap_value(value) - + # Map string kind to CombiningKind enum kind = kind.lower() val_type = val.type elem_type = _get_element_type(val_type) - + is_float = _is_floating_point_type(elem_type) - + kind_map = { "add": _vector.CombiningKind.ADD, "mul": _vector.CombiningKind.MUL, @@ -518,28 +647,29 @@ def reduce(value: Union["ArithValue", Value], kind: str = "add", *, acc: Optiona "or": _vector.CombiningKind.OR, "xor": _vector.CombiningKind.XOR, } - + if kind in ["min", "max"]: if is_float: kind_map["min"] = _vector.CombiningKind.MINIMUMF kind_map["max"] = _vector.CombiningKind.MAXIMUMF else: - kind_map["min"] = _vector.CombiningKind.MINSI # Default to signed + kind_map["min"] = _vector.CombiningKind.MINSI # Default to signed kind_map["max"] = _vector.CombiningKind.MAXSI - + if kind not in kind_map: raise ValueError(f"Unsupported reduction kind: {kind}") - + combining_kind = kind_map[kind] - + if acc is not None: acc = _unwrap_value(acc) op = _vector.ReductionOp(elem_type, combining_kind, val, acc=acc, loc=loc) else: op = _vector.ReductionOp(elem_type, combining_kind, val, loc=loc) - + return ArithValue(op.result) + def _unwrap_value(val): """递归unwrap ArithValue,获取底层的 ir.Value""" while isinstance(val, ArithValue): @@ -573,6 +703,7 @@ def as_value( """Alias for `unwrap`, intended for readability at MLIR builder boundaries.""" return unwrap(val, type=type, index=index, loc=loc) + def _binary_op( lhs: "ArithValue", rhs: "ArithValue", @@ -604,12 +735,12 @@ def _binary_op( rhs = constant(rhs, type=lhs._value.type, loc=loc) else: rhs = ArithValue(rhs) - + # Determine operation suffix based on type # For vectors, check the element type lhs_type = lhs._value.type if isinstance(lhs, ArithValue) else lhs.type element_type = _get_element_type(lhs_type) - + op_name = op.capitalize() if _is_floating_point_type(element_type): op_name += "F" @@ -630,8 +761,10 @@ def _binary_op( else: op_name += "I" else: - raise NotImplementedError(f"Unsupported operand types for {op}: {lhs_type} (element type: {element_type})") - + raise NotImplementedError( + f"Unsupported operand types for {op}: {lhs_type} (element type: {element_type})" + ) + # Get the operation class op_class = getattr(_arith, f"{op_name}Op") @@ -653,11 +786,13 @@ def _binary_op( pass result = op_class(lhs_val, rhs_val, loc=loc).result - + return ArithValue(result) -def _shift_op(lhs: "ArithValue", rhs: "ArithValue", op: str, *, loc: Location = None) -> "ArithValue": +def _shift_op( + lhs: "ArithValue", rhs: "ArithValue", op: str, *, loc: Location = None +) -> "ArithValue": """Shift operation for `ArithValue`. Notes: @@ -678,8 +813,14 @@ def _shift_op(lhs: "ArithValue", rhs: "ArithValue", op: str, *, loc: Location = else: lhs = ArithValue(lhs) elif not isinstance(lhs, ArithValue) and not isinstance(rhs, ArithValue): - lhs = constant(lhs, loc=loc) if isinstance(lhs, (int, float)) else ArithValue(lhs) - rhs = constant(rhs, type=lhs._value.type, loc=loc) if isinstance(rhs, (int, float)) else ArithValue(rhs) + lhs = ( + constant(lhs, loc=loc) if isinstance(lhs, (int, float)) else ArithValue(lhs) + ) + rhs = ( + constant(rhs, type=lhs._value.type, loc=loc) + if isinstance(rhs, (int, float)) + else ArithValue(rhs) + ) lhs_val = _unwrap_value(lhs) rhs_val = _unwrap_value(rhs) @@ -687,7 +828,9 @@ def _shift_op(lhs: "ArithValue", rhs: "ArithValue", op: str, *, loc: Location = lhs_type = lhs_val.type element_type = _get_element_type(lhs_type) if not _is_integer_like_type(element_type): - raise NotImplementedError(f"Shift not supported for type: {lhs_type} (element type: {element_type})") + raise NotImplementedError( + f"Shift not supported for type: {lhs_type} (element type: {element_type})" + ) if op == "shl": op_class = _arith.ShLIOp @@ -698,10 +841,14 @@ def _shift_op(lhs: "ArithValue", rhs: "ArithValue", op: str, *, loc: Location = return ArithValue(op_class(lhs_val, rhs_val, loc=loc).result) -def _rbinary_op(rhs: "ArithValue", lhs: "ArithValue", op: str, *, loc: Location = None) -> "ArithValue": + +def _rbinary_op( + rhs: "ArithValue", lhs: "ArithValue", op: str, *, loc: Location = None +) -> "ArithValue": """Reverse binary operation (for right-hand operations).""" return _binary_op(lhs, rhs, op, loc=loc) + def _comparison_op( lhs: "ArithValue", rhs: "ArithValue", @@ -714,11 +861,17 @@ def _comparison_op( # Coerce rhs to ArithValue if needed if not isinstance(rhs, ArithValue): if isinstance(rhs, (int, float)): - rhs = constant(rhs, type=lhs._value.type if isinstance(lhs, ArithValue) else lhs.type, loc=loc) + rhs = constant( + rhs, + type=lhs._value.type if isinstance(lhs, ArithValue) else lhs.type, + loc=loc, + ) else: rhs = ArithValue(rhs) - - if _is_floating_point_type(lhs._value.type if isinstance(lhs, ArithValue) else lhs.type): + + if _is_floating_point_type( + lhs._value.type if isinstance(lhs, ArithValue) else lhs.type + ): # Ordered float comparison if predicate in {"eq", "ne"}: pred_name = "O" + predicate.upper() # OEQ, ONE @@ -728,7 +881,9 @@ def _comparison_op( lhs_val = _unwrap_value(lhs) if isinstance(lhs, ArithValue) else lhs rhs_val = _unwrap_value(rhs) if isinstance(rhs, ArithValue) else rhs result = _arith.CmpFOp(pred_attr, lhs_val, rhs_val, loc=loc).result - elif _is_integer_like_type(lhs._value.type if isinstance(lhs, ArithValue) else lhs.type): + elif _is_integer_like_type( + lhs._value.type if isinstance(lhs, ArithValue) else lhs.type + ): # Signed integer comparison if predicate in {"eq", "ne"}: pred_name = predicate # eq, ne (lowercase) @@ -739,8 +894,10 @@ def _comparison_op( rhs_val = _unwrap_value(rhs) if isinstance(rhs, ArithValue) else rhs result = _arith.CmpIOp(pred_attr, lhs_val, rhs_val, loc=loc).result else: - raise NotImplementedError(f"Comparison not supported for type: {lhs._value.type if isinstance(lhs, ArithValue) else lhs.type}") - + raise NotImplementedError( + f"Comparison not supported for type: {lhs._value.type if isinstance(lhs, ArithValue) else lhs.type}" + ) + return ArithValue(result) @@ -762,15 +919,23 @@ def cmpu( """ # Coerce inputs similarly to `_comparison_op` if not isinstance(lhs, ArithValue): - lhs = constant(lhs, loc=loc) if isinstance(lhs, (int, float)) else ArithValue(lhs) + lhs = ( + constant(lhs, loc=loc) if isinstance(lhs, (int, float)) else ArithValue(lhs) + ) if not isinstance(rhs, ArithValue): - rhs = constant(rhs, type=lhs._value.type, loc=loc) if isinstance(rhs, (int, float)) else ArithValue(rhs) + rhs = ( + constant(rhs, type=lhs._value.type, loc=loc) + if isinstance(rhs, (int, float)) + else ArithValue(rhs) + ) lhs_val = _unwrap_value(lhs) rhs_val = _unwrap_value(rhs) if not _is_integer_like_type(_get_element_type(lhs_val.type)): - raise NotImplementedError(f"Unsigned compare not supported for type: {lhs_val.type}") + raise NotImplementedError( + f"Unsigned compare not supported for type: {lhs_val.type}" + ) pred_attr = getattr(_arith.CmpIPredicate, predicate) return ArithValue(_arith.CmpIOp(pred_attr, lhs_val, rhs_val, loc=loc).result) @@ -790,6 +955,8 @@ def ugt(lhs, rhs, *, loc: Location = None) -> "ArithValue": def uge(lhs, rhs, *, loc: Location = None) -> "ArithValue": return cmpu(lhs, rhs, "uge", loc=loc) + + def _minmax_op( lhs: "ArithValue", rhs: "ArithValue", @@ -802,14 +969,18 @@ def _minmax_op( # Coerce rhs to ArithValue if needed if not isinstance(rhs, ArithValue): if isinstance(rhs, (int, float)): - rhs = constant(rhs, type=lhs._value.type if isinstance(lhs, ArithValue) else lhs.type, loc=loc) + rhs = constant( + rhs, + type=lhs._value.type if isinstance(lhs, ArithValue) else lhs.type, + loc=loc, + ) else: rhs = ArithValue(rhs) - + # Unwrap values lhs_val = _unwrap_value(lhs) if isinstance(lhs, ArithValue) else lhs rhs_val = _unwrap_value(rhs) if isinstance(rhs, ArithValue) else rhs - + if _is_floating_point_type(lhs_val.type): # Float min/max if op_type == "max": @@ -827,18 +998,19 @@ def _minmax_op( result = op_class(lhs_val, rhs_val, loc=loc).result else: raise NotImplementedError(f"{op_type} not supported for type: {lhs_val.type}") - + return ArithValue(result) + class ArithValue: """Value wrapper with operator overloading for Pythonic arithmetic. - + Allows writing natural Python expressions like: c = a + b # instead of arith.AddIOp(a, b) d = a * 2 # instead of arith.MulIOp(a, constant(2)) e = a < b # instead of arith.CmpIOp(...) """ - + def __init__(self, value: Value): """Wrap an MLIR Value. @@ -848,14 +1020,14 @@ def __init__(self, value: Value): nested wrappers here. """ object.__setattr__(self, "_value", _unwrap_value(value)) - + def __getattr__(self, name): """Delegate attribute access to wrapped value.""" return getattr(self._value, name) - + def __repr__(self): return f"ArithValue({self._value})" - + # Arithmetic operators __add__ = partialmethod(_binary_op, op="add") __sub__ = partialmethod(_binary_op, op="sub") @@ -876,7 +1048,7 @@ def __repr__(self): # Min/Max methods max = partialmethod(_minmax_op, op_type="max") min = partialmethod(_minmax_op, op_type="min") - + # Reverse arithmetic operators (for when left operand is Python type) __radd__ = partialmethod(_rbinary_op, op="add") __rsub__ = partialmethod(_rbinary_op, op="sub") @@ -885,7 +1057,6 @@ def __repr__(self): __rfloordiv__ = partialmethod(_rbinary_op, op="div") __rmod__ = partialmethod(_rbinary_op, op="mod") - # Comparison operators __eq__ = partialmethod(_comparison_op, predicate="eq") __ne__ = partialmethod(_comparison_op, predicate="ne") @@ -893,32 +1064,92 @@ def __repr__(self): __le__ = partialmethod(_comparison_op, predicate="le") __gt__ = partialmethod(_comparison_op, predicate="gt") __ge__ = partialmethod(_comparison_op, predicate="ge") - + # Allow unwrapping for MLIR operations @property def value(self) -> Value: """Get the underlying MLIR Value (递归unwrap).""" return _unwrap_value(self) + # Re-export commonly used arith operations from _mlir.dialects.arith import ( - AddIOp, AddFOp, SubIOp, SubFOp, MulIOp, MulFOp, - DivSIOp, DivFOp, RemSIOp, RemFOp, - CmpIOp, CmpFOp, CmpIPredicate, CmpFPredicate, - IndexCastOp, ExtSIOp, TruncIOp, ExtFOp, TruncFOp, - SIToFPOp, FPToSIOp, SelectOp, + AddIOp, + AddFOp, + SubIOp, + SubFOp, + MulIOp, + MulFOp, + DivSIOp, + DivFOp, + RemSIOp, + RemFOp, + CmpIOp, + CmpFOp, + CmpIPredicate, + CmpFPredicate, + IndexCastOp, + ExtSIOp, + TruncIOp, + ExtFOp, + TruncFOp, + SIToFPOp, + FPToSIOp, + SelectOp, ) __all__ = [ - "constant", "unwrap", "as_value", "index", "i32", "i64", "f16", "f32", "f64", "Index", - "maximum", "minimum", "select", "extf", "fptosi", "sitofp", "absf", "reduce", "constant_vector", - "andi", "ori", "xori", "shrui", "shli", "index_cast", "index_cast_ui", "trunc_f", "bitcast", + "constant", + "unwrap", + "as_value", + "index", + "i32", + "i64", + "f16", + "f32", + "f64", + "Index", + "maximum", + "minimum", + "select", + "extf", + "fptosi", + "sitofp", + "absf", + "reduce", + "constant_vector", + "andi", + "ori", + "xori", + "shrui", + "shli", + "index_cast", + "index_cast_ui", + "trunc_f", + "bitcast", "ArithValue", - "AddIOp", "AddFOp", "SubIOp", "SubFOp", "MulIOp", "MulFOp", - "DivSIOp", "DivFOp", "RemSIOp", "RemFOp", - "CmpIOp", "CmpFOp", "CmpIPredicate", "CmpFPredicate", - "IndexCastOp", "ExtSIOp", "TruncIOp", "ExtFOp", "TruncFOp", - "SIToFPOp", "FPToSIOp", "SelectOp", + "AddIOp", + "AddFOp", + "SubIOp", + "SubFOp", + "MulIOp", + "MulFOp", + "DivSIOp", + "DivFOp", + "RemSIOp", + "RemFOp", + "CmpIOp", + "CmpFOp", + "CmpIPredicate", + "CmpFPredicate", + "IndexCastOp", + "ExtSIOp", + "TruncIOp", + "ExtFOp", + "TruncFOp", + "SIToFPOp", + "FPToSIOp", + "SelectOp", ] # Alias for convenience diff --git a/flydsl/src/flydsl/dialects/ext/buffer_ops.py b/flydsl/src/flydsl/dialects/ext/buffer_ops.py index 9e840667..4b701080 100644 --- a/flydsl/src/flydsl/dialects/ext/buffer_ops.py +++ b/flydsl/src/flydsl/dialects/ext/buffer_ops.py @@ -1,22 +1,22 @@ """AMD Buffer Load/Store Operations - High-level Python API -This module provides high-level Python wrappers for AMD CDNA3/CDNA4 buffer operations. +This module provides high-level Python wrappers for AMD GPU buffer operations (CDNA and RDNA). Buffer operations use a scalar base pointer and per-thread offsets for efficient memory access. Example: >>> from flydsl.dialects.ext import buffer_ops >>> from flydsl.dialects.ext import arith >>> import _mlir.extras.types as T - >>> + >>> >>> # Create buffer resource from memref >>> rsrc = buffer_ops.create_buffer_resource(A) - >>> + >>> >>> # Compute offset >>> offset = row * arith.index(4096) + col - >>> + >>> >>> # Buffer load (4xf32) >>> data = buffer_ops.buffer_load(rsrc, offset, vec_width=4) - >>> + >>> >>> # Buffer store >>> buffer_ops.buffer_store(data, rsrc, offset) """ @@ -27,50 +27,50 @@ from typing import Optional, Union __all__ = [ - 'create_llvm_ptr', - 'get_element_ptr', - 'create_buffer_resource', - 'buffer_load', - 'buffer_store', - 'buffer_load_2d', - 'buffer_store_2d', - 'BufferResourceDescriptor', - 'index_cast_to_i32', - 'i32_mul', - 'i32_add', - 'i32_select', + "create_llvm_ptr", + "get_element_ptr", + "create_buffer_resource", + "buffer_load", + "buffer_store", + "buffer_load_2d", + "buffer_store_2d", + "BufferResourceDescriptor", + "index_cast_to_i32", + "i32_mul", + "i32_add", + "i32_select", ] def create_llvm_ptr(value, address_space: int = 0) -> ir.Value: """Convert an index value to LLVM pointer. - + Args: value: Index value (typically from memref.extract_aligned_pointer_as_index) Can be ir.Value or ArithValue wrapper address_space: LLVM address space (0=generic, 3=LDS, 8=buffer descriptor) - + Returns: LLVM pointer value - + Example: >>> ptr_idx = memref.extract_aligned_pointer_as_index(A) >>> ptr = create_llvm_ptr(ptr_idx) """ # Extract actual MLIR value from wrapper value = _unwrap_value(value) - + # Convert index to i64 first (llvm.inttoptr requires signless integer, not index) if isinstance(value.type, ir.IndexType): i64_type = ir.IntegerType.get_signless(64) op = std_arith.IndexCastOp(i64_type, value) value = _unwrap_value(op.result) - + # Use opaque pointer syntax (LLVM 15+) if address_space == 0: - ptr_type = ir.Type.parse('!llvm.ptr') + ptr_type = ir.Type.parse("!llvm.ptr") else: - ptr_type = ir.Type.parse(f'!llvm.ptr<{address_space}>') + ptr_type = ir.Type.parse(f"!llvm.ptr<{address_space}>") return llvm.IntToPtrOp(ptr_type, value).result @@ -135,7 +135,9 @@ def get_element_ptr( offset_val = _unwrap_value(byte_offset) if isinstance(offset_val.type, ir.IndexType): i64_type = ir.IntegerType.get_signless(64) - offset_val = _unwrap_value(std_arith.IndexCastOp(i64_type, offset_val).result) + offset_val = _unwrap_value( + std_arith.IndexCastOp(i64_type, offset_val).result + ) elif isinstance(offset_val.type, ir.IntegerType): # LLVM GEP indices are integer-typed; keep as-is. pass @@ -147,11 +149,17 @@ def get_element_ptr( if static_byte_offset != 0: if not isinstance(offset_val.type, ir.IntegerType): - raise TypeError(f"dynamic byte_offset must be integer-typed, got {offset_val.type}") + raise TypeError( + f"dynamic byte_offset must be integer-typed, got {offset_val.type}" + ) static_type = offset_val.type static_attr = ir.IntegerAttr.get(static_type, int(static_byte_offset)) - static_const = _unwrap_value(std_arith.ConstantOp(static_type, static_attr).result) - offset_val = _unwrap_value(std_arith.AddIOp(offset_val, static_const).result) + static_const = _unwrap_value( + std_arith.ConstantOp(static_type, static_attr).result + ) + offset_val = _unwrap_value( + std_arith.AddIOp(offset_val, static_const).result + ) dynamic_indices = [offset_val] raw_constant_indices = [_gep_dynamic_index_sentinel] @@ -167,22 +175,22 @@ def get_element_ptr( def _unwrap_value(value): """Recursively unwrap ArithValue or similar wrappers to get the actual MLIR value. - + flir's ArithValue can be nested (double-wrapped), so we need to unwrap recursively until we get to a real ir.Value (OpResult or BlockArgument). """ max_depth = 10 # Safety limit depth = 0 - + while depth < max_depth and not isinstance(value, ir.Value): - if hasattr(value, '_value'): + if hasattr(value, "_value"): value = value._value - elif hasattr(value, 'value'): + elif hasattr(value, "value"): value = value.value else: break depth += 1 - + return value @@ -215,28 +223,30 @@ def _create_i64_constant(value: int) -> ir.Value: class BufferResourceDescriptor: """AMD Buffer Resource Descriptor - + A buffer resource descriptor contains: - base_pointer: Scalar base pointer (wave-uniform, stored in SGPRs) - stride: Stride for structured buffers (typically 0 for contiguous) - num_records: Buffer size in bytes - flags: Data format and access flags - + The descriptor is stored in a special LLVM pointer type (!llvm.ptr<8>) """ - + def __init__(self, rsrc: ir.Value): """Initialize with ROCDL resource descriptor value.""" self.rsrc = rsrc - + @staticmethod - def from_memref(memref_val: ir.Value, - stride: int = 0, - max_size: bool = True, - data_format: str = 'f32', - num_records_bytes: Optional[Union[int, ir.Value]] = None) -> 'BufferResourceDescriptor': + def from_memref( + memref_val: ir.Value, + stride: int = 0, + max_size: bool = True, + data_format: str = "f32", + num_records_bytes: Optional[Union[int, ir.Value]] = None, + ) -> "BufferResourceDescriptor": """Create buffer resource descriptor from memref. - + Args: memref_val: Memref value to create descriptor for stride: Stride in elements (0 for contiguous) @@ -244,25 +254,41 @@ def from_memref(memref_val: ir.Value, num_records_bytes: Override buffer size (in BYTES) used by hardware OOB checking. If provided, this takes precedence over `max_size`. data_format: Data format ('f32', 'f16', 'i32', etc.) - + Returns: BufferResourceDescriptor instance - + Example: >>> rsrc = BufferResourceDescriptor.from_memref(A) """ # Extract base pointer as index extract_op = memref.ExtractAlignedPointerAsIndexOp(memref_val) - ptr_idx = extract_op.result if hasattr(extract_op, 'result') else extract_op - + ptr_idx = extract_op.result if hasattr(extract_op, "result") else extract_op + # Convert to LLVM pointer base_ptr = create_llvm_ptr(ptr_idx, address_space=0) - + # Create buffer resource descriptor - flags_val = (7 << 12) | (4 << 15) # data_format=7 (float), num_format=4 (32bit) + # DWORD3 flags differ between GFX9 (CDNA) and GFX10+ (RDNA): + # GFX9: data_format=7 (float) at bits[14:12], num_format=4 (32-bit) at bits[17:15] + # GFX12: FORMAT at bits[21:14], OOB_SELECT at bits[29:28] + # OOB_SELECT must be >= 1 (0 = structured mode, requires stride != 0) + from flydsl.runtime.device import get_rocm_arch as _get_arch + + _arch = _get_arch() + if ( + _arch.startswith("gfx12") + or _arch.startswith("gfx11") + or _arch.startswith("gfx10") + ): + # RDNA: OOB_SELECT=2 (disable OOB check) | FORMAT bit 13 (raw 32-bit) + flags_val = (2 << 28) | (1 << 13) + else: + # CDNA (GFX9xx): data_format=7 (float), num_format=4 (32-bit) + flags_val = (7 << 12) | (4 << 15) flags = _create_i32_constant(flags_val) stride_val = _create_i16_constant(stride) - + def _num_records_from_memref_type() -> Optional[int]: """Best-effort: derive logical buffer size (in bytes) from static memref type.""" try: @@ -316,32 +342,36 @@ def _num_records_from_memref_type() -> Optional[int]: if nbytes > 0xFFFFFFFF: nbytes = 0xFFFFFFFF num_records = _create_i32_constant(int(nbytes)) - + # Create resource descriptor (returns !llvm.ptr<8>) - rsrc_type = ir.Type.parse('!llvm.ptr<8>') - rsrc = rocdl.MakeBufferRsrcOp(rsrc_type, base_ptr, stride_val, num_records, flags).result - + rsrc_type = ir.Type.parse("!llvm.ptr<8>") + rsrc = rocdl.MakeBufferRsrcOp( + rsrc_type, base_ptr, stride_val, num_records, flags + ).result + return BufferResourceDescriptor(rsrc) -def create_buffer_resource(memref_val: ir.Value, - stride: int = 0, - max_size: bool = True, - *, - num_records_bytes: Optional[Union[int, ir.Value]] = None) -> ir.Value: +def create_buffer_resource( + memref_val: ir.Value, + stride: int = 0, + max_size: bool = True, + *, + num_records_bytes: Optional[Union[int, ir.Value]] = None, +) -> ir.Value: """Create AMD buffer resource descriptor from memref. - + This is a simplified wrapper around BufferResourceDescriptor.from_memref() that returns the raw ROCDL resource value. - + Args: memref_val: Memref value stride: Buffer stride (0 for contiguous) max_size: Use maximum buffer size - + Returns: ROCDL buffer resource descriptor (!llvm.ptr<8>) - + Example: >>> rsrc = create_buffer_resource(A) >>> data = buffer_load(rsrc, offset) @@ -352,18 +382,20 @@ def create_buffer_resource(memref_val: ir.Value, return desc.rsrc -def buffer_load(rsrc: ir.Value, - offset: ir.Value, - vec_width: int = 4, - dtype = None, - mask: Optional[ir.Value] = None, - cache_modifier: int = 0, - soffset_bytes: Optional[Union[int, ir.Value]] = None) -> ir.Value: +def buffer_load( + rsrc: ir.Value, + offset: ir.Value, + vec_width: int = 4, + dtype=None, + mask: Optional[ir.Value] = None, + cache_modifier: int = 0, + soffset_bytes: Optional[Union[int, ir.Value]] = None, +) -> ir.Value: """AMD buffer load operation. - + Load data from global memory using buffer descriptor and offset. Uses hardware-level bounds checking and vectorization. - + Args: rsrc: Buffer resource descriptor (!llvm.ptr<8>) offset: Offset in elements (i32 type) @@ -374,49 +406,49 @@ def buffer_load(rsrc: ir.Value, soffset_bytes: Optional scalar offset (in BYTES) added by the buffer instruction (soffset). Use this to fold small constant deltas into the instruction instead of emitting extra VGPR address arithmetic. - + Returns: Loaded data (scalar or vector depending on vec_width) - + Example: >>> # Load 4xf32 >>> data = buffer_load(rsrc, offset, vec_width=4) - >>> + >>> >>> # Load with mask >>> data = buffer_load(rsrc, offset, vec_width=4, mask=valid) """ # Default dtype to f32 if dtype is None: dtype = ir.F32Type.get() - + # Unwrap offset first offset = _unwrap_value(offset) - + # Convert offset to i32 if needed if not isinstance(offset.type, ir.IntegerType) or offset.type.width != 32: op = std_arith.IndexCastOp(ir.IntegerType.get_signless(32), offset) offset = _unwrap_value(op.result) - + # IMPORTANT: Buffer load offset is in BYTES, not elements! # For vec4xf32, each element is 4 bytes, so multiply offset by 4 element_bytes = dtype.width // 8 bytes_const = _create_i32_constant(element_bytes) op = std_arith.MulIOp(offset, bytes_const) offset = _unwrap_value(op.result) - + # Apply mask by setting invalid offsets to max if mask is not None: mask = _unwrap_value(mask) max_offset = _create_i32_constant(0x7FFFFFFF) op = std_arith.SelectOp(mask, offset, max_offset) offset = _unwrap_value(op.result) - + # Create vector type if vec_width == 1: result_type = dtype else: result_type = ir.VectorType.get([vec_width], dtype) - + # Create instruction offset and aux flags if soffset_bytes is None: soffset = _create_i32_constant(0) @@ -429,41 +461,43 @@ def buffer_load(rsrc: ir.Value, op = std_arith.IndexCastOp(ir.IntegerType.get_signless(32), soffset) soffset = _unwrap_value(op.result) aux_flags = _create_i32_constant(cache_modifier) - + # Emit buffer load load_op = rocdl.RawPtrBufferLoadOp( - result_type, - rsrc, - offset, - soffset, # soffset (scalar byte offset) - aux_flags # aux (cache modifiers) + result_type, + rsrc, + offset, + soffset, # soffset (scalar byte offset) + aux_flags, # aux (cache modifiers) ) - + return load_op.result -def buffer_store(data: ir.Value, - rsrc: ir.Value, - offset: ir.Value, - mask: Optional[ir.Value] = None, - cache_modifier: int = 0, - *, - soffset_bytes: Optional[Union[int, ir.Value]] = None, - offset_is_bytes: bool = False): +def buffer_store( + data: ir.Value, + rsrc: ir.Value, + offset: ir.Value, + mask: Optional[ir.Value] = None, + cache_modifier: int = 0, + *, + soffset_bytes: Optional[Union[int, ir.Value]] = None, + offset_is_bytes: bool = False, +): """AMD buffer store operation. - + Store data to global memory using buffer descriptor and offset. - + Args: data: Data to store (scalar or vector) rsrc: Buffer resource descriptor (!llvm.ptr<8>) offset: Offset in elements (i32 type) mask: Optional mask for predicated store (i1 type) cache_modifier: Cache control flags (0 for default) - + Example: >>> buffer_store(data, rsrc, offset) - >>> + >>> >>> # Store with mask >>> buffer_store(data, rsrc, offset, mask=valid) """ @@ -471,19 +505,19 @@ def buffer_store(data: ir.Value, data = _unwrap_value(data) rsrc = _unwrap_value(rsrc) offset = _unwrap_value(offset) - + # Convert offset to i32 if needed if not isinstance(offset.type, ir.IntegerType) or offset.type.width != 32: op = std_arith.IndexCastOp(ir.IntegerType.get_signless(32), offset) offset = _unwrap_value(op.result) - + # IMPORTANT: RawPtrBufferStoreOp offset is in BYTES. # For backward compat, `buffer_store()` accepts element offsets by default # and scales them to bytes. Set `offset_is_bytes=True` to skip scaling. if not offset_is_bytes: # Get element size from data type data_type = data.type - if hasattr(data_type, 'element_type'): # Vector type + if hasattr(data_type, "element_type"): # Vector type element_type = data_type.element_type else: # Scalar type element_type = data_type @@ -491,14 +525,14 @@ def buffer_store(data: ir.Value, bytes_const = _create_i32_constant(element_bytes) op = std_arith.MulIOp(offset, bytes_const) offset = _unwrap_value(op.result) - + # Apply mask by setting invalid offsets to max if mask is not None: mask = _unwrap_value(mask) max_offset = _create_i32_constant(0x7FFFFFFF) op = std_arith.SelectOp(mask, offset, max_offset) offset = _unwrap_value(op.result) - + # Create instruction offset (soffset) and aux flags if soffset_bytes is None: soffset = _create_i32_constant(0) @@ -511,46 +545,55 @@ def buffer_store(data: ir.Value, op = std_arith.IndexCastOp(ir.IntegerType.get_signless(32), soffset) soffset = _unwrap_value(op.result) aux_flags = _create_i32_constant(cache_modifier) - + # Emit buffer store rocdl.RawPtrBufferStoreOp( data, rsrc, offset, - soffset, # soffset (scalar byte offset) - aux_flags # aux (cache modifiers) + soffset, # soffset (scalar byte offset) + aux_flags, # aux (cache modifiers) ) # Convenience functions for common patterns -def buffer_load_f32x4(rsrc: ir.Value, offset: ir.Value, mask: Optional[ir.Value] = None) -> ir.Value: + +def buffer_load_f32x4( + rsrc: ir.Value, offset: ir.Value, mask: Optional[ir.Value] = None +) -> ir.Value: """Load vector<4xf32> using buffer operation.""" return buffer_load(rsrc, offset, vec_width=4, dtype=ir.F32Type.get(), mask=mask) -def buffer_load_f16x4(rsrc: ir.Value, offset: ir.Value, mask: Optional[ir.Value] = None) -> ir.Value: +def buffer_load_f16x4( + rsrc: ir.Value, offset: ir.Value, mask: Optional[ir.Value] = None +) -> ir.Value: """Load vector<4xf16> using buffer operation (stored as 2xi32).""" # For f16, we load 4 elements but they're packed into 2xi32 - i32_data = buffer_load(rsrc, offset, vec_width=2, dtype=ir.IntegerType.get_signless(32), mask=mask) + i32_data = buffer_load( + rsrc, offset, vec_width=2, dtype=ir.IntegerType.get_signless(32), mask=mask + ) # TODO: Add bitcast to 4xf16 if needed return i32_data -def buffer_store_f32x4(data: ir.Value, rsrc: ir.Value, offset: ir.Value, mask: Optional[ir.Value] = None): +def buffer_store_f32x4( + data: ir.Value, rsrc: ir.Value, offset: ir.Value, mask: Optional[ir.Value] = None +): """Store vector<4xf32> using buffer operation.""" buffer_store(data, rsrc, offset, mask=mask) def index_cast_to_i32(value) -> ir.Value: """Cast index value to i32. - + Args: value: Index value (can be ArithValue or ir.Value) - + Returns: i32 value - + Example: >>> row_i32 = index_cast_to_i32(row_index) """ @@ -562,10 +605,10 @@ def index_cast_to_i32(value) -> ir.Value: def i32_mul(lhs, rhs) -> ir.Value: """Multiply two i32 values. - + Args: lhs, rhs: i32 values (will auto-unwrap if needed) - + Returns: i32 product """ @@ -577,10 +620,10 @@ def i32_mul(lhs, rhs) -> ir.Value: def i32_add(lhs, rhs) -> ir.Value: """Add two i32 values. - + Args: lhs, rhs: i32 values (will auto-unwrap if needed) - + Returns: i32 sum """ @@ -592,12 +635,12 @@ def i32_add(lhs, rhs) -> ir.Value: def i32_select(cond, true_val, false_val) -> ir.Value: """Select between two i32 values based on condition. - + Args: cond: i1 condition (will auto-unwrap if needed) true_val: Value if cond is true false_val: Value if cond is false - + Returns: Selected value """ @@ -608,9 +651,11 @@ def i32_select(cond, true_val, false_val) -> ir.Value: return _unwrap_value(op.result) -def buffer_load_2d(rsrc, row, col, stride, vec_width=4, dtype=None, mask=None) -> ir.Value: +def buffer_load_2d( + rsrc, row, col, stride, vec_width=4, dtype=None, mask=None +) -> ir.Value: """High-level 2D buffer load with automatic offset calculation. - + Args: rsrc: Buffer resource descriptor row: Row index (index or ArithValue) @@ -619,10 +664,10 @@ def buffer_load_2d(rsrc, row, col, stride, vec_width=4, dtype=None, mask=None) - vec_width: Vector width (1, 2, or 4) dtype: Element data type (defaults to f32) mask: Optional mask for predicated load - + Returns: Loaded data (scalar or vector) - + Example: >>> rsrc = create_buffer_resource(A) >>> data = buffer_load_2d(rsrc, row, col, N, vec_width=4) @@ -631,15 +676,15 @@ def buffer_load_2d(rsrc, row, col, stride, vec_width=4, dtype=None, mask=None) - row_i32 = index_cast_to_i32(row) col_i32 = index_cast_to_i32(col) stride_i32 = index_cast_to_i32(stride) - + offset = i32_add(i32_mul(row_i32, stride_i32), col_i32) - + return buffer_load(rsrc, offset, vec_width, dtype, mask) def buffer_store_2d(data, rsrc, row, col, stride, mask=None): """High-level 2D buffer store with automatic offset calculation. - + Args: data: Data to store (scalar or vector) rsrc: Buffer resource descriptor @@ -647,7 +692,7 @@ def buffer_store_2d(data, rsrc, row, col, stride, mask=None): col: Column index (index or ArithValue) stride: Row stride (index or ArithValue) mask: Optional mask for predicated store - + Example: >>> rsrc = create_buffer_resource(B) >>> buffer_store_2d(data, rsrc, row, col, M) @@ -656,8 +701,7 @@ def buffer_store_2d(data, rsrc, row, col, stride, mask=None): row_i32 = index_cast_to_i32(row) col_i32 = index_cast_to_i32(col) stride_i32 = index_cast_to_i32(stride) - + offset = i32_add(i32_mul(row_i32, stride_i32), col_i32) - - buffer_store(data, rsrc, offset, mask) + buffer_store(data, rsrc, offset, mask) diff --git a/flydsl/src/flydsl/dialects/ext/rocdl.py b/flydsl/src/flydsl/dialects/ext/rocdl.py index be82474d..8e3df74e 100644 --- a/flydsl/src/flydsl/dialects/ext/rocdl.py +++ b/flydsl/src/flydsl/dialects/ext/rocdl.py @@ -21,26 +21,45 @@ _ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None) _ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8 _ods_mfma_i32_16x16x32_i8 = mfma_i32_16x16x32_i8 -_ods_mfma_scale_f32_16x16x128_f8f6f4 = ( - globals().get("mfma_scale_f32_16x16x128_f8f6f4", None) - or globals().get("mfma_scale_f32_16x16x128_f8f6f4_", None) -) +_ods_mfma_scale_f32_16x16x128_f8f6f4 = globals().get( + "mfma_scale_f32_16x16x128_f8f6f4", None +) or globals().get("mfma_scale_f32_16x16x128_f8f6f4_", None) _ods_readlane = readlane _ods_readfirstlane = readfirstlane _ods_ds_swizzle = ds_swizzle _ods_raw_ptr_buffer_atomic_fadd = raw_ptr_buffer_atomic_fadd +# Keep ODS references for WMMA ops so we can wrap them. +_ods_wmma_f32_16x16x16_f16 = wmma_f32_16x16x16_f16 +_ods_wmma_f32_16x16x16_bf16 = wmma_f32_16x16x16_bf16 +_ods_wmma_f16_16x16x16_f16 = wmma_f16_16x16x16_f16 +_ods_wmma_bf16_16x16x16_bf16 = wmma_bf16_16x16x16_bf16 +_ods_wmma_i32_16x16x16_iu8 = wmma_i32_16x16x16_iu8 +_ods_wmma_i32_16x16x16_iu4 = wmma_i32_16x16x16_iu4 +_ods_wmma_f32_16x16x16_fp8_fp8 = globals().get("wmma_f32_16x16x16_fp8_fp8", None) +_ods_wmma_f32_16x16x16_fp8_bf8 = globals().get("wmma_f32_16x16x16_fp8_bf8", None) +_ods_wmma_f32_16x16x16_bf8_fp8 = globals().get("wmma_f32_16x16x16_bf8_fp8", None) +_ods_wmma_f32_16x16x16_bf8_bf8 = globals().get("wmma_f32_16x16x16_bf8_bf8", None) +_ods_wmma_i32_16x16x32_iu4 = globals().get("wmma_i32_16x16x32_iu4", None) + mask_mfma = 0x008 mask_vmem_rd = 0x020 mask_dsrd = 0x100 mask_dswr = 0x200 + def sched_mfma(cnt): sched_group_barrier(mask_mfma, cnt, 0) + + def sched_vmem(cnt): sched_group_barrier(mask_vmem_rd, cnt, 0) + + def sched_dsrd(cnt): sched_group_barrier(mask_dsrd, cnt, 0) + + def sched_dswr(cnt): sched_group_barrier(mask_dswr, cnt, 0) @@ -68,6 +87,7 @@ def mfma_f32_16x16x16f16(result_type, operands, *, loc=None, ip=None): """Return the op result directly (no `.result` needed at call sites).""" return mfma_f32_16x16x16f16_op(result_type, operands, loc=loc, ip=ip).result + # for bf16 version mfma def mfma_f32_16x16x16bf16_1k_op(result_type, operands, *, loc=None, ip=None): """Return the op view (original behavior).""" @@ -114,14 +134,173 @@ def mfma_scale_f32_16x16x128_f8f6f4_op(result_type, operands, *, loc=None, ip=No def mfma_scale_f32_16x16x128_f8f6f4(result_type, operands, *, loc=None, ip=None): """Return the op result directly (no `.result` needed at call sites).""" - return mfma_scale_f32_16x16x128_f8f6f4_op(result_type, operands, loc=loc, ip=ip).result + return mfma_scale_f32_16x16x128_f8f6f4_op( + result_type, operands, loc=loc, ip=ip + ).result + + +# --------------------------------------------------------------------------- +# WMMA wrappers (Wave Matrix Multiply-Accumulate – RDNA3/RDNA4) +# +# WMMA operands are [A, B, C] – all MLIR Values, no integer flags. +# For IU variants the operand list is [A_sign, A, B_sign, B, C, clamp]. +# For OPSEL variants (f16->f16, bf16->bf16) the list is [A, B, C, op_sel]. +# --------------------------------------------------------------------------- + + +def _unwrap_wmma_operand(v, *, loc=None): + """Accept Python ints (for flags like op_sel/clamp/signed) and ArithValue wrappers.""" + from _mlir.ir import IntegerType + from . import arith as _arith_ext + + if isinstance(v, bool): + return _arith_ext.constant( + int(v), type=IntegerType.get_signless(1), loc=loc + )._value + if isinstance(v, int): + return _arith_ext.constant(v, type=IntegerType.get_signless(32), loc=loc)._value + return _arith_ext.unwrap(v, loc=loc) + + +# --- f32 output variants --- + + +def wmma_f32_16x16x16_f16_op(result_type, operands, *, loc=None, ip=None): + ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] + return _ods_wmma_f32_16x16x16_f16(result_type, ops, loc=loc, ip=ip) + + +def wmma_f32_16x16x16_f16(result_type, operands, *, loc=None, ip=None): + """WMMA f16->f32, 16x16x16. Operands: [A, B, C]. Returns Value.""" + return wmma_f32_16x16x16_f16_op(result_type, operands, loc=loc, ip=ip).result + + +def wmma_f32_16x16x16_bf16_op(result_type, operands, *, loc=None, ip=None): + ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] + return _ods_wmma_f32_16x16x16_bf16(result_type, ops, loc=loc, ip=ip) + + +def wmma_f32_16x16x16_bf16(result_type, operands, *, loc=None, ip=None): + """WMMA bf16->f32, 16x16x16. Operands: [A, B, C]. Returns Value.""" + return wmma_f32_16x16x16_bf16_op(result_type, operands, loc=loc, ip=ip).result + + +# --- fp8 variants (gfx12 / RDNA4 only) --- + + +def wmma_f32_16x16x16_fp8_fp8_op(result_type, operands, *, loc=None, ip=None): + if _ods_wmma_f32_16x16x16_fp8_fp8 is None: + raise AttributeError("ROCDL op not found: wmma_f32_16x16x16_fp8_fp8") + ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] + return _ods_wmma_f32_16x16x16_fp8_fp8(result_type, ops, loc=loc, ip=ip) + + +def wmma_f32_16x16x16_fp8_fp8(result_type, operands, *, loc=None, ip=None): + """WMMA fp8->f32, 16x16x16 (gfx12). Operands: [A, B, C]. Returns Value.""" + return wmma_f32_16x16x16_fp8_fp8_op(result_type, operands, loc=loc, ip=ip).result + + +def wmma_f32_16x16x16_fp8_bf8_op(result_type, operands, *, loc=None, ip=None): + if _ods_wmma_f32_16x16x16_fp8_bf8 is None: + raise AttributeError("ROCDL op not found: wmma_f32_16x16x16_fp8_bf8") + ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] + return _ods_wmma_f32_16x16x16_fp8_bf8(result_type, ops, loc=loc, ip=ip) + + +def wmma_f32_16x16x16_fp8_bf8(result_type, operands, *, loc=None, ip=None): + """WMMA fp8+bf8->f32, 16x16x16 (gfx12). Operands: [A, B, C]. Returns Value.""" + return wmma_f32_16x16x16_fp8_bf8_op(result_type, operands, loc=loc, ip=ip).result + + +def wmma_f32_16x16x16_bf8_fp8_op(result_type, operands, *, loc=None, ip=None): + if _ods_wmma_f32_16x16x16_bf8_fp8 is None: + raise AttributeError("ROCDL op not found: wmma_f32_16x16x16_bf8_fp8") + ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] + return _ods_wmma_f32_16x16x16_bf8_fp8(result_type, ops, loc=loc, ip=ip) + + +def wmma_f32_16x16x16_bf8_fp8(result_type, operands, *, loc=None, ip=None): + """WMMA bf8+fp8->f32, 16x16x16 (gfx12). Operands: [A, B, C]. Returns Value.""" + return wmma_f32_16x16x16_bf8_fp8_op(result_type, operands, loc=loc, ip=ip).result + + +def wmma_f32_16x16x16_bf8_bf8_op(result_type, operands, *, loc=None, ip=None): + if _ods_wmma_f32_16x16x16_bf8_bf8 is None: + raise AttributeError("ROCDL op not found: wmma_f32_16x16x16_bf8_bf8") + ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] + return _ods_wmma_f32_16x16x16_bf8_bf8(result_type, ops, loc=loc, ip=ip) + + +def wmma_f32_16x16x16_bf8_bf8(result_type, operands, *, loc=None, ip=None): + """WMMA bf8->f32, 16x16x16 (gfx12). Operands: [A, B, C]. Returns Value.""" + return wmma_f32_16x16x16_bf8_bf8_op(result_type, operands, loc=loc, ip=ip).result + + +# --- f16/bf16 output variants (OPSEL: operands include op_sel flag) --- + + +def wmma_f16_16x16x16_f16_op(result_type, operands, *, loc=None, ip=None): + ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] + return _ods_wmma_f16_16x16x16_f16(result_type, ops, loc=loc, ip=ip) + + +def wmma_f16_16x16x16_f16(result_type, operands, *, loc=None, ip=None): + """WMMA f16->f16, 16x16x16. Operands: [A, B, C, op_sel]. Returns Value.""" + return wmma_f16_16x16x16_f16_op(result_type, operands, loc=loc, ip=ip).result + + +def wmma_bf16_16x16x16_bf16_op(result_type, operands, *, loc=None, ip=None): + ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] + return _ods_wmma_bf16_16x16x16_bf16(result_type, ops, loc=loc, ip=ip) + + +def wmma_bf16_16x16x16_bf16(result_type, operands, *, loc=None, ip=None): + """WMMA bf16->bf16, 16x16x16. Operands: [A, B, C, op_sel]. Returns Value.""" + return wmma_bf16_16x16x16_bf16_op(result_type, operands, loc=loc, ip=ip).result + + +# --- Integer variants (IU: operands include sign flags and clamp) --- + + +def wmma_i32_16x16x16_iu8_op(result_type, operands, *, loc=None, ip=None): + ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] + return _ods_wmma_i32_16x16x16_iu8(result_type, ops, loc=loc, ip=ip) + + +def wmma_i32_16x16x16_iu8(result_type, operands, *, loc=None, ip=None): + """WMMA int8->i32, 16x16x16. Operands: [A_sign, A, B_sign, B, C, clamp]. Returns Value.""" + return wmma_i32_16x16x16_iu8_op(result_type, operands, loc=loc, ip=ip).result + + +def wmma_i32_16x16x16_iu4_op(result_type, operands, *, loc=None, ip=None): + ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] + return _ods_wmma_i32_16x16x16_iu4(result_type, ops, loc=loc, ip=ip) + + +def wmma_i32_16x16x16_iu4(result_type, operands, *, loc=None, ip=None): + """WMMA int4->i32, 16x16x16. Operands: [A_sign, A, B_sign, B, C, clamp]. Returns Value.""" + return wmma_i32_16x16x16_iu4_op(result_type, operands, loc=loc, ip=ip).result + + +def wmma_i32_16x16x32_iu4_op(result_type, operands, *, loc=None, ip=None): + if _ods_wmma_i32_16x16x32_iu4 is None: + raise AttributeError("ROCDL op not found: wmma_i32_16x16x32_iu4") + ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] + return _ods_wmma_i32_16x16x32_iu4(result_type, ops, loc=loc, ip=ip) + + +def wmma_i32_16x16x32_iu4(result_type, operands, *, loc=None, ip=None): + """WMMA int4->i32, 16x16x32 (gfx12). Operands: [A_sign, A, B_sign, B, C, clamp]. Returns Value.""" + return wmma_i32_16x16x32_iu4_op(result_type, operands, loc=loc, ip=ip).result def readlane(result_type, src, lane_id, *, loc=None, ip=None): """Lane read that accepts ArithValue / wrappers.""" from . import arith as _arith_ext - return _ods_readlane(result_type, _arith_ext.unwrap(src), _arith_ext.unwrap(lane_id), loc=loc, ip=ip) + return _ods_readlane( + result_type, _arith_ext.unwrap(src), _arith_ext.unwrap(lane_id), loc=loc, ip=ip + ) def readfirstlane(result_type, src, *, loc=None, ip=None): @@ -135,10 +314,14 @@ def ds_swizzle(result_type, src, offset, *, loc=None, ip=None): """DS swizzle that accepts ArithValue / wrappers.""" from . import arith as _arith_ext - return _ods_ds_swizzle(result_type, _arith_ext.unwrap(src), _arith_ext.unwrap(offset), loc=loc, ip=ip) + return _ods_ds_swizzle( + result_type, _arith_ext.unwrap(src), _arith_ext.unwrap(offset), loc=loc, ip=ip + ) -def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, ip=None): +def raw_ptr_buffer_atomic_fadd( + val, rsrc, voffset, soffset, cache, *, loc=None, ip=None +): """Atomic fadd that accepts `ArithValue` / wrappers (no explicit `arith.unwrap(...)` needed). Signature intentionally matches the underlying ODS builder: @@ -163,65 +346,110 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, __all__ = [ # Thread/Block/Grid IDs and dimensions - 'workitem_id_x', 'workitem_id_y', 'workitem_id_z', - 'workgroup_id_x', 'workgroup_id_y', 'workgroup_id_z', - 'workgroup_dim_x', 'workgroup_dim_y', 'workgroup_dim_z', - 'grid_dim_x', 'grid_dim_y', 'grid_dim_z', - 'wavefrontsize', - + "workitem_id_x", + "workitem_id_y", + "workitem_id_z", + "workgroup_id_x", + "workgroup_id_y", + "workgroup_id_z", + "workgroup_dim_x", + "workgroup_dim_y", + "workgroup_dim_z", + "grid_dim_x", + "grid_dim_y", + "grid_dim_z", + "wavefrontsize", # Synchronization - 'barrier', 's_barrier', 's_barrier_signal', 's_barrier_wait', - 's_waitcnt', 's_wait_loadcnt', 's_wait_storecnt', - 's_wait_dscnt', 's_wait_expcnt', - + "barrier", + "s_barrier", + "s_barrier_signal", + "s_barrier_wait", + "s_waitcnt", + "s_wait_loadcnt", + "s_wait_storecnt", + "s_wait_dscnt", + "s_wait_expcnt", # Matrix operations - MFMA (Matrix Fused Multiply-Add) - 'mfma_f32_32x32x8f16', 'mfma_f32_16x16x16f16', - 'mfma_f32_16x16x16bf16_1k', - 'mfma_f32_32x32x4bf16', 'mfma_f32_16x16x8bf16', - 'mfma_i32_32x32x8i8', 'mfma_i32_16x16x16i8', - 'mfma_i32_16x16x32_i8', - 'mfma_scale_f32_16x16x128_f8f6f4', + "mfma_f32_32x32x8f16", + "mfma_f32_16x16x16f16", + "mfma_f32_16x16x16bf16_1k", + "mfma_f32_32x32x4bf16", + "mfma_f32_16x16x8bf16", + "mfma_i32_32x32x8i8", + "mfma_i32_16x16x16i8", + "mfma_i32_16x16x32_i8", + "mfma_scale_f32_16x16x128_f8f6f4", # Raw-op constructors (return op view) for the above - 'mfma_f32_16x16x16f16_op', 'mfma_f32_16x16x32_fp8_fp8_op', - 'mfma_f32_16x16x16bf16_1k_op', - 'mfma_i32_16x16x32_i8_op', - 'mfma_scale_f32_16x16x128_f8f6f4_op', - + "mfma_f32_16x16x16f16_op", + "mfma_f32_16x16x32_fp8_fp8_op", + "mfma_f32_16x16x16bf16_1k_op", + "mfma_i32_16x16x32_i8_op", + "mfma_scale_f32_16x16x128_f8f6f4_op", # Matrix operations - WMMA (Wave Matrix Multiply-Accumulate) - 'wmma_f32_16x16x16_f16', 'wmma_f32_16x16x16_bf16', - 'wmma_i32_16x16x16_iu8', - + "wmma_f32_16x16x16_f16", + "wmma_f32_16x16x16_bf16", + "wmma_f16_16x16x16_f16", + "wmma_bf16_16x16x16_bf16", + "wmma_i32_16x16x16_iu8", + "wmma_i32_16x16x16_iu4", + "wmma_f32_16x16x16_fp8_fp8", + "wmma_f32_16x16x16_fp8_bf8", + "wmma_f32_16x16x16_bf8_fp8", + "wmma_f32_16x16x16_bf8_bf8", + "wmma_i32_16x16x32_iu4", + # Raw-op constructors (return op view) for WMMA + "wmma_f32_16x16x16_f16_op", + "wmma_f32_16x16x16_bf16_op", + "wmma_f16_16x16x16_f16_op", + "wmma_bf16_16x16x16_bf16_op", + "wmma_i32_16x16x16_iu8_op", + "wmma_i32_16x16x16_iu4_op", + "wmma_f32_16x16x16_fp8_fp8_op", + "wmma_f32_16x16x16_fp8_bf8_op", + "wmma_f32_16x16x16_bf8_fp8_op", + "wmma_f32_16x16x16_bf8_bf8_op", + "wmma_i32_16x16x32_iu4_op", # Matrix operations - SMFMAC (Sparse Matrix FMA) - 'smfmac_f32_32x32x16_f16', 'smfmac_f32_32x32x16_bf16', - 'smfmac_i32_32x32x32_i8', - + "smfmac_f32_32x32x16_f16", + "smfmac_f32_32x32x16_bf16", + "smfmac_i32_32x32x32_i8", # Shuffle and permutation - 'ds_swizzle', 'ds_bpermute', - 'permlanex16', 'permlane16_swap', 'permlane32_swap', - 'readlane', 'readfirstlane', - 'update_dpp', - 'ballot', - + "ds_swizzle", + "ds_bpermute", + "permlanex16", + "permlane16_swap", + "permlane32_swap", + "readlane", + "readfirstlane", + "update_dpp", + "ballot", # Data movement - 'raw_buffer_load', 'raw_buffer_store', - 'raw_ptr_buffer_load', 'raw_ptr_buffer_store', - 'load_to_lds', 'global_load_lds', - 'make_buffer_rsrc', - + "raw_buffer_load", + "raw_buffer_store", + "raw_ptr_buffer_load", + "raw_ptr_buffer_store", + "load_to_lds", + "global_load_lds", + "make_buffer_rsrc", # Atomic operations - 'raw_buffer_atomic_fadd', 'raw_buffer_atomic_fmax', - 'raw_buffer_atomic_smax', 'raw_buffer_atomic_umin', - 'raw_ptr_buffer_atomic_fadd', 'raw_ptr_buffer_atomic_fmax', - + "raw_buffer_atomic_fadd", + "raw_buffer_atomic_fmax", + "raw_buffer_atomic_smax", + "raw_buffer_atomic_umin", + "raw_ptr_buffer_atomic_fadd", + "raw_ptr_buffer_atomic_fmax", # Bit manipulation - 'mbcnt_lo', 'mbcnt_hi', - + "mbcnt_lo", + "mbcnt_hi", # Scheduling and optimization - 's_setprio', 's_sleep', - 'sched_barrier', 'sched_group_barrier', - 'iglp_opt', - + "s_setprio", + "s_sleep", + "sched_barrier", + "sched_group_barrier", + "iglp_opt", # Type conversions - 'cvt_f32_bf8', 'cvt_f32_fp8', - 'cvt_pk_f32_bf8', 'cvt_pk_f32_fp8', + "cvt_f32_bf8", + "cvt_f32_fp8", + "cvt_pk_f32_bf8", + "cvt_pk_f32_fp8", ] diff --git a/flydsl/src/flydsl/lang/ir/types.py b/flydsl/src/flydsl/lang/ir/types.py index e07f4e8a..701c0f31 100644 --- a/flydsl/src/flydsl/lang/ir/types.py +++ b/flydsl/src/flydsl/lang/ir/types.py @@ -8,11 +8,11 @@ from flydsl.runtime.device import get_rocm_arch - def _flir_default_f8_type() -> ir.Type: """Select E4M3 f8 type compatible with the current GPU arch. - gfx95* (MI350): FP8 E4M3FN (OCP) + - gfx12* (RDNA4): FP8 E4M3FN (OCP) - gfx94* (MI300): FP8 E4M3FNUZ """ arch = "" @@ -20,7 +20,7 @@ def _flir_default_f8_type() -> ir.Type: arch = str(get_rocm_arch()) except Exception: arch = "" - if "gfx95" in arch: + if "gfx95" in arch or "gfx12" in arch: return ir.Float8E4M3FNType.get() return ir.Float8E4M3FNUZType.get() @@ -34,15 +34,19 @@ def index(self) -> ir.Type: @property def i8(self) -> ir.Type: return ir.IntegerType.get_signless(8) + @property def i8x2(self) -> ir.Type: return ir.VectorType.get([2], ir.IntegerType.get_signless(8)) + @property def i8x4(self) -> ir.Type: return ir.VectorType.get([4], ir.IntegerType.get_signless(8)) + @property def i8x8(self) -> ir.Type: return ir.VectorType.get([8], ir.IntegerType.get_signless(8)) + @property def i8x16(self) -> ir.Type: return ir.VectorType.get([16], ir.IntegerType.get_signless(8)) @@ -50,15 +54,19 @@ def i8x16(self) -> ir.Type: @property def ui8(self) -> ir.Type: return ir.IntegerType.get_unsigned(8) + @property def ui8x2(self) -> ir.Type: return ir.VectorType.get([2], ir.IntegerType.get_unsigned(8)) + @property def ui8x4(self) -> ir.Type: return ir.VectorType.get([4], ir.IntegerType.get_unsigned(8)) + @property def ui8x8(self) -> ir.Type: return ir.VectorType.get([8], ir.IntegerType.get_unsigned(8)) + @property def ui8x16(self) -> ir.Type: return ir.VectorType.get([16], ir.IntegerType.get_unsigned(8)) @@ -66,12 +74,15 @@ def ui8x16(self) -> ir.Type: @property def i16(self) -> ir.Type: return ir.IntegerType.get_signless(16) + @property def i16x2(self) -> ir.Type: return ir.VectorType.get([2], ir.IntegerType.get_signless(16)) + @property def i16x4(self) -> ir.Type: return ir.VectorType.get([4], ir.IntegerType.get_signless(16)) + @property def i16x8(self) -> ir.Type: return ir.VectorType.get([8], ir.IntegerType.get_signless(16)) @@ -79,9 +90,11 @@ def i16x8(self) -> ir.Type: @property def i32(self) -> ir.Type: return ir.IntegerType.get_signless(32) + @property def i32x2(self) -> ir.Type: return ir.VectorType.get([2], ir.IntegerType.get_signless(32)) + @property def i32x4(self) -> ir.Type: return ir.VectorType.get([4], ir.IntegerType.get_signless(32)) @@ -93,6 +106,7 @@ def ui32(self) -> ir.Type: @property def i64(self) -> ir.Type: return ir.IntegerType.get_signless(64) + @property def i64x2(self) -> ir.Type: return ir.VectorType.get([2], ir.IntegerType.get_signless(64)) @@ -100,15 +114,19 @@ def i64x2(self) -> ir.Type: @property def f16(self) -> ir.Type: return ir.F16Type.get() + @property def f16x1(self) -> ir.Type: return ir.VectorType.get([1], ir.F16Type.get()) + @property def f16x2(self) -> ir.Type: return ir.VectorType.get([2], ir.F16Type.get()) + @property def f16x4(self) -> ir.Type: return ir.VectorType.get([4], ir.F16Type.get()) + @property def f16x8(self) -> ir.Type: return ir.VectorType.get([8], ir.F16Type.get()) @@ -116,15 +134,19 @@ def f16x8(self) -> ir.Type: @property def bf16(self) -> ir.Type: return ir.BF16Type.get() + @property def bf16x2(self) -> ir.Type: return ir.VectorType.get([2], ir.BF16Type.get()) + @property def bf16x4(self) -> ir.Type: return ir.VectorType.get([4], ir.BF16Type.get()) + @property def bf16x8(self) -> ir.Type: return ir.VectorType.get([8], ir.BF16Type.get()) + @property def bf16x2(self) -> ir.Type: return ir.VectorType.get([2], ir.BF16Type.get()) @@ -149,18 +171,23 @@ def f64(self) -> ir.Type: @property def f8(self) -> ir.Type: return _flir_default_f8_type() + @property def f8x1(self) -> ir.Type: return ir.VectorType.get([1], _flir_default_f8_type()) + @property def f8x2(self) -> ir.Type: return ir.VectorType.get([2], _flir_default_f8_type()) + @property def f8x4(self) -> ir.Type: return ir.VectorType.get([4], _flir_default_f8_type()) + @property def f8x8(self) -> ir.Type: return ir.VectorType.get([8], _flir_default_f8_type()) + @property def f8x16(self) -> ir.Type: return ir.VectorType.get([16], _flir_default_f8_type()) @@ -168,15 +195,19 @@ def f8x16(self) -> ir.Type: @property def e8m0(self) -> ir.Type: return Float8E8M0FNUType.get() + @property def e8m0x2(self) -> ir.Type: return ir.VectorType.get([2], Float8E8M0FNUType.get()) + @property def e8m0x4(self) -> ir.Type: return ir.VectorType.get([4], Float8E8M0FNUType.get()) + @property def e8m0x8(self) -> ir.Type: return ir.VectorType.get([8], Float8E8M0FNUType.get()) + @property def e8m0x16(self) -> ir.Type: return ir.VectorType.get([16], Float8E8M0FNUType.get()) @@ -184,18 +215,23 @@ def e8m0x16(self) -> ir.Type: @property def f4(self) -> ir.Type: return ir.Float4E2M1FNType.get() + @property def f4x2(self) -> ir.Type: return ir.VectorType.get([2], ir.Float4E2M1FNType.get()) + @property def f4x4(self) -> ir.Type: return ir.VectorType.get([4], ir.Float4E2M1FNType.get()) + @property def f4x8(self) -> ir.Type: return ir.VectorType.get([8], ir.Float4E2M1FNType.get()) + @property def f4x16(self) -> ir.Type: return ir.VectorType.get([16], ir.Float4E2M1FNType.get()) + @property def f4x32(self) -> ir.Type: return ir.VectorType.get([32], ir.Float4E2M1FNType.get()) @@ -211,13 +247,31 @@ def vec(self, n: int, elem: ir.Type) -> ir.Type: T = Types() -__all__ = ["Types", "T", "vec", "i8x2", "i8x4", "i8x8", "i8x16", "i16x2", "i16x4", - "i16x8", "i32x2", "i32x4", "f16", "bf16", "bf16x2", "bf16x4", "bf16x8", "f32", - "f32x2", "f32x4", "f64", "f8", "f8x2", "f8x4", "f8x8", "f8x16"] - - - - - - - +__all__ = [ + "Types", + "T", + "vec", + "i8x2", + "i8x4", + "i8x8", + "i8x16", + "i16x2", + "i16x4", + "i16x8", + "i32x2", + "i32x4", + "f16", + "bf16", + "bf16x2", + "bf16x4", + "bf16x8", + "f32", + "f32x2", + "f32x4", + "f64", + "f8", + "f8x2", + "f8x4", + "f8x8", + "f8x16", +] diff --git a/flydsl/src/flydsl/runtime/device.py b/flydsl/src/flydsl/runtime/device.py index 969aab62..ca5ca776 100644 --- a/flydsl/src/flydsl/runtime/device.py +++ b/flydsl/src/flydsl/runtime/device.py @@ -1,4 +1,5 @@ import os +import subprocess from typing import Optional @@ -22,7 +23,9 @@ def get_rocm_arch() -> str: if torch.cuda.is_available(): props = torch.cuda.get_device_properties(torch.cuda.current_device()) - arch = getattr(props, "gcnArchName", None) or getattr(props, "gcn_arch_name", None) + arch = getattr(props, "gcnArchName", None) or getattr( + props, "gcn_arch_name", None + ) if arch: # MLIR/LLVM expects the processor name without feature suffixes. # Example: "gfx942:sramecc+:xnack-" -> "gfx942". @@ -30,6 +33,21 @@ def get_rocm_arch() -> str: except Exception: pass + # Try rocminfo as fallback to detect GPU arch. + try: + result = subprocess.run( + ["rocminfo"], + capture_output=True, + text=True, + timeout=10, + ) + for line in result.stdout.splitlines(): + line = line.strip() + if line.startswith("Name:") and "gfx" in line: + return line.split(":", 1)[1].strip().split()[0] + except Exception: + pass + # Conservative default for this repo's primary test environment. return "gfx942" @@ -47,5 +65,3 @@ def supports_bf16_global_atomics(arch: str) -> bool: def bf16_global_atomics_arch_description() -> str: """Human-readable list of archs that support bf16 global atomics (for error messages).""" return "/".join(_BF16_GLOBAL_ATOMICS_ARCH_PREFIXES) - - diff --git a/flydsl/src/flydsl/utils/smem_allocator.py b/flydsl/src/flydsl/utils/smem_allocator.py index 07c57cd8..4fafada1 100644 --- a/flydsl/src/flydsl/utils/smem_allocator.py +++ b/flydsl/src/flydsl/utils/smem_allocator.py @@ -12,14 +12,19 @@ # Type Utilities # ============================================================================== + def get_mlir_type_size(mlir_type: ir.Type) -> int: """Returns the size in bytes of an MLIR type.""" - if mlir_type == T.f32() or mlir_type == T.i32(): return 4 - if mlir_type == T.f16() or mlir_type == T.bf16() or mlir_type == T.i16(): return 2 + if mlir_type == T.f32() or mlir_type == T.i32(): + return 4 + if mlir_type == T.f16() or mlir_type == T.bf16() or mlir_type == T.i16(): + return 2 # 1 byte - if (mlir_type == T.i8() or - mlir_type == ir.IntegerType.get_unsigned(8) or - mlir_type == ir.IntegerType.get_signed(8)): + if ( + mlir_type == T.i8() + or mlir_type == ir.IntegerType.get_unsigned(8) + or mlir_type == ir.IntegerType.get_signed(8) + ): return 1 # FP8 types if isinstance( @@ -32,83 +37,101 @@ def get_mlir_type_size(mlir_type: ir.Type) -> int: ), ): return 1 - if mlir_type == T.f64() or mlir_type == T.i64(): return 8 + if mlir_type == T.f64() or mlir_type == T.i64(): + return 8 if isinstance(mlir_type, ir.VectorType): total = 1 - for s in mlir_type.shape: total *= s + for s in mlir_type.shape: + total *= s return total * get_mlir_type_size(mlir_type.element_type) # Fallback / Default return 4 + def get_mlir_type_align(mlir_type: ir.Type) -> int: """Returns the alignment requirement in bytes.""" # For Vector types, usually align to total size, capped at 16 bytes (float4) size = get_mlir_type_size(mlir_type) - return min(size, 16) + return min(size, 16) + def get_op_result_or_value(op_or_val): - if hasattr(op_or_val, 'value'): # ArithValue or similar wrapper + if hasattr(op_or_val, "value"): # ArithValue or similar wrapper return op_or_val.value if isinstance(op_or_val, ir.Value): return op_or_val - if hasattr(op_or_val, 'result'): + if hasattr(op_or_val, "result"): return op_or_val.result - if hasattr(op_or_val, 'results'): + if hasattr(op_or_val, "results"): return op_or_val.results[0] return op_or_val + # ============================================================================== # Pointer Abstraction # ============================================================================== + class SmemPtr: """ Represents a typed pointer into Shared Memory. Analogue to a typed pointer wrapper. """ - def __init__(self, base_memref: ir.Value, byte_offset: int, element_type: ir.Type, shape: Optional[Tuple[int, ...]] = None): - self.base_memref = base_memref # The raw i8 buffer - self.byte_offset = byte_offset # Static offset + + def __init__( + self, + base_memref: ir.Value, + byte_offset: int, + element_type: ir.Type, + shape: Optional[Tuple[int, ...]] = None, + ): + self.base_memref = base_memref # The raw i8 buffer + self.byte_offset = byte_offset # Static offset self.element_type = element_type self.shape = shape self._view_cache = None def _get_value(self, op_or_val): - if hasattr(op_or_val, 'value'): # ArithValue or similar wrapper + if hasattr(op_or_val, "value"): # ArithValue or similar wrapper return op_or_val.value if isinstance(op_or_val, ir.Value): return op_or_val - if hasattr(op_or_val, 'result'): + if hasattr(op_or_val, "result"): return op_or_val.result - if hasattr(op_or_val, 'results'): + if hasattr(op_or_val, "results"): return op_or_val.results[0] return op_or_val def get(self) -> ir.Value: """Dereference: Returns a memref view.""" - if self._view_cache: return self._view_cache + if self._view_cache: + return self._view_cache offset_op = arith.constant(T.index(), self.byte_offset) offset_val = get_op_result_or_value(offset_op) - + # Construct a structured memref view using the provided shape or default to scalar. - + if self.shape: target_shape = self.shape else: - target_shape = (1,) # Scalar treated as 1-element array for view simplicity - - target_type = T.memref(*target_shape, self.element_type, memory_space=lds_space()) - + target_shape = (1,) # Scalar treated as 1-element array for view simplicity + + target_type = T.memref( + *target_shape, self.element_type, memory_space=lds_space() + ) + # memref.view(source, byte_shift, sizes) # sizes are needed for dynamic dimensions. Since we use static shapes here, sizes=[] - self._view_cache = memref.view(target_type, self.base_memref, offset_val, sizes=[]) + self._view_cache = memref.view( + target_type, self.base_memref, offset_val, sizes=[] + ) return self._view_cache def load(self, idxs=None): """Helper to load value. If scalar, idxs defaults to [0].""" view = self.get() - if idxs is None: + if idxs is None: # If scalar (shape is None or (1,)), access index 0 idxs = [ get_op_result_or_value(arith.constant(T.index(), 0)) @@ -117,11 +140,11 @@ def load(self, idxs=None): else: idxs = [get_op_result_or_value(i) for i in idxs] return memref.load(get_op_result_or_value(view), idxs) - + def store(self, val, idxs=None): """Helper to store value. If scalar, idxs defaults to [0].""" view = self.get() - if idxs is None: + if idxs is None: idxs = [ get_op_result_or_value(arith.constant(T.index(), 0)) for _ in range(len(self.shape) if self.shape else 1) @@ -130,18 +153,23 @@ def store(self, val, idxs=None): idxs = [get_op_result_or_value(i) for i in idxs] memref.store(get_op_result_or_value(val), get_op_result_or_value(view), idxs) + # ============================================================================== # Struct Support # ============================================================================== + class SmemStructInstance: """ A proxy object that intercepts attribute access and maps them to SmemPtrs. """ - def __init__(self, base_memref, start_offset, field_layout: Dict[str, Tuple[int, ir.Type]]): + + def __init__( + self, base_memref, start_offset, field_layout: Dict[str, Tuple[int, ir.Type]] + ): self._base_memref = base_memref self._start_offset = start_offset - self._field_layout = field_layout # Dict[name, (offset, type)] + self._field_layout = field_layout # Dict[name, (offset, type)] def __getattr__(self, name): if name in self._field_layout: @@ -151,21 +179,25 @@ def __getattr__(self, name): return SmemPtr(self._base_memref, self._start_offset + rel_offset, dtype) raise AttributeError(f"Struct has no field '{name}'") + # ============================================================================== # Allocator # ============================================================================== + class SmemAllocator: - def __init__(self, ctx, arch: Optional[str] = None, global_sym_name: Optional[str] = None): + def __init__( + self, ctx, arch: Optional[str] = None, global_sym_name: Optional[str] = None + ): self.ctx = ctx self.ptr = 0 self.max_size = 0 - self.alignment = 128 # Base alignment for the whole buffer + self.alignment = 128 # Base alignment for the whole buffer self.finalized = False - self.base_buffer_val = None + self.base_buffer_val = None self.global_sym_name = global_sym_name if global_sym_name else "smem_storage" self.arch = arch - + def init_tracker(self): """ Call this at the start of compilation to reset tracking. @@ -173,16 +205,17 @@ def init_tracker(self): self.ptr = 0 self.max_size = 0 self.finalized = False - + def _align(self, ptr, align): - if ptr % align == 0: return ptr + if ptr % align == 0: + return ptr return (ptr + align - 1) // align * align def get_dyn_smem(self, dtype=None, alignment=1024): """ Analogue to get_dyn_smem. Returns the 'base' pointer generator for dynamic shared memory. - + Currently, MLIR GPU dialect usually handles dynamic smem via `gpu.dynamic_shared_memory`. """ if dtype is None: @@ -190,12 +223,14 @@ def get_dyn_smem(self, dtype=None, alignment=1024): # But usually you choose one. raise NotImplementedError("Dynamic SMEM support is not yet fully implemented.") - def allocate(self, size_or_type_or_struct: Union[int, ir.Type, Any], alignment=None): + def allocate( + self, size_or_type_or_struct: Union[int, ir.Type, Any], alignment=None + ): """ The master allocation function. Returns a generator function that accepts the base pointer. """ - + allocated_bytes = 0 generator = None @@ -203,11 +238,11 @@ def allocate(self, size_or_type_or_struct: Union[int, ir.Type, Any], alignment=N if isinstance(size_or_type_or_struct, int): size_bytes = size_or_type_or_struct align = alignment if alignment else 1 - + offset = self._align(self.ptr, align) self.ptr = offset + size_bytes allocated_bytes = size_bytes - + generator = lambda base: SmemPtr(base, offset, T.i8(), shape=(size_bytes,)) # Mode 2: MLIR Type (Scalar) @@ -215,67 +250,73 @@ def allocate(self, size_or_type_or_struct: Union[int, ir.Type, Any], alignment=N dtype = size_or_type_or_struct size_bytes = get_mlir_type_size(dtype) align = alignment if alignment else get_mlir_type_align(dtype) - + offset = self._align(self.ptr, align) self.ptr = offset + size_bytes allocated_bytes = size_bytes - + generator = lambda base: SmemPtr(base, offset, dtype, shape=None) # Mode 3: Struct (Python Class decorated with @dataclass or similar) - elif is_dataclass(size_or_type_or_struct) or hasattr(size_or_type_or_struct, '__annotations__'): + elif is_dataclass(size_or_type_or_struct) or hasattr( + size_or_type_or_struct, "__annotations__" + ): cls = size_or_type_or_struct - + # Calculate Layout current_struct_offset = 0 struct_align = 1 - field_layout = {} # name -> (offset, type) - + field_layout = {} # name -> (offset, type) + # Get type hints - hints = getattr(cls, '__annotations__', {}) - + hints = getattr(cls, "__annotations__", {}) + for name, type_hint in hints.items(): - field_dtype = type_hint + field_dtype = type_hint if not isinstance(field_dtype, ir.Type): if callable(field_dtype): try: field_dtype = field_dtype() except: pass - + if not isinstance(field_dtype, ir.Type): - raise ValueError(f"Field '{name}' in struct '{cls.__name__}' must be typed with an MLIR Type") - + raise ValueError( + f"Field '{name}' in struct '{cls.__name__}' must be typed with an MLIR Type" + ) + f_size = get_mlir_type_size(field_dtype) f_align = get_mlir_type_align(field_dtype) - + # Align field current_struct_offset = self._align(current_struct_offset, f_align) field_layout[name] = (current_struct_offset, field_dtype) - + current_struct_offset += f_size struct_align = max(struct_align, f_align) - + # Struct total size usually aligned to max align total_struct_size = self._align(current_struct_offset, struct_align) - + # Allocate in global buffer align = alignment if alignment else struct_align base_offset = self._align(self.ptr, align) self.ptr = base_offset + total_struct_size allocated_bytes = total_struct_size - + def struct_generator(base): return SmemStructInstance(base, base_offset, field_layout) - + generator = struct_generator else: - raise ValueError(f"Unsupported argument to allocate: {size_or_type_or_struct}") + raise ValueError( + f"Unsupported argument to allocate: {size_or_type_or_struct}" + ) # Check Capacity check_smem_capacity(self.ptr, self.arch) - + return generator def allocate_array(self, dtype: ir.Type, num_elems: int, alignment=None): @@ -285,77 +326,80 @@ def allocate_array(self, dtype: ir.Type, num_elems: int, alignment=None): elem_size = get_mlir_type_size(dtype) total_size = elem_size * num_elems align = alignment if alignment else get_mlir_type_align(dtype) - + offset = self._align(self.ptr, align) self.ptr = offset + total_size - + check_smem_capacity(self.ptr, self.arch) - + def array_generator(base): return SmemPtr(base, offset, dtype, shape=(num_elems,)) - + return array_generator def allocate_tensor(self, layout, element_type, swizzle=None): """ allocate_tensor(Layout, Type, Swizzle) -> Tensor Generator - + layout: Must be a tuple (shape) or object with .cosize() """ # 1. Calculate cosize (domain size) - if hasattr(layout, 'cosize'): + if hasattr(layout, "cosize"): # Assuming layout is a Flir layout or similar num_elements = layout.cosize() - shape = getattr(layout, 'shape', None) # Try to preserve shape info + shape = getattr(layout, "shape", None) # Try to preserve shape info elif isinstance(layout, tuple): # Simple shape tuple num_elements = 1 - for s in layout: num_elements *= s + for s in layout: + num_elements *= s shape = layout else: raise ValueError("Layout must be a tuple or have .cosize()") - + element_size = get_mlir_type_size(element_type) total_bytes = num_elements * element_size - + # 2. Allocate # Tensor allocations usually want high alignment for Vectorized Access - align = 16 + align = 16 offset = self._align(self.ptr, align) self.ptr = offset + total_bytes - + check_smem_capacity(self.ptr, self.arch) - + # 3. Return Tensor Generator def tensor_generator(base): # Returns a SmemPtr viewing memory as the specified tensor shape, # currently serving as a simplified placeholder for full layout logic. # It uses the provided shape tuple directly # or defaults to a flat 1D view for opaque layout objects. - + tensor_shape = shape if isinstance(shape, tuple) else (num_elements,) - + return SmemPtr(base, offset, element_type, shape=tensor_shape) - + return tensor_generator def finalize(self): """ - Generates the global buffer allocation. + Generates the global buffer allocation. Must be called inside the gpu.module body. """ - if self.finalized: return - + if self.finalized: + return + # Final padding to block alignment - total_size = self._align(self.ptr, 128) - if total_size == 0: total_size = 128 - + total_size = self._align(self.ptr, 128) + if total_size == 0: + total_size = 128 + # Create Global memref_type = T.memref(total_size, T.i8(), memory_space=lds_space()) self.global_op = memref.global_( sym_name=self.global_sym_name, type_=memref_type, - alignment=1024 # High alignment for base + alignment=1024, # High alignment for base ) self.finalized = True return self.global_op @@ -367,12 +411,14 @@ def get_base(self): # We need to recreate the memref type to access the global # The size must match what was allocated (or at least be large enough, but for global access exact match is good) total_size = self._align(self.ptr, 128) - if total_size == 0: total_size = 128 - + if total_size == 0: + total_size = 128 + memref_type = T.memref(total_size, T.i8(), memory_space=lds_space()) op = memref.get_global(memref_type, self.global_sym_name) return get_op_result_or_value(op) + # ============================================================================== # Shared Memory Capacity Check # ============================================================================== @@ -380,11 +426,15 @@ def get_base(self): SMEM_CAPACITY_MAP = { # ===================== AMD CDNA Architectures (Data Center Compute Cards) ===================== # CDNA 3 (MI300 Series) - 64KB LDS per CU - "gfx942": 65536, # MI300A / MI300X: 64KB LDS per CU + "gfx942": 65536, # MI300A / MI300X: 64KB LDS per CU # CDNA 4 (MI350 Series) - 160KB LDS per CU (key upgrade for CDNA4) - "gfx950": 163840, # MI300C / MI300X Enhanced Models: 64KB LDS per CU + "gfx950": 163840, # MI350: 160KB LDS per CU + # ===================== AMD RDNA Architectures (Consumer/Workstation GPUs) ===================== + # RDNA 4 (Radeon RX 9000 Series) - 64KB LDS per workgroup (from HSA GROUP segment) + "gfx1201": 65536, # Radeon RX 9070/9700: 64KB LDS per workgroup } + def check_smem_capacity(allocated_bytes: int, arch: str = None): """ Checks if the allocated shared memory fits within the device capacity. @@ -393,7 +443,7 @@ def check_smem_capacity(allocated_bytes: int, arch: str = None): # Try to detect arch from environment or flir context if possible # For now, default to a safe limit or skip check if unknown return - + if arch in SMEM_CAPACITY_MAP: limit = SMEM_CAPACITY_MAP[arch] if allocated_bytes > limit: @@ -404,5 +454,3 @@ def check_smem_capacity(allocated_bytes: int, arch: str = None): else: # Unknown arch, maybe warn or skip pass - - diff --git a/kernels/kernels_common.py b/kernels/kernels_common.py index 43e56299..23dc0a09 100644 --- a/kernels/kernels_common.py +++ b/kernels/kernels_common.py @@ -1,12 +1,25 @@ from _mlir.dialects import builtin, gpu as _gpu from flydsl.dialects.ext import buffer_ops +from flydsl.runtime.device import get_rocm_arch + + +def get_warp_size(arch: str = None) -> int: + """Return the wavefront/warp size for the given GPU architecture. + + CDNA (gfx9xx) uses wave64, RDNA (gfx10xx/gfx11xx/gfx12xx) uses wave32. + """ + if arch is None: + arch = get_rocm_arch() + if arch.startswith("gfx10") or arch.startswith("gfx11") or arch.startswith("gfx12"): + return 32 + return 64 def stream_ptr_to_async_token(stream_ptr_value, loc=None, ip=None): stream_llvm_ptr = buffer_ops.create_llvm_ptr(stream_ptr_value) - + async_token_type = _gpu.AsyncTokenType.get() cast_op = builtin.UnrealizedConversionCastOp( [async_token_type], [stream_llvm_ptr], loc=loc, ip=ip ) - return cast_op.results[0] \ No newline at end of file + return cast_op.results[0] diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index 54186d8f..18650e33 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -31,7 +31,9 @@ def dtype_to_elem_type(dtype_str: str): BLOCK_THREADS = 256 -WARP_SIZE = 64 +from kernels.kernels_common import get_warp_size + +WARP_SIZE = get_warp_size() VEC_WIDTH = 8 USE_NONTEMPORAL = True VEC_ALIGN = 16 @@ -58,7 +60,9 @@ class _LayerNorm(flir.MlirModule): def init_gpu_module(self): elem_type = dtype_to_elem_type(dtype_str) - compute_type = T.f32() # compute in fp32 for stability (and to keep bf16 safe on backend) + compute_type = ( + T.f32() + ) # compute in fp32 for stability (and to keep bf16 safe on backend) _state["elem_type"] = elem_type _state["compute_type"] = compute_type _state["smem_red_sum"] = allocator.allocate_array(T.f32(), RED_SLOTS) @@ -103,8 +107,11 @@ def layernorm_kernel( val_layout = flir.make_ordered_layout((1, VEC_WIDTH), order=(1, 0)) copy_atom_e = flir.make_copy_atom(elem_type, vector_size=VEC_WIDTH) tiled_copy_e = flir.make_tiled_copy_tv( - copy_atom_e, thr_layout, val_layout, - thr_shape=(1, BLOCK_THREADS), val_shape=(1, VEC_WIDTH) + copy_atom_e, + thr_layout, + val_layout, + thr_shape=(1, BLOCK_THREADS), + val_shape=(1, VEC_WIDTH), ) thr_copy_e = tiled_copy_e.get_slice((tid)) block_reduce_add = reduce_utils.make_block_reduce_add( @@ -140,10 +147,10 @@ def layernorm_kernel( # Read Input once into registers (each thread holds 32 fp32 values = 4 vectors), # then reuse those registers for reduction + normalize + writeback. c_zero = arith.constant(0.0, type=compute_type) - thread_sum = (c_zero) - thread_sumsq = (c_zero) + thread_sum = c_zero + thread_sumsq = c_zero # Reduce VGPR pressure by caching bf16/f16 payload vectors when possible. - cache_as_elem = (dtype_str != "f32") + cache_as_elem = dtype_str != "f32" in_local = [] # bf16/f16: list[vector]; f32: list[vector] vec_type_c = ir.VectorType.get([VEC_WIDTH], compute_type) @@ -183,7 +190,9 @@ def layernorm_kernel( thread_sum = thread_sum + red thread_sumsq = thread_sumsq + red2 - sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq, s_sum, s_sumsq) + sum_val, sumsq_val = block_reduce_add2( + thread_sum, thread_sumsq, s_sum, s_sumsq + ) inv_n = arith.constant(1.0 / float(N), type=compute_type) sum_val = arith.ArithValue(arith.as_value(sum_val)) @@ -219,10 +228,22 @@ def layernorm_kernel( curr_idx0 = flir.arith.AddIOp( arith.as_value(c_base0), arith.as_value(thread_offset_base) ).result - g_e_cur = flir.vector.load(vec_type_e, Gamma, [arith.as_value(curr_idx0)], alignment=VEC_ALIGN) - b_e_cur = flir.vector.load(vec_type_e, Beta, [arith.as_value(curr_idx0)], alignment=VEC_ALIGN) - g_cur = g_e_cur if dtype_str == "f32" else flir.arith.extf(vec_type_c, arith.as_value(g_e_cur)) - b_cur = b_e_cur if dtype_str == "f32" else flir.arith.extf(vec_type_c, arith.as_value(b_e_cur)) + g_e_cur = flir.vector.load( + vec_type_e, Gamma, [arith.as_value(curr_idx0)], alignment=VEC_ALIGN + ) + b_e_cur = flir.vector.load( + vec_type_e, Beta, [arith.as_value(curr_idx0)], alignment=VEC_ALIGN + ) + g_cur = ( + g_e_cur + if dtype_str == "f32" + else flir.arith.extf(vec_type_c, arith.as_value(g_e_cur)) + ) + b_cur = ( + b_e_cur + if dtype_str == "f32" + else flir.arith.extf(vec_type_c, arith.as_value(b_e_cur)) + ) for tile_i in range_constexpr(num_tiles_py): base_idx_int = tile_i * tile_cols @@ -238,12 +259,31 @@ def layernorm_kernel( next_base_idx_int = (tile_i + 1) * tile_cols c_base_next = flir.const_index(next_base_idx_int) next_idx = flir.arith.AddIOp( - arith.as_value(c_base_next), arith.as_value(thread_offset_base) + arith.as_value(c_base_next), + arith.as_value(thread_offset_base), ).result - g_e_next = flir.vector.load(vec_type_e, Gamma, [arith.as_value(next_idx)], alignment=VEC_ALIGN) - b_e_next = flir.vector.load(vec_type_e, Beta, [arith.as_value(next_idx)], alignment=VEC_ALIGN) - g_next = g_e_next if dtype_str == "f32" else flir.arith.extf(vec_type_c, arith.as_value(g_e_next)) - b_next = b_e_next if dtype_str == "f32" else flir.arith.extf(vec_type_c, arith.as_value(b_e_next)) + g_e_next = flir.vector.load( + vec_type_e, + Gamma, + [arith.as_value(next_idx)], + alignment=VEC_ALIGN, + ) + b_e_next = flir.vector.load( + vec_type_e, + Beta, + [arith.as_value(next_idx)], + alignment=VEC_ALIGN, + ) + g_next = ( + g_e_next + if dtype_str == "f32" + else flir.arith.extf(vec_type_c, arith.as_value(g_e_next)) + ) + b_next = ( + b_e_next + if dtype_str == "f32" + else flir.arith.extf(vec_type_c, arith.as_value(b_e_next)) + ) else: g_next = g_cur b_next = b_cur @@ -269,30 +309,55 @@ def layernorm_kernel( vec4_i32_ty = ir.VectorType.get([VEC_WIDTH // 2], T.i32()) vec_bf16_ty = ir.VectorType.get([VEC_WIDTH], elem_type) c16_i32 = arith.constant(16, type=T.i32()) - c16_i32_v = flir.vector.splat(vec_i32_ty, arith.as_value(c16_i32)) + c16_i32_v = flir.vector.splat( + vec_i32_ty, arith.as_value(c16_i32) + ) u = flir.arith.bitcast(vec_i32_ty, (y)) u = arith.as_value(u) - upper = arith.shrui(u, c16_i32_v) # i32 vector (upper 16 bits in low bits) + upper = arith.shrui( + u, c16_i32_v + ) # i32 vector (upper 16 bits in low bits) c1_i32 = arith.constant(1, type=T.i32()) c1_v = flir.vector.splat(vec_i32_ty, arith.as_value(c1_i32)) lsb = arith.andi(upper, arith.as_value(c1_v)) c7fff_i32 = arith.constant(0x7FFF, type=T.i32()) - c7fff_v = flir.vector.splat(vec_i32_ty, arith.as_value(c7fff_i32)) - bias = arith.ArithValue(arith.as_value(c7fff_v)) + arith.ArithValue(arith.as_value(lsb)) + c7fff_v = flir.vector.splat( + vec_i32_ty, arith.as_value(c7fff_i32) + ) + bias = arith.ArithValue( + arith.as_value(c7fff_v) + ) + arith.ArithValue(arith.as_value(lsb)) u_round = arith.ArithValue(u) + bias bf16_bits = arith.as_value(arith.shrui(u_round, c16_i32_v)) - even = flir.vector.shuffle(bf16_bits, bf16_bits, mask=[0, 2, 4, 6]) - odd = flir.vector.shuffle(bf16_bits, bf16_bits, mask=[1, 3, 5, 7]) - odd_sh = arith.as_value(arith.shli(arith.as_value(odd), flir.vector.splat(vec4_i32_ty, arith.as_value(c16_i32)))) - packed = arith.as_value(arith.ori(arith.as_value(even), odd_sh)) + even = flir.vector.shuffle( + bf16_bits, bf16_bits, mask=[0, 2, 4, 6] + ) + odd = flir.vector.shuffle( + bf16_bits, bf16_bits, mask=[1, 3, 5, 7] + ) + odd_sh = arith.as_value( + arith.shli( + arith.as_value(odd), + flir.vector.splat( + vec4_i32_ty, arith.as_value(c16_i32) + ), + ) + ) + packed = arith.as_value( + arith.ori(arith.as_value(even), odd_sh) + ) out_e = flir.vector.bitcast(vec_bf16_ty, (packed)) else: - out_e = y if dtype_str == "f32" else flir.arith.truncf(vec_type_e, (y)) + out_e = ( + y + if dtype_str == "f32" + else flir.arith.truncf(vec_type_e, (y)) + ) blkOut = gOut[((row), tile_i)] thrOut = thr_copy_e.partition_S(blkOut) @@ -318,27 +383,35 @@ def layernorm_kernel( # Generic path: 2-pass global implementation supporting arbitrary N (incl. tail). # For these small/unaligned-N test cases, correctness & robustness matter more than peak perf. c_N = flir.const_index(N) - c_zero = (arith.constant(0.0, type=compute_type)) - thread_sum = (c_zero) - thread_sumsq = (c_zero) + c_zero = arith.constant(0.0, type=compute_type) + thread_sum = c_zero + thread_sumsq = c_zero # Pass1: sum + sumsq for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): c_base = flir.const_index(base_idx_int) - idx = flir.arith.AddIOp(arith.as_value(c_base), arith.as_value(tid)).result + idx = flir.arith.AddIOp( + arith.as_value(c_base), arith.as_value(tid) + ).result is_valid = arith.ult(idx, c_N) thread_sum_next = thread_sum thread_sumsq_next = thread_sumsq if is_valid: x_e = flir.memref.load(Input, [(row), arith.as_value(idx)]) - x = (x_e) if dtype_str == "f32" else flir.arith.extf(compute_type, arith.as_value(x_e)) + x = ( + (x_e) + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(x_e)) + ) x_av = arith.ArithValue(arith.as_value(x)) x2 = x_av * x_av thread_sum_next = thread_sum + x thread_sumsq_next = thread_sumsq + x2 thread_sum, thread_sumsq = thread_sum_next, thread_sumsq_next - sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq, s_sum, s_sumsq) + sum_val, sumsq_val = block_reduce_add2( + thread_sum, thread_sumsq, s_sum, s_sumsq + ) inv_n = arith.constant(1.0 / float(N), type=compute_type) sum_val = arith.ArithValue(arith.as_value(sum_val)) @@ -357,15 +430,29 @@ def layernorm_kernel( # Pass2: normalize + affine + store for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): c_base = flir.const_index(base_idx_int) - idx = flir.arith.AddIOp(arith.as_value(c_base), arith.as_value(tid)).result + idx = flir.arith.AddIOp( + arith.as_value(c_base), arith.as_value(tid) + ).result is_valid = arith.ult(idx, c_N) if is_valid: x_e = flir.memref.load(Input, [(row), arith.as_value(idx)]) g_e = flir.memref.load(Gamma, [arith.as_value(idx)]) b_e = flir.memref.load(Beta, [arith.as_value(idx)]) - x = (x_e) if dtype_str == "f32" else flir.arith.extf(compute_type, arith.as_value(x_e)) - g = (g_e) if dtype_str == "f32" else flir.arith.extf(compute_type, arith.as_value(g_e)) - b = (b_e) if dtype_str == "f32" else flir.arith.extf(compute_type, arith.as_value(b_e)) + x = ( + (x_e) + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(x_e)) + ) + g = ( + (g_e) + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(g_e)) + ) + b = ( + (b_e) + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(b_e)) + ) diff = x - mean norm = diff * rstd scaled = norm * g @@ -373,7 +460,11 @@ def layernorm_kernel( if dtype_str == "bf16": y_e = flir.arith.truncf(elem_type, (y)) else: - y_e = y if dtype_str == "f32" else flir.arith.truncf(elem_type, (y)) + y_e = ( + y + if dtype_str == "f32" + else flir.arith.truncf(elem_type, (y)) + ) flir.memref.store((y_e), Output, [(row), (idx)]) @flir.jit @@ -385,9 +476,9 @@ def __call__( Output: lambda: T.memref(DYN, N, _state["elem_type"]), m_in: lambda: T.index(), ): - c1 = (flir.arith_ext.index(1)) - gx = (m_in) - bx = (flir.arith_ext.index(BLOCK_THREADS)) + c1 = flir.arith_ext.index(1) + gx = m_in + bx = flir.arith_ext.index(BLOCK_THREADS) flir.gpu_ext.LaunchFuncOp( ["layernorm_module", "layernorm_kernel"], grid_size=(gx, c1, c1), @@ -396,5 +487,3 @@ def __call__( ) return _LayerNorm() - - diff --git a/kernels/reduce.py b/kernels/reduce.py index d8313200..e2f6a318 100644 --- a/kernels/reduce.py +++ b/kernels/reduce.py @@ -3,11 +3,26 @@ These helpers build MLIR ops (flir/gpu/scf/vector/etc). They are extracted from softmax/layernorm/rmsnorm kernels to de-duplicate code without changing codegen. """ + from __future__ import annotations from flydsl.dialects.ext.python_control_flow import lower_range_for_loops +def _shuffle_offsets_for_warp(warp_size: int) -> list[int]: + """Return the xor-shuffle offsets for a tree reduction within a wavefront. + + For wave64: [32, 16, 8, 4, 2, 1] + For wave32: [16, 8, 4, 2, 1] + """ + offsets = [] + s = warp_size // 2 + while s >= 1: + offsets.append(s) + s //= 2 + return offsets + + def reduce_vec_max(vec_val, *, VEC_WIDTH, compute_type, vector): if VEC_WIDTH == 1: return vector.extract(vec_val, static_position=[0], dynamic_position=[]) @@ -15,6 +30,7 @@ def reduce_vec_max(vec_val, *, VEC_WIDTH, compute_type, vector): # The vector dialect expects a raw MLIR Value, not wrapper objects. try: from flydsl.dialects.ext import arith as _arith + vec_val = _arith.as_value(vec_val) except Exception: pass @@ -26,18 +42,36 @@ def reduce_vec_sum(vec_val, *, VEC_WIDTH, compute_type, vector, fm_fast): return vector.extract(vec_val, static_position=[0], dynamic_position=[]) try: from flydsl.dialects.ext import arith as _arith + vec_val = _arith.as_value(vec_val) except Exception: pass return vector.reduction(compute_type, "add", vec_val, fastmath=fm_fast) -def make_block_reduce(*, tid, BLOCK_SIZE, compute_type, arith, gpu, flir, s_red_tv, T, ir, c_zero, c_neg_inf, c_zero_idx, fm_fast): +def make_block_reduce( + *, + tid, + BLOCK_SIZE, + compute_type, + arith, + gpu, + flir, + s_red_tv, + T, + ir, + c_zero, + c_neg_inf, + c_zero_idx, + fm_fast, +): """Return a `block_reduce(val, reduce_op_name)` function (softmax-style).""" def block_reduce(val, reduce_op_name): - # AMD wavefront size is 64 on gfx9+/gfx10+/gfx11. - WARP_SIZE = 64 + # Wavefront size: 64 for CDNA (gfx9xx), 32 for RDNA (gfx10xx/11xx/12xx). + from kernels.kernels_common import get_warp_size + + WARP_SIZE = get_warp_size() NUM_WAVES = (BLOCK_SIZE + WARP_SIZE - 1) // WARP_SIZE # python int # Use Flir layout algebra to compute LDS indices for the reduction scratch. c_num_waves = flir.const_index(NUM_WAVES) @@ -57,20 +91,26 @@ def block_reduce(val, reduce_op_name): w = arith.as_value(val) # Intra-wave reduction via xor shuffle - for sh in [32, 16, 8, 4, 2, 1]: + for sh in _shuffle_offsets_for_warp(WARP_SIZE): off = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value(gpu.ShuffleOp(arith.as_value(w), off, width_i32, mode="xor").shuffleResult) + peer = arith.as_value( + gpu.ShuffleOp( + arith.as_value(w), off, width_i32, mode="xor" + ).shuffleResult + ) if reduce_op_name == "max": w = flir.arith.MaximumFOp(arith.as_value(w), peer).result else: w = flir.arith.AddFOp(arith.as_value(w), peer, fastmath=fm_fast).result # lane0 writes per-wave partial into LDS s_red[wave_id] - is_lane0 = arith.as_value(flir.arith.CmpIOp( - flir.arith.CmpIPredicate.eq, - lane_i32, - arith.as_value(arith.constant(0, type=T.i32())), - ).result) + is_lane0 = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.eq, + lane_i32, + arith.as_value(arith.constant(0, type=T.i32())), + ).result + ) if is_lane0: wave_idx = flir.arith.IndexCastOp(T.index(), wave_i32).result red_idx = flir.crd2idx(flir.make_coord(wave_idx), layout_red) @@ -78,41 +118,57 @@ def block_reduce(val, reduce_op_name): gpu.barrier() # wave0 reduces NUM_WAVES partials (still using shuffle) - is_wave0 = arith.as_value(flir.arith.CmpIOp( - flir.arith.CmpIPredicate.eq, - wave_i32, - arith.as_value(arith.constant(0, type=T.i32())), - ).result) + is_wave0 = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.eq, + wave_i32, + arith.as_value(arith.constant(0, type=T.i32())), + ).result + ) if is_wave0: - in_range = arith.as_value(flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ult, - lane_i32, - arith.as_value(arith.constant(NUM_WAVES, type=T.i32())), - ).result) + in_range = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + lane_i32, + arith.as_value(arith.constant(NUM_WAVES, type=T.i32())), + ).result + ) # Predicated load: clamp lane index to 0 when out-of-range, then select. c0_i32 = arith.as_value(arith.constant(0, type=T.i32())) - lane_safe_i32 = arith.as_value(flir.arith.SelectOp(in_range, lane_i32, c0_i32).result) - lane_safe_idx = arith.as_value(flir.arith.IndexCastOp(T.index(), lane_safe_i32).result) + lane_safe_i32 = arith.as_value( + flir.arith.SelectOp(in_range, lane_i32, c0_i32).result + ) + lane_safe_idx = arith.as_value( + flir.arith.IndexCastOp(T.index(), lane_safe_i32).result + ) red_idx = flir.crd2idx(flir.make_coord(lane_safe_idx), layout_red) v = arith.as_value(s_red_tv[red_idx]) neutral = arith.as_value(c_neg_inf if reduce_op_name == "max" else c_zero) ww = arith.as_value(flir.arith.SelectOp(in_range, v, neutral).result) - for sh in [32, 16, 8, 4, 2, 1]: + for sh in _shuffle_offsets_for_warp(WARP_SIZE): off = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value(gpu.ShuffleOp(arith.as_value(ww), off, width_i32, mode="xor").shuffleResult) + peer = arith.as_value( + gpu.ShuffleOp( + arith.as_value(ww), off, width_i32, mode="xor" + ).shuffleResult + ) if reduce_op_name == "max": ww = flir.arith.MaximumFOp(arith.as_value(ww), peer).result else: - ww = flir.arith.AddFOp(arith.as_value(ww), peer, fastmath=fm_fast).result + ww = flir.arith.AddFOp( + arith.as_value(ww), peer, fastmath=fm_fast + ).result # lane0 writes final to s_red[0] - is_lane0_2 = arith.as_value(flir.arith.CmpIOp( - flir.arith.CmpIPredicate.eq, - lane_i32, - arith.as_value(arith.constant(0, type=T.i32())), - ).result) + is_lane0_2 = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.eq, + lane_i32, + arith.as_value(arith.constant(0, type=T.i32())), + ).result + ) if is_lane0_2: red_idx0 = flir.crd2idx(flir.make_coord(c_zero_idx), layout_red) s_red_tv[red_idx0] = ww @@ -127,7 +183,21 @@ def block_reduce(val, reduce_op_name): return block_reduce -def make_block_reduce_add(*, tid, fm_fast, WARP_SIZE, RED_SLOTS, gpu, arith, arith_ops, flir, T, ir, zero_idx, scratch_tv_shape_stride=(None, None)): +def make_block_reduce_add( + *, + tid, + fm_fast, + WARP_SIZE, + RED_SLOTS, + gpu, + arith, + arith_ops, + flir, + T, + ir, + zero_idx, + scratch_tv_shape_stride=(None, None), +): """Return a `block_reduce_add(val_f32, scratch_memref)` function (norm-style).""" shape_unused, stride_unused = scratch_tv_shape_stride _ = shape_unused @@ -139,10 +209,16 @@ def block_reduce_add(val_f32, scratch_memref): if RED_SLOTS == 1: width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) w = arith.as_value(val_f32) - for sh in [32, 16, 8, 4, 2, 1]: + for sh in _shuffle_offsets_for_warp(WARP_SIZE): off = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value(gpu.ShuffleOp(arith.as_value(w), off, width_i32, mode="xor").shuffleResult) - w = arith.as_value(arith_ops.AddFOp(arith.as_value(w), peer, fastmath=fm_fast).result) + peer = arith.as_value( + gpu.ShuffleOp( + arith.as_value(w), off, width_i32, mode="xor" + ).shuffleResult + ) + w = arith.as_value( + arith_ops.AddFOp(arith.as_value(w), peer, fastmath=fm_fast).result + ) return w scratch_tv = flir.make_tensor(scratch_memref, shape=(RED_SLOTS,), strides=(1,)) @@ -161,16 +237,24 @@ def block_reduce_add(val_f32, scratch_memref): layout_red = flir.make_layout(shape_red, stride_red) w = arith.as_value(val_f32) - for sh in [32, 16, 8, 4, 2, 1]: + for sh in _shuffle_offsets_for_warp(WARP_SIZE): off = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value(gpu.ShuffleOp(arith.as_value(w), off, width_i32, mode="xor").shuffleResult) - w = arith.as_value(arith_ops.AddFOp(arith.as_value(w), peer, fastmath=fm_fast).result) - - is_lane0 = arith.as_value(arith_ops.CmpIOp( - arith_ops.CmpIPredicate.eq, - lane_i32, - arith.as_value(arith.constant(0, type=T.i32())), - ).result) + peer = arith.as_value( + gpu.ShuffleOp( + arith.as_value(w), off, width_i32, mode="xor" + ).shuffleResult + ) + w = arith.as_value( + arith_ops.AddFOp(arith.as_value(w), peer, fastmath=fm_fast).result + ) + + is_lane0 = arith.as_value( + arith_ops.CmpIOp( + arith_ops.CmpIPredicate.eq, + lane_i32, + arith.as_value(arith.constant(0, type=T.i32())), + ).result + ) if is_lane0: wave_idx = arith_ops.IndexCastOp(T.index(), wave_i32).result red_idx = flir.crd2idx(flir.make_coord(wave_idx), layout_red) @@ -178,37 +262,53 @@ def block_reduce_add(val_f32, scratch_memref): gpu.barrier() NUM_WAVES = RED_SLOTS - is_wave0 = arith.as_value(arith_ops.CmpIOp( - arith_ops.CmpIPredicate.eq, - wave_i32, - arith.as_value(arith.constant(0, type=T.i32())), - ).result) + is_wave0 = arith.as_value( + arith_ops.CmpIOp( + arith_ops.CmpIPredicate.eq, + wave_i32, + arith.as_value(arith.constant(0, type=T.i32())), + ).result + ) # Only wave0 does final reduction and writes scratch[0]. if is_wave0: - in_range = arith.as_value(arith_ops.CmpIOp( - arith_ops.CmpIPredicate.ult, - lane_i32, - arith.as_value(arith.constant(NUM_WAVES, type=T.i32())), - ).result) + in_range = arith.as_value( + arith_ops.CmpIOp( + arith_ops.CmpIPredicate.ult, + lane_i32, + arith.as_value(arith.constant(NUM_WAVES, type=T.i32())), + ).result + ) c0_i32 = arith.as_value(arith.constant(0, type=T.i32())) - lane_safe_i32 = arith.as_value(flir.arith.SelectOp(in_range, lane_i32, c0_i32).result) - lane_safe_idx = arith.as_value(arith_ops.IndexCastOp(T.index(), lane_safe_i32).result) + lane_safe_i32 = arith.as_value( + flir.arith.SelectOp(in_range, lane_i32, c0_i32).result + ) + lane_safe_idx = arith.as_value( + arith_ops.IndexCastOp(T.index(), lane_safe_i32).result + ) red_idx = flir.crd2idx(flir.make_coord(lane_safe_idx), layout_red) v = scratch_tv[red_idx] z = arith.as_value(arith.constant(0.0, type=T.f32())) ww = arith.as_value(flir.arith.SelectOp(in_range, v, z).result) - for sh in [32, 16, 8, 4, 2, 1]: + for sh in _shuffle_offsets_for_warp(WARP_SIZE): off = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value(gpu.ShuffleOp(arith.as_value(ww), off, width_i32, mode="xor").shuffleResult) - ww = arith.as_value(arith_ops.AddFOp(arith.as_value(ww), peer, fastmath=fm_fast).result) + peer = arith.as_value( + gpu.ShuffleOp( + arith.as_value(ww), off, width_i32, mode="xor" + ).shuffleResult + ) + ww = arith.as_value( + arith_ops.AddFOp(arith.as_value(ww), peer, fastmath=fm_fast).result + ) - is_lane0_2 = arith.as_value(arith_ops.CmpIOp( - arith_ops.CmpIPredicate.eq, - lane_i32, - arith.as_value(arith.constant(0, type=T.i32())), - ).result) + is_lane0_2 = arith.as_value( + arith_ops.CmpIOp( + arith_ops.CmpIPredicate.eq, + lane_i32, + arith.as_value(arith.constant(0, type=T.i32())), + ).result + ) if is_lane0_2: red_idx0 = flir.crd2idx(flir.make_coord(zero_idx), layout_red) scratch_tv[red_idx0] = ww @@ -223,7 +323,9 @@ def block_reduce_add(val_f32, scratch_memref): return block_reduce_add -def make_block_reduce_add2(*, tid, fm_fast, WARP_SIZE, RED_SLOTS, gpu, arith, arith_ops, flir, T, ir, zero_idx): +def make_block_reduce_add2( + *, tid, fm_fast, WARP_SIZE, RED_SLOTS, gpu, arith, arith_ops, flir, T, ir, zero_idx +): """Return a `block_reduce_add2(a_f32, b_f32, scratch_a, scratch_b)` function. This is NOT pair-reduce: it reduces two independent scalars but shares the same @@ -234,12 +336,16 @@ def _wave_reduce_add(x): # Normalize operands to raw MLIR Values: Shuffle/AddFOp expect `Value`, not wrappers. width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) w = arith.as_value(x) - for sh in [32, 16, 8, 4, 2, 1]: + for sh in _shuffle_offsets_for_warp(WARP_SIZE): off = arith.as_value(arith.constant(sh, type=T.i32())) peer = arith.as_value( - gpu.ShuffleOp(arith.as_value(w), off, width_i32, mode="xor").shuffleResult + gpu.ShuffleOp( + arith.as_value(w), off, width_i32, mode="xor" + ).shuffleResult + ) + w = arith.as_value( + arith_ops.AddFOp(arith.as_value(w), peer, fastmath=fm_fast).result ) - w = arith.as_value(arith_ops.AddFOp(arith.as_value(w), peer, fastmath=fm_fast).result) return w def block_reduce_add2(val0_f32, val1_f32, scratch0_memref, scratch1_memref): @@ -247,8 +353,12 @@ def block_reduce_add2(val0_f32, val1_f32, scratch0_memref, scratch1_memref): if RED_SLOTS == 1: return _wave_reduce_add(val0_f32), _wave_reduce_add(val1_f32) - scratch0_tv = flir.make_tensor(scratch0_memref, shape=(RED_SLOTS,), strides=(1,)) - scratch1_tv = flir.make_tensor(scratch1_memref, shape=(RED_SLOTS,), strides=(1,)) + scratch0_tv = flir.make_tensor( + scratch0_memref, shape=(RED_SLOTS,), strides=(1,) + ) + scratch1_tv = flir.make_tensor( + scratch1_memref, shape=(RED_SLOTS,), strides=(1,) + ) tid_v = tid.value if hasattr(tid, "value") else tid tid_v = arith.as_value(tid_v) @@ -301,8 +411,12 @@ def block_reduce_add2(val0_f32, val1_f32, scratch0_memref, scratch1_memref): ) c0_i32 = arith.as_value(arith.constant(0, type=T.i32())) - lane_safe_i32 = arith.as_value(flir.arith.SelectOp(in_range, lane_i32, c0_i32).result) - lane_safe_idx = arith.as_value(arith_ops.IndexCastOp(T.index(), lane_safe_i32).result) + lane_safe_i32 = arith.as_value( + flir.arith.SelectOp(in_range, lane_i32, c0_i32).result + ) + lane_safe_idx = arith.as_value( + arith_ops.IndexCastOp(T.index(), lane_safe_i32).result + ) red_idx = flir.crd2idx(flir.make_coord(lane_safe_idx), layout_red) v0 = scratch0_tv[red_idx] v1 = scratch1_tv[red_idx] diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index caba51b0..c336a08c 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -32,7 +32,9 @@ def dtype_to_elem_type(dtype_str: str): BLOCK_THREADS = 256 -WARP_SIZE = 64 +from kernels.kernels_common import get_warp_size + +WARP_SIZE = get_warp_size() VEC_WIDTH = 8 USE_NONTEMPORAL = True VEC_ALIGN = 16 @@ -112,7 +114,7 @@ def rmsnorm_kernel( if N < tile_cols: c_N = flir.const_index(N) c_zero = arith.constant(0.0, type=compute_type) - thread_sumsq = (c_zero) + thread_sumsq = c_zero # Pass1: sumsq for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): @@ -122,7 +124,11 @@ def rmsnorm_kernel( thread_sumsq_next = thread_sumsq if is_valid: x_e = flir.memref.load(Input, [(row), arith.as_value(idx)]) - x = (x_e) if dtype_str == "f32" else flir.arith.extf(compute_type, arith.as_value(x_e)) + x = ( + (x_e) + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(x_e)) + ) x2 = x * x thread_sumsq_next = thread_sumsq + x2 thread_sumsq = thread_sumsq_next @@ -140,11 +146,23 @@ def rmsnorm_kernel( if is_valid: x_e = flir.memref.load(Input, [(row), arith.as_value(idx)]) g_e = flir.memref.load(Gamma, [arith.as_value(idx)]) - x = (x_e) if dtype_str == "f32" else flir.arith.extf(compute_type, arith.as_value(x_e)) - g = (g_e) if dtype_str == "f32" else flir.arith.extf(compute_type, arith.as_value(g_e)) + x = ( + (x_e) + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(x_e)) + ) + g = ( + (g_e) + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(g_e)) + ) norm = (arith.ArithValue(x) * rrms).value y = (arith.ArithValue(norm) * g).value - y_e = y if dtype_str == "f32" else flir.arith.truncf(elem_type, arith.as_value(y)) + y_e = ( + y + if dtype_str == "f32" + else flir.arith.truncf(elem_type, arith.as_value(y)) + ) flir.memref.store((y_e), Output, [(row), arith.as_value(idx)]) return @@ -152,8 +170,11 @@ def rmsnorm_kernel( val_layout = flir.make_ordered_layout((1, VEC_WIDTH), order=(1, 0)) copy_atom_e = flir.make_copy_atom(elem_type, vector_size=VEC_WIDTH) tiled_copy_e = flir.make_tiled_copy_tv( - copy_atom_e, thr_layout, val_layout, - thr_shape=(1, BLOCK_THREADS), val_shape=(1, VEC_WIDTH) + copy_atom_e, + thr_layout, + val_layout, + thr_shape=(1, BLOCK_THREADS), + val_shape=(1, VEC_WIDTH), ) thr_copy_e = tiled_copy_e.get_slice((tid)) @@ -185,13 +206,13 @@ def rmsnorm_kernel( is_valid = arith.ult(idx_k, c_N) if is_valid: v_e = tensor_In[((row), arith.as_value(idx_k))] - tensor_S[((c0_idx), arith.as_value(idx_k))] = (v_e) + tensor_S[((c0_idx), arith.as_value(idx_k))] = v_e flir.gpu_ext.barrier() # Pass1: sumsq (from LDS row cache) c_zero = arith.constant(0.0, type=compute_type) - thread_sumsq = (c_zero) + thread_sumsq = c_zero for base_idx_int in range_constexpr(0, N, BLOCK_THREADS * VEC_WIDTH): c_base = flir.const_index(base_idx_int) @@ -205,9 +226,15 @@ def rmsnorm_kernel( vec_type_e, s_row, [(c0_idx), (curr_idx)], alignment=VEC_ALIGN ) vec_type_c = ir.VectorType.get([VEC_WIDTH], compute_type) - vec = vec_e if dtype_str == "f32" else flir.arith.extf(vec_type_c, arith.as_value(vec_e)) + vec = ( + vec_e + if dtype_str == "f32" + else flir.arith.extf(vec_type_c, arith.as_value(vec_e)) + ) vec2 = (arith.ArithValue(vec) * vec).value - red2 = flir.vector.reduction(compute_type, "add", (vec2), fastmath=fm_fast) + red2 = flir.vector.reduction( + compute_type, "add", (vec2), fastmath=fm_fast + ) thread_sumsq = thread_sumsq + red2 else: c_N = flir.const_index(N) @@ -219,7 +246,11 @@ def rmsnorm_kernel( v_e = tensor_S[((c0_idx), arith.as_value(idx_k))] else: v_e = arith.constant(0.0, type=elem_type) - v = (v_e) if dtype_str == "f32" else flir.arith.extf(compute_type, arith.as_value(v_e)) + v = ( + (v_e) + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(v_e)) + ) v2 = (arith.ArithValue(v) * v).value thread_sumsq = thread_sumsq + v2 @@ -241,7 +272,9 @@ def rmsnorm_kernel( thread_offset0 = (arith.ArithValue(tid) * VEC_WIDTH).value curr0 = (arith.ArithValue(c_base0) + thread_offset0).value vec_type_e0 = ir.VectorType.get([VEC_WIDTH], elem_type) - g_pref_e = flir.vector.load(vec_type_e0, Gamma, [arith.as_value(curr0)], alignment=VEC_ALIGN) + g_pref_e = flir.vector.load( + vec_type_e0, Gamma, [arith.as_value(curr0)], alignment=VEC_ALIGN + ) for base_idx_int in range_constexpr(0, N, BLOCK_THREADS * VEC_WIDTH): c_base = flir.const_index(base_idx_int) @@ -254,26 +287,62 @@ def rmsnorm_kernel( next_base_int = base_idx_int + (BLOCK_THREADS * VEC_WIDTH) if next_base_int < N: c_base_n = flir.const_index(next_base_int) - curr_idx_n = (arith.ArithValue(c_base_n) + thread_offset_base).value - g_next_e = flir.vector.load(vec_type_e, Gamma, [arith.as_value(curr_idx_n)], alignment=VEC_ALIGN) + curr_idx_n = ( + arith.ArithValue(c_base_n) + thread_offset_base + ).value + g_next_e = flir.vector.load( + vec_type_e, + Gamma, + [arith.as_value(curr_idx_n)], + alignment=VEC_ALIGN, + ) else: g_next_e = None x_e = flir.vector.load( - vec_type_e, s_row, [(c0_idx), arith.as_value(curr_idx)], alignment=VEC_ALIGN + vec_type_e, + s_row, + [(c0_idx), arith.as_value(curr_idx)], + alignment=VEC_ALIGN, ) # Gamma is reused across many blocks: do NOT use nontemporal here. - g_e = g_pref_e if g_pref_e is not None else flir.vector.load(vec_type_e, Gamma, [arith.as_value(curr_idx)], alignment=VEC_ALIGN) - x = x_e if dtype_str == "f32" else flir.arith.extf(vec_type_c, arith.as_value(x_e)) - g = g_e if dtype_str == "f32" else flir.arith.extf(vec_type_c, arith.as_value(g_e)) + g_e = ( + g_pref_e + if g_pref_e is not None + else flir.vector.load( + vec_type_e, + Gamma, + [arith.as_value(curr_idx)], + alignment=VEC_ALIGN, + ) + ) + x = ( + x_e + if dtype_str == "f32" + else flir.arith.extf(vec_type_c, arith.as_value(x_e)) + ) + g = ( + g_e + if dtype_str == "f32" + else flir.arith.extf(vec_type_c, arith.as_value(g_e)) + ) norm = (arith.ArithValue(x) * rrms_splat).value y = (arith.ArithValue(norm) * g).value - y_e = y if dtype_str == "f32" else flir.arith.truncf(vec_type_e, arith.as_value(y)) + y_e = ( + y + if dtype_str == "f32" + else flir.arith.truncf(vec_type_e, arith.as_value(y)) + ) tile_i = base_idx_int // tile_cols # python int blkOut = gOut[((row), tile_i)] thrOut = thr_copy_e.partition_S(blkOut) frgOut = flir.make_fragment_like(thrOut, elem_type) - flir.vector.store(arith.as_value(y_e), frgOut.memref, [c0_idx, c0_idx], alignment=VEC_ALIGN) + flir.vector.store( + arith.as_value(y_e), + frgOut.memref, + [c0_idx, c0_idx], + alignment=VEC_ALIGN, + ) flir.copy( tiled_copy_e, frgOut, @@ -291,12 +360,24 @@ def rmsnorm_kernel( if is_valid: x_e = tensor_S[((c0_idx), arith.as_value(idx_k))] g_e = tensor_Gamma[arith.as_value(idx_k)] - x = (x_e) if dtype_str == "f32" else flir.arith.extf(compute_type, arith.as_value(x_e)) - g = (g_e) if dtype_str == "f32" else flir.arith.extf(compute_type, arith.as_value(g_e)) + x = ( + (x_e) + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(x_e)) + ) + g = ( + (g_e) + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(g_e)) + ) norm = (arith.ArithValue(x) * rrms).value y = (arith.ArithValue(norm) * g).value - y_e = y if dtype_str == "f32" else flir.arith.truncf(elem_type, arith.as_value(y)) - tensor_Out[((row), arith.as_value(idx_k))] = (y_e) + y_e = ( + y + if dtype_str == "f32" + else flir.arith.truncf(elem_type, arith.as_value(y)) + ) + tensor_Out[((row), arith.as_value(idx_k))] = y_e @flir.jit def __call__( @@ -317,5 +398,3 @@ def __call__( ) return _RMSNorm() - - diff --git a/kernels/softmax_kernel.py b/kernels/softmax_kernel.py index bd831e78..3132a857 100644 --- a/kernels/softmax_kernel.py +++ b/kernels/softmax_kernel.py @@ -57,7 +57,9 @@ def build_softmax_module(M, N, dtype_str="f32"): # Allocator for Shared Memory (Warp Reductions) allocator = SmemAllocator(None, arch=gpu_arch) # Reduction scratch: one slot per wave (lane0 writes partials) + reuse slot 0 for broadcast. - WARP_SIZE = 64 + from kernels.kernels_common import get_warp_size + + WARP_SIZE = get_warp_size(gpu_arch) RED_SLOTS = max(1, (BLOCK_SIZE + WARP_SIZE - 1) // WARP_SIZE) _state = {} @@ -177,7 +179,9 @@ def softmax_kernel( # Check bounds # If fully within N, vector load We can check statically for the loop unroll? Since N is compile time constant, we check specific offsets. However, thread_id is dynamic. We rely on logic: If (base_idx_int + BLOCK_SIZE*VEC_WIDTH) <= N, then ALL threads are safe? No. tid=255 accesses last chunk. Safe logic: if (base_idx_int + (BLOCK_SIZE-1)*WIDTH + WIDTH) <= N. - is_safe_vector = (base_idx_int + (BLOCK_SIZE - 1) * VEC_WIDTH + VEC_WIDTH) <= N + is_safe_vector = ( + base_idx_int + (BLOCK_SIZE - 1) * VEC_WIDTH + VEC_WIDTH + ) <= N if is_safe_vector: # Flir tiled copy: global -> rmem fragment, then load vector from fragment. @@ -193,7 +197,9 @@ def softmax_kernel( alignment=VEC_ALIGN, ) vec_type_e = ir.VectorType.get([VEC_WIDTH], elem_type) - vec_val_e = vector.load(vec_type_e, frgA.memref, [c0_idx, c0_idx], alignment=VEC_ALIGN) + vec_val_e = vector.load( + vec_type_e, frgA.memref, [c0_idx, c0_idx], alignment=VEC_ALIGN + ) if dtype_str == "bf16": vec_type_c = ir.VectorType.get([VEC_WIDTH], compute_type) vec_val = flir.arith.extf(vec_type_c, arith.as_value(vec_val_e)) @@ -214,7 +220,11 @@ def softmax_kernel( # Use predicated load to avoid OOB memory access. idx_safe = arith.select(is_valid, idx_k, c0_idx) val_e = tensor_A[(row, arith.as_value(idx_safe))] - val_c = arith.extf(compute_type, val_e) if dtype_str == "bf16" else val_e + val_c = ( + arith.extf(compute_type, val_e) + if dtype_str == "bf16" + else val_e + ) val = arith.select(is_valid, val_c, c_neg_inf) row_buffer.append((val, is_valid)) @@ -226,7 +236,11 @@ def softmax_kernel( vec_val, VEC_WIDTH=VEC_WIDTH, compute_type=compute_type, vector=vector ) reduce_vec_sum = lambda vec_val: reduce_utils.reduce_vec_sum( - vec_val, VEC_WIDTH=VEC_WIDTH, compute_type=compute_type, vector=vector, fm_fast=fm_fast + vec_val, + VEC_WIDTH=VEC_WIDTH, + compute_type=compute_type, + vector=vector, + fm_fast=fm_fast, ) for item in row_buffer: @@ -268,8 +282,12 @@ def softmax_kernel( vec_val = item if g_max_splat_vec is None: vec_type = ir.VectorType.get([VEC_WIDTH], compute_type) - g_max_splat_vec = flir.vector.splat(vec_type, arith.as_value(global_max)) - log2e_splat = flir.vector.splat(vec_type, arith.as_value(c_log2e)) + g_max_splat_vec = flir.vector.splat( + vec_type, arith.as_value(global_max) + ) + log2e_splat = flir.vector.splat( + vec_type, arith.as_value(c_log2e) + ) sub = vec_val - g_max_splat_vec scaled = sub * log2e_splat @@ -299,7 +317,9 @@ def softmax_kernel( c_base = arith.index(base_idx_int) curr_idx = (c_base + thread_offset_base).value - is_safe_vector = (base_idx_int + (BLOCK_SIZE - 1) * VEC_WIDTH + VEC_WIDTH) <= N + is_safe_vector = ( + base_idx_int + (BLOCK_SIZE - 1) * VEC_WIDTH + VEC_WIDTH + ) <= N if is_safe_vector: vec_exp = row_buffer[buf_idx] @@ -307,7 +327,9 @@ def softmax_kernel( if inv_sum_splat_vec is None: vec_type = ir.VectorType.get([VEC_WIDTH], compute_type) - inv_sum_splat_vec = vector.splat(vec_type, arith.as_value(inv_sum)) + inv_sum_splat_vec = vector.splat( + vec_type, arith.as_value(inv_sum) + ) # Prefer fast-math for normalization multiply norm_vec = arith.as_value(vec_exp * inv_sum_splat_vec) @@ -328,9 +350,15 @@ def softmax_kernel( c7fff_i32 = arith.constant(0x7FFF, type=T.i32()) c1_i32 = arith.constant(1, type=T.i32()) - c16_i32_v = flir.vector.splat(vec_i32_ty, arith.as_value(c16_i32)) - c7fff_i32_v = flir.vector.splat(vec_i32_ty, arith.as_value(c7fff_i32)) - c1_i32_v = flir.vector.splat(vec_i32_ty, arith.as_value(c1_i32)) + c16_i32_v = flir.vector.splat( + vec_i32_ty, arith.as_value(c16_i32) + ) + c7fff_i32_v = flir.vector.splat( + vec_i32_ty, arith.as_value(c7fff_i32) + ) + c1_i32_v = flir.vector.splat( + vec_i32_ty, arith.as_value(c1_i32) + ) u = flir.arith.bitcast(vec_i32_ty, norm_vec) hi = arith.as_value(arith.shrui(u, c16_i32_v)) @@ -339,9 +367,20 @@ def softmax_kernel( u_round = arith.as_value(u + bias) bf16_bits = arith.as_value(arith.shrui(u_round, c16_i32_v)) - even = flir.vector.shuffle(bf16_bits, bf16_bits, mask=[0, 2, 4, 6]) - odd = flir.vector.shuffle(bf16_bits, bf16_bits, mask=[1, 3, 5, 7]) - odd_sh = arith.as_value(arith.shli(odd, flir.vector.splat(vec4_i32_ty, arith.as_value(c16_i32)))) + even = flir.vector.shuffle( + bf16_bits, bf16_bits, mask=[0, 2, 4, 6] + ) + odd = flir.vector.shuffle( + bf16_bits, bf16_bits, mask=[1, 3, 5, 7] + ) + odd_sh = arith.as_value( + arith.shli( + odd, + flir.vector.splat( + vec4_i32_ty, arith.as_value(c16_i32) + ), + ) + ) packed = arith.as_value(arith.ori(even, odd_sh)) out_bf16 = flir.vector.bitcast(vec_bf16_ty, packed) @@ -349,7 +388,12 @@ def softmax_kernel( blkC = gC[(row, tile_i)] thrC = thr_copy_C.partition_S(blkC) frgC = flir.make_fragment_like(thrC, elem_type) - vector.store(arith.as_value(out_bf16), frgC.memref, [c0_idx, c0_idx], alignment=VEC_ALIGN) + vector.store( + arith.as_value(out_bf16), + frgC.memref, + [c0_idx, c0_idx], + alignment=VEC_ALIGN, + ) flir.copy( tiled_copy_C, frgC, @@ -364,8 +408,17 @@ def softmax_kernel( thrC = thr_copy_C.partition_S(blkC) frgC = flir.make_fragment_like(thrC, elem_type) vec_type_e = ir.VectorType.get([VEC_WIDTH], elem_type) - norm_e = norm_vec if dtype_str != "bf16" else flir.arith.truncf(vec_type_e, norm_vec) - vector.store(arith.as_value(norm_e), frgC.memref, [c0_idx, c0_idx], alignment=VEC_ALIGN) + norm_e = ( + norm_vec + if dtype_str != "bf16" + else flir.arith.truncf(vec_type_e, norm_vec) + ) + vector.store( + arith.as_value(norm_e), + frgC.memref, + [c0_idx, c0_idx], + alignment=VEC_ALIGN, + ) flir.copy( tiled_copy_C, frgC, @@ -384,7 +437,9 @@ def softmax_kernel( if valid: norm_val = arith.as_value(val_exp * inv_sum) if dtype_str == "bf16": - norm_val = arith.as_value(arith.trunc_f(elem_type, norm_val)) + norm_val = arith.as_value( + arith.trunc_f(elem_type, norm_val) + ) c_k = arith.index(k) idx_k = (arith.ArithValue(curr_idx) + c_k).value @@ -408,5 +463,3 @@ def __call__( ) return _Softmax() - - diff --git a/kernels/wmma_gemm.py b/kernels/wmma_gemm.py new file mode 100644 index 00000000..c3ca3cd7 --- /dev/null +++ b/kernels/wmma_gemm.py @@ -0,0 +1,599 @@ +#!/usr/bin/env python3 +"""Optimized WMMA GEMM kernel for RDNA4 (gfx12xx, wave32). + +Computes C[M,N] = A[M,K] @ B[K,N] using v_wmma_f32_16x16x16_{f16,bf16}. +Supports f16/bf16 inputs with f32 accumulation, output in f32 or bf16. + +Optimizations applied: + - Configurable block tile with K-unroll (REG_K WMMA-K steps per tile) + - Register tiling: each wave computes REG_M x REG_N WMMA tiles per K step + - 4 waves (128 threads) per workgroup + - L2 cache swizzle: grouped block scheduling for better L2 reuse + - Inline asm batched global loads (s_clause 7 + 8x global_load_b128) + to prevent LLVM VGPR reuse serialization + - SW pipelining: next tile global loads overlap with current tile WMMA compute + - Vectorized LDS loads for A operand (contiguous in K dim) + - LDS B padding to reduce bank conflicts on scalar B reads + - Dynamic K loop (scf.for via auto-lowered range()) for O(1) IR size + +WMMA data layout for wave32 (verified on gfx1201): + A operand: "row-of-cols" -- lane t loads A[t%16][(t/16)*8 + i], i=0..7 + B operand: "col-of-rows" -- lane t loads B[(t/16)*8 + i][t%16], i=0..7 + D result: "col-of-rows" -- lane t holds D[(t/16)*8 + i][t%16], i=0..7 + +LDS layout: + A: row-major [BLOCK_M][BLOCK_K] -- A[m][k], contiguous in K + B: row-major [BLOCK_K][B_LDS_STRIDE] -- B[k][n+pad], padded stride +""" + +from flydsl.dialects.ext import flir, arith, memref, vector, rocdl, gpu +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.runtime.device import get_rocm_arch +from flydsl.lang.ir.types import T +from flydsl.utils import SmemAllocator +from _mlir import ir +from _mlir.dialects import llvm as _llvm +from _mlir.dialects import arith as _std_arith +from _mlir.dialects import memref as _std_memref +import _mlir.extras.types as Textra + + +# ============================================================================= +# Kernel configuration +# ============================================================================= + +WMMA_M = 16 +WMMA_N = 16 +WMMA_K = 16 + +# Register tiling: each wave handles REG_M x REG_N WMMA tiles +REG_M = 4 # 4 WMMA tiles vertically per wave +REG_N = 4 # 4 WMMA tiles horizontally per wave + +# K-unroll: multiple WMMA-K steps per tile load +REG_K = 2 # 2 WMMA-K steps per K tile + +# Waves per workgroup arranged as WAVES_M x WAVES_N +WAVES_M = 2 +WAVES_N = 2 +NUM_WAVES = WAVES_M * WAVES_N # 4 + +# Derived block tile dimensions +BLOCK_M = WMMA_M * REG_M * WAVES_M # 16*4*2 = 128 +BLOCK_N = WMMA_N * REG_N * WAVES_N # 16*4*2 = 128 +BLOCK_K = WMMA_K * REG_K # 16*2 = 32 + +THREADS_PER_BLOCK = NUM_WAVES * 32 # 128 + +# Elements per thread for cooperative global->LDS load +A_TILE_ELEMS = BLOCK_M * BLOCK_K +B_TILE_ELEMS = BLOCK_K * BLOCK_N +NUM_A_LOADS = A_TILE_ELEMS // (THREADS_PER_BLOCK * 8) +NUM_B_LOADS = B_TILE_ELEMS // (THREADS_PER_BLOCK * 8) + +# LDS padding for B to reduce bank conflicts +B_PAD = 8 +B_LDS_STRIDE = BLOCK_N + B_PAD + +# LDS sizes (single buffer) +LDS_A_ELEMS = BLOCK_M * BLOCK_K +LDS_B_ELEMS = BLOCK_K * B_LDS_STRIDE + +# L2 cache swizzle group size +GROUP_M = 8 + +# Number of inline-asm batched global loads (A loads + B loads) +TOTAL_GLOBAL_LOADS = NUM_A_LOADS + NUM_B_LOADS # 4 + 4 = 8 + + +def _unwrap(v): + """Unwrap ArithValue to raw MLIR Value.""" + while hasattr(v, "_value"): + v = v._value + return v + + +def create_wmma_gemm_module(M: int, N: int, K: int, in_dtype="bf16", out_dtype="f32"): + """Create an optimized WMMA GEMM module. + + Args: + M, N, K: matrix dimensions (must be multiples of BLOCK_M/N/K) + in_dtype: "bf16" or "f16" + out_dtype: "f32" or "bf16" + """ + gpu_arch = get_rocm_arch() + S = ir.ShapedType.get_dynamic_size() + + assert M % BLOCK_M == 0, f"M={M} must be multiple of BLOCK_M={BLOCK_M}" + assert N % BLOCK_N == 0, f"N={N} must be multiple of BLOCK_N={BLOCK_N}" + assert K % BLOCK_K == 0, f"K={K} must be multiple of BLOCK_K={BLOCK_K}" + + num_k_tiles = K // BLOCK_K + grid_m = M // BLOCK_M + grid_n = N // BLOCK_N + is_bf16 = in_dtype == "bf16" + + def _in_elem_ty(): + return Textra.bf16() if is_bf16 else Textra.f16() + + def _out_elem_ty(): + return Textra.f32() if out_dtype == "f32" else Textra.bf16() + + def _wmma_op(result_type, a_vec, b_vec, acc, v8i16_ty): + """Execute the correct WMMA op based on dtype, with bf16->i16 bitcast.""" + if is_bf16: + a_i16 = vector.bitcast(v8i16_ty, a_vec) + b_i16 = vector.bitcast(v8i16_ty, b_vec) + return rocdl.wmma_f32_16x16x16_bf16( + result_type, + [a_i16, b_i16, arith.unwrap(acc)], + ) + else: + return rocdl.wmma_f32_16x16x16_f16( + result_type, + [arith.unwrap(a_vec), arith.unwrap(b_vec), arith.unwrap(acc)], + ) + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _WmmaGemm(flir.MlirModule): + GPU_MODULE_NAME = "wmma_gemm" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + _state["s_a"] = allocator.allocate_array(_in_elem_ty(), LDS_A_ELEMS) + _state["s_b"] = allocator.allocate_array(_in_elem_ty(), LDS_B_ELEMS) + allocator.finalize() + + @flir.kernel + def wmma_gemm_kernel( + self: flir.T.i64, + A: lambda: Textra.memref(S, S, _in_elem_ty()), + B: lambda: Textra.memref(S, S, _in_elem_ty()), + C: lambda: Textra.memref(S, S, _out_elem_ty()), + ): + # ---- Types ---- + in_ir_ty = ir.BF16Type.get() if is_bf16 else ir.F16Type.get() + v8_in_ty = ir.VectorType.get([8], in_ir_ty) + v8f32_ty = T.vec(8, T.f32) + i16_ty = ir.IntegerType.get_signless(16) + v8i16_ty = ir.VectorType.get([8], i16_ty) + i64_ty = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr") + v4i32_ty = ir.VectorType.get([4], ir.IntegerType.get_signless(32)) + struct_ty = _llvm.StructType.get_literal([v4i32_ty] * TOTAL_GLOBAL_LOADS) + + # ---- Thread / block IDs ---- + tid = flir.thread_idx("x") + pid = flir.block_idx("x") # linear program id + + c32 = arith.index(32) + c16 = arith.index(16) + c8 = arith.index(8) + wave_id = tid // c32 + lane = tid % c32 + lane16 = lane % c16 + base8 = (lane // c16) * c8 + + # ---- L2 cache swizzle (grouped block scheduling) ---- + effective_group_m = min(GROUP_M, grid_m) + c_grid_n = arith.index(grid_n) + c_group_m = arith.index(effective_group_m) + num_pid_in_group = c_group_m * c_grid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * c_group_m + group_size_m = c_group_m + + pid_in_group = pid % num_pid_in_group + bid_m = first_pid_m + (pid_in_group % group_size_m) + bid_n = pid_in_group // group_size_m + + # Wave position in the WAVES_M x WAVES_N grid + c_wn = arith.index(WAVES_N) + wave_m = wave_id // c_wn + wave_n = wave_id % c_wn + + # Global tile origins + tile_m0 = bid_m * arith.index(BLOCK_M) + tile_n0 = bid_n * arith.index(BLOCK_N) + + # ---- LDS setup ---- + lds_base = allocator.get_base() + As = _state["s_a"](lds_base) + Bs = _state["s_b"](lds_base) + lds_a_view = As.get() + lds_b_view = Bs.get() + + # ---- Extract base pointers for inline asm loads ---- + elem_bytes = 2 # bf16 or f16 = 2 bytes + a_base_i64 = _unwrap( + _std_arith.IndexCastOp( + i64_ty, + _unwrap( + _std_memref.ExtractAlignedPointerAsIndexOp(_unwrap(A)).result + ), + ).result + ) + b_base_i64 = _unwrap( + _std_arith.IndexCastOp( + i64_ty, + _unwrap( + _std_memref.ExtractAlignedPointerAsIndexOp(_unwrap(B)).result + ), + ).result + ) + + # ---- Pre-compute thread-local LDS store addresses (invariant) ---- + a_lds_addrs = [] + b_lds_addrs = [] + for al in range_constexpr(NUM_A_LOADS): + a_lin = tid * c8 + arith.index(al * THREADS_PER_BLOCK * 8) + a_load_row = a_lin // arith.index(BLOCK_K) + a_load_col = a_lin % arith.index(BLOCK_K) + a_lds_addrs.append(a_load_row * arith.index(BLOCK_K) + a_load_col) + + for bl in range_constexpr(NUM_B_LOADS): + b_lin = tid * c8 + arith.index(bl * THREADS_PER_BLOCK * 8) + b_load_row = b_lin // arith.index(BLOCK_N) + b_load_col = b_lin % arith.index(BLOCK_N) + b_lds_addrs.append(b_load_row * arith.index(B_LDS_STRIDE) + b_load_col) + + # ---- Build inline asm strings (compile-time constants) ---- + # Load asm: s_clause + 8x global_load_b128, NO wait + asm_load_lines = [f"s_clause {TOTAL_GLOBAL_LOADS - 1}"] + for i in range_constexpr(TOTAL_GLOBAL_LOADS): + asm_load_lines.append( + f"global_load_b128 ${i}, ${i + TOTAL_GLOBAL_LOADS}, off" + ) + asm_load_str = "\n".join(asm_load_lines) + + out_constraints = ",".join(["=&v"] * TOTAL_GLOBAL_LOADS) + in_constraints = ",".join(["v"] * TOTAL_GLOBAL_LOADS) + asm_constraints = f"{out_constraints},{in_constraints}" + + # ---- Helper: compute flat addresses for a given k_base ---- + def _compute_load_addrs(k_base): + """Compute 8 flat pointers for global loads at given k_base.""" + all_addrs = [] + for al in range_constexpr(NUM_A_LOADS): + a_lin = tid * c8 + arith.index(al * THREADS_PER_BLOCK * 8) + a_load_row = a_lin // arith.index(BLOCK_K) + a_load_col = a_lin % arith.index(BLOCK_K) + g_a_row = tile_m0 + a_load_row + g_a_col = k_base + a_load_col + byte_off = (g_a_row * arith.index(K) + g_a_col) * arith.index( + elem_bytes + ) + byte_off_i64 = _unwrap( + _std_arith.IndexCastOp( + i64_ty, _unwrap(arith.unwrap(byte_off)) + ).result + ) + addr_i64 = _unwrap( + _std_arith.AddIOp(a_base_i64, byte_off_i64).result + ) + addr_ptr = _unwrap(_llvm.IntToPtrOp(ptr_ty, addr_i64).result) + all_addrs.append(addr_ptr) + + for bl in range_constexpr(NUM_B_LOADS): + b_lin = tid * c8 + arith.index(bl * THREADS_PER_BLOCK * 8) + b_load_row = b_lin // arith.index(BLOCK_N) + b_load_col = b_lin % arith.index(BLOCK_N) + g_b_row = k_base + b_load_row + g_b_col = tile_n0 + b_load_col + byte_off = (g_b_row * arith.index(N) + g_b_col) * arith.index( + elem_bytes + ) + byte_off_i64 = _unwrap( + _std_arith.IndexCastOp( + i64_ty, _unwrap(arith.unwrap(byte_off)) + ).result + ) + addr_i64 = _unwrap( + _std_arith.AddIOp(b_base_i64, byte_off_i64).result + ) + addr_ptr = _unwrap(_llvm.IntToPtrOp(ptr_ty, addr_i64).result) + all_addrs.append(addr_ptr) + return all_addrs + + # ---- Helper: issue batched loads (no wait) ---- + def _issue_loads(all_addrs): + """Issue 8 batched global_load_b128 via inline asm. Returns struct.""" + return _llvm.inline_asm( + struct_ty, + all_addrs, + asm_load_str, + asm_constraints, + has_side_effects=True, + ) + + # ---- Helper: extract + bitcast load results ---- + def _extract_load_results(asm_result): + """Extract 8 load results from asm struct, bitcast to bf16/f16.""" + all_vecs = [] + for i in range_constexpr(TOTAL_GLOBAL_LOADS): + pos_attr = ir.DenseI64ArrayAttr.get([i]) + v4i32_val = _llvm.ExtractValueOp( + v4i32_ty, asm_result, pos_attr + ).result + bf16_vec = vector.bitcast(v8_in_ty, v4i32_val) + all_vecs.append(bf16_vec) + return all_vecs[:NUM_A_LOADS], all_vecs[NUM_A_LOADS:] + + # ---- Helper: store load results to LDS ---- + def _store_to_lds(a_vecs, b_vecs): + """Store A and B vectors to LDS.""" + for al in range_constexpr(NUM_A_LOADS): + vector.store(a_vecs[al], lds_a_view, [a_lds_addrs[al]]) + for bl in range_constexpr(NUM_B_LOADS): + vector.store(b_vecs[bl], lds_b_view, [b_lds_addrs[bl]]) + + # ---- Inline asm infrastructure for batched B + A LDS reads ---- + # B operand: 4 packed i32 (8 bf16) per WMMA tile, 16 i32 per K-step + # A operand: v4i32 (8 bf16) per WMMA tile, REG_M=4 tiles per K-step + NUM_B_VGPRS_PER_K = REG_N * 4 # 16 + b_lds_stride_bytes = B_LDS_STRIDE * 2 # 272 bytes + + i32_ty = ir.IntegerType.get_signless(32) + n = NUM_B_VGPRS_PER_K # 16 + + # B immediate offsets: even (u16) and odd (d16_hi) for each (rn, pi) + even_offsets = [] + odd_offsets = [] + for rn in range_constexpr(REG_N): + for pi in range_constexpr(4): + even_offsets.append((2 * pi * B_LDS_STRIDE + rn * WMMA_N) * 2) + odd_offsets.append(((2 * pi + 1) * B_LDS_STRIDE + rn * WMMA_N) * 2) + + # A immediate offsets for each rm + a_rk_offsets = [] + for rm in range_constexpr(REG_M): + a_rk_offsets.append(rm * WMMA_M * BLOCK_K * 2) + + # ---- Combined asm for rk=0: A + B loads with waits ---- + # Outputs: $0..$15 = 16 B i32, $16..$19 = 4 A v4i32 + # Inputs: $20 = B base (v), $21 = A base (v) + NUM_COMBINED_OUTPUTS = n + REG_M # 20 + combined_out_types = [i32_ty] * n + [v4i32_ty] * REG_M + combined_struct_ty = _llvm.StructType.get_literal(combined_out_types) + + combined_lines = [] + b_base_idx = NUM_COMBINED_OUTPUTS # $20 + a_base_idx = NUM_COMBINED_OUTPUTS + 1 # $21 + for i in range_constexpr(n): + combined_lines.append( + f"ds_load_u16 ${i}, ${b_base_idx} offset:{even_offsets[i]}" + ) + for i in range_constexpr(REG_M): + combined_lines.append( + f"ds_load_b128 ${n + i}, ${a_base_idx} offset:{a_rk_offsets[i]}" + ) + combined_lines.append("s_wait_dscnt 0x4") + for i in range_constexpr(n): + combined_lines.append( + f"ds_load_u16_d16_hi ${i}, ${b_base_idx} offset:{odd_offsets[i]}" + ) + combined_lines.append("s_wait_dscnt 0x0") + + asm_combined_str = "\n".join(combined_lines) + asm_combined_constraints = ",".join(["=&v"] * NUM_COMBINED_OUTPUTS) + ",v,v" + + # ---- Helpers: compute B/A base address and extract results ---- + def _compute_b_base(k_off): + """Compute single VGPR base address for B LDS reads.""" + b_alloc_byte_off = _unwrap( + _std_arith.IndexCastOp( + i32_ty, + _unwrap( + _std_memref.ExtractAlignedPointerAsIndexOp( + _unwrap(lds_b_view) + ).result + ), + ).result + ) + k_row = k_off + base8 + row_byte_off = k_row * arith.index(b_lds_stride_bytes) + col_idx = wave_n * arith.index(REG_N * WMMA_N) + lane16 + col_byte_off = col_idx * arith.index(2) + total_off = row_byte_off + col_byte_off + total_off_i32 = _unwrap( + _std_arith.IndexCastOp( + i32_ty, _unwrap(arith.unwrap(total_off)) + ).result + ) + return _unwrap( + _std_arith.AddIOp(b_alloc_byte_off, total_off_i32).result + ) + + def _compute_a_base(k_off): + """Compute single VGPR base address for A LDS reads.""" + a_alloc_byte_off = _unwrap( + _std_arith.IndexCastOp( + i32_ty, + _unwrap( + _std_memref.ExtractAlignedPointerAsIndexOp( + _unwrap(lds_a_view) + ).result + ), + ).result + ) + a_row = wave_m * arith.index(REG_M * WMMA_M) + lane16 + a_off = (a_row * arith.index(BLOCK_K) + k_off + base8) * arith.index(2) + a_off_i32 = _unwrap( + _std_arith.IndexCastOp(i32_ty, _unwrap(arith.unwrap(a_off))).result + ) + return _unwrap(_std_arith.AddIOp(a_alloc_byte_off, a_off_i32).result) + + def _do_wmma_block(a_vecs, b_vecs, accs): + """Execute REG_M * REG_N WMMAs for one K-step.""" + for rm in range_constexpr(REG_M): + for rn in range_constexpr(REG_N): + idx = rm * REG_N + rn + accs[idx] = _wmma_op( + v8f32_ty, + a_vecs[rm], + b_vecs[rn], + accs[idx], + v8i16_ty, + ) + + # ---- Helper: WMMA compute phase (reads from LDS) ---- + def _do_compute(accs_in): + """Execute WMMA compute for all K-steps. + + For each K-step, issues a single combined inline asm block: + 16x ds_load_u16 (B low halves) + 4x ds_load_b128 (A tiles) + s_wait_dscnt 0x4 (B_u16 done, A still in flight) + 16x ds_load_u16_d16_hi (B high halves) + s_wait_dscnt 0x0 (everything done) + Then executes REG_M * REG_N WMMAs. + """ + new_accs = list(accs_in) + + for rk in range_constexpr(REG_K): + k_off = arith.index(rk * WMMA_K) + b_base = _compute_b_base(k_off) + a_base = _compute_a_base(k_off) + + result = _llvm.inline_asm( + combined_struct_ty, + [b_base, a_base], + asm_combined_str, + asm_combined_constraints, + has_side_effects=True, + ) + + # Extract B: 16 i32 -> pack into 4 v4i32 -> bitcast to 4 v8bf16 + b_vecs = [] + for rn in range_constexpr(REG_N): + v4 = _unwrap(_llvm.UndefOp(v4i32_ty).result) + for pi in range_constexpr(4): + idx = rn * 4 + pi + pos_attr = ir.DenseI64ArrayAttr.get([idx]) + val = _unwrap( + _llvm.ExtractValueOp(i32_ty, result, pos_attr).result + ) + pi_val = _unwrap( + _std_arith.ConstantOp( + i32_ty, ir.IntegerAttr.get(i32_ty, pi) + ).result + ) + v4 = _unwrap(_llvm.InsertElementOp(v4, val, pi_val).result) + b_vecs.append(vector.bitcast(v8_in_ty, v4)) + + # Extract A: 4 v4i32 -> bitcast to 4 v8bf16 + a_vecs = [] + for rm in range_constexpr(REG_M): + pos_attr = ir.DenseI64ArrayAttr.get([n + rm]) + v4 = _llvm.ExtractValueOp(v4i32_ty, result, pos_attr).result + a_vecs.append(vector.bitcast(v8_in_ty, v4)) + + _do_wmma_block(a_vecs, b_vecs, new_accs) + + return new_accs + + # ---- Initialize REG_M x REG_N accumulators (flat list) ---- + zero_acc = arith.constant_vector(0.0, v8f32_ty) + accs = [zero_acc for _ in range_constexpr(REG_M * REG_N)] + + # ======================================================== + # SW-PIPELINED K LOOP + # ======================================================== + + # ---- Prologue: load tile 0, wait, store to LDS, barrier ---- + prologue_addrs = _compute_load_addrs(arith.index(0)) + prologue_result = _issue_loads(prologue_addrs) + # Wait for prologue loads + _llvm.inline_asm( + res=None, + operands_=[], + asm_string="s_wait_loadcnt 0x0", + constraints="", + has_side_effects=True, + ) + a_vecs_p, b_vecs_p = _extract_load_results(prologue_result) + _store_to_lds(a_vecs_p, b_vecs_p) + gpu.barrier() + + if num_k_tiles == 1: + # Single tile: just compute + accs = _do_compute(accs) + else: + # ---- Main loop: tiles 0..num_k_tiles-2 ---- + for kt in range(num_k_tiles - 1): + # k_base for the NEXT tile's load + k_base_next = (kt + arith.index(1)) * arith.index(BLOCK_K) + + # Issue loads for next tile (NO WAIT - overlap with compute) + next_addrs = _compute_load_addrs(k_base_next) + next_result = _issue_loads(next_addrs) + + # Compute on current tile (from LDS) + accs = _do_compute(accs) + + # Barrier: ensure all waves done reading LDS + gpu.barrier() + + # Wait for next tile's loads to complete + _llvm.inline_asm( + res=None, + operands_=[], + asm_string="s_wait_loadcnt 0x0", + constraints="", + has_side_effects=True, + ) + + # Store next tile to LDS + a_vecs_n, b_vecs_n = _extract_load_results(next_result) + _store_to_lds(a_vecs_n, b_vecs_n) + + # Barrier: ensure all waves done writing LDS + gpu.barrier() + + # ---- Epilogue: compute on last tile ---- + accs = _do_compute(accs) + + # ========== Store results (scalar stores) ========== + for rm in range_constexpr(REG_M): + for rn in range_constexpr(REG_N): + idx = rm * REG_N + rn + wmma_m_off = wave_m * arith.index(REG_M * WMMA_M) + arith.index( + rm * WMMA_M + ) + wmma_n_off = wave_n * arith.index(REG_N * WMMA_N) + arith.index( + rn * WMMA_N + ) + + # D layout: col-of-rows => D[base8+i][lane16] + for si in range_constexpr(8): + g_row = tile_m0 + wmma_m_off + base8 + arith.index(si) + g_col = tile_n0 + wmma_n_off + lane16 + val = vector.extract( + accs[idx], + static_position=[si], + dynamic_position=[], + ) + if out_dtype == "bf16": + val = arith.trunc_f(ir.BF16Type.get(), val) + memref.store(val, C, [g_row, g_col]) + + @flir.jit + def __call__( + self: flir.T.i64, + A: lambda: Textra.memref(S, S, _in_elem_ty()), + B: lambda: Textra.memref(S, S, _in_elem_ty()), + C: lambda: Textra.memref(S, S, _out_elem_ty()), + ): + c1 = arith.index(1) + total_blocks = arith.index(grid_m * grid_n) + bk = arith.index(THREADS_PER_BLOCK) + flir.gpu_ext.LaunchFuncOp( + ["wmma_gemm", "wmma_gemm_kernel"], + grid_size=(total_blocks, c1, c1), + block_size=(bk, c1, c1), + kernel_operands=[A, B, C], + ) + + return _WmmaGemm() diff --git a/tests/kernels/profile_gemm.py b/tests/kernels/profile_gemm.py new file mode 100644 index 00000000..e4571188 --- /dev/null +++ b/tests/kernels/profile_gemm.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +"""Profile script for rocprofv3 — runs GEMM kernels for hardware counter collection.""" + +import torch +import sys +import os + +os.environ.setdefault("HIP_VISIBLE_DEVICES", "0") + +MODE = sys.argv[1] if len(sys.argv) > 1 else "all" +SZ = int(sys.argv[2]) if len(sys.argv) > 2 else 4096 + + +def run_pytorch(sz): + """Run PyTorch (rocBLAS) GEMM.""" + A = torch.randn(sz, sz, device="cuda", dtype=torch.bfloat16) + B = torch.randn(sz, sz, device="cuda", dtype=torch.bfloat16) + # warmup + for _ in range(5): + C = torch.mm(A, B) + torch.cuda.synchronize() + # profiled run + for _ in range(10): + C = torch.mm(A, B) + torch.cuda.synchronize() + + +def run_triton(sz): + """Run Triton GEMM (128x128x32).""" + import triton + import triton.language as tl + + @triton.jit + def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + ): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + acc += tl.dot(a, b) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, acc, mask=c_mask) + + A = torch.randn(sz, sz, device="cuda", dtype=torch.bfloat16) + B = torch.randn(sz, sz, device="cuda", dtype=torch.bfloat16) + M, K = A.shape + K, N = B.shape + c = torch.empty((M, N), device="cuda", dtype=torch.float32) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + # warmup + for _ in range(5): + matmul_kernel[grid]( + A, + B, + c, + M, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + c.stride(0), + c.stride(1), + BLOCK_M=128, + BLOCK_N=128, + BLOCK_K=32, + GROUP_M=8, + ) + torch.cuda.synchronize() + # profiled run + for _ in range(10): + matmul_kernel[grid]( + A, + B, + c, + M, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + c.stride(0), + c.stride(1), + BLOCK_M=128, + BLOCK_N=128, + BLOCK_K=32, + GROUP_M=8, + ) + torch.cuda.synchronize() + + +def run_flydsl(sz): + """Run our FlyDSL WMMA GEMM.""" + # Must add repo root to path for kernels module + repo_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..") + if repo_root not in sys.path: + sys.path.insert(0, repo_root) + + from kernels.wmma_gemm import create_wmma_gemm_module + + A = torch.randn(sz, sz, device="cuda", dtype=torch.bfloat16) + B = torch.randn(sz, sz, device="cuda", dtype=torch.bfloat16) + C = torch.zeros(sz, sz, device="cuda", dtype=torch.float32) + + mod = create_wmma_gemm_module(sz, sz, sz, in_dtype="bf16", out_dtype="f32") + # warmup + for _ in range(5): + mod(A, B, C) + torch.cuda.synchronize() + # profiled run + for _ in range(10): + mod(A, B, C) + torch.cuda.synchronize() + + +if __name__ == "__main__": + if MODE in ("pytorch", "all"): + print(f"Running PyTorch GEMM {SZ}x{SZ}...") + run_pytorch(SZ) + if MODE in ("triton", "all"): + print(f"Running Triton GEMM {SZ}x{SZ}...") + run_triton(SZ) + if MODE in ("flydsl", "all"): + print(f"Running FlyDSL GEMM {SZ}x{SZ}...") + run_flydsl(SZ) + print("Done.") diff --git a/tests/kernels/test_inline_asm_load.py b/tests/kernels/test_inline_asm_load.py new file mode 100644 index 00000000..2c70d75c --- /dev/null +++ b/tests/kernels/test_inline_asm_load.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +"""Test llvm.inline_asm for batched global loads on gfx1201. + +Goal: Force the LLVM backend to use distinct VGPRs for each global_load_b128, +preventing the VGPR reuse that serializes loads in the WMMA GEMM kernel. + +Approach: + 1. Convert memref base address to flat pointer via extract_aligned_pointer_as_index + 2. Compute per-thread byte offset + 3. Use llvm.inline_asm to emit global_load_b128 with explicit VGPR constraints +""" + +import sys +import os +import numpy as np +import pytest +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +import flydsl +from flydsl.dialects.ext import flir, arith, memref, vector, gpu +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.runtime.device import get_rocm_arch +from flydsl.lang.ir.types import T +from _mlir import ir +from _mlir.dialects import llvm as _llvm +from _mlir.dialects import arith as _std_arith +from _mlir.dialects import memref as _std_memref +import _mlir.extras.types as Textra + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available", allow_module_level=True) + +gpu_arch = get_rocm_arch() +if not gpu_arch.startswith("gfx12"): + pytest.skip(f"Test requires gfx12xx, got {gpu_arch}", allow_module_level=True) + + +def _unwrap(v): + """Unwrap ArithValue to raw MLIR Value.""" + while hasattr(v, "_value"): + v = v._value + return v + + +def create_inline_asm_copy_kernel(): + """Minimal kernel: copy 8 bf16 values using inline asm global_load_b128. + + Each thread loads 8 bf16 (= 16 bytes = 128 bits) from src and stores to dst. + This tests the full pipeline: memref -> flat ptr -> inline asm load -> store. + """ + S = ir.ShapedType.get_dynamic_size() + + class _InlineAsmTest(flir.MlirModule): + GPU_MODULE_NAME = "inline_asm_test" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + @flir.kernel + def inline_asm_kernel( + self: flir.T.i64, + src: lambda: Textra.memref(S, Textra.bf16()), + dst: lambda: Textra.memref(S, Textra.bf16()), + ): + tid = flir.thread_idx("x") + + # Each thread handles 8 bf16 elements = 16 bytes + elem_offset = tid * arith.index(8) + byte_offset = tid * arith.index(16) + + # --- Get flat pointer for src --- + # 1. Extract base pointer as index + src_raw = _unwrap(src) + src_ptr_idx = _unwrap( + _std_memref.ExtractAlignedPointerAsIndexOp(src_raw).result + ) + + # 2. Convert index -> i64 + i64_ty = ir.IntegerType.get_signless(64) + src_base_i64 = _unwrap(_std_arith.IndexCastOp(i64_ty, src_ptr_idx).result) + + # 3. Convert byte offset to i64 + byte_off_raw = _unwrap(arith.unwrap(byte_offset)) + byte_off_i64 = _unwrap(_std_arith.IndexCastOp(i64_ty, byte_off_raw).result) + + # 4. Add base + offset + addr_i64 = _unwrap(_std_arith.AddIOp(src_base_i64, byte_off_i64).result) + + # 5. Convert to llvm.ptr + ptr_ty = ir.Type.parse("!llvm.ptr") + addr_ptr = _unwrap(_llvm.IntToPtrOp(ptr_ty, addr_i64).result) + + # --- Use inline asm to do global_load_b128 --- + # global_load_b128 loads 128 bits (16 bytes = 4xi32) from flat address + v4i32_ty = ir.VectorType.get([4], ir.IntegerType.get_signless(32)) + + loaded = _llvm.inline_asm( + v4i32_ty, # result: vector<4xi32> + [addr_ptr], # operands: the flat pointer + "global_load_b128 $0, $1, off\ns_wait_loadcnt 0x0", # asm + "=&v,v", # constraints: output=vgpr (early-clobber), input=vgpr + has_side_effects=True, + ) + + # --- Bitcast v4i32 -> v8bf16 and store --- + v8bf16_ty = ir.VectorType.get([8], ir.BF16Type.get()) + loaded_bf16 = vector.bitcast(v8bf16_ty, loaded) + + # Store to dst using regular vector.store + vector.store(loaded_bf16, dst, [elem_offset]) + + @flir.jit + def __call__( + self: flir.T.i64, + src: lambda: Textra.memref(S, Textra.bf16()), + dst: lambda: Textra.memref(S, Textra.bf16()), + ): + c1 = arith.index(1) + c32 = arith.index(32) + flir.gpu_ext.LaunchFuncOp( + ["inline_asm_test", "inline_asm_kernel"], + grid_size=(c1, c1, c1), + block_size=(c32, c1, c1), + kernel_operands=[src, dst], + ) + + return _InlineAsmTest() + + +def test_inline_asm_single_load(): + """Test that a single inline asm global_load_b128 works correctly.""" + print(f"\n{'=' * 60}") + print(f"Inline ASM single load test - {gpu_arch}") + print(f"{'=' * 60}") + + N = 256 # 32 threads * 8 elements each + m = create_inline_asm_copy_kernel() + exe = flydsl.compile(m) + + src = torch.randn(N, device="cuda", dtype=torch.bfloat16) + dst = torch.zeros(N, device="cuda", dtype=torch.bfloat16) + + exe(src, dst) + torch.cuda.synchronize() + + # Verify copy + error = torch.max(torch.abs(src.float() - dst.float())).item() + print(f"Max error: {error:.2e}") + assert error < 1e-6, f"Inline asm copy failed: error={error:.2e}" + print("PASS - single inline asm load works!") + return True + + +def create_batched_load_kernel(num_loads=2): + """Kernel that uses inline asm for multiple batched global_load_b128. + + Tests that we can issue multiple loads in a single asm block with + s_clause prefix to force the hardware to batch them. + """ + S = ir.ShapedType.get_dynamic_size() + + class _BatchedLoadTest(flir.MlirModule): + GPU_MODULE_NAME = "batched_load_test" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + @flir.kernel + def batched_load_kernel( + self: flir.T.i64, + src: lambda: Textra.memref(S, Textra.bf16()), + dst: lambda: Textra.memref(S, Textra.bf16()), + ): + tid = flir.thread_idx("x") + i64_ty = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr") + v4i32_ty = ir.VectorType.get([4], ir.IntegerType.get_signless(32)) + v8bf16_ty = ir.VectorType.get([8], ir.BF16Type.get()) + + # Get base pointer + src_raw = _unwrap(src) + src_ptr_idx = _unwrap( + _std_memref.ExtractAlignedPointerAsIndexOp(src_raw).result + ) + src_base_i64 = _unwrap(_std_arith.IndexCastOp(i64_ty, src_ptr_idx).result) + + # Compute addresses for num_loads chunks, each 16 bytes apart per thread + # Thread t loads from: base + t*16 + chunk*32*16 + # (32 threads, 16 bytes each = 512 bytes per chunk) + chunk_stride = 32 * 16 # bytes per chunk + + addrs = [] + for i in range_constexpr(num_loads): + byte_off_raw = _unwrap( + arith.unwrap(tid * arith.index(16) + arith.index(i * chunk_stride)) + ) + byte_off_i64 = _unwrap( + _std_arith.IndexCastOp(i64_ty, byte_off_raw).result + ) + addr_i64 = _unwrap(_std_arith.AddIOp(src_base_i64, byte_off_i64).result) + addr_ptr = _llvm.IntToPtrOp(ptr_ty, addr_i64).result + addrs.append(addr_ptr) + + if num_loads == 2: + # 2 loads with s_clause 1 + # Result type: struct of 2 x v4i32 + # Use llvm.StructType for the result + struct_ty = _llvm.StructType.get_literal([v4i32_ty, v4i32_ty]) + + result = _llvm.inline_asm( + struct_ty, + addrs, + "s_clause 1\n" + "global_load_b128 $0, $2, off\n" + "global_load_b128 $1, $3, off\n" + "s_wait_loadcnt 0x0", + "=&v,=&v,v,v", # early-clobber outputs to force distinct VGPRs + has_side_effects=True, + ) + + # Extract results from struct + loaded_0 = _llvm.ExtractValueOp(v4i32_ty, result, [0]).result + loaded_1 = _llvm.ExtractValueOp(v4i32_ty, result, [1]).result + loaded_list = [loaded_0, loaded_1] + else: + # Fallback: individual loads + loaded_list = [] + for i in range_constexpr(num_loads): + loaded = _llvm.inline_asm( + v4i32_ty, + [addrs[i]], + "global_load_b128 $0, $1, off", + "=v,v", + has_side_effects=True, + ) + loaded_list.append(loaded) + + # Store results + for i in range_constexpr(num_loads): + elem_offset = tid * arith.index(8) + arith.index(i * 32 * 8) + loaded_bf16 = vector.bitcast(v8bf16_ty, loaded_list[i]) + vector.store(loaded_bf16, dst, [elem_offset]) + + @flir.jit + def __call__( + self: flir.T.i64, + src: lambda: Textra.memref(S, Textra.bf16()), + dst: lambda: Textra.memref(S, Textra.bf16()), + ): + c1 = arith.index(1) + c32 = arith.index(32) + flir.gpu_ext.LaunchFuncOp( + ["batched_load_test", "batched_load_kernel"], + grid_size=(c1, c1, c1), + block_size=(c32, c1, c1), + kernel_operands=[src, dst], + ) + + return _BatchedLoadTest() + + +def test_batched_load_2(): + """Test 2 batched loads with s_clause 1.""" + print(f"\n{'=' * 60}") + print(f"Inline ASM batched load (2 loads) test - {gpu_arch}") + print(f"{'=' * 60}") + + N = 2 * 32 * 8 # 2 chunks * 32 threads * 8 elements + m = create_batched_load_kernel(num_loads=2) + exe = flydsl.compile(m) + + src = torch.randn(N, device="cuda", dtype=torch.bfloat16) + dst = torch.zeros(N, device="cuda", dtype=torch.bfloat16) + + exe(src, dst) + torch.cuda.synchronize() + + error = torch.max(torch.abs(src.float() - dst.float())).item() + print(f"Max error: {error:.2e}") + assert error < 1e-6, f"Batched load failed: error={error:.2e}" + print("PASS - batched 2-load inline asm works!") + + +if __name__ == "__main__": + test_inline_asm_single_load() + test_batched_load_2() diff --git a/tests/kernels/test_triton_gemm.py b/tests/kernels/test_triton_gemm.py new file mode 100644 index 00000000..6c631cc1 --- /dev/null +++ b/tests/kernels/test_triton_gemm.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +"""Triton GEMM benchmark for RDNA4 comparison.""" + +import torch +import triton +import triton.language as tl +import time + + +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + acc += tl.dot(a, b) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, acc, mask=c_mask) + + +def triton_matmul(a, b, block_m=128, block_n=128, block_k=32, group_m=8): + assert a.shape[1] == b.shape[0] + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + GROUP_M=group_m, + ) + return c + + +def benchmark(fn, *args, warmup=10, iters=100, label=""): + for _ in range(warmup): + fn(*args) + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(iters): + fn(*args) + end_event.record() + torch.cuda.synchronize() + + ms = start_event.elapsed_time(end_event) / iters + return ms + + +if __name__ == "__main__": + torch.manual_seed(42) + + configs = [ + # (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + (64, 64, 16, 8), + (64, 64, 32, 8), + (128, 128, 16, 8), + (128, 128, 32, 8), + (128, 128, 64, 8), + (128, 256, 32, 8), + (256, 128, 32, 8), + (256, 256, 32, 8), + ] + + for sz in [1024, 2048, 4096]: + print(f"\n{'=' * 70}") + print(f" GEMM {sz}x{sz}x{sz}, bf16 -> f32") + print(f"{'=' * 70}") + + A = torch.randn(sz, sz, device="cuda", dtype=torch.bfloat16) + B = torch.randn(sz, sz, device="cuda", dtype=torch.bfloat16) + flops = 2 * sz**3 + + # PyTorch reference + ms = benchmark(lambda a, b: torch.mm(a, b), A, B, label="pytorch") + print( + f" PyTorch torch.mm: {ms:.3f} ms {flops / (ms / 1000) / 1e12:.2f} TFLOPS" + ) + + # Triton configs + for bm, bn, bk, gm in configs: + try: + # Verify correctness first + C_tri = triton_matmul( + A, B, block_m=bm, block_n=bn, block_k=bk, group_m=gm + ) + C_ref = A.float() @ B.float() + err = (C_tri - C_ref).abs().max().item() / C_ref.abs().max().item() + if err > 0.1: + print(f" Triton {bm}x{bn}x{bk} g{gm}: INCORRECT err={err:.2e}") + continue + + ms = benchmark( + lambda a, b: triton_matmul( + a, b, block_m=bm, block_n=bn, block_k=bk, group_m=gm + ), + A, + B, + ) + tflops = flops / (ms / 1000) / 1e12 + print( + f" Triton {bm:3d}x{bn:3d}x{bk:2d} g{gm}: {ms:.3f} ms {tflops:.2f} TFLOPS err={err:.1e}" + ) + except Exception as e: + print(f" Triton {bm:3d}x{bn:3d}x{bk:2d} g{gm}: FAILED - {e}") diff --git a/tests/kernels/test_wmma_basic.py b/tests/kernels/test_wmma_basic.py new file mode 100644 index 00000000..9891f93d --- /dev/null +++ b/tests/kernels/test_wmma_basic.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +"""Basic WMMA test -- single 16x16x16 tile on gfx1201 (RDNA4, wave32). + +Validates that the v_wmma_f32_16x16x16_f16 instruction works correctly +on Radeon 9700 (gfx1201) by running a single-wave 16x16 matmul. + +WMMA data layout for wave32 (32 lanes, 8 elements per lane): + For lane t (0..31), let lane16 = t % 16, base8 = (t / 16) * 8. + + A (16x16 f16): "row-of-cols" layout + Lane t holds A[lane16][base8 + i] for i in 0..7 + (each lane loads 8 consecutive columns from one row) + + B (16x16 f16): "col-of-rows" layout + Lane t holds B[base8 + i][lane16] for i in 0..7 + (each lane loads 8 consecutive rows from one column) + + D (16x16 f32): "col-of-rows" layout (same as B) + Lane t holds D[base8 + i][lane16] for i in 0..7 + + Verified empirically on gfx1201 (Radeon 9700, RDNA4). +""" + +import sys +import os +import numpy as np +import pytest +import torch + +import flydsl +from flydsl.dialects.ext import flir, arith, memref, vector, rocdl +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.runtime.device import get_rocm_arch +from flydsl.lang.ir.types import T +from _mlir import ir +import _mlir.extras.types as Textra + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available", allow_module_level=True) + + +def create_wmma_kernel(): + """Create a minimal WMMA 16x16x16 f16 matmul kernel. + + Layout: C[16,16] = A[16,16] @ B[16,16] (all f16, accumulate in f32) + + Uses row-major A and B in global memory. + Each thread (lane) loads its 8 elements of A and B based on the WMMA lane + mapping, executes one WMMA instruction, and stores the 8 result elements. + """ + gpu_arch = get_rocm_arch() + S = ir.ShapedType.get_dynamic_size() + + class _WmmaBasic(flir.MlirModule): + GPU_MODULE_NAME = "wmma_test" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + @flir.kernel + def wmma_kernel( + self: flir.T.i64, + A: lambda: Textra.memref(S, S, Textra.f16()), # [M, K] row-major f16 + B: lambda: Textra.memref(S, S, Textra.f16()), # [K, N] row-major f16 + C: lambda: Textra.memref(S, S, Textra.f32()), # [M, N] row-major f32 + ): + # WMMA vector types for wave32 + v8f16_ty = T.vec(8, T.f16) + v8f32_ty = T.vec(8, T.f32) + + # Thread ID within the wave (0..31) + tid = flir.thread_idx("x") + + # Decompose lane ID for WMMA mapping: + # lane16 = tid % 16 (0..15) + # base8 = (tid / 16) * 8 (0 or 8) + c16 = arith.index(16) + c8 = arith.index(8) + lane16 = tid % c16 # 0..15 + base8 = (tid // c16) * c8 # 0 or 8 + + # WMMA 16x16x16 wave32 data layout (empirically verified on gfx1201): + # A: "row-of-cols" -- lane t loads A[t%16][(t/16)*8 + i] for i in 0..7 + # B: "col-of-rows" -- lane t loads B[(t/16)*8 + i][t%16] for i in 0..7 + # D: "col-of-rows" -- lane t holds D[(t/16)*8 + i][t%16] for i in 0..7 + # + # lane16 = tid % 16, base8 = (tid / 16) * 8 + + # Load A: row-of-cols layout => A[lane16][base8 + i] + a_elems = [] + for i in range_constexpr(8): + ci = arith.index(i) + a_val = memref.load(A, [lane16, base8 + ci]) + a_elems.append(a_val) + a_vec = vector.from_elements(v8f16_ty, a_elems) + + # Load B: col-of-rows layout => B[base8 + i][lane16] + b_elems = [] + for i in range_constexpr(8): + ci = arith.index(i) + b_val = memref.load(B, [base8 + ci, lane16]) + b_elems.append(b_val) + b_vec = vector.from_elements(v8f16_ty, b_elems) + + # Initialize accumulator to zero + zero_acc = arith.constant_vector(0.0, v8f32_ty) + + # Execute WMMA: D = A * B + C + d_vec = rocdl.wmma_f32_16x16x16_f16( + v8f32_ty, + [arith.unwrap(a_vec), arith.unwrap(b_vec), arith.unwrap(zero_acc)], + ) + + # Store D: col-of-rows layout => C[base8 + i][lane16] + for i in range_constexpr(8): + ci = arith.index(i) + val = vector.extract(d_vec, static_position=[i], dynamic_position=[]) + memref.store(val, C, [base8 + ci, lane16]) + + @flir.jit + def __call__( + self: flir.T.i64, + A: lambda: Textra.memref(S, S, Textra.f16()), + B: lambda: Textra.memref(S, S, Textra.f16()), + C: lambda: Textra.memref(S, S, Textra.f32()), + ): + c1 = arith.index(1) + c32 = arith.index(32) # one wave = 32 threads + flir.gpu_ext.LaunchFuncOp( + ["wmma_test", "wmma_kernel"], + grid_size=(c1, c1, c1), + block_size=(c32, c1, c1), + kernel_operands=[A, B, C], + ) + + return _WmmaBasic() + + +def test_wmma_basic(): + """Test single WMMA 16x16x16 f16 matmul on gfx1201.""" + gpu_arch = get_rocm_arch() + print(f"\n{'=' * 60}") + print(f"WMMA Basic Test - {gpu_arch}") + print(f"{'=' * 60}") + + if not gpu_arch.startswith("gfx12"): + pytest.skip(f"WMMA test requires RDNA4 (gfx12xx), got {gpu_arch}") + + # Create and compile kernel + print("Creating WMMA kernel...") + m = create_wmma_kernel() + print("Compiling...") + exe = flydsl.compile(m) + + # Prepare data - use small values to keep f16 precision reasonable + np.random.seed(42) + a_np = np.random.randn(16, 16).astype(np.float16) * 0.1 + b_np = np.random.randn(16, 16).astype(np.float16) * 0.1 + + # Reference: f16 matmul accumulated in f32 + expected = a_np.astype(np.float32) @ b_np.astype(np.float32) + + A = torch.tensor(a_np, device="cuda", dtype=torch.float16) + B = torch.tensor(b_np, device="cuda", dtype=torch.float16) + C = torch.zeros(16, 16, device="cuda", dtype=torch.float32) + + # Run kernel + print("Launching kernel (1 workgroup, 32 threads)...") + exe(A, B, C) + torch.cuda.synchronize() + + c_host = C.cpu().numpy() + + # Verify + error = np.max(np.abs(c_host - expected)) + rel_error = error / (np.max(np.abs(expected)) + 1e-8) + + print(f"Max absolute error: {error:.2e}") + print(f"Max relative error: {rel_error:.2e}") + print(f"Output sample (top-left 4x4):") + print(c_host[:4, :4]) + print(f"Expected sample:") + print(expected[:4, :4]) + + # f16 matmul with 16 accumulation steps: expect ~1e-3 error + assert rel_error < 1e-2, ( + f"WMMA result too far from reference: rel_error={rel_error:.2e}" + ) + print(f"\nPASS - WMMA basic test succeeded!") + + +if __name__ == "__main__": + test_wmma_basic() diff --git a/tests/kernels/test_wmma_gemm.py b/tests/kernels/test_wmma_gemm.py new file mode 100644 index 00000000..a95b7d6e --- /dev/null +++ b/tests/kernels/test_wmma_gemm.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +"""Test for the optimized WMMA GEMM kernel on RDNA4 (gfx12xx). + +Tests: + - Correctness at multiple sizes + - Performance benchmarks at 1024, 2048, 4096 vs PyTorch +""" + +import sys +import os +import numpy as np +import pytest +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +import flydsl +from flydsl.runtime.device import get_rocm_arch + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available", allow_module_level=True) + +gpu_arch = get_rocm_arch() +if not gpu_arch.startswith("gfx12"): + pytest.skip( + f"WMMA GEMM test requires RDNA4 (gfx12xx), got {gpu_arch}", + allow_module_level=True, + ) + + +from kernels.wmma_gemm import create_wmma_gemm_module, BLOCK_M, BLOCK_N, BLOCK_K + + +# All shapes must be multiples of BLOCK_M=128, BLOCK_N=128, BLOCK_K=32 +TEST_SHAPES = [ + (128, 128, 32), # 1 workgroup, 1 K tile (smallest valid) + (128, 128, 128), # 1 workgroup, 4 K tiles + (256, 256, 256), # 4 workgroups + (512, 512, 512), # medium +] + + +@pytest.mark.parametrize("M,N,K", TEST_SHAPES) +def test_wmma_gemm_bf16_f32(M, N, K): + """Test WMMA GEMM with bf16 inputs, f32 output.""" + print(f"\n{'=' * 60}") + print(f"WMMA GEMM Test: M={M}, N={N}, K={K}, in=bf16, out=f32") + print(f"GPU: {gpu_arch}, BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}, BLOCK_K={BLOCK_K}") + print(f"{'=' * 60}") + + m = create_wmma_gemm_module(M, N, K, in_dtype="bf16", out_dtype="f32") + exe = flydsl.compile(m) + + np.random.seed(42) + A = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) * 0.1 + B = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) * 0.1 + C = torch.zeros(M, N, device="cuda", dtype=torch.float32) + + expected = A.float() @ B.float() + + exe(A, B, C) + torch.cuda.synchronize() + + c_host = C.cpu() + e_host = expected.cpu() + error = torch.max(torch.abs(c_host - e_host)).item() + rel_error = error / (torch.max(torch.abs(e_host)).item() + 1e-8) + + print(f"Max absolute error: {error:.2e}") + print(f"Max relative error: {rel_error:.2e}") + + tol = 0.05 + assert rel_error < tol, f"WMMA GEMM error too high: rel_error={rel_error:.2e}" + print("PASS") + + +@pytest.mark.parametrize("M,N,K", [(128, 128, 32), (128, 128, 128)]) +def test_wmma_gemm_f16_f32(M, N, K): + """Test WMMA GEMM with f16 inputs, f32 output.""" + print(f"\n{'=' * 60}") + print(f"WMMA GEMM Test: M={M}, N={N}, K={K}, in=f16, out=f32") + print(f"{'=' * 60}") + + m = create_wmma_gemm_module(M, N, K, in_dtype="f16", out_dtype="f32") + exe = flydsl.compile(m) + + A = torch.randn(M, K, device="cuda", dtype=torch.float16) * 0.1 + B = torch.randn(K, N, device="cuda", dtype=torch.float16) * 0.1 + C = torch.zeros(M, N, device="cuda", dtype=torch.float32) + + expected = A.float() @ B.float() + + exe(A, B, C) + torch.cuda.synchronize() + + c_host = C.cpu() + e_host = expected.cpu() + error = torch.max(torch.abs(c_host - e_host)).item() + rel_error = error / (torch.max(torch.abs(e_host)).item() + 1e-8) + + print(f"Max absolute error: {error:.2e}") + print(f"Max relative error: {rel_error:.2e}") + + tol = 0.02 + assert rel_error < tol, f"WMMA GEMM error too high: rel_error={rel_error:.2e}" + print("PASS") + + +def _run_benchmark(M, N, K, in_dtype="bf16"): + """Run benchmark at given size and return (our_tflops, pt_tflops).""" + import time + + print(f"\n{'=' * 60}") + print(f"WMMA GEMM Benchmark: M={M}, N={N}, K={K}, in={in_dtype}, out=f32") + print(f"GPU: {gpu_arch}") + print(f"{'=' * 60}") + + torch_dtype = torch.bfloat16 if in_dtype == "bf16" else torch.float16 + m = create_wmma_gemm_module(M, N, K, in_dtype=in_dtype, out_dtype="f32") + exe = flydsl.compile(m) + + A = torch.randn(M, K, device="cuda", dtype=torch_dtype) * 0.01 + B = torch.randn(K, N, device="cuda", dtype=torch_dtype) * 0.01 + C = torch.zeros(M, N, device="cuda", dtype=torch.float32) + + # Warmup + for _ in range(5): + exe(A, B, C) + torch.cuda.synchronize() + + # Benchmark + num_iters = 50 + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(num_iters): + exe(A, B, C) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + avg_ms = (elapsed / num_iters) * 1000 + flops = 2 * M * N * K + tflops = flops / (avg_ms / 1000) / 1e12 + + print(f"Average time: {avg_ms:.3f} ms") + print(f"Throughput: {tflops:.2f} TFLOPS") + + # Verify correctness + expected = A.float() @ B.float() + c_host = C.cpu() + e_host = expected.cpu() + error = torch.max(torch.abs(c_host - e_host)).item() + rel_error = error / (torch.max(torch.abs(e_host)).item() + 1e-8) + print(f"Max relative error: {rel_error:.2e}") + assert rel_error < 0.1, f"Benchmark result incorrect: rel_error={rel_error:.2e}" + + # PyTorch reference + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(num_iters): + _ = A @ B + torch.cuda.synchronize() + pt_elapsed = time.perf_counter() - start + pt_avg_ms = (pt_elapsed / num_iters) * 1000 + pt_tflops = flops / (pt_avg_ms / 1000) / 1e12 + print(f"PyTorch bf16 matmul: {pt_avg_ms:.3f} ms, {pt_tflops:.2f} TFLOPS") + print(f"Efficiency vs PyTorch: {tflops / pt_tflops * 100:.1f}%") + + return tflops, pt_tflops + + +def test_wmma_gemm_benchmark(): + """Benchmark WMMA GEMM at 1024.""" + _run_benchmark(1024, 1024, 1024) + + +def test_wmma_gemm_benchmark_large(): + """Benchmark WMMA GEMM at larger sizes.""" + for size in [2048, 4096]: + _run_benchmark(size, size, size) + + +if __name__ == "__main__": + test_wmma_gemm_bf16_f32(128, 128, 32) + test_wmma_gemm_bf16_f32(128, 128, 128) + test_wmma_gemm_benchmark() + test_wmma_gemm_benchmark_large()