Summary
With the upstream Flash-Attention 4 PR Dao-AILab/flash-attention#2412 adding head_dim=256 support for SM100 (Blackwell), I'd like to ask whether MagiAttention has plans to support head_dim=256 on SM100 as well.
Context
PR #2412 introduces dedicated 2-CTA kernels for head_dim=256 on SM100:
- Forward: tile size
(128, 128), 2-CTA cooperative design
- Backward (dKdV): tile size
(128, 64), 2-CTA design
MagiAttention currently depends on a fork of FA4 (demonatic/flash-attention, branch magi_attn_blackwell_support). After reviewing the codebase, I identified several areas that would need adaptation:
1. Fork FA4 side
_validate_head_dims currently restricts SM100 head_dim to ≤ 128. This needs to be relaxed to 256.
flash_bwd_sm100.py has assert self.tile_hdim <= 128, blocking head_dim=256 backward kernels.
get_tile_sizes_by_backend needs to return the correct tile sizes for head_dim=256 (fwd: (128, 128), bwd: (128, 64)).
- Key architectural concern: The fork currently hardcodes
cluster_size=1 for SM100 backward (no 2-CTA), while upstream PR #2412's head_dim=256 backward relies on 2-CTA (cluster_size=2). This is a significant divergence that needs a design decision.
2. MagiAttention side
calc_meta.py: The SM100 tile size hard constraint (128×128 only) needs to be relaxed to allow bwd tile_n=64.
precompile_ffa_fa4.py: head_dims list needs to include 256.
- Block sparse mask generation: The existing parameterized code (
_make_fa4_args_dict) already handles different fwd/bwd tile sizes correctly, so K2Q/Q2K mask BLOCK_SIZE should adapt automatically.
3. Potential challenge: 2-CTA backward + block sparsity
In upstream FA4, the SM100 backward kernel asserts blocksparse_tensors is None when use_2cta_instrs=True (i.e., 2-CTA and block sparsity are mutually exclusive). Since head_dim=256 backward almost certainly requires 2-CTA for performance, this means block sparsity may not be available for head_dim=256 backward on SM100 unless the upstream resolves this limitation.
This could impact MagiAttention's sparse attention functionality on head_dim=256.
Questions
- Are there plans to support
head_dim=256 on SM100 in MagiAttention?
- Is there a timeline for syncing the fork (
magi_attn_blackwell_support) with upstream FA4 PR #2412?
- Regarding the 2-CTA vs. block sparsity trade-off in backward, is there a preferred strategy (e.g., keeping
cluster_size=1 for sparsity support, or adopting 2-CTA and handling sparsity differently)?
Related
Thanks for the great work on MagiAttention! Happy to help with testing or contributing if there's a plan to move forward.
Summary
With the upstream Flash-Attention 4 PR Dao-AILab/flash-attention#2412 adding
head_dim=256support for SM100 (Blackwell), I'd like to ask whether MagiAttention has plans to supporthead_dim=256on SM100 as well.Context
PR #2412 introduces dedicated 2-CTA kernels for
head_dim=256on SM100:(128, 128), 2-CTA cooperative design(128, 64), 2-CTA designMagiAttention currently depends on a fork of FA4 (
demonatic/flash-attention, branchmagi_attn_blackwell_support). After reviewing the codebase, I identified several areas that would need adaptation:1. Fork FA4 side
_validate_head_dimscurrently restricts SM100head_dimto ≤ 128. This needs to be relaxed to 256.flash_bwd_sm100.pyhasassert self.tile_hdim <= 128, blockinghead_dim=256backward kernels.get_tile_sizes_by_backendneeds to return the correct tile sizes forhead_dim=256(fwd:(128, 128), bwd:(128, 64)).cluster_size=1for SM100 backward (no 2-CTA), while upstream PR #2412'shead_dim=256backward relies on 2-CTA (cluster_size=2). This is a significant divergence that needs a design decision.2. MagiAttention side
calc_meta.py: The SM100 tile size hard constraint (128×128only) needs to be relaxed to allow bwdtile_n=64.precompile_ffa_fa4.py:head_dimslist needs to include256._make_fa4_args_dict) already handles different fwd/bwd tile sizes correctly, so K2Q/Q2K maskBLOCK_SIZEshould adapt automatically.3. Potential challenge: 2-CTA backward + block sparsity
In upstream FA4, the SM100 backward kernel asserts
blocksparse_tensors is Nonewhenuse_2cta_instrs=True(i.e., 2-CTA and block sparsity are mutually exclusive). Sincehead_dim=256backward almost certainly requires 2-CTA for performance, this means block sparsity may not be available forhead_dim=256backward on SM100 unless the upstream resolves this limitation.This could impact MagiAttention's sparse attention functionality on
head_dim=256.Questions
head_dim=256on SM100 in MagiAttention?magi_attn_blackwell_support) with upstream FA4 PR #2412?cluster_size=1for sparsity support, or adopting 2-CTA and handling sparsity differently)?Related
Thanks for the great work on MagiAttention! Happy to help with testing or contributing if there's a plan to move forward.