Add gfx1201 (RDNA4) support with WMMA GEMM kernel#172
Draft
vivienfanghuagood wants to merge 5 commits intoROCm:mainfrom
Draft
Add gfx1201 (RDNA4) support with WMMA GEMM kernel#172vivienfanghuagood wants to merge 5 commits intoROCm:mainfrom
vivienfanghuagood wants to merge 5 commits intoROCm:mainfrom
Conversation
Port FlyDSL to Radeon 9700 (gfx1201, RDNA4, wave32):
Infrastructure changes:
- Add GFX1201 to GfxArchEnum in FlirRocmDialect.td
- Add rocminfo fallback in device.py for arch auto-detection
- Add gfx1201 LDS capacity (64KB) to smem_allocator.py
- Make wave64 flag conditional in compiler.py (false for RDNA wave32)
- Add gfx12 to FP8 E4M3FN type selection in types.py
Wave32 kernel adaptations:
- Add get_warp_size() helper in kernels_common.py (32 for RDNA, 64 for CDNA)
- Replace hardcoded WARP_SIZE=64 with dynamic detection in softmax,
layernorm, rmsnorm kernels
- Add wave32 shuffle offsets [16,8,4,2,1] in reduce.py
WMMA support (new):
- Add 11 WMMA convenience wrappers to rocdl.py matching the MFMA pattern
(f16, bf16, fp8 variants, int8/int4, with _op() and value-return forms)
- New kernels/wmma_gemm.py: tiled WMMA GEMM kernel using 2x2 WMMA tiles
per workgroup (32x32 block, 128 threads, 4 waves) with LDS staging
- Empirically verified WMMA wave32 data layout on gfx1201:
A: row-of-cols (lane t reads A[t%16][(t/16)*8+i])
B: col-of-rows (lane t reads B[(t/16)*8+i][t%16])
D: col-of-rows (same as B)
Test results on Radeon 9700 (gfx1201):
- All existing kernels pass (softmax, layernorm, rmsnorm, vec_add, etc.)
- WMMA basic test: max relative error 1.15e-07
- WMMA GEMM: all shapes 32x32x16 through 512x512x512 pass
- WMMA GEMM benchmark: 1.75 TFLOPS at 1024x1024x1024 (unoptimized)
Replace scalar B LDS reads (ds_load_u16 + ds_load_u16_d16_hi) and individual A vector loads with a single combined inline asm block per K-step. This eliminates the s_wait_dscnt 0xf serialization that LLVM inserts between every d16_hi pair (WAW hazard workaround). Combined asm block per K-step: 16x ds_load_u16 (B low halves) 4x ds_load_b128 (A tiles) s_wait_dscnt 0x4 (B done, A in flight) 16x ds_load_u16_d16_hi (B high halves) s_wait_dscnt 0x0 (all done) Also adds test utilities: inline asm load tests, Triton comparison benchmark, and profiling helper. Performance: ~57 TFLOPS at 4096x4096 bf16 on gfx1201 (Radeon 9700 XT).
The buffer resource descriptor DWORD3 flags were hardcoded to GFX9 (CDNA) format: data_format=7 at bits[14:12], num_format=4 at bits[17:15]. On RDNA (GFX10/11/12), the DWORD3 layout is different: - OOB_SELECT at bits[29:28] must be >= 1 (0 = structured mode, fails) - FORMAT at bits[21:14] This caused buffer_load to return all zeros on gfx1201 (RDNA4). Fix: detect GPU arch and use appropriate flags: - RDNA: OOB_SELECT=2 | FORMAT bit 13 = 0x20002000 - CDNA: existing 0x27000 (unchanged) Tested: scalar f32 and vector 4xf32 buffer loads now work on gfx1201.
Eliminate ~0.8ms per-call overhead from repeated hipModuleLoadData calls by caching loaded modules and functions using FNV-1a content hashing. - Cache hipModule_t by (content_hash, blob_size) key to handle JIT address reuse safely - Cache hipFunction_t per module+name pair - Cache hipStream_t (single stream, created once) - Make mgpuModuleUnload/mgpuStreamDestroy no-ops for cached resources - Add mgpuSetStream API for external stream injection (e.g. PyTorch) - Add .gitignore entries for rocprofv3 profiling artifacts
- arith.py: Add BF16Type import and support throughout arithmetic operations, improve _is_floating_point_type() to include BF16, code formatting cleanup - compiler.py: Read FLYDSL_LLC_OPTS environment variable to allow passing custom LLVM LLC options through the compilation pipeline
Collaborator
|
@vivienfanghuagood we will add a CI device for gfx1201 next week. We can have this enabled after that. Besides, we have pr/v0.1 branch and will get merged in days. it's a big reconstruction of the API and internal IR. Codes need some changes. Could you plz switch to that branch as base? |
Author
Thanks Felix, I will update my PR after your branch updated! Before this, I mark this PR as draft. |
c257731 to
d8842cc
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Port FlyDSL to Radeon 9700 (gfx1201, RDNA4, wave32):
Infrastructure changes:
Wave32 kernel adaptations:
WMMA support (new):
Test results on Radeon 9700 (gfx1201):
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist