diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..c69aefffd --- /dev/null +++ b/.dockerignore @@ -0,0 +1,6 @@ +.git +__pycache__ +*.pyc +*.egg-info +build/ +dist/ diff --git a/.github/workflows/sync-upstream.yaml b/.github/workflows/sync-upstream.yaml new file mode 100644 index 000000000..2afda6a17 --- /dev/null +++ b/.github/workflows/sync-upstream.yaml @@ -0,0 +1,42 @@ +name: Sync upstream main + +on: + schedule: + # Run nightly at 06:00 UTC (midnight CST) + - cron: '0 6 * * *' + workflow_dispatch: # Allow manual trigger + +jobs: + sync: + runs-on: ubuntu-latest + steps: + - name: Checkout fork + uses: actions/checkout@v4 + with: + ref: main + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Add upstream remote + run: git remote add upstream https://github.com/ROCm/ATOM.git + + - name: Fetch upstream + run: git fetch upstream main + + - name: Check for new commits + id: check + run: | + BEHIND=$(git rev-list --count HEAD..upstream/main) + echo "behind=$BEHIND" >> "$GITHUB_OUTPUT" + echo "Fork is $BEHIND commit(s) behind upstream" + + - name: Merge upstream + if: steps.check.outputs.behind != '0' + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git merge upstream/main --no-edit + + - name: Push + if: steps.check.outputs.behind != '0' + run: git push origin main diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index e27d3af61..5e169a1b5 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -297,6 +297,7 @@ def postprocess( continue token_ids = prev_token_ids[seq.id] num_new_token = len(token_ids) + num_rejected = 0 self.update_spec_stats(num_new_token) idx = fwd_output.req_ids.index(seq.id) if is_deferred_out or self.use_spec: diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index b9e0a286e..74622651b 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -370,7 +370,19 @@ def prefill_attention_triton( if ctx.is_prefill: k_cache = k.unsqueeze(1) v_cache = v.unsqueeze(1) - block_tables = attn_metadata.fake_block_tables + # Create fake block_tables for prefill: each token is its own + # "block" (block_size=1). Shape [num_seqs, max_seqlen_k]. + batch_size = attn_metadata.cu_seqlens_k.shape[0] - 1 + max_len = attn_metadata.max_seqlen_k + block_tables = torch.zeros( + batch_size, max_len, dtype=torch.int32, device=q.device + ) + for i in range(batch_size): + s = attn_metadata.cu_seqlens_k[i].item() + e = attn_metadata.cu_seqlens_k[i + 1].item() + block_tables[i, : e - s] = torch.arange( + s, e, dtype=torch.int32, device=q.device + ) o = torch.empty_like(q) descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1]) @@ -407,7 +419,8 @@ def dispatch_backend(self, fwd_ctx: ForwardContext): ctx = fwd_ctx.context if ctx.is_prefill: - return self.prefill_attention + # Always use Triton prefill (no CK/flash_attn_varlen_func dependency) + return self.prefill_attention_triton else: if self.use_triton_attn: return self.paged_attention_triton diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 6b2452cd1..4f0caf33a 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -396,19 +396,22 @@ def _forward_prefill_mha( k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1) - output = flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=attn_metadata.cu_seqlens_q, - cu_seqlens_k=attn_metadata.cu_seqlens_k, - max_seqlen_q=attn_metadata.max_seqlen_q, - max_seqlen_k=attn_metadata.max_seqlen_k, - min_seqlen_q=attn_metadata.min_seqlen_q, - dropout_p=attn_metadata.dropout_p, - softmax_scale=self.scale, - causal=True, - ) + # Use PyTorch SDPA for MLA prefill attention (no CK dependency) + import torch.nn.functional as F + + cu_q = attn_metadata.cu_seqlens_q + cu_k = attn_metadata.cu_seqlens_k + num_seqs = cu_q.shape[0] - 1 + outputs = [] + for i in range(num_seqs): + qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0) + ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) + vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) + oi = F.scaled_dot_product_attention( + qi, ki, vi, is_causal=True, scale=self.scale + ) + outputs.append(oi.squeeze(0).transpose(0, 1)) + output = torch.cat(outputs, dim=0) return self.o_proj(output.flatten(start_dim=-2)) @@ -446,7 +449,8 @@ def _forward_prefill_mla( max_q_len = 1 if kv_c_and_k_pe_cache.numel() > 0: - if self.kv_cache_dtype.startswith("fp8"): + if self.kv_cache_dtype.startswith("fp8") and max_q_len == 1: + # mla_decode_fwd supports fp8 scales but only max_seqlen_q=1 mla_decode_fwd( q, kv_c_and_k_pe_cache.view(-1, 1, 1, q.shape[-1]), @@ -463,9 +467,16 @@ def _forward_prefill_mla( kv_scale=self._k_scale, ) else: + # mla_prefill_fwd supports arbitrary max_seqlen_q but no fp8 scales + q_for_prefill = q.to(self.dtype) if q.dtype != self.dtype else q + kv_for_prefill = ( + kv_c_and_k_pe_cache.to(self.dtype) + if kv_c_and_k_pe_cache.dtype != self.dtype + else kv_c_and_k_pe_cache + ) mla_prefill_fwd( - q, - kv_c_and_k_pe_cache.view(-1, 1, 1, q.shape[-1]), + q_for_prefill, + kv_for_prefill.view(-1, 1, 1, q.shape[-1]), o, paged_cu_seqlens_q, paged_kv_indptr, diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 7fd33253b..89bbdf65e 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -180,6 +180,56 @@ def prepare_prefill(self, batch: ScheduledBatch): bs = batch.total_seqs_num_prefill sum_scheduled_tokens = batch.total_tokens_num_prefill var = self.model_runner.forward_vars + + # Prepare paged KV metadata for MLA prefill paths + # (needed by mla_prefill_fwd for bf16, unified_attention for fp8) + if batch.block_tables: + context_lens = np.asarray(batch.context_lens[:bs], dtype=np.int32) + num_blocks_per_seq = cdiv(context_lens, self.block_size) + kv_indptr = np.cumsum(num_blocks_per_seq) + sum_blocks = kv_indptr[-1] + + dst = var["kv_indices"].np + offset = 0 + for i in range(bs): + bt = batch.block_tables[i] + n = len(bt) + dst[offset : offset + n] = bt + offset += n + sum_blocks_before_converted = offset + + var["kv_indptr"].np[0] = 0 + var["kv_indptr"].np[1 : bs + 1] = kv_indptr + + attn_metadata.kv_indptr = var["kv_indptr"].copy_to_gpu(bs + 1) + attn_metadata.kv_indices = var["kv_indices"].copy_to_gpu( + sum_blocks_before_converted + ) + attn_metadata.kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs] + + if self.block_ratio > 1: + kv_indices_convert_triton( + var["kv_indices"].gpu[:sum_blocks_before_converted], + var["kv_indices_converted"].gpu[:sum_blocks], + var["kv_indptr"].gpu[: bs + 1], + self.block_ratio, + self.block_size, + ) + attn_metadata.kv_indices = var["kv_indices_converted"].gpu[:sum_blocks] + + # Prepare block_tables for unified_attention (fp8 prefill) + if attn_metadata.block_tables is None: + self.prepare_block_tables(batch) + attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs) + if self.block_ratio > 1: + block_table_convert_triton( + var["block_tables"].gpu[:bs], + var["block_tables_converted"].gpu[:bs], + var["context_lens"].gpu[:bs], + self.block_ratio, + ) + attn_metadata.block_tables = var["block_tables_converted"].gpu[:bs] + if self.is_sparse and attn_metadata.max_seqlen_k > self.index_topk: if attn_metadata.block_tables is None: self.prepare_block_tables(batch) diff --git a/atom/model_ops/flydsl_moe.py b/atom/model_ops/flydsl_moe.py new file mode 100644 index 000000000..c8b01c56e --- /dev/null +++ b/atom/model_ops/flydsl_moe.py @@ -0,0 +1,330 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""FlyDSL MOE backend for ATOM. + +When CK MOE sorting is unavailable (e.g. AITER built with ENABLE_CK=0) and +ATOM_USE_FLYDSL_MOE=1, this module routes FP8 MOE through FlyDSL's MLIR-compiled +2-stage kernels instead of the Triton fallback. + +Priority: CK/ASM > FlyDSL > Triton. +""" + +import logging +import os +from typing import Optional, Tuple + +import torch + +from atom.utils import envs + +_logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Detection +# --------------------------------------------------------------------------- +_flydsl_moe_available: Optional[bool] = None + + +def _has_flydsl_moe() -> bool: + """Check if FlyDSL MOE backend is enabled and importable.""" + global _flydsl_moe_available + if _flydsl_moe_available is not None: + return _flydsl_moe_available + + if not envs.ATOM_USE_FLYDSL_MOE: + _flydsl_moe_available = False + return False + + try: + from kernels.moe_gemm_2stage import ( + compile_moe_gemm1, + compile_moe_gemm2, + ) # noqa: F401 + + _flydsl_moe_available = True + _logger.info("FlyDSL MOE kernels detected and available") + except Exception as e: + _flydsl_moe_available = False + _logger.warning( + "ATOM_USE_FLYDSL_MOE=1 but FlyDSL MOE kernels not importable: %s. " + "Ensure FlyDSL repo is on PYTHONPATH.", + e, + ) + return _flydsl_moe_available + + +# --------------------------------------------------------------------------- +# Torch-native MOE sorting (no CK dependency) +# --------------------------------------------------------------------------- +def moe_sorting_torch_native( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + block_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Pure-PyTorch MOE sorting matching CK/FlyDSL kernel expectations. + + Returns: + sorted_ids: int32 with packed (topk_slot<<24 | token_id) encoding + sorted_weights: fp32 aligned with sorted_ids + sorted_expert_ids: int32, one expert id per M-block + num_tokens_post_pad: int32 [2], [0]=total padded tokens, [1]=num_tokens + """ + device = topk_ids.device + M = topk_ids.shape[0] + topk = topk_ids.shape[1] + + max_num_tokens_padded = topk_ids.numel() + num_experts * block_size - topk + max_num_m_blocks = (max_num_tokens_padded + block_size - 1) // block_size + + init_val = (topk << 24) | M + sorted_ids = torch.full( + (max_num_tokens_padded,), init_val, dtype=torch.int32, device=device + ) + sorted_weights = torch.empty( + (max_num_tokens_padded,), dtype=torch.float32, device=device + ) + sorted_expert_ids = torch.full( + (max_num_m_blocks,), -1, dtype=torch.int32, device=device + ) + num_tokens_post_pad = torch.empty((2,), dtype=torch.int32, device=device) + + sorted_ids_begin = 0 + sorted_expert_ids_begin = 0 + skip_expert_num = 0 + for expert_id in range(num_experts): + token_id, topk_id = torch.where(topk_ids == expert_id) + tokens_num = token_id.numel() + sorted_expert_ids_num = (tokens_num + block_size - 1) // block_size + tokens_num_pad = sorted_expert_ids_num * block_size + + sorted_ids[sorted_ids_begin : sorted_ids_begin + tokens_num] = ( + topk_id.to(torch.int32) << 24 + ) | token_id.to(torch.int32) + sorted_weights[sorted_ids_begin : sorted_ids_begin + tokens_num] = topk_weights[ + token_id, topk_id + ].to(torch.float32) + + sorted_ids_begin += tokens_num_pad + sorted_expert_ids[ + sorted_expert_ids_begin : sorted_expert_ids_begin + sorted_expert_ids_num + ] = (expert_id - skip_expert_num) + sorted_expert_ids_begin += sorted_expert_ids_num + + num_tokens_post_pad[0] = sorted_ids_begin + num_tokens_post_pad[1] = M + + return sorted_ids, sorted_weights, sorted_expert_ids, num_tokens_post_pad + + +# --------------------------------------------------------------------------- +# Per-token FP8 quantization +# --------------------------------------------------------------------------- +def _pertoken_quant_fp8( + x: torch.Tensor, fp8_dtype: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + """Per-row dynamic FP8 quantization. + + Args: + x: Input tensor, any shape with last dim as the quant dimension. + fp8_dtype: Target FP8 dtype. + + Returns: + x_fp8: Quantized tensor (same shape as x). + scale_1d: 1D scale tensor [num_rows]. + """ + orig_shape = x.shape + x_2d = x.reshape(-1, orig_shape[-1]).float() + fp8_max = torch.finfo(fp8_dtype).max + amax = x_2d.abs().amax(dim=-1, keepdim=True) # [rows, 1] + scale = (amax / fp8_max).clamp(min=1e-12) + x_scaled = (x_2d / scale).clamp(-fp8_max, fp8_max) + x_fp8 = x_scaled.to(fp8_dtype).reshape(orig_shape) + scale_1d = scale.view(-1).to(torch.float32) + return x_fp8, scale_1d + + +# --------------------------------------------------------------------------- +# FlyDSL FP8 MOE dispatch +# --------------------------------------------------------------------------- +def flydsl_fp8_moe( + x: torch.Tensor, + w13: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + top_k: int, + block_quant: bool, + quant_type, +) -> torch.Tensor: + """Execute FP8 MOE using FlyDSL MLIR-compiled 2-stage kernels. + + Drop-in replacement for _triton_fp8_moe(). Expects preshuffled weights + (same CK layout as ASM/CK path). + + Pipeline: + 1. Sort tokens via torch-native MOE sorting + 2. Quantize activations to FP8 + 3. Stage 1: GEMM1 + SiLU gating (x @ w13^T) + 4. Quantize intermediate to FP8 + 5. Stage 2: GEMM2 with atomic accumulation (intermediate @ w2^T) + """ + from kernels.moe_gemm_2stage import compile_moe_gemm1, compile_moe_gemm2 + + if block_quant: + raise NotImplementedError( + "FlyDSL MOE does not support block quantization yet. " + "Set ATOM_USE_FLYDSL_MOE=0 or use Triton fallback." + ) + + M, model_dim = x.shape + E = w13.shape[0] + inter_dim_2 = w13.shape[1] # 2 * inter_dim + inter_dim = inter_dim_2 // 2 + actual_top_k = topk_ids.numel() // M + device = x.device + + # Detect FP8 dtype from weight dtype + fp8_dtype = w13.dtype + + # Tile sizes (configurable via env vars) + tile_m = int(os.environ.get("ATOM_FLYDSL_MOE_TILE_M", "64")) + tile_n = int(os.environ.get("ATOM_FLYDSL_MOE_TILE_N", "128")) + tile_k = int(os.environ.get("ATOM_FLYDSL_MOE_TILE_K", "64")) + + # --- Step 1: Sort tokens --- + sorted_ids, sorted_weights, sorted_expert_ids, num_tokens_post_pad = ( + moe_sorting_torch_native( + topk_ids.to(torch.int32), + topk_weights.to(torch.float32), + num_experts=E, + block_size=tile_m, + ) + ) + num_valid_ids = num_tokens_post_pad[:1].contiguous() + blocks = sorted_expert_ids.numel() + + # --- Step 2: Compile kernels (cached via @lru_cache) --- + exe1 = compile_moe_gemm1( + model_dim=model_dim, + inter_dim=inter_dim, + experts=E, + topk=actual_top_k, + in_dtype="fp8", + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage1=False, + ) + exe2 = compile_moe_gemm2( + model_dim=model_dim, + inter_dim=inter_dim, + experts=E, + topk=actual_top_k, + in_dtype="fp8", + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=True, + ) + + # --- Step 3: Quantize activations --- + x_fp8, scale_x = _pertoken_quant_fp8(x, fp8_dtype) + x_fp8 = x_fp8.contiguous().view(M, model_dim) + scale_x_1d = scale_x.view(-1).contiguous() + + # --- Step 4: Flatten weights for FlyDSL (expert dim merged into N) --- + # w13: [E, 2*inter_dim, model_dim] -> [E*2*inter_dim, model_dim] + w13_flat = w13.contiguous().view(E * inter_dim_2, model_dim) + + # --- Step 5: Flatten weight scales --- + # Handle different scale shapes: + # per-tensor: [E] -> broadcast to [E*2*inter_dim] + # per-channel: [E, 2*inter_dim, 1] -> [E*2*inter_dim] + # per-tensor (after max reduction): [E] -> broadcast + if w13_scale.dim() == 1: + # per-tensor [E]: broadcast each expert's scale to all its rows + scale_w13_1d = ( + w13_scale.unsqueeze(1).expand(E, inter_dim_2).contiguous().view(-1) + ) + elif w13_scale.dim() == 3: + # per-channel [E, 2*inter_dim, 1] + scale_w13_1d = w13_scale.contiguous().view(-1) + elif w13_scale.dim() == 2: + # [E, 2*inter_dim] or [E, 2] -> handle both + if w13_scale.shape[1] == inter_dim_2: + scale_w13_1d = w13_scale.contiguous().view(-1) + else: + # [E, 2] per-tensor per-shard -> broadcast + scale_w13_1d = ( + w13_scale.unsqueeze(2).expand(E, 2, inter_dim).contiguous().view(-1) + ) + else: + scale_w13_1d = w13_scale.contiguous().view(-1) + + # --- Step 6: Stage 1 (GEMM1 + SiLU) --- + out1 = torch.empty((M, actual_top_k, inter_dim), device=device, dtype=torch.float16) + stream_ptr = torch.cuda.current_stream().cuda_stream + + exe1( + out1, + x_fp8, + w13_flat, + scale_x_1d, + scale_w13_1d, + sorted_ids, + sorted_expert_ids, + sorted_weights.view(-1), + num_valid_ids, + M, + inter_dim, + model_dim, + blocks, + stream_ptr, + ) + + # --- Step 7: Quantize intermediate for Stage 2 --- + out1_fp32 = out1.to(torch.float32) + a2_fp8, scale_a2 = _pertoken_quant_fp8(out1_fp32, fp8_dtype) + a2_flat = a2_fp8.contiguous().view(-1) + scale_a2_1d = scale_a2.view(-1).contiguous() + + # --- Step 8: Flatten w2 weights and scales --- + # w2: [E, model_dim, inter_dim] -> [E*model_dim, inter_dim] -> flat 1D + w2_flat = w2.contiguous().view(-1) + + if w2_scale.dim() == 1: + # per-tensor [E]: broadcast to [E*model_dim] + scale_w2_1d = w2_scale.unsqueeze(1).expand(E, w2.shape[1]).contiguous().view(-1) + elif w2_scale.dim() == 3: + # per-channel [E, model_dim, 1] + scale_w2_1d = w2_scale.contiguous().view(-1) + elif w2_scale.dim() == 2: + scale_w2_1d = w2_scale.contiguous().view(-1) + else: + scale_w2_1d = w2_scale.contiguous().view(-1) + + # --- Step 9: Stage 2 (GEMM2 with atomic accumulation) --- + out2 = torch.zeros((M, model_dim), device=device, dtype=torch.float16) + + exe2( + out2.view(-1), + a2_flat, + w2_flat, + scale_a2_1d, + scale_w2_1d, + sorted_ids, + sorted_expert_ids, + sorted_weights.view(-1), + num_valid_ids, + M, + model_dim, + inter_dim, + blocks, + stream_ptr, + ) + + return out2.to(x.dtype) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 1bd3538a3..83b492a99 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -700,6 +700,10 @@ def weight_loader( elif self.quant_type == QuantType.per_Tensor: shard_offset = loaded_shard_id shard_size = 1 + else: + # per_Token and per_1x32: scale dim 0 matches output_size + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] else: shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_size = self.output_sizes[loaded_shard_id] diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index d9dc1e34c..b5e195d29 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -48,6 +48,209 @@ from torch import nn from transformers import PretrainedConfig +import logging + +_moe_logger = logging.getLogger(__name__) + + +def _has_ck_moe_sorting() -> bool: + """Check if CK MOE sorting kernel is available.""" + try: + import importlib + + return importlib.util.find_spec("aiter.jit.module_moe_sorting") is not None + except Exception: + return False + + +from atom.model_ops.flydsl_moe import _has_flydsl_moe + + +def _per_token_group_quant_fp8(x, group_size, fp8_dtype): + """Quantize input tensor to FP8 with per-token-group scaling. + + Args: + x: Input tensor of shape (M, K) in bf16/fp16. + group_size: Number of elements per quantization group. + fp8_dtype: Target FP8 dtype (e.g. torch.float8_e4m3fnuz). + + Returns: + x_fp8: Quantized tensor of shape (M, K). + scale: Dequantization scale of shape (M, K // group_size). + """ + M, K = x.shape + assert K % group_size == 0 + num_groups = K // group_size + x_float = x.float() + x_grouped = x_float.view(M, num_groups, group_size) + max_abs = x_grouped.abs().amax(dim=-1) # (M, num_groups) + fp8_max = torch.finfo(fp8_dtype).max + scale = (max_abs / fp8_max).clamp(min=1e-12) + x_scaled = x_grouped / scale.unsqueeze(-1) + x_fp8 = x_scaled.clamp(-fp8_max, fp8_max).to(fp8_dtype) + x_fp8 = x_fp8.view(M, K) + return x_fp8, scale + + +def _triton_fp8_moe( + x, + w13, + w2, + topk_weights, + topk_ids, + w13_scale, + w2_scale, + top_k, + block_quant, + quant_type, +): + """Execute FP8 MOE using AITER Triton kernels (no CK dependency). + + Two-stage pipeline: + Stage 1 (GEMM1+SiLU): x @ w13^T with SiLU gating + Stage 2 (GEMM2): intermediate @ w2^T with routing weight accumulation + + For GEMM2, we reshape the intermediate so each (token, expert_k) pair is + treated as an independent token with top_k=1, allowing correct A indexing. + """ + import triton.language as tl + from aiter.ops.triton.moe.moe_align_block_size import moe_align_block_size_triton + from aiter.ops.triton.moe.moe_op_silu_fused import fused_moe_silu + from aiter.ops.triton.moe.moe_op import fused_moe as triton_fused_moe + from aiter.ops.triton.utils.moe_config_utils import get_optimal_moe_config + + M, hidden_dim = x.shape + E = w13.shape[0] + inter_dim_2 = w13.shape[1] # 2 * inter_dim + inter_dim = inter_dim_2 // 2 + # actual_top_k may differ from top_k when shared experts are fused + actual_top_k = topk_ids.numel() // M + + if block_quant: + if quant_type == QuantType.per_1x128: + block_shape = [128, 128] + elif quant_type == QuantType.per_1x32: + block_shape = [1, 32] + else: + block_shape = None + else: + block_shape = None + + config = get_optimal_moe_config(dtype=x.dtype, use_fp8_w8a8=True, M=M) + block_size_m = config["BLOCK_SIZE_M"] + compute_type = tl.bfloat16 if x.dtype == torch.bfloat16 else tl.float16 + + # --- Stage 1: Sorting --- + max_num_tokens_padded = topk_ids.numel() + E * (block_size_m - 1) + sorted_token_ids = torch.empty( + max_num_tokens_padded, dtype=torch.int32, device=x.device + ) + sorted_token_ids.fill_(topk_ids.numel()) + max_num_m_blocks = (max_num_tokens_padded + block_size_m - 1) // block_size_m + expert_ids = torch.empty(max_num_m_blocks, dtype=torch.int32, device=x.device) + num_tokens_post_pad = torch.empty(1, dtype=torch.int32, device=x.device) + + moe_align_block_size_triton( + topk_ids, E, block_size_m, sorted_token_ids, expert_ids, num_tokens_post_pad + ) + + # --- Stage 2: GEMM1 with SiLU (x @ w13^T) --- + if block_quant and block_shape is not None: + block_k = block_shape[1] + a_fp8, a_scale = _per_token_group_quant_fp8(x, block_k, w13.dtype) + else: + a_fp8 = x + a_scale = None + + intermediate = torch.zeros( + M * actual_top_k, inter_dim, dtype=x.dtype, device=x.device + ) + + fused_moe_silu( + A=a_fp8, + B=w13, + C=intermediate, + A_scale=a_scale, + B_scale=w13_scale, + B_zp=None, + topk_weights=topk_weights, + topk_ids=topk_ids, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_pad, + mul_routed_weight=False, + top_k=actual_top_k, + compute_type=compute_type, + use_fp8_w8a8=True, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=block_shape, + config=config, + ) + + # --- Stage 3: GEMM2 (intermediate @ w2^T) --- + # Reshape for GEMM2: treat each (token, expert_k) as independent token + # with top_k=1 so the kernel indexes A correctly (A // top_k = A // 1 = A) + gemm2_topk_ids = topk_ids.reshape(M * actual_top_k, 1) + gemm2_topk_weights = topk_weights.reshape(M * actual_top_k, 1) + + # Re-sort for GEMM2 with the reshaped topk_ids + gemm2_max_padded = gemm2_topk_ids.numel() + E * (block_size_m - 1) + gemm2_sorted_ids = torch.empty(gemm2_max_padded, dtype=torch.int32, device=x.device) + gemm2_sorted_ids.fill_(gemm2_topk_ids.numel()) + gemm2_max_blocks = (gemm2_max_padded + block_size_m - 1) // block_size_m + gemm2_expert_ids = torch.empty(gemm2_max_blocks, dtype=torch.int32, device=x.device) + gemm2_num_pad = torch.empty(1, dtype=torch.int32, device=x.device) + + moe_align_block_size_triton( + gemm2_topk_ids, + E, + block_size_m, + gemm2_sorted_ids, + gemm2_expert_ids, + gemm2_num_pad, + ) + + # Quantize intermediate for FP8 GEMM2 + if block_quant and block_shape is not None: + block_k2 = block_shape[1] + inter_fp8, inter_scale = _per_token_group_quant_fp8( + intermediate, block_k2, w2.dtype + ) + else: + inter_fp8 = intermediate + inter_scale = None + + output = torch.zeros( + M * actual_top_k, 1, hidden_dim, dtype=x.dtype, device=x.device + ) + + triton_fused_moe( + A=inter_fp8, + B=w2, + C=output, + A_scale=inter_scale, + B_scale=w2_scale, + B_zp=None, + topk_weights=gemm2_topk_weights, + topk_ids=gemm2_topk_ids, + sorted_token_ids=gemm2_sorted_ids, + expert_ids=gemm2_expert_ids, + num_tokens_post_padded=gemm2_num_pad, + mul_routed_weight=True, + top_k=1, + compute_type=compute_type, + use_fp8_w8a8=True, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=block_shape, + config=config, + ) + + # Reduce: sum across top_k experts per token + result = output.squeeze(1).view(M, actual_top_k, hidden_dim).sum(dim=1) + return result + class FusedMoeWeightScaleSupported(Enum): """Supported quantization strategies for MoE weight scales.""" @@ -980,6 +1183,21 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.block_n = 1 self.block_k = 32 + # Detect CK MOE availability; fall back to FlyDSL or Triton + self.use_flydsl_moe = False + self.use_triton_moe = False + if not _has_ck_moe_sorting(): + if not self.block_quant and _has_flydsl_moe(): + self.use_flydsl_moe = True + _moe_logger.info( + "CK unavailable, using FlyDSL MOE for CompressedTensors FP8" + ) + else: + self.use_triton_moe = True + _moe_logger.info( + "CK unavailable, using Triton MOE for CompressedTensors FP8" + ) + def create_weights( self, layer: torch.nn.Module, @@ -1219,17 +1437,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: max_w13_scales, requires_grad=False ) - # Shuffle weights for asm moe (moved from inference to load time for better performance) - if w13.dtype in [ - torch.int8, - torch.uint8, - torch.float8_e4m3fnuz, - torch.float8_e4m3fn, - ]: - from aiter.ops.shuffle import shuffle_weight + # Shuffle weights for asm/FlyDSL moe (moved from inference to load time) + # Skip shuffle when using Triton path (Triton expects standard row-major) + if not self.use_triton_moe: + if w13.dtype in [ + torch.int8, + torch.uint8, + torch.float8_e4m3fnuz, + torch.float8_e4m3fn, + ]: + from aiter.ops.shuffle import shuffle_weight - w13.data = shuffle_weight(w13.data) - w2.data = shuffle_weight(w2.data) + w13.data = shuffle_weight(w13.data) + w2.data = shuffle_weight(w2.data) # Call parent class for any additional processing super().process_weights_after_loading(layer) @@ -1298,6 +1518,38 @@ def apply( a1_scale = getattr(layer, "w13_input_scale", None) a2_scale = getattr(layer, "w2_input_scale", None) + # FlyDSL MOE fallback when CK is not available + if self.use_flydsl_moe: + from atom.model_ops.flydsl_moe import flydsl_fp8_moe + + return flydsl_fp8_moe( + x=x, + w13=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + top_k=top_k, + block_quant=self.block_quant, + quant_type=self.quant_type, + ) + + # Triton MOE fallback when CK is not available + if self.use_triton_moe: + return _triton_fp8_moe( + x=x, + w13=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + top_k=top_k, + block_quant=self.block_quant, + quant_type=self.quant_type, + ) + # Use modular kernel if available (for EP/DP setups) # Otherwise fall back to direct kernel call if self.fused_experts is not None: @@ -1362,6 +1614,16 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.need_normalize_e4m3fn_to_e4m3fnuz = ( self.quant_dtype == torch.float8_e4m3fnuz ) + # Detect CK MOE availability; fall back to FlyDSL or Triton + self.use_flydsl_moe = False + self.use_triton_moe = False + if not _has_ck_moe_sorting(): + if not self.block_quant and _has_flydsl_moe(): + self.use_flydsl_moe = True + _moe_logger.info("CK unavailable, using FlyDSL MOE for FP8") + else: + self.use_triton_moe = True + _moe_logger.info("CK unavailable, using Triton MOE for FP8") def create_weights( self, @@ -1525,7 +1787,8 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: layer.w2_weight = nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale = nn.Parameter(w2_weight_scale, requires_grad=False) - shuffle_weights(layer.w13_weight, layer.w2_weight) + if not self.use_triton_moe: + shuffle_weights(layer.w13_weight, layer.w2_weight) return else: @@ -1597,7 +1860,8 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: ) start += shard_size - shuffle_weights(layer.w13_weight, layer.w2_weight) + if not self.use_triton_moe: + shuffle_weights(layer.w13_weight, layer.w2_weight) layer.w13_weight_scale = torch.nn.Parameter( max_w13_scales, requires_grad=False @@ -1647,6 +1911,38 @@ def apply( num_fused_shared_experts=layer.num_fused_shared_experts, routed_scaling_factor=layer.routed_scaling_factor, ) + # FlyDSL MOE fallback when CK is not available + if self.use_flydsl_moe: + from atom.model_ops.flydsl_moe import flydsl_fp8_moe + + return flydsl_fp8_moe( + x=x, + w13=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + top_k=top_k, + block_quant=self.block_quant, + quant_type=self.quant_type, + ) + + # Triton MOE fallback when CK is not available + if self.use_triton_moe: + return _triton_fp8_moe( + x=x, + w13=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + top_k=top_k, + block_quant=self.block_quant, + quant_type=self.quant_type, + ) + # per_Tensor not support num_local_tokens so not use mori if self.quant_type == QuantType.per_Tensor or self.fused_experts is None: return torch.ops.aiter.rocm_aiter_fused_moe( diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 62ce11bb5..eb77ccbc9 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -42,6 +42,7 @@ "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1" ) == "1", + "ATOM_USE_FLYDSL_MOE": lambda: os.getenv("ATOM_USE_FLYDSL_MOE", "0") == "1", } diff --git a/docker/Dockerfile b/docker/Dockerfile index 85c99daac..7d6b9d837 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -14,6 +14,8 @@ ARG AITER_COMMIT="HEAD" ARG MORI_COMMIT="b0dce4beebeb1f26c784eee17d5fd9785ee9447f" ARG PREBUILD_KERNELS=1 ARG MAX_JOBS +# Set ENABLE_CK=0 to skip CK/ASM modules for a fast Triton-only AITER build +ARG ENABLE_CK=1 RUN pip install --upgrade pip RUN pip install lm-eval[api] @@ -63,6 +65,10 @@ RUN git clone --depth=1 --branch release/internal/3.5.x https://github.com/ROCm/ MAX_JOBS=64 pip --retries=10 --default-timeout=60 install . RUN pip show triton || true +# Install triton_kernels (required for MXFP4 MoE on gfx94x) +RUN pip install --no-deps -e /triton-test/python/triton_kernels/ +RUN pip show triton-kernels || true + # Install Aiter RUN mkdir -p /app RUN pip uninstall -y aiter || true @@ -70,8 +76,11 @@ RUN git clone $AITER_REPO /app/aiter-test && \ cd /app/aiter-test && \ pip install -r requirements.txt && \ git checkout $AITER_COMMIT && \ - git submodule sync && git submodule update --init --recursive && \ - MAX_JOBS=$MAX_JOBS PREBUILD_KERNELS=$PREBUILD_KERNELS GPU_ARCHS=$GPU_ARCH_LIST python3 setup.py develop + if [ "$ENABLE_CK" != "0" ]; then \ + git submodule sync && git submodule update --init --recursive; \ + fi && \ + ENABLE_CK=$ENABLE_CK MAX_JOBS=$MAX_JOBS PREBUILD_KERNELS=$PREBUILD_KERNELS \ + GPU_ARCHS=$GPU_ARCH_LIST python3 setup.py develop RUN pip show amd-aiter || true diff --git a/docker/Dockerfile.clean b/docker/Dockerfile.clean new file mode 100644 index 000000000..4a2d4dbc6 --- /dev/null +++ b/docker/Dockerfile.clean @@ -0,0 +1,69 @@ +# Dockerfile.clean — Wheel-only ATOM/AITER build (zero source compilation) +# +# Base: rocm/dev-ubuntu-24.04:7.2-complete (Python 3.12, full ROCm runtime) +# All packages installed from pre-built wheels — no git clones, no compiles. +# +# Option A — from pre-built wheels directory: +# cd /home/pensun/ATOM +# DOCKER_BUILDKIT=1 docker build \ +# --build-context wheels=/home/pensun/dist \ +# -f docker/Dockerfile.clean -t atom:clean . +# +# Option B — multi-stage from Dockerfile.wheels builder image: +# docker build -f docker/Dockerfile.wheels -t atom:wheels . +# DOCKER_BUILDKIT=1 docker build \ +# --build-context wheels=docker-image://atom:wheels \ +# -f docker/Dockerfile.clean -t atom:clean . +# +# Run: +# docker run --rm --device=/dev/kfd --device=/dev/dri \ +# -v /data2/models:/models atom:clean bash + +ARG BASE_IMAGE="rocm/dev-ubuntu-24.04:7.2-complete" +FROM ${BASE_IMAGE} + +ENV DEBIAN_FRONTEND=noninteractive + +# ── 1. System packages (minimal — no build tools needed) ───────────── +RUN apt-get update && apt-get install -y --no-install-recommends \ + git python3-pip python3-dev \ + ibverbs-utils libpci-dev locales \ + openmpi-bin libopenmpi-dev libdw1 \ + && rm -rf /var/lib/apt/lists/* + +RUN pip3 install --break-system-packages --ignore-installed pip setuptools wheel + +# ── 2. Install all pre-built wheels ──────────────────────────────────── +# Uses bind-mount to avoid a 60+ GB COPY layer from the wheels image. +# Works with both Option A (flat directory) and Option B (docker-image://). +RUN --mount=type=bind,from=wheels,source=/,target=/mnt/wheels \ + mkdir -p /tmp/wheels \ + && find /mnt/wheels -name '*.whl' -exec cp {} /tmp/wheels/ \; \ + && ls -lhS /tmp/wheels/*.whl \ + && pip3 install --break-system-packages --no-deps \ + /tmp/wheels/torch-*.whl \ + /tmp/wheels/torchvision-*.whl \ + /tmp/wheels/torchaudio-*.whl \ + /tmp/wheels/triton-*.whl \ + /tmp/wheels/triton_kernels-*.whl \ + && pip3 install --break-system-packages \ + filelock typing-extensions sympy networkx jinja2 fsspec numpy pillow \ + && pip3 install --break-system-packages \ + /tmp/wheels/mori-*.whl \ + /tmp/wheels/flydsl-*.whl \ + && pip3 install --break-system-packages \ + /tmp/wheels/amd_aiter-*.whl \ + && rm -rf /tmp/wheels \ + && python3 -c "import torch; print(f'PyTorch {torch.__version__}, ROCm: {torch.version.hip}')" \ + && python3 -c "import triton; print(f'Triton {triton.__version__}')" \ + && python3 -c "import aiter; print('AITER OK')" \ + && python3 -c "import flydsl; print('FlyDSL OK')" \ + && pip3 show mori && echo "MORI wheel installed OK" + +# ── 3. ATOM (from build context — pure Python, instant install) ────── +COPY . /app/ATOM +RUN cd /app/ATOM && pip3 install --break-system-packages -e . \ + && python3 -c "import atom; print('ATOM OK')" + +WORKDIR /app/ATOM +CMD ["/bin/bash"] diff --git a/docker/Dockerfile.wheels b/docker/Dockerfile.wheels new file mode 100644 index 000000000..e9da1a648 --- /dev/null +++ b/docker/Dockerfile.wheels @@ -0,0 +1,159 @@ +# Dockerfile.wheels — Build/download all wheels for ATOM clean install +# +# Produces /wheels/ containing: +# torch, torchvision, torchaudio (pulled from PyTorch nightly) +# triton 3.5.x (built from ROCm/triton source) +# triton_kernels (built from ROCm/triton source) +# flydsl (built from FlyDSL source + embedded MLIR runtime) +# mori (built from MORI source) +# amd_aiter (built with ENABLE_CK=0 + pre-compiled Triton kernels) +# +# Usage (standalone — extract wheels to host): +# docker build -f docker/Dockerfile.wheels -t atom:wheels . +# docker run --rm atom:wheels tar cf - /wheels | tar xf - -C /home/pensun/dist --strip-components=1 +# +# Usage (multi-stage — pipe directly into Dockerfile.clean): +# DOCKER_BUILDKIT=1 docker build \ +# --build-context wheels=docker-image://atom:wheels \ +# -f docker/Dockerfile.clean -t atom:clean . + +ARG BASE_IMAGE="rocm/dev-ubuntu-24.04:7.2-complete" +FROM ${BASE_IMAGE} + +ARG GPU_ARCH="gfx942;gfx950" +ARG AITER_REPO="https://github.com/sunway513/aiter.git" +ARG AITER_BRANCH="feat/prebuild-triton" +ARG FLYDSL_REPO="https://github.com/ROCm/FlyDSL.git" +ARG FLYDSL_BRANCH="main" +ARG LLVM_COMMIT="04f968b02917" +ARG MORI_REPO="https://github.com/ROCm/mori.git" +ARG MORI_COMMIT="b0dce4beebeb1f26c784eee17d5fd9785ee9447f" +ARG MAX_JOBS="" +ARG PREBUILD_TRITON=1 + +ENV GPU_ARCH_LIST=${GPU_ARCH} +ENV PYTORCH_ROCM_ARCH=${GPU_ARCH} +ENV DEBIAN_FRONTEND=noninteractive + +# ── 1. System packages + build tools ──────────────────────────────── +RUN apt-get update && apt-get install -y --no-install-recommends \ + git cmake ninja-build \ + python3-pip python3-dev python3-venv \ + ibverbs-utils libpci-dev locales \ + && rm -rf /var/lib/apt/lists/* + +RUN pip3 install --break-system-packages --ignore-installed \ + pip setuptools wheel build + +RUN mkdir -p /wheels + +# ── 2. Pull PyTorch ROCm 7.2 nightly wheels ───────────────────────── +RUN pip3 download --no-deps --dest /wheels \ + torch torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/nightly/rocm7.2 + +# ── 3. Build Triton 3.5.x from ROCm fork ──────────────────────────── +RUN git clone --depth=1 --branch release/internal/3.5.x \ + https://github.com/ROCm/triton.git /build/triton + +RUN cd /build/triton \ + && pip3 install --break-system-packages -r python/requirements.txt \ + && pip3 install --break-system-packages filecheck \ + && MAX_JOBS=${MAX_JOBS:-64} pip3 wheel \ + --no-build-isolation --no-deps -w /wheels . \ + && ls -lh /wheels/triton-*.whl + +# Build triton_kernels wheel +RUN cd /build/triton/python/triton_kernels \ + && pip3 wheel --no-deps -w /wheels . \ + && ls -lh /wheels/triton_kernels-*.whl + +# ── 4. Build LLVM/MLIR for FlyDSL ─────────────────────────────────── +# Blobless clone (~6 min vs ~30 min full clone). LLVM_COMMIT rarely +# changes, so this layer stays cached across most rebuilds. +RUN pip3 install --break-system-packages nanobind numpy pybind11 + +RUN git clone --filter=blob:none --no-checkout \ + https://github.com/ROCm/llvm-project.git /build/llvm-project \ + && cd /build/llvm-project \ + && git fetch origin amd-staging \ + && git checkout ${LLVM_COMMIT} + +RUN mkdir -p /build/llvm-project/buildmlir \ + && cd /build/llvm-project/buildmlir \ + && NANOBIND_DIR=$(python3 -c "import nanobind; import os; print(os.path.dirname(nanobind.__file__) + '/cmake')") \ + && cmake -G Ninja \ + -S /build/llvm-project/llvm \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_STANDARD=17 \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_INSTALL_UTILS=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE=$(which python3) \ + -Dnanobind_DIR="$NANOBIND_DIR" \ + -DBUILD_SHARED_LIBS=OFF \ + && cmake --build . -j$(nproc) \ + && cmake --install . --prefix /build/llvm-project/mlir_install + +# ── 5. Install torch + triton (needed for AITER/MORI builds) ──────── +RUN pip3 install --break-system-packages --no-deps \ + /wheels/torch-*.whl /wheels/triton-3.5*.whl \ + && pip3 install --break-system-packages \ + filelock typing-extensions sympy networkx jinja2 fsspec numpy + +# ── 6. Build FlyDSL wheel ─────────────────────────────────────────── +RUN git clone --depth=1 --branch ${FLYDSL_BRANCH} ${FLYDSL_REPO} /build/FlyDSL + +RUN cd /build/FlyDSL \ + && export MLIR_PATH=/build/llvm-project/mlir_install \ + && bash flir/build.sh \ + && export FLIR_IN_BUILD_SH=1 \ + && pip3 install --break-system-packages auditwheel patchelf \ + && python3 setup.py bdist_wheel \ + && cp dist/flydsl-*.whl /wheels/ \ + && ls -lh /wheels/flydsl-*.whl + +# ── 7. Build MORI wheel ───────────────────────────────────────────── +RUN apt-get update && apt-get install -y --no-install-recommends \ + openmpi-bin libopenmpi-dev cython3 libdw1 \ + && rm -rf /var/lib/apt/lists/* + +# Patch PyTorch's Caffe2Config.cmake: the ROCm nightly wheel's config +# hard-errors when CUDA toolkit is not found, even though we only need ROCm. +# Convert the fatal error to a warning so MORI (and other torch-cmake users) +# can build against the ROCm PyTorch wheel without CUDA installed. +RUN CAFFE2_CFG=$(python3 -c "import torch, pathlib; print(pathlib.Path(torch.__file__).parent / 'share/cmake/Caffe2/Caffe2Config.cmake')") \ + && sed -i 's/message(FATAL_ERROR "Your installed Caffe2 version uses CUDA/message(WARNING "Skipped: Your installed Caffe2 version uses CUDA/' "$CAFFE2_CFG" + +RUN git clone ${MORI_REPO} /build/mori \ + && cd /build/mori \ + && git checkout ${MORI_COMMIT} \ + && grep -iv '^torch\|^triton' requirements-build.txt \ + | pip3 install --break-system-packages -r /dev/stdin \ + && git submodule update --init --recursive \ + && pip3 wheel --no-build-isolation --no-deps -w /wheels . \ + && ls -lh /wheels/mori-*.whl + +# ── 8. Build AITER wheel (ENABLE_CK=0, pre-compiled Triton kernels) ── +RUN git clone --depth=1 --branch ${AITER_BRANCH} ${AITER_REPO} /build/aiter + +# Set AITER build env for all subsequent commands in this layer +RUN cd /build/aiter \ + && pip3 install --break-system-packages -r requirements.txt \ + && export ENABLE_CK=0 PREBUILD_TRITON=${PREBUILD_TRITON} \ + PREBUILD_TRITON_ARCHS="gfx942,gfx950" \ + MAX_JOBS=${MAX_JOBS} GPU_ARCHS=${GPU_ARCH_LIST} \ + && pip3 install --break-system-packages --no-build-isolation -e . \ + && python3 -c "import aiter; print('editable install OK')" \ + && echo "install" > aiter/install_mode \ + && python3 setup.py bdist_wheel \ + && cp dist/amd_aiter-*.whl /wheels/ \ + && ls -lh /wheels/amd_aiter-*.whl + +# ── 9. Summary ────────────────────────────────────────────────────── +RUN echo "=== Wheel inventory ===" && ls -lhS /wheels/*.whl && echo "=== Done ===" + +WORKDIR /wheels +CMD ["ls", "-lhS", "/wheels/"] diff --git a/tests/test_flydsl_moe.py b/tests/test_flydsl_moe.py new file mode 100644 index 000000000..6765a41a8 --- /dev/null +++ b/tests/test_flydsl_moe.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +"""Tests for FlyDSL MOE backend (atom/model_ops/flydsl_moe.py). + +Test tiers: + 1. Unit tests (CPU-only, no FlyDSL/AITER deps): sorting, quantization, detection + 2. Integration tests (GPU + FlyDSL): full flydsl_fp8_moe pipeline + +Run: + # Unit tests only (no GPU needed): + python3 tests/test_flydsl_moe.py --unit + + # Full GPU test (needs FlyDSL on PYTHONPATH + GPU): + ATOM_USE_FLYDSL_MOE=1 PYTHONPATH=/path/to/FlyDSL:$PYTHONPATH \ + python3 tests/test_flydsl_moe.py --gpu +""" + +import os +import sys +import argparse +import torch + +# Ensure ATOM root is on path +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + + +# --------------------------------------------------------------------------- +# Unit tests (CPU-only) +# --------------------------------------------------------------------------- +def test_detection_disabled(): + """_has_flydsl_moe returns False when env var is disabled.""" + import atom.model_ops.flydsl_moe as mod + + # Reset cache + mod._flydsl_moe_available = None + old = os.environ.get("ATOM_USE_FLYDSL_MOE") + os.environ["ATOM_USE_FLYDSL_MOE"] = "0" + try: + assert mod._has_flydsl_moe() is False, "Should be False when disabled" + finally: + mod._flydsl_moe_available = None + if old is None: + os.environ.pop("ATOM_USE_FLYDSL_MOE", None) + else: + os.environ["ATOM_USE_FLYDSL_MOE"] = old + print(" PASS: test_detection_disabled") + + +def test_sorting_basic(): + """Test moe_sorting_torch_native produces correct shapes and encoding.""" + from atom.model_ops.flydsl_moe import moe_sorting_torch_native + + M, topk, E, block_size = 8, 2, 4, 4 + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Deterministic routing + torch.manual_seed(42) + topk_ids = torch.randint(0, E, (M, topk), device=device, dtype=torch.int32) + topk_weights = torch.rand(M, topk, device=device, dtype=torch.float32) + + sorted_ids, sorted_weights, sorted_expert_ids, num_tokens_post_pad = ( + moe_sorting_torch_native(topk_ids, topk_weights, E, block_size) + ) + + # Shape checks + assert sorted_ids.dtype == torch.int32 + assert sorted_weights.dtype == torch.float32 + assert sorted_expert_ids.dtype == torch.int32 + assert num_tokens_post_pad.shape == (2,) + assert num_tokens_post_pad[1].item() == M, "num_tokens should be M" + + total_padded = num_tokens_post_pad[0].item() + assert total_padded > 0 + assert total_padded % block_size == 0, "total should be block-aligned" + + # Verify packed encoding: (topk_slot << 24 | token_id) + init_val = (topk << 24) | M + for i in range(total_padded): + val = sorted_ids[i].item() + if val == init_val: + continue # padding sentinel + token_id = val & 0xFFFFFF + topk_slot = (val >> 24) & 0xFF + assert 0 <= token_id < M, f"token_id={token_id} out of range" + assert 0 <= topk_slot < topk, f"topk_slot={topk_slot} out of range" + + print(" PASS: test_sorting_basic") + + +def test_sorting_all_same_expert(): + """Edge case: all tokens routed to same expert.""" + from atom.model_ops.flydsl_moe import moe_sorting_torch_native + + M, topk, E, block_size = 16, 1, 8, 4 + device = "cuda" if torch.cuda.is_available() else "cpu" + + topk_ids = torch.zeros(M, topk, device=device, dtype=torch.int32) # all expert 0 + topk_weights = torch.ones(M, topk, device=device, dtype=torch.float32) + + sorted_ids, sorted_weights, sorted_expert_ids, num_tokens_post_pad = ( + moe_sorting_torch_native(topk_ids, topk_weights, E, block_size) + ) + + total_padded = num_tokens_post_pad[0].item() + expected_blocks = (M + block_size - 1) // block_size + expected_padded = expected_blocks * block_size + assert ( + total_padded == expected_padded + ), f"Expected {expected_padded} padded tokens, got {total_padded}" + + # All expert_ids should be 0 for the used blocks + for i in range(expected_blocks): + assert sorted_expert_ids[i].item() == 0 + + print(" PASS: test_sorting_all_same_expert") + + +def test_pertoken_quant_fp8(): + """Test per-token FP8 quantization correctness.""" + from atom.model_ops.flydsl_moe import _pertoken_quant_fp8 + + if not hasattr(torch, "float8_e4m3fnuz"): + print(" SKIP: test_pertoken_quant_fp8 (torch version lacks fp8 support)") + return + + device = "cuda" if torch.cuda.is_available() else "cpu" + fp8_dtype = torch.float8_e4m3fnuz + + x = torch.randn(32, 128, device=device, dtype=torch.float32) + x_fp8, scale_1d = _pertoken_quant_fp8(x, fp8_dtype) + + assert x_fp8.shape == x.shape + assert x_fp8.dtype == fp8_dtype + assert scale_1d.shape == (32,), f"Expected [32], got {scale_1d.shape}" + assert scale_1d.dtype == torch.float32 + + # Dequantize and check approximate reconstruction + x_deq = x_fp8.to(torch.float32) * scale_1d.unsqueeze(1) + rel_err = (x_deq - x).abs() / (x.abs() + 1e-6) + mean_err = rel_err.mean().item() + assert mean_err < 0.2, f"Mean relative error too high: {mean_err}" + + print(f" PASS: test_pertoken_quant_fp8 (mean_rel_err={mean_err:.4f})") + + +def test_pertoken_quant_3d(): + """Test per-token FP8 quantization with 3D input.""" + from atom.model_ops.flydsl_moe import _pertoken_quant_fp8 + + if not hasattr(torch, "float8_e4m3fnuz"): + print(" SKIP: test_pertoken_quant_3d (torch version lacks fp8 support)") + return + + device = "cuda" if torch.cuda.is_available() else "cpu" + fp8_dtype = torch.float8_e4m3fnuz + + x = torch.randn(8, 2, 64, device=device, dtype=torch.float32) + x_fp8, scale_1d = _pertoken_quant_fp8(x, fp8_dtype) + + assert x_fp8.shape == x.shape + assert scale_1d.shape == (16,), f"Expected [16], got {scale_1d.shape}" + + print(" PASS: test_pertoken_quant_3d") + + +def run_unit_tests(): + print("=" * 60) + print("Unit tests (CPU/GPU, no FlyDSL dependency)") + print("=" * 60) + test_detection_disabled() + test_sorting_basic() + test_sorting_all_same_expert() + test_pertoken_quant_fp8() + test_pertoken_quant_3d() + print("\nAll unit tests passed!") + + +# --------------------------------------------------------------------------- +# GPU integration tests (requires FlyDSL + GPU) +# --------------------------------------------------------------------------- +def test_flydsl_fp8_moe_gpu(): + """Full end-to-end test: flydsl_fp8_moe on random FP8 weights.""" + from atom.model_ops.flydsl_moe import _has_flydsl_moe, flydsl_fp8_moe + + if not _has_flydsl_moe(): + print(" SKIP: FlyDSL MOE not available") + return + + device = torch.device("cuda") + torch.manual_seed(42) + + # Model dimensions (small for testing) + M = 64 # tokens + model_dim = 256 + inter_dim = 128 + E = 4 # experts + topk = 2 + + fp8_dtype = torch.float8_e4m3fnuz + + # Create random FP8 weights (preshuffled via aiter) + w13_fp32 = torch.randn(E, 2 * inter_dim, model_dim, device=device) * 0.1 + w2_fp32 = torch.randn(E, model_dim, inter_dim, device=device) * 0.1 + + # Quantize weights to FP8 + fp8_max = torch.finfo(fp8_dtype).max + w13_amax = w13_fp32.abs().amax(dim=-1, keepdim=True) + w13_scale_full = (w13_amax / fp8_max).clamp(min=1e-12) + w13 = (w13_fp32 / w13_scale_full).clamp(-fp8_max, fp8_max).to(fp8_dtype) + + w2_amax = w2_fp32.abs().amax(dim=-1, keepdim=True) + w2_scale_full = (w2_amax / fp8_max).clamp(min=1e-12) + w2 = (w2_fp32 / w2_scale_full).clamp(-fp8_max, fp8_max).to(fp8_dtype) + + # Per-tensor scales [E] + w13_scale = w13_scale_full.squeeze(-1).amax(dim=-1) # [E] + w2_scale = w2_scale_full.squeeze(-1).amax(dim=-1) # [E] + + # Shuffle weights (FlyDSL expects preshuffled) + try: + from aiter.ops.shuffle import shuffle_weight + + w13 = shuffle_weight(w13) + w2 = shuffle_weight(w2) + except ImportError: + print(" WARNING: aiter shuffle not available, using unshuffled weights") + + # Input and routing + x = torch.randn(M, model_dim, device=device, dtype=torch.bfloat16) + scores = torch.randn(M, E, device=device) + topk_vals, topk_ids = torch.topk(scores, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + + # Run FlyDSL MOE + out = flydsl_fp8_moe( + x=x, + w13=w13, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + w13_scale=w13_scale, + w2_scale=w2_scale, + top_k=topk, + block_quant=False, + quant_type=None, + ) + + assert out.shape == (M, model_dim), f"Expected ({M}, {model_dim}), got {out.shape}" + assert out.dtype == x.dtype + assert torch.isfinite(out).all(), "Output contains NaN/Inf" + + print( + f" PASS: test_flydsl_fp8_moe_gpu (output shape={out.shape}, " + f"mean={out.float().mean():.4f}, std={out.float().std():.4f})" + ) + + +def run_gpu_tests(): + print("=" * 60) + print("GPU integration tests (requires FlyDSL + GPU)") + print("=" * 60) + if not torch.cuda.is_available(): + print("SKIP: No GPU available") + return + test_flydsl_fp8_moe_gpu() + print("\nAll GPU tests passed!") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--unit", action="store_true", help="Run unit tests only") + parser.add_argument("--gpu", action="store_true", help="Run GPU integration tests") + args = parser.parse_args() + + if not args.unit and not args.gpu: + args.unit = True # default to unit tests + + if args.unit: + run_unit_tests() + if args.gpu: + run_gpu_tests() diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 324c10a9c..48155c046 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: MIT # Tests for atom/model_engine/scheduler.py — public API only +import numpy as np + from atom.model_engine.scheduler import Scheduler, ScheduledBatchOutput from atom.model_engine.sequence import SequenceStatus, SequenceType from atom.sampling_params import SamplingParams @@ -121,7 +123,9 @@ def _prefill(self, scheduler, seq): def _output(self, seq_id, tokens): return ScheduledBatchOutput( - token_ids={seq_id: tuple(tokens)}, draft_token_ids=None + token_ids={seq_id: tuple(tokens)}, + num_rejected=np.zeros(0, dtype=np.int32), + draft_token_ids=None, ) def test_appends_token(self, scheduler, seq_factory): @@ -166,7 +170,11 @@ def test_stop_token_ids(self, seq_factory): sched.schedule() finished = sched.postprocess( list(sched.running), - ScheduledBatchOutput(token_ids={seq.id: (99,)}, draft_token_ids=None), + ScheduledBatchOutput( + token_ids={seq.id: (99,)}, + num_rejected=np.zeros(0, dtype=np.int32), + draft_token_ids=None, + ), ) assert len(finished) == 1 assert "stop_99" in finished[0].leave_reason