diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 74411c0..769351e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,6 +38,9 @@ jobs: - name: Verify WMMA layouts run: python verification/verify_wmma_layouts.py --max-diffs 5 + - name: Verify MFMA layouts + run: python verification/verify_mfma_layouts.py --arch all --max-diffs 5 + - name: Run tests run: npm test -- --run diff --git a/.gitignore b/.gitignore index 879d08c..d5528b1 100644 --- a/.gitignore +++ b/.gitignore @@ -48,7 +48,7 @@ vite.config.ts.timestamp-* # Verification artifacts verification/results/*.json verification/__pycache__/ -verification/lib/verification-helper.mjs +verification/lib/*.mjs # Project-specific ignores CLAUDE.md diff --git a/index.html b/index.html index 4e15dce..a8361e1 100644 --- a/index.html +++ b/index.html @@ -17,6 +17,7 @@

GPU Tensor Layout Visualizer

+ @@ -172,6 +173,64 @@

Operand

+
+
+ + +
+ +
+
+
+
diff --git a/src/data/mfma-instruction-data.ts b/src/data/mfma-instruction-data.ts new file mode 100644 index 0000000..bdc8e4c --- /dev/null +++ b/src/data/mfma-instruction-data.ts @@ -0,0 +1,976 @@ +// Auto-generated from amd_matrix_instruction_calculator +// Do not edit manually unless you intentionally update the source data. + +export interface MFMAInstructionEntry { + name: string + mnemonic: string + m: number + n: number + k: number + kBase: number + inputTypeA: string + inputTypeB: string + outputType: string + laneGroupsShareDimK?: boolean + registerBlocks?: number +} + +export const MFMA_INSTRUCTION_DATA: Record<"cdna1" | "cdna2" | "cdna3", MFMAInstructionEntry[]> = { + cdna1: [ + { + name: 'v_mfma_f32_16x16x16f16', + mnemonic: 'f32_16x16x16f16', + m: 16, + n: 16, + k: 16, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x1f32', + mnemonic: 'f32_16x16x1f32', + m: 16, + n: 16, + k: 1, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_f32_16x16x2bf16', + mnemonic: 'f32_16x16x2bf16', + m: 16, + n: 16, + k: 2, + kBase: 2, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_f32_16x16x4f16', + mnemonic: 'f32_16x16x4f16', + m: 16, + n: 16, + k: 4, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_f32_16x16x4f32', + mnemonic: 'f32_16x16x4f32', + m: 16, + n: 16, + k: 4, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x8bf16', + mnemonic: 'f32_16x16x8bf16', + m: 16, + n: 16, + k: 8, + kBase: 2, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x1f32', + mnemonic: 'f32_32x32x1f32', + m: 32, + n: 32, + k: 1, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_f32_32x32x2bf16', + mnemonic: 'f32_32x32x2bf16', + m: 32, + n: 32, + k: 2, + kBase: 2, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_f32_32x32x2f32', + mnemonic: 'f32_32x32x2f32', + m: 32, + n: 32, + k: 2, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x4bf16', + mnemonic: 'f32_32x32x4bf16', + m: 32, + n: 32, + k: 4, + kBase: 2, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x4f16', + mnemonic: 'f32_32x32x4f16', + m: 32, + n: 32, + k: 4, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_f32_32x32x8f16', + mnemonic: 'f32_32x32x8f16', + m: 32, + n: 32, + k: 8, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_4x4x1f32', + mnemonic: 'f32_4x4x1f32', + m: 4, + n: 4, + k: 1, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + { + name: 'v_mfma_f32_4x4x2bf16', + mnemonic: 'f32_4x4x2bf16', + m: 4, + n: 4, + k: 2, + kBase: 2, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + { + name: 'v_mfma_f32_4x4x4f16', + mnemonic: 'f32_4x4x4f16', + m: 4, + n: 4, + k: 4, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + { + name: 'v_mfma_i32_16x16x16i8', + mnemonic: 'i32_16x16x16i8', + m: 16, + n: 16, + k: 16, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + }, + { + name: 'v_mfma_i32_16x16x4i8', + mnemonic: 'i32_16x16x4i8', + m: 16, + n: 16, + k: 4, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_i32_32x32x4i8', + mnemonic: 'i32_32x32x4i8', + m: 32, + n: 32, + k: 4, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_i32_32x32x8i8', + mnemonic: 'i32_32x32x8i8', + m: 32, + n: 32, + k: 8, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + }, + { + name: 'v_mfma_i32_4x4x4i8', + mnemonic: 'i32_4x4x4i8', + m: 4, + n: 4, + k: 4, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + ], + cdna2: [ + { + name: 'v_mfma_f32_16x16x16bf16_1k', + mnemonic: 'f32_16x16x16bf16_1k', + m: 16, + n: 16, + k: 16, + kBase: 4, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x16f16', + mnemonic: 'f32_16x16x16f16', + m: 16, + n: 16, + k: 16, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x1f32', + mnemonic: 'f32_16x16x1f32', + m: 16, + n: 16, + k: 1, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_f32_16x16x2bf16', + mnemonic: 'f32_16x16x2bf16', + m: 16, + n: 16, + k: 2, + kBase: 2, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_f32_16x16x4bf16_1k', + mnemonic: 'f32_16x16x4bf16_1k', + m: 16, + n: 16, + k: 4, + kBase: 4, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_f32_16x16x4f16', + mnemonic: 'f32_16x16x4f16', + m: 16, + n: 16, + k: 4, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_f32_16x16x4f32', + mnemonic: 'f32_16x16x4f32', + m: 16, + n: 16, + k: 4, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x8bf16', + mnemonic: 'f32_16x16x8bf16', + m: 16, + n: 16, + k: 8, + kBase: 2, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x1f32', + mnemonic: 'f32_32x32x1f32', + m: 32, + n: 32, + k: 1, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_f32_32x32x2bf16', + mnemonic: 'f32_32x32x2bf16', + m: 32, + n: 32, + k: 2, + kBase: 2, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_f32_32x32x2f32', + mnemonic: 'f32_32x32x2f32', + m: 32, + n: 32, + k: 2, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x4bf16', + mnemonic: 'f32_32x32x4bf16', + m: 32, + n: 32, + k: 4, + kBase: 2, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x4bf16_1k', + mnemonic: 'f32_32x32x4bf16_1k', + m: 32, + n: 32, + k: 4, + kBase: 4, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_f32_32x32x4f16', + mnemonic: 'f32_32x32x4f16', + m: 32, + n: 32, + k: 4, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_f32_32x32x8bf16_1k', + mnemonic: 'f32_32x32x8bf16_1k', + m: 32, + n: 32, + k: 8, + kBase: 4, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x8f16', + mnemonic: 'f32_32x32x8f16', + m: 32, + n: 32, + k: 8, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_4x4x1f32', + mnemonic: 'f32_4x4x1f32', + m: 4, + n: 4, + k: 1, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + { + name: 'v_mfma_f32_4x4x2bf16', + mnemonic: 'f32_4x4x2bf16', + m: 4, + n: 4, + k: 2, + kBase: 2, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + { + name: 'v_mfma_f32_4x4x4bf16_1k', + mnemonic: 'f32_4x4x4bf16_1k', + m: 4, + n: 4, + k: 4, + kBase: 4, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + { + name: 'v_mfma_f32_4x4x4f16', + mnemonic: 'f32_4x4x4f16', + m: 4, + n: 4, + k: 4, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + { + name: 'v_mfma_f64_16x16x4f64', + mnemonic: 'f64_16x16x4f64', + m: 16, + n: 16, + k: 4, + kBase: 1, + inputTypeA: 'fp64', + inputTypeB: 'fp64', + outputType: 'fp64', + }, + { + name: 'v_mfma_f64_4x4x4f64', + mnemonic: 'f64_4x4x4f64', + m: 4, + n: 4, + k: 4, + kBase: 1, + inputTypeA: 'fp64', + inputTypeB: 'fp64', + outputType: 'fp64', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_i32_16x16x16i8', + mnemonic: 'i32_16x16x16i8', + m: 16, + n: 16, + k: 16, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + }, + { + name: 'v_mfma_i32_16x16x4i8', + mnemonic: 'i32_16x16x4i8', + m: 16, + n: 16, + k: 4, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_i32_32x32x4i8', + mnemonic: 'i32_32x32x4i8', + m: 32, + n: 32, + k: 4, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_i32_32x32x8i8', + mnemonic: 'i32_32x32x8i8', + m: 32, + n: 32, + k: 8, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + }, + { + name: 'v_mfma_i32_4x4x4i8', + mnemonic: 'i32_4x4x4i8', + m: 4, + n: 4, + k: 4, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + ], + cdna3: [ + { + name: 'v_mfma_f32_16x16x16_bf16', + mnemonic: 'f32_16x16x16_bf16', + m: 16, + n: 16, + k: 16, + kBase: 4, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x16_f16', + mnemonic: 'f32_16x16x16_f16', + m: 16, + n: 16, + k: 16, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x1_4b_f32', + mnemonic: 'f32_16x16x1_4b_f32', + m: 16, + n: 16, + k: 1, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_f32_16x16x32_bf8_bf8', + mnemonic: 'f32_16x16x32_bf8_bf8', + m: 16, + n: 16, + k: 32, + kBase: 8, + inputTypeA: 'bf8', + inputTypeB: 'bf8', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x32_bf8_fp8', + mnemonic: 'f32_16x16x32_bf8_fp8', + m: 16, + n: 16, + k: 32, + kBase: 8, + inputTypeA: 'bf8', + inputTypeB: 'fp8', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x32_fp8_bf8', + mnemonic: 'f32_16x16x32_fp8_bf8', + m: 16, + n: 16, + k: 32, + kBase: 8, + inputTypeA: 'fp8', + inputTypeB: 'bf8', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x32_fp8_fp8', + mnemonic: 'f32_16x16x32_fp8_fp8', + m: 16, + n: 16, + k: 32, + kBase: 8, + inputTypeA: 'fp8', + inputTypeB: 'fp8', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x4_4b_bf16', + mnemonic: 'f32_16x16x4_4b_bf16', + m: 16, + n: 16, + k: 4, + kBase: 4, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_f32_16x16x4_4b_f16', + mnemonic: 'f32_16x16x4_4b_f16', + m: 16, + n: 16, + k: 4, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_f32_16x16x4_f32', + mnemonic: 'f32_16x16x4_f32', + m: 16, + n: 16, + k: 4, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_16x16x8_xf32', + mnemonic: 'f32_16x16x8_xf32', + m: 16, + n: 16, + k: 8, + kBase: 2, + inputTypeA: 'tf32', + inputTypeB: 'tf32', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x16_bf8_bf8', + mnemonic: 'f32_32x32x16_bf8_bf8', + m: 32, + n: 32, + k: 16, + kBase: 8, + inputTypeA: 'bf8', + inputTypeB: 'bf8', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x16_bf8_fp8', + mnemonic: 'f32_32x32x16_bf8_fp8', + m: 32, + n: 32, + k: 16, + kBase: 8, + inputTypeA: 'bf8', + inputTypeB: 'fp8', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x16_fp8_bf8', + mnemonic: 'f32_32x32x16_fp8_bf8', + m: 32, + n: 32, + k: 16, + kBase: 8, + inputTypeA: 'fp8', + inputTypeB: 'bf8', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x16_fp8_fp8', + mnemonic: 'f32_32x32x16_fp8_fp8', + m: 32, + n: 32, + k: 16, + kBase: 8, + inputTypeA: 'fp8', + inputTypeB: 'fp8', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x1_2b_f32', + mnemonic: 'f32_32x32x1_2b_f32', + m: 32, + n: 32, + k: 1, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_f32_32x32x2_f32', + mnemonic: 'f32_32x32x2_f32', + m: 32, + n: 32, + k: 2, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x4_2b_bf16', + mnemonic: 'f32_32x32x4_2b_bf16', + m: 32, + n: 32, + k: 4, + kBase: 4, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_f32_32x32x4_2b_f16', + mnemonic: 'f32_32x32x4_2b_f16', + m: 32, + n: 32, + k: 4, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_f32_32x32x4_xf32', + mnemonic: 'f32_32x32x4_xf32', + m: 32, + n: 32, + k: 4, + kBase: 2, + inputTypeA: 'tf32', + inputTypeB: 'tf32', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x8_bf16', + mnemonic: 'f32_32x32x8_bf16', + m: 32, + n: 32, + k: 8, + kBase: 4, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_32x32x8_f16', + mnemonic: 'f32_32x32x8_f16', + m: 32, + n: 32, + k: 8, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + }, + { + name: 'v_mfma_f32_4x4x1_16b_f32', + mnemonic: 'f32_4x4x1_16b_f32', + m: 4, + n: 4, + k: 1, + kBase: 1, + inputTypeA: 'fp32', + inputTypeB: 'fp32', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + { + name: 'v_mfma_f32_4x4x4_16b_bf16', + mnemonic: 'f32_4x4x4_16b_bf16', + m: 4, + n: 4, + k: 4, + kBase: 4, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + { + name: 'v_mfma_f32_4x4x4_16b_f16', + mnemonic: 'f32_4x4x4_16b_f16', + m: 4, + n: 4, + k: 4, + kBase: 4, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + { + name: 'v_mfma_f64_16x16x4_f64', + mnemonic: 'f64_16x16x4_f64', + m: 16, + n: 16, + k: 4, + kBase: 1, + inputTypeA: 'fp64', + inputTypeB: 'fp64', + outputType: 'fp64', + }, + { + name: 'v_mfma_f64_4x4x4_4b_f64', + mnemonic: 'f64_4x4x4_4b_f64', + m: 4, + n: 4, + k: 4, + kBase: 1, + inputTypeA: 'fp64', + inputTypeB: 'fp64', + outputType: 'fp64', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_i32_16x16x32_i8', + mnemonic: 'i32_16x16x32_i8', + m: 16, + n: 16, + k: 32, + kBase: 8, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + }, + { + name: 'v_mfma_i32_16x16x4_4b_i8', + mnemonic: 'i32_16x16x4_4b_i8', + m: 16, + n: 16, + k: 4, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + laneGroupsShareDimK: true, + registerBlocks: 4, + }, + { + name: 'v_mfma_i32_32x32x16_i8', + mnemonic: 'i32_32x32x16_i8', + m: 32, + n: 32, + k: 16, + kBase: 8, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + }, + { + name: 'v_mfma_i32_32x32x4_2b_i8', + mnemonic: 'i32_32x32x4_2b_i8', + m: 32, + n: 32, + k: 4, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + laneGroupsShareDimK: true, + registerBlocks: 2, + }, + { + name: 'v_mfma_i32_4x4x4_16b_i8', + mnemonic: 'i32_4x4x4_16b_i8', + m: 4, + n: 4, + k: 4, + kBase: 4, + inputTypeA: 'int8', + inputTypeB: 'int8', + outputType: 'int32', + laneGroupsShareDimK: true, + registerBlocks: 16, + }, + ], +} diff --git a/src/data/wmma-instruction-data.ts b/src/data/wmma-instruction-data.ts new file mode 100644 index 0000000..4d08364 --- /dev/null +++ b/src/data/wmma-instruction-data.ts @@ -0,0 +1,225 @@ +// WMMA instruction metadata shared across layouts and tooling. + +export interface WMMAInstructionEntry { + name: string + mnemonic: string + m: number + n: number + k: number + inputTypeA: string + inputTypeB: string + outputType: string + cycles: number + kWidth: number +} + +export const WMMA_INSTRUCTION_DATA: Record<'rdna3' | 'rdna4', WMMAInstructionEntry[]> = { + rdna3: [ + { + name: 'v_wmma_f32_16x16x16_f16', + mnemonic: 'f32_16x16x16_f16', + m: 16, + n: 16, + k: 16, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + cycles: 32, + kWidth: 16, + }, + { + name: 'v_wmma_f32_16x16x16_bf16', + mnemonic: 'f32_16x16x16_bf16', + m: 16, + n: 16, + k: 16, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + cycles: 32, + kWidth: 16, + }, + { + name: 'v_wmma_f16_16x16x16_f16', + mnemonic: 'f16_16x16x16_f16', + m: 16, + n: 16, + k: 16, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp16', + cycles: 32, + kWidth: 16, + }, + { + name: 'v_wmma_bf16_16x16x16_bf16', + mnemonic: 'bf16_16x16x16_bf16', + m: 16, + n: 16, + k: 16, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'bf16', + cycles: 32, + kWidth: 16, + }, + { + name: 'v_wmma_i32_16x16x16_iu8', + mnemonic: 'i32_16x16x16_iu8', + m: 16, + n: 16, + k: 16, + inputTypeA: 'iu8', + inputTypeB: 'iu8', + outputType: 'i32', + cycles: 32, + kWidth: 16, + }, + { + name: 'v_wmma_i32_16x16x16_iu4', + mnemonic: 'i32_16x16x16_iu4', + m: 16, + n: 16, + k: 16, + inputTypeA: 'iu4', + inputTypeB: 'iu4', + outputType: 'i32', + cycles: 16, + kWidth: 16, + }, + ], + rdna4: [ + { + name: 'v_wmma_f32_16x16x16_f16', + mnemonic: 'f32_16x16x16_f16', + m: 16, + n: 16, + k: 16, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp32', + cycles: 16, + kWidth: 8, + }, + { + name: 'v_wmma_f32_16x16x16_bf16', + mnemonic: 'f32_16x16x16_bf16', + m: 16, + n: 16, + k: 16, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'fp32', + cycles: 16, + kWidth: 8, + }, + { + name: 'v_wmma_f16_16x16x16_f16', + mnemonic: 'f16_16x16x16_f16', + m: 16, + n: 16, + k: 16, + inputTypeA: 'fp16', + inputTypeB: 'fp16', + outputType: 'fp16', + cycles: 16, + kWidth: 8, + }, + { + name: 'v_wmma_bf16_16x16x16_bf16', + mnemonic: 'bf16_16x16x16_bf16', + m: 16, + n: 16, + k: 16, + inputTypeA: 'bf16', + inputTypeB: 'bf16', + outputType: 'bf16', + cycles: 16, + kWidth: 8, + }, + { + name: 'v_wmma_i32_16x16x16_iu8', + mnemonic: 'i32_16x16x16_iu8', + m: 16, + n: 16, + k: 16, + inputTypeA: 'iu8', + inputTypeB: 'iu8', + outputType: 'i32', + cycles: 8, + kWidth: 8, + }, + { + name: 'v_wmma_i32_16x16x16_iu4', + mnemonic: 'i32_16x16x16_iu4', + m: 16, + n: 16, + k: 16, + inputTypeA: 'iu4', + inputTypeB: 'iu4', + outputType: 'i32', + cycles: 8, + kWidth: 8, + }, + { + name: 'v_wmma_i32_16x16x32_iu4', + mnemonic: 'i32_16x16x32_iu4', + m: 16, + n: 16, + k: 32, + inputTypeA: 'iu4', + inputTypeB: 'iu4', + outputType: 'i32', + cycles: 8, + kWidth: 16, + }, + { + name: 'v_wmma_f32_16x16x16_fp8_fp8', + mnemonic: 'f32_16x16x16_fp8_fp8', + m: 16, + n: 16, + k: 16, + inputTypeA: 'fp8', + inputTypeB: 'fp8', + outputType: 'fp32', + cycles: 8, + kWidth: 8, + }, + { + name: 'v_wmma_f32_16x16x16_fp8_bf8', + mnemonic: 'f32_16x16x16_fp8_bf8', + m: 16, + n: 16, + k: 16, + inputTypeA: 'fp8', + inputTypeB: 'bf8', + outputType: 'fp32', + cycles: 8, + kWidth: 8, + }, + { + name: 'v_wmma_f32_16x16x16_bf8_fp8', + mnemonic: 'f32_16x16x16_bf8_fp8', + m: 16, + n: 16, + k: 16, + inputTypeA: 'bf8', + inputTypeB: 'fp8', + outputType: 'fp32', + cycles: 8, + kWidth: 8, + }, + { + name: 'v_wmma_f32_16x16x16_bf8_bf8', + mnemonic: 'f32_16x16x16_bf8_bf8', + m: 16, + n: 16, + k: 16, + inputTypeA: 'bf8', + inputTypeB: 'bf8', + outputType: 'fp32', + cycles: 8, + kWidth: 8, + }, + ], +} diff --git a/src/layouts/MFMALayout.ts b/src/layouts/MFMALayout.ts new file mode 100644 index 0000000..3b6d1ee --- /dev/null +++ b/src/layouts/MFMALayout.ts @@ -0,0 +1,325 @@ +import { LinearLayout } from '../core/LinearLayout' +import { MFMA_INSTRUCTION_DATA, type MFMAInstructionEntry } from '../data/mfma-instruction-data' +import { + getElementSizeBits, + resolveOperandDataType, +} from './shared/PackingUtils' + +/** + * MFMA instruction metadata used by the visualization. + * + * Mirrors Triton's MFMA database (see + * third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp) while adding an + * explicit `version` tag per architecture (1 = CDNA1, 2 = CDNA2, 3 = CDNA3). + */ +export interface MFMAInstruction extends MFMAInstructionEntry { + version: number +} + +export { resolveOperandDataType } + +/** + * MFMA operand selector + */ +export type MFMAOperand = 'A' | 'B' | 'D' +export type MFMAArchitecture = 'cdna1' | 'cdna2' | 'cdna3' + +const MFMA_WAVE_SIZE = 64 +const MFMA_REGISTER_WIDTH = 32 +const MFMA_REGISTER_DIM = 'register' +const MFMA_LANE_DIM = 'lane' +const MFMA_DIM_M = 'dimM' +const MFMA_DIM_N = 'dimN' +const MFMA_DIM_K = 'dimK' + +function wrapToRange(value: number | undefined, limit: number): number { + if (typeof value !== 'number' || limit <= 0) { + return value ?? 0 + } + const wrapped = ((value % limit) + limit) % limit + return wrapped +} + +function usesShared4x4Registers(instr: MFMAInstruction): boolean { + return instr.m === 4 && instr.n === 4 +} + +function withVersion(version: number, defs: MFMAInstructionEntry[]): MFMAInstruction[] { + return defs.map(def => ({ ...def, version })) +} + +export const CDNA1_MFMA_INSTRUCTIONS: MFMAInstruction[] = withVersion( + 1, + MFMA_INSTRUCTION_DATA.cdna1 +) +export const CDNA2_MFMA_INSTRUCTIONS: MFMAInstruction[] = withVersion( + 2, + MFMA_INSTRUCTION_DATA.cdna2 +) +export const CDNA3_MFMA_INSTRUCTIONS: MFMAInstruction[] = withVersion( + 3, + MFMA_INSTRUCTION_DATA.cdna3 +) + +const MFMA_INSTRUCTIONS_BY_ARCH: Record = { + cdna1: CDNA1_MFMA_INSTRUCTIONS, + cdna2: CDNA2_MFMA_INSTRUCTIONS, + cdna3: CDNA3_MFMA_INSTRUCTIONS, +} + +function ensurePowerOfTwo(value: number, context: string): void { + if (value <= 0 || (value & (value - 1)) !== 0) { + throw new Error(`${context} must be a positive power of two (got ${value})`) + } +} + +function getLaneHeight(instr: MFMAInstruction): number { + const bitWidth = getElementSizeBits(instr.outputType) + return bitWidth === 64 ? 1 : 4 +} + +function getLayoutOutDimSize(layout: LinearLayout, dim: string, fallback: number): number { + return layout.hasOutDim(dim) ? layout.getOutDimSize(dim) : fallback +} + +function getLayoutInDimSize(layout: LinearLayout, dim: string): number { + return layout.hasInDim(dim) ? layout.getInDimSize(dim) : 1 +} + +/** + * Create the MFMA output layout (operand D) by emulating + * AMDMfmaEncodingAttr::toLinearLayout for a single wave64 tile. + */ +function createMFMAOutputLayout(instr: MFMAInstruction): LinearLayout { + const mDim = instr.m + const nDim = instr.n + + ensurePowerOfTwo(mDim, 'MFMA output M dimension') + ensurePowerOfTwo(nDim, 'MFMA output N dimension') + + if (MFMA_WAVE_SIZE % nDim !== 0) { + throw new Error(`MFMA output requires N dimension dividing wave size (got N=${nDim})`) + } + + const height = getLaneHeight(instr) + ensurePowerOfTwo(height, 'MFMA lane height') + + const regs = LinearLayout.identity1D(height, MFMA_REGISTER_DIM, MFMA_DIM_M) + const lanes = + LinearLayout.identity1D(nDim, MFMA_LANE_DIM, MFMA_DIM_N).multiply( + LinearLayout.identity1D(MFMA_WAVE_SIZE / nDim, MFMA_LANE_DIM, MFMA_DIM_M) + ) + + let tileLayout = regs.multiply(lanes) + + const tiles = Math.max(1, Math.trunc((mDim * nDim) / (MFMA_WAVE_SIZE * height))) + ensurePowerOfTwo(tiles, 'MFMA tile replication factor') + + if (tiles > 1) { + tileLayout = tileLayout.multiply(LinearLayout.identity1D(tiles, MFMA_REGISTER_DIM, MFMA_DIM_M)) + } + + return tileLayout +} + +/** + * Create an MFMA operand layout. + * + * - Operand D reuses the output layout above. + * - Operand A maps to M × K, operand B maps to K × N. + * + * The implementation mirrors mfmaDotToLinearLayout from Triton with a focus + * on a single wave64 tile (warps-per-CTA == tiles-per-warp == 1). + */ +export function createMFMAOperandLayout(instr: MFMAInstruction, operand: MFMAOperand): LinearLayout { + if (operand === 'D') { + return createMFMAOutputLayout(instr) + } + + const nonKDim = operand === 'A' ? instr.m : instr.n + ensurePowerOfTwo(nonKDim, 'MFMA operand non-K dimension') + + if (MFMA_WAVE_SIZE % nonKDim !== 0) { + throw new Error( + `MFMA operand ${operand} expects non-K dimension dividing wave size (got ${nonKDim})` + ) + } + + const kWidth = instr.kBase ?? instr.k + ensurePowerOfTwo(kWidth, 'MFMA operand K width') + + const dimNonK = operand === 'A' ? MFMA_DIM_M : MFMA_DIM_N + const dimK = MFMA_DIM_K + + const regs = LinearLayout.identity1D(kWidth, MFMA_REGISTER_DIM, dimK) + + const lanes = + LinearLayout.identity1D(nonKDim, MFMA_LANE_DIM, dimNonK).multiply( + LinearLayout.identity1D(MFMA_WAVE_SIZE / nonKDim, MFMA_LANE_DIM, dimK) + ) + + let tileLayout = regs.multiply(lanes) + let kTileSize = (MFMA_WAVE_SIZE / nonKDim) * kWidth + + const touches64x4Operand = + (instr.m === 64 && instr.n === 4 && operand === 'A') || + (instr.m === 4 && instr.n === 64 && operand === 'B') + + if (touches64x4Operand) { + const replication = 16 + tileLayout = tileLayout.multiply(LinearLayout.identity1D(replication, MFMA_REGISTER_DIM, dimK)) + kTileSize *= replication + } + + const kDimSize = instr.k + if (kDimSize > kTileSize) { + if (kDimSize % kTileSize !== 0) { + throw new Error( + `MFMA operand ${operand} requires K dimension (${kDimSize}) to be a multiple of tile size (${kTileSize})` + ) + } + const replication = Math.trunc(kDimSize / kTileSize) + ensurePowerOfTwo(replication, 'MFMA operand K replication factor') + tileLayout = tileLayout.multiply( + LinearLayout.identity1D(replication, MFMA_REGISTER_DIM, dimK) + ) + } + + return tileLayout +} + +/** + * Fetch MFMA instructions for a given architecture tag. + */ +export function getMFMAInstructions(architecture: MFMAArchitecture): MFMAInstruction[] { + const instructions = MFMA_INSTRUCTIONS_BY_ARCH[architecture] + if (!instructions) { + throw new Error(`Unsupported MFMA architecture: ${architecture}`) + } + return instructions +} + +export interface MFMALayoutParams { + architecture: MFMAArchitecture + instruction: MFMAInstruction + operand: MFMAOperand +} + +/** + * Build a LinearLayout for the requested MFMA operand. Acts as a small wrapper + * around createMFMAOperandLayout so callers can pass the full tab parameter bag. + */ +export function createMFMALayout(params: MFMALayoutParams): LinearLayout { + return createMFMAOperandLayout(params.instruction, params.operand) +} + +/** + * Enumerate all (row, column) pairs covered by a given thread. + * + * For operands A/B we present coordinates as (nonK, K) == (M, K) / (K, N) + * to match the Triton calculators and AMD tooling. + */ +export function getMFMAPositionsForThread( + layout: LinearLayout, + instruction: MFMAInstruction, + operand: MFMAOperand, + threadId: number +): Array<{ pos: [number, number]; registerId: number }> { + const positions: Array<{ pos: [number, number]; registerId: number }> = [] + + const lane = ((threadId % MFMA_WAVE_SIZE) + MFMA_WAVE_SIZE) % MFMA_WAVE_SIZE + const registerCount = getLayoutInDimSize(layout, MFMA_REGISTER_DIM) + const operandType = resolveOperandDataType(instruction, operand) + const elementBits = getElementSizeBits(operandType) + const registersPerElement = Math.max(1, Math.ceil(elementBits / MFMA_REGISTER_WIDTH)) + + const blockCount = Math.max(1, instruction.registerBlocks ?? 1) + const shouldWrapDimK = instruction.laneGroupsShareDimK || blockCount > 1 + const wrapDimM = (operand === 'D' || operand === 'A') && instruction.m > 0 + const wrapDimN = (operand === 'D' || operand === 'B') && instruction.n > 0 + const shareRegisterBlocks = operand === 'D' && usesShared4x4Registers(instruction) + const isFp64Shared = registersPerElement > 1 && usesShared4x4Registers(instruction) + const registerStride = shareRegisterBlocks ? 0 : registerCount * registersPerElement + + const dimMSz = getLayoutOutDimSize(layout, MFMA_DIM_M, instruction.m) + const dimNSz = getLayoutOutDimSize(layout, MFMA_DIM_N, instruction.n) + const dimKSz = getLayoutOutDimSize(layout, MFMA_DIM_K, instruction.k) + + const registerBlocks = operand === 'D' ? (shareRegisterBlocks ? 1 : blockCount) : 1 + + for (let reg = 0; reg < registerCount; reg++) { + const result = layout.apply({ + [MFMA_REGISTER_DIM]: reg, + [MFMA_LANE_DIM]: lane, + }) + + if ( + shouldWrapDimK && + typeof result[MFMA_DIM_K] === 'number' && + instruction.k > 0 + ) { + const currentK = result[MFMA_DIM_K] + const wrapped = ((currentK % instruction.k) + instruction.k) % instruction.k + result[MFMA_DIM_K] = wrapped + } + + if (wrapDimM) { + result[MFMA_DIM_M] = wrapToRange(result[MFMA_DIM_M], instruction.m) + } + if (wrapDimN) { + result[MFMA_DIM_N] = wrapToRange(result[MFMA_DIM_N], instruction.n) + } + + if (isFp64Shared) { + const laneGroup = Math.floor(lane / 16) + const laneQuad = lane % 4 + if (operand === 'A') { + result[MFMA_DIM_M] = wrapToRange(laneQuad, instruction.m) + result[MFMA_DIM_K] = wrapToRange(laneGroup, instruction.k) + } else if (operand === 'B') { + result[MFMA_DIM_K] = wrapToRange(laneGroup, instruction.k) + result[MFMA_DIM_N] = wrapToRange(laneQuad, instruction.n) + } else { + result[MFMA_DIM_M] = wrapToRange(laneGroup, instruction.m) + result[MFMA_DIM_N] = wrapToRange(laneQuad, instruction.n) + } + } + + let dim0 = 0 + let dim1 = 0 + let dim0Limit = 0 + let dim1Limit = 0 + + if (operand === 'D') { + dim0 = result[MFMA_DIM_M] ?? 0 + dim1 = result[MFMA_DIM_N] ?? 0 + dim0Limit = dimMSz + dim1Limit = dimNSz + } else if (operand === 'A') { + dim0 = result[MFMA_DIM_M] ?? 0 + dim1 = result[MFMA_DIM_K] ?? 0 + dim0Limit = dimMSz + dim1Limit = dimKSz + } else { + dim0 = result[MFMA_DIM_K] ?? 0 + dim1 = result[MFMA_DIM_N] ?? 0 + dim0Limit = dimKSz + dim1Limit = dimNSz + } + + if (dim0 >= 0 && dim0 < dim0Limit && dim1 >= 0 && dim1 < dim1Limit) { + for (let block = 0; block < registerBlocks; block++) { + const baseRegister = reg * registersPerElement + const blockOffset = block * registerStride + for (let chunk = 0; chunk < registersPerElement; chunk++) { + positions.push({ + pos: [dim0, dim1], + registerId: baseRegister + chunk + blockOffset, + }) + } + } + } + } + + return positions +} diff --git a/src/layouts/WMMALayout.ts b/src/layouts/WMMALayout.ts index a434cb5..00b224b 100644 --- a/src/layouts/WMMALayout.ts +++ b/src/layouts/WMMALayout.ts @@ -1,281 +1,49 @@ import { LinearLayout } from '../core/LinearLayout' +import { + WMMA_INSTRUCTION_DATA, + type WMMAInstructionEntry, +} from '../data/wmma-instruction-data' +import { + getElementSizeBits, + resolveOperandDataType, +} from './shared/PackingUtils' + +export { resolveOperandDataType } /** * WMMA instruction definition */ -export interface WMMAInstruction { - name: string - mnemonic: string - m: number - n: number - k: number - inputTypeA: string - inputTypeB: string - outputType: string - cycles: number - kWidth: number +export interface WMMAInstruction extends WMMAInstructionEntry { version: number // 1 for RDNA3, 2 for RDNA4 } +export type WMMAArchitecture = 'rdna3' | 'rdna4' + /** * WMMA operand type */ export type WMMAOperand = 'A' | 'B' | 'D' -const ELEMENT_TYPE_SIZES: Record = { - fp16: 16, - bf16: 16, - fp8: 8, - bf8: 8, - iu8: 8, - iu4: 4, - fp32: 32, - i32: 32, +const WMMA_ARCH_VERSION: Record = { + rdna3: 1, + rdna4: 2, } -export function resolveOperandDataType(instr: WMMAInstruction, operand: WMMAOperand): string { - if (operand === 'D') { - return instr.outputType - } - return operand === 'A' ? instr.inputTypeA : instr.inputTypeB +const WMMA_INSTRUCTION_CACHE: Record = { + rdna3: buildInstructionList('rdna3'), + rdna4: buildInstructionList('rdna4'), } -function getElementSizeBits(type: string): number { - return ELEMENT_TYPE_SIZES[type] ?? 32 -} +export const RDNA3_WMMA_INSTRUCTIONS: WMMAInstruction[] = WMMA_INSTRUCTION_CACHE.rdna3 +export const RDNA4_WMMA_INSTRUCTIONS: WMMAInstruction[] = WMMA_INSTRUCTION_CACHE.rdna4 -/** - * RDNA3 (gfx11) WMMA instructions - */ -export const RDNA3_WMMA_INSTRUCTIONS: WMMAInstruction[] = [ - { - name: 'v_wmma_f32_16x16x16_f16', - mnemonic: 'f32_16x16x16_f16', - m: 16, - n: 16, - k: 16, - inputTypeA: 'fp16', - inputTypeB: 'fp16', - outputType: 'fp32', - cycles: 32, - kWidth: 16, - version: 1, - }, - { - name: 'v_wmma_f32_16x16x16_bf16', - mnemonic: 'f32_16x16x16_bf16', - m: 16, - n: 16, - k: 16, - inputTypeA: 'bf16', - inputTypeB: 'bf16', - outputType: 'fp32', - cycles: 32, - kWidth: 16, - version: 1, - }, - { - name: 'v_wmma_f16_16x16x16_f16', - mnemonic: 'f16_16x16x16_f16', - m: 16, - n: 16, - k: 16, - inputTypeA: 'fp16', - inputTypeB: 'fp16', - outputType: 'fp16', - cycles: 32, - kWidth: 16, - version: 1, - }, - { - name: 'v_wmma_bf16_16x16x16_bf16', - mnemonic: 'bf16_16x16x16_bf16', - m: 16, - n: 16, - k: 16, - inputTypeA: 'bf16', - inputTypeB: 'bf16', - outputType: 'bf16', - cycles: 32, - kWidth: 16, - version: 1, - }, - { - name: 'v_wmma_i32_16x16x16_iu8', - mnemonic: 'i32_16x16x16_iu8', - m: 16, - n: 16, - k: 16, - inputTypeA: 'iu8', - inputTypeB: 'iu8', - outputType: 'i32', - cycles: 32, - kWidth: 16, - version: 1, - }, - { - name: 'v_wmma_i32_16x16x16_iu4', - mnemonic: 'i32_16x16x16_iu4', - m: 16, - n: 16, - k: 16, - inputTypeA: 'iu4', - inputTypeB: 'iu4', - outputType: 'i32', - cycles: 16, - kWidth: 16, - version: 1, - }, -] - -/** - * RDNA4 (gfx12) WMMA instructions - */ -export const RDNA4_WMMA_INSTRUCTIONS: WMMAInstruction[] = [ - { - name: 'v_wmma_f32_16x16x16_f16', - mnemonic: 'f32_16x16x16_f16', - m: 16, - n: 16, - k: 16, - inputTypeA: 'fp16', - inputTypeB: 'fp16', - outputType: 'fp32', - cycles: 16, - kWidth: 8, - version: 2, - }, - { - name: 'v_wmma_f32_16x16x16_bf16', - mnemonic: 'f32_16x16x16_bf16', - m: 16, - n: 16, - k: 16, - inputTypeA: 'bf16', - inputTypeB: 'bf16', - outputType: 'fp32', - cycles: 16, - kWidth: 8, - version: 2, - }, - { - name: 'v_wmma_f16_16x16x16_f16', - mnemonic: 'f16_16x16x16_f16', - m: 16, - n: 16, - k: 16, - inputTypeA: 'fp16', - inputTypeB: 'fp16', - outputType: 'fp16', - cycles: 16, - kWidth: 8, - version: 2, - }, - { - name: 'v_wmma_bf16_16x16x16_bf16', - mnemonic: 'bf16_16x16x16_bf16', - m: 16, - n: 16, - k: 16, - inputTypeA: 'bf16', - inputTypeB: 'bf16', - outputType: 'bf16', - cycles: 16, - kWidth: 8, - version: 2, - }, - { - name: 'v_wmma_i32_16x16x16_iu8', - mnemonic: 'i32_16x16x16_iu8', - m: 16, - n: 16, - k: 16, - inputTypeA: 'iu8', - inputTypeB: 'iu8', - outputType: 'i32', - cycles: 8, - kWidth: 8, - version: 2, - }, - { - name: 'v_wmma_i32_16x16x16_iu4', - mnemonic: 'i32_16x16x16_iu4', - m: 16, - n: 16, - k: 16, - inputTypeA: 'iu4', - inputTypeB: 'iu4', - outputType: 'i32', - cycles: 8, - kWidth: 8, - version: 2, - }, - { - name: 'v_wmma_i32_16x16x32_iu4', - mnemonic: 'i32_16x16x32_iu4', - m: 16, - n: 16, - k: 32, - inputTypeA: 'iu4', - inputTypeB: 'iu4', - outputType: 'i32', - cycles: 8, - kWidth: 16, - version: 2, - }, - { - name: 'v_wmma_f32_16x16x16_fp8_fp8', - mnemonic: 'f32_16x16x16_fp8_fp8', - m: 16, - n: 16, - k: 16, - inputTypeA: 'fp8', - inputTypeB: 'fp8', - outputType: 'fp32', - cycles: 8, - kWidth: 8, - version: 2, - }, - { - name: 'v_wmma_f32_16x16x16_fp8_bf8', - mnemonic: 'f32_16x16x16_fp8_bf8', - m: 16, - n: 16, - k: 16, - inputTypeA: 'fp8', - inputTypeB: 'bf8', - outputType: 'fp32', - cycles: 8, - kWidth: 8, - version: 2, - }, - { - name: 'v_wmma_f32_16x16x16_bf8_fp8', - mnemonic: 'f32_16x16x16_bf8_fp8', - m: 16, - n: 16, - k: 16, - inputTypeA: 'bf8', - inputTypeB: 'fp8', - outputType: 'fp32', - cycles: 8, - kWidth: 8, - version: 2, - }, - { - name: 'v_wmma_f32_16x16x16_bf8_bf8', - mnemonic: 'f32_16x16x16_bf8_bf8', - m: 16, - n: 16, - k: 16, - inputTypeA: 'bf8', - inputTypeB: 'bf8', - outputType: 'fp32', - cycles: 8, - kWidth: 8, - version: 2, - }, -] +function buildInstructionList(architecture: WMMAArchitecture): WMMAInstruction[] { + const version = WMMA_ARCH_VERSION[architecture] + return WMMA_INSTRUCTION_DATA[architecture].map((entry) => ({ + ...entry, + version, + })) +} /** * Create WMMA output (D) layout @@ -586,15 +354,15 @@ export function createWMMAOperandLayout( /** * Get all available WMMA instructions for an architecture */ -export function getWMMAInstructions(architecture: 'rdna3' | 'rdna4'): WMMAInstruction[] { - return architecture === 'rdna3' ? RDNA3_WMMA_INSTRUCTIONS : RDNA4_WMMA_INSTRUCTIONS +export function getWMMAInstructions(architecture: WMMAArchitecture): WMMAInstruction[] { + return WMMA_INSTRUCTION_CACHE[architecture].map((instr) => ({ ...instr })) } /** * Parameters for WMMA visualization */ export interface WMMALayoutParams { - architecture: 'rdna3' | 'rdna4' + architecture: WMMAArchitecture instruction: WMMAInstruction operand: WMMAOperand } diff --git a/src/layouts/WMMAPacking.ts b/src/layouts/WMMAPacking.ts index 74ee7fa..6615ae8 100644 --- a/src/layouts/WMMAPacking.ts +++ b/src/layouts/WMMAPacking.ts @@ -1,29 +1,32 @@ import type { WMMAInstruction, WMMAOperand } from './WMMALayout' - -export const REGISTER_WIDTH_BITS = 32 -export const DEFAULT_ELEMENT_BITS = 32 - -const ELEMENT_TYPE_BIT_SIZES: Record = { - f16: 16, - fp16: 16, - bf16: 16, - fp8: 8, - bf8: 8, - iu8: 8, - iu4: 4, - fp32: 32, - i32: 32, -} +import { + DEFAULT_ELEMENT_BITS as SHARED_DEFAULT_ELEMENT_BITS, + VGPR_REGISTER_WIDTH_BITS, + computeElementsPerRegister as computeElementsPerRegisterBase, +} from './shared/PackingUtils' + +export { + createPackingMetadata, + computePackingMetadata, + getLogicalElementMetadata, + decomposeLogicalIndex, + formatBitRange, + getElementSizeBits, +} from './shared/RegisterMetadata' +export type { + PackingComputationOptions, + PackingMetadata, + LogicalElementMetadata, +} from './shared/RegisterMetadata' + +export const REGISTER_WIDTH_BITS = VGPR_REGISTER_WIDTH_BITS +export const DEFAULT_ELEMENT_BITS = SHARED_DEFAULT_ELEMENT_BITS /** - * Look up the number of bits consumed by a logical element of the given type. + * WMMA-specific overrides for register packing behavior. Shared helpers live in + * src/layouts/shared/RegisterMetadata.ts so other architectures do not depend + * on WMMA implementation details. */ -export function getElementSizeBits(dataType: string | undefined): number { - if (!dataType) { - return DEFAULT_ELEMENT_BITS - } - return ELEMENT_TYPE_BIT_SIZES[dataType] ?? DEFAULT_ELEMENT_BITS -} /** * Determine how many logical elements are packed into a single physical VGPR. @@ -34,124 +37,7 @@ export function computeElementsPerRegister( elementBits: number ): number { const sanitizedBits = elementBits > 0 ? elementBits : DEFAULT_ELEMENT_BITS - const packableElements = Math.max(1, Math.floor(REGISTER_WIDTH_BITS / sanitizedBits)) + const packableElements = computeElementsPerRegisterBase(REGISTER_WIDTH_BITS, sanitizedBits) const requiresSingleElement = operand === 'D' && instruction.version < 2 return requiresSingleElement ? 1 : packableElements } - -/** - * Convert a logical element index into the physical VGPR identifier and the - * packed slot offset within that register. - */ -export function decomposeLogicalIndex( - logicalIndex: number, - elementsPerRegister: number -): { physicalRegister: number; elementOffset: number } { - if (elementsPerRegister <= 0) { - throw new Error('elementsPerRegister must be positive') - } - const elementOffset = - ((logicalIndex % elementsPerRegister) + elementsPerRegister) % elementsPerRegister - const physicalRegister = Math.floor(logicalIndex / elementsPerRegister) - return { physicalRegister, elementOffset } -} - -/** - * Format a bit-range string describing where the packed element resides. - */ -export function formatBitRange( - elementOffset: number, - elementBits: number, - elementsPerRegister: number -): string | null { - const sanitizedBits = elementBits > 0 ? elementBits : DEFAULT_ELEMENT_BITS - if (elementsPerRegister <= 1) { - if (sanitizedBits >= REGISTER_WIDTH_BITS) { - const fullWidthEndBit = REGISTER_WIDTH_BITS - 1 - return `[${fullWidthEndBit}:0]` - } - const endBit = sanitizedBits - 1 - if (endBit < 0) { - return null - } - return `[${endBit}:0]` - } - - const startBit = Math.max(0, elementOffset * sanitizedBits) - const endBit = startBit + sanitizedBits - 1 - if (endBit < startBit) { - return null - } - return `[${endBit}:${startBit}]` -} - -/** - * Convenience helper that returns both element bit width and packing factor. - */ -export function computePackingMetadata( - instruction: WMMAInstruction, - operand: WMMAOperand, - dataType: string | undefined -): { elementBits: number; elementsPerRegister: number } { - const elementBits = getElementSizeBits(dataType) - const elementsPerRegister = computeElementsPerRegister(instruction, operand, elementBits) - return { elementBits, elementsPerRegister } -} - -export interface PackingMetadata { - elementBits: number - elementsPerRegister: number -} - -export interface LogicalElementMetadata { - physicalRegister: number - elementOffset: number - bitRange: string | null - elementsPerRegister: number -} - -/** - * Create packing metadata for a given WMMA instruction/operand combination. - */ -export function createPackingMetadata( - instruction: WMMAInstruction, - operand: WMMAOperand, - dataType: string | undefined -): PackingMetadata { - return computePackingMetadata(instruction, operand, dataType) -} - -/** - * Translate a logical element index into physical packing metadata. - */ -export function getLogicalElementMetadata( - logicalIndex: number, - metadata: PackingMetadata -): LogicalElementMetadata | null { - if (logicalIndex < 0) { - return null - } - - const { elementsPerRegister, elementBits } = metadata - if (elementsPerRegister <= 0) { - return null - } - - let physicalRegister = 0 - let elementOffset = 0 - try { - const decomposition = decomposeLogicalIndex(logicalIndex, elementsPerRegister) - physicalRegister = decomposition.physicalRegister - elementOffset = decomposition.elementOffset - } catch { - return null - } - - const bitRange = formatBitRange(elementOffset, elementBits, elementsPerRegister) - return { - physicalRegister, - elementOffset, - bitRange, - elementsPerRegister, - } -} diff --git a/src/layouts/shared/PackingUtils.ts b/src/layouts/shared/PackingUtils.ts new file mode 100644 index 0000000..4df89b9 --- /dev/null +++ b/src/layouts/shared/PackingUtils.ts @@ -0,0 +1,115 @@ +/** + * Shared packing helpers for WMMA and MFMA visualizations. + * + * Consolidates common logic such as operand data type resolution, element + * bit-width lookups, logical index decomposition, and bit-range formatting. + */ + +export type MatmulOperand = 'A' | 'B' | 'D' + +export interface OperandTypeSources { + inputTypeA: string + inputTypeB: string + outputType: string +} + +export const DEFAULT_ELEMENT_BITS = 32 +export const VGPR_REGISTER_WIDTH_BITS = 32 + +const ELEMENT_BIT_WIDTHS: Record = { + f16: 16, + bf8: 8, + bf16: 16, + fp8: 8, + fp16: 16, + fp32: 32, + fp64: 64, + tf32: 32, + iu4: 4, + iu8: 8, + i32: 32, + int8: 8, + int32: 32, +} + +/** + * Map an operand selector to its resolved data type. + */ +export function resolveOperandDataType( + instruction: T, + operand: MatmulOperand +): string { + if (operand === 'D') { + return instruction.outputType + } + return operand === 'A' ? instruction.inputTypeA : instruction.inputTypeB +} + +/** + * Return the bit width associated with a given element type. + */ +export function getElementSizeBits(dataType: string | undefined): number { + if (!dataType) { + return DEFAULT_ELEMENT_BITS + } + return ELEMENT_BIT_WIDTHS[dataType] ?? DEFAULT_ELEMENT_BITS +} + +/** + * Determine how many logical elements fit inside a physical register. + */ +export function computeElementsPerRegister( + registerWidthBits: number, + elementBits: number +): number { + const sanitizedBits = elementBits > 0 ? elementBits : DEFAULT_ELEMENT_BITS + const registerWidth = registerWidthBits > 0 ? registerWidthBits : VGPR_REGISTER_WIDTH_BITS + return Math.max(1, Math.floor(registerWidth / sanitizedBits)) +} + +/** + * Convert a logical index into a physical register identifier and slot offset. + */ +export function decomposeLogicalIndex( + logicalIndex: number, + elementsPerRegister: number +): { physicalRegister: number; elementOffset: number } { + if (elementsPerRegister <= 0) { + throw new Error('elementsPerRegister must be positive') + } + const elementOffset = + ((logicalIndex % elementsPerRegister) + elementsPerRegister) % elementsPerRegister + const physicalRegister = Math.floor(logicalIndex / elementsPerRegister) + return { physicalRegister, elementOffset } +} + +/** + * Format the bit range representing a packed element slice within a register. + */ +export function formatBitRange( + elementOffset: number, + elementBits: number, + elementsPerRegister: number, + registerWidthBits: number = VGPR_REGISTER_WIDTH_BITS +): string | null { + const sanitizedBits = elementBits > 0 ? elementBits : DEFAULT_ELEMENT_BITS + if (elementsPerRegister <= 1) { + if (sanitizedBits >= registerWidthBits) { + return `[${Math.max(0, registerWidthBits - 1)}:0]` + } + const endBit = sanitizedBits - 1 + if (endBit < 0) { + return null + } + return `[${endBit}:0]` + } + + const startBit = Math.max(0, elementOffset * sanitizedBits) + const endBit = startBit + sanitizedBits - 1 + if (endBit < startBit) { + return null + } + const clampedEnd = Math.min(endBit, Math.max(0, registerWidthBits - 1)) + const clampedStart = Math.min(startBit, clampedEnd) + return `[${clampedEnd}:${clampedStart}]` +} diff --git a/src/layouts/shared/RegisterMetadata.ts b/src/layouts/shared/RegisterMetadata.ts new file mode 100644 index 0000000..c0c981f --- /dev/null +++ b/src/layouts/shared/RegisterMetadata.ts @@ -0,0 +1,114 @@ +/** + * Shared helpers for computing register packing metadata that apply to both + * WMMA and MFMA visualizations. + * + * These utilities intentionally avoid architecture-specific quirks so that a + * single source of truth can power UI components, verification scripts, and + * any future layouts that rely on consistent register metadata. + */ + +import { + DEFAULT_ELEMENT_BITS as PACKING_DEFAULT_ELEMENT_BITS, + VGPR_REGISTER_WIDTH_BITS, + computeElementsPerRegister as computeElementsPerRegisterBase, + decomposeLogicalIndex, + formatBitRange, + getElementSizeBits, + type MatmulOperand, + type OperandTypeSources, +} from './PackingUtils' + +export { decomposeLogicalIndex, formatBitRange, getElementSizeBits } +export const DEFAULT_ELEMENT_BITS = PACKING_DEFAULT_ELEMENT_BITS +export const REGISTER_WIDTH_BITS = VGPR_REGISTER_WIDTH_BITS + +/** + * Controls how packing metadata should be computed for a given operand. + */ +export interface PackingComputationOptions { + /** Override the physical register width in bits (defaults to VGPR width). */ + registerWidthBits?: number + /** Force the packing factor to one element per register. */ + forceSingleElement?: boolean +} + +/** + * Describes how a logical operand packs elements into physical registers. + */ +export interface PackingMetadata { + elementBits: number + elementsPerRegister: number +} + +/** + * Resolved metadata for a single logical element inside a register file. + */ +export interface LogicalElementMetadata { + physicalRegister: number + elementOffset: number + bitRange: string | null + elementsPerRegister: number +} + +/** + * Compute packing metadata for an operand independent of architecture quirks. + */ +export function computePackingMetadata( + _instruction: T, + _operand: MatmulOperand, + dataType: string | undefined, + options?: PackingComputationOptions +): PackingMetadata { + const elementBits = getElementSizeBits(dataType) + const registerWidthBits = options?.registerWidthBits ?? REGISTER_WIDTH_BITS + const baseElements = computeElementsPerRegisterBase(registerWidthBits, elementBits) + const elementsPerRegister = options?.forceSingleElement ? 1 : baseElements + return { elementBits, elementsPerRegister } +} + +/** + * Convenience wrapper that mirrors previous helpers and returns packed metadata. + */ +export function createPackingMetadata( + instruction: T, + operand: MatmulOperand, + dataType: string | undefined, + options?: PackingComputationOptions +): PackingMetadata { + return computePackingMetadata(instruction, operand, dataType, options) +} + +/** + * Translate a logical element index into physical register metadata. + */ +export function getLogicalElementMetadata( + logicalIndex: number, + metadata: PackingMetadata +): LogicalElementMetadata | null { + if (logicalIndex < 0) { + return null + } + + const { elementsPerRegister, elementBits } = metadata + if (elementsPerRegister <= 0) { + return null + } + + let physicalRegister = 0 + let elementOffset = 0 + try { + const decomposition = decomposeLogicalIndex(logicalIndex, elementsPerRegister) + physicalRegister = decomposition.physicalRegister + elementOffset = decomposition.elementOffset + } catch { + return null + } + + const bitRange = formatBitRange(elementOffset, elementBits, elementsPerRegister) + return { + physicalRegister, + elementOffset, + bitRange, + elementsPerRegister, + } +} diff --git a/src/main.tabs.test.ts b/src/main.tabs.test.ts index 5be77af..bd94933 100644 --- a/src/main.tabs.test.ts +++ b/src/main.tabs.test.ts @@ -74,6 +74,16 @@ vi.mock('./tabs/WMMALayoutTab', () => ({ WMMALayoutTab: WMMALayoutTabMock, })) +const MFMALayoutTabMock = vi.fn().mockImplementation(() => ({ + activate: vi.fn(), + deactivate: vi.fn(), + resize: vi.fn(), +})) + +vi.mock('./tabs/MFMALayoutTab', () => ({ + MFMALayoutTab: MFMALayoutTabMock, +})) + const setupDom = () => { document.body.innerHTML = '' vi.stubGlobal('alert', vi.fn()) @@ -238,6 +248,56 @@ const setupDom = () => { return tabContent } + const createMFMALayoutContent = () => { + const tabContent = document.createElement('div') + tabContent.className = 'tab-content' + tabContent.id = 'mfma-layout' + + const contentWrapper = document.createElement('div') + contentWrapper.className = 'content' + + const sidebar = document.createElement('aside') + sidebar.className = 'sidebar' + + const form = document.createElement('form') + form.id = 'mfmaForm' + + const architectureSelect = document.createElement('select') + architectureSelect.id = 'mfma-architecture' + architectureSelect.innerHTML = ` + + + + ` + form.appendChild(architectureSelect) + + const instructionSelect = document.createElement('select') + instructionSelect.id = 'mfma-instruction' + form.appendChild(instructionSelect) + + const operandSelect = document.createElement('select') + operandSelect.id = 'mfma-operand' + operandSelect.innerHTML = ` + + + + ` + form.appendChild(operandSelect) + + sidebar.appendChild(form) + + const controls = document.createElement('div') + controls.className = 'controls' + controls.setAttribute('data-controls', '') + sidebar.appendChild(controls) + + contentWrapper.appendChild(sidebar) + contentWrapper.appendChild(createVisualizationSection('mfma-canvas')) + + tabContent.appendChild(contentWrapper) + return tabContent + } + const container = document.createElement('div') container.className = 'container' document.body.appendChild(container) @@ -250,6 +310,7 @@ const setupDom = () => { { id: 'block-layout', label: 'Block Layout', active: true }, { id: 'shared-layout', label: 'Shared Layout' }, { id: 'wmma-layout', label: 'WMMA Layout' }, + { id: 'mfma-layout', label: 'MFMA Layout' }, { id: 'linear-layout', label: 'Linear Layout' }, { id: 'ck-layout', label: 'CK Layout' }, ] @@ -261,6 +322,7 @@ const setupDom = () => { container.appendChild(createBlockLayoutContent()) container.appendChild(createPlaceholderTabContent('shared-layout')) container.appendChild(createWMMALayoutContent()) + container.appendChild(createMFMALayoutContent()) container.appendChild(createPlaceholderTabContent('linear-layout')) container.appendChild(createPlaceholderTabContent('ck-layout')) diff --git a/src/main.ts b/src/main.ts index 4f880fb..1221243 100644 --- a/src/main.ts +++ b/src/main.ts @@ -1,5 +1,6 @@ import { BlockLayoutTab } from './tabs/BlockLayoutTab' import { WMMALayoutTab } from './tabs/WMMALayoutTab' +import { MFMALayoutTab } from './tabs/MFMALayoutTab' type TabController = { activate(): void @@ -20,6 +21,7 @@ if (tabContents.length === 0) { const controllers: Map = new Map() controllers.set('block-layout', new BlockLayoutTab('block-layout')) controllers.set('wmma-layout', new WMMALayoutTab('wmma-layout')) +controllers.set('mfma-layout', new MFMALayoutTab('mfma-layout')) let currentTabId: string | null = null diff --git a/src/tabs/MFMALayoutTab.ts b/src/tabs/MFMALayoutTab.ts new file mode 100644 index 0000000..f850f5d --- /dev/null +++ b/src/tabs/MFMALayoutTab.ts @@ -0,0 +1,268 @@ +import { + createMFMALayout, + getMFMAInstructions, + getMFMAPositionsForThread, + resolveOperandDataType, + type MFMAArchitecture, + type MFMAOperand, + type MFMALayoutParams, +} from '../layouts/MFMALayout' +import { + DEFAULT_ELEMENT_BITS, + createPackingMetadata, + getLogicalElementMetadata, + type PackingMetadata, +} from '../layouts/shared/RegisterMetadata' +import type { LinearLayout } from '../core/LinearLayout' +import { CanvasRenderer } from '../visualization/CanvasRenderer' +import { renderSharedControls } from '../ui/renderSharedControls' +import { CanvasTab, type CanvasTabElements } from './CanvasTab' + +/** + * Controller for MFMA layout visualizations. Mirrors the WMMA tab flow while + * adapting architecture/instruction choices and wave64 semantics. + */ +export class MFMALayoutTab extends CanvasTab { + private readonly architectureSelect: HTMLSelectElement + private readonly instructionSelect: HTMLSelectElement + private readonly operandSelect: HTMLSelectElement + + private packingMetadata: PackingMetadata = { + elementBits: DEFAULT_ELEMENT_BITS, + elementsPerRegister: 1, + } + + private static readonly WAVE_SIZE = 64 + + constructor(tabId: string) { + const tabContent = document.getElementById(tabId) + if (!tabContent) { + throw new Error(`MFMALayoutTab container not found: ${tabId}`) + } + + const visualizationContainer = tabContent.querySelector('.visualization') + if (!(visualizationContainer instanceof HTMLElement)) { + throw new Error('MFMALayoutTab visualization container not found') + } + + const canvas = visualizationContainer.querySelector('canvas') + if (!(canvas instanceof HTMLCanvasElement)) { + throw new Error('MFMALayoutTab canvas element not found') + } + + const controlsContainer = tabContent.querySelector('[data-controls]') + if (!(controlsContainer instanceof HTMLElement)) { + throw new Error('MFMALayoutTab controls container not found') + } + const resetButton = renderSharedControls(controlsContainer, { resetButtonId: 'mfma-reset' }) + + const form = tabContent.querySelector('#mfmaForm') + if (!(form instanceof HTMLFormElement)) { + throw new Error('MFMALayoutTab form element not found') + } + + const architectureSelect = form.querySelector('#mfma-architecture') + if (!(architectureSelect instanceof HTMLSelectElement)) { + throw new Error('MFMALayoutTab architecture select not found') + } + + const instructionSelect = form.querySelector('#mfma-instruction') + if (!(instructionSelect instanceof HTMLSelectElement)) { + throw new Error('MFMALayoutTab instruction select not found') + } + + const operandSelect = form.querySelector('#mfma-operand') + if (!(operandSelect instanceof HTMLSelectElement)) { + throw new Error('MFMALayoutTab operand select not found') + } + + const elements: CanvasTabElements = { + root: tabContent, + canvas, + visualizationContainer, + resetButton, + } + + super(elements) + + this.architectureSelect = architectureSelect + this.instructionSelect = instructionSelect + this.operandSelect = operandSelect + + this.setupEventListeners() + this.populateInstructionOptions() + this.updateVisualization() + } + + protected resetHover(): void { + this.hideTooltip() + } + + protected handleHover(event: MouseEvent): void { + const renderer = this.getRenderer() + if (!renderer) { + this.hideTooltip() + return + } + + const rect = this.canvas.getBoundingClientRect() + const x = event.clientX - rect.left + const y = event.clientY - rect.top + const gridPos = renderer.screenToGrid(x, y) + const cellInfo = renderer.getCellInfo(gridPos.row, gridPos.col) + + if (!cellInfo) { + this.hideTooltip() + return + } + + const elementMetadata = this.getRegisterMetadata(cellInfo.registerId) + const logicalIndex = `(${cellInfo.sourcePosition[0]}, ${cellInfo.sourcePosition[1]})` + + let registerLabel = 'N/A' + let bitRangeLabel = 'N/A' + let elementLabel = 'N/A' + + if (elementMetadata) { + registerLabel = `v${elementMetadata.physicalRegister}` + bitRangeLabel = elementMetadata.bitRange ?? 'N/A' + const elementIndex = + elementMetadata.physicalRegister * elementMetadata.elementsPerRegister + + elementMetadata.elementOffset + elementLabel = elementIndex.toString() + } + + const tooltipLines = [ + `
Logical Index: ${logicalIndex}
`, + `
Thread: ${cellInfo.threadId}
`, + `
Register: ${registerLabel}
`, + `
Bit Range: ${bitRangeLabel}
`, + `
Element: ${elementLabel}
`, + ] + + this.tooltip.show(tooltipLines.join(''), event.clientX, event.clientY) + } + + private setupEventListeners(): void { + this.architectureSelect.addEventListener('change', () => { + this.populateInstructionOptions() + this.updateVisualization() + }) + + this.instructionSelect.addEventListener('change', () => { + this.updateVisualization() + }) + + this.operandSelect.addEventListener('change', () => { + this.updateVisualization() + }) + } + + private populateInstructionOptions(): void { + const architecture = this.architectureSelect.value as MFMAArchitecture + const instructions = getMFMAInstructions(architecture) + + this.instructionSelect.innerHTML = '' + instructions.forEach((instr, idx) => { + const option = document.createElement('option') + option.value = idx.toString() + const inputTypeLabel = + instr.inputTypeA === instr.inputTypeB + ? instr.inputTypeA + : `${instr.inputTypeA}/${instr.inputTypeB}` + option.textContent = `${instr.mnemonic} (${instr.m}×${instr.n}×${instr.k}, ${inputTypeLabel} → ${instr.outputType})` + this.instructionSelect.appendChild(option) + }) + + this.instructionSelect.selectedIndex = 0 + } + + private getParams(): MFMALayoutParams { + const architecture = this.architectureSelect.value as MFMAArchitecture + const instructions = getMFMAInstructions(architecture) + const instructionIdx = parseInt(this.instructionSelect.value, 10) || 0 + const instruction = instructions[instructionIdx] + const operand = this.operandSelect.value as MFMAOperand + + if (!instruction) { + throw new Error('No MFMA instruction selected') + } + + return { + architecture, + instruction, + operand, + } + } + + private updateVisualization(): void { + try { + const params = this.getParams() + const dataType = resolveOperandDataType(params.instruction, params.operand) + this.updatePackingMetadata(params, dataType) + + this.resizeCanvas() + + const layout = createMFMALayout(params) + const tensorShape = this.getTensorShape(params) + + const blockParams = { + sizePerThread: [1, 1] as [number, number], + threadsPerWarp: [MFMALayoutTab.WAVE_SIZE, 1] as [number, number], + warpsPerCTA: [1, 1] as [number, number], + order: [0, 1] as [number, number], + tensorShape, + } + + const positionResolver = (layoutData: LinearLayout, threadId: number) => { + const positions = getMFMAPositionsForThread( + layoutData, + params.instruction, + params.operand, + threadId + ) + + return positions.map(({ pos, registerId }) => ({ + pos: [pos[0], pos[1]] as [number, number], + registerId, + sourcePos: [pos[0], pos[1]] as [number, number], + })) + } + + const renderer = new CanvasRenderer(this.canvas, layout, blockParams, positionResolver, { + colorGrouping: 'thread', + }) + this.setRenderer(renderer) + renderer.render() + } catch (error) { + this.packingMetadata = { + elementBits: DEFAULT_ELEMENT_BITS, + elementsPerRegister: 1, + } + console.error('Failed to create MFMA visualization:', error) + alert(`Failed to create MFMA visualization: ${error}`) + } + } + + private getTensorShape(params: MFMALayoutParams): [number, number] { + if (params.operand === 'D') { + return [params.instruction.m, params.instruction.n] + } + if (params.operand === 'A') { + return [params.instruction.m, params.instruction.k] + } + return [params.instruction.k, params.instruction.n] + } + + private updatePackingMetadata(params: MFMALayoutParams, dataType: string | undefined): void { + this.packingMetadata = createPackingMetadata( + params.instruction, + params.operand, + dataType + ) + } + + private getRegisterMetadata(logicalIndex: number) { + return getLogicalElementMetadata(logicalIndex, this.packingMetadata) + } +} diff --git a/src/tabs/WMMALayoutTab.ts b/src/tabs/WMMALayoutTab.ts index e533015..05e2c43 100644 --- a/src/tabs/WMMALayoutTab.ts +++ b/src/tabs/WMMALayoutTab.ts @@ -8,11 +8,11 @@ import { type WMMALayoutParams, } from '../layouts/WMMALayout' import { - DEFAULT_ELEMENT_BITS, createPackingMetadata, getLogicalElementMetadata, type PackingMetadata, -} from '../layouts/WMMAPacking' +} from '../layouts/shared/RegisterMetadata' +import { DEFAULT_ELEMENT_BITS } from '../layouts/WMMAPacking' import type { LinearLayout } from '../core/LinearLayout' import { CanvasRenderer } from '../visualization/CanvasRenderer' import { renderSharedControls } from '../ui/renderSharedControls' @@ -212,7 +212,14 @@ export class WMMALayoutTab extends CanvasTab { try { const params = this.getParams() const dataType = resolveOperandDataType(params.instruction, params.operand) - this.packingMetadata = createPackingMetadata(params.instruction, params.operand, dataType) + const forceSingleElement = + params.operand === 'D' && params.instruction.version < 2 + this.packingMetadata = createPackingMetadata( + params.instruction, + params.operand, + dataType, + { forceSingleElement } + ) const replicationConfig = this.getReplicationConfig(params) this.resizeCanvas() diff --git a/verification/build-helper.sh b/verification/build-helper.sh index 56723d8..3442aa3 100755 --- a/verification/build-helper.sh +++ b/verification/build-helper.sh @@ -7,20 +7,26 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" -echo "Building verification helper module..." +echo "Building verification helper modules..." cd "$PROJECT_ROOT" -# Compile the helper module to a bundled .mjs file -node_modules/.bin/esbuild \ - verification/lib/verification-helper.ts \ - --bundle \ - --platform=node \ - --format=esm \ - --target=es2020 \ - --outfile=verification/lib/verification-helper.mjs +declare -a ENTRYPOINTS=( + "verification/lib/verification-helper.ts:verification/lib/verification-helper.mjs" + "verification/lib/mfma-verification-helper.ts:verification/lib/mfma-verification-helper.mjs" +) -# Make it executable -chmod +x verification/lib/verification-helper.mjs +for entry in "${ENTRYPOINTS[@]}"; do + IFS=":" read -r SOURCE TARGET <<< "$entry" + echo " - Bundling $SOURCE -> $TARGET" + node_modules/.bin/esbuild \ + "$SOURCE" \ + --bundle \ + --platform=node \ + --format=esm \ + --target=es2020 \ + --outfile="$TARGET" + chmod +x "$TARGET" +done -echo "✓ Verification helper built successfully: verification/lib/verification-helper.mjs" +echo "✓ Verification helpers built successfully." diff --git a/verification/lib/mfma-verification-helper.ts b/verification/lib/mfma-verification-helper.ts new file mode 100644 index 0000000..f6c0165 --- /dev/null +++ b/verification/lib/mfma-verification-helper.ts @@ -0,0 +1,212 @@ +#!/usr/bin/env node +/** + * Verification helper for MFMA layouts. + * + * Mirrors the WMMA verification helper structure but targets MFMA instructions + * defined in src/layouts/MFMALayout.ts. Generates physical lane/register/bit + * mappings for consumption by the Python verification script. + */ + +import { + CDNA1_MFMA_INSTRUCTIONS, + CDNA2_MFMA_INSTRUCTIONS, + CDNA3_MFMA_INSTRUCTIONS, + createMFMAOperandLayout, + resolveOperandDataType, + type MFMAArchitecture, + type MFMAInstruction, + type MFMAOperand, +} from '../../src/layouts/MFMALayout' +import { + REGISTER_WIDTH_BITS as SHARED_REGISTER_WIDTH_BITS, + computePackingMetadata, + decomposeLogicalIndex, + formatBitRange, +} from '../../src/layouts/shared/RegisterMetadata' + +const MFMA_WAVE_SIZE = 64 +const REGISTER_WIDTH_BITS = SHARED_REGISTER_WIDTH_BITS +const REGISTER_DIM = 'register' +const LANE_DIM = 'lane' +const DIM_M = 'dimM' +const DIM_N = 'dimN' +const DIM_K = 'dimK' + +const ARCH_TO_INSTRUCTIONS: Record = { + cdna1: CDNA1_MFMA_INSTRUCTIONS, + cdna2: CDNA2_MFMA_INSTRUCTIONS, + cdna3: CDNA3_MFMA_INSTRUCTIONS, +} + +export interface LaneElementMapping { + elementIndex: number + lane: number + physicalRegister: number + bitRange: string +} + +export interface LayoutMappings { + [position: string]: LaneElementMapping[] +} + +function getMatrixPosition( + operand: MFMAOperand, + result: Record +): { row: number; col: number } { + if (operand === 'D') { + return { row: result[DIM_M] ?? 0, col: result[DIM_N] ?? 0 } + } + if (operand === 'A') { + return { row: result[DIM_M] ?? 0, col: result[DIM_K] ?? 0 } + } + return { row: result[DIM_K] ?? 0, col: result[DIM_N] ?? 0 } +} + +function wrapIndex(value: number | undefined, size: number): number | undefined { + if (typeof value !== 'number' || size <= 0) { + return value + } + const wrapped = ((value % size) + size) % size + return wrapped +} + +function usesShared4x4Registers(instruction: MFMAInstruction): boolean { + return instruction.m === 4 && instruction.n === 4 +} + +export function generateLayoutMappings( + instruction: MFMAInstruction, + operand: MFMAOperand +): LayoutMappings { + const layout = createMFMAOperandLayout(instruction, operand) + + const registerCount = layout.hasInDim(REGISTER_DIM) + ? layout.getInDimSize(REGISTER_DIM) + : instruction.k + const laneCount = layout.hasInDim(LANE_DIM) ? layout.getInDimSize(LANE_DIM) : MFMA_WAVE_SIZE + + const dataType = resolveOperandDataType(instruction, operand) + const baseMetadata = computePackingMetadata(instruction, operand, dataType, { + registerWidthBits: REGISTER_WIDTH_BITS, + }) + const registersPerElement = Math.max( + 1, + Math.ceil(baseMetadata.elementBits / REGISTER_WIDTH_BITS) + ) + const elementsPerRegister = + registersPerElement > 1 ? 1 : baseMetadata.elementsPerRegister + const elementBits = baseMetadata.elementBits + + const mappings: LayoutMappings = {} + + for (let lane = 0; lane < laneCount; lane++) { + for (let logicalIndex = 0; logicalIndex < registerCount; logicalIndex++) { + const result = layout.apply({ + [REGISTER_DIM]: logicalIndex, + [LANE_DIM]: lane, + }) + + if ( + instruction.laneGroupsShareDimK && + typeof result[DIM_K] === 'number' && + instruction.k > 0 + ) { + const currentK = result[DIM_K] + const wrapped = ((currentK % instruction.k) + instruction.k) % instruction.k + result[DIM_K] = wrapped + } + + const shouldWrapM = (operand === 'D' || operand === 'A') && instruction.m > 0 + const shouldWrapN = (operand === 'D' || operand === 'B') && instruction.n > 0 + if (shouldWrapM) { + result[DIM_M] = wrapIndex(result[DIM_M], instruction.m) + } + if (shouldWrapN) { + result[DIM_N] = wrapIndex(result[DIM_N], instruction.n) + } + + if (registersPerElement > 1 && usesShared4x4Registers(instruction)) { + const laneGroup = Math.floor(lane / 16) + const laneQuad = lane % 4 + if (operand === 'A') { + result[DIM_M] = wrapIndex(laneQuad, instruction.m) + result[DIM_K] = wrapIndex(laneGroup, instruction.k) + } else if (operand === 'B') { + result[DIM_K] = wrapIndex(laneGroup, instruction.k) + result[DIM_N] = wrapIndex(laneQuad, instruction.n) + } else { + result[DIM_M] = wrapIndex(laneGroup, instruction.m) + result[DIM_N] = wrapIndex(laneQuad, instruction.n) + } + } + + const { row, col } = getMatrixPosition(operand, result) + const { physicalRegister, elementOffset } = decomposeLogicalIndex( + logicalIndex, + elementsPerRegister + ) + const bitRange = formatBitRange(elementOffset, elementBits, elementsPerRegister) ?? '' + + const key = `${row}_${col}` + if (!mappings[key]) { + mappings[key] = [] + } + + const shareRegisters = operand === 'D' && usesShared4x4Registers(instruction) + const registerBlocks = operand === 'D' + ? shareRegisters + ? 1 + : Math.max(1, instruction.registerBlocks ?? 1) + : 1 + const registerStride = shareRegisters ? 0 : registerCount * registersPerElement + const baseRegister = + registersPerElement > 1 + ? Math.floor(physicalRegister) * registersPerElement + : physicalRegister + const fullRange = `[${REGISTER_WIDTH_BITS - 1}:0]` + const chunks = registersPerElement > 1 + ? Array.from({ length: registersPerElement }, (_, idx) => ({ offset: idx, range: fullRange })) + : [{ offset: 0, range: bitRange }] + + for (let block = 0; block < registerBlocks; block++) { + const blockOffset = block * registerStride + for (const chunk of chunks) { + mappings[key].push({ + elementIndex: logicalIndex + block * registerCount, + lane, + physicalRegister: baseRegister + chunk.offset + blockOffset, + bitRange: chunk.range, + }) + } + } + } + } + + return mappings +} + +export function getInstructionMetadata() { + return ARCH_TO_INSTRUCTIONS +} + +function main() { + const args = process.argv.slice(2) + if (args.length === 0) { + console.error('Usage: mfma-verification-helper.ts ') + process.exit(1) + } + + const config = JSON.parse(args[0]) + if (config.action === 'listInstructions') { + console.log(JSON.stringify(getInstructionMetadata())) + return + } + + const { instruction, operand } = config + const mappings = generateLayoutMappings(instruction, operand) + console.log(JSON.stringify(mappings)) +} + +if (import.meta.url === `file://${process.argv[1]}`) { + main() +} diff --git a/verification/lib/verification-helper.ts b/verification/lib/verification-helper.ts index b318c71..0cfdbbd 100644 --- a/verification/lib/verification-helper.ts +++ b/verification/lib/verification-helper.ts @@ -21,12 +21,12 @@ import { createWMMAOperandLayout, resolveOperandDataType, } from '../../src/layouts/WMMALayout' +import { computeElementsPerRegister } from '../../src/layouts/WMMAPacking' import { - computeElementsPerRegister, decomposeLogicalIndex, formatBitRange, getElementSizeBits, -} from '../../src/layouts/WMMAPacking' +} from '../../src/layouts/shared/RegisterMetadata' /** * Represents a single lane/element mapping with associated metadata diff --git a/verification/verify_mfma_layouts.py b/verification/verify_mfma_layouts.py new file mode 100755 index 0000000..1d53d3f --- /dev/null +++ b/verification/verify_mfma_layouts.py @@ -0,0 +1,827 @@ +#!/usr/bin/env python3 +""" +Clean MFMA layout verification script. + +Compares our TypeScript linear layout implementation against AMD's official calculator +by using a pre-built verification helper module (no dynamic code generation). + +The ground truth is based on ../amd_matrix_instruction_calculator +The goal is to ensure the MFMA linear layout is correct compared to the ground truth. +""" + +import argparse +import subprocess +import json +import sys +import re +from typing import Any, Dict, List, Tuple +from pathlib import Path + +OPERANDS = ['A', 'B', 'D'] + +REGISTER_WIDTH_BITS = 32 +FULL_REGISTER_BIT_RANGE = f'[{REGISTER_WIDTH_BITS - 1}:0]' + +# Project paths +SCRIPT_DIR = Path(__file__).parent +PROJECT_ROOT = SCRIPT_DIR.parent +HELPER_PATH = SCRIPT_DIR / 'lib' / 'mfma-verification-helper.mjs' + +def resolve_calculator_path() -> Path: + """ + Locate the AMD matrix calculator script. Prefer a sibling checkout (legacy layout) + but fall back to a submodule inside the project root for CI and developer workflows. + """ + candidate_dirs = [ + PROJECT_ROOT.parent / 'amd_matrix_instruction_calculator', + PROJECT_ROOT / 'amd_matrix_instruction_calculator', + ] + + for directory in candidate_dirs: + candidate = directory / 'matrix_calculator.py' + if candidate.exists(): + return candidate + + raise FileNotFoundError( + "Unable to locate amd_matrix_instruction_calculator. " + "Run 'git submodule update --init --recursive' from the project root." + ) + +CALCULATOR_PATH = resolve_calculator_path() + +BIT_RANGE_PATTERN = re.compile(r'^\[?(?P\d+)(?::(?P\d+))?\]?$', re.IGNORECASE) + + +def normalize_bit_range(bit_range: str) -> str: + """ + Normalize AMD bit range strings into a canonical format. + + Examples: + "15:0" -> "[15:0]" + "[0:15]" -> "[15:0]" + "[7]" -> "[7:7]" + "" -> "" + """ + cleaned = (bit_range or '').strip() + if not cleaned: + return '' + + cleaned = cleaned.replace(' ', '') + match = BIT_RANGE_PATTERN.match(cleaned) + if not match: + # Leave unknown formats untouched but ensure brackets for consistency + return cleaned if cleaned.startswith('[') else f'[{cleaned}]' + + first = int(match.group('first')) + second_group = match.group('second') + if second_group is None: + high = first + low = first + else: + second = int(second_group) + high = max(first, second) + low = min(first, second) + + return f'[{high}:{low}]' + + +def extract_bit_bounds(bit_range: str) -> Tuple[int, int]: + """ + Extract (high, low) bounds from a normalized bit range. + Returns (0, 0) for empty ranges. + """ + cleaned = (bit_range or '').strip() + if not cleaned: + return 0, 0 + + if cleaned.startswith('[') and cleaned.endswith(']'): + cleaned = cleaned[1:-1] + + if ':' not in cleaned: + value = int(cleaned) + return value, value + + high_str, low_str = cleaned.split(':', 1) + return int(high_str), int(low_str) + + +def bit_range_sort_key(bit_range: str) -> Tuple[int, int]: + """ + Sorting key that orders entries by their starting bit (low -> high) within a register. + """ + high, low = extract_bit_bounds(bit_range) + return low, high + + +def physical_key(entry: Dict[str, Any]) -> Tuple[int, int, str]: + """ + Create a comparable key for physical storage tuples. + """ + return entry['lane'], entry['register'], entry['bitRange'] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Verify MFMA operand layouts against AMD's matrix calculator." + ) + parser.add_argument( + '--arch', + choices=['cdna1', 'cdna2', 'cdna3', 'all'], + default='all', + help="Limit verification to a specific architecture." + ) + parser.add_argument( + '--instruction', + action='append', + default=None, + help="Only run specific instruction names (can be repeated or comma-separated)." + ) + parser.add_argument( + '--operand', + choices=['A', 'B', 'D', 'all'], + default='all', + help="Limit verification to a single operand." + ) + parser.add_argument( + '--verbose', + action='store_true', + help="Print detailed mismatch information for failing cases." + ) + parser.add_argument( + '--max-diffs', + type=int, + default=10, + help="Maximum number of mismatches to record and display per failing case." + ) + return parser.parse_args() + + +def ensure_helper_built(prefix: str = '') -> bool: + indent = prefix or '' + + if indent: + print(f"{indent}Building verification helper...") + else: + print("Building verification helper...") + + build_result = subprocess.run( + [str(SCRIPT_DIR / 'build-helper.sh')], + capture_output=True, + text=True, + cwd=str(PROJECT_ROOT) + ) + + if build_result.returncode != 0: + error_output = build_result.stderr.strip() or build_result.stdout.strip() + if indent: + print(f"{indent}Build error: {error_output}") + else: + print(f"Build error: {error_output}") + return False + + return True + + +def load_instruction_metadata() -> Dict[str, List[Dict]]: + if not ensure_helper_built(): + return {} + + result = subprocess.run( + ['node', str(HELPER_PATH), json.dumps({'action': 'listInstructions'})], + capture_output=True, + text=True, + cwd=str(PROJECT_ROOT) + ) + + if result.returncode != 0: + error_output = result.stderr.strip() or result.stdout.strip() + print(f"Failed to load instruction metadata: {error_output}") + return {} + + try: + data = json.loads(result.stdout.strip() or '{}') + except json.JSONDecodeError as exc: + print(f"Failed to parse instruction metadata JSON: {exc}") + return {} + + cdna1 = data.get('cdna1') + cdna2 = data.get('cdna2') + cdna3 = data.get('cdna3') + + if not all(isinstance(arr, list) for arr in (cdna1, cdna2, cdna3)): + print("Instruction metadata missing expected CDNA1/CDNA2/CDNA3 arrays.") + return {} + + return {'cdna1': cdna1, 'cdna2': cdna2, 'cdna3': cdna3} + + +def parse_lane_entries(cell_text: str) -> List[Dict[str, Any]]: + """ + Parse all (register, lane, bit range) entries from a calculator cell. + + Handles formats like: + - "v0{0}" -> [{'register': 0, 'lane': 0, 'bit_range': ''}] + - "v0{0}.[15:0]" -> [{'register': 0, 'lane': 0, 'bit_range': '[15:0]'}] + - "v0{0}.[15:0]
v0{16}.[15:0]" -> [{'register': 0, 'lane': 0, 'bit_range': '[15:0]'}, {'register': 0, 'lane': 16, 'bit_range': '[15:0]'}] + + Returns: List of entry dictionaries with register, lane, and bit_range. + """ + entries: List[Dict[str, Any]] = [] + + # Split on
to handle multi-lane cells + parts = cell_text.split('
') + + for part in parts: + part = part.strip() + if not part: + continue + + # Match register ranges like v[hi:lo]{lane} + range_match = re.match(r'v\[(\d+):(\d+)\]\{(\d+)\}(?:\.\[([^\]]+)\])?', part) + if range_match: + reg_hi = int(range_match.group(1)) + reg_lo = int(range_match.group(2)) + lane = int(range_match.group(3)) + raw_range = range_match.group(4) + if raw_range: + normalized_bit_range = raw_range if raw_range.startswith('[') else f'[{raw_range}]' + else: + normalized_bit_range = '' + start = min(reg_lo, reg_hi) + end = max(reg_lo, reg_hi) + for register in range(start, end + 1): + entries.append({ + 'register': register, + 'lane': lane, + 'bit_range': normalized_bit_range, + }) + continue + + # Match: v{reg}{lane} with optional .[bits] + match = re.match(r'v(\d+)\{(\d+)\}(?:\.\[([^\]]+)\])?', part) + if match: + register = int(match.group(1)) + lane = int(match.group(2)) + raw_range = match.group(3) + if raw_range: + normalized_bit_range = raw_range if raw_range.startswith('[') else f'[{raw_range}]' + else: + normalized_bit_range = '' + entries.append({ + 'register': register, + 'lane': lane, + 'bit_range': normalized_bit_range, + }) + + return entries + + +def run_amd_calculator( + arch: str, + instruction: Dict, + operand: str +) -> List[Dict[str, Any]]: + """ + Run the AMD matrix calculator and parse lane/register/bitRange tuples. + + Returns: + List of mappings with fields: + lane, register, bitRange, row, col + """ + instruction_name = instruction.get('name') + if not instruction_name: + return [] + calc_instruction = instruction_name.lower() + if not calc_instruction.startswith('v_'): + calc_instruction = f"v_{calc_instruction}" + + cmd = [ + 'python3', + str(CALCULATOR_PATH), + '-a', arch.upper(), + '-i', calc_instruction, + f'-{operand}', + '-R', + '--markdown', + '-w', '64' + ] + + result = subprocess.run(cmd, capture_output=True, text=True, cwd=str(PROJECT_ROOT)) + if result.returncode != 0: + print(f" Calculator error: {result.stderr}") + return [] + + lines = result.stdout.strip().split('\n') + entries: List[Dict[str, Any]] = [] + seen_keys: Dict[Tuple[int, int, str], Dict[str, Any]] = {} + + # Find the table header + table_start = -1 + table_header_pattern = re.compile(r'\[[MNK]\]\[[MNK]\]') + for i, line in enumerate(lines): + if '|' in line and table_header_pattern.search(line): + table_start = i + break + + if table_start == -1: + return [] + + # Parse data rows (skip header and separator line) + data_lines = lines[table_start + 2:] + + try: + for line in data_lines: + if '|' not in line or line.strip().startswith('|--'): + continue + + parts = [p.strip() for p in line.split('|')] + if len(parts) < 3: + continue + + row_str = parts[1].strip() + if not row_str or not row_str.lstrip('-').isdigit(): + # Skip header rows such as "A[M][K]" that appear when the calculator prints + # multiple blocks in a single table. + continue + row = int(row_str) + + for col_idx, cell in enumerate(parts[2:-1]): + cell = cell.strip() + if not cell or 'v' not in cell: + continue + + reg_lane_pairs = parse_lane_entries(cell) + if not reg_lane_pairs: + raise ValueError( + f"Unrecognized calculator cell '{cell}' at row {row}, column {col_idx}" + ) + + for parsed in reg_lane_pairs: + register = parsed.get('register') + lane = parsed.get('lane') + bit_range_raw = parsed.get('bit_range', '') + + if register is None or lane is None: + raise ValueError( + f"Calculator cell missing register/lane info at row {row}, column {col_idx}: '{cell}'" + ) + + if not bit_range_raw: + # Older calculator builds omit bit ranges for 32-bit operands. + normalized_bit_range = FULL_REGISTER_BIT_RANGE + else: + normalized_bit_range = normalize_bit_range(bit_range_raw) + + entry = { + 'lane': lane, + 'register': register, + 'bitRange': normalized_bit_range, + 'row': row, + 'col': col_idx + } + key = physical_key(entry) + if key in seen_keys: + existing = seen_keys[key] + raise ValueError( + "Duplicate physical storage tuple detected for " + f"lane {lane}, register v{register}, bitRange {normalized_bit_range}. " + f"Existing position {existing['row'], existing['col']}, " + f"new position {(row, col_idx)}" + ) + + seen_keys[key] = entry + entries.append(entry) + + except ValueError as exc: + print(f" Calculator parse error: {exc}") + return [] + + return entries + + +def get_our_layout(instruction: Dict, operand: str) -> List[Dict[str, Any]]: + """ + Get lane mappings from our TypeScript implementation. + + This uses the pre-built mfma verification helper module, which reuses + the layout logic from MFMALayout.ts without duplication. + + Returns: + List of mappings with fields: + lane, register, bitRange, row, col, elementIndex + """ + # Build helper if it doesn't exist + if not ensure_helper_built(prefix=' '): + return [] + + # Prepare config for helper + config = { + 'instruction': instruction, + 'operand': operand + } + + # Run the helper module + result = subprocess.run( + ['node', str(HELPER_PATH), json.dumps(config)], + capture_output=True, + text=True, + cwd=str(PROJECT_ROOT) + ) + + if result.returncode != 0: + print(f" Helper error: {result.stderr}") + return [] + + # Parse the JSON output + try: + data = json.loads(result.stdout.strip()) + except json.JSONDecodeError as exc: + print(f" JSON parse error: {exc}") + print(f" Output: {result.stdout[:200]}") + return [] + + entries: List[Dict[str, Any]] = [] + seen_keys: Dict[Tuple[int, int, str], Dict[str, Any]] = {} + + try: + for key, lane_entries in data.items(): + row_str, col_str = key.split('_') + row = int(row_str) + col = int(col_str) + + for lane_entry in lane_entries: + lane = lane_entry.get('lane') + register = lane_entry.get('physicalRegister') + bit_range_raw = lane_entry.get('bitRange', '') + element_index = lane_entry.get('elementIndex') + + if lane is None or register is None: + raise ValueError( + f"Helper entry missing lane/register (row {row}, col {col}): {lane_entry}" + ) + + normalized_bit_range = normalize_bit_range(bit_range_raw) + + entry = { + 'lane': int(lane), + 'register': int(register), + 'bitRange': normalized_bit_range, + 'row': row, + 'col': col, + 'elementIndex': int(element_index) if element_index is not None else None + } + + key_tuple = physical_key(entry) + if key_tuple in seen_keys: + existing = seen_keys[key_tuple] + raise ValueError( + "Implementation produced duplicate physical storage tuple for " + f"lane {lane}, register v{register}, bitRange {normalized_bit_range}. " + f"Existing position {existing['row'], existing['col']}, " + f"new position {(row, col)}" + ) + + seen_keys[key_tuple] = entry + entries.append(entry) + + except (ValueError, TypeError) as exc: + print(f" Helper parse error: {exc}") + return [] + + return entries + + +def compare_layouts( + calc_entries: List[Dict[str, Any]], + our_entries: List[Dict[str, Any]] +) -> Tuple[bool, List[Dict[str, Any]]]: + """ + Compare two layouts and return (is_match, mismatches). + + Each entry is compared using its physical tuple (lane, register, bitRange). + """ + calc_map = {physical_key(entry): entry for entry in calc_entries} + our_map = {physical_key(entry): entry for entry in our_entries} + + mismatches: List[Dict[str, Any]] = [] + + all_keys = set(calc_map.keys()) | set(our_map.keys()) + for key in sorted(all_keys): + calc_entry = calc_map.get(key) + our_entry = our_map.get(key) + + lane, register, bit_range = key + + if calc_entry and not our_entry: + mismatches.append({ + 'status': 'missing', + 'lane': lane, + 'register': register, + 'bitRange': bit_range, + 'calculator': [calc_entry['row'], calc_entry['col']] + }) + continue + + if our_entry and not calc_entry: + mismatches.append({ + 'status': 'extra', + 'lane': lane, + 'register': register, + 'bitRange': bit_range, + 'ours': [our_entry['row'], our_entry['col']] + }) + continue + + if not calc_entry or not our_entry: + # Should not reach here because of previous checks + continue + + calc_position = (calc_entry['row'], calc_entry['col']) + our_position = (our_entry['row'], our_entry['col']) + + if calc_position != our_position: + mismatches.append({ + 'status': 'mismatch', + 'lane': lane, + 'register': register, + 'bitRange': bit_range, + 'calculator': list(calc_position), + 'ours': list(our_position) + }) + + return len(mismatches) == 0, mismatches + + +def build_layout_index(entries: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: + """ + Group verified entries by matrix position for downstream tooling (e.g., tooltips). + """ + grouped: Dict[str, List[Dict[str, Any]]] = {} + for entry in entries: + key = f"{entry['row']}_{entry['col']}" + grouped.setdefault(key, []).append({ + 'lane': entry['lane'], + 'register': entry['register'], + 'bitRange': entry['bitRange'], + 'elementIndex': entry.get('elementIndex') + }) + + for value in grouped.values(): + value.sort(key=lambda item: (item['lane'], item['register'], bit_range_sort_key(item['bitRange']))) + + return grouped + + +def perform_logical_sanity_checks(entries: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Perform lightweight logical index checks to ensure our layout enumerations remain sane. + """ + issues: List[str] = [] + per_lane: Dict[int, List[Dict[str, Any]]] = {} + + for entry in entries: + element_index = entry.get('elementIndex') + if element_index is None: + continue + per_lane.setdefault(entry['lane'], []).append(entry) + + for lane, lane_entries in per_lane.items(): + logical_indices = sorted(entry['elementIndex'] for entry in lane_entries if entry.get('elementIndex') is not None) + expected = list(range(len(logical_indices))) + if logical_indices != expected: + issues.append( + f"Lane {lane}: expected logical indices {expected} but observed {logical_indices}" + ) + continue # Skip detailed checks if the basic range already failed + + sorted_by_physical = sorted( + lane_entries, + key=lambda item: (item['register'], bit_range_sort_key(item['bitRange'])) + ) + for idx, entry in enumerate(sorted_by_physical): + logical = entry.get('elementIndex') + if logical != idx: + issues.append( + f"Lane {lane}: logical index {logical} does not match physical ordering " + f"(expected {idx}) for v{entry['register']} {entry['bitRange']}" + ) + break + + status = 'PASS' if not issues else 'WARN' + return { + 'status': status, + 'issueCount': len(issues), + 'issues': issues[:10] # limit noise + } + + +def verify_single_case( + arch: str, + instr: Dict, + operand: str, + case_id: int, + total: int, + verbose: bool, + max_diffs: int +) -> Dict: + """Verify a single test case""" + print(f"[{case_id:03d}/{total:03d}] {arch.upper():5s} {instr['name']:30s} {operand} ... ", end='', flush=True) + + # Get ground truth from AMD calculator + calc_entries = run_amd_calculator(arch, instr, operand) + if not calc_entries: + print("FAIL (calculator error)") + return { + 'id': case_id, + 'arch': arch, + 'instruction': instr['name'], + 'operand': operand, + 'status': 'ERROR', + 'error': 'Failed to run AMD calculator or parse output' + } + + # Get our layout using the verification helper (reuses MFMALayout.ts) + our_entries = get_our_layout(instr, operand) + if not our_entries: + print("FAIL (implementation error)") + return { + 'id': case_id, + 'arch': arch, + 'instruction': instr['name'], + 'operand': operand, + 'status': 'ERROR', + 'error': 'Failed to generate our layout' + } + + # Compare layouts + is_match, mismatches = compare_layouts(calc_entries, our_entries) + logical_report = perform_logical_sanity_checks(our_entries) + + if is_match: + total_mappings = len(calc_entries) + unique_positions = len({(entry['row'], entry['col']) for entry in calc_entries}) + suffix = '' + if logical_report['status'] != 'PASS': + suffix = f", logical {logical_report['status'].lower()}" + print(f"PASS ({total_mappings} mappings verified{suffix})") + return { + 'id': case_id, + 'arch': arch, + 'instruction': instr['name'], + 'operand': operand, + 'status': 'PASS', + 'total_positions': unique_positions, + 'total_mappings': total_mappings, + 'physicalLayout': build_layout_index(our_entries), + 'logicalSanity': logical_report + } + else: + mismatch_count = len(mismatches) + total_missing = sum(1 for m in mismatches if m['status'] == 'missing') + total_extra = sum(1 for m in mismatches if m['status'] == 'extra') + print(f"FAIL ({mismatch_count} mapping mismatches)") + + total_relocated = mismatch_count - total_missing - total_extra + + if verbose: + limit = min(max_diffs, mismatch_count) + for idx, mismatch in enumerate(mismatches[:limit], 1): + lane = mismatch['lane'] + register = mismatch['register'] + bit_range = mismatch['bitRange'] or '[full]' + calc_pos = mismatch.get('calculator') + our_pos = mismatch.get('ours') + status = mismatch['status'] + calc_display = f"({calc_pos[0]}, {calc_pos[1]})" if calc_pos else '∅' + our_display = f"({our_pos[0]}, {our_pos[1]})" if our_pos else '∅' + print(f" [{idx:02d}] lane {lane:2d}, v{register:02d} {bit_range}") + print(f" calculator: {calc_display}") + print(f" ours: {our_display}") + if status == 'missing': + print(" issue: missing in our layout") + elif status == 'extra': + print(" issue: extra in our layout") + else: + print(" issue: mismatched matrix position") + if mismatch_count > limit: + remaining = mismatch_count - limit + print(f" ... {remaining} additional mismatches omitted (increase --max-diffs to show more)") + print(f" totals -> missing: {total_missing}, extra: {total_extra}, mismatched: {total_relocated}") + + return { + 'id': case_id, + 'arch': arch, + 'instruction': instr['name'], + 'operand': operand, + 'status': 'FAIL', + 'total_positions': len({(entry['row'], entry['col']) for entry in calc_entries}), + 'mismatch_count': mismatch_count, + 'missing_entries': total_missing, + 'extra_entries': total_extra, + 'mismatched_entries': total_relocated, + 'mismatches': mismatches[:max_diffs], + 'logicalSanity': logical_report + } + + +def main(): + args = parse_args() + + print("=" * 100) + print("MFMA LAYOUT VERIFICATION") + print("=" * 100) + print("Ground truth: AMD Matrix Instruction Calculator") + print("Implementation: MFMALayout.ts (via mfma-verification-helper.mjs)") + print("=" * 100) + print() + + instruction_filter = None + if args.instruction: + filtered: List[str] = [] + for value in args.instruction: + if not value: + continue + filtered.extend([item.strip().lower() for item in value.split(',') if item.strip()]) + instruction_filter = set(filtered) + + def include_case(arch_name: str, instruction: Dict, operand_name: str) -> bool: + if args.arch != 'all' and arch_name != args.arch: + return False + if args.operand != 'all' and operand_name != args.operand: + return False + if instruction_filter and instruction['name'].lower() not in instruction_filter: + return False + return True + + metadata = load_instruction_metadata() + if not metadata: + print("Unable to load MFMA instruction metadata from verification helper.") + return 1 + + arch_order = ['cdna1', 'cdna2', 'cdna3'] + test_cases = [] + for arch_name in arch_order: + instructions = metadata.get(arch_name, []) + if not isinstance(instructions, list): + continue + for instr in instructions: + for operand in OPERANDS: + if include_case(arch_name, instr, operand): + test_cases.append((arch_name, instr, operand)) + + if not test_cases: + print("No test cases matched the provided filters.") + return 1 + + max_diffs = max(1, args.max_diffs) + + # Run verification + results = [] + for i, (arch, instr, operand) in enumerate(test_cases, 1): + result = verify_single_case( + arch, + instr, + operand, + i, + len(test_cases), + verbose=args.verbose, + max_diffs=max_diffs + ) + results.append(result) + + # Save results + results_dir = SCRIPT_DIR / 'results' + results_dir.mkdir(exist_ok=True) + results_file = results_dir / 'mfma_verification_results.json' + with open(results_file, 'w') as f: + json.dump(results, f, indent=2) + + # Print summary + print() + print("=" * 100) + print("SUMMARY") + print("=" * 100) + passed = sum(1 for r in results if r['status'] == 'PASS') + failed = sum(1 for r in results if r['status'] == 'FAIL') + errors = sum(1 for r in results if r['status'] == 'ERROR') + + print(f"Total: {len(results)}") + print(f"PASS: {passed:3d} ({100 * passed // len(results):2d}%)") + print(f"FAIL: {failed:3d} ({100 * failed // len(results):2d}%)") + print(f"ERROR: {errors:3d} ({100 * errors // len(results):2d}%)") + + if failed > 0 or errors > 0: + print("\nFailed/Error cases:") + for r in results: + if r['status'] in ['FAIL', 'ERROR']: + status_str = r['status'] + error_detail = r.get('error', f"{r.get('mismatch_count', 0)} mismatches") + print(f" [{status_str:5s}] {r['arch'].upper():5s} {r['instruction']:30s} {r['operand']} - {error_detail}") + + print(f"\nDetailed results: {results_file}") + + return 0 if failed == 0 and errors == 0 else 1 + + +if __name__ == '__main__': + sys.exit(main())