Skip to content

feat: multi-arch builds, env refactoring, new features, and expanded tests#307

Merged
Strivin0311 merged 99 commits intomainfrom
big-pr
Apr 8, 2026
Merged

feat: multi-arch builds, env refactoring, new features, and expanded tests#307
Strivin0311 merged 99 commits intomainfrom
big-pr

Conversation

@littsk
Copy link
Copy Markdown
Collaborator

@littsk littsk commented Apr 1, 2026

Summary

Comprehensive update: multi-arch CUDA builds, env var centralisation, new features, API refactoring, and expanded tests.

Build System & Installation

  • Refactor setup.py for comma-separated compute capabilities (MAGI_ATTENTION_BUILD_COMPUTE_CAPABILITY=90,100)
  • Add scripts/install_on_scm.sh with ARCH env var for cross-platform wheel builds (aarch64/x86_64)
  • Add scm_setup.py, scripts/install_skip_all.sh
  • Update flash-attention submodule, remove runtime hotfixes from install_flash_attn_cute.sh

Env & API Refactoring

  • Centralise all MAGI_ATTENTION_* env var accessors into magi_attention/env/ package
  • Add MAGI_ATTENTION_LOG_LEVEL, common/protocols.py
  • Clean up public API exports, bump flake8/ruff/mypy

New Features

  • FA4 Backend: Blackwell GPU support via FFA_FA4
  • No-Overlap Path: Non-overlapped distributed attention execution
  • SDPA Online: Online softmax-based SDPA fallback kernel
  • Uneven Shard: Last chunk can be smaller, no virtual padding
  • Dispatch Refactor: Custom autograd Functions for lower memory
  • Distributed Roll: P2P cyclic shift for MTP, O(N/P) memory

Refactoring

  • Merge csrc/utils/ into csrc/extensions/
  • Consolidate test_pipeline_sdpa.py into test_pipeline.py

Tests

  • New: test_dispatch.py, test_roll.py, test_protocol_conformance.py
  • Add dist_common.py testing utilities

概述

综合更新:多架构 CUDA 构建、环境变量集中管理、新功能、API 重构、测试扩展。

构建系统

  • 支持逗号分隔计算能力(如 90,100),新增 SCM 安装脚本
  • 更新 flash-attention 子模块,移除安装脚本中的运行时热修复

环境变量 & API

  • 迁移所有 MAGI_ATTENTION_*magi_attention/env/
  • 新增日志级别控制、协议定义、升级 lint 工具

新功能

  • FA4 后端(Blackwell)、无重叠路径、SDPA Online、非均匀分片
  • Dispatch autograd 重构、分布式 Roll(MTP)

重构 & 测试

  • csrc/utils 合并到 csrc/extensions
  • 合并 pipeline 测试,新增 dispatch/roll/protocol 测试

Test Plan

  • pytest tests/ — 所有单元测试
  • pytest tests/test_pipeline.py — 分布式测试
  • pytest tests/test_functional/ — dispatch 和 roll 测试
  • 验证 install_on_scm.sh 在 x86_64/aarch64 集群正常运行
  • 验证 FA4 后端(如有 Blackwell 硬件)

littsk added 30 commits March 27, 2026 04:15
* Support Ampere with cutlass-based FFA_FA4 (#287)
* Update v1.1.0 overview (#285)
* Update v1.1.0 public blogs (#281)
* [HotFix]: fix proxy (#284)
* Add DSA attention interface in extensions (#283)

See merge request: !1
Support Ampere with cutlass-based FFA_FA4 (#287)

See merge request: !1
Refactor build system to accept comma-separated compute capabilities
via MAGI_ATTENTION_BUILD_COMPUTE_CAPABILITY (e.g. "90,100"). Add
helper functions parse_compute_capabilities, get_gencode_flags, and
resolve_build_capabilities in setup.py. Update CMakeLists.txt to
accept MAGI_CUDA_ARCHITECTURES and strip PyTorch-injected gencode
flags that may reference unsupported architectures.

Made-with: Cursor
Allow every query in a sample to always attend to the first
global_window_size key tokens in addition to the sliding window,
useful for architectures that require prefix tokens (e.g. sink tokens)
to be globally visible. Update docs with the new parameter.

Made-with: Cursor
Introduce a P2P-based `roll` operation that cyclically shifts dispatched
local tensors along the sequence dimension without materialising the full
global tensor (O(N/P) memory instead of O(N)). Primarily designed for
Multi-Token Prediction (MTP) where labels are shifted relative to inputs.

- New `functional/roll.py` with `roll_p2p` implementation and autograd support
- Expose `roll` in public API (`magi_attention.api`)
- Clean up import paths: import `roll_func` directly from `functional.roll`
  instead of re-exporting through `functional.dispatch`
- Add `roll` section to API reference and quickstart docs
- Allow optional `num_heads_q/kv`, `head_dim` override in
  `make_flex_key_for_new_mask_after_dispatch`
- Add comprehensive tests (`tests/test_functional/test_roll.py`)

Made-with: Cursor
Consolidate the `ceil_div` helper into the shared utils module instead
of defining it locally in `api/functools.py`, so that meta/solver code
can reuse it without circular imports.

Made-with: Cursor
Replace the previous "virtual metadata padding" approach with a simpler
design where `total_seqlen` is used as-is (no padding at all):

- Use `ceil_div` for num_chunks so the last chunk can be smaller
- Remove `actual_total_seqlen_q/k` parameters from all interfaces
- MinHeapDispatchAlg now reports `is_equal_num_workloads=False` and
  uses `ceil_div` for the per-bucket job limit
- Simplify dispatch/undispatch: no zero-size virtual chunks, so
  `torch.split` works directly with `chunk_actual_sizes`
- Remove virtual padding logic from `magi_attn_flex_key` and
  `make_flex_key_for_new_mask_after_dispatch`

Made-with: Cursor
Extract `_compute_segments` to handle source-segment calculation for
both uniform and variable (last-chunk-smaller) layouts. Refactor
`_roll_p2p_impl` to iterate segments generically, replacing the
previous special-case branches for r==0 and r>0.

Add comprehensive uneven-shard tests: aligned/non-aligned shifts,
cross-last-chunk wrapping, negative/large shifts, edge cases
(last_chunk_size=1), larger sequences, and backward correctness.

Made-with: Cursor
Remove virtual metadata padding logic from test_pipeline and
test_pipeline_sdpa: no longer need `compute_pad_size`/`apply_padding`
imports or `actual_total_seqlen_q/k` variables, since the uneven_shard
path now uses original total_seqlen directly.

Made-with: Cursor
…eqlens

Cache the hash of DistAttnRuntimeKey via __hash__ override to avoid
repeated hashing of all fields on every dict lookup. Also add lru_cache
to infer_attn_mask_from_cu_seqlens to skip redundant mask inference for
repeated cu_seqlens patterns.

Made-with: Cursor
Add precise type hints for return types, parameters, and internal data
structures. Import DistAttnRuntimeMgr for proper typing and remove the
resolved TODO comment.

Made-with: Cursor
Replace the concat-all-then-scatter approach with custom autograd
Functions (_DispatchFunc / _UndispatchFunc). Forward dispatch now
selects local chunks directly (O(shard_seqlen) alloc) instead of
building a full permuted tensor (O(total_seqlen)). Backward uses
all_gather_v + unpermute, mirroring the inverse path.

Made-with: Cursor
Ensure partial_dsink is contiguous before communication in the backward
pass to avoid potential issues with non-contiguous tensor layouts.

Made-with: Cursor
- test_dispatch_solver: fix wrong assertTrue -> assertFalse for MinHeap
- test_gt_dispatcher: use Python AttnRanges for sub_mask comparisons to
  avoid C++/Python cross-type equality failure
- test_pipeline: add native_grpcoll invalidation rules for uneven_shard
  and small hidden_size_kv configs; pass num_heads/head_dim in test_config

Made-with: Cursor
@Strivin0311
Copy link
Copy Markdown
Contributor

@copilot review like proof-reading

Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 3, 2026

@copilot review like proof-reading

Done — addressed in commit cfd7e00. Here's a summary of what was corrected across 13 files:

Typos in source code:

  • fa4.py: eanbleenable (×2), leanablelearnable

Grammar / phrasing ("by now" → "currently"):

  • sdpa_online.py (×2), env/general.py, dist_attn_runtime_mgr.py, comm/…/utils.py, tests/test_pipeline.py

Other grammar fixes:

  • overlap_solver.py: serves as as aserves as a
  • dist_attn.py: can not be both Truecannot be both True
  • README.md: we roposewe propose
  • docs/…/blackwell_ffa_fa4.md: To early support … NVIDIA team and us introduceTo provide early support for … NVIDIA team and we introduce
  • docs/…/env_variables.md: MagiAttention are … will to be releasedMagiAttention is … will be released

The corresponding .po locale files were also updated to keep msgid strings in sync with the corrected source text.

Copilot AI requested a review from Strivin0311 April 3, 2026 03:57
Copy link
Copy Markdown
Contributor

@Strivin0311 Strivin0311 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TOO much to review, forced to LGTM

…ile compatibility

einops.repeat hashes its axes_lengths kwargs internally, which fails
under torch.compile(dynamic=True) because SymInt is not hashable.

Made-with: Cursor
einops internally hashes tensor shapes for recipe caching, which is
incompatible with SymInt under torch.compile(dynamic=True). Replace
rearrange, reduce, and repeat with equivalent permute/sum/unsqueeze.

Made-with: Cursor
Copy link
Copy Markdown
Contributor

@Strivin0311 Strivin0311 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Strivin0311 Strivin0311 merged commit 36edcce into main Apr 8, 2026
8 of 9 checks passed
@Strivin0311 Strivin0311 deleted the big-pr branch April 8, 2026 07:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants