Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 142 additions & 87 deletions hopper/epilogue_bwd.hpp

Large diffs are not rendered by default.

40 changes: 20 additions & 20 deletions hopper/epilogue_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ namespace flash {

using namespace cute;

template <class TileShape_MNK_, class ClusterShape_, class Element_, class ArchTag_,
template <class TileShape_MNK_VO_, class ClusterShape_, class Element_, class ArchTag_,
int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool FP8PermuteCol=false>
struct CollectiveEpilogueFwd {

using TileShape_MNK = TileShape_MNK_;
using TileShape_MNK_VO = TileShape_MNK_VO_;
using ClusterShape = ClusterShape_;
using Element = Element_;
using ArchTag = ArchTag_;
Expand All @@ -37,18 +37,18 @@ struct CollectiveEpilogueFwd {
static_assert(ArchTag::kMinComputeCapability >= 80);
static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);

static constexpr int kBlockM = get<0>(TileShape_MNK{});
static constexpr int kHeadDim = get<2>(TileShape_MNK{});
static constexpr int kBlockM = get<0>(TileShape_MNK_VO{});
static constexpr int kHeadDim_VO = get<2>(TileShape_MNK_VO{});

using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;

// These are for storing the output tensor without TMA (e.g., for setting output to zero)
static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
static_assert(kHeadDim_VO % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
// We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
// in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
// we need to call divmod.
static constexpr int kBytePerRow = kHeadDim * sizeof(Element);
static constexpr int kBytePerRow = kHeadDim_VO * sizeof(Element);
static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
// static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
// static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerStore, NumEpilogueThreads);
Expand All @@ -65,15 +65,15 @@ struct CollectiveEpilogueFwd {
Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store

using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 2>(TileShape_MNK{})));
decltype(cute::get<0>(TileShape_MNK_VO{})), decltype(cute::get<2>(TileShape_MNK_VO{}))>());
using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 2>(TileShape_MNK_VO{})));
static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
Layout<Shape<_8, Int<kBlockKGmem>>,
Stride<Int<kBlockKGmem>, _1>>{}));
using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK_VO{})));
using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>;

using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch, num_splits)
Expand Down Expand Up @@ -109,7 +109,7 @@ struct CollectiveEpilogueFwd {
GmemTiledCopyOTMA{},
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
SmemLayoutOTMA{},
select<0, 2>(TileShape_MNK{}),
select<0, 2>(TileShape_MNK_VO{}),
_1{})), // no mcast for O
std::nullptr_t
>;
Expand Down Expand Up @@ -148,7 +148,7 @@ struct CollectiveEpilogueFwd {
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
TMA_O tma_store_O = [&]{
if constexpr (Use_TMA_O) {
return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{}); // no mcast
return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK_VO{}), _1{}); // no mcast
} else {
return nullptr;
}
Expand Down Expand Up @@ -243,14 +243,14 @@ struct CollectiveEpilogueFwd {
// Step 2: Write LSE from rmem -> gmem
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
// (MMA,MMA_M,MMA_K)
Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})));
Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK_VO{})));
static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));
Tensor taccOcO_row = taccOcO_rowcol(_, _0{});
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M

using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>;
using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK_VO{}), get<2>(TileShape_MNK_VO{}), NumEpilogueThreads, Element>;

Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx);
// if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
Expand All @@ -267,7 +267,7 @@ struct CollectiveEpilogueFwd {
// Step 3: Write O from smem -> gmem
if constexpr (Use_TMA_O) {
Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK_VO{}), make_coord(m_block, _0{})); // (M, K)
auto block_tma_O = params.tma_store_O.get_slice(_0{});
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
Expand All @@ -287,7 +287,7 @@ struct CollectiveEpilogueFwd {
}
} else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK_VO{}), make_coord(m_block, _0{})); // (M, K)
// if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
if constexpr (Use_smem) {
GmemTiledCopyO gmem_tiled_copy_O;
Expand All @@ -305,7 +305,7 @@ struct CollectiveEpilogueFwd {
}
if constexpr (!PackGQA) {
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})));
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK_VO{})));
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
Expand Down Expand Up @@ -361,7 +361,7 @@ struct CollectiveEpilogueFwd {
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
) {
static constexpr int kBlockM = get<0>(TileShape_MNK{});
static constexpr int kBlockM = get<0>(TileShape_MNK_VO{});
auto [m_block, bidh, bidb, split_idx] = block_coord;
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
bool const is_varlen = Varlen && params.cu_seqlens;
Expand Down Expand Up @@ -391,12 +391,12 @@ struct CollectiveEpilogueFwd {

GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})));
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK_VO{})));
if constexpr (!PackGQA) {
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK_VO{}), make_coord(m_block, _0{})); // (M, K)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_fragment_like(tOgO);
cute::clear(tOrO);
Expand All @@ -406,7 +406,7 @@ struct CollectiveEpilogueFwd {
);
} else {
// If PackGQA, we split the work of compute O_ptr among threads in the same row
using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>;
using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK_VO{}), get<2>(TileShape_MNK_VO{}), NumEpilogueThreads, Element>;
Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));
cute::clear(tOrO);
PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
Expand Down
6 changes: 3 additions & 3 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct Flash_fwd_params : public Qkv_params {
index_t v_descale_head_stride;

// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, d_vo, d_vo_rounded;
int total_q, total_k, total_knew;
int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q

Expand Down Expand Up @@ -197,9 +197,9 @@ struct Flash_bwd_params : public Flash_fwd_params {

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int Arch, typename T, int Headdim, bool Split, bool PagedKV, bool Has_softcap, bool PackGQA>
template <int Arch, typename T, int Headdim, bool Split, bool PagedKV, bool Has_softcap, bool PackGQA, int Headdim_VO=Headdim>
void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template <int Arch, typename T, int Headdim, bool Has_softcap>
template <int Arch, typename T, int Headdim, bool Has_softcap, int Headdim_VO=Headdim>
void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
template <typename T, typename Tpartial, int Headdim>
void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream);
Loading