diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 74411c0..769351e 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -38,6 +38,9 @@ jobs:
- name: Verify WMMA layouts
run: python verification/verify_wmma_layouts.py --max-diffs 5
+ - name: Verify MFMA layouts
+ run: python verification/verify_mfma_layouts.py --arch all --max-diffs 5
+
- name: Run tests
run: npm test -- --run
diff --git a/.gitignore b/.gitignore
index 879d08c..d5528b1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -48,7 +48,7 @@ vite.config.ts.timestamp-*
# Verification artifacts
verification/results/*.json
verification/__pycache__/
-verification/lib/verification-helper.mjs
+verification/lib/*.mjs
# Project-specific ignores
CLAUDE.md
diff --git a/index.html b/index.html
index 4e15dce..a8361e1 100644
--- a/index.html
+++ b/index.html
@@ -17,6 +17,7 @@
GPU Tensor Layout Visualizer
+
@@ -172,6 +173,64 @@ Operand
+
+
diff --git a/src/data/mfma-instruction-data.ts b/src/data/mfma-instruction-data.ts
new file mode 100644
index 0000000..bdc8e4c
--- /dev/null
+++ b/src/data/mfma-instruction-data.ts
@@ -0,0 +1,976 @@
+// Auto-generated from amd_matrix_instruction_calculator
+// Do not edit manually unless you intentionally update the source data.
+
+export interface MFMAInstructionEntry {
+ name: string
+ mnemonic: string
+ m: number
+ n: number
+ k: number
+ kBase: number
+ inputTypeA: string
+ inputTypeB: string
+ outputType: string
+ laneGroupsShareDimK?: boolean
+ registerBlocks?: number
+}
+
+export const MFMA_INSTRUCTION_DATA: Record<"cdna1" | "cdna2" | "cdna3", MFMAInstructionEntry[]> = {
+ cdna1: [
+ {
+ name: 'v_mfma_f32_16x16x16f16',
+ mnemonic: 'f32_16x16x16f16',
+ m: 16,
+ n: 16,
+ k: 16,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x1f32',
+ mnemonic: 'f32_16x16x1f32',
+ m: 16,
+ n: 16,
+ k: 1,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_f32_16x16x2bf16',
+ mnemonic: 'f32_16x16x2bf16',
+ m: 16,
+ n: 16,
+ k: 2,
+ kBase: 2,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_f32_16x16x4f16',
+ mnemonic: 'f32_16x16x4f16',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_f32_16x16x4f32',
+ mnemonic: 'f32_16x16x4f32',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x8bf16',
+ mnemonic: 'f32_16x16x8bf16',
+ m: 16,
+ n: 16,
+ k: 8,
+ kBase: 2,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x1f32',
+ mnemonic: 'f32_32x32x1f32',
+ m: 32,
+ n: 32,
+ k: 1,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_f32_32x32x2bf16',
+ mnemonic: 'f32_32x32x2bf16',
+ m: 32,
+ n: 32,
+ k: 2,
+ kBase: 2,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_f32_32x32x2f32',
+ mnemonic: 'f32_32x32x2f32',
+ m: 32,
+ n: 32,
+ k: 2,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x4bf16',
+ mnemonic: 'f32_32x32x4bf16',
+ m: 32,
+ n: 32,
+ k: 4,
+ kBase: 2,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x4f16',
+ mnemonic: 'f32_32x32x4f16',
+ m: 32,
+ n: 32,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_f32_32x32x8f16',
+ mnemonic: 'f32_32x32x8f16',
+ m: 32,
+ n: 32,
+ k: 8,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_4x4x1f32',
+ mnemonic: 'f32_4x4x1f32',
+ m: 4,
+ n: 4,
+ k: 1,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ {
+ name: 'v_mfma_f32_4x4x2bf16',
+ mnemonic: 'f32_4x4x2bf16',
+ m: 4,
+ n: 4,
+ k: 2,
+ kBase: 2,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ {
+ name: 'v_mfma_f32_4x4x4f16',
+ mnemonic: 'f32_4x4x4f16',
+ m: 4,
+ n: 4,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ {
+ name: 'v_mfma_i32_16x16x16i8',
+ mnemonic: 'i32_16x16x16i8',
+ m: 16,
+ n: 16,
+ k: 16,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ },
+ {
+ name: 'v_mfma_i32_16x16x4i8',
+ mnemonic: 'i32_16x16x4i8',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_i32_32x32x4i8',
+ mnemonic: 'i32_32x32x4i8',
+ m: 32,
+ n: 32,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_i32_32x32x8i8',
+ mnemonic: 'i32_32x32x8i8',
+ m: 32,
+ n: 32,
+ k: 8,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ },
+ {
+ name: 'v_mfma_i32_4x4x4i8',
+ mnemonic: 'i32_4x4x4i8',
+ m: 4,
+ n: 4,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ ],
+ cdna2: [
+ {
+ name: 'v_mfma_f32_16x16x16bf16_1k',
+ mnemonic: 'f32_16x16x16bf16_1k',
+ m: 16,
+ n: 16,
+ k: 16,
+ kBase: 4,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x16f16',
+ mnemonic: 'f32_16x16x16f16',
+ m: 16,
+ n: 16,
+ k: 16,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x1f32',
+ mnemonic: 'f32_16x16x1f32',
+ m: 16,
+ n: 16,
+ k: 1,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_f32_16x16x2bf16',
+ mnemonic: 'f32_16x16x2bf16',
+ m: 16,
+ n: 16,
+ k: 2,
+ kBase: 2,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_f32_16x16x4bf16_1k',
+ mnemonic: 'f32_16x16x4bf16_1k',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_f32_16x16x4f16',
+ mnemonic: 'f32_16x16x4f16',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_f32_16x16x4f32',
+ mnemonic: 'f32_16x16x4f32',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x8bf16',
+ mnemonic: 'f32_16x16x8bf16',
+ m: 16,
+ n: 16,
+ k: 8,
+ kBase: 2,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x1f32',
+ mnemonic: 'f32_32x32x1f32',
+ m: 32,
+ n: 32,
+ k: 1,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_f32_32x32x2bf16',
+ mnemonic: 'f32_32x32x2bf16',
+ m: 32,
+ n: 32,
+ k: 2,
+ kBase: 2,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_f32_32x32x2f32',
+ mnemonic: 'f32_32x32x2f32',
+ m: 32,
+ n: 32,
+ k: 2,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x4bf16',
+ mnemonic: 'f32_32x32x4bf16',
+ m: 32,
+ n: 32,
+ k: 4,
+ kBase: 2,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x4bf16_1k',
+ mnemonic: 'f32_32x32x4bf16_1k',
+ m: 32,
+ n: 32,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_f32_32x32x4f16',
+ mnemonic: 'f32_32x32x4f16',
+ m: 32,
+ n: 32,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_f32_32x32x8bf16_1k',
+ mnemonic: 'f32_32x32x8bf16_1k',
+ m: 32,
+ n: 32,
+ k: 8,
+ kBase: 4,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x8f16',
+ mnemonic: 'f32_32x32x8f16',
+ m: 32,
+ n: 32,
+ k: 8,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_4x4x1f32',
+ mnemonic: 'f32_4x4x1f32',
+ m: 4,
+ n: 4,
+ k: 1,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ {
+ name: 'v_mfma_f32_4x4x2bf16',
+ mnemonic: 'f32_4x4x2bf16',
+ m: 4,
+ n: 4,
+ k: 2,
+ kBase: 2,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ {
+ name: 'v_mfma_f32_4x4x4bf16_1k',
+ mnemonic: 'f32_4x4x4bf16_1k',
+ m: 4,
+ n: 4,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ {
+ name: 'v_mfma_f32_4x4x4f16',
+ mnemonic: 'f32_4x4x4f16',
+ m: 4,
+ n: 4,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ {
+ name: 'v_mfma_f64_16x16x4f64',
+ mnemonic: 'f64_16x16x4f64',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 1,
+ inputTypeA: 'fp64',
+ inputTypeB: 'fp64',
+ outputType: 'fp64',
+ },
+ {
+ name: 'v_mfma_f64_4x4x4f64',
+ mnemonic: 'f64_4x4x4f64',
+ m: 4,
+ n: 4,
+ k: 4,
+ kBase: 1,
+ inputTypeA: 'fp64',
+ inputTypeB: 'fp64',
+ outputType: 'fp64',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_i32_16x16x16i8',
+ mnemonic: 'i32_16x16x16i8',
+ m: 16,
+ n: 16,
+ k: 16,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ },
+ {
+ name: 'v_mfma_i32_16x16x4i8',
+ mnemonic: 'i32_16x16x4i8',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_i32_32x32x4i8',
+ mnemonic: 'i32_32x32x4i8',
+ m: 32,
+ n: 32,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_i32_32x32x8i8',
+ mnemonic: 'i32_32x32x8i8',
+ m: 32,
+ n: 32,
+ k: 8,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ },
+ {
+ name: 'v_mfma_i32_4x4x4i8',
+ mnemonic: 'i32_4x4x4i8',
+ m: 4,
+ n: 4,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ ],
+ cdna3: [
+ {
+ name: 'v_mfma_f32_16x16x16_bf16',
+ mnemonic: 'f32_16x16x16_bf16',
+ m: 16,
+ n: 16,
+ k: 16,
+ kBase: 4,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x16_f16',
+ mnemonic: 'f32_16x16x16_f16',
+ m: 16,
+ n: 16,
+ k: 16,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x1_4b_f32',
+ mnemonic: 'f32_16x16x1_4b_f32',
+ m: 16,
+ n: 16,
+ k: 1,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_f32_16x16x32_bf8_bf8',
+ mnemonic: 'f32_16x16x32_bf8_bf8',
+ m: 16,
+ n: 16,
+ k: 32,
+ kBase: 8,
+ inputTypeA: 'bf8',
+ inputTypeB: 'bf8',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x32_bf8_fp8',
+ mnemonic: 'f32_16x16x32_bf8_fp8',
+ m: 16,
+ n: 16,
+ k: 32,
+ kBase: 8,
+ inputTypeA: 'bf8',
+ inputTypeB: 'fp8',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x32_fp8_bf8',
+ mnemonic: 'f32_16x16x32_fp8_bf8',
+ m: 16,
+ n: 16,
+ k: 32,
+ kBase: 8,
+ inputTypeA: 'fp8',
+ inputTypeB: 'bf8',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x32_fp8_fp8',
+ mnemonic: 'f32_16x16x32_fp8_fp8',
+ m: 16,
+ n: 16,
+ k: 32,
+ kBase: 8,
+ inputTypeA: 'fp8',
+ inputTypeB: 'fp8',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x4_4b_bf16',
+ mnemonic: 'f32_16x16x4_4b_bf16',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_f32_16x16x4_4b_f16',
+ mnemonic: 'f32_16x16x4_4b_f16',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_f32_16x16x4_f32',
+ mnemonic: 'f32_16x16x4_f32',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_16x16x8_xf32',
+ mnemonic: 'f32_16x16x8_xf32',
+ m: 16,
+ n: 16,
+ k: 8,
+ kBase: 2,
+ inputTypeA: 'tf32',
+ inputTypeB: 'tf32',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x16_bf8_bf8',
+ mnemonic: 'f32_32x32x16_bf8_bf8',
+ m: 32,
+ n: 32,
+ k: 16,
+ kBase: 8,
+ inputTypeA: 'bf8',
+ inputTypeB: 'bf8',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x16_bf8_fp8',
+ mnemonic: 'f32_32x32x16_bf8_fp8',
+ m: 32,
+ n: 32,
+ k: 16,
+ kBase: 8,
+ inputTypeA: 'bf8',
+ inputTypeB: 'fp8',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x16_fp8_bf8',
+ mnemonic: 'f32_32x32x16_fp8_bf8',
+ m: 32,
+ n: 32,
+ k: 16,
+ kBase: 8,
+ inputTypeA: 'fp8',
+ inputTypeB: 'bf8',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x16_fp8_fp8',
+ mnemonic: 'f32_32x32x16_fp8_fp8',
+ m: 32,
+ n: 32,
+ k: 16,
+ kBase: 8,
+ inputTypeA: 'fp8',
+ inputTypeB: 'fp8',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x1_2b_f32',
+ mnemonic: 'f32_32x32x1_2b_f32',
+ m: 32,
+ n: 32,
+ k: 1,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_f32_32x32x2_f32',
+ mnemonic: 'f32_32x32x2_f32',
+ m: 32,
+ n: 32,
+ k: 2,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x4_2b_bf16',
+ mnemonic: 'f32_32x32x4_2b_bf16',
+ m: 32,
+ n: 32,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_f32_32x32x4_2b_f16',
+ mnemonic: 'f32_32x32x4_2b_f16',
+ m: 32,
+ n: 32,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_f32_32x32x4_xf32',
+ mnemonic: 'f32_32x32x4_xf32',
+ m: 32,
+ n: 32,
+ k: 4,
+ kBase: 2,
+ inputTypeA: 'tf32',
+ inputTypeB: 'tf32',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x8_bf16',
+ mnemonic: 'f32_32x32x8_bf16',
+ m: 32,
+ n: 32,
+ k: 8,
+ kBase: 4,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_32x32x8_f16',
+ mnemonic: 'f32_32x32x8_f16',
+ m: 32,
+ n: 32,
+ k: 8,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ },
+ {
+ name: 'v_mfma_f32_4x4x1_16b_f32',
+ mnemonic: 'f32_4x4x1_16b_f32',
+ m: 4,
+ n: 4,
+ k: 1,
+ kBase: 1,
+ inputTypeA: 'fp32',
+ inputTypeB: 'fp32',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ {
+ name: 'v_mfma_f32_4x4x4_16b_bf16',
+ mnemonic: 'f32_4x4x4_16b_bf16',
+ m: 4,
+ n: 4,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ {
+ name: 'v_mfma_f32_4x4x4_16b_f16',
+ mnemonic: 'f32_4x4x4_16b_f16',
+ m: 4,
+ n: 4,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ {
+ name: 'v_mfma_f64_16x16x4_f64',
+ mnemonic: 'f64_16x16x4_f64',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 1,
+ inputTypeA: 'fp64',
+ inputTypeB: 'fp64',
+ outputType: 'fp64',
+ },
+ {
+ name: 'v_mfma_f64_4x4x4_4b_f64',
+ mnemonic: 'f64_4x4x4_4b_f64',
+ m: 4,
+ n: 4,
+ k: 4,
+ kBase: 1,
+ inputTypeA: 'fp64',
+ inputTypeB: 'fp64',
+ outputType: 'fp64',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_i32_16x16x32_i8',
+ mnemonic: 'i32_16x16x32_i8',
+ m: 16,
+ n: 16,
+ k: 32,
+ kBase: 8,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ },
+ {
+ name: 'v_mfma_i32_16x16x4_4b_i8',
+ mnemonic: 'i32_16x16x4_4b_i8',
+ m: 16,
+ n: 16,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 4,
+ },
+ {
+ name: 'v_mfma_i32_32x32x16_i8',
+ mnemonic: 'i32_32x32x16_i8',
+ m: 32,
+ n: 32,
+ k: 16,
+ kBase: 8,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ },
+ {
+ name: 'v_mfma_i32_32x32x4_2b_i8',
+ mnemonic: 'i32_32x32x4_2b_i8',
+ m: 32,
+ n: 32,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 2,
+ },
+ {
+ name: 'v_mfma_i32_4x4x4_16b_i8',
+ mnemonic: 'i32_4x4x4_16b_i8',
+ m: 4,
+ n: 4,
+ k: 4,
+ kBase: 4,
+ inputTypeA: 'int8',
+ inputTypeB: 'int8',
+ outputType: 'int32',
+ laneGroupsShareDimK: true,
+ registerBlocks: 16,
+ },
+ ],
+}
diff --git a/src/data/wmma-instruction-data.ts b/src/data/wmma-instruction-data.ts
new file mode 100644
index 0000000..4d08364
--- /dev/null
+++ b/src/data/wmma-instruction-data.ts
@@ -0,0 +1,225 @@
+// WMMA instruction metadata shared across layouts and tooling.
+
+export interface WMMAInstructionEntry {
+ name: string
+ mnemonic: string
+ m: number
+ n: number
+ k: number
+ inputTypeA: string
+ inputTypeB: string
+ outputType: string
+ cycles: number
+ kWidth: number
+}
+
+export const WMMA_INSTRUCTION_DATA: Record<'rdna3' | 'rdna4', WMMAInstructionEntry[]> = {
+ rdna3: [
+ {
+ name: 'v_wmma_f32_16x16x16_f16',
+ mnemonic: 'f32_16x16x16_f16',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ cycles: 32,
+ kWidth: 16,
+ },
+ {
+ name: 'v_wmma_f32_16x16x16_bf16',
+ mnemonic: 'f32_16x16x16_bf16',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ cycles: 32,
+ kWidth: 16,
+ },
+ {
+ name: 'v_wmma_f16_16x16x16_f16',
+ mnemonic: 'f16_16x16x16_f16',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp16',
+ cycles: 32,
+ kWidth: 16,
+ },
+ {
+ name: 'v_wmma_bf16_16x16x16_bf16',
+ mnemonic: 'bf16_16x16x16_bf16',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'bf16',
+ cycles: 32,
+ kWidth: 16,
+ },
+ {
+ name: 'v_wmma_i32_16x16x16_iu8',
+ mnemonic: 'i32_16x16x16_iu8',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'iu8',
+ inputTypeB: 'iu8',
+ outputType: 'i32',
+ cycles: 32,
+ kWidth: 16,
+ },
+ {
+ name: 'v_wmma_i32_16x16x16_iu4',
+ mnemonic: 'i32_16x16x16_iu4',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'iu4',
+ inputTypeB: 'iu4',
+ outputType: 'i32',
+ cycles: 16,
+ kWidth: 16,
+ },
+ ],
+ rdna4: [
+ {
+ name: 'v_wmma_f32_16x16x16_f16',
+ mnemonic: 'f32_16x16x16_f16',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp32',
+ cycles: 16,
+ kWidth: 8,
+ },
+ {
+ name: 'v_wmma_f32_16x16x16_bf16',
+ mnemonic: 'f32_16x16x16_bf16',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'fp32',
+ cycles: 16,
+ kWidth: 8,
+ },
+ {
+ name: 'v_wmma_f16_16x16x16_f16',
+ mnemonic: 'f16_16x16x16_f16',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'fp16',
+ inputTypeB: 'fp16',
+ outputType: 'fp16',
+ cycles: 16,
+ kWidth: 8,
+ },
+ {
+ name: 'v_wmma_bf16_16x16x16_bf16',
+ mnemonic: 'bf16_16x16x16_bf16',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'bf16',
+ inputTypeB: 'bf16',
+ outputType: 'bf16',
+ cycles: 16,
+ kWidth: 8,
+ },
+ {
+ name: 'v_wmma_i32_16x16x16_iu8',
+ mnemonic: 'i32_16x16x16_iu8',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'iu8',
+ inputTypeB: 'iu8',
+ outputType: 'i32',
+ cycles: 8,
+ kWidth: 8,
+ },
+ {
+ name: 'v_wmma_i32_16x16x16_iu4',
+ mnemonic: 'i32_16x16x16_iu4',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'iu4',
+ inputTypeB: 'iu4',
+ outputType: 'i32',
+ cycles: 8,
+ kWidth: 8,
+ },
+ {
+ name: 'v_wmma_i32_16x16x32_iu4',
+ mnemonic: 'i32_16x16x32_iu4',
+ m: 16,
+ n: 16,
+ k: 32,
+ inputTypeA: 'iu4',
+ inputTypeB: 'iu4',
+ outputType: 'i32',
+ cycles: 8,
+ kWidth: 16,
+ },
+ {
+ name: 'v_wmma_f32_16x16x16_fp8_fp8',
+ mnemonic: 'f32_16x16x16_fp8_fp8',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'fp8',
+ inputTypeB: 'fp8',
+ outputType: 'fp32',
+ cycles: 8,
+ kWidth: 8,
+ },
+ {
+ name: 'v_wmma_f32_16x16x16_fp8_bf8',
+ mnemonic: 'f32_16x16x16_fp8_bf8',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'fp8',
+ inputTypeB: 'bf8',
+ outputType: 'fp32',
+ cycles: 8,
+ kWidth: 8,
+ },
+ {
+ name: 'v_wmma_f32_16x16x16_bf8_fp8',
+ mnemonic: 'f32_16x16x16_bf8_fp8',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'bf8',
+ inputTypeB: 'fp8',
+ outputType: 'fp32',
+ cycles: 8,
+ kWidth: 8,
+ },
+ {
+ name: 'v_wmma_f32_16x16x16_bf8_bf8',
+ mnemonic: 'f32_16x16x16_bf8_bf8',
+ m: 16,
+ n: 16,
+ k: 16,
+ inputTypeA: 'bf8',
+ inputTypeB: 'bf8',
+ outputType: 'fp32',
+ cycles: 8,
+ kWidth: 8,
+ },
+ ],
+}
diff --git a/src/layouts/MFMALayout.ts b/src/layouts/MFMALayout.ts
new file mode 100644
index 0000000..3b6d1ee
--- /dev/null
+++ b/src/layouts/MFMALayout.ts
@@ -0,0 +1,325 @@
+import { LinearLayout } from '../core/LinearLayout'
+import { MFMA_INSTRUCTION_DATA, type MFMAInstructionEntry } from '../data/mfma-instruction-data'
+import {
+ getElementSizeBits,
+ resolveOperandDataType,
+} from './shared/PackingUtils'
+
+/**
+ * MFMA instruction metadata used by the visualization.
+ *
+ * Mirrors Triton's MFMA database (see
+ * third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp) while adding an
+ * explicit `version` tag per architecture (1 = CDNA1, 2 = CDNA2, 3 = CDNA3).
+ */
+export interface MFMAInstruction extends MFMAInstructionEntry {
+ version: number
+}
+
+export { resolveOperandDataType }
+
+/**
+ * MFMA operand selector
+ */
+export type MFMAOperand = 'A' | 'B' | 'D'
+export type MFMAArchitecture = 'cdna1' | 'cdna2' | 'cdna3'
+
+const MFMA_WAVE_SIZE = 64
+const MFMA_REGISTER_WIDTH = 32
+const MFMA_REGISTER_DIM = 'register'
+const MFMA_LANE_DIM = 'lane'
+const MFMA_DIM_M = 'dimM'
+const MFMA_DIM_N = 'dimN'
+const MFMA_DIM_K = 'dimK'
+
+function wrapToRange(value: number | undefined, limit: number): number {
+ if (typeof value !== 'number' || limit <= 0) {
+ return value ?? 0
+ }
+ const wrapped = ((value % limit) + limit) % limit
+ return wrapped
+}
+
+function usesShared4x4Registers(instr: MFMAInstruction): boolean {
+ return instr.m === 4 && instr.n === 4
+}
+
+function withVersion(version: number, defs: MFMAInstructionEntry[]): MFMAInstruction[] {
+ return defs.map(def => ({ ...def, version }))
+}
+
+export const CDNA1_MFMA_INSTRUCTIONS: MFMAInstruction[] = withVersion(
+ 1,
+ MFMA_INSTRUCTION_DATA.cdna1
+)
+export const CDNA2_MFMA_INSTRUCTIONS: MFMAInstruction[] = withVersion(
+ 2,
+ MFMA_INSTRUCTION_DATA.cdna2
+)
+export const CDNA3_MFMA_INSTRUCTIONS: MFMAInstruction[] = withVersion(
+ 3,
+ MFMA_INSTRUCTION_DATA.cdna3
+)
+
+const MFMA_INSTRUCTIONS_BY_ARCH: Record = {
+ cdna1: CDNA1_MFMA_INSTRUCTIONS,
+ cdna2: CDNA2_MFMA_INSTRUCTIONS,
+ cdna3: CDNA3_MFMA_INSTRUCTIONS,
+}
+
+function ensurePowerOfTwo(value: number, context: string): void {
+ if (value <= 0 || (value & (value - 1)) !== 0) {
+ throw new Error(`${context} must be a positive power of two (got ${value})`)
+ }
+}
+
+function getLaneHeight(instr: MFMAInstruction): number {
+ const bitWidth = getElementSizeBits(instr.outputType)
+ return bitWidth === 64 ? 1 : 4
+}
+
+function getLayoutOutDimSize(layout: LinearLayout, dim: string, fallback: number): number {
+ return layout.hasOutDim(dim) ? layout.getOutDimSize(dim) : fallback
+}
+
+function getLayoutInDimSize(layout: LinearLayout, dim: string): number {
+ return layout.hasInDim(dim) ? layout.getInDimSize(dim) : 1
+}
+
+/**
+ * Create the MFMA output layout (operand D) by emulating
+ * AMDMfmaEncodingAttr::toLinearLayout for a single wave64 tile.
+ */
+function createMFMAOutputLayout(instr: MFMAInstruction): LinearLayout {
+ const mDim = instr.m
+ const nDim = instr.n
+
+ ensurePowerOfTwo(mDim, 'MFMA output M dimension')
+ ensurePowerOfTwo(nDim, 'MFMA output N dimension')
+
+ if (MFMA_WAVE_SIZE % nDim !== 0) {
+ throw new Error(`MFMA output requires N dimension dividing wave size (got N=${nDim})`)
+ }
+
+ const height = getLaneHeight(instr)
+ ensurePowerOfTwo(height, 'MFMA lane height')
+
+ const regs = LinearLayout.identity1D(height, MFMA_REGISTER_DIM, MFMA_DIM_M)
+ const lanes =
+ LinearLayout.identity1D(nDim, MFMA_LANE_DIM, MFMA_DIM_N).multiply(
+ LinearLayout.identity1D(MFMA_WAVE_SIZE / nDim, MFMA_LANE_DIM, MFMA_DIM_M)
+ )
+
+ let tileLayout = regs.multiply(lanes)
+
+ const tiles = Math.max(1, Math.trunc((mDim * nDim) / (MFMA_WAVE_SIZE * height)))
+ ensurePowerOfTwo(tiles, 'MFMA tile replication factor')
+
+ if (tiles > 1) {
+ tileLayout = tileLayout.multiply(LinearLayout.identity1D(tiles, MFMA_REGISTER_DIM, MFMA_DIM_M))
+ }
+
+ return tileLayout
+}
+
+/**
+ * Create an MFMA operand layout.
+ *
+ * - Operand D reuses the output layout above.
+ * - Operand A maps to M × K, operand B maps to K × N.
+ *
+ * The implementation mirrors mfmaDotToLinearLayout from Triton with a focus
+ * on a single wave64 tile (warps-per-CTA == tiles-per-warp == 1).
+ */
+export function createMFMAOperandLayout(instr: MFMAInstruction, operand: MFMAOperand): LinearLayout {
+ if (operand === 'D') {
+ return createMFMAOutputLayout(instr)
+ }
+
+ const nonKDim = operand === 'A' ? instr.m : instr.n
+ ensurePowerOfTwo(nonKDim, 'MFMA operand non-K dimension')
+
+ if (MFMA_WAVE_SIZE % nonKDim !== 0) {
+ throw new Error(
+ `MFMA operand ${operand} expects non-K dimension dividing wave size (got ${nonKDim})`
+ )
+ }
+
+ const kWidth = instr.kBase ?? instr.k
+ ensurePowerOfTwo(kWidth, 'MFMA operand K width')
+
+ const dimNonK = operand === 'A' ? MFMA_DIM_M : MFMA_DIM_N
+ const dimK = MFMA_DIM_K
+
+ const regs = LinearLayout.identity1D(kWidth, MFMA_REGISTER_DIM, dimK)
+
+ const lanes =
+ LinearLayout.identity1D(nonKDim, MFMA_LANE_DIM, dimNonK).multiply(
+ LinearLayout.identity1D(MFMA_WAVE_SIZE / nonKDim, MFMA_LANE_DIM, dimK)
+ )
+
+ let tileLayout = regs.multiply(lanes)
+ let kTileSize = (MFMA_WAVE_SIZE / nonKDim) * kWidth
+
+ const touches64x4Operand =
+ (instr.m === 64 && instr.n === 4 && operand === 'A') ||
+ (instr.m === 4 && instr.n === 64 && operand === 'B')
+
+ if (touches64x4Operand) {
+ const replication = 16
+ tileLayout = tileLayout.multiply(LinearLayout.identity1D(replication, MFMA_REGISTER_DIM, dimK))
+ kTileSize *= replication
+ }
+
+ const kDimSize = instr.k
+ if (kDimSize > kTileSize) {
+ if (kDimSize % kTileSize !== 0) {
+ throw new Error(
+ `MFMA operand ${operand} requires K dimension (${kDimSize}) to be a multiple of tile size (${kTileSize})`
+ )
+ }
+ const replication = Math.trunc(kDimSize / kTileSize)
+ ensurePowerOfTwo(replication, 'MFMA operand K replication factor')
+ tileLayout = tileLayout.multiply(
+ LinearLayout.identity1D(replication, MFMA_REGISTER_DIM, dimK)
+ )
+ }
+
+ return tileLayout
+}
+
+/**
+ * Fetch MFMA instructions for a given architecture tag.
+ */
+export function getMFMAInstructions(architecture: MFMAArchitecture): MFMAInstruction[] {
+ const instructions = MFMA_INSTRUCTIONS_BY_ARCH[architecture]
+ if (!instructions) {
+ throw new Error(`Unsupported MFMA architecture: ${architecture}`)
+ }
+ return instructions
+}
+
+export interface MFMALayoutParams {
+ architecture: MFMAArchitecture
+ instruction: MFMAInstruction
+ operand: MFMAOperand
+}
+
+/**
+ * Build a LinearLayout for the requested MFMA operand. Acts as a small wrapper
+ * around createMFMAOperandLayout so callers can pass the full tab parameter bag.
+ */
+export function createMFMALayout(params: MFMALayoutParams): LinearLayout {
+ return createMFMAOperandLayout(params.instruction, params.operand)
+}
+
+/**
+ * Enumerate all (row, column) pairs covered by a given thread.
+ *
+ * For operands A/B we present coordinates as (nonK, K) == (M, K) / (K, N)
+ * to match the Triton calculators and AMD tooling.
+ */
+export function getMFMAPositionsForThread(
+ layout: LinearLayout,
+ instruction: MFMAInstruction,
+ operand: MFMAOperand,
+ threadId: number
+): Array<{ pos: [number, number]; registerId: number }> {
+ const positions: Array<{ pos: [number, number]; registerId: number }> = []
+
+ const lane = ((threadId % MFMA_WAVE_SIZE) + MFMA_WAVE_SIZE) % MFMA_WAVE_SIZE
+ const registerCount = getLayoutInDimSize(layout, MFMA_REGISTER_DIM)
+ const operandType = resolveOperandDataType(instruction, operand)
+ const elementBits = getElementSizeBits(operandType)
+ const registersPerElement = Math.max(1, Math.ceil(elementBits / MFMA_REGISTER_WIDTH))
+
+ const blockCount = Math.max(1, instruction.registerBlocks ?? 1)
+ const shouldWrapDimK = instruction.laneGroupsShareDimK || blockCount > 1
+ const wrapDimM = (operand === 'D' || operand === 'A') && instruction.m > 0
+ const wrapDimN = (operand === 'D' || operand === 'B') && instruction.n > 0
+ const shareRegisterBlocks = operand === 'D' && usesShared4x4Registers(instruction)
+ const isFp64Shared = registersPerElement > 1 && usesShared4x4Registers(instruction)
+ const registerStride = shareRegisterBlocks ? 0 : registerCount * registersPerElement
+
+ const dimMSz = getLayoutOutDimSize(layout, MFMA_DIM_M, instruction.m)
+ const dimNSz = getLayoutOutDimSize(layout, MFMA_DIM_N, instruction.n)
+ const dimKSz = getLayoutOutDimSize(layout, MFMA_DIM_K, instruction.k)
+
+ const registerBlocks = operand === 'D' ? (shareRegisterBlocks ? 1 : blockCount) : 1
+
+ for (let reg = 0; reg < registerCount; reg++) {
+ const result = layout.apply({
+ [MFMA_REGISTER_DIM]: reg,
+ [MFMA_LANE_DIM]: lane,
+ })
+
+ if (
+ shouldWrapDimK &&
+ typeof result[MFMA_DIM_K] === 'number' &&
+ instruction.k > 0
+ ) {
+ const currentK = result[MFMA_DIM_K]
+ const wrapped = ((currentK % instruction.k) + instruction.k) % instruction.k
+ result[MFMA_DIM_K] = wrapped
+ }
+
+ if (wrapDimM) {
+ result[MFMA_DIM_M] = wrapToRange(result[MFMA_DIM_M], instruction.m)
+ }
+ if (wrapDimN) {
+ result[MFMA_DIM_N] = wrapToRange(result[MFMA_DIM_N], instruction.n)
+ }
+
+ if (isFp64Shared) {
+ const laneGroup = Math.floor(lane / 16)
+ const laneQuad = lane % 4
+ if (operand === 'A') {
+ result[MFMA_DIM_M] = wrapToRange(laneQuad, instruction.m)
+ result[MFMA_DIM_K] = wrapToRange(laneGroup, instruction.k)
+ } else if (operand === 'B') {
+ result[MFMA_DIM_K] = wrapToRange(laneGroup, instruction.k)
+ result[MFMA_DIM_N] = wrapToRange(laneQuad, instruction.n)
+ } else {
+ result[MFMA_DIM_M] = wrapToRange(laneGroup, instruction.m)
+ result[MFMA_DIM_N] = wrapToRange(laneQuad, instruction.n)
+ }
+ }
+
+ let dim0 = 0
+ let dim1 = 0
+ let dim0Limit = 0
+ let dim1Limit = 0
+
+ if (operand === 'D') {
+ dim0 = result[MFMA_DIM_M] ?? 0
+ dim1 = result[MFMA_DIM_N] ?? 0
+ dim0Limit = dimMSz
+ dim1Limit = dimNSz
+ } else if (operand === 'A') {
+ dim0 = result[MFMA_DIM_M] ?? 0
+ dim1 = result[MFMA_DIM_K] ?? 0
+ dim0Limit = dimMSz
+ dim1Limit = dimKSz
+ } else {
+ dim0 = result[MFMA_DIM_K] ?? 0
+ dim1 = result[MFMA_DIM_N] ?? 0
+ dim0Limit = dimKSz
+ dim1Limit = dimNSz
+ }
+
+ if (dim0 >= 0 && dim0 < dim0Limit && dim1 >= 0 && dim1 < dim1Limit) {
+ for (let block = 0; block < registerBlocks; block++) {
+ const baseRegister = reg * registersPerElement
+ const blockOffset = block * registerStride
+ for (let chunk = 0; chunk < registersPerElement; chunk++) {
+ positions.push({
+ pos: [dim0, dim1],
+ registerId: baseRegister + chunk + blockOffset,
+ })
+ }
+ }
+ }
+ }
+
+ return positions
+}
diff --git a/src/layouts/WMMALayout.ts b/src/layouts/WMMALayout.ts
index a434cb5..00b224b 100644
--- a/src/layouts/WMMALayout.ts
+++ b/src/layouts/WMMALayout.ts
@@ -1,281 +1,49 @@
import { LinearLayout } from '../core/LinearLayout'
+import {
+ WMMA_INSTRUCTION_DATA,
+ type WMMAInstructionEntry,
+} from '../data/wmma-instruction-data'
+import {
+ getElementSizeBits,
+ resolveOperandDataType,
+} from './shared/PackingUtils'
+
+export { resolveOperandDataType }
/**
* WMMA instruction definition
*/
-export interface WMMAInstruction {
- name: string
- mnemonic: string
- m: number
- n: number
- k: number
- inputTypeA: string
- inputTypeB: string
- outputType: string
- cycles: number
- kWidth: number
+export interface WMMAInstruction extends WMMAInstructionEntry {
version: number // 1 for RDNA3, 2 for RDNA4
}
+export type WMMAArchitecture = 'rdna3' | 'rdna4'
+
/**
* WMMA operand type
*/
export type WMMAOperand = 'A' | 'B' | 'D'
-const ELEMENT_TYPE_SIZES: Record = {
- fp16: 16,
- bf16: 16,
- fp8: 8,
- bf8: 8,
- iu8: 8,
- iu4: 4,
- fp32: 32,
- i32: 32,
+const WMMA_ARCH_VERSION: Record = {
+ rdna3: 1,
+ rdna4: 2,
}
-export function resolveOperandDataType(instr: WMMAInstruction, operand: WMMAOperand): string {
- if (operand === 'D') {
- return instr.outputType
- }
- return operand === 'A' ? instr.inputTypeA : instr.inputTypeB
+const WMMA_INSTRUCTION_CACHE: Record = {
+ rdna3: buildInstructionList('rdna3'),
+ rdna4: buildInstructionList('rdna4'),
}
-function getElementSizeBits(type: string): number {
- return ELEMENT_TYPE_SIZES[type] ?? 32
-}
+export const RDNA3_WMMA_INSTRUCTIONS: WMMAInstruction[] = WMMA_INSTRUCTION_CACHE.rdna3
+export const RDNA4_WMMA_INSTRUCTIONS: WMMAInstruction[] = WMMA_INSTRUCTION_CACHE.rdna4
-/**
- * RDNA3 (gfx11) WMMA instructions
- */
-export const RDNA3_WMMA_INSTRUCTIONS: WMMAInstruction[] = [
- {
- name: 'v_wmma_f32_16x16x16_f16',
- mnemonic: 'f32_16x16x16_f16',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'fp16',
- inputTypeB: 'fp16',
- outputType: 'fp32',
- cycles: 32,
- kWidth: 16,
- version: 1,
- },
- {
- name: 'v_wmma_f32_16x16x16_bf16',
- mnemonic: 'f32_16x16x16_bf16',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'bf16',
- inputTypeB: 'bf16',
- outputType: 'fp32',
- cycles: 32,
- kWidth: 16,
- version: 1,
- },
- {
- name: 'v_wmma_f16_16x16x16_f16',
- mnemonic: 'f16_16x16x16_f16',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'fp16',
- inputTypeB: 'fp16',
- outputType: 'fp16',
- cycles: 32,
- kWidth: 16,
- version: 1,
- },
- {
- name: 'v_wmma_bf16_16x16x16_bf16',
- mnemonic: 'bf16_16x16x16_bf16',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'bf16',
- inputTypeB: 'bf16',
- outputType: 'bf16',
- cycles: 32,
- kWidth: 16,
- version: 1,
- },
- {
- name: 'v_wmma_i32_16x16x16_iu8',
- mnemonic: 'i32_16x16x16_iu8',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'iu8',
- inputTypeB: 'iu8',
- outputType: 'i32',
- cycles: 32,
- kWidth: 16,
- version: 1,
- },
- {
- name: 'v_wmma_i32_16x16x16_iu4',
- mnemonic: 'i32_16x16x16_iu4',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'iu4',
- inputTypeB: 'iu4',
- outputType: 'i32',
- cycles: 16,
- kWidth: 16,
- version: 1,
- },
-]
-
-/**
- * RDNA4 (gfx12) WMMA instructions
- */
-export const RDNA4_WMMA_INSTRUCTIONS: WMMAInstruction[] = [
- {
- name: 'v_wmma_f32_16x16x16_f16',
- mnemonic: 'f32_16x16x16_f16',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'fp16',
- inputTypeB: 'fp16',
- outputType: 'fp32',
- cycles: 16,
- kWidth: 8,
- version: 2,
- },
- {
- name: 'v_wmma_f32_16x16x16_bf16',
- mnemonic: 'f32_16x16x16_bf16',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'bf16',
- inputTypeB: 'bf16',
- outputType: 'fp32',
- cycles: 16,
- kWidth: 8,
- version: 2,
- },
- {
- name: 'v_wmma_f16_16x16x16_f16',
- mnemonic: 'f16_16x16x16_f16',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'fp16',
- inputTypeB: 'fp16',
- outputType: 'fp16',
- cycles: 16,
- kWidth: 8,
- version: 2,
- },
- {
- name: 'v_wmma_bf16_16x16x16_bf16',
- mnemonic: 'bf16_16x16x16_bf16',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'bf16',
- inputTypeB: 'bf16',
- outputType: 'bf16',
- cycles: 16,
- kWidth: 8,
- version: 2,
- },
- {
- name: 'v_wmma_i32_16x16x16_iu8',
- mnemonic: 'i32_16x16x16_iu8',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'iu8',
- inputTypeB: 'iu8',
- outputType: 'i32',
- cycles: 8,
- kWidth: 8,
- version: 2,
- },
- {
- name: 'v_wmma_i32_16x16x16_iu4',
- mnemonic: 'i32_16x16x16_iu4',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'iu4',
- inputTypeB: 'iu4',
- outputType: 'i32',
- cycles: 8,
- kWidth: 8,
- version: 2,
- },
- {
- name: 'v_wmma_i32_16x16x32_iu4',
- mnemonic: 'i32_16x16x32_iu4',
- m: 16,
- n: 16,
- k: 32,
- inputTypeA: 'iu4',
- inputTypeB: 'iu4',
- outputType: 'i32',
- cycles: 8,
- kWidth: 16,
- version: 2,
- },
- {
- name: 'v_wmma_f32_16x16x16_fp8_fp8',
- mnemonic: 'f32_16x16x16_fp8_fp8',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'fp8',
- inputTypeB: 'fp8',
- outputType: 'fp32',
- cycles: 8,
- kWidth: 8,
- version: 2,
- },
- {
- name: 'v_wmma_f32_16x16x16_fp8_bf8',
- mnemonic: 'f32_16x16x16_fp8_bf8',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'fp8',
- inputTypeB: 'bf8',
- outputType: 'fp32',
- cycles: 8,
- kWidth: 8,
- version: 2,
- },
- {
- name: 'v_wmma_f32_16x16x16_bf8_fp8',
- mnemonic: 'f32_16x16x16_bf8_fp8',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'bf8',
- inputTypeB: 'fp8',
- outputType: 'fp32',
- cycles: 8,
- kWidth: 8,
- version: 2,
- },
- {
- name: 'v_wmma_f32_16x16x16_bf8_bf8',
- mnemonic: 'f32_16x16x16_bf8_bf8',
- m: 16,
- n: 16,
- k: 16,
- inputTypeA: 'bf8',
- inputTypeB: 'bf8',
- outputType: 'fp32',
- cycles: 8,
- kWidth: 8,
- version: 2,
- },
-]
+function buildInstructionList(architecture: WMMAArchitecture): WMMAInstruction[] {
+ const version = WMMA_ARCH_VERSION[architecture]
+ return WMMA_INSTRUCTION_DATA[architecture].map((entry) => ({
+ ...entry,
+ version,
+ }))
+}
/**
* Create WMMA output (D) layout
@@ -586,15 +354,15 @@ export function createWMMAOperandLayout(
/**
* Get all available WMMA instructions for an architecture
*/
-export function getWMMAInstructions(architecture: 'rdna3' | 'rdna4'): WMMAInstruction[] {
- return architecture === 'rdna3' ? RDNA3_WMMA_INSTRUCTIONS : RDNA4_WMMA_INSTRUCTIONS
+export function getWMMAInstructions(architecture: WMMAArchitecture): WMMAInstruction[] {
+ return WMMA_INSTRUCTION_CACHE[architecture].map((instr) => ({ ...instr }))
}
/**
* Parameters for WMMA visualization
*/
export interface WMMALayoutParams {
- architecture: 'rdna3' | 'rdna4'
+ architecture: WMMAArchitecture
instruction: WMMAInstruction
operand: WMMAOperand
}
diff --git a/src/layouts/WMMAPacking.ts b/src/layouts/WMMAPacking.ts
index 74ee7fa..6615ae8 100644
--- a/src/layouts/WMMAPacking.ts
+++ b/src/layouts/WMMAPacking.ts
@@ -1,29 +1,32 @@
import type { WMMAInstruction, WMMAOperand } from './WMMALayout'
-
-export const REGISTER_WIDTH_BITS = 32
-export const DEFAULT_ELEMENT_BITS = 32
-
-const ELEMENT_TYPE_BIT_SIZES: Record = {
- f16: 16,
- fp16: 16,
- bf16: 16,
- fp8: 8,
- bf8: 8,
- iu8: 8,
- iu4: 4,
- fp32: 32,
- i32: 32,
-}
+import {
+ DEFAULT_ELEMENT_BITS as SHARED_DEFAULT_ELEMENT_BITS,
+ VGPR_REGISTER_WIDTH_BITS,
+ computeElementsPerRegister as computeElementsPerRegisterBase,
+} from './shared/PackingUtils'
+
+export {
+ createPackingMetadata,
+ computePackingMetadata,
+ getLogicalElementMetadata,
+ decomposeLogicalIndex,
+ formatBitRange,
+ getElementSizeBits,
+} from './shared/RegisterMetadata'
+export type {
+ PackingComputationOptions,
+ PackingMetadata,
+ LogicalElementMetadata,
+} from './shared/RegisterMetadata'
+
+export const REGISTER_WIDTH_BITS = VGPR_REGISTER_WIDTH_BITS
+export const DEFAULT_ELEMENT_BITS = SHARED_DEFAULT_ELEMENT_BITS
/**
- * Look up the number of bits consumed by a logical element of the given type.
+ * WMMA-specific overrides for register packing behavior. Shared helpers live in
+ * src/layouts/shared/RegisterMetadata.ts so other architectures do not depend
+ * on WMMA implementation details.
*/
-export function getElementSizeBits(dataType: string | undefined): number {
- if (!dataType) {
- return DEFAULT_ELEMENT_BITS
- }
- return ELEMENT_TYPE_BIT_SIZES[dataType] ?? DEFAULT_ELEMENT_BITS
-}
/**
* Determine how many logical elements are packed into a single physical VGPR.
@@ -34,124 +37,7 @@ export function computeElementsPerRegister(
elementBits: number
): number {
const sanitizedBits = elementBits > 0 ? elementBits : DEFAULT_ELEMENT_BITS
- const packableElements = Math.max(1, Math.floor(REGISTER_WIDTH_BITS / sanitizedBits))
+ const packableElements = computeElementsPerRegisterBase(REGISTER_WIDTH_BITS, sanitizedBits)
const requiresSingleElement = operand === 'D' && instruction.version < 2
return requiresSingleElement ? 1 : packableElements
}
-
-/**
- * Convert a logical element index into the physical VGPR identifier and the
- * packed slot offset within that register.
- */
-export function decomposeLogicalIndex(
- logicalIndex: number,
- elementsPerRegister: number
-): { physicalRegister: number; elementOffset: number } {
- if (elementsPerRegister <= 0) {
- throw new Error('elementsPerRegister must be positive')
- }
- const elementOffset =
- ((logicalIndex % elementsPerRegister) + elementsPerRegister) % elementsPerRegister
- const physicalRegister = Math.floor(logicalIndex / elementsPerRegister)
- return { physicalRegister, elementOffset }
-}
-
-/**
- * Format a bit-range string describing where the packed element resides.
- */
-export function formatBitRange(
- elementOffset: number,
- elementBits: number,
- elementsPerRegister: number
-): string | null {
- const sanitizedBits = elementBits > 0 ? elementBits : DEFAULT_ELEMENT_BITS
- if (elementsPerRegister <= 1) {
- if (sanitizedBits >= REGISTER_WIDTH_BITS) {
- const fullWidthEndBit = REGISTER_WIDTH_BITS - 1
- return `[${fullWidthEndBit}:0]`
- }
- const endBit = sanitizedBits - 1
- if (endBit < 0) {
- return null
- }
- return `[${endBit}:0]`
- }
-
- const startBit = Math.max(0, elementOffset * sanitizedBits)
- const endBit = startBit + sanitizedBits - 1
- if (endBit < startBit) {
- return null
- }
- return `[${endBit}:${startBit}]`
-}
-
-/**
- * Convenience helper that returns both element bit width and packing factor.
- */
-export function computePackingMetadata(
- instruction: WMMAInstruction,
- operand: WMMAOperand,
- dataType: string | undefined
-): { elementBits: number; elementsPerRegister: number } {
- const elementBits = getElementSizeBits(dataType)
- const elementsPerRegister = computeElementsPerRegister(instruction, operand, elementBits)
- return { elementBits, elementsPerRegister }
-}
-
-export interface PackingMetadata {
- elementBits: number
- elementsPerRegister: number
-}
-
-export interface LogicalElementMetadata {
- physicalRegister: number
- elementOffset: number
- bitRange: string | null
- elementsPerRegister: number
-}
-
-/**
- * Create packing metadata for a given WMMA instruction/operand combination.
- */
-export function createPackingMetadata(
- instruction: WMMAInstruction,
- operand: WMMAOperand,
- dataType: string | undefined
-): PackingMetadata {
- return computePackingMetadata(instruction, operand, dataType)
-}
-
-/**
- * Translate a logical element index into physical packing metadata.
- */
-export function getLogicalElementMetadata(
- logicalIndex: number,
- metadata: PackingMetadata
-): LogicalElementMetadata | null {
- if (logicalIndex < 0) {
- return null
- }
-
- const { elementsPerRegister, elementBits } = metadata
- if (elementsPerRegister <= 0) {
- return null
- }
-
- let physicalRegister = 0
- let elementOffset = 0
- try {
- const decomposition = decomposeLogicalIndex(logicalIndex, elementsPerRegister)
- physicalRegister = decomposition.physicalRegister
- elementOffset = decomposition.elementOffset
- } catch {
- return null
- }
-
- const bitRange = formatBitRange(elementOffset, elementBits, elementsPerRegister)
- return {
- physicalRegister,
- elementOffset,
- bitRange,
- elementsPerRegister,
- }
-}
diff --git a/src/layouts/shared/PackingUtils.ts b/src/layouts/shared/PackingUtils.ts
new file mode 100644
index 0000000..4df89b9
--- /dev/null
+++ b/src/layouts/shared/PackingUtils.ts
@@ -0,0 +1,115 @@
+/**
+ * Shared packing helpers for WMMA and MFMA visualizations.
+ *
+ * Consolidates common logic such as operand data type resolution, element
+ * bit-width lookups, logical index decomposition, and bit-range formatting.
+ */
+
+export type MatmulOperand = 'A' | 'B' | 'D'
+
+export interface OperandTypeSources {
+ inputTypeA: string
+ inputTypeB: string
+ outputType: string
+}
+
+export const DEFAULT_ELEMENT_BITS = 32
+export const VGPR_REGISTER_WIDTH_BITS = 32
+
+const ELEMENT_BIT_WIDTHS: Record = {
+ f16: 16,
+ bf8: 8,
+ bf16: 16,
+ fp8: 8,
+ fp16: 16,
+ fp32: 32,
+ fp64: 64,
+ tf32: 32,
+ iu4: 4,
+ iu8: 8,
+ i32: 32,
+ int8: 8,
+ int32: 32,
+}
+
+/**
+ * Map an operand selector to its resolved data type.
+ */
+export function resolveOperandDataType(
+ instruction: T,
+ operand: MatmulOperand
+): string {
+ if (operand === 'D') {
+ return instruction.outputType
+ }
+ return operand === 'A' ? instruction.inputTypeA : instruction.inputTypeB
+}
+
+/**
+ * Return the bit width associated with a given element type.
+ */
+export function getElementSizeBits(dataType: string | undefined): number {
+ if (!dataType) {
+ return DEFAULT_ELEMENT_BITS
+ }
+ return ELEMENT_BIT_WIDTHS[dataType] ?? DEFAULT_ELEMENT_BITS
+}
+
+/**
+ * Determine how many logical elements fit inside a physical register.
+ */
+export function computeElementsPerRegister(
+ registerWidthBits: number,
+ elementBits: number
+): number {
+ const sanitizedBits = elementBits > 0 ? elementBits : DEFAULT_ELEMENT_BITS
+ const registerWidth = registerWidthBits > 0 ? registerWidthBits : VGPR_REGISTER_WIDTH_BITS
+ return Math.max(1, Math.floor(registerWidth / sanitizedBits))
+}
+
+/**
+ * Convert a logical index into a physical register identifier and slot offset.
+ */
+export function decomposeLogicalIndex(
+ logicalIndex: number,
+ elementsPerRegister: number
+): { physicalRegister: number; elementOffset: number } {
+ if (elementsPerRegister <= 0) {
+ throw new Error('elementsPerRegister must be positive')
+ }
+ const elementOffset =
+ ((logicalIndex % elementsPerRegister) + elementsPerRegister) % elementsPerRegister
+ const physicalRegister = Math.floor(logicalIndex / elementsPerRegister)
+ return { physicalRegister, elementOffset }
+}
+
+/**
+ * Format the bit range representing a packed element slice within a register.
+ */
+export function formatBitRange(
+ elementOffset: number,
+ elementBits: number,
+ elementsPerRegister: number,
+ registerWidthBits: number = VGPR_REGISTER_WIDTH_BITS
+): string | null {
+ const sanitizedBits = elementBits > 0 ? elementBits : DEFAULT_ELEMENT_BITS
+ if (elementsPerRegister <= 1) {
+ if (sanitizedBits >= registerWidthBits) {
+ return `[${Math.max(0, registerWidthBits - 1)}:0]`
+ }
+ const endBit = sanitizedBits - 1
+ if (endBit < 0) {
+ return null
+ }
+ return `[${endBit}:0]`
+ }
+
+ const startBit = Math.max(0, elementOffset * sanitizedBits)
+ const endBit = startBit + sanitizedBits - 1
+ if (endBit < startBit) {
+ return null
+ }
+ const clampedEnd = Math.min(endBit, Math.max(0, registerWidthBits - 1))
+ const clampedStart = Math.min(startBit, clampedEnd)
+ return `[${clampedEnd}:${clampedStart}]`
+}
diff --git a/src/layouts/shared/RegisterMetadata.ts b/src/layouts/shared/RegisterMetadata.ts
new file mode 100644
index 0000000..c0c981f
--- /dev/null
+++ b/src/layouts/shared/RegisterMetadata.ts
@@ -0,0 +1,114 @@
+/**
+ * Shared helpers for computing register packing metadata that apply to both
+ * WMMA and MFMA visualizations.
+ *
+ * These utilities intentionally avoid architecture-specific quirks so that a
+ * single source of truth can power UI components, verification scripts, and
+ * any future layouts that rely on consistent register metadata.
+ */
+
+import {
+ DEFAULT_ELEMENT_BITS as PACKING_DEFAULT_ELEMENT_BITS,
+ VGPR_REGISTER_WIDTH_BITS,
+ computeElementsPerRegister as computeElementsPerRegisterBase,
+ decomposeLogicalIndex,
+ formatBitRange,
+ getElementSizeBits,
+ type MatmulOperand,
+ type OperandTypeSources,
+} from './PackingUtils'
+
+export { decomposeLogicalIndex, formatBitRange, getElementSizeBits }
+export const DEFAULT_ELEMENT_BITS = PACKING_DEFAULT_ELEMENT_BITS
+export const REGISTER_WIDTH_BITS = VGPR_REGISTER_WIDTH_BITS
+
+/**
+ * Controls how packing metadata should be computed for a given operand.
+ */
+export interface PackingComputationOptions {
+ /** Override the physical register width in bits (defaults to VGPR width). */
+ registerWidthBits?: number
+ /** Force the packing factor to one element per register. */
+ forceSingleElement?: boolean
+}
+
+/**
+ * Describes how a logical operand packs elements into physical registers.
+ */
+export interface PackingMetadata {
+ elementBits: number
+ elementsPerRegister: number
+}
+
+/**
+ * Resolved metadata for a single logical element inside a register file.
+ */
+export interface LogicalElementMetadata {
+ physicalRegister: number
+ elementOffset: number
+ bitRange: string | null
+ elementsPerRegister: number
+}
+
+/**
+ * Compute packing metadata for an operand independent of architecture quirks.
+ */
+export function computePackingMetadata(
+ _instruction: T,
+ _operand: MatmulOperand,
+ dataType: string | undefined,
+ options?: PackingComputationOptions
+): PackingMetadata {
+ const elementBits = getElementSizeBits(dataType)
+ const registerWidthBits = options?.registerWidthBits ?? REGISTER_WIDTH_BITS
+ const baseElements = computeElementsPerRegisterBase(registerWidthBits, elementBits)
+ const elementsPerRegister = options?.forceSingleElement ? 1 : baseElements
+ return { elementBits, elementsPerRegister }
+}
+
+/**
+ * Convenience wrapper that mirrors previous helpers and returns packed metadata.
+ */
+export function createPackingMetadata(
+ instruction: T,
+ operand: MatmulOperand,
+ dataType: string | undefined,
+ options?: PackingComputationOptions
+): PackingMetadata {
+ return computePackingMetadata(instruction, operand, dataType, options)
+}
+
+/**
+ * Translate a logical element index into physical register metadata.
+ */
+export function getLogicalElementMetadata(
+ logicalIndex: number,
+ metadata: PackingMetadata
+): LogicalElementMetadata | null {
+ if (logicalIndex < 0) {
+ return null
+ }
+
+ const { elementsPerRegister, elementBits } = metadata
+ if (elementsPerRegister <= 0) {
+ return null
+ }
+
+ let physicalRegister = 0
+ let elementOffset = 0
+ try {
+ const decomposition = decomposeLogicalIndex(logicalIndex, elementsPerRegister)
+ physicalRegister = decomposition.physicalRegister
+ elementOffset = decomposition.elementOffset
+ } catch {
+ return null
+ }
+
+ const bitRange = formatBitRange(elementOffset, elementBits, elementsPerRegister)
+ return {
+ physicalRegister,
+ elementOffset,
+ bitRange,
+ elementsPerRegister,
+ }
+}
diff --git a/src/main.tabs.test.ts b/src/main.tabs.test.ts
index 5be77af..bd94933 100644
--- a/src/main.tabs.test.ts
+++ b/src/main.tabs.test.ts
@@ -74,6 +74,16 @@ vi.mock('./tabs/WMMALayoutTab', () => ({
WMMALayoutTab: WMMALayoutTabMock,
}))
+const MFMALayoutTabMock = vi.fn().mockImplementation(() => ({
+ activate: vi.fn(),
+ deactivate: vi.fn(),
+ resize: vi.fn(),
+}))
+
+vi.mock('./tabs/MFMALayoutTab', () => ({
+ MFMALayoutTab: MFMALayoutTabMock,
+}))
+
const setupDom = () => {
document.body.innerHTML = ''
vi.stubGlobal('alert', vi.fn())
@@ -238,6 +248,56 @@ const setupDom = () => {
return tabContent
}
+ const createMFMALayoutContent = () => {
+ const tabContent = document.createElement('div')
+ tabContent.className = 'tab-content'
+ tabContent.id = 'mfma-layout'
+
+ const contentWrapper = document.createElement('div')
+ contentWrapper.className = 'content'
+
+ const sidebar = document.createElement('aside')
+ sidebar.className = 'sidebar'
+
+ const form = document.createElement('form')
+ form.id = 'mfmaForm'
+
+ const architectureSelect = document.createElement('select')
+ architectureSelect.id = 'mfma-architecture'
+ architectureSelect.innerHTML = `
+
+
+
+ `
+ form.appendChild(architectureSelect)
+
+ const instructionSelect = document.createElement('select')
+ instructionSelect.id = 'mfma-instruction'
+ form.appendChild(instructionSelect)
+
+ const operandSelect = document.createElement('select')
+ operandSelect.id = 'mfma-operand'
+ operandSelect.innerHTML = `
+
+
+
+ `
+ form.appendChild(operandSelect)
+
+ sidebar.appendChild(form)
+
+ const controls = document.createElement('div')
+ controls.className = 'controls'
+ controls.setAttribute('data-controls', '')
+ sidebar.appendChild(controls)
+
+ contentWrapper.appendChild(sidebar)
+ contentWrapper.appendChild(createVisualizationSection('mfma-canvas'))
+
+ tabContent.appendChild(contentWrapper)
+ return tabContent
+ }
+
const container = document.createElement('div')
container.className = 'container'
document.body.appendChild(container)
@@ -250,6 +310,7 @@ const setupDom = () => {
{ id: 'block-layout', label: 'Block Layout', active: true },
{ id: 'shared-layout', label: 'Shared Layout' },
{ id: 'wmma-layout', label: 'WMMA Layout' },
+ { id: 'mfma-layout', label: 'MFMA Layout' },
{ id: 'linear-layout', label: 'Linear Layout' },
{ id: 'ck-layout', label: 'CK Layout' },
]
@@ -261,6 +322,7 @@ const setupDom = () => {
container.appendChild(createBlockLayoutContent())
container.appendChild(createPlaceholderTabContent('shared-layout'))
container.appendChild(createWMMALayoutContent())
+ container.appendChild(createMFMALayoutContent())
container.appendChild(createPlaceholderTabContent('linear-layout'))
container.appendChild(createPlaceholderTabContent('ck-layout'))
diff --git a/src/main.ts b/src/main.ts
index 4f880fb..1221243 100644
--- a/src/main.ts
+++ b/src/main.ts
@@ -1,5 +1,6 @@
import { BlockLayoutTab } from './tabs/BlockLayoutTab'
import { WMMALayoutTab } from './tabs/WMMALayoutTab'
+import { MFMALayoutTab } from './tabs/MFMALayoutTab'
type TabController = {
activate(): void
@@ -20,6 +21,7 @@ if (tabContents.length === 0) {
const controllers: Map = new Map()
controllers.set('block-layout', new BlockLayoutTab('block-layout'))
controllers.set('wmma-layout', new WMMALayoutTab('wmma-layout'))
+controllers.set('mfma-layout', new MFMALayoutTab('mfma-layout'))
let currentTabId: string | null = null
diff --git a/src/tabs/MFMALayoutTab.ts b/src/tabs/MFMALayoutTab.ts
new file mode 100644
index 0000000..f850f5d
--- /dev/null
+++ b/src/tabs/MFMALayoutTab.ts
@@ -0,0 +1,268 @@
+import {
+ createMFMALayout,
+ getMFMAInstructions,
+ getMFMAPositionsForThread,
+ resolveOperandDataType,
+ type MFMAArchitecture,
+ type MFMAOperand,
+ type MFMALayoutParams,
+} from '../layouts/MFMALayout'
+import {
+ DEFAULT_ELEMENT_BITS,
+ createPackingMetadata,
+ getLogicalElementMetadata,
+ type PackingMetadata,
+} from '../layouts/shared/RegisterMetadata'
+import type { LinearLayout } from '../core/LinearLayout'
+import { CanvasRenderer } from '../visualization/CanvasRenderer'
+import { renderSharedControls } from '../ui/renderSharedControls'
+import { CanvasTab, type CanvasTabElements } from './CanvasTab'
+
+/**
+ * Controller for MFMA layout visualizations. Mirrors the WMMA tab flow while
+ * adapting architecture/instruction choices and wave64 semantics.
+ */
+export class MFMALayoutTab extends CanvasTab {
+ private readonly architectureSelect: HTMLSelectElement
+ private readonly instructionSelect: HTMLSelectElement
+ private readonly operandSelect: HTMLSelectElement
+
+ private packingMetadata: PackingMetadata = {
+ elementBits: DEFAULT_ELEMENT_BITS,
+ elementsPerRegister: 1,
+ }
+
+ private static readonly WAVE_SIZE = 64
+
+ constructor(tabId: string) {
+ const tabContent = document.getElementById(tabId)
+ if (!tabContent) {
+ throw new Error(`MFMALayoutTab container not found: ${tabId}`)
+ }
+
+ const visualizationContainer = tabContent.querySelector('.visualization')
+ if (!(visualizationContainer instanceof HTMLElement)) {
+ throw new Error('MFMALayoutTab visualization container not found')
+ }
+
+ const canvas = visualizationContainer.querySelector('canvas')
+ if (!(canvas instanceof HTMLCanvasElement)) {
+ throw new Error('MFMALayoutTab canvas element not found')
+ }
+
+ const controlsContainer = tabContent.querySelector('[data-controls]')
+ if (!(controlsContainer instanceof HTMLElement)) {
+ throw new Error('MFMALayoutTab controls container not found')
+ }
+ const resetButton = renderSharedControls(controlsContainer, { resetButtonId: 'mfma-reset' })
+
+ const form = tabContent.querySelector('#mfmaForm')
+ if (!(form instanceof HTMLFormElement)) {
+ throw new Error('MFMALayoutTab form element not found')
+ }
+
+ const architectureSelect = form.querySelector('#mfma-architecture')
+ if (!(architectureSelect instanceof HTMLSelectElement)) {
+ throw new Error('MFMALayoutTab architecture select not found')
+ }
+
+ const instructionSelect = form.querySelector('#mfma-instruction')
+ if (!(instructionSelect instanceof HTMLSelectElement)) {
+ throw new Error('MFMALayoutTab instruction select not found')
+ }
+
+ const operandSelect = form.querySelector('#mfma-operand')
+ if (!(operandSelect instanceof HTMLSelectElement)) {
+ throw new Error('MFMALayoutTab operand select not found')
+ }
+
+ const elements: CanvasTabElements = {
+ root: tabContent,
+ canvas,
+ visualizationContainer,
+ resetButton,
+ }
+
+ super(elements)
+
+ this.architectureSelect = architectureSelect
+ this.instructionSelect = instructionSelect
+ this.operandSelect = operandSelect
+
+ this.setupEventListeners()
+ this.populateInstructionOptions()
+ this.updateVisualization()
+ }
+
+ protected resetHover(): void {
+ this.hideTooltip()
+ }
+
+ protected handleHover(event: MouseEvent): void {
+ const renderer = this.getRenderer()
+ if (!renderer) {
+ this.hideTooltip()
+ return
+ }
+
+ const rect = this.canvas.getBoundingClientRect()
+ const x = event.clientX - rect.left
+ const y = event.clientY - rect.top
+ const gridPos = renderer.screenToGrid(x, y)
+ const cellInfo = renderer.getCellInfo(gridPos.row, gridPos.col)
+
+ if (!cellInfo) {
+ this.hideTooltip()
+ return
+ }
+
+ const elementMetadata = this.getRegisterMetadata(cellInfo.registerId)
+ const logicalIndex = `(${cellInfo.sourcePosition[0]}, ${cellInfo.sourcePosition[1]})`
+
+ let registerLabel = 'N/A'
+ let bitRangeLabel = 'N/A'
+ let elementLabel = 'N/A'
+
+ if (elementMetadata) {
+ registerLabel = `v${elementMetadata.physicalRegister}`
+ bitRangeLabel = elementMetadata.bitRange ?? 'N/A'
+ const elementIndex =
+ elementMetadata.physicalRegister * elementMetadata.elementsPerRegister +
+ elementMetadata.elementOffset
+ elementLabel = elementIndex.toString()
+ }
+
+ const tooltipLines = [
+ `Logical Index: ${logicalIndex}
`,
+ `Thread: ${cellInfo.threadId}
`,
+ `Register: ${registerLabel}
`,
+ `Bit Range: ${bitRangeLabel}
`,
+ `Element: ${elementLabel}
`,
+ ]
+
+ this.tooltip.show(tooltipLines.join(''), event.clientX, event.clientY)
+ }
+
+ private setupEventListeners(): void {
+ this.architectureSelect.addEventListener('change', () => {
+ this.populateInstructionOptions()
+ this.updateVisualization()
+ })
+
+ this.instructionSelect.addEventListener('change', () => {
+ this.updateVisualization()
+ })
+
+ this.operandSelect.addEventListener('change', () => {
+ this.updateVisualization()
+ })
+ }
+
+ private populateInstructionOptions(): void {
+ const architecture = this.architectureSelect.value as MFMAArchitecture
+ const instructions = getMFMAInstructions(architecture)
+
+ this.instructionSelect.innerHTML = ''
+ instructions.forEach((instr, idx) => {
+ const option = document.createElement('option')
+ option.value = idx.toString()
+ const inputTypeLabel =
+ instr.inputTypeA === instr.inputTypeB
+ ? instr.inputTypeA
+ : `${instr.inputTypeA}/${instr.inputTypeB}`
+ option.textContent = `${instr.mnemonic} (${instr.m}×${instr.n}×${instr.k}, ${inputTypeLabel} → ${instr.outputType})`
+ this.instructionSelect.appendChild(option)
+ })
+
+ this.instructionSelect.selectedIndex = 0
+ }
+
+ private getParams(): MFMALayoutParams {
+ const architecture = this.architectureSelect.value as MFMAArchitecture
+ const instructions = getMFMAInstructions(architecture)
+ const instructionIdx = parseInt(this.instructionSelect.value, 10) || 0
+ const instruction = instructions[instructionIdx]
+ const operand = this.operandSelect.value as MFMAOperand
+
+ if (!instruction) {
+ throw new Error('No MFMA instruction selected')
+ }
+
+ return {
+ architecture,
+ instruction,
+ operand,
+ }
+ }
+
+ private updateVisualization(): void {
+ try {
+ const params = this.getParams()
+ const dataType = resolveOperandDataType(params.instruction, params.operand)
+ this.updatePackingMetadata(params, dataType)
+
+ this.resizeCanvas()
+
+ const layout = createMFMALayout(params)
+ const tensorShape = this.getTensorShape(params)
+
+ const blockParams = {
+ sizePerThread: [1, 1] as [number, number],
+ threadsPerWarp: [MFMALayoutTab.WAVE_SIZE, 1] as [number, number],
+ warpsPerCTA: [1, 1] as [number, number],
+ order: [0, 1] as [number, number],
+ tensorShape,
+ }
+
+ const positionResolver = (layoutData: LinearLayout, threadId: number) => {
+ const positions = getMFMAPositionsForThread(
+ layoutData,
+ params.instruction,
+ params.operand,
+ threadId
+ )
+
+ return positions.map(({ pos, registerId }) => ({
+ pos: [pos[0], pos[1]] as [number, number],
+ registerId,
+ sourcePos: [pos[0], pos[1]] as [number, number],
+ }))
+ }
+
+ const renderer = new CanvasRenderer(this.canvas, layout, blockParams, positionResolver, {
+ colorGrouping: 'thread',
+ })
+ this.setRenderer(renderer)
+ renderer.render()
+ } catch (error) {
+ this.packingMetadata = {
+ elementBits: DEFAULT_ELEMENT_BITS,
+ elementsPerRegister: 1,
+ }
+ console.error('Failed to create MFMA visualization:', error)
+ alert(`Failed to create MFMA visualization: ${error}`)
+ }
+ }
+
+ private getTensorShape(params: MFMALayoutParams): [number, number] {
+ if (params.operand === 'D') {
+ return [params.instruction.m, params.instruction.n]
+ }
+ if (params.operand === 'A') {
+ return [params.instruction.m, params.instruction.k]
+ }
+ return [params.instruction.k, params.instruction.n]
+ }
+
+ private updatePackingMetadata(params: MFMALayoutParams, dataType: string | undefined): void {
+ this.packingMetadata = createPackingMetadata(
+ params.instruction,
+ params.operand,
+ dataType
+ )
+ }
+
+ private getRegisterMetadata(logicalIndex: number) {
+ return getLogicalElementMetadata(logicalIndex, this.packingMetadata)
+ }
+}
diff --git a/src/tabs/WMMALayoutTab.ts b/src/tabs/WMMALayoutTab.ts
index e533015..05e2c43 100644
--- a/src/tabs/WMMALayoutTab.ts
+++ b/src/tabs/WMMALayoutTab.ts
@@ -8,11 +8,11 @@ import {
type WMMALayoutParams,
} from '../layouts/WMMALayout'
import {
- DEFAULT_ELEMENT_BITS,
createPackingMetadata,
getLogicalElementMetadata,
type PackingMetadata,
-} from '../layouts/WMMAPacking'
+} from '../layouts/shared/RegisterMetadata'
+import { DEFAULT_ELEMENT_BITS } from '../layouts/WMMAPacking'
import type { LinearLayout } from '../core/LinearLayout'
import { CanvasRenderer } from '../visualization/CanvasRenderer'
import { renderSharedControls } from '../ui/renderSharedControls'
@@ -212,7 +212,14 @@ export class WMMALayoutTab extends CanvasTab {
try {
const params = this.getParams()
const dataType = resolveOperandDataType(params.instruction, params.operand)
- this.packingMetadata = createPackingMetadata(params.instruction, params.operand, dataType)
+ const forceSingleElement =
+ params.operand === 'D' && params.instruction.version < 2
+ this.packingMetadata = createPackingMetadata(
+ params.instruction,
+ params.operand,
+ dataType,
+ { forceSingleElement }
+ )
const replicationConfig = this.getReplicationConfig(params)
this.resizeCanvas()
diff --git a/verification/build-helper.sh b/verification/build-helper.sh
index 56723d8..3442aa3 100755
--- a/verification/build-helper.sh
+++ b/verification/build-helper.sh
@@ -7,20 +7,26 @@ set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
-echo "Building verification helper module..."
+echo "Building verification helper modules..."
cd "$PROJECT_ROOT"
-# Compile the helper module to a bundled .mjs file
-node_modules/.bin/esbuild \
- verification/lib/verification-helper.ts \
- --bundle \
- --platform=node \
- --format=esm \
- --target=es2020 \
- --outfile=verification/lib/verification-helper.mjs
+declare -a ENTRYPOINTS=(
+ "verification/lib/verification-helper.ts:verification/lib/verification-helper.mjs"
+ "verification/lib/mfma-verification-helper.ts:verification/lib/mfma-verification-helper.mjs"
+)
-# Make it executable
-chmod +x verification/lib/verification-helper.mjs
+for entry in "${ENTRYPOINTS[@]}"; do
+ IFS=":" read -r SOURCE TARGET <<< "$entry"
+ echo " - Bundling $SOURCE -> $TARGET"
+ node_modules/.bin/esbuild \
+ "$SOURCE" \
+ --bundle \
+ --platform=node \
+ --format=esm \
+ --target=es2020 \
+ --outfile="$TARGET"
+ chmod +x "$TARGET"
+done
-echo "✓ Verification helper built successfully: verification/lib/verification-helper.mjs"
+echo "✓ Verification helpers built successfully."
diff --git a/verification/lib/mfma-verification-helper.ts b/verification/lib/mfma-verification-helper.ts
new file mode 100644
index 0000000..f6c0165
--- /dev/null
+++ b/verification/lib/mfma-verification-helper.ts
@@ -0,0 +1,212 @@
+#!/usr/bin/env node
+/**
+ * Verification helper for MFMA layouts.
+ *
+ * Mirrors the WMMA verification helper structure but targets MFMA instructions
+ * defined in src/layouts/MFMALayout.ts. Generates physical lane/register/bit
+ * mappings for consumption by the Python verification script.
+ */
+
+import {
+ CDNA1_MFMA_INSTRUCTIONS,
+ CDNA2_MFMA_INSTRUCTIONS,
+ CDNA3_MFMA_INSTRUCTIONS,
+ createMFMAOperandLayout,
+ resolveOperandDataType,
+ type MFMAArchitecture,
+ type MFMAInstruction,
+ type MFMAOperand,
+} from '../../src/layouts/MFMALayout'
+import {
+ REGISTER_WIDTH_BITS as SHARED_REGISTER_WIDTH_BITS,
+ computePackingMetadata,
+ decomposeLogicalIndex,
+ formatBitRange,
+} from '../../src/layouts/shared/RegisterMetadata'
+
+const MFMA_WAVE_SIZE = 64
+const REGISTER_WIDTH_BITS = SHARED_REGISTER_WIDTH_BITS
+const REGISTER_DIM = 'register'
+const LANE_DIM = 'lane'
+const DIM_M = 'dimM'
+const DIM_N = 'dimN'
+const DIM_K = 'dimK'
+
+const ARCH_TO_INSTRUCTIONS: Record = {
+ cdna1: CDNA1_MFMA_INSTRUCTIONS,
+ cdna2: CDNA2_MFMA_INSTRUCTIONS,
+ cdna3: CDNA3_MFMA_INSTRUCTIONS,
+}
+
+export interface LaneElementMapping {
+ elementIndex: number
+ lane: number
+ physicalRegister: number
+ bitRange: string
+}
+
+export interface LayoutMappings {
+ [position: string]: LaneElementMapping[]
+}
+
+function getMatrixPosition(
+ operand: MFMAOperand,
+ result: Record
+): { row: number; col: number } {
+ if (operand === 'D') {
+ return { row: result[DIM_M] ?? 0, col: result[DIM_N] ?? 0 }
+ }
+ if (operand === 'A') {
+ return { row: result[DIM_M] ?? 0, col: result[DIM_K] ?? 0 }
+ }
+ return { row: result[DIM_K] ?? 0, col: result[DIM_N] ?? 0 }
+}
+
+function wrapIndex(value: number | undefined, size: number): number | undefined {
+ if (typeof value !== 'number' || size <= 0) {
+ return value
+ }
+ const wrapped = ((value % size) + size) % size
+ return wrapped
+}
+
+function usesShared4x4Registers(instruction: MFMAInstruction): boolean {
+ return instruction.m === 4 && instruction.n === 4
+}
+
+export function generateLayoutMappings(
+ instruction: MFMAInstruction,
+ operand: MFMAOperand
+): LayoutMappings {
+ const layout = createMFMAOperandLayout(instruction, operand)
+
+ const registerCount = layout.hasInDim(REGISTER_DIM)
+ ? layout.getInDimSize(REGISTER_DIM)
+ : instruction.k
+ const laneCount = layout.hasInDim(LANE_DIM) ? layout.getInDimSize(LANE_DIM) : MFMA_WAVE_SIZE
+
+ const dataType = resolveOperandDataType(instruction, operand)
+ const baseMetadata = computePackingMetadata(instruction, operand, dataType, {
+ registerWidthBits: REGISTER_WIDTH_BITS,
+ })
+ const registersPerElement = Math.max(
+ 1,
+ Math.ceil(baseMetadata.elementBits / REGISTER_WIDTH_BITS)
+ )
+ const elementsPerRegister =
+ registersPerElement > 1 ? 1 : baseMetadata.elementsPerRegister
+ const elementBits = baseMetadata.elementBits
+
+ const mappings: LayoutMappings = {}
+
+ for (let lane = 0; lane < laneCount; lane++) {
+ for (let logicalIndex = 0; logicalIndex < registerCount; logicalIndex++) {
+ const result = layout.apply({
+ [REGISTER_DIM]: logicalIndex,
+ [LANE_DIM]: lane,
+ })
+
+ if (
+ instruction.laneGroupsShareDimK &&
+ typeof result[DIM_K] === 'number' &&
+ instruction.k > 0
+ ) {
+ const currentK = result[DIM_K]
+ const wrapped = ((currentK % instruction.k) + instruction.k) % instruction.k
+ result[DIM_K] = wrapped
+ }
+
+ const shouldWrapM = (operand === 'D' || operand === 'A') && instruction.m > 0
+ const shouldWrapN = (operand === 'D' || operand === 'B') && instruction.n > 0
+ if (shouldWrapM) {
+ result[DIM_M] = wrapIndex(result[DIM_M], instruction.m)
+ }
+ if (shouldWrapN) {
+ result[DIM_N] = wrapIndex(result[DIM_N], instruction.n)
+ }
+
+ if (registersPerElement > 1 && usesShared4x4Registers(instruction)) {
+ const laneGroup = Math.floor(lane / 16)
+ const laneQuad = lane % 4
+ if (operand === 'A') {
+ result[DIM_M] = wrapIndex(laneQuad, instruction.m)
+ result[DIM_K] = wrapIndex(laneGroup, instruction.k)
+ } else if (operand === 'B') {
+ result[DIM_K] = wrapIndex(laneGroup, instruction.k)
+ result[DIM_N] = wrapIndex(laneQuad, instruction.n)
+ } else {
+ result[DIM_M] = wrapIndex(laneGroup, instruction.m)
+ result[DIM_N] = wrapIndex(laneQuad, instruction.n)
+ }
+ }
+
+ const { row, col } = getMatrixPosition(operand, result)
+ const { physicalRegister, elementOffset } = decomposeLogicalIndex(
+ logicalIndex,
+ elementsPerRegister
+ )
+ const bitRange = formatBitRange(elementOffset, elementBits, elementsPerRegister) ?? ''
+
+ const key = `${row}_${col}`
+ if (!mappings[key]) {
+ mappings[key] = []
+ }
+
+ const shareRegisters = operand === 'D' && usesShared4x4Registers(instruction)
+ const registerBlocks = operand === 'D'
+ ? shareRegisters
+ ? 1
+ : Math.max(1, instruction.registerBlocks ?? 1)
+ : 1
+ const registerStride = shareRegisters ? 0 : registerCount * registersPerElement
+ const baseRegister =
+ registersPerElement > 1
+ ? Math.floor(physicalRegister) * registersPerElement
+ : physicalRegister
+ const fullRange = `[${REGISTER_WIDTH_BITS - 1}:0]`
+ const chunks = registersPerElement > 1
+ ? Array.from({ length: registersPerElement }, (_, idx) => ({ offset: idx, range: fullRange }))
+ : [{ offset: 0, range: bitRange }]
+
+ for (let block = 0; block < registerBlocks; block++) {
+ const blockOffset = block * registerStride
+ for (const chunk of chunks) {
+ mappings[key].push({
+ elementIndex: logicalIndex + block * registerCount,
+ lane,
+ physicalRegister: baseRegister + chunk.offset + blockOffset,
+ bitRange: chunk.range,
+ })
+ }
+ }
+ }
+ }
+
+ return mappings
+}
+
+export function getInstructionMetadata() {
+ return ARCH_TO_INSTRUCTIONS
+}
+
+function main() {
+ const args = process.argv.slice(2)
+ if (args.length === 0) {
+ console.error('Usage: mfma-verification-helper.ts ')
+ process.exit(1)
+ }
+
+ const config = JSON.parse(args[0])
+ if (config.action === 'listInstructions') {
+ console.log(JSON.stringify(getInstructionMetadata()))
+ return
+ }
+
+ const { instruction, operand } = config
+ const mappings = generateLayoutMappings(instruction, operand)
+ console.log(JSON.stringify(mappings))
+}
+
+if (import.meta.url === `file://${process.argv[1]}`) {
+ main()
+}
diff --git a/verification/lib/verification-helper.ts b/verification/lib/verification-helper.ts
index b318c71..0cfdbbd 100644
--- a/verification/lib/verification-helper.ts
+++ b/verification/lib/verification-helper.ts
@@ -21,12 +21,12 @@ import {
createWMMAOperandLayout,
resolveOperandDataType,
} from '../../src/layouts/WMMALayout'
+import { computeElementsPerRegister } from '../../src/layouts/WMMAPacking'
import {
- computeElementsPerRegister,
decomposeLogicalIndex,
formatBitRange,
getElementSizeBits,
-} from '../../src/layouts/WMMAPacking'
+} from '../../src/layouts/shared/RegisterMetadata'
/**
* Represents a single lane/element mapping with associated metadata
diff --git a/verification/verify_mfma_layouts.py b/verification/verify_mfma_layouts.py
new file mode 100755
index 0000000..1d53d3f
--- /dev/null
+++ b/verification/verify_mfma_layouts.py
@@ -0,0 +1,827 @@
+#!/usr/bin/env python3
+"""
+Clean MFMA layout verification script.
+
+Compares our TypeScript linear layout implementation against AMD's official calculator
+by using a pre-built verification helper module (no dynamic code generation).
+
+The ground truth is based on ../amd_matrix_instruction_calculator
+The goal is to ensure the MFMA linear layout is correct compared to the ground truth.
+"""
+
+import argparse
+import subprocess
+import json
+import sys
+import re
+from typing import Any, Dict, List, Tuple
+from pathlib import Path
+
+OPERANDS = ['A', 'B', 'D']
+
+REGISTER_WIDTH_BITS = 32
+FULL_REGISTER_BIT_RANGE = f'[{REGISTER_WIDTH_BITS - 1}:0]'
+
+# Project paths
+SCRIPT_DIR = Path(__file__).parent
+PROJECT_ROOT = SCRIPT_DIR.parent
+HELPER_PATH = SCRIPT_DIR / 'lib' / 'mfma-verification-helper.mjs'
+
+def resolve_calculator_path() -> Path:
+ """
+ Locate the AMD matrix calculator script. Prefer a sibling checkout (legacy layout)
+ but fall back to a submodule inside the project root for CI and developer workflows.
+ """
+ candidate_dirs = [
+ PROJECT_ROOT.parent / 'amd_matrix_instruction_calculator',
+ PROJECT_ROOT / 'amd_matrix_instruction_calculator',
+ ]
+
+ for directory in candidate_dirs:
+ candidate = directory / 'matrix_calculator.py'
+ if candidate.exists():
+ return candidate
+
+ raise FileNotFoundError(
+ "Unable to locate amd_matrix_instruction_calculator. "
+ "Run 'git submodule update --init --recursive' from the project root."
+ )
+
+CALCULATOR_PATH = resolve_calculator_path()
+
+BIT_RANGE_PATTERN = re.compile(r'^\[?(?P\d+)(?::(?P\d+))?\]?$', re.IGNORECASE)
+
+
+def normalize_bit_range(bit_range: str) -> str:
+ """
+ Normalize AMD bit range strings into a canonical format.
+
+ Examples:
+ "15:0" -> "[15:0]"
+ "[0:15]" -> "[15:0]"
+ "[7]" -> "[7:7]"
+ "" -> ""
+ """
+ cleaned = (bit_range or '').strip()
+ if not cleaned:
+ return ''
+
+ cleaned = cleaned.replace(' ', '')
+ match = BIT_RANGE_PATTERN.match(cleaned)
+ if not match:
+ # Leave unknown formats untouched but ensure brackets for consistency
+ return cleaned if cleaned.startswith('[') else f'[{cleaned}]'
+
+ first = int(match.group('first'))
+ second_group = match.group('second')
+ if second_group is None:
+ high = first
+ low = first
+ else:
+ second = int(second_group)
+ high = max(first, second)
+ low = min(first, second)
+
+ return f'[{high}:{low}]'
+
+
+def extract_bit_bounds(bit_range: str) -> Tuple[int, int]:
+ """
+ Extract (high, low) bounds from a normalized bit range.
+ Returns (0, 0) for empty ranges.
+ """
+ cleaned = (bit_range or '').strip()
+ if not cleaned:
+ return 0, 0
+
+ if cleaned.startswith('[') and cleaned.endswith(']'):
+ cleaned = cleaned[1:-1]
+
+ if ':' not in cleaned:
+ value = int(cleaned)
+ return value, value
+
+ high_str, low_str = cleaned.split(':', 1)
+ return int(high_str), int(low_str)
+
+
+def bit_range_sort_key(bit_range: str) -> Tuple[int, int]:
+ """
+ Sorting key that orders entries by their starting bit (low -> high) within a register.
+ """
+ high, low = extract_bit_bounds(bit_range)
+ return low, high
+
+
+def physical_key(entry: Dict[str, Any]) -> Tuple[int, int, str]:
+ """
+ Create a comparable key for physical storage tuples.
+ """
+ return entry['lane'], entry['register'], entry['bitRange']
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Verify MFMA operand layouts against AMD's matrix calculator."
+ )
+ parser.add_argument(
+ '--arch',
+ choices=['cdna1', 'cdna2', 'cdna3', 'all'],
+ default='all',
+ help="Limit verification to a specific architecture."
+ )
+ parser.add_argument(
+ '--instruction',
+ action='append',
+ default=None,
+ help="Only run specific instruction names (can be repeated or comma-separated)."
+ )
+ parser.add_argument(
+ '--operand',
+ choices=['A', 'B', 'D', 'all'],
+ default='all',
+ help="Limit verification to a single operand."
+ )
+ parser.add_argument(
+ '--verbose',
+ action='store_true',
+ help="Print detailed mismatch information for failing cases."
+ )
+ parser.add_argument(
+ '--max-diffs',
+ type=int,
+ default=10,
+ help="Maximum number of mismatches to record and display per failing case."
+ )
+ return parser.parse_args()
+
+
+def ensure_helper_built(prefix: str = '') -> bool:
+ indent = prefix or ''
+
+ if indent:
+ print(f"{indent}Building verification helper...")
+ else:
+ print("Building verification helper...")
+
+ build_result = subprocess.run(
+ [str(SCRIPT_DIR / 'build-helper.sh')],
+ capture_output=True,
+ text=True,
+ cwd=str(PROJECT_ROOT)
+ )
+
+ if build_result.returncode != 0:
+ error_output = build_result.stderr.strip() or build_result.stdout.strip()
+ if indent:
+ print(f"{indent}Build error: {error_output}")
+ else:
+ print(f"Build error: {error_output}")
+ return False
+
+ return True
+
+
+def load_instruction_metadata() -> Dict[str, List[Dict]]:
+ if not ensure_helper_built():
+ return {}
+
+ result = subprocess.run(
+ ['node', str(HELPER_PATH), json.dumps({'action': 'listInstructions'})],
+ capture_output=True,
+ text=True,
+ cwd=str(PROJECT_ROOT)
+ )
+
+ if result.returncode != 0:
+ error_output = result.stderr.strip() or result.stdout.strip()
+ print(f"Failed to load instruction metadata: {error_output}")
+ return {}
+
+ try:
+ data = json.loads(result.stdout.strip() or '{}')
+ except json.JSONDecodeError as exc:
+ print(f"Failed to parse instruction metadata JSON: {exc}")
+ return {}
+
+ cdna1 = data.get('cdna1')
+ cdna2 = data.get('cdna2')
+ cdna3 = data.get('cdna3')
+
+ if not all(isinstance(arr, list) for arr in (cdna1, cdna2, cdna3)):
+ print("Instruction metadata missing expected CDNA1/CDNA2/CDNA3 arrays.")
+ return {}
+
+ return {'cdna1': cdna1, 'cdna2': cdna2, 'cdna3': cdna3}
+
+
+def parse_lane_entries(cell_text: str) -> List[Dict[str, Any]]:
+ """
+ Parse all (register, lane, bit range) entries from a calculator cell.
+
+ Handles formats like:
+ - "v0{0}" -> [{'register': 0, 'lane': 0, 'bit_range': ''}]
+ - "v0{0}.[15:0]" -> [{'register': 0, 'lane': 0, 'bit_range': '[15:0]'}]
+ - "v0{0}.[15:0]
v0{16}.[15:0]" -> [{'register': 0, 'lane': 0, 'bit_range': '[15:0]'}, {'register': 0, 'lane': 16, 'bit_range': '[15:0]'}]
+
+ Returns: List of entry dictionaries with register, lane, and bit_range.
+ """
+ entries: List[Dict[str, Any]] = []
+
+ # Split on
to handle multi-lane cells
+ parts = cell_text.split('
')
+
+ for part in parts:
+ part = part.strip()
+ if not part:
+ continue
+
+ # Match register ranges like v[hi:lo]{lane}
+ range_match = re.match(r'v\[(\d+):(\d+)\]\{(\d+)\}(?:\.\[([^\]]+)\])?', part)
+ if range_match:
+ reg_hi = int(range_match.group(1))
+ reg_lo = int(range_match.group(2))
+ lane = int(range_match.group(3))
+ raw_range = range_match.group(4)
+ if raw_range:
+ normalized_bit_range = raw_range if raw_range.startswith('[') else f'[{raw_range}]'
+ else:
+ normalized_bit_range = ''
+ start = min(reg_lo, reg_hi)
+ end = max(reg_lo, reg_hi)
+ for register in range(start, end + 1):
+ entries.append({
+ 'register': register,
+ 'lane': lane,
+ 'bit_range': normalized_bit_range,
+ })
+ continue
+
+ # Match: v{reg}{lane} with optional .[bits]
+ match = re.match(r'v(\d+)\{(\d+)\}(?:\.\[([^\]]+)\])?', part)
+ if match:
+ register = int(match.group(1))
+ lane = int(match.group(2))
+ raw_range = match.group(3)
+ if raw_range:
+ normalized_bit_range = raw_range if raw_range.startswith('[') else f'[{raw_range}]'
+ else:
+ normalized_bit_range = ''
+ entries.append({
+ 'register': register,
+ 'lane': lane,
+ 'bit_range': normalized_bit_range,
+ })
+
+ return entries
+
+
+def run_amd_calculator(
+ arch: str,
+ instruction: Dict,
+ operand: str
+) -> List[Dict[str, Any]]:
+ """
+ Run the AMD matrix calculator and parse lane/register/bitRange tuples.
+
+ Returns:
+ List of mappings with fields:
+ lane, register, bitRange, row, col
+ """
+ instruction_name = instruction.get('name')
+ if not instruction_name:
+ return []
+ calc_instruction = instruction_name.lower()
+ if not calc_instruction.startswith('v_'):
+ calc_instruction = f"v_{calc_instruction}"
+
+ cmd = [
+ 'python3',
+ str(CALCULATOR_PATH),
+ '-a', arch.upper(),
+ '-i', calc_instruction,
+ f'-{operand}',
+ '-R',
+ '--markdown',
+ '-w', '64'
+ ]
+
+ result = subprocess.run(cmd, capture_output=True, text=True, cwd=str(PROJECT_ROOT))
+ if result.returncode != 0:
+ print(f" Calculator error: {result.stderr}")
+ return []
+
+ lines = result.stdout.strip().split('\n')
+ entries: List[Dict[str, Any]] = []
+ seen_keys: Dict[Tuple[int, int, str], Dict[str, Any]] = {}
+
+ # Find the table header
+ table_start = -1
+ table_header_pattern = re.compile(r'\[[MNK]\]\[[MNK]\]')
+ for i, line in enumerate(lines):
+ if '|' in line and table_header_pattern.search(line):
+ table_start = i
+ break
+
+ if table_start == -1:
+ return []
+
+ # Parse data rows (skip header and separator line)
+ data_lines = lines[table_start + 2:]
+
+ try:
+ for line in data_lines:
+ if '|' not in line or line.strip().startswith('|--'):
+ continue
+
+ parts = [p.strip() for p in line.split('|')]
+ if len(parts) < 3:
+ continue
+
+ row_str = parts[1].strip()
+ if not row_str or not row_str.lstrip('-').isdigit():
+ # Skip header rows such as "A[M][K]" that appear when the calculator prints
+ # multiple blocks in a single table.
+ continue
+ row = int(row_str)
+
+ for col_idx, cell in enumerate(parts[2:-1]):
+ cell = cell.strip()
+ if not cell or 'v' not in cell:
+ continue
+
+ reg_lane_pairs = parse_lane_entries(cell)
+ if not reg_lane_pairs:
+ raise ValueError(
+ f"Unrecognized calculator cell '{cell}' at row {row}, column {col_idx}"
+ )
+
+ for parsed in reg_lane_pairs:
+ register = parsed.get('register')
+ lane = parsed.get('lane')
+ bit_range_raw = parsed.get('bit_range', '')
+
+ if register is None or lane is None:
+ raise ValueError(
+ f"Calculator cell missing register/lane info at row {row}, column {col_idx}: '{cell}'"
+ )
+
+ if not bit_range_raw:
+ # Older calculator builds omit bit ranges for 32-bit operands.
+ normalized_bit_range = FULL_REGISTER_BIT_RANGE
+ else:
+ normalized_bit_range = normalize_bit_range(bit_range_raw)
+
+ entry = {
+ 'lane': lane,
+ 'register': register,
+ 'bitRange': normalized_bit_range,
+ 'row': row,
+ 'col': col_idx
+ }
+ key = physical_key(entry)
+ if key in seen_keys:
+ existing = seen_keys[key]
+ raise ValueError(
+ "Duplicate physical storage tuple detected for "
+ f"lane {lane}, register v{register}, bitRange {normalized_bit_range}. "
+ f"Existing position {existing['row'], existing['col']}, "
+ f"new position {(row, col_idx)}"
+ )
+
+ seen_keys[key] = entry
+ entries.append(entry)
+
+ except ValueError as exc:
+ print(f" Calculator parse error: {exc}")
+ return []
+
+ return entries
+
+
+def get_our_layout(instruction: Dict, operand: str) -> List[Dict[str, Any]]:
+ """
+ Get lane mappings from our TypeScript implementation.
+
+ This uses the pre-built mfma verification helper module, which reuses
+ the layout logic from MFMALayout.ts without duplication.
+
+ Returns:
+ List of mappings with fields:
+ lane, register, bitRange, row, col, elementIndex
+ """
+ # Build helper if it doesn't exist
+ if not ensure_helper_built(prefix=' '):
+ return []
+
+ # Prepare config for helper
+ config = {
+ 'instruction': instruction,
+ 'operand': operand
+ }
+
+ # Run the helper module
+ result = subprocess.run(
+ ['node', str(HELPER_PATH), json.dumps(config)],
+ capture_output=True,
+ text=True,
+ cwd=str(PROJECT_ROOT)
+ )
+
+ if result.returncode != 0:
+ print(f" Helper error: {result.stderr}")
+ return []
+
+ # Parse the JSON output
+ try:
+ data = json.loads(result.stdout.strip())
+ except json.JSONDecodeError as exc:
+ print(f" JSON parse error: {exc}")
+ print(f" Output: {result.stdout[:200]}")
+ return []
+
+ entries: List[Dict[str, Any]] = []
+ seen_keys: Dict[Tuple[int, int, str], Dict[str, Any]] = {}
+
+ try:
+ for key, lane_entries in data.items():
+ row_str, col_str = key.split('_')
+ row = int(row_str)
+ col = int(col_str)
+
+ for lane_entry in lane_entries:
+ lane = lane_entry.get('lane')
+ register = lane_entry.get('physicalRegister')
+ bit_range_raw = lane_entry.get('bitRange', '')
+ element_index = lane_entry.get('elementIndex')
+
+ if lane is None or register is None:
+ raise ValueError(
+ f"Helper entry missing lane/register (row {row}, col {col}): {lane_entry}"
+ )
+
+ normalized_bit_range = normalize_bit_range(bit_range_raw)
+
+ entry = {
+ 'lane': int(lane),
+ 'register': int(register),
+ 'bitRange': normalized_bit_range,
+ 'row': row,
+ 'col': col,
+ 'elementIndex': int(element_index) if element_index is not None else None
+ }
+
+ key_tuple = physical_key(entry)
+ if key_tuple in seen_keys:
+ existing = seen_keys[key_tuple]
+ raise ValueError(
+ "Implementation produced duplicate physical storage tuple for "
+ f"lane {lane}, register v{register}, bitRange {normalized_bit_range}. "
+ f"Existing position {existing['row'], existing['col']}, "
+ f"new position {(row, col)}"
+ )
+
+ seen_keys[key_tuple] = entry
+ entries.append(entry)
+
+ except (ValueError, TypeError) as exc:
+ print(f" Helper parse error: {exc}")
+ return []
+
+ return entries
+
+
+def compare_layouts(
+ calc_entries: List[Dict[str, Any]],
+ our_entries: List[Dict[str, Any]]
+) -> Tuple[bool, List[Dict[str, Any]]]:
+ """
+ Compare two layouts and return (is_match, mismatches).
+
+ Each entry is compared using its physical tuple (lane, register, bitRange).
+ """
+ calc_map = {physical_key(entry): entry for entry in calc_entries}
+ our_map = {physical_key(entry): entry for entry in our_entries}
+
+ mismatches: List[Dict[str, Any]] = []
+
+ all_keys = set(calc_map.keys()) | set(our_map.keys())
+ for key in sorted(all_keys):
+ calc_entry = calc_map.get(key)
+ our_entry = our_map.get(key)
+
+ lane, register, bit_range = key
+
+ if calc_entry and not our_entry:
+ mismatches.append({
+ 'status': 'missing',
+ 'lane': lane,
+ 'register': register,
+ 'bitRange': bit_range,
+ 'calculator': [calc_entry['row'], calc_entry['col']]
+ })
+ continue
+
+ if our_entry and not calc_entry:
+ mismatches.append({
+ 'status': 'extra',
+ 'lane': lane,
+ 'register': register,
+ 'bitRange': bit_range,
+ 'ours': [our_entry['row'], our_entry['col']]
+ })
+ continue
+
+ if not calc_entry or not our_entry:
+ # Should not reach here because of previous checks
+ continue
+
+ calc_position = (calc_entry['row'], calc_entry['col'])
+ our_position = (our_entry['row'], our_entry['col'])
+
+ if calc_position != our_position:
+ mismatches.append({
+ 'status': 'mismatch',
+ 'lane': lane,
+ 'register': register,
+ 'bitRange': bit_range,
+ 'calculator': list(calc_position),
+ 'ours': list(our_position)
+ })
+
+ return len(mismatches) == 0, mismatches
+
+
+def build_layout_index(entries: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
+ """
+ Group verified entries by matrix position for downstream tooling (e.g., tooltips).
+ """
+ grouped: Dict[str, List[Dict[str, Any]]] = {}
+ for entry in entries:
+ key = f"{entry['row']}_{entry['col']}"
+ grouped.setdefault(key, []).append({
+ 'lane': entry['lane'],
+ 'register': entry['register'],
+ 'bitRange': entry['bitRange'],
+ 'elementIndex': entry.get('elementIndex')
+ })
+
+ for value in grouped.values():
+ value.sort(key=lambda item: (item['lane'], item['register'], bit_range_sort_key(item['bitRange'])))
+
+ return grouped
+
+
+def perform_logical_sanity_checks(entries: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Perform lightweight logical index checks to ensure our layout enumerations remain sane.
+ """
+ issues: List[str] = []
+ per_lane: Dict[int, List[Dict[str, Any]]] = {}
+
+ for entry in entries:
+ element_index = entry.get('elementIndex')
+ if element_index is None:
+ continue
+ per_lane.setdefault(entry['lane'], []).append(entry)
+
+ for lane, lane_entries in per_lane.items():
+ logical_indices = sorted(entry['elementIndex'] for entry in lane_entries if entry.get('elementIndex') is not None)
+ expected = list(range(len(logical_indices)))
+ if logical_indices != expected:
+ issues.append(
+ f"Lane {lane}: expected logical indices {expected} but observed {logical_indices}"
+ )
+ continue # Skip detailed checks if the basic range already failed
+
+ sorted_by_physical = sorted(
+ lane_entries,
+ key=lambda item: (item['register'], bit_range_sort_key(item['bitRange']))
+ )
+ for idx, entry in enumerate(sorted_by_physical):
+ logical = entry.get('elementIndex')
+ if logical != idx:
+ issues.append(
+ f"Lane {lane}: logical index {logical} does not match physical ordering "
+ f"(expected {idx}) for v{entry['register']} {entry['bitRange']}"
+ )
+ break
+
+ status = 'PASS' if not issues else 'WARN'
+ return {
+ 'status': status,
+ 'issueCount': len(issues),
+ 'issues': issues[:10] # limit noise
+ }
+
+
+def verify_single_case(
+ arch: str,
+ instr: Dict,
+ operand: str,
+ case_id: int,
+ total: int,
+ verbose: bool,
+ max_diffs: int
+) -> Dict:
+ """Verify a single test case"""
+ print(f"[{case_id:03d}/{total:03d}] {arch.upper():5s} {instr['name']:30s} {operand} ... ", end='', flush=True)
+
+ # Get ground truth from AMD calculator
+ calc_entries = run_amd_calculator(arch, instr, operand)
+ if not calc_entries:
+ print("FAIL (calculator error)")
+ return {
+ 'id': case_id,
+ 'arch': arch,
+ 'instruction': instr['name'],
+ 'operand': operand,
+ 'status': 'ERROR',
+ 'error': 'Failed to run AMD calculator or parse output'
+ }
+
+ # Get our layout using the verification helper (reuses MFMALayout.ts)
+ our_entries = get_our_layout(instr, operand)
+ if not our_entries:
+ print("FAIL (implementation error)")
+ return {
+ 'id': case_id,
+ 'arch': arch,
+ 'instruction': instr['name'],
+ 'operand': operand,
+ 'status': 'ERROR',
+ 'error': 'Failed to generate our layout'
+ }
+
+ # Compare layouts
+ is_match, mismatches = compare_layouts(calc_entries, our_entries)
+ logical_report = perform_logical_sanity_checks(our_entries)
+
+ if is_match:
+ total_mappings = len(calc_entries)
+ unique_positions = len({(entry['row'], entry['col']) for entry in calc_entries})
+ suffix = ''
+ if logical_report['status'] != 'PASS':
+ suffix = f", logical {logical_report['status'].lower()}"
+ print(f"PASS ({total_mappings} mappings verified{suffix})")
+ return {
+ 'id': case_id,
+ 'arch': arch,
+ 'instruction': instr['name'],
+ 'operand': operand,
+ 'status': 'PASS',
+ 'total_positions': unique_positions,
+ 'total_mappings': total_mappings,
+ 'physicalLayout': build_layout_index(our_entries),
+ 'logicalSanity': logical_report
+ }
+ else:
+ mismatch_count = len(mismatches)
+ total_missing = sum(1 for m in mismatches if m['status'] == 'missing')
+ total_extra = sum(1 for m in mismatches if m['status'] == 'extra')
+ print(f"FAIL ({mismatch_count} mapping mismatches)")
+
+ total_relocated = mismatch_count - total_missing - total_extra
+
+ if verbose:
+ limit = min(max_diffs, mismatch_count)
+ for idx, mismatch in enumerate(mismatches[:limit], 1):
+ lane = mismatch['lane']
+ register = mismatch['register']
+ bit_range = mismatch['bitRange'] or '[full]'
+ calc_pos = mismatch.get('calculator')
+ our_pos = mismatch.get('ours')
+ status = mismatch['status']
+ calc_display = f"({calc_pos[0]}, {calc_pos[1]})" if calc_pos else '∅'
+ our_display = f"({our_pos[0]}, {our_pos[1]})" if our_pos else '∅'
+ print(f" [{idx:02d}] lane {lane:2d}, v{register:02d} {bit_range}")
+ print(f" calculator: {calc_display}")
+ print(f" ours: {our_display}")
+ if status == 'missing':
+ print(" issue: missing in our layout")
+ elif status == 'extra':
+ print(" issue: extra in our layout")
+ else:
+ print(" issue: mismatched matrix position")
+ if mismatch_count > limit:
+ remaining = mismatch_count - limit
+ print(f" ... {remaining} additional mismatches omitted (increase --max-diffs to show more)")
+ print(f" totals -> missing: {total_missing}, extra: {total_extra}, mismatched: {total_relocated}")
+
+ return {
+ 'id': case_id,
+ 'arch': arch,
+ 'instruction': instr['name'],
+ 'operand': operand,
+ 'status': 'FAIL',
+ 'total_positions': len({(entry['row'], entry['col']) for entry in calc_entries}),
+ 'mismatch_count': mismatch_count,
+ 'missing_entries': total_missing,
+ 'extra_entries': total_extra,
+ 'mismatched_entries': total_relocated,
+ 'mismatches': mismatches[:max_diffs],
+ 'logicalSanity': logical_report
+ }
+
+
+def main():
+ args = parse_args()
+
+ print("=" * 100)
+ print("MFMA LAYOUT VERIFICATION")
+ print("=" * 100)
+ print("Ground truth: AMD Matrix Instruction Calculator")
+ print("Implementation: MFMALayout.ts (via mfma-verification-helper.mjs)")
+ print("=" * 100)
+ print()
+
+ instruction_filter = None
+ if args.instruction:
+ filtered: List[str] = []
+ for value in args.instruction:
+ if not value:
+ continue
+ filtered.extend([item.strip().lower() for item in value.split(',') if item.strip()])
+ instruction_filter = set(filtered)
+
+ def include_case(arch_name: str, instruction: Dict, operand_name: str) -> bool:
+ if args.arch != 'all' and arch_name != args.arch:
+ return False
+ if args.operand != 'all' and operand_name != args.operand:
+ return False
+ if instruction_filter and instruction['name'].lower() not in instruction_filter:
+ return False
+ return True
+
+ metadata = load_instruction_metadata()
+ if not metadata:
+ print("Unable to load MFMA instruction metadata from verification helper.")
+ return 1
+
+ arch_order = ['cdna1', 'cdna2', 'cdna3']
+ test_cases = []
+ for arch_name in arch_order:
+ instructions = metadata.get(arch_name, [])
+ if not isinstance(instructions, list):
+ continue
+ for instr in instructions:
+ for operand in OPERANDS:
+ if include_case(arch_name, instr, operand):
+ test_cases.append((arch_name, instr, operand))
+
+ if not test_cases:
+ print("No test cases matched the provided filters.")
+ return 1
+
+ max_diffs = max(1, args.max_diffs)
+
+ # Run verification
+ results = []
+ for i, (arch, instr, operand) in enumerate(test_cases, 1):
+ result = verify_single_case(
+ arch,
+ instr,
+ operand,
+ i,
+ len(test_cases),
+ verbose=args.verbose,
+ max_diffs=max_diffs
+ )
+ results.append(result)
+
+ # Save results
+ results_dir = SCRIPT_DIR / 'results'
+ results_dir.mkdir(exist_ok=True)
+ results_file = results_dir / 'mfma_verification_results.json'
+ with open(results_file, 'w') as f:
+ json.dump(results, f, indent=2)
+
+ # Print summary
+ print()
+ print("=" * 100)
+ print("SUMMARY")
+ print("=" * 100)
+ passed = sum(1 for r in results if r['status'] == 'PASS')
+ failed = sum(1 for r in results if r['status'] == 'FAIL')
+ errors = sum(1 for r in results if r['status'] == 'ERROR')
+
+ print(f"Total: {len(results)}")
+ print(f"PASS: {passed:3d} ({100 * passed // len(results):2d}%)")
+ print(f"FAIL: {failed:3d} ({100 * failed // len(results):2d}%)")
+ print(f"ERROR: {errors:3d} ({100 * errors // len(results):2d}%)")
+
+ if failed > 0 or errors > 0:
+ print("\nFailed/Error cases:")
+ for r in results:
+ if r['status'] in ['FAIL', 'ERROR']:
+ status_str = r['status']
+ error_detail = r.get('error', f"{r.get('mismatch_count', 0)} mismatches")
+ print(f" [{status_str:5s}] {r['arch'].upper():5s} {r['instruction']:30s} {r['operand']} - {error_detail}")
+
+ print(f"\nDetailed results: {results_file}")
+
+ return 0 if failed == 0 and errors == 0 else 1
+
+
+if __name__ == '__main__':
+ sys.exit(main())