Skip to content

[Feature Request] Support head_dim=256 on SM100 (Blackwell) #309

@LucienXian

Description

@LucienXian

Summary

With the upstream Flash-Attention 4 PR Dao-AILab/flash-attention#2412 adding head_dim=256 support for SM100 (Blackwell), I'd like to ask whether MagiAttention has plans to support head_dim=256 on SM100 as well.

Context

PR #2412 introduces dedicated 2-CTA kernels for head_dim=256 on SM100:

  • Forward: tile size (128, 128), 2-CTA cooperative design
  • Backward (dKdV): tile size (128, 64), 2-CTA design

MagiAttention currently depends on a fork of FA4 (demonatic/flash-attention, branch magi_attn_blackwell_support). After reviewing the codebase, I identified several areas that would need adaptation:

1. Fork FA4 side

  • _validate_head_dims currently restricts SM100 head_dim to ≤ 128. This needs to be relaxed to 256.
  • flash_bwd_sm100.py has assert self.tile_hdim <= 128, blocking head_dim=256 backward kernels.
  • get_tile_sizes_by_backend needs to return the correct tile sizes for head_dim=256 (fwd: (128, 128), bwd: (128, 64)).
  • Key architectural concern: The fork currently hardcodes cluster_size=1 for SM100 backward (no 2-CTA), while upstream PR #2412's head_dim=256 backward relies on 2-CTA (cluster_size=2). This is a significant divergence that needs a design decision.

2. MagiAttention side

  • calc_meta.py: The SM100 tile size hard constraint (128×128 only) needs to be relaxed to allow bwd tile_n=64.
  • precompile_ffa_fa4.py: head_dims list needs to include 256.
  • Block sparse mask generation: The existing parameterized code (_make_fa4_args_dict) already handles different fwd/bwd tile sizes correctly, so K2Q/Q2K mask BLOCK_SIZE should adapt automatically.

3. Potential challenge: 2-CTA backward + block sparsity

In upstream FA4, the SM100 backward kernel asserts blocksparse_tensors is None when use_2cta_instrs=True (i.e., 2-CTA and block sparsity are mutually exclusive). Since head_dim=256 backward almost certainly requires 2-CTA for performance, this means block sparsity may not be available for head_dim=256 backward on SM100 unless the upstream resolves this limitation.

This could impact MagiAttention's sparse attention functionality on head_dim=256.

Questions

  1. Are there plans to support head_dim=256 on SM100 in MagiAttention?
  2. Is there a timeline for syncing the fork (magi_attn_blackwell_support) with upstream FA4 PR #2412?
  3. Regarding the 2-CTA vs. block sparsity trade-off in backward, is there a preferred strategy (e.g., keeping cluster_size=1 for sparsity support, or adopting 2-CTA and handling sparsity differently)?

Related

Thanks for the great work on MagiAttention! Happy to help with testing or contributing if there's a plan to move forward.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions