Skip to content

Fix Flash Attention 3 API compatibility for window size parameters#2704

Open
jhvmhg wants to merge 9 commits intoNVIDIA:mainfrom
jhvmhg:fix/flash_attn3_support_CP
Open

Fix Flash Attention 3 API compatibility for window size parameters#2704
jhvmhg wants to merge 9 commits intoNVIDIA:mainfrom
jhvmhg:fix/flash_attn3_support_CP

Conversation

@jhvmhg
Copy link

@jhvmhg jhvmhg commented Feb 25, 2026

Replace single window_size parameter with window_size_left and window_size_right in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.

  • Update function signature in flash_attn_interface
  • Maintain backward compatibility where possible
  • Ensure consistency with Flash Attention v2 implementation

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

  1. Fix window size parameters in flash_attn_fwd - Replaces the single window_size parameter with separate window_size_left and window_size_right parameters to match the updated flash-attn v2.7.0+ API.
  2. Fix causal parameter naming in flash_attn_bwd - Renames causal to is_causal in the backward function signature for consistency with the latest flash-attn interface.

Motivation:

The flash-attn library v2.7.0+ introduced breaking API changes that cause compatibility issues with TransformerEngine's Flash Attention 3 integration. These updates ensure seamless operation with newer versions of the flash-attn library while maintaining correctness of both forward and backward attention computations.

Related API Changes:

flash-attn v2.7.0+ split window_size into window_size_left and window_size_right
flash-attn v3+ renamed causal parameter to is_causal in backward pass

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Replace single window_size parameter with window_size_left and window_size_right
    in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.
  • Rename causal parameter to is_causal in flash_attn_bwd function to align
    with flash-attn v3

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Replace single window_size parameter with window_size_left and window_size_right
in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.

- Update function signature in flash_attn_interface
- Maintain backward compatibility where possible
- Ensure consistency with Flash Attention v2 implementation

Signed-off-by: Chaoyang Mei <1192554423@qq.com>
Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 25, 2026

Greptile Summary

This PR correctly fixes Flash Attention 3 API compatibility across the context-parallel attention implementations in TransformerEngine. It replaces the legacy window_size=(left, right) tuple parameter with the newer window_size_left / window_size_right split parameters for both FA3 and FA2 v2.7.0+, and renames the causal parameter to is_causal in the FA3 backward-pass call sites.

Key changes:

  • cp_p2p_fwd_flash_attn: condition restructured so FA3 uses window_size_left/window_size_right instead of window_size
  • cp_p2p_bwd_flash_attn: same window-size restructuring plus causal_ conditionally written as is_causal (FA3) or causal (FA2) into fa_backward_kwargs before each flash_attn_bwd call
  • AttnFuncWithCPAndKVP2P.forward(): initial FA3 setup now uses split window-size keys
  • AttnFuncWithCPAndKVAllGather.backward() and AttnFuncWithCPAndQKVOA2A.backward(): causal parameter migrated into fa_backward_kwargs under the appropriate key

The changes are applied consistently across all three CP attention classes and the two standalone per-tile helpers. The fa_backward_kwargs shared mutable dict is safely used: the causal/is_causal key is written unconditionally before each flash_attn_bwd call, eliminating stale-value risk across iterations.

Confidence Score: 4/5

  • This PR is safe to merge; it correctly adapts the FA3 API usage without regressing FA2 paths.
  • The restructuring is logically consistent across all three CP attention classes. The window_size_left/window_size_right split and the is_causal rename are applied in every affected call site. The shared-dict mutation pattern for fa_backward_kwargs was pre-existing; the new causal/is_causal key is always overwritten before it is consumed, so no stale-value risk is introduced. The only minor concern is the lack of tests that cover the FA3 code paths, but this is a broader testing gap rather than a bug in the change itself.
  • No files require special attention beyond the single changed file.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Flash Attention Call] --> B{use_flash_attn_3?}
    B -->|Yes FA3| C[window_size_left / window_size_right]
    B -->|No| D{fa_utils.v2_7_0_plus?}
    D -->|Yes| C
    D -->|No| E{fa_utils.v2_3_plus?}
    E -->|Yes| F[window_size tuple]
    E -->|No| G[No window_size param]

    C --> H[Backward Pass]
    F --> H
    G --> H

    H --> I{use_flash_attn_3?}
    I -->|Yes FA3| J[is_causal = bool]
    I -->|No FA2| K[causal = bool]

    J --> L[flash_attn_bwd called]
    K --> L
Loading

Last reviewed commit: 0e5c14d

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 25, 2026

Additional Comments (1)

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
removed causal parameter but other flash_attn_bwd calls in this file (lines 3222, 3832) still pass it - verify this inconsistency is intentional

Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@jhvmhg jhvmhg force-pushed the fix/flash_attn3_support_CP branch from a245229 to f9752ca Compare February 25, 2026 07:54
Copy link
Author

@jhvmhg jhvmhg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix Flash Attention 3 backward API parameter naming

Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

jhvmhg and others added 2 commits February 25, 2026 15:56
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Replace keyword arguments with positional arguments in flash_attn_fwd and
flash_attn_bwd to abstract away parameter naming differences (causal vs
is_causal) between flash-attn versions. This provides a more robust
interface that is resilient to future API changes in the flash-attn library.

- Convert window_size_left, window_size_right, and causal parameters to
  positional args in both forward and backward functions
- Eliminate version-specific parameter naming dependencies
- Simplify compatibility handling across flash-attn v2.7.0+ variants

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

softmax_lse_per_step[i],
*fa_backward_args_thd,
causal="causal" in ctx.attn_mask_type,
ctx.attn_mask_type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.attn_mask_type is a string (e.g., "causal", "no_mask"), not a boolean. Should be "causal" in ctx.attn_mask_type to convert to boolean.

Suggested change
ctx.attn_mask_type,
"causal" in ctx.attn_mask_type,

@jhvmhg jhvmhg closed this Feb 25, 2026
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v3 API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@jhvmhg jhvmhg reopened this Feb 25, 2026
@jhvmhg jhvmhg closed this Feb 25, 2026
@jhvmhg jhvmhg reopened this Feb 25, 2026
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from cyanguwa February 25, 2026 19:57
@cyanguwa cyanguwa requested a review from mk-61 February 26, 2026 00:04
@cyanguwa
Copy link
Collaborator

@mk-61 I think the changes look good, but could you please follow through with the CI, especially the L3_FA_version tests, to make sure the new changes pass the SWA tests for FA3? Thanks!

@vcherepanov-nv
Copy link
Collaborator

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants