Skip to content

Commit 36edcce

Browse files
littskCopilotStrivin0311
authored
feat: multi-arch builds, env refactoring, new features, and expanded tests (#307)
* Support Ampere with cutlass-based FFA_FA4 (#287) * Support Ampere with cutlass-based FFA_FA4 (#287) * Update v1.1.0 overview (#285) * Update v1.1.0 public blogs (#281) * [HotFix]: fix proxy (#284) * Add DSA attention interface in extensions (#283) See merge request: !1 * support multi-arch CUDA builds (Hopper + Blackwell) Refactor build system to accept comma-separated compute capabilities via MAGI_ATTENTION_BUILD_COMPUTE_CAPABILITY (e.g. "90,100"). Add helper functions parse_compute_capabilities, get_gencode_flags, and resolve_build_capabilities in setup.py. Update CMakeLists.txt to accept MAGI_CUDA_ARCHITECTURES and strip PyTorch-injected gencode flags that may reference unsupported architectures. Made-with: Cursor * add global_window_size to infer_attn_mask_from_cu_seqlens Allow every query in a sample to always attend to the first global_window_size key tokens in addition to the sliding window, useful for architectures that require prefix tokens (e.g. sink tokens) to be globally visible. Update docs with the new parameter. Made-with: Cursor * add scm install script * skip fa4_ffa_precompile * add distributed roll API for MTP support Introduce a P2P-based `roll` operation that cyclically shifts dispatched local tensors along the sequence dimension without materialising the full global tensor (O(N/P) memory instead of O(N)). Primarily designed for Multi-Token Prediction (MTP) where labels are shifted relative to inputs. - New `functional/roll.py` with `roll_p2p` implementation and autograd support - Expose `roll` in public API (`magi_attention.api`) - Clean up import paths: import `roll_func` directly from `functional.roll` instead of re-exporting through `functional.dispatch` - Add `roll` section to API reference and quickstart docs - Allow optional `num_heads_q/kv`, `head_dim` override in `make_flex_key_for_new_mask_after_dispatch` - Add comprehensive tests (`tests/test_functional/test_roll.py`) Made-with: Cursor * polish tests for roll * dynamic pad token * ciel div clear * polish code * support uneven shard * refactor: move `ceil_div` to `magi_attention/utils/_utils.py` Consolidate the `ceil_div` helper into the shared utils module instead of defining it locally in `api/functools.py`, so that meta/solver code can reuse it without circular imports. Made-with: Cursor * simplify uneven_shard: remove virtual padding, use real chunk sizes Replace the previous "virtual metadata padding" approach with a simpler design where `total_seqlen` is used as-is (no padding at all): - Use `ceil_div` for num_chunks so the last chunk can be smaller - Remove `actual_total_seqlen_q/k` parameters from all interfaces - MinHeapDispatchAlg now reports `is_equal_num_workloads=False` and uses `ceil_div` for the per-bucket job limit - Simplify dispatch/undispatch: no zero-size virtual chunks, so `torch.split` works directly with `chunk_actual_sizes` - Remove virtual padding logic from `magi_attn_flex_key` and `make_flex_key_for_new_mask_after_dispatch` Made-with: Cursor * support variable chunk sizes in roll P2P Extract `_compute_segments` to handle source-segment calculation for both uniform and variable (last-chunk-smaller) layouts. Refactor `_roll_p2p_impl` to iterate segments generically, replacing the previous special-case branches for r==0 and r>0. Add comprehensive uneven-shard tests: aligned/non-aligned shifts, cross-last-chunk wrapping, negative/large shifts, edge cases (last_chunk_size=1), larger sequences, and backward correctness. Made-with: Cursor * update pipeline tests for simplified uneven_shard Remove virtual metadata padding logic from test_pipeline and test_pipeline_sdpa: no longer need `compute_pad_size`/`apply_padding` imports or `actual_total_seqlen_q/k` variables, since the uneven_shard path now uses original total_seqlen directly. Made-with: Cursor * fix rank error in roll * add caching for DistAttnRuntimeKey hash and infer_attn_mask_from_cu_seqlens Cache the hash of DistAttnRuntimeKey via __hash__ override to avoid repeated hashing of all fields on every dict lookup. Also add lru_cache to infer_attn_mask_from_cu_seqlens to skip redundant mask inference for repeated cu_seqlens patterns. Made-with: Cursor * improve type annotations for DistAttnRuntimeDictManager Add precise type hints for return types, parameters, and internal data structures. Import DistAttnRuntimeMgr for proper typing and remove the resolved TODO comment. Made-with: Cursor * refactor dispatch/undispatch with autograd Functions to reduce memory Replace the concat-all-then-scatter approach with custom autograd Functions (_DispatchFunc / _UndispatchFunc). Forward dispatch now selects local chunks directly (O(shard_seqlen) alloc) instead of building a full permuted tensor (O(total_seqlen)). Backward uses all_gather_v + unpermute, mirroring the inverse path. Made-with: Cursor * fix partial_dsink contiguity before backward communication Ensure partial_dsink is contiguous before communication in the backward pass to avoid potential issues with non-contiguous tensor layouts. Made-with: Cursor * polish code * fix last chunk_size for uneven_shard * fast build * fast build * install ffa build * enhance dist runtime dict * mem save ag * simple p2p * simple p2p * logging * reduce log * sequential dispatch uneven * fix seq bug * add tests * support fa4 * install scm * fix install on no gpu env * fix * fix install scm * fix ffa fa4 bug * collect all wheel * fix build * fix build * install wheel * fix install order * switch flash-attention submodule to magi-flash-attention and remove build patches - Point submodule to git@code.byted.org:seed/magi-flash-attention.git (main) with namespace renamed to magi_flash_attn_3 to avoid TORCH_LIBRARY conflict - Remove hopper_makefile_wrapper.mk (wheel support now handled in collect step) - Remove patch_create_block_mask.py (should be upstreamed into submodule) - Simplify install_flash_attn_cute.sh accordingly Made-with: Cursor * update flash-attention submodule and clean up build scripts - Update submodule to latest main (2c1b058) which includes: - Rename torch library namespace from flash_attn_3 to magi_flash_attn_3 - Support headless build in create_block_mask setup.py - Support MAGI_WHEEL_DIR in hopper Makefile - Remove redundant hopper wheel collection from install_flash_attn_cute.sh (now handled by upstream Makefile when MAGI_WHEEL_DIR is set) Made-with: Cursor * no build ffa * increace max func * increace max func * add set -e to install_flash_attn_cute.sh to fail fast on errors Made-with: Cursor * no overlaped impl * improve no_overlap path: pre-build merged_attn_arg, enhance logging, and fix test filtering - Pre-build merged_attn_arg in CalcMeta.__post_init__ instead of computing it on every forward/backward call in the no_overlap path - Fix seqlen_k_local calculation in DistAttnSolver to use host_k_ranges_global - Add detailed logging for OverlapConfig, CalcMeta, and DistAttnRuntime; move verbose remote_attn_args logging to DEBUG level - Add skip_if_world_size_filtered decorator for proper subprocess-level skip instead of early-return inside test body - Change num_heads test filter to underscore-separated format (e.g. 8_8) and support tuple values in should_run_test_case Made-with: Cursor * no overlap support fa4 * add sdpa_online, consolidate pipeline tests, dist_attn and solver updates Made-with: Cursor * add is_partial_grad option to undispatch: use reduce_scatter in backward When is_partial_grad=True, the backward of undispatch uses dist.reduce_scatter to sum partial gradients across ranks before scattering, instead of simply selecting local chunks. This supports scenarios where each rank holds a partial gradient contribution (e.g. partial attention output gradients) that must be aggregated. The parameter is threaded through the full API stack: undispatch_func -> DistAttnRuntimeMgr.undispatch_qo/kv -> undispatch() Also adds unit tests covering forward round-trip, default backward, and partial-grad backward (both random and uniform) with even/uneven shards. Made-with: Cursor * install ffa * refactor: merge csrc/utils into extensions, unify C++/Python backend switching, and align interfaces - Merge csrc/utils/ into csrc/extensions/, eliminating the separate flexible_flash_attention_utils_cuda module. All FFA utils (argsort_ranges, unique_consecutive_pairs, compute_sparse_load_metadata, etc.) are now part of magi_attn_ext built via CMake. - Centralize C++/Python backend switching in common/__init__.py instead of scattered if-blocks at the bottom of each implementation file. Add Protocol definitions (protocols.py) to enforce interface alignment between backends. - Align all pybind11 bindings with Python ground truth: fix parameter names, remove C++-only methods (to_string, sort_ranges, reserve, clear, get_q/k/d_range), remove .export_values() on AttnMaskType, add missing __iter__ on AttnRectangles, and add Google-style docstrings to all public methods on both sides. - Replace ffa_utils alias with direct magi_attn_ext imports across all files. - Convert test_common/ tests from unittest.TestCase to plain pytest classes with a conftest.py backend fixture (params=["python", "cpp"]) so every test automatically runs against both backends. - Regenerate magi_attn_ext.pyi with updated signatures and docstrings. Made-with: Cursor * add readme * refactor: centralise env vars into magi_attention/env/ package Move all MAGI_ATTENTION_* environment variable accessors from scattered locations (__init__.py, comm/__init__.py, common/__init__.py, functional/*.py, common/jit/*.py) into a dedicated magi_attention/env/ package with three submodules: - env/general.py — runtime toggles, kernel backend, precision, etc. - env/comm.py — communication flags (hierarchical, qo_comm, etc.) - env/build.py — JIT/build settings (cache, workspace, nvcc, etc.) All ~100 call sites across magi_attention/, tests/, and exps/ are updated to use the new `env.general.xxx()` / `env.comm.xxx()` style. The old _env.py is removed. The top-level __init__.py and comm/__init__.py no longer re-export env-var functions. Also adds `!magi_attention/env/` to .gitignore so the package is not caught by the `env/` virtualenv exclusion rule. Made-with: Cursor * clear api * remove redundant __all__ from magi_attn_interface.py Public exports are managed solely by api/__init__.py; the per-module __all__ was duplicating that responsibility and adding maintenance burden. Made-with: Cursor * feat: add MAGI_ATTENTION_LOG_LEVEL env var to control package-wide logging - Add log_level() helper in env/general.py supporting DEBUG/INFO/WARN/ERROR/CRITICAL (default: WARN) - Configure the root magi_attention logger at import time based on the env var - Replace custom MagiAttentionJITLogger with standard getLogger(__name__) so JIT logger participates in the magi_attention logger hierarchy - Add INFO-level logging throughout JIT build pipeline (core.py, _flex_flash_attn_jit.py) Made-with: Cursor * support no_chunk_size * minor fix * fix: resolve pre-commit lint errors for Python 3.10+ match syntax - Upgrade flake8 6.1.0 -> 7.3.0 (pyflakes 3.4+ with match support) - Upgrade ruff v0.1.5 -> v0.11.4, add --target-version=py310 - Add --python-version=3.10 to mypy args - Add # noqa: F811 for intentional conditional re-imports in common/__init__.py - Add # noqa: E402 for necessary non-top-level imports - Add # noqa: F824 for read-only global/nonlocal declarations Made-with: Cursor * update submodule * fix tests * patch fix * fix cu131 * add chinese docs * log env * scm install fix * patch fix * fix scm * merge main * lint * chore: point flash-attention submodule at littsk fork, drop install hotfixes Use https://github.com/littsk/flexible-flash-attention on branch magi_attn_blackwell_support; multi-arch create_block_mask gencode and hopper Makefile build_ext live in the submodule. Remove redundant runtime patches from install_flash_attn_cute.sh. Made-with: Cursor * docs: update MAGI_ATTENTION_BUILD_COMPUTE_CAPABILITY description Reflect that this env variable now also affects create_block_mask builds, and document comma-separated multi-arch support (e.g. 90,100). Made-with: Cursor * chore: bump flash-attention submodule (platform tag fix) Picks up ce387e5 which fixes get_platform() in hopper/setup.py to use platform.machine() instead of hardcoded x86_64. Made-with: Cursor * feat: support CUSTOM_ARCH for cross-platform wheel builds Detect host CPU architecture from CUSTOM_ARCH env var (defaults to uname -m) and derive MAGI_WHEEL_PLAT_NAME (e.g. linux_aarch64). Pass --plat-name to all bdist_wheel / pip wheel invocations so that sub-package wheels (create_block_mask_cuda, magi_to_hstu_cuda, ffa_fa3) and the main magi_attention wheel carry the correct platform tag. Made-with: Cursor * fix: use setup.py bdist_wheel for --plat-name instead of pip --build-option pip wheel dropped --build-option support in newer versions, causing "no such option" errors on SCM builds. Switch all sub-package wheel builds to python setup.py bdist_wheel --plat-name + cp to wheel dir. Made-with: Cursor * fix: FA4 mask tile size resolution and sink+FA4 infinite loop in tests 1. _resolve_tile_sizes / _resolve_fa4_tile_sizes now return (128, 128) on SM100+, since the tile_m/tile_n in FA4AttnArg represent the mask block tile (not the kernel tile). The kernel internally doubles tile_m via sparse_tile_m in _make_fa4_args_dict. Only SM80/SM90 need to query get_tile_sizes_by_backend for headdim-dependent tiles. 2. Improved the SM10 tile size validation error message in FA4AttnArg to include actual vs expected values and the sparse_tile_m note. 3. Added BACKENDS (excluding FA4) to all 6 sink-bearing attn_configs in test_pipeline.py. FA4 does not support sink, so these configs would cause get_next_valid_comb to loop forever (all flag combos rejected by _is_valid_flag_comb while the generator never exhausts due to cycle_times=-1). Made-with: Cursor * chore: point flash-attention submodule back to demonatic upstream PR #9 merged all changes into demonatic/flash-attention, switch submodule URL back from littsk fork and update pointer to merge commit. Made-with: Cursor * lint * fix ut * more clear code * fix ci * fix: resolve pre-commit lint failures (black, flake8, ruff) Made-with: Cursor * fix: resolve CI test failures (dispatch, submask, pipeline alignment) - test_dispatch_solver: fix wrong assertTrue -> assertFalse for MinHeap - test_gt_dispatcher: use Python AttnRanges for sub_mask comparisons to avoid C++/Python cross-type equality failure - test_pipeline: add native_grpcoll invalidation rules for uneven_shard and small hidden_size_kv configs; pass num_heads/head_dim in test_config Made-with: Cursor * fix: proof-reading corrections (typos, grammar, and phrasing) Agent-Logs-Url: https://github.com/SandAI-org/MagiAttention/sessions/32611ac4-4154-4596-b276-d3f6d07fdf05 Co-authored-by: Strivin0311 <61719042+Strivin0311@users.noreply.github.com> * increase timeout * fix ci * fix: replace einops.repeat with native ops in sink_bwd for torch.compile compatibility einops.repeat hashes its axes_lengths kwargs internally, which fails under torch.compile(dynamic=True) because SymInt is not hashable. Made-with: Cursor * fix: replace all einops calls in sink_bwd with native PyTorch ops einops internally hashes tensor shapes for recipe caching, which is incompatible with SymInt under torch.compile(dynamic=True). Replace rearrange, reduce, and repeat with equivalent permute/sum/unsqueeze. Made-with: Cursor * fix max logits dtype error --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Strivin0311 <61719042+Strivin0311@users.noreply.github.com>
1 parent 1816199 commit 36edcce

156 files changed

Lines changed: 20865 additions & 4198 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
---
2+
name: debug-test-failures
3+
description: >-
4+
Systematic approach to debug failing tests. Use when a user reports test
5+
failures, regression bugs, or assertion mismatches. Core method: reproduce
6+
the failure, then use git history to determine whether the failure is caused
7+
by a recent commit or by the user's uncommitted changes.
8+
---
9+
10+
# Debug Test Failures
11+
12+
## Principle
13+
14+
**Test failures always have a cause in code changes.** The fastest debugging
15+
path is to figure out *which change* broke it, not to guess at the logic.
16+
17+
## Step 1: Reproduce and Extract Key Info
18+
19+
Run the failing test, capture full output:
20+
21+
```bash
22+
python -m pytest -sq <test_file> 2>&1 | head -80
23+
python -m pytest -sq <test_file> 2>&1 | tail -50
24+
```
25+
26+
Extract from the error:
27+
- Which test case(s) failed
28+
- The **actual vs expected** values
29+
- The **file and line** of the failing assertion
30+
31+
## Step 2: Determine the Scope of Failure
32+
33+
Check if this test *ever* passed on the current branch:
34+
35+
```bash
36+
# What files have uncommitted changes?
37+
git status
38+
39+
# What committed changes touch files related to the failure?
40+
git log --oneline -20 -- <relevant_source_files>
41+
```
42+
43+
This splits into two cases:
44+
45+
### Case A: The user has uncommitted changes in related files
46+
47+
```bash
48+
git diff -- <relevant_source_files>
49+
```
50+
51+
Read the diff carefully. The bug is likely in the uncommitted changes.
52+
Compare the diff against the assertion error to find the mismatch.
53+
54+
### Case B: No uncommitted changes in related files
55+
56+
The regression was introduced by a recent commit. Proceed to Step 3.
57+
58+
## Step 3: Walk Git History to Find the Offending Commit
59+
60+
```bash
61+
# List recent commits touching the relevant files
62+
git log --oneline --all -- <file_path>
63+
```
64+
65+
Then inspect each suspect commit:
66+
67+
```bash
68+
git show <commit_hash> -- <file_path>
69+
```
70+
71+
Walk commits **chronologically** and identify:
72+
1. **Last known good state** — what the logic looked like before
73+
2. **Offending commit** — where behavior changed
74+
3. **Intent** — was the change a refactor, feature, or bugfix?
75+
76+
Common regression patterns:
77+
- **Refactoring that widens a condition** — e.g. merging two flags into one,
78+
where the new condition covers more cases than intended
79+
- **Default value changes** — a dataclass/config default was changed, silently
80+
affecting callers that relied on the old default
81+
- **Silent override in initialization**`__init__` / `__post_init__` /
82+
constructor overwrites a user-provided value under a too-broad condition
83+
84+
## Step 4: Confirm Root Cause by Comparing Before/After
85+
86+
Once you identify the suspect commit, compare the old and new logic side by
87+
side. Verify that:
88+
- The old logic would produce the **expected** test output
89+
- The new logic produces the **actual** (wrong) test output
90+
- The behavioral difference is **unintentional** (not a deliberate design change)
91+
92+
## Step 5: Apply Minimal Fix and Verify
93+
94+
1. Make the **smallest change** that restores correct behavior
95+
2. Ensure the original intent of the offending commit is preserved
96+
3. Re-run the failing test to confirm all cases pass
97+
98+
```bash
99+
python -m pytest -sq <test_file> 2>&1 | tail -5
100+
```
101+
102+
## Anti-Patterns
103+
104+
- **Don't update expected values** to match broken output without understanding
105+
why they differ — the tests encode domain knowledge.
106+
- **Don't guess at the fix** without tracing the cause — always check git
107+
history first.
108+
- **Don't ignore the "other case"** — if the fix narrows a condition, verify
109+
that the broader case (which the offending commit intended to handle) still
110+
works correctly.
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
---
2+
name: magi-code-philosophy
3+
description: >-
4+
Core engineering philosophies for the Magi Attention codebase. Use when
5+
writing, reviewing, or modifying any code in this project. Covers config
6+
consistency, static-analysis-friendly readability, and test coverage
7+
requirements. Any deviation from these philosophies must be strictly
8+
commented with justification.
9+
---
10+
11+
# Magi Code Philosophy
12+
13+
These are the non-negotiable engineering principles for this codebase.
14+
Every contributor — human or AI — must follow them. When a principle
15+
cannot be followed, a **DEVIATION comment** is required (see bottom).
16+
17+
---
18+
19+
## Philosophy 1: Config Consistency
20+
21+
> **A config field's value should mean what the user set it to.**
22+
23+
If internal logic must transform, override, or reinterpret a user-supplied
24+
value, this is a deviation that requires explicit justification.
25+
26+
### Rules
27+
28+
1. **Preserve user intent** — When a user sets `field=X`, reading `obj.field`
29+
should return `X` or something recognizably equivalent. If the value must
30+
be normalized, store the original intent in a private field or property
31+
before overwriting.
32+
33+
2. **Validate early, normalize minimally**`__post_init__` should primarily
34+
assert invariants. Normalization should be the smallest necessary
35+
adjustment. Never silently clamp or discard user input.
36+
37+
3. **Derived fields ≠ overwritten fields** — Values computed from other
38+
fields should be **new** fields, not overwrites of user input. Exception:
39+
sentinel values (e.g., `-1` = "auto-detect") are designed to be replaced,
40+
but still require a comment.
41+
42+
4. **In-place mutation must be documented** — If `__post_init__` or a helper
43+
mutates a nested structure (e.g., `num_tokens *= 2` for packed KV), the
44+
site must comment: what is mutated, why, and how to recover the original.
45+
46+
5. **Forced override needs justification** — When code forces a field value
47+
from another field (e.g., `deterministic |= reduce_op != "sum"`), explain
48+
why the user's choice is being overridden.
49+
50+
### Deviation Format (Config)
51+
52+
```python
53+
# DEVIATION: <one-line summary>
54+
# Reason: <why the user-facing value cannot be kept as-is>
55+
# Recovery: <how to access original intent, or "none">
56+
```
57+
58+
---
59+
60+
## Philosophy 2: Readability via Static Navigability
61+
62+
> **Every symbol in the code must be statically resolvable and jump-to-able.**
63+
64+
Code is read far more often than written. The reader should be able to
65+
Ctrl+Click (or equivalent) on any name and land on its definition. If the
66+
IDE's static analysis cannot resolve a symbol, the code is not readable
67+
enough.
68+
69+
### Rules
70+
71+
1. **Explicit imports over dynamic lookups** — Use direct imports, not
72+
`getattr(module, name)` or `globals()[name]`. If dynamic dispatch is
73+
truly needed, use a typed registry/dict with the concrete types visible
74+
at the registration site.
75+
76+
2. **Typed dicts and enums over magic strings** — Prefer `Enum` members
77+
and `TypedDict` keys over raw string literals. Strings are invisible to
78+
static analysis; enums and typed keys are jump-to-able.
79+
80+
3. **No untyped `**kwargs` pass-through in public APIs** — Public-facing
81+
functions should declare their parameters explicitly. `**kwargs` may be
82+
used internally (e.g., forwarding to a backend), but the public signature
83+
must be self-documenting.
84+
85+
4. **Avoid deep `Any` typing**`Any` kills jump-to-definition. Use
86+
`Protocol`, generics, or union types. Reserve `Any` for truly
87+
polymorphic boundaries (e.g., serialization).
88+
89+
5. **String-based dispatch must have a central map** — If behavior branches
90+
on a string value, define the mapping in one place (a dict or match/case)
91+
so that all targets are visible together and searchable.
92+
93+
6. **Re-exports must be explicit** — When `__init__.py` re-exports symbols,
94+
use explicit `from .module import Name` rather than `import module` with
95+
`__all__`. This ensures the IDE can resolve the re-exported name.
96+
97+
### Deviation Format (Readability)
98+
99+
```python
100+
# DEVIATION: <what is not statically resolvable>
101+
# Reason: <why dynamic dispatch / Any / kwargs is unavoidable here>
102+
# Mitigation: <how a reader can still find the target, e.g., "see registry at X">
103+
```
104+
105+
---
106+
107+
## Philosophy 3: Test Completeness
108+
109+
> **Where there is code, there must be tests.**
110+
111+
No feature, bug fix, refactor, or config change is considered done until
112+
it has corresponding test coverage. Untested code is assumed broken.
113+
114+
### Rules
115+
116+
1. **Every public function/class has a test** — If it's importable from
117+
outside its module, it needs at least one test exercising its primary
118+
path and one test for its most important edge case.
119+
120+
2. **Config normalization must be tested** — For every `__post_init__`
121+
normalization or deviation, there must be a test that:
122+
- Sets the user-facing value
123+
- Asserts the normalized internal value
124+
- Asserts the original intent is recoverable (if applicable)
125+
126+
3. **Bug fixes come with regression tests** — The test must reproduce the
127+
original failure first (red), then pass with the fix (green).
128+
129+
4. **Solver/algorithm changes need correctness tests** — Any change to a
130+
solver (`overlap_solver`, `dispatch_solver`, `dist_attn_solver`, etc.)
131+
must include tests that verify the output solution against known-good
132+
reference values.
133+
134+
5. **Test names describe the scenario** — Use descriptive names like
135+
`test_overlap_config_degree_zero_normalizes_to_one`, not `test_config_1`.
136+
The name should read as a specification.
137+
138+
6. **No test-only code in production modules** — Test helpers, fixtures, and
139+
mocks live in `tests/`. Production code should not contain `if TESTING:`
140+
branches or similar.
141+
142+
### Deviation Format (Test)
143+
144+
```
145+
# DEVIATION: <what is not tested>
146+
# Reason: <why testing is impractical, e.g., requires multi-GPU hardware>
147+
# Tracking: <issue/TODO reference for future coverage>
148+
```
149+
150+
---
151+
152+
## General Deviation Protocol
153+
154+
When **any** philosophy cannot be followed, add a structured comment at the
155+
deviation site. The format depends on the philosophy (see each section
156+
above). The key invariant is:
157+
158+
> **Silence is not acceptable. If the code deviates, the code says so.**
159+
160+
Reviewers (human or AI) should flag any deviation that lacks a comment as a
161+
blocking issue.
162+
163+
---
164+
165+
## Quick Checklist
166+
167+
Before submitting code, verify:
168+
169+
**Config Consistency**
170+
- [ ] Every `__post_init__` field overwrite has a DEVIATION comment
171+
- [ ] User intent is recoverable via private field or property
172+
- [ ] Sentinel values are documented in the class docstring
173+
174+
**Readability**
175+
- [ ] All symbols are Ctrl+Click navigable (no unresolvable dynamic lookups)
176+
- [ ] Public APIs have explicit typed signatures (no bare `**kwargs`)
177+
- [ ] String-based dispatch has a central, visible mapping
178+
179+
**Tests**
180+
- [ ] Every new/changed public API has corresponding tests
181+
- [ ] Config normalizations have dedicated test cases
182+
- [ ] Bug fixes include a regression test
183+
- [ ] Test names describe the scenario, not just a number

.flake8

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ ignore =
88
E203
99
exclude =
1010
# Exclude Python interface files
11-
*.pyi
11+
*.pyi
12+
# Exclude translation script (contains long i18n strings)
13+
scripts/translate_po.py

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ celerybeat.pid
148148
.envrc
149149
.venv
150150
env/
151+
!magi_attention/env/
151152
venv/
152153
ENV/
153154
env.bak/

.pre-commit-config.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ repos:
1919
pass_filenames: true
2020
always_run: true
2121
files: \.(txt|md|yaml|c|cc|cxx|cpp|cu|cuh|h|hpp|hxx|proto|py|pyi|sh)$
22+
exclude: '(_zh\.\w+$|scripts/translate_po\.py|docs/source/conf\.py)'
2223
- id: csrc_code_formatter
2324
name: check for csrc code format
2425
entry: bash scripts/run_csrc_code_formatter.sh
@@ -58,15 +59,15 @@ repos:
5859
- id: black
5960
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
6061
- repo: https://github.com/PyCQA/flake8
61-
rev: 6.1.0
62+
rev: 7.3.0
6263
hooks:
6364
- id: flake8
6465
args: ["--config=.flake8"]
6566
- repo: https://github.com/astral-sh/ruff-pre-commit
66-
rev: v0.1.5
67+
rev: v0.11.4
6768
hooks:
6869
- id: ruff
69-
args: [--fix, --exit-non-zero-on-fix, --no-cache]
70+
args: [--fix, --exit-non-zero-on-fix, --no-cache, --target-version=py310]
7071
- repo: https://github.com/pre-commit/mirrors-isort
7172
rev: v5.10.1
7273
hooks:
@@ -77,4 +78,4 @@ repos:
7778
hooks:
7879
- id: mypy
7980
files: \.py$
80-
args: [--config=mypy.ini, --ignore-missing-imports]
81+
args: [--config=mypy.ini, --ignore-missing-imports, --python-version=3.10]

MANIFEST.in

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@ include LICENSE
44

55
# Only include source code under csrc (runtime JIT/extension needs)
66
recursive-include magi_attention/csrc/common *.h *.hpp
7-
recursive-include magi_attention/csrc/extensions *.hpp *.cpp
7+
recursive-include magi_attention/csrc/extensions *.hpp *.cpp *.cu *.cuh *.h
88
recursive-include magi_attention/csrc/flexible_flash_attention *.h *.hpp *.cuh *.cu *.cpp *.jinja *.py
9-
recursive-include magi_attention/csrc/utils *.cpp *.cu
10-
119
# Cutlass: keep only headers under include/
1210
prune magi_attention/csrc/cutlass
1311
graft magi_attention/csrc/cutlass/include

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ To achieve linear scalability in distributed attention, we implemented the follo
6363

6464
- **Flexible Flash Attention Kernel**. We introduce a generalized attention mask formulation namely `AttnSlice` with a tailed kernel<em>Flex‑Flash‑Attention (FFA)</em>—natively designed to enable compact expression of diverse mask types and make distributed mask partitioning tractable, with performance comparable to [Flash-Attention 3](https://arxiv.org/abs/2407.08608) on Hopper GPUs, and preliminary support for Blackwell via a forked [Flash-Attention 4](https://github.com/demonatic/flash-attention/tree/magi_attn_blackwell_support).
6565
- **Computation Load Balancing**. With a fine-grained chunk‑level sharding strategy, we elaborate an efficient <em>dispatch solver</em> that ensures balanced computational workloads across each CP rank.
66-
- **Zero-Redundant Communication**. Instead of adopting the common Ring-style P2P communication pattern, we ropose two novel communication primitives, <em>GroupCast</em> and <em>GroupReduce</em>, realizing zero-redundant communication volume for both forward and backward passes.
66+
- **Zero-Redundant Communication**. Instead of adopting the common Ring-style P2P communication pattern, we propose two novel communication primitives, <em>GroupCast</em> and <em>GroupReduce</em>, realizing zero-redundant communication volume for both forward and backward passes.
6767
- **Adaptive Multi-Stage Overlap**. Leveraging the above enhancements, we further implement an adaptive multi-stage overlap strategy that schedules computation and communication to effectively hide latency and maximize utilization via either manual or automatic tuning.
6868

6969
If you are interested in the detailed methodology and implementation, please check our [blog](https://SandAI-org.github.io/MagiAttention/docs/main/blog/magi_attn.html#methodology) for more information.

conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,29 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
1517
import pytest
1618

1719

1820
def pytest_addoption(parser):
1921
parser.addoption(
2022
"--skip-slow", action="store_true", default=False, help="skip slow tests"
2123
)
24+
parser.addoption(
25+
"--test-attn-config",
26+
default=None,
27+
help="comma-separated attn_config names to run (supports fnmatch wildcards)",
28+
)
2229

2330

2431
def pytest_configure(config):
2532
config.addinivalue_line("markers", "slow: marks a test as slow to run")
2633

34+
attn_config_filter = config.getoption("--test-attn-config", default=None)
35+
if attn_config_filter is not None:
36+
os.environ["MAGI_ATTENTION_TEST_ATTN_CONFIG"] = attn_config_filter
37+
2738

2839
def pytest_collection_modifyitems(config, items):
2940
if config.getoption("--skip-slow"):

0 commit comments

Comments
 (0)