From 5498bcd25891f175d93f65bcd560fce64d0393d7 Mon Sep 17 00:00:00 2001 From: shw Date: Wed, 28 Jan 2026 09:10:45 +0000 Subject: [PATCH 01/14] add pack_gqa template for bwd --- .../bwd_inst_template.jinja | 9 +++++++-- .../csrc/flexible_flash_attention/flash.h | 2 ++ .../flash_bwd_launch_template.h | 2 ++ .../flexible_flash_attention/flex_flash_bwd.hpp | 2 +- magi_attention/functional/flex_flash_attn.py | 13 +++++++++++-- tests/test_attn/test_flex_flash_attn.py | 4 ++++ 6 files changed, 27 insertions(+), 5 deletions(-) diff --git a/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja b/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja index d6706a0a2..0ae5901f9 100644 --- a/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja +++ b/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja @@ -19,10 +19,15 @@ static constexpr bool kDeterministic = {{ deterministic }}; static constexpr bool kProfileMode = {{ profile_mode }}; static constexpr bool kRangeMerge = {{ auto_range_merge }}; static constexpr bool kSwapBwdQKLoop = {{ swap_bwd_qk_loop }}; +static constexpr bool kPackGQA = {{ pack_gqa }}; +static constexpr int kQheadPerKhead = {{ qhead_per_khead }}; // TODO: add support for RangeMerge and Deterministic mode when SwapBwdQKLoop is enabled static_assert(!kSwapBwdQKLoop || (!kRangeMerge && !kDeterministic), "Neither RangeMerge nor Deterministic mode is supported by now when SwapBwdQKLoop is enabled."); +// PackGQA is only supported when SwapBwdQKLoop is enabled +static_assert(!kPackGQA || kSwapBwdQKLoop, "PackGQA is only supported when SwapBwdQKLoop is enabled."); + // Runtime contract checks to ensure consistency with compile-time constraints static inline void _check_runtime_contract_bwd( const at::Tensor& q, @@ -82,7 +87,7 @@ std::vector mha_bwd( auto stream = at::cuda::getCurrentCUDAStream().stream(); // Parameter preparation (including output tensor allocation) - auto [params, dq, dk, dv, dsink] = prepare_mha_bwd( + auto [params, dq, dk, dv, dsink] = prepare_mha_bwd( dout, q, k, v, sink_, out, dq_, dk_, dv_, dsink_, softmax_lse, q_ranges, k_ranges, attn_type_map_, merge_k_ranges_, bwd_kq_map_, bwd_unique_count_, @@ -94,7 +99,7 @@ std::vector mha_bwd( MagiEvents::stop("bwd_prepare"); // Kernel launch (single variant) - run_mha_bwd_(params, stream); + run_mha_bwd_(params, stream); return {dq, dk, dv, dsink}; } diff --git a/magi_attention/csrc/flexible_flash_attention/flash.h b/magi_attention/csrc/flexible_flash_attention/flash.h index fc260f05a..7e7bc9086 100644 --- a/magi_attention/csrc/flexible_flash_attention/flash.h +++ b/magi_attention/csrc/flexible_flash_attention/flash.h @@ -210,6 +210,8 @@ template < bool Deterministic, bool RangeMerge, bool SwapBwdQKLoop, + bool PackGQA, + int QheadPerKhead, bool ProfileMode> void run_mha_bwd_(Flash_bwd_params& params, cudaStream_t stream); diff --git a/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h b/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h index 594a5814c..ca8e8fda7 100644 --- a/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h +++ b/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h @@ -314,6 +314,8 @@ template < bool Deterministic, bool RangeMerge, bool SwapBwdQKLoop, + bool PackGQA, + int QheadPerKhead, bool ProfileMode> void run_mha_bwd_(Flash_bwd_params& params, cudaStream_t stream) { static_assert(sizeof(T) == 2, "Only 16bit computation are supported"); diff --git a/magi_attention/csrc/flexible_flash_attention/flex_flash_bwd.hpp b/magi_attention/csrc/flexible_flash_attention/flex_flash_bwd.hpp index 7304ec54d..080d30381 100644 --- a/magi_attention/csrc/flexible_flash_attention/flex_flash_bwd.hpp +++ b/magi_attention/csrc/flexible_flash_attention/flex_flash_bwd.hpp @@ -80,7 +80,7 @@ struct type_caster { // }); // } -template +template std::tuple prepare_mha_bwd( const at::Tensor& dout, const at::Tensor& q, diff --git a/magi_attention/functional/flex_flash_attn.py b/magi_attention/functional/flex_flash_attn.py index bd60b4c6b..c7cbcb465 100644 --- a/magi_attention/functional/flex_flash_attn.py +++ b/magi_attention/functional/flex_flash_attn.py @@ -457,6 +457,7 @@ def _flex_flash_attn_backward_compilable( bwd_kq_map: torch.Tensor | None, bwd_unique_count: torch.Tensor | None, swap_bwd_qk_loop: bool, + pack_gqa: bool, ) -> None: """torch.ops.flex_flash_attn._flex_flash_attn_backward_compilable""" mod = get_ffa_jit_mod( @@ -467,8 +468,8 @@ def _flex_flash_attn_backward_compilable( or (k.dtype if disable_bwd_dkv_atomic_reduction else torch.float32), softcap=softcap > 0.0, disable_atomic_reduction=disable_bwd_dkv_atomic_reduction, - pack_gqa=False, - qhead_per_khead=q.size(1) / k.size(1), + pack_gqa=pack_gqa, + qhead_per_khead=q.size(1) // k.size(1), deterministic=deterministic, auto_range_merge=auto_range_merge, swap_bwd_qk_loop=swap_bwd_qk_loop, @@ -543,6 +544,7 @@ def _flex_flash_attn_backward_compilable_fake( bwd_kq_map: torch.Tensor | None, bwd_unique_count: torch.Tensor | None, swap_bwd_qk_loop: bool, + pack_gqa: bool, ) -> None: pass @@ -577,6 +579,7 @@ def _flex_flash_attn_backward( bwd_kq_map: torch.Tensor | None = None, bwd_unique_count: torch.Tensor | None = None, swap_bwd_qk_loop: bool = False, + pack_gqa: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: if profile_mode: # NOTE: stop_event is called inside the kernel ffa_utils.start_event("bwd_prepare") @@ -629,6 +632,7 @@ def _flex_flash_attn_backward( bwd_kq_map=bwd_kq_map, bwd_unique_count=bwd_unique_count, swap_bwd_qk_loop=swap_bwd_qk_loop, + pack_gqa=pack_gqa, ) return dq, dk, dv, dsink @@ -765,6 +769,7 @@ def forward( ctx.auto_range_merge = auto_range_merge ctx.swap_ab = swap_ab ctx.swap_bwd_qk_loop = swap_bwd_qk_loop + ctx.pack_gqa = pack_gqa return out, lse @@ -790,6 +795,9 @@ def backward(ctx, dout: torch.Tensor, *args): # pragma: no cover ) merge_k_ranges, bwd_kq_map, bwd_unique_count = None, None, None + # pack_gqa in backward is only enabled when both pack_gqa and swap_bwd_qk_loop are True + bwd_pack_gqa = ctx.pack_gqa and ctx.swap_bwd_qk_loop + dq, dk, dv, dsink = _flex_flash_attn_backward( dout=dout, q=q, @@ -820,6 +828,7 @@ def backward(ctx, dout: torch.Tensor, *args): # pragma: no cover bwd_kq_map=bwd_kq_map, bwd_unique_count=bwd_unique_count, swap_bwd_qk_loop=ctx.swap_bwd_qk_loop, + pack_gqa=bwd_pack_gqa, ) # Cast gradients to the same dtype as inputs diff --git a/tests/test_attn/test_flex_flash_attn.py b/tests/test_attn/test_flex_flash_attn.py index ba9e7f117..e5e2d3c6e 100644 --- a/tests/test_attn/test_flex_flash_attn.py +++ b/tests/test_attn/test_flex_flash_attn.py @@ -1058,6 +1058,7 @@ def run_test_case( ref_block_size: tuple[int, int] | None, pack_gqa: bool, test_case: str, + swap_bwd_qk_loop: bool = False, err_ratio_dict: dict[str, float] = {}, max_seqlen_q: int | None = None, ) -> None: @@ -1166,6 +1167,7 @@ def run_test_case( ref_block_size=ref_block_size, pack_gqa=pack_gqa, sparse_load=sparse_load, + swap_bwd_qk_loop=swap_bwd_qk_loop, ) # run ffa backward @@ -1654,6 +1656,7 @@ def test_ffa_simple( swap_ab=swap_ab, ref_block_size=ref_block_size, pack_gqa=pack_gqa, + swap_bwd_qk_loop=swap_bwd_qk_loop, max_seqlen_q=max_seqlen_q, test_case=test_case, err_ratio_dict={ @@ -1848,6 +1851,7 @@ def test_ffa_random( swap_ab=swap_ab, ref_block_size=ref_block_size, pack_gqa=pack_gqa, + swap_bwd_qk_loop=swap_bwd_qk_loop, test_case=test_case, sink_layout="sh", max_seqlen_q=max_seqlen_q, From 62cd71d3ebfb13311579099ae40363c63bd07407 Mon Sep 17 00:00:00 2001 From: shw Date: Sun, 1 Feb 2026 11:43:35 +0000 Subject: [PATCH 02/14] support pack_gqa for tile_scheduler --- .../csrc/flexible_flash_attention/block.h | 27 +++++++++++++---- .../bwd_tile_scheduler.hpp | 30 +++++++++++++++---- .../flash_bwd_launch_template.h | 6 +++- 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/magi_attention/csrc/flexible_flash_attention/block.h b/magi_attention/csrc/flexible_flash_attention/block.h index 733312228..8f9040215 100644 --- a/magi_attention/csrc/flexible_flash_attention/block.h +++ b/magi_attention/csrc/flexible_flash_attention/block.h @@ -57,24 +57,39 @@ struct BlockMN { return {n_block_min, n_block_max}; } - // TODO: For backward with packgqa, we need to modify this function + // For backward with packgqa, m_block is in packed space, n_block is in logical space static CUTLASS_DEVICE cute::tuple get_m_block_min_max(SeqlenInfo_t const& seqlen_info, int const n_block, int const bidb, flash::AttnType const attn_type) { int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; - int m_block_max = cute::ceil_div(seqlen_q, kBlockM); + + // For PackGQA, the packed seqlen_q is seqlen_q * Qhead_per_khead + int const seqlen_q_packed = !PackGQA ? seqlen_q : seqlen_q * Qhead_per_khead; + int m_block_max = cute::ceil_div(seqlen_q_packed, kBlockM); + if (attn_type == flash::AttnType::Full || attn_type == flash::AttnType::Causal) { // do nothing } else if (attn_type == flash::AttnType::InvCausal || attn_type == flash::AttnType::BiCausal) { - // TODO: Need better way to compute this - int m_idx_max = std::min(seqlen_k, (n_block + 1) * kBlockN); - m_block_max = std::min(m_block_max, cute::ceil_div(m_idx_max, kBlockM)); + // n_idx_max in logical space (max n_idx for this n_block) + int n_idx_max = std::min(seqlen_k, (n_block + 1) * kBlockN); + // For PackGQA, convert to packed m space: m_idx_packed = m_idx_logical * Qhead_per_khead + // For InvCausal (m >= n), m_idx must be >= n_idx, so in packed space: m_idx_packed >= n_idx * Qhead_per_khead + int m_idx_max_packed = !PackGQA ? n_idx_max : n_idx_max * Qhead_per_khead; + m_block_max = std::min(m_block_max, cute::ceil_div(m_idx_max_packed, kBlockM)); } + int m_block_min = 0; if (attn_type == flash::AttnType::Causal || attn_type == flash::AttnType::BiCausal) { - m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k) / kBlockM); + // For Causal (m <= n + offset), where offset = seqlen_q - seqlen_k + // Given n_block, n_idx_min = n_block * kBlockN + // m_idx_min (logical) = n_idx_min + seqlen_q - seqlen_k + int m_idx_min_logical = n_block * kBlockN + seqlen_q - seqlen_k; + // For PackGQA, convert to packed space + int m_idx_min_packed = !PackGQA ? m_idx_min_logical : m_idx_min_logical * Qhead_per_khead; + m_block_min = std::max(m_block_min, m_idx_min_packed / kBlockM); } else if (attn_type == flash::AttnType::InvCausal || attn_type == flash::AttnType::Full) { // do nothing } + return {m_block_min, m_block_max}; } }; diff --git a/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp b/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp index c46235f3f..1b83d08e7 100644 --- a/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp +++ b/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp @@ -34,7 +34,8 @@ namespace flash { // Host side kernel arguments struct TileSchedulerArguments { - int const num_heads; + int const num_heads_q; // Number of Q heads + int const num_heads_kv; // Number of KV heads (for GQA) int const num_batches; int* const tile_count_semaphore = nullptr; int2* const ranges = nullptr; @@ -51,6 +52,7 @@ template < int NumMmaThreads = 2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads = cutlass::NumThreadsPerWarp, bool WarpSpecialized = true, + bool PackGQA = false, bool Deterministic = false> class DynamicPersistentTileSchedulerBwd { using resv_barrier = cutlass::arch::ReservedNamedBarriers; @@ -68,6 +70,7 @@ class DynamicPersistentTileSchedulerBwd { // Device side kernel params struct Params { int num_heads; + int seqlen_scale_factor; // PackGQA: num_heads_q / num_heads_kv, otherwise 1 int num_batches; int* const tile_count_semaphore; int2* const ranges; @@ -78,10 +81,24 @@ class DynamicPersistentTileSchedulerBwd { }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { + // PackGQA: seqlen_scale_factor = num_heads_q / num_heads_kv, otherwise 1 + int seqlen_scale_factor = !PackGQA ? 1 : (args.num_heads_q / args.num_heads_kv); + // PackGQA: num_heads = num_heads_kv, otherwise num_heads_q + int num_heads = !PackGQA ? args.num_heads_q : args.num_heads_kv; + assert(args.tile_count_semaphore != nullptr); - assert(args.num_heads < (1 << 16)); + assert(num_heads < (1 << 16)); int2* const ranges = args.merge_ranges ? args.merge_ranges : args.ranges; - return {args.num_heads, args.num_batches, args.tile_count_semaphore, ranges, args.merge_ranges, args.range_map, args.determin_conflict_state, args.unique_count}; + return { + num_heads, + seqlen_scale_factor, + args.num_batches, + args.tile_count_semaphore, + ranges, + args.merge_ranges, + args.range_map, + args.determin_conflict_state, + args.unique_count}; } static dim3 get_grid_shape(Params const& params, int num_sm) { @@ -135,7 +152,8 @@ class DynamicPersistentTileSchedulerBwd { return 0; int2 range = params.ranges[batch_idx]; int seqlen = batch_idx < actual_num_batches ? range.y - range.x : 0; - return batch_idx < actual_num_batches && lane < cutlass::NumThreadsPerWarp - 1 ? cute::ceil_div(seqlen, kBlock) : 0; + // PackGQA: seqlen needs to be multiplied by seqlen_scale_factor + return batch_idx < actual_num_batches && lane < cutlass::NumThreadsPerWarp - 1 ? cute::ceil_div(seqlen * params.seqlen_scale_factor, kBlock) : 0; }; int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane @@ -207,7 +225,9 @@ class DynamicPersistentTileSchedulerBwd { // %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, // bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); // } - + if (threadIdx.x == 0) { + printf("blockIdx.x = %d, threadIdx.x = %d, bidb = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, bidb, bidh, block); + } if constexpr (!Deterministic) { return {next_tile_idx, block, bidh, bidb}; } else { diff --git a/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h b/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h index ca8e8fda7..793f97e8b 100644 --- a/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h +++ b/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h @@ -112,6 +112,7 @@ template < typename ElementDkv, bool Deterministic, bool SwapBwdQKLoop, + bool PackGQA = false, int Stages = 2, int Stages_dO = 2, int Stages_dS = 2, @@ -178,7 +179,8 @@ void run_flash_bwd(Flash_bwd_params& params, cudaStream_t stream) { CollectiveMainloop::NumMmaThreads, CollectiveMainloop::NumProducerThreads, /*WarpSpecialized=*/Arch >= 90, - Deterministic>; + /*PackGQA=*/PackGQA, + /*Deterministic=*/Deterministic>; using CollectiveEpilogue = flash::CollectiveEpilogueBwd< TileShape_MNK, ElementDkv, @@ -250,6 +252,7 @@ void run_flash_bwd(Flash_bwd_params& params, cudaStream_t stream) { }; typename flash::TileSchedulerArguments scheduler_args{/*num_heads_q=*/params.h_qo, + /*num_heads_kv=*/params.h_kv, /*num_batches=*/params.merge_batch_size, /*tile_count_semaphore=*/params.tile_count_semaphore, /*ranges=*/SwapBwdQKLoop ? params.q_ranges : params.k_ranges, @@ -361,6 +364,7 @@ void run_mha_bwd_(Flash_bwd_params& params, cudaStream_t stream) { /*ElementDkv=*/TDkv, /*Deterministic=*/Deterministic, /*SwapBwdQKLoop=*/SwapBwdQKLoop, + /*PackGQA=*/PackGQA, /*Stages=*/Stages, /*Stages_dO=*/Stages_dO, /*Stages_dS=*/Stages_dS, From 1468b0b7b18b95145d9ea638cf7a09eb372a5dac Mon Sep 17 00:00:00 2001 From: shw Date: Mon, 2 Feb 2026 00:29:29 +0000 Subject: [PATCH 03/14] finish mainloop --- .../flash_bwd_launch_template.h | 4 + .../mainloop_bwd_sm90_tma_gmma_ws.hpp | 175 ++++++++++++++++-- 2 files changed, 161 insertions(+), 18 deletions(-) diff --git a/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h b/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h index 793f97e8b..1bfe022a5 100644 --- a/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h +++ b/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h @@ -113,6 +113,7 @@ template < bool Deterministic, bool SwapBwdQKLoop, bool PackGQA = false, + int QheadPerKhead = 1, int Stages = 2, int Stages_dO = 2, int Stages_dS = 2, @@ -169,6 +170,8 @@ void run_flash_bwd(Flash_bwd_params& params, cudaStream_t stream) { SdP_swapAB, dKV_swapAB, dQ_swapAB, + PackGQA, + QheadPerKhead, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, @@ -365,6 +368,7 @@ void run_mha_bwd_(Flash_bwd_params& params, cudaStream_t stream) { /*Deterministic=*/Deterministic, /*SwapBwdQKLoop=*/SwapBwdQKLoop, /*PackGQA=*/PackGQA, + /*QheadPerKhead=*/QheadPerKhead, /*Stages=*/Stages, /*Stages_dO=*/Stages_dO, /*Stages_dS=*/Stages_dS, diff --git a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp index 3244495da..1e5b5f0ca 100644 --- a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -59,6 +59,8 @@ template < bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_, + bool PackGQA_, + int Qhead_per_khead_, int NumMmaWarpGroups = 2, int AtomLayoutMSdP = 1, int AtomLayoutNdKV = 2, @@ -86,6 +88,8 @@ struct CollectiveMainloopBwdSm90 { static constexpr bool dKV_swapAB = dKV_swapAB_; static constexpr bool dQ_swapAB = dQ_swapAB_; static constexpr bool SwapBwdQKLoop = SwapBwdQKLoop_; + static constexpr bool PackGQA = PackGQA_; + static constexpr int Qhead_per_khead = Qhead_per_khead_; static constexpr bool Q_dO_same_stages = kStages == kStages_dO; using MainloopPipeline = typename cutlass::PipelineTmaAsync; @@ -101,7 +105,7 @@ struct CollectiveMainloopBwdSm90 { static constexpr int kHeadDim = get<2>(TileShape_MNK{}); using SeqlenInfo_t = flash::DistributedSeqlenInfo; - using BlockMN_t = flash::BlockMN; + using BlockMN_t = flash::BlockMN; static_assert(NumMmaWarpGroups % AtomLayoutMSdP == 0); static_assert(NumMmaWarpGroups % AtomLayoutNdKV == 0); @@ -300,6 +304,18 @@ struct CollectiveMainloopBwdSm90 { using ShapeLSE = cute::Shape<_4, int32_t, int32_t>; // (4, seqlen_q, num_heads_q) using StrideLSE = cute::Stride<_1, _4, int64_t>; + // Packed shape/stride for Q and dO when PackGQA is enabled + using ShapeQPackedTMA = std::conditional_t< + !PackGQA, + ShapeQKV, + cute::Shape, int32_t>, int32_t, int32_t> // ((qhead_per_khead, seqlen), headdim, khead) + >; + using StrideQPackedTMA = std::conditional_t< + !PackGQA, + StrideQKV, + cute::Shape, _1, int64_t> // ((qhead_per_khead, seqlen), headdim, khead) + >; + using TMA_QdO = decltype(make_tma_copy_A_sm90( GmemTiledCopyQdO{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), @@ -307,6 +323,22 @@ struct CollectiveMainloopBwdSm90 { TileShape_MNK{}, ClusterShape{})); // mcast along N mode for this M load, if any + // Packed TMA for Q when PackGQA is enabled (only used when SwapBwdQKLoop, i.e. load Q once per m_block) + using TMA_Q_Packed = decltype(make_tma_copy( + cute::SM90_TMA_LOAD{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQPackedTMA{}, StrideQPackedTMA{}), + take<0, 2>(SmemLayoutQ{}), + select<0, 2>(TileShape_MNK{}), + _1{})); // no mcast for packed Q + + // Packed TMA for dO when PackGQA is enabled + using TMA_dO_Packed = decltype(make_tma_copy( + cute::SM90_TMA_LOAD{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQPackedTMA{}, StrideQPackedTMA{}), + take<0, 2>(SmemLayoutdO{}), + select<0, 2>(TileShape_MNK{}), + _1{})); // no mcast for packed dO + using TMA_K = decltype(make_tma_copy_B_sm90( GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), @@ -446,6 +478,7 @@ struct CollectiveMainloopBwdSm90 { // Device side kernel params struct Params { ShapeQKV const shape_Q; + ShapeQPackedTMA const shape_Q_packed; // For PackGQA ShapeQKV const shape_K; ElementAccum* const ptr_dQ; // k for outer-loop and q for inner-loop ShapeQKV const shape_dQ; @@ -458,6 +491,8 @@ struct CollectiveMainloopBwdSm90 { StrideQKV const stride_dV; cutlass::FastDivmod qhead_per_khead_divmod; TMA_QdO tma_load_Q, tma_load_dO; + TMA_Q_Packed tma_load_Q_packed; // For PackGQA + TMA_dO_Packed tma_load_dO_packed; // For PackGQA TMA_K tma_load_K; TMA_V tma_load_V; TMA_add_dQ tma_add_dQ; // k for outer-loop and q for inner-loop @@ -494,6 +529,52 @@ struct CollectiveMainloopBwdSm90 { Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); TMA_V tma_load_V = make_tma_copy_B_sm90(GmemTiledCopyKV{}, mV, take<0, 2>(SmemLayoutV{}), TileShape_MNK{}, ClusterShape{}); + // Create packed shape/stride and TMA for PackGQA + auto const shape_Q_packed = cute::conditional_return( + args.shape_Q, + make_shape( + make_shape(cute::Int{}, get<0>(args.shape_Q)), // (qhead_per_khead, seqlen) + get<1>(args.shape_Q), // headdim + get<2>(args.shape_K) // numhead_k + )); + + auto const stride_Q_packed = cute::conditional_return( + args.stride_Q, + make_stride( + make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), // (qhead_per_khead, seqlen) + get<1>(args.stride_Q), // headdim + get<2>(args.stride_Q) * Qhead_per_khead)); + + auto mQPacked = [&]() { + if constexpr (!PackGQA) { + return mQ; + } else { + return make_tensor(make_gmem_ptr(args.ptr_Q), make_layout(shape_Q_packed, stride_Q_packed)); + } + }(); + + auto mdOPacked = [&]() { + if constexpr (!PackGQA) { + return mdO; + } else { + return make_tensor( + make_gmem_ptr(args.ptr_dO), + make_layout( + make_shape( + make_shape(cute::Int{}, get<0>(args.shape_Q)), // (qhead_per_khead, seqlen) + get<1>(args.shape_Q), // headdim + get<2>(args.shape_K) // numhead_k + ), + make_stride( + make_stride(get<2>(args.stride_dO), get<0>(args.stride_dO)), // (qhead_per_khead, seqlen) + get<1>(args.stride_dO), // headdim + get<2>(args.stride_dO) * Qhead_per_khead))); + } + }(); + + TMA_Q_Packed tma_load_Q_packed = make_tma_copy(cute::SM90_TMA_LOAD{}, mQPacked, take<0, 2>(SmemLayoutQ{}), select<0, 2>(TileShape_MNK{}), _1{}); + TMA_dO_Packed tma_load_dO_packed = make_tma_copy(cute::SM90_TMA_LOAD{}, mdOPacked, take<0, 2>(SmemLayoutdO{}), select<0, 2>(TileShape_MNK{}), _1{}); + Tensor mdQ = make_tensor(make_gmem_ptr(args.ptr_dQ), args.shape_dQ, args.stride_dQ); TMA_add_dQ tma_add_dQ = make_tma_copy(GmemTiledCopydQaccum{}, mdQ, SmemLayoutdQaccumTMA{}, TileShape_dQaccum{}, _1{}); Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK); @@ -512,6 +593,7 @@ struct CollectiveMainloopBwdSm90 { // (the original softmax_scale) at the end. return { args.shape_Q, + shape_Q_packed, args.shape_K, args.ptr_dQ, args.shape_dQ, @@ -525,6 +607,8 @@ struct CollectiveMainloopBwdSm90 { /*qhead_per_khead_divmod=*/cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), tma_load_Q, tma_load_dO, + tma_load_Q_packed, + tma_load_dO_packed, tma_load_K, tma_load_V, tma_add_dQ, @@ -549,8 +633,13 @@ struct CollectiveMainloopBwdSm90 { // Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { - cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); - cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor()); + if constexpr (!PackGQA) { + cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor()); + } else { + cute::prefetch_tma_descriptor(params.tma_load_Q_packed.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_dO_packed.get_tma_descriptor()); + } cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); } @@ -728,7 +817,8 @@ struct CollectiveMainloopBwdSm90 { static_assert(SwapBwdQKLoop, "load_with_loop_k() must be called when SwapBwdQKLoop is true"); int m_block = get<0>(block_coord), bidh = get<1>(block_coord), bidb = get<2>(block_coord); - int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); + // For PackGQA, bidh is already the KV head index + int bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; SeqlenInfo_t seqlen_info{bidb, params.q_ranges, params.k_ranges}; flash::AttnType attn_type = static_cast(params.attn_type_map ? params.attn_type_map[bidb] : 0); @@ -751,15 +841,32 @@ struct CollectiveMainloopBwdSm90 { auto [mcast_mask_kv, cluster_block_id_kv] = get_tma_multi_cast_meta(); // Prepare the TMA loads - Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh); // (seqlen_q, head_dim) - Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh); // (seqlen_q, head_dim) + // For PackGQA, use packed TMA and shape_Q_packed + Tensor mQ = [&]() { + if constexpr (!PackGQA) { + return params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh); // (seqlen_q, head_dim) + } else { + return params.tma_load_Q_packed.get_tma_tensor(params.shape_Q_packed)(_, _, bidh); // ((qhead_per_khead, seqlen_q), head_dim) + } + }(); + Tensor mdO = [&]() { + if constexpr (!PackGQA) { + return params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh); // (seqlen_q, head_dim) + } else { + return params.tma_load_dO_packed.get_tma_tensor(params.shape_Q_packed)(_, _, bidh); // ((qhead_per_khead, seqlen_q), head_dim) + } + }(); Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv); // (seqlen_kv, head_dim) Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv); // (seqlen_kv, head_dim) - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, _, bidh); // (4, seqlen_q) - Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, _, bidh); // (4, seqlen_q) - - Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + // For PackGQA, LSE/dPsum still use the original Q head index (need to multiply bidh by Qhead_per_khead to get the first Q head) + int bidh_q = !PackGQA ? bidh : bidh * Qhead_per_khead; + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, _, bidh_q); // (4, seqlen_q) + Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, _, bidh_q); // (4, seqlen_q) + + // For PackGQA, offset needs to be multiplied by Qhead_per_khead + int offset_q_packed = !PackGQA ? seqlen_info.offset_q : seqlen_info.offset_q * Qhead_per_khead; + Tensor gQ = local_tile(domain_offset(make_coord(offset_q_packed, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gdO = local_tile(domain_offset(make_coord(offset_q_packed, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) @@ -769,12 +876,24 @@ struct CollectiveMainloopBwdSm90 { local_tile(cute::domain_offset(make_coord(_0{}, seqlen_info.offset_q), mdPsum), make_shape(_4{}, Int{}), make_coord(_0{}, m_block)); // (4, M) // NOTE: tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually - auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); + auto block_tma_Q = [&]() { + if constexpr (!PackGQA) { + return params.tma_load_Q.get_slice(_0{}); + } else { + return params.tma_load_Q_packed.get_slice(_0{}); + } + }(); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) // NOTE: tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually - auto block_tma_dO = params.tma_load_dO.get_slice(_0{}); + auto block_tma_dO = [&]() { + if constexpr (!PackGQA) { + return params.tma_load_dO.get_slice(_0{}); + } else { + return params.tma_load_dO_packed.get_slice(_0{}); + } + }(); Tensor tdOgdO = group_modes<0, 3>(block_tma_dO.partition_S(gdO)); // (TMA) Tensor tdOsdO = group_modes<0, 3>(block_tma_dO.partition_D(sdO)); // (TMA) @@ -820,8 +939,13 @@ struct CollectiveMainloopBwdSm90 { auto& barrier_QdO = reinterpret_cast(shared_storage.pipelines.barrier_QdO); shared_storage.pipelines.barrier_QdO.arrive_and_expect_tx(TmaTransactionBytesQ + TmaTransactionBytesdO + TmaTransactionBytesLSE + TmaTransactionBytesdPsum); // REVIEW: why not add `TMA::CacheHintSm90::EVICT_FIRST` hint here ? - copy(params.tma_load_Q.with(barrier_QdO, /*mcast_mask=*/0), tQgQ, tQsQ); - copy(params.tma_load_dO.with(barrier_QdO, /*mcast_mask=*/0), tdOgdO, tdOsdO); + if constexpr (!PackGQA) { + copy(params.tma_load_Q.with(barrier_QdO, /*mcast_mask=*/0), tQgQ, tQsQ); + copy(params.tma_load_dO.with(barrier_QdO, /*mcast_mask=*/0), tdOgdO, tdOsdO); + } else { + copy(params.tma_load_Q_packed.with(barrier_QdO, /*mcast_mask=*/0), tQgQ, tQsQ); + copy(params.tma_load_dO_packed.with(barrier_QdO, /*mcast_mask=*/0), tdOgdO, tdOsdO); + } copy(bulk_copy.with(barrier_QdO), gLSE, sLSE); copy(bulk_copy.with(barrier_QdO), gdPsum, sdPsum); } @@ -1120,7 +1244,8 @@ struct CollectiveMainloopBwdSm90 { } int m_block = get<0>(block_coord), bidh = get<1>(block_coord), bidb = get<2>(block_coord); - int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); + // For PackGQA, bidh is already the KV head index + int bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; SeqlenInfo_t seqlen_info{bidb, params.q_ranges, params.k_ranges}; flash::AttnType attn_type = static_cast(params.attn_type_map ? params.attn_type_map[bidb] : 0); auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max(seqlen_info, m_block, bidb, attn_type); @@ -1804,7 +1929,9 @@ struct CollectiveMainloopBwdSm90 { } // Define mask lambda func - auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block, attn_type, thread_idx, seqlen_q, seqlen_k); }; + auto mask_fn = [&](auto& tSrS, int m_block) { + mask.template apply(tSrS, m_block, n_block, attn_type, thread_idx, seqlen_q, seqlen_k); + }; // Apply backward steps CUTLASS_PRAGMA_NO_UNROLL @@ -2030,6 +2157,16 @@ struct CollectiveMainloopBwdSm90 { shared_storage.pipelines.barrier_QdO.wait(work_idx % 2); } + // DEBUG: + if (m_block == 0 && bidh == 1 && bidb == 1 && thread_idx == 0) { + printf("\n[DEBUG CUDA] sQ (PackGQA=%d, m_block=%d, bidh=%d, bidb=%d):\n", PackGQA, m_block, bidh, bidb); + cute::print_tensor(sQ); + printf("\n[DEBUG CUDA] sdO:\n"); + cute::print_tensor(sdO); + printf("\n"); + } + // } + // Copy LSE from shared memory to registers if constexpr (!ShuffleLSE) { cute::copy(tLSEsLSE, tLSErLSE); @@ -2487,7 +2624,9 @@ struct CollectiveMainloopBwdSm90 { } // Define mask lambda func - auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block, attn_type, thread_idx, seqlen_q, seqlen_k); }; + auto mask_fn = [&](auto& tSrS, int n_block) { + mask.template apply(tSrS, m_block, n_block, attn_type, thread_idx, seqlen_q, seqlen_k); + }; // Apply backward steps // NOTE: only the last m block for the same batch needs to mask_lse From c07aff6fdb5d9aff62aee92fb4ae20ab432d5e56 Mon Sep 17 00:00:00 2001 From: shw Date: Tue, 3 Feb 2026 06:56:15 +0000 Subject: [PATCH 04/14] support bwd_epilogue for pack_gqa --- .../bwd_tile_scheduler.hpp | 6 +- .../flexible_flash_attention/epilogue_bwd.hpp | 121 ++++++++++++-- .../flash_bwd_launch_template.h | 4 +- .../mainloop_bwd_sm90_tma_gmma_ws.hpp | 85 ++++++++-- tests/test_attn/test_flex_flash_attn.py | 158 +++++++++++++++--- 5 files changed, 315 insertions(+), 59 deletions(-) diff --git a/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp b/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp index 1b83d08e7..65183acd3 100644 --- a/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp +++ b/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp @@ -225,9 +225,9 @@ class DynamicPersistentTileSchedulerBwd { // %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, // bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); // } - if (threadIdx.x == 0) { - printf("blockIdx.x = %d, threadIdx.x = %d, bidb = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, bidb, bidh, block); - } + // if (threadIdx.x == 0) { + // printf("blockIdx.x = %d, threadIdx.x = %d, bidb = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, bidb, bidh, block); + // } if constexpr (!Deterministic) { return {next_tile_idx, block, bidh, bidb}; } else { diff --git a/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp b/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp index 458446aa3..de3d139d2 100644 --- a/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp +++ b/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp @@ -48,7 +48,9 @@ template < int AtomLayoutNdKV = 2, bool DisableBwdDkvAtomicReduction_ = false, bool Deterministic_ = false, - bool SwapBwdQKLoop_ = false> + bool SwapBwdQKLoop_ = false, + bool PackGQA_ = false, + int Qhead_per_khead_ = 1> struct CollectiveEpilogueBwd { using TileShape_MNK = TileShape_MNK_; using Element = Element_; @@ -69,6 +71,8 @@ struct CollectiveEpilogueBwd { static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90; static constexpr bool Deterministic = Deterministic_; static constexpr bool SwapBwdQKLoop = SwapBwdQKLoop_; + static constexpr bool PackGQA = PackGQA_; + static constexpr int Qhead_per_khead = Qhead_per_khead_; // for non packgqa, Qhead_per_khead is always 1. static constexpr int NumEpilogueThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int AtomLayoutMdKV = NumMmaWarpGroups * (Use_TMA ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV; @@ -150,6 +154,11 @@ struct CollectiveEpilogueBwd { using ShapedQKV = cute::Shape; // (seqlen, head_dim, num_heads) using StridedQKV = cute::Stride; + // Packed shape/stride for dQ when PackGQA is enabled + // ((Qhead_per_khead, seqlen_q), head_dim, nheads_kv) + using ShapedQPacked = std::conditional_t, int32_t>, int32_t, int32_t>>; + using StridedQPacked = std::conditional_t, _1, int64_t>>; + using TMA_dQ = std::conditional_t< Use_TMA, decltype(make_tma_copy( @@ -160,6 +169,17 @@ struct CollectiveEpilogueBwd { _1{})), // no mcast for dQ std::nullptr_t>; + // Packed TMA for dQ when PackGQA is enabled + using TMA_dQ_Packed = std::conditional_t< + Use_TMA && PackGQA, + decltype(make_tma_copy( + GmemTiledCopydQTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapedQPacked{}, StridedQPacked{}), + SmemLayoutdQTMA{}, + select<0, 2>(TileShape_MNK{}), + _1{})), // no mcast for packed dQ + std::nullptr_t>; + using TMA_dKV = std::conditional_t< Use_TMA, decltype(make_tma_copy( @@ -192,7 +212,9 @@ struct CollectiveEpilogueBwd { struct Params { Element* ptr_dQ; // q for outer-loop and k for inner-loop ShapedQKV const shape_dQ; + ShapedQPacked const shape_dQ_packed; StridedQKV const stride_dQ; + StridedQPacked const stride_dQ_packed; Element* ptr_dK; // k for outer-loop and q for inner-loop ShapedQKV const shape_dK; StridedQKV const stride_dK; @@ -200,6 +222,7 @@ struct CollectiveEpilogueBwd { ShapedQKV const shape_dV; StridedQKV const stride_dV; TMA_dQ tma_store_dQ; // q for outer-loop and k for inner-loop + TMA_dQ_Packed tma_store_dQ_packed; // For PackGQA TMA_dKV tma_store_dK; // k for outer-loop and q for inner-loop TMA_dKV tma_store_dV; // k for outer-loop and q for inner-loop int2 const* q_ranges; @@ -214,6 +237,14 @@ struct CollectiveEpilogueBwd { Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK); Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV); + // Compute packed shape/stride for dQ when PackGQA is enabled + // shape_dQ_packed: ((Qhead_per_khead, seqlen_q), head_dim, nheads_kv) + // stride_dQ_packed: ((head_stride, seq_stride), 1, head_stride * Qhead_per_khead) + auto const shape_dQ_packed = cute::conditional_return( + args.shape_dQ, make_shape(make_shape(cute::Int{}, get<0>(args.shape_dQ)), get<1>(args.shape_dQ), args.num_heads_kv)); + auto const stride_dQ_packed = cute::conditional_return( + args.stride_dQ, make_stride(make_stride(get<2>(args.stride_dQ), get<0>(args.stride_dQ)), get<1>(args.stride_dQ), get<2>(args.stride_dQ) * Qhead_per_khead)); + TMA_dQ tma_store_dQ = [&] { if constexpr (Use_TMA) { return make_tma_copy(GmemTiledCopydQTMA{}, mdQ, SmemLayoutdQTMA{}, select<0, 2>(TileShape_MNK{}), _1{}); @@ -221,6 +252,16 @@ struct CollectiveEpilogueBwd { return nullptr; } }(); + + // Create packed TMA descriptor for dQ when PackGQA is enabled + TMA_dQ_Packed tma_store_dQ_packed = [&] { + if constexpr (Use_TMA && PackGQA) { + Tensor mdQ_packed = make_tensor(make_gmem_ptr(args.ptr_dQ), shape_dQ_packed, stride_dQ_packed); + return make_tma_copy(GmemTiledCopydQTMA{}, mdQ_packed, SmemLayoutdQTMA{}, select<0, 2>(TileShape_MNK{}), _1{}); + } else { + return nullptr; + } + }(); TMA_dKV tma_store_dK = [&] { if constexpr (Use_TMA) { return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); @@ -235,10 +276,13 @@ struct CollectiveEpilogueBwd { return nullptr; } }(); + return { args.ptr_dQ, args.shape_dQ, + shape_dQ_packed, args.stride_dQ, + stride_dQ_packed, args.ptr_dK, args.shape_dK, args.stride_dK, @@ -246,6 +290,7 @@ struct CollectiveEpilogueBwd { args.shape_dV, args.stride_dV, tma_store_dQ, + tma_store_dQ_packed, tma_store_dK, tma_store_dV, args.q_ranges, @@ -260,7 +305,11 @@ struct CollectiveEpilogueBwd { static void prefetch_tma_descriptors(Params const& params) { if constexpr (Use_TMA) { if constexpr (SwapBwdQKLoop) { - cute::prefetch_tma_descriptor(params.tma_store_dQ.get_tma_descriptor()); + if constexpr (PackGQA) { + cute::prefetch_tma_descriptor(params.tma_store_dQ_packed.get_tma_descriptor()); + } else { + cute::prefetch_tma_descriptor(params.tma_store_dQ.get_tma_descriptor()); + } } else { cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor()); @@ -469,8 +518,29 @@ struct CollectiveEpilogueBwd { // Get block coordinates for current job (tile) int m_block = get<0>(block_coord), bidh = get<1>(block_coord), bidb = get<2>(block_coord); + // #region agent log - DEBUG: print for all m_blocks when PackGQA=1 + // if (PackGQA && bidh == 0 && bidb == 0 && thread_idx == 0) { + // // Check first element of tdQrdQ to see if it's zero + // float first_val = static_cast(tdQrdQ(0)); + // printf("[DEBUG store_dq] PackGQA=%d, m_block=%d, bidh=%d, bidb=%d, tdQrdQ[0]=%.4f\n", + // PackGQA, m_block, bidh, bidb, first_val); + // } + // #endregion + + // if (m_block == 0 && bidh == 0 && bidb == 0 && thread_idx == 0) { + // printf("\n[DEBUG CUDA] store_dq(PackGQA=%d, m_block=%d, bidh=%d, bidb=%d):\n", PackGQA, m_block, bidh, bidb); + // printf("\n[DEBUG CUDA] tdQrdQ:\n"); + // cute::print_tensor(tdQrdQ); + // // printf("\n[DEBUG CUDA] sdQ:\n"); + // // cute::print_tensor(sdQ); + // // printf("\n[DEBUG CUDA] sdQt:\n"); + // // cute::print_tensor(sdQt); + // printf("\n"); + // } + // For PackGQA, bidh is already KV head index (scheduler uses num_heads_kv) + // For non-PackGQA, bidh is Q head index int bidh_idx_in_group; - int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh); + int bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh) : bidh; Tensor sdQ = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dq.data()), SmemLayoutdQ{})); Tensor sdQt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dq.data()), SmemLayoutdQt{})); @@ -503,19 +573,42 @@ struct CollectiveEpilogueBwd { BarrierManager::arrive(resv_barrier::EpilogueBarrier); SeqlenInfo_t seqlen_info{bidb, params.q_ranges, params.k_ranges}; - Tensor mdQ = params.tma_store_dQ.get_tma_tensor(params.shape_dK)(_, _, bidh); // (seqlen_q, head_dim) - Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + int warp_idx_sync = warp_uniform(thread_idx / cutlass::NumThreadsPerWarp); - auto block_tma_dQ = params.tma_store_dQ.get_slice(_0{}); - Tensor tdQgdQ = block_tma_dQ.partition_D(gdQ); // (TMA, TMA_M, TMA_K) - Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K) + // For PackGQA, use packed TMA descriptor and packed offset + // For non-PackGQA, use original TMA descriptor and offset + if constexpr (!PackGQA) { + Tensor mdQ = params.tma_store_dQ.get_tma_tensor(params.shape_dQ)(_, _, bidh); // (seqlen_q, head_dim) + Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - int warp_idx_sync = warp_uniform(thread_idx / cutlass::NumThreadsPerWarp); - if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { - BarrierManager::sync(resv_barrier::EpilogueBarrier); - if (cute::elect_one_sync()) { - cute::copy(params.tma_store_dQ, tdQsdQ, tdQgdQ); - tma_store_arrive(); + auto block_tma_dQ = params.tma_store_dQ.get_slice(_0{}); + Tensor tdQgdQ = block_tma_dQ.partition_D(gdQ); // (TMA, TMA_M, TMA_K) + Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K) + + if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { + BarrierManager::sync(resv_barrier::EpilogueBarrier); + if (cute::elect_one_sync()) { + cute::copy(params.tma_store_dQ, tdQsdQ, tdQgdQ); + tma_store_arrive(); + } + } + } else { + // For PackGQA: use packed TMA descriptor + // bidh is KV head index, offset_q needs to be scaled by Qhead_per_khead + Tensor mdQ_packed = params.tma_store_dQ_packed.get_tma_tensor(params.shape_dQ_packed)(_, _, bidh); // (seqlen_q * Qhead_per_khead, head_dim) + Tensor gdQ_packed = local_tile( + domain_offset(make_coord(seqlen_info.offset_q * Qhead_per_khead, _0{}), mdQ_packed), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + + auto block_tma_dQ_packed = params.tma_store_dQ_packed.get_slice(_0{}); + Tensor tdQgdQ_packed = block_tma_dQ_packed.partition_D(gdQ_packed); // (TMA, TMA_M, TMA_K) + Tensor tdQsdQ_packed = block_tma_dQ_packed.partition_S(sdQ); // (TMA, TMA_M, TMA_K) + + if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { + BarrierManager::sync(resv_barrier::EpilogueBarrier); + if (cute::elect_one_sync()) { + cute::copy(params.tma_store_dQ_packed, tdQsdQ_packed, tdQgdQ_packed); + tma_store_arrive(); + } } } diff --git a/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h b/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h index 1bfe022a5..442eec9d6 100644 --- a/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h +++ b/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h @@ -197,7 +197,9 @@ void run_flash_bwd(Flash_bwd_params& params, cudaStream_t stream) { AtomLayoutNdKV, DisableBwdDkvAtomicReduction, Deterministic, - SwapBwdQKLoop>; + SwapBwdQKLoop, + /*PackGQA=*/PackGQA, + /*Qhead_per_khead=*/QheadPerKhead>; using AttnKernel = flash::enable_sm90_or_later>; typename CollectiveMainloop::Arguments mainloop_args{ diff --git a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp index 1e5b5f0ca..deb3df021 100644 --- a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -304,6 +304,16 @@ struct CollectiveMainloopBwdSm90 { using ShapeLSE = cute::Shape<_4, int32_t, int32_t>; // (4, seqlen_q, num_heads_q) using StrideLSE = cute::Stride<_1, _4, int64_t>; + // Packed shape/stride for LSE and dPsum when PackGQA is enabled (only used when SwapBwdQKLoop) + using ShapeLSEPacked = std::conditional_t< + !PackGQA, + ShapeLSE, // (4, seqlen_q, num_heads_q) + cute::Shape<_4, cute::Shape, int32_t>, int32_t>>; // (4, (qhead_per_khead, seqlen_q), nheads_kv) + using StrideLSEPacked = std::conditional_t< + !PackGQA, + StrideLSE, // (1, 4, head_stride) + cute::Stride<_1, cute::Stride, int64_t>>; // (1, (head_stride, 4), head_stride * qhead_per_khead) + // Packed shape/stride for Q and dO when PackGQA is enabled using ShapeQPackedTMA = std::conditional_t< !PackGQA, @@ -503,6 +513,9 @@ struct CollectiveMainloopBwdSm90 { StrideLSE const stride_LSE_log2; float const* const ptr_dPsum; StrideLSE const stride_dPsum; + ShapeLSEPacked const shape_LSE_packed; // For PackGQA + StrideLSEPacked const stride_LSE_packed; // For PackGQA + StrideLSEPacked const stride_dPsum_packed; // For PackGQA float const softmax_scale; float const softmax_scale_log2; float const softcap_val; @@ -582,6 +595,27 @@ struct CollectiveMainloopBwdSm90 { Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV); TMA_add_dKV tma_add_dV = make_tma_copy(GmemTiledCopydKVaccum{}, mdV, SmemLayoutdKVaccumTMA{}, TileShape_dKVaccum{}, _1{}); + // Create packed shape/stride for LSE and dPsum when PackGQA is enabled + auto const shape_LSE_packed = cute::conditional_return( + args.shape_LSE, + make_shape( + _4{}, + make_shape(cute::Int{}, get<1>(args.shape_LSE)), // (qhead_per_khead, seqlen_q) + get<2>(args.shape_K) // nheads_kv + )); + auto const stride_LSE_packed = cute::conditional_return( + args.stride_LSE_log2, + make_stride( + _1{}, + make_stride(get<2>(args.stride_LSE_log2), _4{}), // (head_stride, 4) + get<2>(args.stride_LSE_log2) * Qhead_per_khead)); + auto const stride_dPsum_packed = cute::conditional_return( + args.stride_dPsum, + make_stride( + _1{}, + make_stride(get<2>(args.stride_dPsum), _4{}), // (head_stride, 4) + get<2>(args.stride_dPsum) * Qhead_per_khead)); + // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val @@ -619,6 +653,9 @@ struct CollectiveMainloopBwdSm90 { args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, + shape_LSE_packed, + stride_LSE_packed, + stride_dPsum_packed, args.softmax_scale, /*softmax_scale_log2=*/!Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), /*softcap_val=*/!Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, @@ -858,10 +895,21 @@ struct CollectiveMainloopBwdSm90 { }(); Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv); // (seqlen_kv, head_dim) Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv); // (seqlen_kv, head_dim) - // For PackGQA, LSE/dPsum still use the original Q head index (need to multiply bidh by Qhead_per_khead to get the first Q head) - int bidh_q = !PackGQA ? bidh : bidh * Qhead_per_khead; - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, _, bidh_q); // (4, seqlen_q) - Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, _, bidh_q); // (4, seqlen_q) + // For PackGQA, LSE/dPsum use packed shape/stride to correctly read data from multiple Q heads + auto mLSE = [&]() { + if constexpr (!PackGQA) { + return make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, _, bidh); // (4, seqlen_q) + } else { + return make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE_packed, params.stride_LSE_packed)(_, _, bidh); // (4, (qhead_per_khead, seqlen_q)) + } + }(); + auto mdPsum = [&]() { + if constexpr (!PackGQA) { + return make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, _, bidh); // (4, seqlen_q) + } else { + return make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE_packed, params.stride_dPsum_packed)(_, _, bidh); // (4, (qhead_per_khead, seqlen_q)) + } + }(); // For PackGQA, offset needs to be multiplied by Qhead_per_khead int offset_q_packed = !PackGQA ? seqlen_info.offset_q : seqlen_info.offset_q * Qhead_per_khead; @@ -870,10 +918,10 @@ struct CollectiveMainloopBwdSm90 { Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + // For PackGQA, LSE/dPsum also use packed offset to match Q/dO's packed access pattern auto bulk_copy = Copy_Traits{}; - Tensor gLSE = local_tile(cute::domain_offset(make_coord(_0{}, seqlen_info.offset_q), mLSE), make_shape(_4{}, Int{}), make_coord(_0{}, m_block)); // (4, M) - Tensor gdPsum = - local_tile(cute::domain_offset(make_coord(_0{}, seqlen_info.offset_q), mdPsum), make_shape(_4{}, Int{}), make_coord(_0{}, m_block)); // (4, M) + Tensor gLSE = local_tile(cute::domain_offset(make_coord(_0{}, offset_q_packed), mLSE), make_shape(_4{}, Int{}), make_coord(_0{}, m_block)); // (4, M) + Tensor gdPsum = local_tile(cute::domain_offset(make_coord(_0{}, offset_q_packed), mdPsum), make_shape(_4{}, Int{}), make_coord(_0{}, m_block)); // (4, M) // NOTE: tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually auto block_tma_Q = [&]() { @@ -1978,7 +2026,10 @@ struct CollectiveMainloopBwdSm90 { int m_block = get<0>(block_coord), bidh = get<1>(block_coord), bidb = get<2>(block_coord); SeqlenInfo_t seqlen_info{bidb, params.q_ranges, params.k_ranges}; int const seqlen_q = seqlen_info.seqlen_q, seqlen_k = seqlen_info.seqlen_k; - bool const is_last_m_block_this_batch = seqlen_q - m_block * kBlockM <= kBlockM; + // For PackGQA, the packed seqlen_q is seqlen_q * Qhead_per_khead + int const seqlen_q_packed = !PackGQA ? seqlen_q : seqlen_q * Qhead_per_khead; + bool const is_last_m_block_this_batch = seqlen_q_packed - m_block * kBlockM <= kBlockM; + // bool const is_last_m_block_this_batch = seqlen_q - m_block * kBlockM <= kBlockM; flash::AttnType attn_type = static_cast(params.attn_type_map ? params.attn_type_map[bidb] : 0); auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max(seqlen_info, m_block, bidb, attn_type); @@ -2158,13 +2209,13 @@ struct CollectiveMainloopBwdSm90 { } // DEBUG: - if (m_block == 0 && bidh == 1 && bidb == 1 && thread_idx == 0) { - printf("\n[DEBUG CUDA] sQ (PackGQA=%d, m_block=%d, bidh=%d, bidb=%d):\n", PackGQA, m_block, bidh, bidb); - cute::print_tensor(sQ); - printf("\n[DEBUG CUDA] sdO:\n"); - cute::print_tensor(sdO); - printf("\n"); - } + // if (m_block == 0 && bidh == 1 && bidb == 1 && thread_idx == 0) { + // printf("\n[DEBUG CUDA] sQ (PackGQA=%d, m_block=%d, bidh=%d, bidb=%d):\n", PackGQA, m_block, bidh, bidb); + // cute::print_tensor(sQ); + // printf("\n[DEBUG CUDA] sdO:\n"); + // cute::print_tensor(sdO); + // printf("\n"); + // } // } // Copy LSE from shared memory to registers @@ -2258,7 +2309,9 @@ struct CollectiveMainloopBwdSm90 { Tensor t0ScS = thread0_mma.partition_C(cS); Tensor t0ScS_rowcol = make_tensor(t0ScS.data(), flash::convert_layout_acc_rowcol(t0ScS.layout())); int const thread_row_offset = get(tScS_rowcol(_0{}, _0{})); - int const seqlenq_row_limit = seqlen_q - m_block * kBlockM - thread_row_offset; + // For PackGQA, need to use seqlen_q_packed for physical row limit + int const seqlenq_row_limit = seqlen_q_packed - m_block * kBlockM - thread_row_offset; + // int const seqlenq_row_limit = seqlen_q - m_block * kBlockM - thread_row_offset; #pragma unroll for (int mi = 0; mi < size<0>(scores); ++mi) { diff --git a/tests/test_attn/test_flex_flash_attn.py b/tests/test_attn/test_flex_flash_attn.py index e5e2d3c6e..dd15252c3 100644 --- a/tests/test_attn/test_flex_flash_attn.py +++ b/tests/test_attn/test_flex_flash_attn.py @@ -1240,18 +1240,18 @@ def run_test_case( "num_heads_kv": 4, "head_dim": 128, }, - { - "name": "mha_nh1_hd64", - "num_heads_q": 1, - "num_heads_kv": 1, - "head_dim": 64, - }, - { - "name": "gqa_nhq4_nhkv2_hd64", - "num_heads_q": 4, - "num_heads_kv": 2, - "head_dim": 64, - }, + # { + # "name": "mha_nh1_hd64", + # "num_heads_q": 1, + # "num_heads_kv": 1, + # "head_dim": 64, + # }, + # { + # "name": "gqa_nhq4_nhkv2_hd64", + # "num_heads_q": 4, + # "num_heads_kv": 2, + # "head_dim": 64, + # }, ] @with_run_in_mp @@ -1504,17 +1504,17 @@ def run_test_case( ), "attn_type_map": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], }, - { - "name": "deterministic_sample", - "seqlen": 2500, - "q_ranges": AttnRanges.from_ranges( - [[i * 50, (i + 1) * 50] for i in range(50) for j in range(50)] - ), - "k_ranges": AttnRanges.from_ranges( - [[i * 50, (i + 1) * 50] for i in range(50)] * 50 - ), - "attn_type_map": [0, 1] * 1250, - }, + # { + # "name": "deterministic_sample", + # "seqlen": 2500, + # "q_ranges": AttnRanges.from_ranges( + # [[i * 50, (i + 1) * 50] for i in range(50) for j in range(50)] + # ), + # "k_ranges": AttnRanges.from_ranges( + # [[i * 50, (i + 1) * 50] for i in range(50)] * 50 + # ), + # "attn_type_map": [0, 1] * 1250, + # }, { "name": "sparse_attn_2k_with_same_k_ranges", "seqlen": 2048, @@ -1557,7 +1557,7 @@ def run_test_case( ], ) @parameterize("model_config", MODEL_CONFIGS) - @parameterize("dtype", [torch.float16, torch.bfloat16]) + @parameterize("dtype", [torch.bfloat16]) def test_ffa_simple( self, attn_mask_config: dict[str, Any], @@ -1602,6 +1602,11 @@ def test_ffa_simple( pack_gqa = ref_block_config["pack_gqa"] sparse_load = ref_block_config["sparse_load"] + pack_gqa = True + swap_bwd_qk_loop = True + auto_range_merge = False + deterministic = False + # skip invalid flag combinations if swap_bwd_qk_loop: # TODO: support auto_range_merge mode with swap_bwd_qk_loop @@ -1673,6 +1678,104 @@ def test_ffa_simple( }, ) + # seqlen = 128 + # num_heads_q = 8 + # num_heads_kv = 2 + # head_dim = 64 + # dtype = torch.float16 + + # q_ranges = AttnRanges.from_ranges([[0, seqlen]]) + # k_ranges = AttnRanges.from_ranges([[0, seqlen]]) + # attn_type_map = [0] + # q_ranges_tensor = q_ranges.to_tensor(device=self.device) + # k_ranges_tensor = k_ranges.to_tensor(device=self.device) + # attn_type_map_tensor = torch.tensor( + # attn_type_map, dtype=torch.int32, device=self.device + # ) + + # mask = make_attn_mask_from_ffa_args( + # q_ranges=q_ranges, + # k_ranges=k_ranges, + # attn_type_map=attn_type_map, + # total_seqlen_q=seqlen, + # total_seqlen_k=seqlen, + # device=self.device, + # ) + + # q = torch.randn( + # seqlen, num_heads_q, head_dim, + # dtype=dtype, device=self.device, requires_grad=True, + # ) + # k = torch.randn( + # seqlen, num_heads_kv, head_dim, + # dtype=dtype, device=self.device, requires_grad=True, + # ) + # v = torch.randn( + # seqlen, num_heads_kv, head_dim, + # dtype=dtype, device=self.device, requires_grad=True, + # ) + # do = torch.randn_like(q) + + # q_ref = q.clone().detach().requires_grad_(True) + # k_ref = k.clone().detach().requires_grad_(True) + # v_ref = v.clone().detach().requires_grad_(True) + # out_ref, lse_ref = ref_attn_func( + # q=q_ref, + # k=k_ref, + # v=v_ref, + # mask=mask, + # layout="thd", + # backend="sdpa", + # high_precision=True, + # return_lse=True, + # ) + # out_ref.backward(do) + # dq_ref = q_ref.grad.clone() + # dk_ref = k_ref.grad.clone() + # dv_ref = v_ref.grad.clone() + + # for pack_gqa in [False, True]: + # q2 = q.clone().detach().requires_grad_(True) + # k2 = k.clone().detach().requires_grad_(True) + # v2 = v.clone().detach().requires_grad_(True) + # o2, lse2 = flex_flash_attn_func( + # q=q2, + # k=k2, + # v=v2, + # q_ranges=q_ranges_tensor, + # k_ranges=k_ranges_tensor, + # attn_type_map=attn_type_map_tensor, + # pack_gqa=pack_gqa, + # swap_bwd_qk_loop=True, + # ) + # o2.backward(do) + + # case = f"swap_bwd_qk_loop=True, pack_gqa={pack_gqa}" + # assert_close( + # q2.grad, + # dq_ref, + # atol=EPSILON, + # rtol=0.2, + # test_case=f"{case} => dq", + # print_rank=-1, + # ) + # assert_close( + # k2.grad, + # dk_ref, + # atol=EPSILON, + # rtol=0.08, + # test_case=f"{case} => dk", + # print_rank=-1, + # ) + # assert_close( + # v2.grad, + # dv_ref, + # atol=EPSILON, + # rtol=0.05, + # test_case=f"{case} => dv", + # print_rank=-1, + # ) + @with_run_in_mp @parameterize("model_config", MODEL_CONFIGS) @parameterize( @@ -1738,7 +1841,7 @@ def test_ffa_simple( @parameterize( "num_pairs", [10, 100, 1000] ) # the max num of qk range pairs to generate - @parameterize("dtype", [torch.float16, torch.bfloat16]) + @parameterize("dtype", [torch.bfloat16]) @parameterize( "attn_type", [0, 1, 2, 3, 4] ) # 0 - 3 means attn type are all 0/1/2/3, 4 means random attn type. @@ -1802,6 +1905,11 @@ def test_ffa_random( pack_gqa = ref_block_config["pack_gqa"] sparse_load = ref_block_config["sparse_load"] + pack_gqa = True + swap_bwd_qk_loop = True + auto_range_merge = False + deterministic = False + # skip invalid flag combinations if swap_bwd_qk_loop: # TODO: support auto_range_merge mode with swap_bwd_qk_loop From 6a16e66160862c7467756689a340933a5886d276 Mon Sep 17 00:00:00 2001 From: shw Date: Tue, 3 Feb 2026 07:57:21 +0000 Subject: [PATCH 05/14] fix test_flex_flash_attn --- tests/test_attn/test_flex_flash_attn.py | 157 ++++-------------------- 1 file changed, 27 insertions(+), 130 deletions(-) diff --git a/tests/test_attn/test_flex_flash_attn.py b/tests/test_attn/test_flex_flash_attn.py index dd15252c3..efbb9b673 100644 --- a/tests/test_attn/test_flex_flash_attn.py +++ b/tests/test_attn/test_flex_flash_attn.py @@ -1231,8 +1231,8 @@ def run_test_case( { "name": "mha_nh8_hd128", "num_heads_q": 8, - "num_heads_kv": 8, - "head_dim": 128, + "num_heads_kv": 2, + "head_dim": 64, }, { "name": "gqa_nhq32_nhkv4_hd128", @@ -1240,18 +1240,18 @@ def run_test_case( "num_heads_kv": 4, "head_dim": 128, }, - # { - # "name": "mha_nh1_hd64", - # "num_heads_q": 1, - # "num_heads_kv": 1, - # "head_dim": 64, - # }, - # { - # "name": "gqa_nhq4_nhkv2_hd64", - # "num_heads_q": 4, - # "num_heads_kv": 2, - # "head_dim": 64, - # }, + { + "name": "mha_nh1_hd64", + "num_heads_q": 1, + "num_heads_kv": 1, + "head_dim": 64, + }, + { + "name": "gqa_nhq4_nhkv2_hd64", + "num_heads_q": 4, + "num_heads_kv": 2, + "head_dim": 64, + }, ] @with_run_in_mp @@ -1504,17 +1504,17 @@ def run_test_case( ), "attn_type_map": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], }, - # { - # "name": "deterministic_sample", - # "seqlen": 2500, - # "q_ranges": AttnRanges.from_ranges( - # [[i * 50, (i + 1) * 50] for i in range(50) for j in range(50)] - # ), - # "k_ranges": AttnRanges.from_ranges( - # [[i * 50, (i + 1) * 50] for i in range(50)] * 50 - # ), - # "attn_type_map": [0, 1] * 1250, - # }, + { + "name": "deterministic_sample", + "seqlen": 2500, + "q_ranges": AttnRanges.from_ranges( + [[i * 50, (i + 1) * 50] for i in range(50) for j in range(50)] + ), + "k_ranges": AttnRanges.from_ranges( + [[i * 50, (i + 1) * 50] for i in range(50)] * 50 + ), + "attn_type_map": [0, 1] * 1250, + }, { "name": "sparse_attn_2k_with_same_k_ranges", "seqlen": 2048, @@ -1557,7 +1557,7 @@ def run_test_case( ], ) @parameterize("model_config", MODEL_CONFIGS) - @parameterize("dtype", [torch.bfloat16]) + @parameterize("dtype", [torch.float16, torch.bfloat16]) def test_ffa_simple( self, attn_mask_config: dict[str, Any], @@ -1602,11 +1602,6 @@ def test_ffa_simple( pack_gqa = ref_block_config["pack_gqa"] sparse_load = ref_block_config["sparse_load"] - pack_gqa = True - swap_bwd_qk_loop = True - auto_range_merge = False - deterministic = False - # skip invalid flag combinations if swap_bwd_qk_loop: # TODO: support auto_range_merge mode with swap_bwd_qk_loop @@ -1678,104 +1673,6 @@ def test_ffa_simple( }, ) - # seqlen = 128 - # num_heads_q = 8 - # num_heads_kv = 2 - # head_dim = 64 - # dtype = torch.float16 - - # q_ranges = AttnRanges.from_ranges([[0, seqlen]]) - # k_ranges = AttnRanges.from_ranges([[0, seqlen]]) - # attn_type_map = [0] - # q_ranges_tensor = q_ranges.to_tensor(device=self.device) - # k_ranges_tensor = k_ranges.to_tensor(device=self.device) - # attn_type_map_tensor = torch.tensor( - # attn_type_map, dtype=torch.int32, device=self.device - # ) - - # mask = make_attn_mask_from_ffa_args( - # q_ranges=q_ranges, - # k_ranges=k_ranges, - # attn_type_map=attn_type_map, - # total_seqlen_q=seqlen, - # total_seqlen_k=seqlen, - # device=self.device, - # ) - - # q = torch.randn( - # seqlen, num_heads_q, head_dim, - # dtype=dtype, device=self.device, requires_grad=True, - # ) - # k = torch.randn( - # seqlen, num_heads_kv, head_dim, - # dtype=dtype, device=self.device, requires_grad=True, - # ) - # v = torch.randn( - # seqlen, num_heads_kv, head_dim, - # dtype=dtype, device=self.device, requires_grad=True, - # ) - # do = torch.randn_like(q) - - # q_ref = q.clone().detach().requires_grad_(True) - # k_ref = k.clone().detach().requires_grad_(True) - # v_ref = v.clone().detach().requires_grad_(True) - # out_ref, lse_ref = ref_attn_func( - # q=q_ref, - # k=k_ref, - # v=v_ref, - # mask=mask, - # layout="thd", - # backend="sdpa", - # high_precision=True, - # return_lse=True, - # ) - # out_ref.backward(do) - # dq_ref = q_ref.grad.clone() - # dk_ref = k_ref.grad.clone() - # dv_ref = v_ref.grad.clone() - - # for pack_gqa in [False, True]: - # q2 = q.clone().detach().requires_grad_(True) - # k2 = k.clone().detach().requires_grad_(True) - # v2 = v.clone().detach().requires_grad_(True) - # o2, lse2 = flex_flash_attn_func( - # q=q2, - # k=k2, - # v=v2, - # q_ranges=q_ranges_tensor, - # k_ranges=k_ranges_tensor, - # attn_type_map=attn_type_map_tensor, - # pack_gqa=pack_gqa, - # swap_bwd_qk_loop=True, - # ) - # o2.backward(do) - - # case = f"swap_bwd_qk_loop=True, pack_gqa={pack_gqa}" - # assert_close( - # q2.grad, - # dq_ref, - # atol=EPSILON, - # rtol=0.2, - # test_case=f"{case} => dq", - # print_rank=-1, - # ) - # assert_close( - # k2.grad, - # dk_ref, - # atol=EPSILON, - # rtol=0.08, - # test_case=f"{case} => dk", - # print_rank=-1, - # ) - # assert_close( - # v2.grad, - # dv_ref, - # atol=EPSILON, - # rtol=0.05, - # test_case=f"{case} => dv", - # print_rank=-1, - # ) - @with_run_in_mp @parameterize("model_config", MODEL_CONFIGS) @parameterize( @@ -1841,7 +1738,7 @@ def test_ffa_simple( @parameterize( "num_pairs", [10, 100, 1000] ) # the max num of qk range pairs to generate - @parameterize("dtype", [torch.bfloat16]) + @parameterize("dtype", [torch.float16, torch.bfloat16]) @parameterize( "attn_type", [0, 1, 2, 3, 4] ) # 0 - 3 means attn type are all 0/1/2/3, 4 means random attn type. From 9bae6ba6be8ea5a3c26cf8a85d739c2ce4bfe2af Mon Sep 17 00:00:00 2001 From: shw Date: Tue, 3 Feb 2026 11:58:53 +0000 Subject: [PATCH 06/14] format packgqa for ffa bwd --- .../bwd_tile_scheduler.hpp | 9 ------ .../flexible_flash_attention/epilogue_bwd.hpp | 19 ------------ .../mainloop_bwd_sm90_tma_gmma_ws.hpp | 31 +++---------------- tests/test_attn/test_flex_flash_attn.py | 9 ++---- 4 files changed, 6 insertions(+), 62 deletions(-) diff --git a/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp b/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp index 65183acd3..d907b6c1f 100644 --- a/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp +++ b/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp @@ -219,15 +219,6 @@ class DynamicPersistentTileSchedulerBwd { int bidh = mh_block / num_m_blocks; int block = mh_block - bidh * num_m_blocks; - /* DEBUG */ - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = - // %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, - // bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); - // } - // if (threadIdx.x == 0) { - // printf("blockIdx.x = %d, threadIdx.x = %d, bidb = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, bidb, bidh, block); - // } if constexpr (!Deterministic) { return {next_tile_idx, block, bidh, bidb}; } else { diff --git a/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp b/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp index de3d139d2..cd066fe84 100644 --- a/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp +++ b/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp @@ -518,25 +518,6 @@ struct CollectiveEpilogueBwd { // Get block coordinates for current job (tile) int m_block = get<0>(block_coord), bidh = get<1>(block_coord), bidb = get<2>(block_coord); - // #region agent log - DEBUG: print for all m_blocks when PackGQA=1 - // if (PackGQA && bidh == 0 && bidb == 0 && thread_idx == 0) { - // // Check first element of tdQrdQ to see if it's zero - // float first_val = static_cast(tdQrdQ(0)); - // printf("[DEBUG store_dq] PackGQA=%d, m_block=%d, bidh=%d, bidb=%d, tdQrdQ[0]=%.4f\n", - // PackGQA, m_block, bidh, bidb, first_val); - // } - // #endregion - - // if (m_block == 0 && bidh == 0 && bidb == 0 && thread_idx == 0) { - // printf("\n[DEBUG CUDA] store_dq(PackGQA=%d, m_block=%d, bidh=%d, bidb=%d):\n", PackGQA, m_block, bidh, bidb); - // printf("\n[DEBUG CUDA] tdQrdQ:\n"); - // cute::print_tensor(tdQrdQ); - // // printf("\n[DEBUG CUDA] sdQ:\n"); - // // cute::print_tensor(sdQ); - // // printf("\n[DEBUG CUDA] sdQt:\n"); - // // cute::print_tensor(sdQt); - // printf("\n"); - // } // For PackGQA, bidh is already KV head index (scheduler uses num_heads_kv) // For non-PackGQA, bidh is Q head index int bidh_idx_in_group; diff --git a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp index deb3df021..127b0cded 100644 --- a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -2024,6 +2024,8 @@ struct CollectiveMainloopBwdSm90 { // Get block coordinates and seqlen info int m_block = get<0>(block_coord), bidh = get<1>(block_coord), bidb = get<2>(block_coord); + // For PackGQA, bidh is already the KV head index + int bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; SeqlenInfo_t seqlen_info{bidb, params.q_ranges, params.k_ranges}; int const seqlen_q = seqlen_info.seqlen_q, seqlen_k = seqlen_info.seqlen_k; // For PackGQA, the packed seqlen_q is seqlen_q * Qhead_per_khead @@ -2153,12 +2155,12 @@ struct CollectiveMainloopBwdSm90 { // For the case where we do atomicAdd directly to gdKaccum,gdVaccum instead of using TMA Tensor mdKaccum = - make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dK)), params.shape_dK, params.stride_dK)(_, _, bidh); // (seqlen_kv, head_dim) + make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dK)), params.shape_dK, params.stride_dK)(_, _, bidh_kv); // (seqlen_kv, head_dim) Tensor gdKaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mdKaccum), TileShape_dKVaccum{}, make_coord(_, _0{})); // (N, K, _) Tensor gdKaccum = cute::flat_divide(gdKaccum_, make_shape(Int{}, Int{})); // (N / WG, K, WG, 1, _) Tensor mdVaccum = - make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dV)), params.shape_dV, params.stride_dV)(_, _, bidh); // (seqlen_kv, head_dim) + make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dV)), params.shape_dV, params.stride_dV)(_, _, bidh_kv); // (seqlen_kv, head_dim) Tensor gdVaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mdVaccum), TileShape_dKVaccum{}, make_coord(_, _0{})); // (N, K, _) Tensor gdVaccum = cute::flat_divide(gdVaccum_, make_shape(Int{}, Int{})); // (N / WG, K, WG, 1, _) @@ -2170,21 +2172,6 @@ struct CollectiveMainloopBwdSm90 { Tensor tdVgdV = block_tma_dV.partition_D(gdVaccum); // (TMA, TMA_N, TMA_K) Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_N, TMA_K) - /* DEBUG */ - // if (thread_idx == 0 && bidh == 0 && m_block == 0){ - // printf("bidb: %d, offset_k: %d\n", bidb, seqlen_info.offset_k); - // printf("mdKaccum: "); print(mdKaccum); printf("\n"); - // printf("gdKaccum_: "); print(gdKaccum_); printf("\n"); - // printf("gdKaccum: "); print(gdKaccum); printf("\n"); - // printf("tdKgdK: "); print(tdKgdK); printf("\n"); - // printf("tdKsdK: "); print(tdKsdK); printf("\n"); - // printf("mdVaccum: "); print(mdVaccum); printf("\n"); - // printf("gdVaccum_: "); print(gdVaccum_); printf("\n"); - // printf("gdVaccum: "); print(gdVaccum); printf("\n"); - // printf("tdVgdV: "); print(tdVgdV); printf("\n"); - // printf("tdVsdV: "); print(tdVsdV); printf("\n"); - // } - // We can reuse r2s_thr_copy_dKVaccum for this partitioning Tensor tdKgdKaccum = r2s_thr_copy_dKVaccum.partition_D(gdKaccum); Tensor tdVgdVaccum = r2s_thr_copy_dKVaccum.partition_D(gdVaccum); @@ -2208,16 +2195,6 @@ struct CollectiveMainloopBwdSm90 { shared_storage.pipelines.barrier_QdO.wait(work_idx % 2); } - // DEBUG: - // if (m_block == 0 && bidh == 1 && bidb == 1 && thread_idx == 0) { - // printf("\n[DEBUG CUDA] sQ (PackGQA=%d, m_block=%d, bidh=%d, bidb=%d):\n", PackGQA, m_block, bidh, bidb); - // cute::print_tensor(sQ); - // printf("\n[DEBUG CUDA] sdO:\n"); - // cute::print_tensor(sdO); - // printf("\n"); - // } - // } - // Copy LSE from shared memory to registers if constexpr (!ShuffleLSE) { cute::copy(tLSEsLSE, tLSErLSE); diff --git a/tests/test_attn/test_flex_flash_attn.py b/tests/test_attn/test_flex_flash_attn.py index efbb9b673..e5e2d3c6e 100644 --- a/tests/test_attn/test_flex_flash_attn.py +++ b/tests/test_attn/test_flex_flash_attn.py @@ -1231,8 +1231,8 @@ def run_test_case( { "name": "mha_nh8_hd128", "num_heads_q": 8, - "num_heads_kv": 2, - "head_dim": 64, + "num_heads_kv": 8, + "head_dim": 128, }, { "name": "gqa_nhq32_nhkv4_hd128", @@ -1802,11 +1802,6 @@ def test_ffa_random( pack_gqa = ref_block_config["pack_gqa"] sparse_load = ref_block_config["sparse_load"] - pack_gqa = True - swap_bwd_qk_loop = True - auto_range_merge = False - deterministic = False - # skip invalid flag combinations if swap_bwd_qk_loop: # TODO: support auto_range_merge mode with swap_bwd_qk_loop From b880a06227b3950588e430e0223bc1c5142f17ce Mon Sep 17 00:00:00 2001 From: shw Date: Wed, 4 Feb 2026 03:17:22 +0000 Subject: [PATCH 07/14] add packgqa_swapab bench --- exps/attn/run_packgqa_swapab_bench.py | 359 +++++++++++++++++++ magi_attention/functional/flex_flash_attn.py | 1 + 2 files changed, 360 insertions(+) create mode 100644 exps/attn/run_packgqa_swapab_bench.py diff --git a/exps/attn/run_packgqa_swapab_bench.py b/exps/attn/run_packgqa_swapab_bench.py new file mode 100644 index 000000000..49750044e --- /dev/null +++ b/exps/attn/run_packgqa_swapab_bench.py @@ -0,0 +1,359 @@ +# Copyright (c) 2025-2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from datetime import datetime + +import torch +from baselines.attn_impl import ffa_func +from baselines.utils import seed_everything +from einops import rearrange + +from magi_attention.benchmarking import Benchmark, do_bench_flops, perf_report +from magi_attention.utils.sparse_utils import ( + flatten_block_mask_to_kv_shape, + generate_block_sparse_pattern, + generate_ranges_from_block_mask, +) + +impls = ["ffa_packgqa_swapab", "ffa"] + +# actual seqlen +seqlens = [8192 * 8] + +# current block sparse attention always has low sparsity +sparsity_ratio = [0.05] +# ss = [k * 1024 for k in [4, 96, 128]] +ds = [128] +wds = ["fwd"] +attn_modes = ["GQA"] # MHA, GQA +nhqs = [64] +num_groups = [8, 16] +# small K block +# q_block_sizes = [64, 64, 64, 64, 64] +# k_block_sizes = [64, 32, 16, 8, 1] +# small Q block +# q_block_sizes = [64, 32, 16, 8] +# k_block_sizes = [64, 64, 64, 64] +q_block_sizes = [1] +k_block_sizes = [64] +# large Q block and K block +# q_block_sizes = [64, 128] +# k_block_sizes = [64, 128] + +assert len(q_block_sizes) == len(k_block_sizes) + +b = 1 + +dtype = torch.bfloat16 + +bias = None +softmax_scale = None +dropout_p = 0.0 +return_attn_probs = False + +quantiles = [0.5, 0.2, 0.8] + + +attn_flops_configs = [ + Benchmark( + x_names=["sparsity_ratio"], # Argument names to use as an x-axis for the plot. + x_vals=sparsity_ratio, # Different possible values for `x_name`. + x_log=False, # x axis is logarithmic. + line_arg="attn_impl", # Argument name whose value corresponds to a different line in the plot. + line_vals=impls, # Possible values for `line_arg`. + line_names=impls, # Label name for the lines. + styles=[ # Line styles. + ("green", "--"), + ("orange", "--"), + ("steelblue", "--"), + ("red", "-"), + ], + ylabel={ # Label name for the y-axis. + "flops": "Throughout (TFLOPs/s)", + "mem": "Peak Memory (GB)", + }, + plot_name=( + f"block sparse attn-{wd} attn_mode-{attn_mode} " + f"{'n_head-' + str(nhq) if attn_mode == 'MHA' else f'n_head-{nhq}:{nhq // num_group}'}\n" + f"block_size-{q_block_size}:{k_block_size} seq_len {seqlen}" + ), + # Name for the plot. Used also as a file name for saving the plot. + args={ # Values for function arguments not in `x_names` and `y_name`. + "hd": hd, + "wd": wd, + "q_block_size": q_block_size, + "k_block_size": k_block_size, + "seqlen": seqlen, + "num_group": num_group, + "attn_mode": attn_mode, + "nhq": nhq, + }, + ) + for hd in ds + for wd in wds + for q_block_size, k_block_size in zip(q_block_sizes, k_block_sizes) + for seqlen in seqlens + for num_group in num_groups + for attn_mode in attn_modes + for nhq in nhqs +] + +seed_everything() + + +@perf_report(attn_flops_configs) +def sparse_attn_benchmark( + sparsity_ratio, + hd, + wd, + q_block_size, + k_block_size, + seqlen, + num_group, + attn_mode, + nhq, + attn_impl, +): + assert b == 1, "for now, we only supports b=1 for ffa" + is_attn_impl_support_this_mask = True + already_known_oom_before_run = False + + # --------- prepare arguments --------- # + + device = torch.cuda.current_device() + orig_seq_len_q = orig_seq_len_k = seqlen # fi square mask where sq == sk + block_m = q_block_size + block_n = k_block_size + + num_q_blocks_orig = orig_seq_len_q // block_m + num_kv_blocks_orig = orig_seq_len_k // block_n + orig_head = nhq + if attn_mode == "MHA": + nhk = nhq + elif attn_mode == "GQA": + nhk = nhq // num_group + + # prepare q, k ranges and calculate attn_flops + # for now, we only do bench for block sparse mask. + # block_mask, scores = generate_global_block_sparse_pattern( + # orig_head, num_q_blocks_orig, num_kv_blocks_orig, sparsity_ratio, device="cuda" + # ) + + block_mask, scores = generate_block_sparse_pattern( + num_q_heads=nhq, + num_kv_heads=nhk, + num_q_blocks=num_q_blocks_orig, + num_kv_blocks=num_kv_blocks_orig, + sparsity=sparsity_ratio, + device="cuda", + ) + + attn_flops = 4 * orig_seq_len_q * orig_seq_len_k * orig_head * hd * sparsity_ratio + # --------- prepare data --------- # + # flash style shape: (b,s,h,d) + q = torch.randn( + b, orig_seq_len_q, nhq, hd, device=device, dtype=dtype, requires_grad=False + ) + k = torch.randn( + b, orig_seq_len_k, nhk, hd, device=device, dtype=dtype, requires_grad=False + ) + v = torch.randn( + b, orig_seq_len_k, nhk, hd, device=device, dtype=dtype, requires_grad=False + ) + + # ffa style shape: (t,h,d) + if attn_impl in ("ffa_packgqa_swapab", "ffa"): + h1 = nhk + q = rearrange(q, "b s (h1 h2) d -> (b h1 s) h2 d", h1=h1) + k = rearrange(k, "b s h d -> (b h s) 1 d") + v = rearrange(v, "b s h d -> (b h s) 1 d") + + if attn_impl in ("sdpa", "vsa", "vsa_triton", "flashinfer", "flex"): + q = rearrange(q, "b s h d -> b h s d") + k = rearrange(k, "b s h d -> b h s d") + v = rearrange(v, "b s h d -> b h s d") + + # --------- prepare grads --------- # + + if wd == "bwd": + attn_flops = attn_flops * 2.5 + do = torch.randn_like(q) + # require grads + [x.requires_grad_(True) for x in [q, k, v, do]] + + # --------- prepare func --------- # + # is_attn_impl_support_this_mask = block_sparse_available( + # attn_impl, nhq, nhk, q_block_size, k_block_size, wd + # ) + is_attn_impl_support_this_mask = True + if is_attn_impl_support_this_mask: + if attn_impl == "ffa_packgqa_swapab": + # flatten headdim for ffa cause + flat_block_sparse_mask = flatten_block_mask_to_kv_shape(block_mask) + q_ranges, k_ranges = generate_ranges_from_block_mask( + flat_block_sparse_mask, block_m, block_n + ) + + attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda") + + # TODO: we need to optimize choose_ref_block. + # You'd better set ref_blocks manually now + # ref_block_size = choose_ref_block((q_block_size, k_block_size)) + ref_block_size = (64, 64) + + def fn(): + return ffa_func( + q, + k, + v, + q_ranges=q_ranges, + k_ranges=k_ranges, + attn_type_map=attn_type_map, + auto_range_merge=True, # we should enable auto_range_merge for block sparse mask. + ref_block_size=ref_block_size, + pack_gqa=True, + swap_ab=True, + disable_fwd_atomic_reduction=True, + ) + + if wd == "bwd": + try: + o, *rest = fn() + except Exception as e: + if "CUDA out of memory" not in str(e): + print( + f"Error occured before running {attn_impl} with " + f"{q_block_size=}, {k_block_size=} " + f"when {seqlen=}, {hd=} during {wd}: {e=}" + ) + raise e + already_known_oom_before_run = True + + def fn(): + o.backward(do, retain_graph=True) + + elif attn_impl == "ffa": + # flatten headdim for ffa cause + flat_block_sparse_mask = flatten_block_mask_to_kv_shape(block_mask) + q_ranges, k_ranges = generate_ranges_from_block_mask( + flat_block_sparse_mask, block_m, block_n + ) + + attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda") + + # ref_block_size = choose_ref_block((q_block_size, k_block_size)) + ref_block_size = (64, 64) + + def fn(): + return ffa_func( + q, + k, + v, + q_ranges=q_ranges, + k_ranges=k_ranges, + attn_type_map=attn_type_map, + auto_range_merge=True, # we should enable auto_range_merge for block sparse mask. + ref_block_size=ref_block_size, + pack_gqa=False, + disable_fwd_atomic_reduction=True, + ) + + if wd == "bwd": + try: + o, *rest = fn() + except Exception as e: + if "CUDA out of memory" not in str(e): + print( + f"Error occured before running {attn_impl} with " + f"{q_block_size=}, {k_block_size=} " + f"when {seqlen=}, {hd=} during {wd}: {e=}" + ) + raise e + already_known_oom_before_run = True + + def fn(): + o.backward(do, retain_graph=True) + + # --------- try do the bench --------- # + if is_attn_impl_support_this_mask: + if already_known_oom_before_run: + # -1 indicates oom + perf_dict = { + "flops": [-1, -1, -1], + "mem": [-1, -1, -1], + } + else: + try: + # disable mem test to only test flops for now + perf_dict = do_bench_flops( + fn, + quantiles=quantiles, + mem_record_mode="peak", + ) + + # --------- process report --------- # + + # post process the perf_dict + def ms_to_tflops(ms: float) -> float: + return attn_flops / ms * 1e-9 + + perf_dict["flops"] = list(map(ms_to_tflops, perf_dict["flops"])) + + # disable mem test + def gb(m): + return m / 1024**3 + + # perf_dict["mem"] = list(map(gb, perf_dict["mem"])) + except Exception as e: + if "CUDA out of memory" not in str(e): + print( + f"Error occured before running {attn_impl} with " + f"{q_block_size=}, {k_block_size=} " + f"when {seqlen=}, {hd=} during {wd}: {e=}" + ) + perf_dict = { + "flops": [-2, -2, -2], + "mem": [-2, -2, -2], + } + # raise e + # -1 indicates oom + perf_dict = { + "flops": [-1, -1, -1], + "mem": [-1, -1, -1], + } + print( + f"Error occured before running {attn_impl} with {q_block_size=}, {k_block_size=} " + f"when {seqlen=}, {hd=} during {wd}: {e=}" + ) + else: + # -2 indicates not support + perf_dict = { + "flops": [-2, -2, -2], + "mem": [-2, -2, -2], + } + + return perf_dict + + +if __name__ == "__main__": + script_dir = os.path.dirname(os.path.abspath(__file__)) + current_time = datetime.strftime(datetime.now(), "%Y-%m-%d_%H-%M-%S") + out_root = os.path.join( + script_dir, os.path.join("outs", f"bench_attn_{current_time}") + ) + + sparse_attn_benchmark.run( + print_data=True, print_value_on_bar=False, save_path=out_root + ) diff --git a/magi_attention/functional/flex_flash_attn.py b/magi_attention/functional/flex_flash_attn.py index c7cbcb465..d4677db4a 100644 --- a/magi_attention/functional/flex_flash_attn.py +++ b/magi_attention/functional/flex_flash_attn.py @@ -965,6 +965,7 @@ def flex_flash_attn_func( seqlen_q scenarios. This method significantly improves the computational efficiency of block sparse attention when seqlen_q is small. **Note:** kblockm must be divisible by qhead_per_khead(num_qhead // num_khead). + For backward pass, this flag is only enabled when swap_bwd_qk_loop is True. sparse_load (bool, optional): Whether to enable sparse load mode for optimizing performance when k_range size is small (< 64). From b3d89585ba88413ff5ece0f1c627e6c573ff35f2 Mon Sep 17 00:00:00 2001 From: shw Date: Wed, 4 Feb 2026 03:28:13 +0000 Subject: [PATCH 08/14] format --- magi_attention/csrc/flexible_flash_attention/block.h | 1 - 1 file changed, 1 deletion(-) diff --git a/magi_attention/csrc/flexible_flash_attention/block.h b/magi_attention/csrc/flexible_flash_attention/block.h index 8f9040215..ae8195e70 100644 --- a/magi_attention/csrc/flexible_flash_attention/block.h +++ b/magi_attention/csrc/flexible_flash_attention/block.h @@ -89,7 +89,6 @@ struct BlockMN { } else if (attn_type == flash::AttnType::InvCausal || attn_type == flash::AttnType::Full) { // do nothing } - return {m_block_min, m_block_max}; } }; From a7c21e85984d25545a6163f85b754714fba59a71 Mon Sep 17 00:00:00 2001 From: shw Date: Mon, 9 Feb 2026 09:43:18 +0000 Subject: [PATCH 09/14] add write back for log2lse and dpsum --- .../flash_bwd_preprocess_kernel.h | 40 +++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/magi_attention/csrc/flexible_flash_attention/flash_bwd_preprocess_kernel.h b/magi_attention/csrc/flexible_flash_attention/flash_bwd_preprocess_kernel.h index 2c935f84b..f06dd098b 100644 --- a/magi_attention/csrc/flexible_flash_attention/flash_bwd_preprocess_kernel.h +++ b/magi_attention/csrc/flexible_flash_attention/flash_bwd_preprocess_kernel.h @@ -355,8 +355,11 @@ class FlashAttnBwdPreprocess { } // Initialize the output tensor for dPsum - Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(0, _, bidh); // [sq,] - Tensor gdPsum = local_tile(cute::domain_offset(make_coord(0), mdPsum), Shape>{}, make_coord(m_block)); // (M,) + // Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(0, _, bidh); // [sq,] + Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, _, bidh); + + // Tensor gdPsum = local_tile(cute::domain_offset(make_coord(0), mdPsum), Shape>{}, make_coord(m_block)); // (M,) + Tensor gdPsum = local_tile(mdPsum, Shape, Int>{}, make_coord(_0{}, m_block)); // (4, M) // Store the reduced dPsum to output tensor // and also compute partial dsink if `Has_sink` @@ -365,8 +368,21 @@ class FlashAttnBwdPreprocess { #pragma unroll for (int mi = 0; mi < size(dP_sum); ++mi) { int const row_idx = get<0>(tOcO(_0{}, mi, _0{})); // row_idx - float dPsum_mi = row_idx < remain_valid_seqlen_q ? dP_sum(mi) : 0; // NOTE: we make OOB dPsum as 0 - gdPsum(row_idx) = dPsum_mi; // NOTE: the OOB part had better be set to 0 + float dPsum_mi = 0.0f; + if (row_idx < remain_valid_seqlen_q) { + dPsum_mi = dP_sum(mi); + } + + if (row_idx < kBlockM) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + // write 0 to index 1 - 4 at dim 0 + gdPsum(i, row_idx) = (i == 0) ? dPsum_mi : 0.0f; + } + } + + // float dPsum_mi = row_idx < remain_valid_seqlen_q ? dP_sum(mi) : 0; // NOTE: we make OOB dPsum as 0 + // gdPsum(row_idx) = dPsum_mi; // NOTE: the OOB part had better be set to 0 // Compute `dsink = p_sink * -dPsum` if constexpr (Has_sink) { @@ -392,16 +408,26 @@ class FlashAttnBwdPreprocess { } // Initialize the output tensor for LSE_log2 - Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(0, _, bidh); // [sq,] - Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(0), mLSElog2), Shape>{}, make_coord(m_block)); // (M,) + // Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(0, _, bidh); // [sq,] + // Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(0), mLSElog2), Shape>{}, make_coord(m_block)); // (M,) + + Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, _, bidh); + Tensor gLSElog2 = local_tile(mLSElog2, Shape, Int>{}, make_coord(_0{}, m_block)); // (4, M) // Scale and store the LSE to LSE_log2 // NOTE: we reset the valid `-inf` to 0 // to make the subsequent calculation of scores (exp(x - lse)) always correct // since when x = lse = `-inf`, the results would be NaN, but the expected result is `-inf`. // So instead, we reset `-inf` lse to 0 to make `-inf` - (`-inf`) become `-inf` - 0 = `-inf` if (is_valid_row) { - gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E); + float lse_val = (lse == -INFINITY) ? 0.f : lse * float(M_LOG2E); + +#pragma unroll + for (int i = 0; i < 4; ++i) { + // write 0 to index 1 - 4 at dim 0 + gLSElog2(i, thread_idx) = (i == 0) ? lse_val : 0.0f; + } + // gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E); } // Reduce partial dsink along the seqlen_q dim From e63a96398ff9fc003508a8a0c754283212069b6e Mon Sep 17 00:00:00 2001 From: shw Date: Mon, 9 Feb 2026 11:17:38 +0000 Subject: [PATCH 10/14] add bwd packgqa with k outer loop template --- .../csrc/flexible_flash_attention/bwd_inst_template.jinja | 3 +-- .../csrc/flexible_flash_attention/bwd_tile_scheduler.hpp | 8 +++++--- .../flexible_flash_attention/flash_bwd_launch_template.h | 3 ++- .../mainloop_bwd_sm90_tma_gmma_ws.hpp | 7 ++++++- magi_attention/functional/flex_flash_attn.py | 4 ++-- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja b/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja index 0ae5901f9..924bd7e63 100644 --- a/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja +++ b/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja @@ -25,8 +25,7 @@ static constexpr int kQheadPerKhead = {{ qhead_per_khead }}; // TODO: add support for RangeMerge and Deterministic mode when SwapBwdQKLoop is enabled static_assert(!kSwapBwdQKLoop || (!kRangeMerge && !kDeterministic), "Neither RangeMerge nor Deterministic mode is supported by now when SwapBwdQKLoop is enabled."); -// PackGQA is only supported when SwapBwdQKLoop is enabled -static_assert(!kPackGQA || kSwapBwdQKLoop, "PackGQA is only supported when SwapBwdQKLoop is enabled."); +// PackGQA is supported for both SwapBwdQKLoop=true and SwapBwdQKLoop=false // Runtime contract checks to ensure consistency with compile-time constraints static inline void _check_runtime_contract_bwd( diff --git a/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp b/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp index d907b6c1f..095158765 100644 --- a/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp +++ b/magi_attention/csrc/flexible_flash_attention/bwd_tile_scheduler.hpp @@ -53,7 +53,8 @@ template < int NumProducerThreads = cutlass::NumThreadsPerWarp, bool WarpSpecialized = true, bool PackGQA = false, - bool Deterministic = false> + bool Deterministic = false, + bool SwapBwdQKLoop = false> class DynamicPersistentTileSchedulerBwd { using resv_barrier = cutlass::arch::ReservedNamedBarriers; static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); @@ -81,8 +82,9 @@ class DynamicPersistentTileSchedulerBwd { }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { - // PackGQA: seqlen_scale_factor = num_heads_q / num_heads_kv, otherwise 1 - int seqlen_scale_factor = !PackGQA ? 1 : (args.num_heads_q / args.num_heads_kv); + // Only scale when PackGQA && SwapBwdQKLoop (Q is outer loop, scaled by Qhead_per_khead) + // When !SwapBwdQKLoop, K is outer loop, no seqlen scaling needed + int seqlen_scale_factor = (PackGQA && SwapBwdQKLoop) ? (args.num_heads_q / args.num_heads_kv) : 1; // PackGQA: num_heads = num_heads_kv, otherwise num_heads_q int num_heads = !PackGQA ? args.num_heads_q : args.num_heads_kv; diff --git a/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h b/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h index 442eec9d6..260b00891 100644 --- a/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h +++ b/magi_attention/csrc/flexible_flash_attention/flash_bwd_launch_template.h @@ -183,7 +183,8 @@ void run_flash_bwd(Flash_bwd_params& params, cudaStream_t stream) { CollectiveMainloop::NumProducerThreads, /*WarpSpecialized=*/Arch >= 90, /*PackGQA=*/PackGQA, - /*Deterministic=*/Deterministic>; + /*Deterministic=*/Deterministic, + /*SwapBwdQKLoop=*/SwapBwdQKLoop>; using CollectiveEpilogue = flash::CollectiveEpilogueBwd< TileShape_MNK, ElementDkv, diff --git a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp index 127b0cded..598673bb0 100644 --- a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -696,12 +696,17 @@ struct CollectiveMainloopBwdSm90 { static_assert(!SwapBwdQKLoop, "load_with_loop_q() must be called when SwapBwdQKLoop is false"); int n_block = get<0>(block_coord), bidh = get<1>(block_coord), bidb = get<2>(block_coord); - int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); + // int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); + int bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; SeqlenInfo_t seqlen_info{bidb, params.q_ranges, params.k_ranges}; flash::AttnType attn_type = static_cast(params.attn_type_map ? params.attn_type_map[bidb] : 0); auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max(seqlen_info, n_block, bidb, attn_type); + if (threadIdx.x == 0) { + printf( + "load_with_loop_q: n_block: %d, m_block_min: %d, m_block_max: %d, bidh: %d, bidh_kv: %d bidb: %d\n", n_block, m_block_min, m_block_max, bidh, bidh_kv, bidb); + } // It's possible to have m_block_max <= m_block_min, // where loading Q,dO might cause illegal memory access if (m_block_max <= m_block_min) { diff --git a/magi_attention/functional/flex_flash_attn.py b/magi_attention/functional/flex_flash_attn.py index 947da5227..5f62620dc 100644 --- a/magi_attention/functional/flex_flash_attn.py +++ b/magi_attention/functional/flex_flash_attn.py @@ -825,8 +825,8 @@ def backward(ctx, dout: torch.Tensor, *args): # pragma: no cover ) merge_k_ranges, bwd_kq_map, bwd_unique_count = None, None, None - # pack_gqa in backward is only enabled when both pack_gqa and swap_bwd_qk_loop are True - bwd_pack_gqa = ctx.pack_gqa and ctx.swap_bwd_qk_loop + # pack_gqa in backward + bwd_pack_gqa = ctx.pack_gqa dq, dk, dv, dsink = _flex_flash_attn_backward( dout=dout, From bdd5207d6c066f439cdb0dc52c49cfc986fe387c Mon Sep 17 00:00:00 2001 From: shw Date: Mon, 9 Feb 2026 12:20:21 +0000 Subject: [PATCH 11/14] complete load and dk dv store --- .../flexible_flash_attention/epilogue_bwd.hpp | 6 +- .../mainloop_bwd_sm90_tma_gmma_ws.hpp | 99 ++++++++++++++----- 2 files changed, 80 insertions(+), 25 deletions(-) diff --git a/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp b/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp index cd066fe84..ff44117de 100644 --- a/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp +++ b/magi_attention/csrc/flexible_flash_attention/epilogue_bwd.hpp @@ -394,7 +394,8 @@ struct CollectiveEpilogueBwd { int n_block = get<0>(block_coord), bidh = get<1>(block_coord), bidb = get<2>(block_coord); int bidh_idx_in_group; - int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh); + // When PackGQA, bidh from scheduler is already bidh_kv (scheduler uses num_heads_kv) + int bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh) : bidh; Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{})); Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{})); Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{})); @@ -614,7 +615,8 @@ struct CollectiveEpilogueBwd { int right_range_conflict_msg = get<4>(block_coord); int arrive_num = get<5>(block_coord); int bidh_idx_in_group; - int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh); + // When PackGQA, bidh from scheduler is already bidh_kv + int bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh) : bidh; SeqlenInfo_t seqlen_info{bidb, params.q_ranges, params.k_ranges}; int offset_k = seqlen_info.offset_k; int qheads_per_kheads = params.qhead_per_khead_divmod; diff --git a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp index 598673bb0..574474cc8 100644 --- a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -724,35 +724,72 @@ struct CollectiveMainloopBwdSm90 { auto [mcast_mask_qdo, cluster_block_id_qdo] = get_tma_multi_cast_meta(); // Prepare the TMA loads - Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh); // (seqlen_q, head_dim) - Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh); // (seqlen_q, head_dim) + // For PackGQA, use packed TMA and shape_Q_packed + Tensor mQ = [&]() { + if constexpr (!PackGQA) { + return params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh); // (seqlen_q, head_dim) + } else { + return params.tma_load_Q_packed.get_tma_tensor(params.shape_Q_packed)(_, _, bidh); // ((qhead_per_khead, seqlen_q), head_dim) + } + }(); + Tensor mdO = [&]() { + if constexpr (!PackGQA) { + return params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh); // (seqlen_q, head_dim) + } else { + return params.tma_load_dO_packed.get_tma_tensor(params.shape_Q_packed)(_, _, bidh); // ((qhead_per_khead, seqlen_q), head_dim) + } + }(); Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv); // (seqlen_kv, head_dim) Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv); // (seqlen_kv, head_dim) - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, _, bidh); // (4, seqlen_q) - Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, _, bidh); // (4, seqlen_q) + // For PackGQA, LSE/dPsum use packed shape/stride to correctly read data from multiple Q heads + auto mLSE = [&]() { + if constexpr (!PackGQA) { + return make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, _, bidh); // (4, seqlen_q) + } else { + return make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE_packed, params.stride_LSE_packed)(_, _, bidh); // (4, (qhead_per_khead, seqlen_q)) + } + }(); + auto mdPsum = [&]() { + if constexpr (!PackGQA) { + return make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, _, bidh); // (4, seqlen_q) + } else { + return make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE_packed, params.stride_dPsum_packed)(_, _, bidh); // (4, (qhead_per_khead, seqlen_q)) + } + }(); - Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) - Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + // For PackGQA, offset needs to be multiplied by Qhead_per_khead + int offset_q_packed = !PackGQA ? seqlen_info.offset_q : seqlen_info.offset_q * Qhead_per_khead; + Tensor gQ = local_tile(domain_offset(make_coord(offset_q_packed, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gdO = local_tile(domain_offset(make_coord(offset_q_packed, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + // For PackGQA, LSE/dPsum also use packed offset to match Q/dO's packed access pattern auto bulk_copy = Copy_Traits{}; - Tensor gLSE = local_tile(cute::domain_offset(make_coord(_0{}, seqlen_info.offset_q), mLSE), make_shape(_4{}, Int{}), make_coord(_0{}, _)); // (4, M, _) - Tensor gdPsum = local_tile(cute::domain_offset(make_coord(_0{}, seqlen_info.offset_q), mdPsum), make_shape(_4{}, Int{}), make_coord(_0{}, _)); // (4, M, _) + Tensor gLSE = local_tile(cute::domain_offset(make_coord(_0{}, offset_q_packed), mLSE), make_shape(_4{}, Int{}), make_coord(_0{}, _)); // (4, M, _) + Tensor gdPsum = local_tile(cute::domain_offset(make_coord(_0{}, offset_q_packed), mdPsum), make_shape(_4{}, Int{}), make_coord(_0{}, _)); // (4, M, _) // NOTE: tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually - auto block_tma_Q = params.tma_load_Q.get_slice(cluster_block_id_qdo); + auto block_tma_Q = [&]() { + if constexpr (!PackGQA) { + return params.tma_load_Q.get_slice(cluster_block_id_qdo); + } else { + return params.tma_load_Q_packed.get_slice(cluster_block_id_qdo); + } + }(); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); - // auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout{}, - // group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE) // NOTE: tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually - auto block_tma_dO = params.tma_load_dO.get_slice(cluster_block_id_qdo); + auto block_tma_dO = [&]() { + if constexpr (!PackGQA) { + return params.tma_load_dO.get_slice(cluster_block_id_qdo); + } else { + return params.tma_load_dO_packed.get_slice(cluster_block_id_qdo); + } + }(); Tensor tdOgdO = group_modes<0, 3>(block_tma_dO.partition_S(gdO)); Tensor tdOsdO = group_modes<0, 3>(block_tma_dO.partition_D(sdO)); - // auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout{}, - // group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE) Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{})); Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{})); @@ -776,10 +813,17 @@ struct CollectiveMainloopBwdSm90 { // Define lambda funcs to load Q,dO,K,V,LSE,dPsum auto load_Q_LSE = [&, mcast_mask_qdo = mcast_mask_qdo](int const m_block_idx) { pipeline_q.producer_acquire(smem_pipe_write_q); - copy( - params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), - tQgQ(_, m_block_idx), - tQsQ(_, smem_pipe_write_q.index())); + if constexpr (!PackGQA) { + copy( + params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), + tQgQ(_, m_block_idx), + tQsQ(_, smem_pipe_write_q.index())); + } else { + copy( + params.tma_load_Q_packed.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), /*mcast_mask=*/0, TMA::CacheHintSm90::EVICT_LAST), + tQgQ(_, m_block_idx), + tQsQ(_, smem_pipe_write_q.index())); + } copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q)), gLSE(_, _, m_block_idx), sLSE(_, _, smem_pipe_write_q.index())); }; @@ -788,10 +832,17 @@ struct CollectiveMainloopBwdSm90 { // we can use the same pipeline state variable to reduce registers PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return(smem_pipe_write_q, smem_pipe_write_do); pipeline_do.producer_acquire(smem_pipe_write_do_cur); - copy( - params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), - tdOgdO(_, m_block_idx), - tdOsdO(_, smem_pipe_write_do_cur.index())); + if constexpr (!PackGQA) { + copy( + params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), + tdOgdO(_, m_block_idx), + tdOsdO(_, smem_pipe_write_do_cur.index())); + } else { + copy( + params.tma_load_dO_packed.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), /*mcast_mask=*/0, TMA::CacheHintSm90::EVICT_LAST), + tdOgdO(_, m_block_idx), + tdOsdO(_, smem_pipe_write_do_cur.index())); + } copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)), gdPsum(_, _, m_block_idx), sdPsum(_, _, smem_pipe_write_do_cur.index())); }; @@ -1443,6 +1494,8 @@ struct CollectiveMainloopBwdSm90 { int n_block = get<0>(block_coord), bidh = get<1>(block_coord), bidb = get<2>(block_coord); SeqlenInfo_t seqlen_info{bidb, params.q_ranges, params.k_ranges}; int const seqlen_q = seqlen_info.seqlen_q, seqlen_k = seqlen_info.seqlen_k; + // For PackGQA, the packed seqlen_q is seqlen_q * Qhead_per_khead + int const seqlen_q_packed = !PackGQA ? seqlen_q : seqlen_q * Qhead_per_khead; flash::AttnType attn_type = static_cast(params.attn_type_map ? params.attn_type_map[bidb] : 0); auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max(seqlen_info, n_block, bidb, attn_type); @@ -1680,7 +1733,7 @@ struct CollectiveMainloopBwdSm90 { Tensor t0ScS = thread0_mma.partition_C(cS); Tensor t0ScS_rowcol = make_tensor(t0ScS.data(), flash::convert_layout_acc_rowcol(t0ScS.layout())); int const thread_row_offset = get(tScS_rowcol(_0{}, _0{})); - int const seqlenq_row_limit = seqlen_q - m_block * kBlockM - thread_row_offset; + int const seqlenq_row_limit = seqlen_q_packed - m_block * kBlockM - thread_row_offset; #pragma unroll for (int mi = 0; mi < size<0>(scores); ++mi) { From 54b58634252a0e4e25e3fd878a7c884975f3e2e1 Mon Sep 17 00:00:00 2001 From: shw Date: Mon, 9 Feb 2026 13:09:59 +0000 Subject: [PATCH 12/14] support dq --- .../mainloop_bwd_sm90_tma_gmma_ws.hpp | 91 ++++++++++++++++--- 1 file changed, 79 insertions(+), 12 deletions(-) diff --git a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp index 574474cc8..0e8bfa9c3 100644 --- a/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/magi_attention/csrc/flexible_flash_attention/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -371,6 +371,14 @@ struct CollectiveMainloopBwdSm90 { TileShape_dQaccum{}, _1{})); // no mcast for partial dQ + // Packed TMA for dQ reduce-add when PackGQA is enabled (k for outer-loop and q for inner-loop) + using TMA_add_dQ_Packed = decltype(make_tma_copy( + GmemTiledCopydQaccum{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQPackedTMA{}, StrideQPackedTMA{}), + SmemLayoutdQaccumTMA{}, + TileShape_dQaccum{}, + _1{})); // no mcast for packed partial dQ + // q for outer-loop and k for inner-loop using TMA_add_dKV = decltype(make_tma_copy( GmemTiledCopydKVaccum{}, @@ -493,6 +501,8 @@ struct CollectiveMainloopBwdSm90 { ElementAccum* const ptr_dQ; // k for outer-loop and q for inner-loop ShapeQKV const shape_dQ; StrideQKV const stride_dQ; + ShapeQPackedTMA const shape_dQ_packed; // For PackGQA + StrideQPackedTMA const stride_dQ_packed; // For PackGQA ElementAccum* const ptr_dK; // q for outer-loop and k for inner-loop ShapeQKV const shape_dK; StrideQKV const stride_dK; @@ -506,6 +516,7 @@ struct CollectiveMainloopBwdSm90 { TMA_K tma_load_K; TMA_V tma_load_V; TMA_add_dQ tma_add_dQ; // k for outer-loop and q for inner-loop + TMA_add_dQ_Packed tma_add_dQ_packed; // For PackGQA, k for outer-loop and q for inner-loop TMA_add_dKV tma_add_dK; // q for outer-loop and k for inner-loop TMA_add_dKV tma_add_dV; // q for outer-loop and k for inner-loop float const* const ptr_LSE_log2; @@ -590,6 +601,30 @@ struct CollectiveMainloopBwdSm90 { Tensor mdQ = make_tensor(make_gmem_ptr(args.ptr_dQ), args.shape_dQ, args.stride_dQ); TMA_add_dQ tma_add_dQ = make_tma_copy(GmemTiledCopydQaccum{}, mdQ, SmemLayoutdQaccumTMA{}, TileShape_dQaccum{}, _1{}); + + // Create packed dQ shape/stride and TMA for PackGQA + auto const shape_dQ_packed = cute::conditional_return( + args.shape_dQ, + make_shape( + make_shape(cute::Int{}, get<0>(args.shape_dQ)), // (qhead_per_khead, seqlen) + get<1>(args.shape_dQ), // headdim + get<2>(args.shape_K) // numhead_kv + )); + auto const stride_dQ_packed = cute::conditional_return( + args.stride_dQ, + make_stride( + make_stride(get<2>(args.stride_dQ), get<0>(args.stride_dQ)), // (head_stride, seq_stride) + get<1>(args.stride_dQ), // headdim + get<2>(args.stride_dQ) * Qhead_per_khead)); + auto mdQPacked = [&]() { + if constexpr (!PackGQA) { + return mdQ; + } else { + return make_tensor(make_gmem_ptr(args.ptr_dQ), make_layout(shape_dQ_packed, stride_dQ_packed)); + } + }(); + TMA_add_dQ_Packed tma_add_dQ_packed = make_tma_copy(GmemTiledCopydQaccum{}, mdQPacked, SmemLayoutdQaccumTMA{}, TileShape_dQaccum{}, _1{}); + Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK); TMA_add_dKV tma_add_dK = make_tma_copy(GmemTiledCopydKVaccum{}, mdK, SmemLayoutdKVaccumTMA{}, TileShape_dKVaccum{}, _1{}); Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV); @@ -632,6 +667,8 @@ struct CollectiveMainloopBwdSm90 { args.ptr_dQ, args.shape_dQ, args.stride_dQ, + shape_dQ_packed, + stride_dQ_packed, args.ptr_dK, args.shape_dK, args.stride_dK, @@ -646,6 +683,7 @@ struct CollectiveMainloopBwdSm90 { tma_load_K, tma_load_V, tma_add_dQ, + tma_add_dQ_packed, tma_add_dK, tma_add_dV, args.ptr_LSE_log2, @@ -703,10 +741,6 @@ struct CollectiveMainloopBwdSm90 { flash::AttnType attn_type = static_cast(params.attn_type_map ? params.attn_type_map[bidb] : 0); auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max(seqlen_info, n_block, bidb, attn_type); - if (threadIdx.x == 0) { - printf( - "load_with_loop_q: n_block: %d, m_block_min: %d, m_block_max: %d, bidh: %d, bidh_kv: %d bidb: %d\n", n_block, m_block_min, m_block_max, bidh, bidh_kv, bidb); - } // It's possible to have m_block_max <= m_block_min, // where loading Q,dO might cause illegal memory access if (m_block_max <= m_block_min) { @@ -1228,7 +1262,8 @@ struct CollectiveMainloopBwdSm90 { } int const last_n_block = cute::ceil_div(seqlen_info.seqlen_k, kBlockN) - 1; - int const m_block_num = cute::ceil_div(seqlen_info.seqlen_q, kBlockM); + int const seqlen_q_packed = !PackGQA ? seqlen_info.seqlen_q : seqlen_info.seqlen_q * Qhead_per_khead; + int const m_block_num = cute::ceil_div(seqlen_q_packed, kBlockM); bool const lane_predicate = cute::elect_one_sync(); int const num_heads = get<2>(params.shape_Q); @@ -1275,10 +1310,24 @@ struct CollectiveMainloopBwdSm90 { } Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccumTMA{}); - Tensor mdQaccum = params.tma_add_dQ.get_tma_tensor(params.shape_dQ)(_, _, bidh); // (seqlen_q, head_dim) - Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdQaccum), TileShape_dQaccum{}, make_coord(_, _0{})); // (M, K, _) + // When PackGQA, use packed TMA descriptor and packed offset (bidh is already bidh_kv) + auto mdQaccum = [&]() { + if constexpr (!PackGQA) { + return params.tma_add_dQ.get_tma_tensor(params.shape_dQ)(_, _, bidh); // (seqlen_q, head_dim) + } else { + return params.tma_add_dQ_packed.get_tma_tensor(params.shape_dQ_packed)(_, _, bidh); // ((qhead_per_khead, seqlen_q), head_dim) + } + }(); + int const offset_q_dQ = !PackGQA ? seqlen_info.offset_q : seqlen_info.offset_q * Qhead_per_khead; + Tensor gdQaccum = local_tile(domain_offset(make_coord(offset_q_dQ, _0{}), mdQaccum), TileShape_dQaccum{}, make_coord(_, _0{})); // (M, K, _) - auto block_tma_dQ = params.tma_add_dQ.get_slice(_0{}); + auto block_tma_dQ = [&]() { + if constexpr (!PackGQA) { + return params.tma_add_dQ.get_slice(_0{}); + } else { + return params.tma_add_dQ_packed.get_slice(_0{}); + } + }(); Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K) Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K) @@ -1305,7 +1354,11 @@ struct CollectiveMainloopBwdSm90 { if constexpr (Deterministic) { m_block_sync(m_block); } - cute::copy(params.tma_add_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block)); + if constexpr (!PackGQA) { + cute::copy(params.tma_add_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block)); + } else { + cute::copy(params.tma_add_dQ_packed, tdQsdQ, tdQgdQ(_, _, _, m_block)); + } tma_store_arrive(); tma_store_wait<0>(); if constexpr (Deterministic) { @@ -1605,11 +1658,25 @@ struct CollectiveMainloopBwdSm90 { }; // For the case where we do atomicAdd directly to gdQaccum instead of using TMA - Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQ)), params.shape_dQ, params.stride_dQ)(_, _, bidh); // (seqlen_q, head_dim) - Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdQaccum), TileShape_dQaccum{}, make_coord(_, _0{})); // (M, K, _) + // When PackGQA, use packed shape/stride and packed offset (bidh is already bidh_kv) + auto mdQaccum = [&]() { + if constexpr (!PackGQA) { + return make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQ)), params.shape_dQ, params.stride_dQ)(_, _, bidh); + } else { + return make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQ)), params.shape_dQ_packed, params.stride_dQ_packed)(_, _, bidh); + } + }(); + int const offset_q_dQ = !PackGQA ? seqlen_info.offset_q : seqlen_info.offset_q * Qhead_per_khead; + Tensor gdQaccum_ = local_tile(domain_offset(make_coord(offset_q_dQ, _0{}), mdQaccum), TileShape_dQaccum{}, make_coord(_, _0{})); // (M, K, _) Tensor gdQaccum = cute::flat_divide(gdQaccum_, make_shape(Int{}, Int{})); // (M / WG, K, WG, 1, _) - auto block_tma_dQ = params.tma_add_dQ.get_slice(_0{}); + auto block_tma_dQ = [&]() { + if constexpr (!PackGQA) { + return params.tma_add_dQ.get_slice(_0{}); + } else { + return params.tma_add_dQ_packed.get_slice(_0{}); + } + }(); Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K) Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K) From 10ab6e455b356ba30d2c54326e074e68c9c4fd2a Mon Sep 17 00:00:00 2001 From: shw Date: Mon, 9 Feb 2026 13:17:51 +0000 Subject: [PATCH 13/14] assert bwd packgqa and deterministic --- .../csrc/flexible_flash_attention/bwd_inst_template.jinja | 2 ++ magi_attention/functional/flex_flash_attn.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja b/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja index 924bd7e63..361ec3ed6 100644 --- a/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja +++ b/magi_attention/csrc/flexible_flash_attention/bwd_inst_template.jinja @@ -26,6 +26,8 @@ static constexpr int kQheadPerKhead = {{ qhead_per_khead }}; static_assert(!kSwapBwdQKLoop || (!kRangeMerge && !kDeterministic), "Neither RangeMerge nor Deterministic mode is supported by now when SwapBwdQKLoop is enabled."); // PackGQA is supported for both SwapBwdQKLoop=true and SwapBwdQKLoop=false +// TODO: add support for Deterministic mode with PackGQA +static_assert(!kPackGQA || !kDeterministic, "Deterministic mode is not supported with PackGQA in backward pass."); // Runtime contract checks to ensure consistency with compile-time constraints static inline void _check_runtime_contract_bwd( diff --git a/magi_attention/functional/flex_flash_attn.py b/magi_attention/functional/flex_flash_attn.py index 5f62620dc..8b3ad3288 100644 --- a/magi_attention/functional/flex_flash_attn.py +++ b/magi_attention/functional/flex_flash_attn.py @@ -826,7 +826,8 @@ def backward(ctx, dout: torch.Tensor, *args): # pragma: no cover merge_k_ranges, bwd_kq_map, bwd_unique_count = None, None, None # pack_gqa in backward - bwd_pack_gqa = ctx.pack_gqa + # Deterministic mode is not yet supported with PackGQA + bwd_pack_gqa = ctx.pack_gqa and not ctx.deterministic dq, dk, dv, dsink = _flex_flash_attn_backward( dout=dout, From dd32e179085ed77820131bacad4e315b4db3a50f Mon Sep 17 00:00:00 2001 From: shw Date: Mon, 9 Feb 2026 13:42:59 +0000 Subject: [PATCH 14/14] add todo for bwd packgqa and deterministic --- tests/test_attn/test_flex_flash_attn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_attn/test_flex_flash_attn.py b/tests/test_attn/test_flex_flash_attn.py index fba30f857..024fcbbd5 100644 --- a/tests/test_attn/test_flex_flash_attn.py +++ b/tests/test_attn/test_flex_flash_attn.py @@ -1661,6 +1661,8 @@ def test_ffa_simple( sparse_load = ref_block_config["sparse_load"] return_max_logits = bool(flag_comb.get("return_max_logits", False)) + # TODO: bwd pack_gqa combine with deterministic are not supported yet. + # skip invalid flag combinations if swap_bwd_qk_loop: # TODO: support auto_range_merge mode with swap_bwd_qk_loop @@ -1863,6 +1865,7 @@ def test_ffa_random( sparse_load = ref_block_config["sparse_load"] return_max_logits = bool(flag_comb.get("return_max_logits", False)) + # TODO: bwd pack_gqa combine with deterministic are not supported yet. # skip invalid flag combinations if swap_bwd_qk_loop: # TODO: support auto_range_merge mode with swap_bwd_qk_loop