Skip to content

[RFC] CK-Free AITER Compatibility: Attention, MOE, and Docker Changes #229

@sunway513

Description

@sunway513

Motivation

AITER is adding a CK-free build mode (ENABLE_CK=0) that removes Composable Kernel dependencies, reducing build time from ~35min to ~8min. ATOM needs corresponding changes to:

  • Route around CK-dependent code paths when CK kernels are unavailable
  • Maintain performance by preferring ASM PA (paged attention) and Triton kernels where possible
  • Support clean Docker builds from pre-built wheels (zero compilation)
  • Fix bugs discovered during CK-free validation

All changes are backward compatible — when ATOM_CK_FREE=0 (default), ATOM uses CK/ASM kernels as before.

Validated Locally

  • MI300X (gfx942): Llama-3.1-8B, DeepSeek-R1-671B inference
  • MI355X (gfx950): MXFP4 MOE with Swiglu paths
  • Clean Docker (zero-compile from wheels): Llama-3.1-8B decode 80-82% of public image perf

Proposed PR Sequence

PR 1: Bug Fixes (Standalone, No CK-Free Dependency)

Files: atom/model_ops/linear.py, atom/model_ops/fused_moe_triton.py, atom/model_engine/scheduler.py

Individual fixes:

  1. linear.py: UnboundLocalError in MergedReplicatedLinear.weight_loader() — missing else clause for per_Token and per_1x32 quant types. Compute shard_offset from self.output_sizes.
  2. fused_moe_triton.py: Use CDNA4MXScaleLayout instead of GFX950MXScaleLayout for gfx950 arch detection (more accurate naming).
  3. fused_moe_triton.py: Add update_opt_flags_constraints({"block_m": 128}) for MI355X — default CDNA4 block_m=256 exceeds 160KB LDS limit.
  4. scheduler.py: Initialize num_rejected=0 to prevent UnboundLocalError in non-speculative path.

Dependency: None — these fix real bugs independent of CK-free work.


PR 2: MHA Attention Dispatch Decoupling

Files: atom/model_ops/attention_mha.py, atom/utils/envs.py

Key changes:

  • Add ATOM_CK_FREE env var (default=0) as master switch
  • Decouple cache update from paged attention backend selection:
    • Cache update: Always use Triton fused rope+cache (fast, no module_cache JIT dependency)
    • Paged attention: Independently select ASM PA for decode when head_dim=128 and no sliding window
  • AITER_FORCE_TRITON_ATTN env override for forcing Triton PA
  • FP8 KV cache: fill per-token scale buffers with uniform per-tensor scale so ASM PA can dequantize correctly
  • Prefill: always use prefill_attention_triton (no CK/flash_attn_varlen_func dependency)
  • Move kv_scale tensor to CUDA at init for graph capture compatibility
  • Create proper block_tables for prefill with block_size=1

Dependency: AITER PR 1 (CK-free build gating)


PR 3: MLA CK-Free Paths

Files: atom/model_ops/attention_mla.py, atom/model_ops/attentions/aiter_mla.py

Key changes:

  • ATOM_USE_TRITON_MLA_DECODE or ATOM_CK_FREE → force Triton MLA decode using decode_attention_fwd_grouped_rope (AITER Triton kernel)
  • MLA prefill fallback: Replace flash_attn_varlen_func with PyTorch F.scaled_dot_product_attention (loops over sequences)
  • FP8 constraint: only use fp8 scales when max_seqlen_q == 1 (mla_decode_fwd limitation)
  • Type casting: convert Q/KV to model dtype before mla_prefill_fwd if dtype mismatch
  • aiter_mla.py: Build paged KV metadata (kv_indptr, kv_indices, block_tables) for MLA prefill paths

Dependency: PR 2


PR 4: MOE Cascade Routing (CK → FlyDSL → Triton)

Files: atom/model_ops/moe.py, atom/model_ops/flydsl_moe.py (NEW)

Key changes:

  • Detection functions: _has_ck_moe_sorting(), _has_flydsl_moe() with caching
  • Cascade logic for Fp8MoEMethod / CompressedTensorsFp8MoEMethod:
    1. Check CK MOE sorting availability
    2. If unavailable or ATOM_CK_FREE=1: try FlyDSL MOE (ATOM_USE_FLYDSL_MOE=1)
    3. Else: fall back to Triton MOE (_triton_fp8_moe())
  • Weight shuffle skip: Triton expects standard row-major weights (not CK's shuffled layout)
  • flydsl_moe.py (NEW): FlyDSL MOE backend with torch-native sorting, per-token FP8 quant, 5-stage GEMM pipeline
  • _triton_fp8_moe(): Complete Triton MOE pipeline (sort → GEMM1+SiLU → GEMM2)
  • _per_token_group_quant_fp8(): Per-token-group FP8 quantization helper

Dependency: PR 2 (for ATOM_CK_FREE env var)


PR 5: Docker Infrastructure

Files: docker/Dockerfile, docker/Dockerfile.clean (NEW), docker/Dockerfile.wheels (NEW), .dockerignore (NEW)

Key changes:

  • Dockerfile: ARG ENABLE_CK=1 parameter, conditional git submodule, pass to AITER setup.py, install triton_kernels wheel
  • Dockerfile.wheels (NEW, ~160 lines): Multi-stage builder — PyTorch ROCm 7.2, Triton 3.5.x, FlyDSL, MORI, AITER (ENABLE_CK=0)
  • Dockerfile.clean (NEW, ~70 lines): Zero-compilation runtime from pre-built wheels via bind-mount
  • .dockerignore (NEW): Exclude .git/, build/, dist/ — reduces context from 67.9GB to 37.9GB

Build time comparison:

Image Build Time Size
Current (full CK) ~60 min Large
Dockerfile.wheels ~60 min (one-time) Wheels only
Dockerfile.clean ~10 min Minimal runtime

Dependency: AITER PR 1 (ENABLE_CK support)


PR 6: Test Suite & CI (Nice-to-have)

Files: tests/test_ck_free_mode.py, tests/test_flydsl_moe.py, tests/test_attention_dispatch.py, tests/test_mla_prefill_routing.py, tests/test_aiter_mla_metadata.py, tests/test_moe_shapes.py (all NEW), pyproject.toml, .github/workflows/pre-checks.yaml, .github/workflows/atom-test.yaml

Key changes:

  • 6 new test files (~970 lines) covering: env var detection, MOE routing, MHA dispatch, MLA prefill/decode, metadata construction, MOE shapes
  • pyproject.toml: gpu pytest marker for tagging GPU-requiring tests
  • pre-checks.yaml: Add unit-tests job running CPU-only pytest (-m "not gpu") — works on forks without GPU runners
  • atom-test.yaml: Guard to only run on ROCm/ATOM (not forks), re-enable golden output tests

Dependency: PR 2-4


PR 7: CI Nightly Sync (Nice-to-have)

Files: .github/workflows/sync-upstream.yaml (NEW), scripts/test_golden_output.sh (NEW)

  • Nightly scheduled sync of fork main with upstream
  • Golden output comparison script for CI regression testing

Dependency: None


Execution Path Summary

When ATOM_CK_FREE=1:

MHA: Triton fused rope+cache → ASM PA (head_dim=128) or Triton PA
MLA: PyTorch SDPA prefill → Triton decode (decode_attention_fwd_grouped_rope)
MOE: torch-native sorting → FlyDSL GEMM (if available) or Triton GEMM

When ATOM_CK_FREE=0 (default): unchanged behavior using CK/ASM kernels.

Performance Summary (MI300X, Llama-3.1-8B)

Configuration Decode tok/s vs Public
Public image (full CK+ASM) ~8,200 100%
CK-free, ASM PA, fp8 KV ~6,830 ~83%
CK-free, Triton PA, bf16 KV ~6,255 ~76%

Remaining gap primarily due to: (1) no ASM GEMM for decode (M=1), (2) tuned GEMM CSV coverage only M≤256.

Open Questions

  1. Should FlyDSL MOE be the default fallback, or should we wait for more validation?
  2. Should Dockerfile.clean/Dockerfile.wheels live in ATOM or a separate build-infra repo?
  3. Priority of ASM GEMM re-enablement (v2 clean Docker) vs other optimizations?

Related

  • AITER RFC: (will link after creation)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions