Skip to content

Add gfx1201 (RDNA4) support with WMMA GEMM kernel#172

Draft
vivienfanghuagood wants to merge 5 commits intoROCm:mainfrom
vivienfanghuagood:gfx1201-wmma-support
Draft

Add gfx1201 (RDNA4) support with WMMA GEMM kernel#172
vivienfanghuagood wants to merge 5 commits intoROCm:mainfrom
vivienfanghuagood:gfx1201-wmma-support

Conversation

@vivienfanghuagood
Copy link

@vivienfanghuagood vivienfanghuagood commented Mar 4, 2026

Port FlyDSL to Radeon 9700 (gfx1201, RDNA4, wave32):

Infrastructure changes:

  • Add GFX1201 to GfxArchEnum in FlirRocmDialect.td
  • Add rocminfo fallback in device.py for arch auto-detection
  • Add gfx1201 LDS capacity (64KB) to smem_allocator.py
  • Make wave64 flag conditional in compiler.py (false for RDNA wave32)
  • Add gfx12 to FP8 E4M3FN type selection in types.py

Wave32 kernel adaptations:

  • Add get_warp_size() helper in kernels_common.py (32 for RDNA, 64 for CDNA)
  • Replace hardcoded WARP_SIZE=64 with dynamic detection in softmax, layernorm, rmsnorm kernels
  • Add wave32 shuffle offsets [16,8,4,2,1] in reduce.py

WMMA support (new):

  • Add 11 WMMA convenience wrappers to rocdl.py matching the MFMA pattern (f16, bf16, fp8 variants, int8/int4, with _op() and value-return forms)
  • New kernels/wmma_gemm.py: tiled WMMA GEMM kernel using 2x2 WMMA tiles per workgroup (32x32 block, 128 threads, 4 waves) with LDS staging
  • Empirically verified WMMA wave32 data layout on gfx1201: A: row-of-cols (lane t reads A[t%16][(t/16)*8+i]) B: col-of-rows (lane t reads B[(t/16)*8+i][t%16]) D: col-of-rows (same as B)

Test results on Radeon 9700 (gfx1201):

  • All existing kernels pass (softmax, layernorm, rmsnorm, vec_add, etc.)
  • WMMA basic test: max relative error 1.15e-07
  • WMMA GEMM: all shapes 32x32x16 through 512x512x512 pass
  • WMMA GEMM benchmark: 1.75 TFLOPS at 1024x1024x1024 (unoptimized)

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Port FlyDSL to Radeon 9700 (gfx1201, RDNA4, wave32):

Infrastructure changes:
- Add GFX1201 to GfxArchEnum in FlirRocmDialect.td
- Add rocminfo fallback in device.py for arch auto-detection
- Add gfx1201 LDS capacity (64KB) to smem_allocator.py
- Make wave64 flag conditional in compiler.py (false for RDNA wave32)
- Add gfx12 to FP8 E4M3FN type selection in types.py

Wave32 kernel adaptations:
- Add get_warp_size() helper in kernels_common.py (32 for RDNA, 64 for CDNA)
- Replace hardcoded WARP_SIZE=64 with dynamic detection in softmax,
  layernorm, rmsnorm kernels
- Add wave32 shuffle offsets [16,8,4,2,1] in reduce.py

WMMA support (new):
- Add 11 WMMA convenience wrappers to rocdl.py matching the MFMA pattern
  (f16, bf16, fp8 variants, int8/int4, with _op() and value-return forms)
- New kernels/wmma_gemm.py: tiled WMMA GEMM kernel using 2x2 WMMA tiles
  per workgroup (32x32 block, 128 threads, 4 waves) with LDS staging
- Empirically verified WMMA wave32 data layout on gfx1201:
    A: row-of-cols (lane t reads A[t%16][(t/16)*8+i])
    B: col-of-rows (lane t reads B[(t/16)*8+i][t%16])
    D: col-of-rows (same as B)

Test results on Radeon 9700 (gfx1201):
- All existing kernels pass (softmax, layernorm, rmsnorm, vec_add, etc.)
- WMMA basic test: max relative error 1.15e-07
- WMMA GEMM: all shapes 32x32x16 through 512x512x512 pass
- WMMA GEMM benchmark: 1.75 TFLOPS at 1024x1024x1024 (unoptimized)
Replace scalar B LDS reads (ds_load_u16 + ds_load_u16_d16_hi) and
individual A vector loads with a single combined inline asm block per
K-step. This eliminates the s_wait_dscnt 0xf serialization that LLVM
inserts between every d16_hi pair (WAW hazard workaround).

Combined asm block per K-step:
  16x ds_load_u16 (B low halves)
  4x ds_load_b128 (A tiles)
  s_wait_dscnt 0x4 (B done, A in flight)
  16x ds_load_u16_d16_hi (B high halves)
  s_wait_dscnt 0x0 (all done)

Also adds test utilities: inline asm load tests, Triton comparison
benchmark, and profiling helper.

Performance: ~57 TFLOPS at 4096x4096 bf16 on gfx1201 (Radeon 9700 XT).
The buffer resource descriptor DWORD3 flags were hardcoded to GFX9 (CDNA)
format: data_format=7 at bits[14:12], num_format=4 at bits[17:15].

On RDNA (GFX10/11/12), the DWORD3 layout is different:
- OOB_SELECT at bits[29:28] must be >= 1 (0 = structured mode, fails)
- FORMAT at bits[21:14]

This caused buffer_load to return all zeros on gfx1201 (RDNA4).

Fix: detect GPU arch and use appropriate flags:
- RDNA: OOB_SELECT=2 | FORMAT bit 13 = 0x20002000
- CDNA: existing 0x27000 (unchanged)

Tested: scalar f32 and vector 4xf32 buffer loads now work on gfx1201.
Eliminate ~0.8ms per-call overhead from repeated hipModuleLoadData calls
by caching loaded modules and functions using FNV-1a content hashing.

- Cache hipModule_t by (content_hash, blob_size) key to handle JIT
  address reuse safely
- Cache hipFunction_t per module+name pair
- Cache hipStream_t (single stream, created once)
- Make mgpuModuleUnload/mgpuStreamDestroy no-ops for cached resources
- Add mgpuSetStream API for external stream injection (e.g. PyTorch)
- Add .gitignore entries for rocprofv3 profiling artifacts
- arith.py: Add BF16Type import and support throughout arithmetic
  operations, improve _is_floating_point_type() to include BF16,
  code formatting cleanup
- compiler.py: Read FLYDSL_LLC_OPTS environment variable to allow
  passing custom LLVM LLC options through the compilation pipeline
@coderfeli
Copy link
Collaborator

@vivienfanghuagood we will add a CI device for gfx1201 next week. We can have this enabled after that. Besides, we have pr/v0.1 branch and will get merged in days. it's a big reconstruction of the API and internal IR. Codes need some changes. Could you plz switch to that branch as base?

@vivienfanghuagood vivienfanghuagood marked this pull request as draft March 7, 2026 11:38
@vivienfanghuagood
Copy link
Author

@vivienfanghuagood we will add a CI device for gfx1201 next week. We can have this enabled after that. Besides, we have pr/v0.1 branch and will get merged in days. it's a big reconstruction of the API and internal IR. Codes need some changes. Could you plz switch to that branch as base?

Thanks Felix, I will update my PR after your branch updated! Before this, I mark this PR as draft.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants