From da5032eda3e81f7198fddfc0332682e30cd54326 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Mon, 20 Jan 2025 00:46:51 -0800 Subject: [PATCH 1/2] draft --- hopper/epilogue_fwd.hpp | 40 ++++---- hopper/flash.h | 6 +- hopper/flash_api.cpp | 63 ++++++++---- hopper/flash_bwd_launch_template.h | 17 ++-- hopper/flash_bwd_preprocess_kernel.h | 2 +- hopper/flash_fwd_kernel_sm90.h | 13 ++- hopper/flash_fwd_launch_template.h | 15 +-- .../flash_bwd_hdim192_128_bf16_sm90.cu | 12 +++ .../flash_fwd_hdim192_128_bf16_sm90.cu | 9 ++ .../flash_fwd_hdimall_bf16_sm90.cu | 14 +-- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 41 ++++---- hopper/setup.py | 11 ++- hopper/test_flash_attn.py | 8 +- hopper/test_util.py | 97 ++++++++++++++----- 14 files changed, 234 insertions(+), 114 deletions(-) create mode 100644 hopper/instantiations/flash_bwd_hdim192_128_bf16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index 0f916060260..598d14578eb 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -20,11 +20,11 @@ namespace flash { using namespace cute; -template struct CollectiveEpilogueFwd { - using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_VO = TileShape_MNK_VO_; using ClusterShape = ClusterShape_; using Element = Element_; using ArchTag = ArchTag_; @@ -37,18 +37,18 @@ struct CollectiveEpilogueFwd { static_assert(ArchTag::kMinComputeCapability >= 80); static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_VO{}); + static constexpr int kHeadDim_VO = get<2>(TileShape_MNK_VO{}); using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; // These are for storing the output tensor without TMA (e.g., for setting output to zero) static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); + static_assert(kHeadDim_VO % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times // we need to call divmod. - static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBytePerRow = kHeadDim_VO * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); // static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerStore, NumEpilogueThreads); @@ -65,15 +65,15 @@ struct CollectiveEpilogueFwd { Layout>>{})); // Val layout, 8 or 16 vals per store using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 2>(TileShape_MNK{}))); + decltype(cute::get<0>(TileShape_MNK_VO{})), decltype(cute::get<2>(TileShape_MNK_VO{}))>()); + using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 2>(TileShape_MNK_VO{}))); static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); - using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK_VO{}))); using SmemLayoutO = std::conditional_t= 90, SmemLayoutOTMA, SmemLayoutOSTS>; using ShapeO = cute::Shape; // (seqlen_q, d, head, batch, num_splits) @@ -109,7 +109,7 @@ struct CollectiveEpilogueFwd { GmemTiledCopyOTMA{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), SmemLayoutOTMA{}, - select<0, 2>(TileShape_MNK{}), + select<0, 2>(TileShape_MNK_VO{}), _1{})), // no mcast for O std::nullptr_t >; @@ -148,7 +148,7 @@ struct CollectiveEpilogueFwd { Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); TMA_O tma_store_O = [&]{ if constexpr (Use_TMA_O) { - return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{}); // no mcast + return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK_VO{}), _1{}); // no mcast } else { return nullptr; } @@ -243,14 +243,14 @@ struct CollectiveEpilogueFwd { // Step 2: Write LSE from rmem -> gmem auto thread_mma = tiled_mma.get_thread_slice(thread_idx); // (MMA,MMA_M,MMA_K) - Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK_VO{}))); static_assert(decltype(size<0, 0>(taccOcO))::value == 2); static_assert(decltype(size<0, 1>(taccOcO))::value == 2); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; + using PackGQAt = flash::PackGQAManager(TileShape_MNK_VO{}), get<2>(TileShape_MNK_VO{}), NumEpilogueThreads, Element>; Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } @@ -267,7 +267,7 @@ struct CollectiveEpilogueFwd { // Step 3: Write O from smem -> gmem if constexpr (Use_TMA_O) { Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK_VO{}), make_coord(m_block, _0{})); // (M, K) auto block_tma_O = params.tma_store_O.get_slice(_0{}); Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) @@ -287,7 +287,7 @@ struct CollectiveEpilogueFwd { } } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK_VO{}), make_coord(m_block, _0{})); // (M, K) // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } if constexpr (Use_smem) { GmemTiledCopyO gmem_tiled_copy_O; @@ -305,7 +305,7 @@ struct CollectiveEpilogueFwd { } if constexpr (!PackGQA) { // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK_VO{}))); Tensor tOpO = make_tensor(make_shape(size<2>(tOsO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } @@ -361,7 +361,7 @@ struct CollectiveEpilogueFwd { int thread_idx, cute::tuple const& block_coord ) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_VO{}); auto [m_block, bidh, bidb, split_idx] = block_coord; flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; @@ -391,12 +391,12 @@ struct CollectiveEpilogueFwd { GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK_VO{}))); if constexpr (!PackGQA) { Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK_VO{}), make_coord(m_block, _0{})); // (M, K) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOrO = make_fragment_like(tOgO); cute::clear(tOrO); @@ -406,7 +406,7 @@ struct CollectiveEpilogueFwd { ); } else { // If PackGQA, we split the work of compute O_ptr among threads in the same row - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; + using PackGQAt = flash::PackGQAManager(TileShape_MNK_VO{}), get<2>(TileShape_MNK_VO{}), NumEpilogueThreads, Element>; Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); cute::clear(tOrO); PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); diff --git a/hopper/flash.h b/hopper/flash.h index 4559a1352e4..1c27f2d833f 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -62,7 +62,7 @@ struct Flash_fwd_params : public Qkv_params { index_t v_descale_head_stride; // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, d_vo, d_vo_rounded; int total_q, total_k, total_knew; int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q @@ -197,9 +197,9 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 82643d9fff4..b5c148d1cbd 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -70,6 +70,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, const size_t h_k, const size_t d, const size_t d_rounded, + const size_t d_vo, + const size_t d_vo_rounded, // device pointers const at::Tensor q, const at::Tensor k, @@ -136,6 +138,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; + params.d_vo = d_vo; + params.d_vo_rounded = d_vo_rounded; // Set the different scale values. params.scale_softmax = softmax_scale; @@ -184,6 +188,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, const size_t h_k, const size_t d, const size_t d_rounded, + const size_t d_vo, + const size_t d_vo_rounded, // device pointers const at::Tensor q, const at::Tensor k, @@ -211,7 +217,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms, int const sm_margin=0) { set_params_fprop(params, - b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, d_vo, d_vo_rounded, q, k, v, out, cu_seqlens_q_d, cu_seqlens_k_d, @@ -280,7 +286,13 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_(params, stream); } + if (params.d <= 192) { + if (params.d_vo == 128) { + return run_mha_fwd_(params, stream); + } else { + return run_mha_fwd_(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d <= 256) { return run_mha_fwd_(params, stream); } @@ -551,6 +563,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; int num_heads = q.size(-2); int const head_size = q.size(-1); + int const head_size_vo = v.size(-1); int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); int const num_pages = !paged_KV ? 0 : k.size(0); int const page_size = !paged_KV ? 1 : k.size(1); @@ -583,15 +596,15 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (!paged_KV) { if (!is_varlen_k) { CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_vo); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_vo); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } } else { CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_vo); CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); } @@ -620,16 +633,17 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); if (!is_varlen_q) { - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_vo); } else { - CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_vo); } } else { - out = torch::empty_like(q, opts.dtype(out_type)); + out = torch::empty({total_q, num_heads, head_size_vo}, q.options()); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); + int const head_size_vo_rounded = round_up_headdim(head_size_vo); int const seqlen_q_rounded = round_multiple(seqlen_q, 128); int const seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -651,6 +665,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, + head_size_vo, head_size_vo_rounded, q, k, v, out, !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), @@ -772,12 +787,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (params.num_splits > 1) { TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); if (!is_varlen_q) { - out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(outaccum_type)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_vo}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); params.oaccum_batch_stride = out_accum.stride(1); params.lseaccum_batch_stride = softmax_lse_accum.stride(1); } else { - out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size}, opts.dtype(outaccum_type)); + out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_vo}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat)); } params.is_fp32 = false; @@ -921,7 +936,13 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { if (params.d <= 128) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_bwd_(params, stream); } + if (params.d <= 192) { + if (params.d_vo == 128) { + return run_mha_bwd_(params, stream); + } else { + return run_mha_bwd_(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d <= 256) { return run_mha_bwd_(params, stream); } @@ -1017,10 +1038,12 @@ std::vector mha_bwd( int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; int const num_heads = q.size(-2); int const head_size = q.size(-1); + int const head_size_vo = v.size(-1); int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value(); int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); int const num_heads_k = k.size(-2); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_vo % 8 == 0, "head_size should be a multiple of 8"); int const max_headdim = get_max_headdim(); TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -1035,6 +1058,7 @@ std::vector mha_bwd( int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; int const head_size_rounded = round_up_headdim(head_size); + int const head_size_vo_rounded = round_up_headdim(head_size_vo); // Very important that these match the kernel configs bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) @@ -1063,20 +1087,20 @@ std::vector mha_bwd( if (!is_varlen_q) { CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_vo); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_vo); } else { CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(out, total_q, num_heads, head_size); - CHECK_SHAPE(dout, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_vo); + CHECK_SHAPE(dout, total_q, num_heads, head_size_vo); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); } if (!is_varlen_k) { CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_vo); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_vo); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } @@ -1126,9 +1150,9 @@ std::vector mha_bwd( CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); if (!is_varlen_k) { - CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_vo); } else { - CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size_vo); } } else { dv = torch::empty_like(v); @@ -1172,6 +1196,7 @@ std::vector mha_bwd( seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, + head_size_vo, head_size_vo_rounded, q, k, v, out, dout, dq, dk, dv, !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index 635228eebcf..bead0a790c0 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -28,7 +28,7 @@ template + bool V_in_regs=false, int kHeadDim_VO = kHeadDim> void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time."); using ElementAccum = float; @@ -46,7 +46,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { int batch_k = !is_varlen_k ? params.b : 1; using TileShape_MK = cute::Shape, Int>; - using PreprocessKernel = flash::FlashAttnBwdPreprocess; + using TileShape_MK_VO = cute::Shape, Int>; + using PreprocessKernel = flash::FlashAttnBwdPreprocess; typename PreprocessKernel::Arguments preprocess_args { static_cast(params.o_ptr), {seqlen_q, params.d, params.h, batch_q}, // shape_O @@ -284,13 +285,13 @@ template + bool V_in_regs=false, int kHeadDim_VO=kHeadDim> void run_mha_bwd_dispatch(Flash_bwd_params ¶ms, cudaStream_t stream) { VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { BOOL_SWITCH(params.h != params.h_k, GQA, [&] { // BOOL_SWITCH(params.deterministic, Deterministic, [&] { // run_flash_bwd(params, stream); - run_flash_bwd(params, stream); + run_flash_bwd(params, stream); // }); }); }); @@ -349,15 +350,15 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { }); } -template +template void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { if constexpr (Arch >= 90) { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); } else if constexpr (Arch == 86 || Arch == 89) { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); } else { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); } }); } diff --git a/hopper/flash_bwd_preprocess_kernel.h b/hopper/flash_bwd_preprocess_kernel.h index 85e877f9d4f..5b3d5606c83 100644 --- a/hopper/flash_bwd_preprocess_kernel.h +++ b/hopper/flash_bwd_preprocess_kernel.h @@ -18,7 +18,7 @@ namespace flash { using namespace cute; -template +template class FlashAttnBwdPreprocess { public: diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index e5411042dc9..1ebcfb79210 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -49,6 +49,7 @@ class FlashAttnFwdSm90 { // Mainloop derived types using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; + using TileShape_MNK_VO = typename CollectiveMainloop::TileShape_MNK_VO; using TiledMma0 = typename CollectiveMainloop::TiledMma0; using TiledMma1 = typename CollectiveMainloop::TiledMma1; using ArchTag = typename CollectiveMainloop::ArchTag; @@ -210,6 +211,7 @@ class FlashAttnFwdSm90 { // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); PipelineParamsK pipeline_params_k; + PipelineParamsV pipeline_params_v; pipeline_params_k.role = warp_group_idx == 0 ? MainloopPipelineK::ThreadCategory::Producer : MainloopPipelineK::ThreadCategory::Consumer; @@ -217,9 +219,12 @@ class FlashAttnFwdSm90 { pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; pipeline_params_k.is_leader = warp_group_thread_idx == 0; pipeline_params_k.num_consumers = NumMmaThreads; + pipeline_params_v = pipeline_params_k; + pipeline_params_v.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; } else { pipeline_params_k.consumer_arv_count = NumMmaThreads; pipeline_params_k.producer_arv_count = NumProducerThreads; + pipeline_params_v = pipeline_params_k; } MainloopPipelineK pipeline_k = [&] { @@ -232,11 +237,11 @@ class FlashAttnFwdSm90 { // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{}); MainloopPipelineV pipeline_v = [&] { if constexpr (!Transpose_V) { - static_assert(is_same_v); + // static_assert(is_same_v); if constexpr (Use_TMA_KV) { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k, ClusterShape{}); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{}); } else { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); } } else { PipelineParamsV pipeline_params_v; @@ -357,7 +362,7 @@ class FlashAttnFwdSm90 { work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); + Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK_VO{})); float softmax_scale_log2 = params.mainloop.softmax_scale_log2; // If there's tanh softcap, the scaling will be done before tanh. auto block_coord = work_tile_info.get_block_coord(params.scheduler); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 16701f160d2..ca28b9d2d1f 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -25,7 +25,7 @@ using namespace cute; template + bool PackGQA, bool Split, bool V_colmajor, int kHeadDim_VO> void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); @@ -46,13 +46,14 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); using TileShape_MNK = cute::Shape, Int, Int>; + using TileShape_MNK_VO = cute::Shape, Int, Int>; using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm90, flash::CollectiveMainloopFwdSm80 >; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; using SchedulerPersistent = std::conditional_t(params.v_ptr), + {!PagedKV ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, + params.d_vo, params.h_k, !PagedKV ? batch_k : params.num_pages}, // shape_V v_strides, // stride_V static_cast(params.knew_ptr), {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new @@ -124,7 +127,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(!Split ? params.o_ptr : params.oaccum_ptr), - {seqlen_q, params.d, params.h, batch_q, params.num_splits}, // shape_O + {seqlen_q, params.d_vo, params.h, batch_q, params.num_splits}, // shape_O {!Split ? params.o_row_stride : params.oaccum_row_stride, _1{}, !Split ? params.o_head_stride : params.oaccum_head_stride, @@ -179,7 +182,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; @@ -196,7 +199,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + run_flash_fwd(params, stream); }); }); }); diff --git a/hopper/instantiations/flash_bwd_hdim192_128_bf16_sm90.cu b/hopper/instantiations/flash_bwd_hdim192_128_bf16_sm90.cu new file mode 100644 index 00000000000..7c1df91ff82 --- /dev/null +++ b/hopper/instantiations/flash_bwd_hdim192_128_bf16_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template<> +void run_mha_bwd_<90, cutlass::bfloat16_t, 192, false, 128>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false, 128>(params, stream); +} +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu new file mode 100644 index 00000000000..64994a4179c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, false, false, 128>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu index 2aac1970b1b..39bf7dff8e7 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu @@ -1,9 +1,11 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, +// Pradeep Ramani, Tri Dao. Splitting the different template instantiations to +// different files to speed up compilation. This file is auto-generated. See +// "generate_kernels.py" -#include "flash_fwd_hdim64_bf16_sm90.cu" -#include "flash_fwd_hdim96_bf16_sm90.cu" #include "flash_fwd_hdim128_bf16_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_sm90.cu" #include "flash_fwd_hdim192_bf16_sm90.cu" -#include "flash_fwd_hdim256_bf16_sm90.cu" \ No newline at end of file +#include "flash_fwd_hdim256_bf16_sm90.cu" +#include "flash_fwd_hdim64_bf16_sm90.cu" +#include "flash_fwd_hdim96_bf16_sm90.cu" diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index dbbf2f8f821..d2839660d12 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -27,7 +27,7 @@ namespace flash { using namespace cute; -template struct CollectiveMainloopFwdSm90 { @@ -35,6 +35,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kStages = Stages; using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_VO = TileShape_MNK_VO_; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -63,6 +64,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kHeadDim_VO = get<2>(TileShape_MNK_VO{}); // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. @@ -84,9 +86,9 @@ struct CollectiveMainloopFwdSm90 { std::conditional_t< !Mma1_is_RS, decltype(cute::GMMA::ss_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()), + decltype(select<0, 2, 1>(TileShape_MNK_VO{})), GMMA::Major::K, MmaMajorV>()), decltype(cute::GMMA::rs_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()) + decltype(select<0, 2, 1>(TileShape_MNK_VO{})), GMMA::Major::K, MmaMajorV>()) >{}, AtomLayoutMNK{})); @@ -107,25 +109,25 @@ struct CollectiveMainloopFwdSm90 { make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + decltype(cute::get<2>(TileShape_MNK_VO{})), decltype(cute::get<1>(TileShape_MNK_VO{}))>()); using SmemLayoutVt = decltype(tile_to_shape( SmemLayoutAtomVt{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + make_shape(shape<2>(TileShape_MNK_VO{}), shape<1>(TileShape_MNK_VO{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + decltype(cute::get<2>(TileShape_MNK_VO{})), decltype(cute::get<1>(TileShape_MNK_VO{}))>()); using SmemLayoutVtMma = decltype(tile_to_shape( SmemLayoutAtomVtMma{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + make_shape(shape<2>(TileShape_MNK_VO{}), shape<1>(TileShape_MNK_VO{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); // Only used if we're using cp.async to load V using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK_VO{})), decltype(cute::get<2>(TileShape_MNK_VO{}))>()); using SmemLayoutVCpAsync = decltype(tile_to_shape( SmemLayoutAtomVCpAsync{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + make_shape(shape<1>(TileShape_MNK_VO{}), shape<2>(TileShape_MNK_VO{}), Int{}))); using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); @@ -221,14 +223,14 @@ struct CollectiveMainloopFwdSm90 { GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<2, 1>(TileShape_MNK_VO{}), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesV = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); - static_assert(TmaTransactionBytesK == TmaTransactionBytesV); + // static_assert(TmaTransactionBytesK == TmaTransactionBytesV); using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; using MainloopPipelineK = std::conditional_t>; @@ -294,6 +296,7 @@ struct CollectiveMainloopFwdSm90 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + ShapeQKV const shape_V; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -335,6 +338,7 @@ struct CollectiveMainloopFwdSm90 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + ShapeQKV const shape_V; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -388,12 +392,12 @@ struct CollectiveMainloopFwdSm90 { take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{}); // mcast along M mode for this N load, if any - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), select<1, 0, 2, 3>(args.shape_K), select<1, 0, 2, 3>(args.stride_V)); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), select<1, 0, 2, 3>(args.shape_V), select<1, 0, 2, 3>(args.stride_V)); TMA_V tma_load_V = make_tma_copy( GmemTiledCopyKV{}, mV, take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<2, 1>(TileShape_MNK_VO{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new); TMA_K tma_load_K_new = make_tma_copy_B_sm90( @@ -407,7 +411,7 @@ struct CollectiveMainloopFwdSm90 { GmemTiledCopyKV{}, cute::conditional_return(mVnew, mV), take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<2, 1>(TileShape_MNK_VO{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); @@ -429,7 +433,7 @@ struct CollectiveMainloopFwdSm90 { // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, - args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.stride_V, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.shape_V, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, @@ -560,7 +564,7 @@ struct CollectiveMainloopFwdSm90 { Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<2, 1>(TileShape_MNK_VO{}), make_coord(_0{}, _)); // (K, N, _) auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) @@ -1210,7 +1214,7 @@ struct CollectiveMainloopFwdSm90 { Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K_new))(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<2, 1>(TileShape_MNK_VO{}), make_coord(_0{}, _)); // (K, N, _) auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) @@ -1306,10 +1310,11 @@ struct CollectiveMainloopFwdSm90 { int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK_VO{}), make_coord(_, _0{})); // (N, K, _) static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kHeadDim_VO = get<2>(TileShape_MNK_VO{}); int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; int const seqlen_k_new = seqlen_info.seqlen_k_new; using Rotary_t = Rotary; diff --git a/hopper/setup.py b/hopper/setup.py index d95be9ad409..a1daa1e6667 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -450,11 +450,12 @@ def nvcc_threads_args(): DTYPE_BWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) HEAD_DIMENSIONS_BWD = ( [] - + ([64] if not DISABLE_HDIM64 else []) - + ([96] if not DISABLE_HDIM96 else []) - + ([128] if not DISABLE_HDIM128 else []) - + ([192] if not DISABLE_HDIM192 else []) - + ([256] if not DISABLE_HDIM256 else []) + + (["64"] if not DISABLE_HDIM64 else []) + + (["96"] if not DISABLE_HDIM96 else []) + + (["128"] if not DISABLE_HDIM128 else []) + + (["192"] if not DISABLE_HDIM192 else []) + + (["256"] if not DISABLE_HDIM256 else []) + + (["192_128"] if not DISABLE_HDIM256 else []) ) HEAD_DIMENSIONS_FWD = ["all"] HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 1fe43e21fa2..3aedeb684c1 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -1,6 +1,7 @@ import os import math import itertools +from typing import Tuple import pytest import torch @@ -41,6 +42,7 @@ + ([128] if not DISABLE_HDIM128 else []) + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) + + ([(192, 128)] if not DISABLE_HDIM192 else []) ) @@ -309,6 +311,10 @@ def test_flash_attn_output( def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, mha_type, dtype ): + if isinstance(d, Tuple): + d, d_vo = d + else: + d_vo = d device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) @@ -326,7 +332,7 @@ def test_flash_attn_varlen_output( q_ref = (q_ref * softcap / 4).detach().requires_grad_() q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d_vo, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) if dtype == torch.float8_e4m3fn: diff --git a/hopper/test_util.py b/hopper/test_util.py index 54eb195eb36..a11f82ca419 100644 --- a/hopper/test_util.py +++ b/hopper/test_util.py @@ -6,16 +6,25 @@ from padding import pad_input, unpad_input -def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): +def generate_random_padding_mask( + max_seqlen, batch_size, device, mode="random", zero_lengths=False +): assert mode in ["full", "random", "third"] if mode == "full": - lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + lengths = torch.full( + (batch_size, 1), max_seqlen, device=device, dtype=torch.int32 + ) elif mode == "random": lengths = torch.randint( - max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device + max(0 if zero_lengths else 1, max_seqlen - 20), + max_seqlen + 1, + (batch_size, 1), + device=device, ) elif mode == "third": - lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + lengths = torch.randint( + max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device + ) if zero_lengths: # Generate zero-lengths every 5 batches and the last batch. @@ -24,28 +33,37 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", lengths[i] = 0 lengths[-1] = 0 padding_mask = ( - repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) + < lengths ) return padding_mask def generate_qkv( - q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False, - query_unused_mask=None, key_unused_mask=None, + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + kvpacked=False, + qkvpacked=False, + query_unused_mask=None, + key_unused_mask=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, d) k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d_v) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape _, seqlen_k, nheads_k, _ = k.shape + _, _, _, d_v = v.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) if query_unused_mask is not None or key_unused_mask is not None: assert not kvpacked assert not qkvpacked @@ -60,7 +78,11 @@ def generate_qkv( else: q_unpad = rearrange(q, "b s h d -> (b s) h d") cu_seqlens_q = torch.arange( - 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q_unpad.device, ) seqused_q = None max_seqlen_q = seqlen_q @@ -77,18 +99,25 @@ def generate_qkv( k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=k_unpad.device, ) seqused_k = None max_seqlen_k = seqlen_k if qkvpacked: + assert d != d_v assert (query_padding_mask == key_padding_mask).all() assert nheads == nheads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: - dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + dqkv_pad_fn = lambda dqkv_unpad: pad_input( + dqkv_unpad, indices_q, batch_size, seqlen_q + ) else: dqkv_pad_fn = lambda dqkv_unpad: rearrange( dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size @@ -102,11 +131,14 @@ def generate_qkv( dqkv_pad_fn, ) elif kvpacked: + assert d != d_v kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) kv = torch.stack([k, v], dim=2) dq_pad_fn = output_pad_fn if key_padding_mask is not None: - dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + dkv_pad_fn = lambda dkv_unpad: pad_input( + dkv_unpad, indices_k, batch_size, seqlen_k + ) else: dkv_pad_fn = lambda dkv_unpad: rearrange( dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size @@ -127,9 +159,13 @@ def generate_qkv( else: dq_pad_fn = output_pad_fn if key_padding_mask is not None: - dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + dk_pad_fn = lambda dk_unpad: pad_input( + dk_unpad, indices_k, batch_size, seqlen_k + ) else: - dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + dk_pad_fn = lambda dk_unpad: rearrange( + dk_unpad, "(b s) h d -> b s h d", b=batch_size + ) return ( q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), @@ -159,7 +195,9 @@ def construct_local_mask( key_leftpad=None, device=None, ): - row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + row_idx = rearrange( + torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1" + ) col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") @@ -181,7 +219,10 @@ def construct_local_mask( sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk return torch.logical_or( col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), - torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), + torch.logical_and( + col_idx < row_idx + sk - sq - window_size[0], + col_idx >= sink_token_length, + ), ) @@ -196,7 +237,9 @@ def attention_ref( dropout_p=0.0, dropout_mask=None, causal=False, - q_descale=None, k_descale=None, v_descale=None, + q_descale=None, + k_descale=None, + v_descale=None, window_size=(-1, -1), # -1 means infinite window size sink_token_length=0, softcap=0.0, @@ -230,7 +273,7 @@ def attention_ref( if upcast: q, k, v = q.float(), k.float(), v.float() if q_descale is not None: - q_descale = repeat(q_descale, "b h -> b (h g)", g = q.shape[2] // k.shape[2]) + q_descale = repeat(q_descale, "b h -> b (h g)", g=q.shape[2] // k.shape[2]) q = (q.float() * rearrange(q_descale, "b h -> b 1 h 1")).to(dtype=q.dtype) if k_descale is not None: k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) @@ -247,7 +290,9 @@ def attention_ref( if softcap > 0: scores = torch.tanh(scores / softcap) * softcap if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + scores.masked_fill_( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") + ) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, @@ -266,13 +311,19 @@ def attention_ref( # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + attention = attention.masked_fill( + rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 + ) # Without this we might get NaN in dv if key_padding_mask is not None: - attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + attention = attention.masked_fill( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0 + ) # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: - attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + attention = attention.masked_fill( + torch.all(local_mask, dim=-1, keepdim=True), 0.0 + ) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) From 4ee927ed210daba68770a613b12e55630fe151ed Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Mon, 20 Jan 2025 22:32:38 -0800 Subject: [PATCH 2/2] bwd --- hopper/epilogue_bwd.hpp | 229 ++++++++++++++--------- hopper/flash_bwd_kernel_sm90.h | 20 +- hopper/flash_bwd_launch_template.h | 68 ++++--- hopper/flash_bwd_postprocess_kernel.h | 10 + hopper/flash_bwd_preprocess_kernel.h | 7 +- hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp | 202 ++++++++++++-------- 6 files changed, 339 insertions(+), 197 deletions(-) diff --git a/hopper/epilogue_bwd.hpp b/hopper/epilogue_bwd.hpp index f99dfe918e8..8f39e1296ec 100644 --- a/hopper/epilogue_bwd.hpp +++ b/hopper/epilogue_bwd.hpp @@ -19,10 +19,12 @@ namespace flash { using namespace cute; template + int NumEpilogueThreads_, bool Varlen_, bool dKV_swapAB_, + int AtomLayoutKdKV=1, class TileShape_MNK_VO_=TileShape_MNK_> struct CollectiveEpilogueBwd { using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_VO = TileShape_MNK_VO_; using Element = Element_; using ArchTag = ArchTag_; static constexpr int NumEpilogueThreads = NumEpilogueThreads_; @@ -38,6 +40,7 @@ struct CollectiveEpilogueBwd { static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kHeadDim_VO = get<2>(TileShape_MNK_VO{}); static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads); static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, @@ -47,14 +50,18 @@ struct CollectiveEpilogueBwd { GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per store - using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), Int(TileShape_MNK{})) / AtomLayoutKdKV>>()); - using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{}))); - using SmemLayoutdKVtTMA = - decltype(cute::composition(SmemLayoutdKVTMA{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), - make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{})))); + using SmemLayoutdKTMA = decltype(tile_to_shape(SmemLayoutAtomdKTMA{}, select<1, 2>(TileShape_MNK{}))); + using SmemLayoutAtomdVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_VO{})), Int(TileShape_MNK_VO{})) / AtomLayoutKdKV>>()); + using SmemLayoutdVTMA = decltype(tile_to_shape(SmemLayoutAtomdVTMA{}, select<1, 2>(TileShape_MNK_VO{}))); + // using SmemLayoutdKVtTMA = + // decltype(cute::composition(SmemLayoutdKVTMA{}, + // make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), + // make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{})))); // If we don't use TMA static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16); @@ -64,12 +71,18 @@ struct CollectiveEpilogueBwd { Layout, Int>, Stride, _1>>{})); - using SmemLayoutAtomdKV = std::conditional_t; - using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{}))); - using SmemLayoutdKVt = - decltype(cute::composition(SmemLayoutdKV{}, + using SmemLayoutAtomdK = std::conditional_t; + using SmemLayoutAtomdV = std::conditional_t; + using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomdK{}, select<1, 2>(TileShape_MNK{}))); + using SmemLayoutdV = decltype(tile_to_shape(SmemLayoutAtomdV{}, select<1, 2>(TileShape_MNK_VO{}))); + using SmemLayoutdKt = + decltype(cute::composition(SmemLayoutdK{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{})))); + using SmemLayoutdVt = + decltype(cute::composition(SmemLayoutdV{}, + make_layout(make_shape(get<2>(TileShape_MNK_VO{}), get<1>(TileShape_MNK_VO{})), + make_stride(decltype(get<1>(TileShape_MNK_VO{})){}, _1{})))); using SmemCopyAtomdKV = Copy_Atom< std::conditional_t< @@ -79,27 +92,38 @@ struct CollectiveEpilogueBwd { >, Element>; - static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128; - static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment"); + static constexpr size_t SmemAlignmentdK = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdK{}) : 128; + static constexpr size_t SmemAlignmentdV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdV{}) : 128; + static_assert(SmemAlignmentdK >= 128, "Require at least 128B alignment"); - struct TensorStorage : cute::aligned_struct { - cute::array_aligned, SmemAlignmentdKV> smem_dk; - cute::array_aligned, SmemAlignmentdKV> smem_dv; + struct TensorStorage : cute::aligned_struct { + cute::array_aligned, SmemAlignmentdK> smem_dk; + cute::array_aligned, SmemAlignmentdV> smem_dv; }; using ShapedKV = cute::Shape; // (seqlen_k, d, head, batch) using StridedKV = cute::Stride; - using TMA_dKV = std::conditional_t< + using TMA_dK = std::conditional_t< Use_TMA, decltype(make_tma_copy( GmemTiledCopydKVTMA{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapedKV{}, StridedKV{}), - SmemLayoutdKVTMA{}, + SmemLayoutdKTMA{}, select<1, 2>(TileShape_MNK{}), _1{})), // no mcast for dKV std::nullptr_t >; + using TMA_dV = std::conditional_t< + Use_TMA, + decltype(make_tma_copy( + GmemTiledCopydKVTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapedKV{}, StridedKV{}), + SmemLayoutdVTMA{}, + select<1, 2>(TileShape_MNK_VO{}), + _1{})), // no mcast for dKV + std::nullptr_t + >; // Host side kernel arguments struct Arguments { @@ -107,6 +131,7 @@ struct CollectiveEpilogueBwd { ShapedKV const shape_dK; StridedKV const stride_dK; Element* ptr_dV; + ShapedKV const shape_dV; StridedKV const stride_dV; int const num_heads_q; int* dk_semaphore; @@ -121,8 +146,10 @@ struct CollectiveEpilogueBwd { ShapedKV const shape_dK; StridedKV const stride_dK; Element* ptr_dV; + ShapedKV const shape_dV; StridedKV const stride_dV; - TMA_dKV tma_store_dK, tma_store_dV; + TMA_dK tma_store_dK; + TMA_dV tma_store_dV; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; @@ -130,22 +157,22 @@ struct CollectiveEpilogueBwd { static Params to_underlying_arguments(Arguments const& args) { Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK); - Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV); - TMA_dKV tma_store_dK = [&] { + Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV); + TMA_dK tma_store_dK = [&] { if constexpr (Use_TMA) { - return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV + return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV } else { return nullptr; } }(); - TMA_dKV tma_store_dV = [&] { + TMA_dV tma_store_dV = [&] { if constexpr (Use_TMA) { - return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV + return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdVTMA{}, select<1, 2>(TileShape_MNK_VO{}), _1{}); // no mcast for dKV } else { return nullptr; } }(); - return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV, + return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.shape_dV, args.stride_dV, tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused}; } @@ -158,48 +185,51 @@ struct CollectiveEpilogueBwd { } } - template + template CUTLASS_DEVICE void store(Params const& params, - FrgTensorO const& tdKrdK, - FrgTensorO const& tdVrdV, + FrgTensor_dK const& tdKrdK, + FrgTensor_dV const& tdVrdV, SharedStorage& shared_storage, - TiledMma tiled_mma, + TiledMma_dK tiled_mma_dk, + TiledMma_dV tiled_mma_dv, int thread_idx, cute::tuple const& block_coord ) { auto [n_block, bidh, bidb] = block_coord; - Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{})); - Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{})); - Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{})); - Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{})); - auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma); - auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx); + Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdK{})); + Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdV{})); + Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKt{})); + Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdVt{})); + auto smem_tiled_copy_dK = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma_dk); + auto smem_tiled_copy_dV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma_dv); + auto smem_thr_copy_dK = smem_tiled_copy_dK.get_thread_slice(thread_idx); + auto smem_thr_copy_dV = smem_tiled_copy_dV.get_thread_slice(thread_idx); Tensor tdVrdV_out = make_tensor_like(tdVrdV); flash::convert_type_out(tdVrdV, tdVrdV_out); Tensor tdKrdK_out = make_tensor_like(tdKrdK); flash::convert_type_out(tdKrdK, tdKrdK_out); - Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccdKrdK = smem_thr_copy_dK.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccdVrdV = smem_thr_copy_dV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N) // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); } - Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) - Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdKsdK = smem_thr_copy_dK.partition_D(cute::conditional_return(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdV = smem_thr_copy_dV.partition_D(cute::conditional_return(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Make sure all WGs have finished reading K and V flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); - cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + cute::copy(smem_tiled_copy_dV, taccdVrdV, taccdVsdV); + cute::copy(smem_tiled_copy_dK, taccdKrdK, taccdKsdK); if constexpr (Use_TMA) { cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK); - Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK); + Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dV); Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK_VO{}), make_coord(n_block, _0{})); // (M, K) auto block_tma_dK = params.tma_store_dK.get_slice(_0{}); auto block_tma_dV = params.tma_store_dV.get_slice(_0{}); Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K) @@ -227,39 +257,45 @@ struct CollectiveEpilogueBwd { bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); - Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK_VO{}), make_coord(n_block, _0{})); // (M, K) GmemTiledCopydKV gmem_tiled_copy_dKV; - auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); - Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV); - Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K) - Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK); - Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K) + auto gmem_thr_copy_dK = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); + auto gmem_thr_copy_dV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); + Tensor tdKVgdV = gmem_thr_copy_dV.partition_D(gdV); + Tensor tdKVsdV = gmem_thr_copy_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K) + Tensor tdKVgdK = gmem_thr_copy_dK.partition_D(gdK); + Tensor tdKVsdK = gmem_thr_copy_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K) Tensor tdKVrdV = make_fragment_like(tdKVgdV); Tensor tdKVrdK = make_fragment_like(tdKVgdK); - Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cdK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cdV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK_VO{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdV))); + Tensor tdKVcdK = gmem_thr_copy_dK.partition_D(cdK); + Tensor tdKVcdV = gmem_thr_copy_dV.partition_D(cdV); + Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); + Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); + #pragma unroll + for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdK(_0{}, _0{}, k)) < get<1>(params.shape_dK); } #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } // Need to check OOB when reading from smem if kBlockN isn't evenly tiled static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; flash::copy( - gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdKV, kBlockN); + gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdV, tdKVpdV, kBlockN); flash::copy( - gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdKV, kBlockN); + gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdK, tdKVpdK, kBlockN); // // Tell warp 0 that smem_k and smem_v are ready // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); // Construct identity layout for gdKV // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdV, tdKVpdV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) ); flash::copy( - gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdK, tdKVpdK, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) ); } } @@ -282,38 +318,46 @@ struct CollectiveEpilogueBwd { bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); - Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK_VO{}), make_coord(n_block, _0{})); // (M, K) GmemTiledCopydKV gmem_tiled_copy_dKV; auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK); Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV); - Tensor tdKVrdKV = make_fragment_like(tdKVgdK); - clear(tdKVrdKV); + Tensor tdKVrdK = make_fragment_like(tdKVgdK); + Tensor tdKVrdV = make_fragment_like(tdKVgdV); + clear(tdKVrdK); + clear(tdKVrdV); // Construct identity layout for gdKV - Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cdK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cdV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK_VO{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdK))); + Tensor tdKVcdK = gmem_thr_copy_dKV.partition_D(cdK); + Tensor tdKVcdV = gmem_thr_copy_dKV.partition_D(cdV); + Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); + Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); + #pragma unroll + for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdK(_0{}, _0{}, k)) < get<1>(params.shape_dK); } #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN + gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdK, tdKVpdK, seqlen_info.seqlen - n_block * kBlockN ); flash::copy( - gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN + gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdV, tdKVpdV, seqlen_info.seqlen - n_block * kBlockN ); } }; template + int NumEpilogueThreads_, bool Varlen_, bool Deterministic, class TileShape_MNK_VO_> struct CollectiveEpilogueBwdGQA { using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_VO = TileShape_MNK_VO_; using Element = ElementAccum; using ArchTag = ArchTag_; static constexpr int NumEpilogueThreads = NumEpilogueThreads_; @@ -324,6 +368,7 @@ struct CollectiveEpilogueBwdGQA { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kHeadDim_VO = get<2>(TileShape_MNK_VO{}); static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp"); static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup; // Thread layout, 256 or 384 threads per row @@ -336,14 +381,16 @@ struct CollectiveEpilogueBwdGQA { using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2GLayoutAtomdKVaccum{}, Layout>{})); // Val layout, 1 vals per store - using SmemLayoutdKVaccum = Layout, Int>>; - using SmemLayoutdKVaccumFlat = Layout>>; + using SmemLayoutdKaccum = Layout, Int>>; + using SmemLayoutdKaccumFlat = Layout>>; + using SmemLayoutdVaccum = Layout, Int>>; + using SmemLayoutdVaccumFlat = Layout>>; // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue. static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256); struct TensorStorageTMA : cute::aligned_struct { - cute::array_aligned, SmemAlignment> smem_dkv; + cute::array_aligned, SmemAlignment> smem_dkv; }; struct TensorStorageSTG { cute::array smem_dkv; @@ -359,6 +406,7 @@ struct CollectiveEpilogueBwdGQA { ShapedKV const shape_dKaccum; StridedKV const stride_dKaccum; ElementAccum* ptr_dVaccum; + ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; int num_heads_q; int* dk_semaphore; @@ -373,6 +421,7 @@ struct CollectiveEpilogueBwdGQA { ShapedKV const shape_dKaccum; StridedKV const stride_dKaccum; ElementAccum* ptr_dVaccum; + ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; cutlass::FastDivmod qhead_per_khead_divmod; int* dk_semaphore; @@ -387,7 +436,8 @@ struct CollectiveEpilogueBwdGQA { assert(args.dk_semaphore != nullptr); assert(args.dv_semaphore != nullptr); } - return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.stride_dVaccum, + return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, + args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum, cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))), args.dk_semaphore, args.dv_semaphore, args.cu_seqlens, args.seqused}; @@ -398,13 +448,14 @@ struct CollectiveEpilogueBwdGQA { static void prefetch_tma_descriptors(Params const& params) { } - template + template CUTLASS_DEVICE void store(Params const& params, - FrgTensorO const& tdKrdK, - FrgTensorO const& tdVrdV, + FrgTensor_dK const& tdKrdK, + FrgTensor_dV const& tdVrdV, SharedStorage& shared_storage, - TiledMma tiled_mma, + TiledMma_dK tiled_mma_dk, + TiledMma_dV tiled_mma_dv, int thread_idx, cute::tuple const& block_coord ) { @@ -412,20 +463,24 @@ struct CollectiveEpilogueBwdGQA { auto [n_block, bidh, bidb] = block_coord; int bidh_idx_in_group; int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh); - Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{}); - Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{}); - static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum); + Tensor sdK = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKaccum{}); + Tensor sdK_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKaccumFlat{}); + static constexpr int dK_TMA_num_bytes = CUTE_STATIC_V(size(sdK_flat)) * sizeof(ElementAccum); + Tensor sdV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdVaccum{}); + Tensor sdV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdVaccumFlat{}); + static constexpr int dV_TMA_num_bytes = CUTE_STATIC_V(size(sdV_flat)) * sizeof(ElementAccum); flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0); - Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); + Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) - Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) + Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim_VO), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum; auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx); - Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV); + Tensor tdKVsdKaccum = r2s_thr_copy_dKVaccum.partition_D(sdK); + Tensor tdKVsdVaccum = r2s_thr_copy_dKVaccum.partition_D(sdV); // Only used if !Use_TMA R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum; @@ -436,7 +491,7 @@ struct CollectiveEpilogueBwdGQA { flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); if constexpr (Use_TMA) { Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N) - cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum); + cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdVaccum); } // int const num_batch = params.num_batch; @@ -455,7 +510,7 @@ struct CollectiveEpilogueBwdGQA { cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); if (thread_idx == 0) { - SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); + SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdV_flat.data()), raw_pointer_cast(gdVaccum.data()), dV_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); tma_store_arrive(); tma_store_wait<0>(); } @@ -473,7 +528,7 @@ struct CollectiveEpilogueBwdGQA { if constexpr (Use_TMA) { cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N) - cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum); + cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKaccum); } lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv; // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);} @@ -486,7 +541,7 @@ struct CollectiveEpilogueBwdGQA { cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); if (thread_idx == 0) { - SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); + SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdK_flat.data()), raw_pointer_cast(gdKaccum.data()), dK_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); tma_store_arrive(); tma_store_wait<0>(); } diff --git a/hopper/flash_bwd_kernel_sm90.h b/hopper/flash_bwd_kernel_sm90.h index 7aa32a8460f..2e31b328f97 100644 --- a/hopper/flash_bwd_kernel_sm90.h +++ b/hopper/flash_bwd_kernel_sm90.h @@ -35,8 +35,11 @@ class FlashAttnBwdSm90 { // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; - using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP; - using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV; + using TileShape_MNK_VO = typename CollectiveMainloop::TileShape_MNK_VO; + using TiledMmaS = typename CollectiveMainloop::TiledMmaS; + using TiledMmadP = typename CollectiveMainloop::TiledMmadP; + using TiledMmadK = typename CollectiveMainloop::TiledMmadK; + using TiledMmadV = typename CollectiveMainloop::TiledMmadV; using ArchTag = typename CollectiveMainloop::ArchTag; using ClusterShape = typename CollectiveMainloop::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; @@ -55,8 +58,8 @@ class FlashAttnBwdSm90 { using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaS{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaS{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -244,7 +247,8 @@ class FlashAttnBwdSm90 { TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); // Initialize matmul objects. - TiledMmadKV tiled_mma_dKV; + TiledMmadK tiled_mma_dK; + TiledMmadV tiled_mma_dV; PipelineState smem_pipe_read; PipelineState_dO smem_pipe_read_do; @@ -262,13 +266,13 @@ class FlashAttnBwdSm90 { cute::tuple block_coord = {n_block, bidh, bidb}; // dK and dV output accumulator. - Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); - Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); + Tensor tdKrdK = partition_fragment_C(tiled_mma_dK, select(TileShape_MNK{})); + Tensor tdVrdV = partition_fragment_C(tiled_mma_dV, select(TileShape_MNK_VO{})); bool tile_valid = collective_mainloop.mma( params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do, tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage); if (tile_valid) { - collective_epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, + collective_epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dK, tiled_mma_dV, threadIdx.x - NumCopyThreads, block_coord); } else { collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index bead0a790c0..75f9131e7e4 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -50,7 +50,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { using PreprocessKernel = flash::FlashAttnBwdPreprocess; typename PreprocessKernel::Arguments preprocess_args { static_cast(params.o_ptr), - {seqlen_q, params.d, params.h, batch_q}, // shape_O + {seqlen_q, params.d_vo, params.h, batch_q}, // shape_O {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O static_cast(params.do_ptr), {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO @@ -76,6 +76,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { CHECK_CUDA_KERNEL_LAUNCH(); using TileShape_MNK = cute::Shape, Int, Int>; + using TileShape_MNK_VO = cute::Shape, Int, Int>; using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80 static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80; @@ -84,15 +85,15 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { Arch >= 90, flash::CollectiveMainloopBwdSm90, + SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs, TileShape_MNK_VO>, flash::CollectiveMainloopBwdSm80 >; using CollectiveEpilogue = std::conditional_t< !GQA, - flash::CollectiveEpilogueBwd= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>, - flash::CollectiveEpilogueBwdGQA + flash::CollectiveEpilogueBwd= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV, TileShape_MNK_VO>, + flash::CollectiveEpilogueBwdGQA >; using Scheduler = flash::SingleTileScheduler; using AttnKernel = std::conditional_t< @@ -109,8 +110,10 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { {seqlen_k, params.d, params.h_k, batch_k}, // shape_K {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K static_cast(params.v_ptr), + {seqlen_k, params.d_vo, params.h_k, batch_k}, // shape_V {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V static_cast(params.do_ptr), + {seqlen_q, params.d_vo, params.h, batch_q}, // shape_dO {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO static_cast(params.dq_accum_ptr), {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum @@ -146,11 +149,18 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { } }(), static_cast(!GQA ? params.dv_ptr : params.dv_accum_ptr), + [&] { + if constexpr (!GQA) { + return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d_vo, params.h, batch_k}; // shape_dV + } else { + return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_vo_rounded, params.h_k, batch_k}; // shape_dVaccum + } + }(), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV } else { - return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum + return typename CollectiveEpilogue::StridedKV {_1{}, params.d_vo_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_vo_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum } }(), params.h, @@ -237,13 +247,20 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { CHECK_CUDA_KERNEL_LAUNCH(); if constexpr (GQA) { - using TileShape_NK = cute::Shape, Int>; - using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ, Int>; + using PostprocessKerneldK = flash::FlashAttnBwdPostprocessConvertdQ; - typename PostprocessKerneldKV::Arguments postprocess_dK_args { + using TileShape_NK_dV = cute::Shape, Int>; + using PostprocessKerneldV = flash::FlashAttnBwdPostprocessConvertdQ; + + typename PostprocessKerneldK::Arguments postprocess_dK_args { static_cast(params.dk_accum_ptr), {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum @@ -254,28 +271,29 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { params.cu_seqlens_k, params.seqused_k }; - typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args); - typename PostprocessKerneldKV::Arguments postprocess_dV_args { + typename PostprocessKerneldK::Params postprocess_dK_params = PostprocessKerneldK::to_underlying_arguments(postprocess_dK_args); + typename PostprocessKerneldV::Arguments postprocess_dV_args { static_cast(params.dv_accum_ptr), - {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dVaccum - {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum + {seqlen_k_rounded * params.d_vo_rounded, params.h_k, batch_k}, // shape_dVaccum + {_1{}, seqlen_k_rounded * params.d_vo_rounded, !is_varlen_k ? params.d_vo_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum static_cast(params.dv_ptr), - {seqlen_k, params.d, params.h_k, batch_k}, // shape_dV + {seqlen_k, params.d_vo, params.h_k, batch_k}, // shape_dV {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV 1.f, params.cu_seqlens_k, params.seqused_k }; - typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args); - int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{})); + typename PostprocessKerneldV::Params postprocess_dV_params = PostprocessKerneldV::to_underlying_arguments(postprocess_dV_args); + int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK_dK{})); dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b); - int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize; + int smem_size_postprocess = PostprocessKerneldK::SharedStorageSize; if (smem_size_postprocess >= 48 * 1024) { - CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess)); + CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, PostprocessKerneldK::SharedStorageSize)); + CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, PostprocessKerneldV::SharedStorageSize)); } - cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/); + cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldK::MaxThreadsPerBlock, PostprocessKerneldK::SharedStorageSize, stream, postprocess_dK_params, false /*launch_with_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); - cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/); + cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldV::MaxThreadsPerBlock, PostprocessKerneldV::SharedStorageSize, stream, postprocess_dV_params, false /*launch_with_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -354,11 +372,15 @@ template void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { if constexpr (Arch >= 90) { - run_mha_bwd_dispatch(params, stream); + if constexpr (kHeadDim_VO == 128) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } } else if constexpr (Arch == 86 || Arch == 89) { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); } else { - run_mha_bwd_dispatch(params, stream); + run_mha_bwd_dispatch(params, stream); } }); } diff --git a/hopper/flash_bwd_postprocess_kernel.h b/hopper/flash_bwd_postprocess_kernel.h index c91e261507d..32ed66a2829 100644 --- a/hopper/flash_bwd_postprocess_kernel.h +++ b/hopper/flash_bwd_postprocess_kernel.h @@ -17,6 +17,14 @@ namespace flash { +template +__device__ +constexpr void +assert_eq() { + static_assert(A == B); +} + + using namespace cute; template @@ -206,6 +214,8 @@ class FlashAttnBwdPostprocessConvertdQ { // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); } // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); } // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); } + assert_eq; + assert_eq; CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum)); Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum); cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum); diff --git a/hopper/flash_bwd_preprocess_kernel.h b/hopper/flash_bwd_preprocess_kernel.h index 5b3d5606c83..a5f6bf9f89e 100644 --- a/hopper/flash_bwd_preprocess_kernel.h +++ b/hopper/flash_bwd_preprocess_kernel.h @@ -25,6 +25,7 @@ class FlashAttnBwdPreprocess { // Type Aliases using TileShape_MK = TileShape_MK_; + using TileShape_MK_VO = TileShape_MK_VO_; using ArchTag = ArchTag_; static_assert(std::is_same_v && ArchTag::kMinComputeCapability >= 75 || @@ -156,9 +157,9 @@ class FlashAttnBwdPreprocess { if (is_varlen && m_block * kBlockM >= seqlen_o) { return; } Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0); - Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK_VO{}, make_coord(m_block, _0{})); // (M, K) Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0); - Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) + Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK_VO{}, make_coord(m_block, _0{})); // (M, K) auto shape_LSE = select<0, 2, 3>(params.shape_O); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0); @@ -172,7 +173,7 @@ class FlashAttnBwdPreprocess { Tensor tOgO = gmem_thr_copy_O.partition_S(gO); Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO); // Construct identity layout for gO - Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cO = cute::make_identity_tensor(TileShape_MK_VO{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O.partition_D(cO); Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index 393a6e5814b..5556b96177b 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -30,7 +30,7 @@ template + bool Mma_dP_is_RS=false, class TileShape_MNK_VO_=TileShape_MNK_> struct CollectiveMainloopBwdSm90 { static constexpr int kStages = Stages; @@ -41,6 +41,7 @@ struct CollectiveMainloopBwdSm90 { static_assert(!Mma_dP_is_RS || SdP_swapAB_); // If Mma_dP_is_RS, we need SdP_SwapAB using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_VO = TileShape_MNK_VO_; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -59,6 +60,7 @@ struct CollectiveMainloopBwdSm90 { static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kHeadDim_VO = get<2>(TileShape_MNK_VO{}); static_assert(ArchTag::kMinComputeCapability >= 90); static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1); @@ -76,41 +78,75 @@ struct CollectiveMainloopBwdSm90 { // static constexpr GMMA::Major PdS_Major = GMMA::Major::MN; static constexpr GMMA::Major PdSt_Major = PdS_Major == GMMA::Major::K ? GMMA::Major::MN : GMMA::Major::K; - using TileShapeAtomSdP = std::conditional_t< + using TileShapeAtomS = std::conditional_t< !SdP_swapAB, Shape, Int, Int>, Shape, Int, Int> >; - using AtomLayoutSdP = std::conditional_t< + using AtomLayoutS = std::conditional_t< !SdP_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; - using TiledMmaSdP = decltype(cute::make_tiled_mma( - cute::GMMA::ss_op_selector(), - AtomLayoutSdP{})); + using TiledMmaS = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutS{})); - using TiledMmadPRS = decltype(cute::make_tiled_mma( - cute::GMMA::rs_op_selector(), - AtomLayoutSdP{})); + using TileShapeAtomdP = std::conditional_t< + !SdP_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutdP = std::conditional_t< + !SdP_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + using TiledMmadP_SS = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutdP{})); + + using TiledMmadP_RS = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutdP{})); - using TileShapeAtomdKV = std::conditional_t< + using TiledMmadP = std::conditional_t; + + using TileShapeAtomdK = std::conditional_t< !dKV_swapAB, Shape, Int, Int>, Shape, Int, Int> >; - using AtomLayoutdKV = std::conditional_t< + using AtomLayoutdK = std::conditional_t< + !dKV_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + using TiledMmadK = decltype(cute::make_tiled_mma( + std::conditional_t< + Mma_dKV_is_RS, + decltype(cute::GMMA::rs_op_selector()), + decltype(cute::GMMA::ss_op_selector()) + >{}, + AtomLayoutdK{})); + + using TileShapeAtomdV = std::conditional_t< + !dKV_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutdV = std::conditional_t< !dKV_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; - using TiledMmadKV = decltype(cute::make_tiled_mma( + using TiledMmadV = decltype(cute::make_tiled_mma( std::conditional_t< Mma_dKV_is_RS, - decltype(cute::GMMA::rs_op_selector()), - decltype(cute::GMMA::ss_op_selector()) + decltype(cute::GMMA::rs_op_selector()), + decltype(cute::GMMA::ss_op_selector()) >{}, - AtomLayoutdKV{})); + AtomLayoutdV{})); using TileShapeAtomdQ = std::conditional_t< !dQ_swapAB, @@ -141,15 +177,15 @@ struct CollectiveMainloopBwdSm90 { make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutdO = decltype(tile_to_shape(SmemLayoutAtomQdO{}, - make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + make_shape(shape<0>(TileShape_MNK_VO{}), shape<2>(TileShape_MNK_VO{}), Int{}))); using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{}))); + decltype(cute::get<1>(TileShape_MNK_VO{})), decltype(cute::get<2>(TileShape_MNK_VO{}))>()); + using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK_VO{}))); using SmemLayoutAtomPdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector, @@ -176,8 +212,8 @@ struct CollectiveMainloopBwdSm90 { make_stride(Int{}, _1{}, Int{})))); using SmemLayoutdOt = decltype(cute::composition(SmemLayoutdO{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), - make_stride(Int{}, _1{}, Int{})))); + make_layout(make_shape(get<2>(TileShape_MNK_VO{}), get<0>(TileShape_MNK_VO{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); using SmemLayoutKt = decltype(cute::composition(SmemLayoutK{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), @@ -215,13 +251,20 @@ struct CollectiveMainloopBwdSm90 { using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; - using TMA_QdO = decltype(make_tma_copy_A_sm90( + using TMA_Q = decltype(make_tma_copy_A_sm90( GmemTiledCopyQdO{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), take<0, 2>(SmemLayoutQ{}), TileShape_MNK{}, ClusterShape{})); // mcast along N mode for this M load, if any + using TMA_dO = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQdO{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), + take<0, 2>(SmemLayoutdO{}), + TileShape_MNK_VO{}, + ClusterShape{})); // mcast along N mode for this M load, if any + using TMA_K = decltype(make_tma_copy_B_sm90( GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), @@ -233,7 +276,7 @@ struct CollectiveMainloopBwdSm90 { GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), SmemLayoutV{}, - TileShape_MNK{}, + TileShape_MNK_VO{}, ClusterShape{})); // no mcast for KV using MainloopPipeline = typename cutlass::PipelineTmaAsync; @@ -295,8 +338,10 @@ struct CollectiveMainloopBwdSm90 { ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; + ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; + ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; @@ -321,11 +366,14 @@ struct CollectiveMainloopBwdSm90 { struct Params { ShapeQKV const shape_Q; ShapeQKV const shape_K; + ShapeQKV const shape_V; + ShapeQKV const shape_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum stride_dQaccum; cutlass::FastDivmod qhead_per_khead_divmod; - TMA_QdO tma_load_Q, tma_load_dO; + TMA_Q tma_load_Q; + TMA_dO tma_load_dO; TMA_K tma_load_K; TMA_V tma_load_V; float const* const ptr_LSE_log2; @@ -347,18 +395,18 @@ struct CollectiveMainloopBwdSm90 { static Params to_underlying_arguments(Arguments const& args) { Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); - TMA_QdO tma_load_Q = make_tma_copy_A_sm90( + TMA_Q tma_load_Q = make_tma_copy_A_sm90( GmemTiledCopyQdO{}, mQ, SmemLayoutQ{}(_, _, _0{}), TileShape_MNK{}, ClusterShape{}); // mcast along N mode for this M load, if any - Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO); - TMA_QdO tma_load_dO = make_tma_copy_A_sm90( + Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_dO, args.stride_dO); + TMA_dO tma_load_dO = make_tma_copy_A_sm90( GmemTiledCopyQdO{}, mdO, SmemLayoutdO{}(_, _, _0{}), - TileShape_MNK{}, + TileShape_MNK_VO{}, ClusterShape{}); // mcast along N mode for this M load, if any Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); TMA_K tma_load_K = make_tma_copy_B_sm90( @@ -367,12 +415,12 @@ struct CollectiveMainloopBwdSm90 { SmemLayoutK{}, TileShape_MNK{}, ClusterShape{}); // no mcast for KV - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_V, args.stride_V); TMA_V tma_load_V = make_tma_copy_B_sm90( GmemTiledCopyKV{}, mV, SmemLayoutV{}, - TileShape_MNK{}, + TileShape_MNK_VO{}, ClusterShape{}); // no mcast for KV if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. @@ -384,7 +432,7 @@ struct CollectiveMainloopBwdSm90 { // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale // (the original softmax_scale) at the end. - return {args.shape_Q, args.shape_K, + return {args.shape_Q, args.shape_K, args.shape_V, args.shape_dO, args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, @@ -468,16 +516,16 @@ struct CollectiveMainloopBwdSm90 { bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); - Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) - Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK_VO{}), make_coord(_, _0{})); // (M, K, _) Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) - Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK_VO{}), make_coord(n_block, _0{})); // (N, K) Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) @@ -675,21 +723,21 @@ struct CollectiveMainloopBwdSm90 { } } - template + template CUTLASS_DEVICE bool mma(Params const& params, MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, PipelineState& smem_pipe_read, PipelineState_dO& smem_pipe_read_do, - FrgTensordKV& tdKrdK, - FrgTensordKV& tdVrdV, + FrgTensordK& tdKrdK, + FrgTensordV& tdVrdV, int thread_idx, int &work_idx, cute::tuple block_coord, SharedStorage& shared_storage ) { - static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); + static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); int n_block = get<0>(block_coord); int bidb = get<2>(block_coord); @@ -722,10 +770,10 @@ struct CollectiveMainloopBwdSm90 { Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{}); Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{}); - static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and - stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and - size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and - size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup, + static_assert(stride<0>(typename TiledMmaS::ALayout{}) == 0 and + stride<0>(typename TiledMmaS::BLayout{}) == 0 and + size<0>(typename TiledMmaS::ALayout{}) == cutlass::NumThreadsPerWarpGroup and + size<0>(typename TiledMmaS::BLayout{}) == cutlass::NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), @@ -734,19 +782,21 @@ struct CollectiveMainloopBwdSm90 { make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); - TiledMmaSdP tiled_mma_SdP; - using TiledMmadP = std::conditional_t; + TiledMmaS tiled_mma_S; TiledMmadP tiled_mma_dP; - TiledMmadKV tiled_mma_dKV; + TiledMmadK tiled_mma_dK; + TiledMmadV tiled_mma_dV; TiledMmadQ tiled_mma_dQ; - auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_S = tiled_mma_S.get_slice(warp_group_thread_layout(warp_group_idx)); auto wg_mma_dP = tiled_mma_dP.get_slice(warp_group_thread_layout(warp_group_idx)); - auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx); - auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx)); + auto thread_mma_S = tiled_mma_S.get_thread_slice(thread_idx); + auto thread_mma_dP = tiled_mma_dP.get_thread_slice(thread_idx); + auto wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx)); auto wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout_dq(warp_group_idx)); - auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP); + auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_S); auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx); R2STiledCopydQaccum r2s_tiled_copy_dQaccum; @@ -758,12 +808,12 @@ struct CollectiveMainloopBwdSm90 { // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda, // because some partition_fragment_A/B don't compile. // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function - Tensor tSrQ = mma_partition_fragment_AB(wg_mma_SdP, sQ); - Tensor tSrK = mma_partition_fragment_AB(wg_mma_SdP, sK); - Tensor tdPrdO = mma_partition_fragment_AB(wg_mma_SdP, sdO); + Tensor tSrQ = mma_partition_fragment_AB(wg_mma_S, sQ); + Tensor tSrK = mma_partition_fragment_AB(wg_mma_S, sK); + Tensor tdPrdO = mma_partition_fragment_AB(wg_mma_dP, sdO); Tensor tdPrV = mma_partition_fragment_AB(wg_mma_dP, sV); - Tensor tdVrdO = mma_partition_fragment_AB(wg_mma_dKV, sdOt); - Tensor tdKrQ = mma_partition_fragment_AB(wg_mma_dKV, sQt); + Tensor tdVrdO = mma_partition_fragment_AB(wg_mma_dV, sdOt); + Tensor tdKrQ = mma_partition_fragment_AB(wg_mma_dK, sQt); Tensor tdQrdS = mma_partition_fragment_AB(wg_mma_dQ, sdS); Tensor tdQrK = mma_partition_fragment_AB(wg_mma_dQ, sKt); @@ -774,11 +824,11 @@ struct CollectiveMainloopBwdSm90 { // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the col indices // or row indices, depending on whether SdP_swapAB. Tensor tLSEsLSE = cute::conditional_return( - group_modes<0, 2>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), // (2, MMA_M, PIPE) - group_modes<0, 3>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _, _))); // (2, V, MMA_N, PIPE) + group_modes<0, 2>(thread_mma_S.partition_C(sLSEMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), // (2, MMA_M, PIPE) + group_modes<0, 3>(thread_mma_S.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _, _))); // (2, V, MMA_N, PIPE) Tensor tLSEsdPsum = cute::conditional_return( - group_modes<0, 2>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), - group_modes<0, 3>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _, _))); + group_modes<0, 2>(thread_mma_dP.partition_C(sdPsumMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), + group_modes<0, 3>(thread_mma_dP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _, _))); // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); } // If we want to split the stats among the 8 threads that share the same rows. static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tLSEsLSE))::value, 8); @@ -802,7 +852,7 @@ struct CollectiveMainloopBwdSm90 { Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum); // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } - flash::Mask mask( + flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, params.qhead_per_khead_divmod ); @@ -826,9 +876,9 @@ struct CollectiveMainloopBwdSm90 { } auto bwd_step = [&](int m_block, auto mask_fn) { - Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_S, select(TileShape_MNK{})); consumer_wait(pipeline_q, smem_pipe_read); - flash::gemm(tiled_mma_SdP, tSrQ(_, _, _, smem_pipe_read.index()), tSrK, tSrS); + flash::gemm(tiled_mma_S, tSrQ(_, _, _, smem_pipe_read.index()), tSrK, tSrS); Tensor tLSErLSE = cute::conditional_return(make_fragment_like(tLSEsLSE(_, _0{})), make_tensor(Int{})); if constexpr (!ShuffleLSE) { cute::copy(tLSEsLSE(_, smem_pipe_read.index()), tLSErLSE); @@ -839,7 +889,7 @@ struct CollectiveMainloopBwdSm90 { tLSErLSE(i) = tLSEsLSE((thread_idx % 32) / 4 + i * 8, smem_pipe_read.index()); } } - Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); + Tensor tdPrdP = partition_fragment_C(tiled_mma_dP, select(TileShape_MNK_VO{})); PipelineState_dO smem_pipe_read_do_cur = cute::conditional_return(smem_pipe_read, smem_pipe_read_do); consumer_wait(pipeline_do, smem_pipe_read_do_cur); flash::gemm(tiled_mma_dP, tdPrdO(_, _, _, smem_pipe_read_do_cur.index()), tdPrV, tdPrdP); @@ -919,12 +969,12 @@ struct CollectiveMainloopBwdSm90 { if constexpr (!Slice_dQKV_Mma) { // Most cases take this path, except for hdim256 where we want to slice to reduce register pressure if constexpr (Mma_dKV_is_RS) { - Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); - flash::gemm(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); + flash::gemm(tiled_mma_dV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); } else { - Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); + Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dV, sPt); Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); - flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + flash::gemm(tiled_mma_dV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); } // SMEM fence to make sure sdS is written before it's read by WGMMA cutlass::arch::fence_view_async_shared(); @@ -935,12 +985,12 @@ struct CollectiveMainloopBwdSm90 { pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dQ if constexpr (Mma_dKV_is_RS) { - Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); - flash::gemm(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); + flash::gemm(tiled_mma_dK, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); } else { - Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); + Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dK, sdSt); Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); - flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + flash::gemm(tiled_mma_dK, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); } if constexpr (dQacc_use_TMA) { int const warp_group_idx = flash::canonical_warp_group_idx_nosync() - 1; @@ -961,31 +1011,31 @@ struct CollectiveMainloopBwdSm90 { } else { // Slice_dQKV_Mma static_assert(!(Slice_dQKV_Mma && Mma_dKV_is_RS)); - Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); + Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dV, sPt); Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); - flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + flash::gemm(tiled_mma_dV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); - flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + flash::gemm(tiled_mma_dV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); Tensor tdQrdQ_atomic = recast(r2s_thr_copy_dQaccum.retile_S(tdQrdQ)); Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); #pragma unroll for (int i = 0; i < size(tdQrdQ_atomic) / 2; ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } - Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); + Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dK, sdSt); Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); - flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + flash::gemm(tiled_mma_dK, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); #pragma unroll for (int i = size(tdQrdQ_atomic) / 2; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } - flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + flash::gemm(tiled_mma_dK, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); } warpgroup_wait<0>();