Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
1ec0a7b
removed useless argument from exps/grpcoll test
Strivin0311 Jan 26, 2026
ee86ea0
updated generate_inst script to use re to extract the kernel function…
Strivin0311 Jan 26, 2026
b0b835b
speed up magi_attn_comm building by skipping building when instantiat…
Strivin0311 Jan 26, 2026
cacf8de
added native_grpcoll_split_alignment envvar with checking
Strivin0311 Jan 26, 2026
87344e5
impl _preprocess_args_for_split_alignment and added pragma: no cover …
Strivin0311 Jan 26, 2026
8221344
minor fixed logging and repr
Strivin0311 Jan 26, 2026
fa5cc30
minor fixed comments
Strivin0311 Jan 26, 2026
e40c3cd
minor fixed comments
Strivin0311 Jan 26, 2026
dd5ba5c
minor fixed comments
Strivin0311 Jan 26, 2026
7fff225
refactored buffer to extract the common output view out
Strivin0311 Jan 26, 2026
93f8669
refactored buffer to add split alignment to view
Strivin0311 Jan 26, 2026
ab98805
refactored buffer to add split alignment to view for lse
Strivin0311 Jan 26, 2026
fdc01db
implemented test intranode with split alignment
Strivin0311 Jan 26, 2026
09317c2
minor updated test intranode
Strivin0311 Jan 26, 2026
7526bc8
updated test_intranode_grpcoll
Strivin0311 Jan 27, 2026
9e3ef2b
minor fixed comments
Strivin0311 Jan 27, 2026
78a7dc4
added temp debug code to let benchmark meet the split alignment
Strivin0311 Jan 27, 2026
0790e3c
raised up kNumTMABytesPerWarp to 216KB to support larger token
Strivin0311 Jan 27, 2026
3db712c
implemented split_alignment for internode
Strivin0311 Jan 27, 2026
539ca2e
fixed a bytes count bug for internode; forbid pass_padded_out_buffer …
Strivin0311 Jan 27, 2026
624d41e
updated benchmark settings
Strivin0311 Jan 28, 2026
7cf3a87
Support per split token in static solver (#228)
WT1W Jan 28, 2026
7a87590
Dyn solver split alignment (#230)
lijinnn Jan 29, 2026
9e43bb0
Relax INT_MAX buffer size limit for internode (#229)
Strivin0311 Jan 29, 2026
94fd9f6
add dynamic_solver_vis (#231)
lijinnn Jan 29, 2026
5c51d7e
Dynamic split alignment (#233)
Strivin0311 Jan 30, 2026
652bd55
updated the docs for MAGI_ATTENTION_AUTO_RANGE_MERGE
Strivin0311 Jan 30, 2026
4bc922b
build cp-bench docker image
Big-TRex Feb 3, 2026
a8e1a75
Update API for num_heads and head_dim (#236)
Strivin0311 Feb 3, 2026
366e79e
Support auto split alignment (#241)
Strivin0311 Feb 5, 2026
41d4e4c
hotfix switch envvars in bench
Big-TRex Feb 5, 2026
239970c
Update benchmark for blackwell (#243)
Strivin0311 Feb 6, 2026
a085d3b
support bwd save last stage overlap policy (#244)
WT1W Feb 6, 2026
b9a8223
[HotFix] Fix CI (#246)
Strivin0311 Feb 7, 2026
4a1f4f5
updated benchmark dockerfile to change base to ngc2510 and add magi_a…
Strivin0311 Feb 9, 2026
6440452
added b200_baseline config
Strivin0311 Feb 9, 2026
cd5535e
added b200_magi config
Strivin0311 Feb 9, 2026
ce64282
added b200_magi_native config
Strivin0311 Feb 9, 2026
0f49eeb
added b200_all config
Strivin0311 Feb 9, 2026
09c3ea0
fixed get_device_compute_capability with default_cap
Strivin0311 Jan 26, 2026
8eaded8
simplified the Dockerfile.benchmark
Strivin0311 Feb 9, 2026
2a38f5e
added a temp MAGI_ATTENTION_BUILD_COMPUTE_CAPABILITY
Strivin0311 Feb 10, 2026
44f1ec7
added a temp docker file
Strivin0311 Feb 10, 2026
b9128af
added a temp docker build script
Strivin0311 Feb 10, 2026
05c8e56
removed magi from baseline config
Strivin0311 Feb 10, 2026
0f4bb9f
fixed get_a2av_perm_idx kernel
Strivin0311 Feb 11, 2026
1771c19
Merge branch 'blackwell_benchmark' of https://github.com/SandAI-org/M…
Strivin0311 Feb 11, 2026
2a38a48
updated docker version to 25.10.6
Strivin0311 Feb 11, 2026
71edc34
supported get_a2av_perm_idx for 32 nodes
Strivin0311 Feb 11, 2026
37d02df
raised num_bytes from 5GB to 10GB
Strivin0311 Feb 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,3 @@ pybind11-stubgen magi_attention.magi_attn_ext -o .

> [!IMPORTANT]
> Failure to update stubs after modifying C++ code may cause type checking errors during CI.
```
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ For more usage instructions, you can refer to our [docs](https://SandAI-org.gith

import magi_attention
from magi_attention.api import (
magi_attn_flex_dispatch, calc_attn, undispatch, # interface functions
magi_attn_flex_key, dispatch, calc_attn, undispatch, # interface functions
compute_pad_size, # helper functions
)
from magi_attention.common import AttnRanges
Expand Down Expand Up @@ -288,19 +288,23 @@ For more usage instructions, you can refer to our [docs](https://SandAI-org.gith
# 1. the dispatched local token embedding may be shuffled along seqlen dim,
# so it's safe for token-wise operations such as matmul, layer-norm, etc
# while for sample-wise operations like RoPE, you might need to be more careful
# 2. the `magi_attn_runtime_key` holds some inner meta data as one argument for many other magi_attention APIs,
# which users don’t have to bother with
local_x, magi_attn_runtime_key = magi_attn_flex_dispatch(
x,
# 2. the `magi_attn_runtime_key` holds some inner meta data,
# as a required argument for many APIs of ``magi_attention``,
# which users don't have to bother with
magi_attn_runtime_key = magi_attn_flex_key(
q_ranges=q_ranges,
k_ranges=k_ranges,
attn_mask_type=attn_mask_type,
total_seqlen_q=total_seqlen_q,
total_seqlen_k=total_seqlen_k,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
pad_size=pad_size,
chunk_size=chunk_size,
cp_group_or_mesh=world_group, # assuming we only have 1-dim context parallelism (cp)
)
local_x = dispatch(x, key=magi_attn_runtime_key)

# --- Simulate QKV projection --- #

Expand Down Expand Up @@ -405,11 +409,12 @@ We provide additional [magi_attn_extensions](https://github.com/SandAI-org/MagiA
- [ ] **[WIP]** Support Ampere, Blackwell as well as other GPU architectures.
- [ ] **[WIP]** Optimize `Flex-Flash-Attention` kernels to improve performance and better support sparse attention (*such as [NSA](https://arxiv.org/pdf/2502.11089)*).
- [ ] **[WIP]** Optimize `DistAttnSolver` to reduce CPU overhead for meta info calculation and support better comp-/comm- overlapping.
- [ ] **[WIP]** Optimize `DynamicAttnSolver` for hybrid attention model or dynamic mask scenarios like sparse attention.
- [ ] **[WIP]** Provide a more comprehensive documentation with tutorials, and a more detailed technical blog / paper.
- [ ] Support other attention patterns including cross-attention, and inference scenarios involving KV cache (*w.r.t. [Paged Attention](https://arxiv.org/abs/2309.06180)*).
- [ ] Provide more example codes and recipes for various training scenarios.
- [ ] Upgrade `MagiAttention` to a distributed native `Flex-Flash-Attention` kernel (*as a major version update*).
- [x] Support `Dynamic DistAttnSolver` with query/output communication pattern, one for either hybrid attention model or dynamic mask scenarios like sparse attention, the other for reducing communication overhead for many cases when only communicating key/value is not the best choice.
- [x] Support `DynamicAttnSolver` with query/output communication pattern, for reducing communication overhead for many cases when only communicating key/value is not the best choice.
- [x] Support native `GroupCast` and `GroupReduce` communication kernels with inter-/intra-node hierarchical optimization (*similar to [DeepEP](https://github.com/deepseek-ai/DeepEP)*).
- [x] Support learnable attention sink (*w.r.t. [StreamingLLM](https://arxiv.org/abs/2309.17453)*).
- [x] Refactor `Distributed Attention Solver` to support all mask types with all kinds of overlap.
Expand Down
10 changes: 10 additions & 0 deletions docs/source/env_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ This feature is experimental and under active development for now.
If the C++ extension is not found or this variable is set to `0`, it will fall back to the Python implementation.
```

**MAGI_ATTENTION_AUTO_RANGE_MERGE**

Toggle this env variable to ``1`` to enable automatic range merging for flex-flash-attention,
to improve performance by reducing the number of attention ranges. The default value is `0`.

```{note}
This feature is experimental and under active development for now,
thus please do NOT enable it unless you know exactly what you are doing.
```

**MAGI_ATTENTION_DIST_ATTN_RUNTIME_DICT_SIZE**

Set the value of this env variable to control the size of `dist_attn_runtime_dict`. The default value is `100`.
Expand Down
22 changes: 5 additions & 17 deletions docs/source/magi_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,12 @@ To support computing irregular-shaped masks, we implemented a `flexible_flash_at

### Varlen Dispatch

If you're using a mask defined by `cu_seqlens`, such as a varlen full or varlen causal mask, we've designed a similar interface inspired by FlashAttention's API, making it easy for you to get started quickly. In the function named `magi_attn_varlen_dispatch`, you can obtain the dispatched `x` and `key`.
If you're using a mask defined by `cu_seqlens`, such as a varlen full or varlen causal mask, we've designed a similar interface `magi_attn_varlen_key` inspired by FlashAttention's API as follows, making it easy for you to get started quickly.

```{eval-rst}
.. currentmodule:: magi_attention.api.magi_attn_interface
```

```{eval-rst}
.. autofunction:: magi_attn_varlen_dispatch
```

The logic of the `magi_attn_varlen_dispatch` function mainly consists of two parts: it first calls `magi_attn_varlen_key` to compute a key value, and then uses this key to dispatch the input x. The description of `magi_attn_varlen_key` is as follows.

```{eval-rst}
.. autofunction:: magi_attn_varlen_key
```
Expand All @@ -52,18 +46,12 @@ Then the new mask will reuse the same dispatch solution as the mask used for dis

### Flexible Dispatch

If the masks you're using are not limited to varlen full or varlen causal, but also include sliding window masks or other more diverse types, we recommend using the following API. By calling `magi_attn_flex_dispatch`, you can obtain the dispatched x and key.
If the masks you're using are not limited to varlen full or varlen causal, but also include sliding window masks or other more diverse types, we recommend using the `magi_attn_flex_key` as follows.

```{eval-rst}
.. currentmodule:: magi_attention.api.magi_attn_interface
```

```{eval-rst}
.. autofunction:: magi_attn_flex_dispatch
```

Similar to the logic of `magi_attn_varlen_dispatch`, `magi_attn_flex_dispatch` first calls `magi_attn_flex_key` to obtain a key, and then uses this key to dispatch x. The description of `magi_attn_flex_key` is as follows.

```{eval-rst}
.. autofunction:: magi_attn_flex_key
```
Expand All @@ -78,7 +66,7 @@ Then the new mask will reuse the same dispatch solution as the mask used for dis

### Dispatch Function

If you already have the key, you can call `dispatch` function to get the padded and dispatched local tensor.
When you get the dist attn runtime key, you can call `dispatch` function to dispatch the global input tensor(s) to get the padded local tensor(s) along the seqlen dim.

```{eval-rst}
.. currentmodule:: magi_attention.api.magi_attn_interface
Expand All @@ -90,7 +78,7 @@ If you already have the key, you can call `dispatch` function to get the padded

## Calculate Attention

After dispatch and projection, you should obtain the query, key, and value needed for computation. Using the key obtained from the dispatch function mentioned above, you can perform the computation by calling `calc_attn`, which returns the results out and meta (containing lse).
After dispatch and QKV projection, you should obtain the local query, key, and value. Then you can calculate the distributed attention by calling `calc_attn` with the dist attn runtime key to get the local attention output tensor.

```{eval-rst}
.. currentmodule:: magi_attention.api.magi_attn_interface
Expand All @@ -105,7 +93,7 @@ After dispatch and projection, you should obtain the query, key, and value neede

### Undispatch Function

When you need to recover the complete global tensor from the local tensor like computing the loss, you can call `undispatch` function to unpad and undispatch the local tensor along the seqlen dim.
When you need to recover the global output tensor(s) from the local one(s), to compute the loss or some reason else, you can call `undispatch` function to undispatch the padded local ouput tensor(s) back to the unpadded global tensor along the seqlen dim.

```{eval-rst}
.. currentmodule:: magi_attention.api.magi_attn_interface
Expand Down
14 changes: 9 additions & 5 deletions docs/source/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ import torch.distributed as dist

import magi_attention
from magi_attention.api import (
magi_attn_flex_dispatch, calc_attn, undispatch, # interface functions
magi_attn_flex_key, dispatch, calc_attn, undispatch, # interface functions
compute_pad_size, # helper functions
)
from magi_attention.common import AttnRanges
Expand Down Expand Up @@ -164,19 +164,23 @@ pad_size = compute_pad_size( # pad embeds along seqlen dim for better performanc
# 1. the dispatched local token embedding may be shuffled along seqlen dim,
# so it's safe for token-wise operations such as matmul, layer-norm, etc
# while for sample-wise operations like RoPE, you might need to be more careful
# 2. the `magi_attn_runtime_key` holds some inner meta data as one argument for many other magi_attention APIs,
# which users don’t have to bother with
local_x, magi_attn_runtime_key = magi_attn_flex_dispatch(
x,
# 2. the `magi_attn_runtime_key` holds some inner meta data,
# as a required argument for many APIs of ``magi_attention``,
# which users don't have to bother with
magi_attn_runtime_key = magi_attn_flex_key(
q_ranges=q_ranges,
k_ranges=k_ranges,
attn_mask_type=attn_mask_type,
total_seqlen_q=total_seqlen_q,
total_seqlen_k=total_seqlen_k,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
pad_size=pad_size,
chunk_size=chunk_size,
cp_group_or_mesh=world_group, # assuming we only have 1-dim context parallelism (cp)
)
local_x = dispatch(x, key=magi_attn_runtime_key)

# --- Simulate QKV projection --- #

Expand Down
32 changes: 26 additions & 6 deletions examples/torch_native/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,14 @@ def train(model, optimizer, lr_scheduler, device_mesh, train_iter):
):
# dispatched input and prepare magi_attn key.
input, dist_attn_runtime_key = prepare_magi_attention(
input, cu_seqlens_q, cu_seqlens_k, pad_size, CHUNK_SIZE, device_mesh.get_group("cp")
input=input,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
num_heads_q=model.config.num_attention_heads,
num_heads_kv=model.config.num_key_value_heads,
head_dim=model.config.head_dim,
pad_size=pad_size,
cp_group=device_mesh.get_group("cp"),
)

output = model(input, dist_attn_runtime_key)
Expand Down Expand Up @@ -160,22 +167,35 @@ def prepare_data(device_mesh, train_iter):

**Prepare magi_attn_key:** Dispatch input data along cp dim and get dist_attn_runtime_key.
```python
def prepare_magi_attention(input, cu_seqlens_q, cu_seqlens_k, pad_size, cp_group):
# --- magi_attn_flex_dispatch --- #
def prepare_magi_attention(
input: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
num_heads_q: int,
num_heads_kv: int,
head_dim: int,
pad_size: int,
cp_group: dist.ProcessGroup,
):
# --- magi_attn_varlen_dispatch --- #

dist_attn_config = DistAttnConfig()

# you can also use fa_varlen-like varlen dispatch interface directly
x_padded, dist_attn_runtime_key = magi_attn_varlen_dispatch(
input,
dist_attn_runtime_key = magi_attn_varlen_key(
cu_seqlens_q,
cu_seqlens_k,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
pad_size=pad_size,
chunk_size=CHUNK_SIZE,
cp_group_or_mesh=cp_group,
causal=LlamaConfig().is_causal,
dist_attn_config=dist_attn_config,
)

x_padded = dispatch(input, key=dist_attn_runtime_key)

return x_padded, dist_attn_runtime_key
```

Expand Down
42 changes: 31 additions & 11 deletions examples/torch_native/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn.functional as F
from configuration_llama import LlamaConfig
from llama_pretrain_config import data_config, parallel_config, train_config
from modeling_llama import LlamaDecoderLayer, build_llama3_1b_model
from modeling_llama import LlamaDecoderLayer, LlamaModel, build_llama3_1b_model
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor, Partial, Shard, distribute_tensor
Expand All @@ -28,8 +28,9 @@
from magi_attention.api import (
DistAttnConfig,
compute_pad_size,
dispatch,
infer_varlen_mask_from_batch,
magi_attn_varlen_dispatch,
magi_attn_varlen_key,
squash_batch_dim,
undispatch,
)
Expand Down Expand Up @@ -238,23 +239,35 @@ def prepare_data(device_mesh, train_iter):
return local_input, local_label, cu_seqlens_q, cu_seqlens_k, pad_size


def prepare_magi_attention(input, cu_seqlens_q, cu_seqlens_k, pad_size, cp_group):
# --- magi_attn_flex_dispatch --- #
# an example of distattnconfig
def prepare_magi_attention(
input: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
num_heads_q: int,
num_heads_kv: int,
head_dim: int,
pad_size: int,
cp_group: dist.ProcessGroup,
):
# --- magi_attn_varlen_dispatch --- #

dist_attn_config = DistAttnConfig()

# you can also use fa_varlen-like varlen dispatch interface directly
x_padded, dist_attn_runtime_key = magi_attn_varlen_dispatch(
input,
dist_attn_runtime_key = magi_attn_varlen_key(
cu_seqlens_q,
cu_seqlens_k,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
pad_size=pad_size,
chunk_size=CHUNK_SIZE,
cp_group_or_mesh=cp_group,
causal=LlamaConfig().is_causal,
dist_attn_config=dist_attn_config,
)

x_padded = dispatch(input, key=dist_attn_runtime_key)

return x_padded, dist_attn_runtime_key


Expand Down Expand Up @@ -292,7 +305,7 @@ def loss_func(
return loss


def train(model, optimizer, lr_scheduler, device_mesh, train_iter):
def train(model: LlamaModel, optimizer, lr_scheduler, device_mesh, train_iter):
"""main training loop"""
model.train()

Expand All @@ -306,12 +319,19 @@ def train(model, optimizer, lr_scheduler, device_mesh, train_iter):
dist_attn_runtime_key = None

if (
parallel_config["context_parallel_size"] > 1
parallel_config["context_parallel_size"] > 1 # type: ignore[operator]
and parallel_config["context_parallel_backend"] == "magi_attention"
):
# dispatched input
input, dist_attn_runtime_key = prepare_magi_attention(
input, cu_seqlens_q, cu_seqlens_k, pad_size, device_mesh.get_group("cp")
input=input,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
num_heads_q=model.config.num_attention_heads,
num_heads_kv=model.config.num_key_value_heads,
head_dim=model.config.head_dim,
pad_size=pad_size,
cp_group=device_mesh.get_group("cp"),
)

output = model(input, dist_attn_runtime_key)
Expand Down
Loading