-
Notifications
You must be signed in to change notification settings - Fork 19
Description
Motivation
AITER is adding a CK-free build mode (ENABLE_CK=0) that removes Composable Kernel dependencies, reducing build time from ~35min to ~8min. ATOM needs corresponding changes to:
- Route around CK-dependent code paths when CK kernels are unavailable
- Maintain performance by preferring ASM PA (paged attention) and Triton kernels where possible
- Support clean Docker builds from pre-built wheels (zero compilation)
- Fix bugs discovered during CK-free validation
All changes are backward compatible — when ATOM_CK_FREE=0 (default), ATOM uses CK/ASM kernels as before.
Validated Locally
- MI300X (gfx942): Llama-3.1-8B, DeepSeek-R1-671B inference
- MI355X (gfx950): MXFP4 MOE with Swiglu paths
- Clean Docker (zero-compile from wheels): Llama-3.1-8B decode 80-82% of public image perf
Proposed PR Sequence
PR 1: Bug Fixes (Standalone, No CK-Free Dependency)
Files: atom/model_ops/linear.py, atom/model_ops/fused_moe_triton.py, atom/model_engine/scheduler.py
Individual fixes:
linear.py:UnboundLocalErrorinMergedReplicatedLinear.weight_loader()— missing else clause forper_Tokenandper_1x32quant types. Computeshard_offsetfromself.output_sizes.fused_moe_triton.py: UseCDNA4MXScaleLayoutinstead ofGFX950MXScaleLayoutfor gfx950 arch detection (more accurate naming).fused_moe_triton.py: Addupdate_opt_flags_constraints({"block_m": 128})for MI355X — default CDNA4block_m=256exceeds 160KB LDS limit.scheduler.py: Initializenum_rejected=0to preventUnboundLocalErrorin non-speculative path.
Dependency: None — these fix real bugs independent of CK-free work.
PR 2: MHA Attention Dispatch Decoupling
Files: atom/model_ops/attention_mha.py, atom/utils/envs.py
Key changes:
- Add
ATOM_CK_FREEenv var (default=0) as master switch - Decouple cache update from paged attention backend selection:
- Cache update: Always use Triton fused rope+cache (fast, no module_cache JIT dependency)
- Paged attention: Independently select ASM PA for decode when
head_dim=128and no sliding window
AITER_FORCE_TRITON_ATTNenv override for forcing Triton PA- FP8 KV cache: fill per-token scale buffers with uniform per-tensor scale so ASM PA can dequantize correctly
- Prefill: always use
prefill_attention_triton(no CK/flash_attn_varlen_func dependency) - Move
kv_scaletensor to CUDA at init for graph capture compatibility - Create proper
block_tablesfor prefill withblock_size=1
Dependency: AITER PR 1 (CK-free build gating)
PR 3: MLA CK-Free Paths
Files: atom/model_ops/attention_mla.py, atom/model_ops/attentions/aiter_mla.py
Key changes:
ATOM_USE_TRITON_MLA_DECODEorATOM_CK_FREE→ force Triton MLA decode usingdecode_attention_fwd_grouped_rope(AITER Triton kernel)- MLA prefill fallback: Replace
flash_attn_varlen_funcwith PyTorchF.scaled_dot_product_attention(loops over sequences) - FP8 constraint: only use fp8 scales when
max_seqlen_q == 1(mla_decode_fwd limitation) - Type casting: convert Q/KV to model dtype before
mla_prefill_fwdif dtype mismatch aiter_mla.py: Build paged KV metadata (kv_indptr,kv_indices, block_tables) for MLA prefill paths
Dependency: PR 2
PR 4: MOE Cascade Routing (CK → FlyDSL → Triton)
Files: atom/model_ops/moe.py, atom/model_ops/flydsl_moe.py (NEW)
Key changes:
- Detection functions:
_has_ck_moe_sorting(),_has_flydsl_moe()with caching - Cascade logic for
Fp8MoEMethod/CompressedTensorsFp8MoEMethod:- Check CK MOE sorting availability
- If unavailable or
ATOM_CK_FREE=1: try FlyDSL MOE (ATOM_USE_FLYDSL_MOE=1) - Else: fall back to Triton MOE (
_triton_fp8_moe())
- Weight shuffle skip: Triton expects standard row-major weights (not CK's shuffled layout)
flydsl_moe.py(NEW): FlyDSL MOE backend with torch-native sorting, per-token FP8 quant, 5-stage GEMM pipeline_triton_fp8_moe(): Complete Triton MOE pipeline (sort → GEMM1+SiLU → GEMM2)_per_token_group_quant_fp8(): Per-token-group FP8 quantization helper
Dependency: PR 2 (for ATOM_CK_FREE env var)
PR 5: Docker Infrastructure
Files: docker/Dockerfile, docker/Dockerfile.clean (NEW), docker/Dockerfile.wheels (NEW), .dockerignore (NEW)
Key changes:
Dockerfile:ARG ENABLE_CK=1parameter, conditionalgit submodule, pass to AITER setup.py, installtriton_kernelswheelDockerfile.wheels(NEW, ~160 lines): Multi-stage builder — PyTorch ROCm 7.2, Triton 3.5.x, FlyDSL, MORI, AITER (ENABLE_CK=0)Dockerfile.clean(NEW, ~70 lines): Zero-compilation runtime from pre-built wheels via bind-mount.dockerignore(NEW): Exclude.git/,build/,dist/— reduces context from 67.9GB to 37.9GB
Build time comparison:
| Image | Build Time | Size |
|---|---|---|
| Current (full CK) | ~60 min | Large |
| Dockerfile.wheels | ~60 min (one-time) | Wheels only |
| Dockerfile.clean | ~10 min | Minimal runtime |
Dependency: AITER PR 1 (ENABLE_CK support)
PR 6: Test Suite & CI (Nice-to-have)
Files: tests/test_ck_free_mode.py, tests/test_flydsl_moe.py, tests/test_attention_dispatch.py, tests/test_mla_prefill_routing.py, tests/test_aiter_mla_metadata.py, tests/test_moe_shapes.py (all NEW), pyproject.toml, .github/workflows/pre-checks.yaml, .github/workflows/atom-test.yaml
Key changes:
- 6 new test files (~970 lines) covering: env var detection, MOE routing, MHA dispatch, MLA prefill/decode, metadata construction, MOE shapes
pyproject.toml:gpupytest marker for tagging GPU-requiring testspre-checks.yaml: Addunit-testsjob running CPU-only pytest (-m "not gpu") — works on forks without GPU runnersatom-test.yaml: Guard to only run on ROCm/ATOM (not forks), re-enable golden output tests
Dependency: PR 2-4
PR 7: CI Nightly Sync (Nice-to-have)
Files: .github/workflows/sync-upstream.yaml (NEW), scripts/test_golden_output.sh (NEW)
- Nightly scheduled sync of fork main with upstream
- Golden output comparison script for CI regression testing
Dependency: None
Execution Path Summary
When ATOM_CK_FREE=1:
MHA: Triton fused rope+cache → ASM PA (head_dim=128) or Triton PA
MLA: PyTorch SDPA prefill → Triton decode (decode_attention_fwd_grouped_rope)
MOE: torch-native sorting → FlyDSL GEMM (if available) or Triton GEMM
When ATOM_CK_FREE=0 (default): unchanged behavior using CK/ASM kernels.
Performance Summary (MI300X, Llama-3.1-8B)
| Configuration | Decode tok/s | vs Public |
|---|---|---|
| Public image (full CK+ASM) | ~8,200 | 100% |
| CK-free, ASM PA, fp8 KV | ~6,830 | ~83% |
| CK-free, Triton PA, bf16 KV | ~6,255 | ~76% |
Remaining gap primarily due to: (1) no ASM GEMM for decode (M=1), (2) tuned GEMM CSV coverage only M≤256.
Open Questions
- Should FlyDSL MOE be the default fallback, or should we wait for more validation?
- Should Dockerfile.clean/Dockerfile.wheels live in ATOM or a separate build-infra repo?
- Priority of ASM GEMM re-enablement (v2 clean Docker) vs other optimizations?
Related
- AITER RFC: (will link after creation)