diff --git a/ci/jax.sh b/ci/jax.sh index 81d994585..f048492ba 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -57,8 +57,9 @@ run_test_config() { export NVTE_JAX_UNITTEST_LEVEL=L0 # this env variable controls parameters set for some tests run_default_fa 1 test_custom_call_compute.py run_default_fa 1 test_functions.py - run 1 test_fused_attn.py - NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass + run 1 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # skip smallseq in normal flow + XLA_FLAGS='--xla_gpu_graph_level=0' run 1 test_fused_attn.py -k 'test_ck_unfused_smallseq_backend' # CK smallseq with GPU graph disabled + NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # Using FAv2 for forward and backward pass run_default_fa 1 test_helper.py run_default_fa 1 test_layer.py #it effectevly always uses unfused attention run_default_fa 1 test_sanity_import.py diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 4d7718cd0..8e2684a1b 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -9,6 +9,7 @@ from functools import partial from math import sqrt from typing import Tuple, Optional, Dict +import os import random import jax @@ -329,7 +330,11 @@ class FusedAttnRunner: # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. def _get_max_segments_per_sequence(self): if self.qkv_layout.is_thd(): - if 90400 <= get_cudnn_version() < 90500: + if ( + 90400 <= get_cudnn_version() < 90500 + or ( is_hip_extension() and + os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1") + ): return self.num_segments_per_seq else: # +1 for testing runtime_segments < max_segments @@ -418,6 +423,57 @@ def _check_configs(self): "the F16_arbitrary_seqlen backend." ) + def _setup_thd_segments_ck_smallseq(self, generate_random_segment_ids): + """ + Build THD segment descriptors for the CK small-seq path (NVTE_FUSED_ATTN_CK_SMALLSEQ=1). + + Uses num_segments_per_seq = max_seqlen_q for both Q and KV. For Q: if max_seqlen_q == 1, + uses a fixed layout (one token per batch, cu_seqlens [0,1,...,batch_size]); otherwise + generates random segments. For KV: always generates random segments. + """ + num_segments_per_seq = self.max_seqlen_q + if self.max_seqlen_q == 1: + # Q: deterministic - one segment of length 1 per batch -> cu_seqlen [0,1,2,...,batch_size] + segment_ids_q = jnp.ones((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + segment_pos_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + pad_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + seqlens_q = jnp.ones((self.batch_size, 1), dtype=jnp.int32) + offsets_q = jnp.concatenate( + [ + jnp.arange(self.batch_size, dtype=jnp.int32)[:, None], + jnp.full((self.batch_size, 1), -1, dtype=jnp.int32), + ], + axis=1, + ) + else: + segment_ids_q, segment_pos_q, pad_q = generate_random_segment_ids( + self.batch_size, self.max_seqlen_q, num_segments_per_seq, seed=42 + ) + seqlens_q, offsets_q = get_seqlens_and_offsets(segment_ids_q) + + min_segment_len = None if self.window_size is None else seqlens_q + segment_ids_kv, segment_pos_kv, pad_kv = generate_random_segment_ids( + self.batch_size, + self.max_seqlen_kv, + num_segments_per_seq, + seed=2024, + min_segment_len=min_segment_len, + ) + seqlens_kv, offsets_kv = get_seqlens_and_offsets(segment_ids_kv) + return ( + num_segments_per_seq, + segment_ids_q, + segment_pos_q, + pad_q, + seqlens_q, + offsets_q, + segment_ids_kv, + segment_pos_kv, + pad_kv, + seqlens_kv, + offsets_kv, + ) + def _setup_inputs(self): self._check_configs() @@ -539,27 +595,42 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - self.num_segments_per_seq = 2 - self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( - self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 - ) - self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) - # TODO(rewang): record only self attention and find the reason of cross attention - if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: - self.segment_ids_kv = self.segment_ids_q - self.segment_pos_kv = self.segment_pos_q - self.pad_kv = self.pad_q - else: - # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support - min_segment_len = None if self.window_size is None else self.seqlens_q - self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( - self.batch_size, - self.max_seqlen_kv, + if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": + ( self.num_segments_per_seq, - seed=2024, - min_segment_len=min_segment_len, + self.segment_ids_q, + self.segment_pos_q, + self.pad_q, + self.seqlens_q, + self.offsets_q, + self.segment_ids_kv, + self.segment_pos_kv, + self.pad_kv, + self.seqlens_kv, + self.offsets_kv, + ) = self._setup_thd_segments_ck_smallseq(generate_random_segment_ids) + else: + self.num_segments_per_seq = 2 + self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( + self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) - self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) + self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) + # TODO(rewang): record only self attention and find the reason of cross attention + if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: + self.segment_ids_kv = self.segment_ids_q + self.segment_pos_kv = self.segment_pos_q + self.pad_kv = self.pad_q + else: + # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support + min_segment_len = None if self.window_size is None else self.seqlens_q + self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( + self.batch_size, + self.max_seqlen_kv, + self.num_segments_per_seq, + seed=2024, + min_segment_len=min_segment_len, + ) + self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) else: self.num_segments_per_seq = 1 self.segment_ids_q, self.pad_q = gen_valid( @@ -1214,3 +1285,61 @@ def test_jax_new_rng(): ) runner = FusedAttnRunner(**kwargs) runner.test_forward() + + +# ROCm CK small-seq varlen tests. +@pytest.fixture +def ck_smallseq_env(monkeypatch): + """Enable CK small-seq path and disable XLA GPU graphs for these tests.""" + if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""): + pytest.skip("Test must be run with XLA_FLAGS='--xla_gpu_graph_level=0'") + monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1") + yield + +@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"]) +@pytest.mark.parametrize( + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v", + [ + pytest.param(4000, 1, 2, 16, 16, 128, 128, id="4000-1-2-16-16-128-128"), + pytest.param(4000, 1, 4, 16, 16, 128, 128, id="4000-1-4-16-16-128-128"), + pytest.param(4000, 1, 6, 16, 16, 128, 128, id="4000-1-6-16-16-128-128"), + pytest.param(4000, 1, 8, 16, 16, 128, 128, id="4000-1-8-16-16-128-128"), + pytest.param(4000, 1, 12, 16, 16, 128, 128, id="4000-1-12-16-16-128-128"), + pytest.param(4000, 1, 16, 16, 16, 128, 128, id="4000-1-16-16-16-128-128"), + pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"), + pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"), + ], +) +@pytest.mark.skipif( + not is_hip_extension(), reason="CK unfused smallseq backend only available on AMD hardware" +) +def test_ck_unfused_smallseq_backend( + ck_smallseq_env, b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype +): + """ + Test the CK unfused small-seq (varlen) path on ROCm: s_q=1, s_kv<=16, THD layout. + Uses THD_THD_THD (Q,K,V all THD). ck_smallseq_env sets NVTE_FUSED_ATTN_CK_SMALLSEQ=1 and + restores it after the test. + """ + runner = FusedAttnRunner( + batch_size=b, + max_seqlen_q=s_q, + max_seqlen_kv=s_kv, + num_heads_q=h_q, + num_heads_kv=h_kv, + head_dim_qk=d_qk, + head_dim_v=d_v, + attn_bias_type=AttnBiasType.NO_BIAS, + attn_mask_type=AttnMaskType.PADDING_MASK, + dropout_prob=0.0, + use_old_rng=True, + dtype=dtype, + is_training=True, + qkv_layout=QKVLayout.THD_THD_THD, + bias_shape=None, + window_size=None, + seq_desc_format=SeqDescFormat.Seqlens, + ) + runner._setup_inputs() + # runner.test_forward() + runner.test_backward() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 50dcf90a0..6774acfd2 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -200,6 +200,7 @@ else() fused_attn_rocm/fused_attn.cpp fused_attn_rocm/fused_attn_aotriton.cpp fused_attn_rocm/fused_attn_ck.cpp + fused_attn_rocm/fused_attn_smallseq.cpp fused_attn_rocm/utils.cpp gemm/rocm_gemm.cu amd_detail/system.cpp) diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 54ee94786..47528c020 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -168,6 +168,12 @@ hipError_t ck_attn_varlen_bwd( int how_v3_bf16_cvt, hipStream_t stream); +uint64_t get_runtime_max_seqlen(uint64_t b, + const void* cu_seqlen_ptr, + const void* cu_seqlen_padded_ptr, + void* workspace, + hipStream_t stream); + }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_H diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 7ca6fc95f..2af841581 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -9,6 +9,7 @@ #include // Required for std::accumulate #ifdef USE_FUSED_ATTN_CK #include +#include "fused_attn_smallseq.h" #endif // USE_FUSED_ATTN_CK #include "../util/cuda_runtime.h" #include "../util/system.h" @@ -614,6 +615,40 @@ void fused_attn_ck_fwd_impl( // denote the next available section of workspace from upstream void* workspace_next = workspace; + const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); + if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { + void* max_seqlen_workspace = workspace_next; + size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( + static_cast(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); + size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( + static_cast(b), devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream)); + workspace_next = static_cast(static_cast(workspace_next) + sizeof(uint64_t)); + + if (nvte_log_ck_config) { + std::cout << std::endl << "attn_fwd(ck small-seq): "; + std::cout << "b: " << b << ", "; + std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; + std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", "; + std::cout << "flow: " + << (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && + runtime_max_seqlen_kv <= 16 + ? "ck-smallseq" + : "regular ck/aiter") + << std::endl; + } + + if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { + fused_attn_rocm::fused_attn_smallseq_fwd( + b, h, hg, runtime_max_seqlen_kv, d_qk, d_v, + is_training, scaling_factor, dropout_probability, + devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxAux, + devPtrCuSeqlensKV, devPtrSeqOffsetsKV, + devPtrDropoutSeed, devPtrDropoutOffset, + dtype, workspace, workspace_size, stream); + return; + } + } + std::array q_stride; std::array k_stride; std::array v_stride; @@ -916,6 +951,40 @@ void fused_attn_ck_bwd_impl( // denote the next available section of workspace from upstream void* workspace_next = workspace; + const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); + if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { + void* max_seqlen_workspace = workspace_next; + size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); + size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream)); + workspace_next = static_cast(static_cast(workspace_next) + sizeof(uint64_t)); + + if (nvte_log_ck_config) { + std::cout << std::endl << "attn_bwd(ck small-seq): "; + std::cout << "b: " << b << ", "; + std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; + std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", "; + std::cout << "flow: " + << (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && + runtime_max_seqlen_kv <= 16 + ? "ck-smallseq" + : "regular ck/aiter") + << std::endl; + } + + if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { + fused_attn_rocm::fused_attn_smallseq_bwd( + b, h, hg, runtime_max_seqlen_kv, d_qk, d_v, + scaling_factor, dropout_probability, + devPtrQ, devPtrK, devPtrV, devPtrO, devPtrdO, devPtrSoftmaxAux, + devPtrdQ, devPtrdK, devPtrdV, + devPtrCuSeqlensKV, devPtrSeqOffsetsKV, + dtype, workspace, workspace_size, stream); + return; + } + } + std::array q_stride; std::array k_stride; std::array v_stride; @@ -1828,7 +1897,7 @@ void fused_attn_ck_fwd( size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast(1), std::multiplies())/h_q/d_qk; size_t max_tokens_kv = std::accumulate((input_K->data).shape.begin(), (input_K->data).shape.end(), static_cast(1), std::multiplies())/h_kv/d_qk; - bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; + bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; @@ -1883,7 +1952,6 @@ void fused_attn_ck_fwd( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - fused_attn_ck_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h, max_tokens_q, max_tokens_kv, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp new file mode 100644 index 000000000..789beffa2 --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp @@ -0,0 +1,972 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +/*! \file fused_attn_smallseq.cpp + * \brief Unfused small-seq (varlen) attention: seq_q=1, max_seqlen_kv<=16, THD only. + */ + +#include +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "fused_attn_smallseq.h" +#include "utils.h" + +// Macros to avoid repeating dispatch switch cases for max_seqlen_kv in [2, 16]. +// T, bi, hi and the pointer/scale args must be in scope where these are used. +#define SMALLSEQ_DISPATCH_FWD_CASE(N) \ + case N: \ + dispatch_fwd(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, \ + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, \ + stream); \ + break; +#define SMALLSEQ_DISPATCH_BWD_CASE(N) \ + case N: \ + dispatch_bwd(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, \ + dropout_mask, dropout, sqr_dk_scale, dQ_ptr, dK_ptr, \ + dV_ptr, workspace_ptr, cu_kv, cu_kv_p, stream); \ + break; + +namespace transformer_engine { +namespace fused_attn_rocm { + +enum class CausalMaskType { DISABLE = 0, TOP_LEFT = 1, BOTTOM_RIGHT = 2 }; + +template +struct SmallSeqConfig { + static constexpr int seq_q = 1; + static constexpr int max_seq_kv = MAX_SEQ_KV; + static constexpr int head_dim = HEAD_DIM; + static constexpr int step2_block_size = STEP2_BLOCK_SIZE; + static constexpr bool enable_dropout_mask = ENABLE_DROPOUT_MASK; + static constexpr CausalMaskType mask_type = MASK_TYPE; +}; + +/* MAX_SEQ_KV and HEAD_DIM are compile-time so kernels can use fixed stack arrays + * (e.g. float results[max_seq_kv], T attn[max_seq_kv]) and constexpr grid/block + * sizes. This matches varlen_attn/attn_fwd.cpp (FmhaKernelConfig<..., MAX_SEQ_KV, HEAD_DIM>) + * and INTEGRATION_TASK.md: seq_q==1, max_seq_kv<=16; head_dim=128 is the only + * value tested in varlen_attn (main() uses TestRunner<2,16>::run<..., 128, ...>). */ + +// ----- Forward kernels (with runtime batch_size, head_num) ----- + +template +__global__ void compute_scores_kernel(const T* Q, + const T* K, + T* scores, + float scale, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = 64; + constexpr int thread_block_size = 64; + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * thread_block_size * tasks_per_block; + int thread_id = threadIdx.x; + + for (int task = 0; task < tasks_per_block; task++) { + int cur_batch_idx = base_block_offset + task * thread_block_size + thread_id; + int batch_idx = cur_batch_idx / (seq_q * head_num); + int seq_head_idx = cur_batch_idx % (seq_q * head_num); + int seq_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + int kv_offset = cu_seqlens_kv_padded[batch_idx]; + + float results[max_seq_kv]; + T fetch_Q[block_k]; + T fetch_K[block_k]; + T* Q_ptr = (T*)&Q[(batch_idx * seq_q * head_num + seq_idx * head_num + head_idx) * head_dim]; + T* K_ptr = (T*)&K[(kv_offset * head_num + head_idx) * head_dim]; + T* score_ptr = (T*)&scores[cur_batch_idx * max_seq_kv]; + uint4 ls_dwordx4_tmp_var; + for (int i = 0; i < seq_kv; i++) + results[i] = 0.0f; + for (int dim_offset = 0; dim_offset < head_dim; dim_offset += block_k) { + if constexpr (std::is_same::value || std::is_same::value) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = *((uint4*)&Q_ptr[dim_offset + k * 8]); + fetch_Q[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_Q[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_Q[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_Q[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_Q[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_Q[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_Q[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_Q[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&K_ptr[kv_idx * head_num * head_dim + dim_offset + k * 8]); + fetch_K[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_K[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_K[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_K[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_K[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_K[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_K[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_K[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += static_cast(fetch_Q[k]) * static_cast(fetch_K[k]); + } + } else { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = *((uint4*)&Q_ptr[dim_offset + k * 4]); + fetch_Q[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_Q[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_Q[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_Q[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&K_ptr[kv_idx * head_num * head_dim + dim_offset + k * 4]); + fetch_K[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_K[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_K[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_K[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += fetch_Q[k] * fetch_K[k]; + } + } + } + for (int i = 0; i < seq_kv; i++) + score_ptr[i] = T(results[i] * scale); + for (int i = seq_kv; i < max_seq_kv; i++) + score_ptr[i] = T(-1e9f); + } +} + +template +__global__ void apply_mask_and_softmax_kernel(T* scores, + const T* dropout_mask, + float dropout_scale, + const int* cu_seqlens_kv, + int batch_size, + int head_num) +{ + const uint32_t block_id = blockIdx.x; + const uint32_t thread_id = threadIdx.x; + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int block_size = Config::step2_block_size; + constexpr int per_score_size = seq_q * max_seq_kv; + constexpr int valid_thread_range = block_size / per_score_size * per_score_size; + const uint32_t cur_block_offset = block_id * valid_thread_range + thread_id; + const uint32_t total_elt = static_cast(batch_size) * head_num * seq_q * max_seq_kv; + bool is_tail = block_id * valid_thread_range + block_size >= total_elt; + int real_row_num = + is_tail ? (total_elt - block_id * valid_thread_range) / max_seq_kv + : valid_thread_range / max_seq_kv; + + if (cur_block_offset < total_elt && thread_id < valid_thread_range) { + __shared__ T tmp_scores[valid_thread_range]; + constexpr int row_num = valid_thread_range / max_seq_kv; + __shared__ T row_max[row_num]; + __shared__ T row_sum[row_num]; + + int global_row_idx = cur_block_offset / max_seq_kv; + int batch_idx = global_row_idx / (seq_q * head_num); + int k_idx = cur_block_offset % max_seq_kv; + + int seq_kv = (batch_idx < batch_size) + ? (cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]) + : max_seq_kv; + + T score_value = scores[cur_block_offset]; + tmp_scores[thread_id] = score_value; + + if constexpr (Config::mask_type == CausalMaskType::TOP_LEFT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx > q_idx || k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } else if constexpr (Config::mask_type == CausalMaskType::BOTTOM_RIGHT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx < q_idx || k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } else { + if (k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } + __syncthreads(); + + if (thread_id < real_row_num) { + T max_val = T(-1e9f); +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + max_val = fmaxf(static_cast(max_val), + static_cast(tmp_scores[thread_id * max_seq_kv + i])); + row_max[thread_id] = max_val; + } + __syncthreads(); + + T exp_val = T(expf(static_cast(tmp_scores[thread_id] - + row_max[thread_id / max_seq_kv]))); + tmp_scores[thread_id] = exp_val; + __syncthreads(); + + if (thread_id < real_row_num) { + T sum = T(0.0f); +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + sum += tmp_scores[thread_id * max_seq_kv + i]; + row_sum[thread_id] = sum; + } + __syncthreads(); + + T attn_weight = tmp_scores[thread_id] / row_sum[thread_id / max_seq_kv]; + if constexpr (Config::enable_dropout_mask) { + attn_weight = attn_weight * dropout_mask[cur_block_offset] * dropout_scale; + } + scores[cur_block_offset] = attn_weight; + } +} + +template +__global__ void compute_output_kernel(const T* attn_weights, + const T* V, + T* O, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt], + store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T attn[max_seq_kv]; + + for (int task = 0; task < tasks_per_block; task++) { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + int batch_idx = cur_idx / (seq_q * head_num); + int seq_head_idx = cur_idx % (seq_q * head_num); + int seq_q_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + int kv_offset = cu_seqlens_kv_padded[batch_idx]; + +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + attn[i] = attn_weights[cur_idx * max_seq_kv + i]; + for (int j = 0; j < seq_kv; j++) { +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + load_dwordx4_tmp_var[i] = + *((uint4*)&V[((kv_offset + j) * head_num + head_idx) * head_dim + thread_head_offset + + i * dwordx4_load_elt]); + } +#pragma unroll + for (int b = 0; b < block_k; b++) + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + attn[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) + *((uint4*)&O[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * head_dim + + thread_head_offset + i * dwordx4_load_elt]) = store_dwordx4_tmp_var[i]; + } +} + +// ----- Forward launcher ----- + +template +void run_attn_fwd_impl(int b, + int head_num, + const T* Q, + const T* K, + const T* V, + const T* dropout_mask, + float dropout_p, + float sqr_dk_scale, + T* O, + T* workspace, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + hipStream_t stream) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int warp_size = 64; + + int merge_bs = b * head_num; + float scale = sqr_dk_scale; + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + + constexpr int kernel1_threads = 64; + dim3 block(kernel1_threads); + dim3 grid((merge_bs + kernel1_threads - 1) / kernel1_threads); + compute_scores_kernel<<>>( + Q, K, workspace, scale, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); + + constexpr int work_thread_num = + Config::step2_block_size / (seq_q * max_seq_kv) * (seq_q * max_seq_kv); + dim3 grid2((merge_bs * seq_q * max_seq_kv + work_thread_num - 1) / work_thread_num); + dim3 block2(Config::step2_block_size); + apply_mask_and_softmax_kernel<<>>( + workspace, dropout_mask, dropout_scale, cu_seqlens_kv, b, head_num); + + constexpr int kernel3_block_k = 8; + constexpr int kernel3_threads = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / kernel3_block_k); + + dim3 block3(kernel3_threads); + dim3 grid3((merge_bs / process_head_per_warp + 2 - 1) / 2); + compute_output_kernel<<>>( + workspace, V, O, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); +} + +// ----- Backward kernels (with runtime batch_size, head_num) ----- + +template +__global__ void compute_grad_v_kernel(const T* attn_weights, + const T* grad_O, + T* grad_V, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T attn[max_seq_kv]; + + for (int task = 0; task < tasks_per_block; task++) { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + int batch_idx = cur_idx / (seq_q * head_num); + int seq_head_idx = cur_idx % (seq_q * head_num); + int seq_q_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + attn[i] = attn_weights[cur_idx * max_seq_kv + i]; + + for (int j = 0; j < seq_kv; j++) { + uint4 store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } + +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + load_dwordx4_tmp_var[i] = + *((uint4*)&grad_O[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * + head_dim + + thread_head_offset + i * dwordx4_load_elt]); + } + +#pragma unroll + for (int b = 0; b < block_k; b++) { + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + attn[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } + +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + int grad_v_idx = (cu_seqlens_kv_padded[batch_idx] + j) * head_num * head_dim + + head_idx * head_dim + thread_head_offset + i * dwordx4_load_elt; + *((uint4*)&grad_V[grad_v_idx]) = store_dwordx4_tmp_var[i]; + } + } + } +} + +template +__global__ void compute_grad_attn_kernel(const T* grad_O, + const T* V, + T* grad_attn, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = 64; + constexpr int thread_block_size = 64; + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * thread_block_size * tasks_per_block; + int thread_id = threadIdx.x; + + for (int task = 0; task < tasks_per_block; task++) { + int cur_batch_idx = base_block_offset + task * thread_block_size + thread_id; + int batch_idx = cur_batch_idx / (seq_q * head_num); + int seq_head_idx = cur_batch_idx % (seq_q * head_num); + int seq_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + + float results[max_seq_kv]; + T fetch_grad_O[block_k]; + T fetch_V[block_k]; + + T* grad_O_ptr = (T*)&grad_O[(batch_idx * seq_q * head_num + seq_idx * head_num + head_idx) * + head_dim]; + + const T* V_base = + &V[cu_seqlens_kv_padded[batch_idx] * head_num * head_dim + head_idx * head_dim]; + int V_stride = head_num * head_dim; + + T* grad_attn_ptr = (T*)&grad_attn[cur_batch_idx * max_seq_kv]; + + uint4 ls_dwordx4_tmp_var; + + for (int i = 0; i < seq_kv; i++) + results[i] = 0.0f; + + for (int dim_offset = 0; dim_offset < head_dim; dim_offset += block_k) { + if constexpr (std::is_same::value || std::is_same::value) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = *((uint4*)&grad_O_ptr[dim_offset + k * 8]); + fetch_grad_O[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_grad_O[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_grad_O[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_grad_O[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_grad_O[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_grad_O[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_grad_O[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_grad_O[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&V_base[kv_idx * V_stride + dim_offset + k * 8]); + fetch_V[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_V[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_V[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_V[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_V[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_V[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_V[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_V[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += + static_cast(fetch_grad_O[k]) * static_cast(fetch_V[k]); + } + } else { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = *((uint4*)&grad_O_ptr[dim_offset + k * 4]); + fetch_grad_O[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_grad_O[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_grad_O[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_grad_O[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&V_base[kv_idx * V_stride + dim_offset + k * 4]); + fetch_V[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_V[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_V[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_V[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += fetch_grad_O[k] * fetch_V[k]; + } + } + } + for (int i = 0; i < seq_kv; i++) + grad_attn_ptr[i] = T(results[i]); + for (int i = seq_kv; i < max_seq_kv; i++) + grad_attn_ptr[i] = T(0.0f); + } +} + +template +__global__ void softmax_backward_kernel(const T* attn_weights, + const T* dropout_mask, + T* grad_attn, + float dropout_scale, + const int* cu_seqlens_kv, + int batch_size, + int head_num) +{ + const uint32_t block_id = blockIdx.x; + const uint32_t thread_id = threadIdx.x; + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int block_size = Config::step2_block_size; + constexpr int per_grad_attn_size = seq_q * max_seq_kv; + constexpr int valid_thread_range = block_size / per_grad_attn_size * per_grad_attn_size; + const uint32_t cur_block_offset = block_id * valid_thread_range + thread_id; + const uint32_t total_elt = static_cast(batch_size) * head_num * seq_q * max_seq_kv; + bool is_tail = block_id * valid_thread_range + block_size >= total_elt; + int real_row_num = + is_tail ? (total_elt - block_id * valid_thread_range) / max_seq_kv + : valid_thread_range / max_seq_kv; + + if (cur_block_offset < total_elt && thread_id < valid_thread_range) { + __shared__ T tmp_grad_score[valid_thread_range]; + constexpr int row_num = valid_thread_range / max_seq_kv; + __shared__ T reduce_grad_score[row_num]; + + int global_row_idx = cur_block_offset / max_seq_kv; + int batch_idx = global_row_idx / (seq_q * head_num); + int k_idx = cur_block_offset % max_seq_kv; + + int seq_kv = (batch_idx < batch_size) + ? (cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]) + : max_seq_kv; + + T grad_attn_value = grad_attn[cur_block_offset]; + if constexpr (Config::enable_dropout_mask) + grad_attn_value = grad_attn_value * dropout_mask[cur_block_offset] * dropout_scale; + T attn_weight = attn_weights[cur_block_offset]; + T grad_score = grad_attn_value * attn_weight; + tmp_grad_score[thread_id] = grad_score; + __syncthreads(); + + if (thread_id < real_row_num) { + T sum = T(0.0f); +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + sum += tmp_grad_score[thread_id * max_seq_kv + i]; + reduce_grad_score[thread_id] = sum; + } + __syncthreads(); + + grad_score -= attn_weight * reduce_grad_score[thread_id / max_seq_kv]; + + if constexpr (Config::mask_type == CausalMaskType::TOP_LEFT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx > q_idx || k_idx >= seq_kv) + grad_score = T(0.0f); + } else if constexpr (Config::mask_type == CausalMaskType::BOTTOM_RIGHT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx < q_idx || k_idx >= seq_kv) + grad_score = T(0.0f); + } else { + if (k_idx >= seq_kv) + grad_score = T(0.0f); + } + grad_attn[cur_block_offset] = grad_score; + } +} + +template +__global__ void compute_grad_qk_kernel(const T* grad_scores, + const T* Q, + const T* K, + T* grad_Q, + T* grad_K, + float scale, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T grad_score_vals[max_seq_kv]; + + for (int task = 0; task < tasks_per_block; task++) { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + int batch_idx = cur_idx / (seq_q * head_num); + int seq_head_idx = cur_idx % (seq_q * head_num); + int seq_q_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + grad_score_vals[i] = grad_scores[cur_idx * max_seq_kv + i]; + + uint4 store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } + for (int j = 0; j < seq_kv; j++) { +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + int k_idx = (cu_seqlens_kv_padded[batch_idx] + j) * head_num * head_dim + + head_idx * head_dim + thread_head_offset + i * dwordx4_load_elt; + load_dwordx4_tmp_var[i] = *((uint4*)&K[k_idx]); + } +#pragma unroll + for (int b = 0; b < block_k; b++) { + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + grad_score_vals[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } + } +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + T* grad_Q_ptr = &grad_Q[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * + head_dim + + thread_head_offset + i * dwordx4_load_elt]; + for (int b = 0; b < dwordx4_load_elt; b++) + grad_Q_ptr[b] = ((T*)&store_dwordx4_tmp_var[i])[b] * T(scale); + } +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + load_dwordx4_tmp_var[i] = + *((uint4*)&Q[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * head_dim + + thread_head_offset + i * dwordx4_load_elt]); + } + for (int j = 0; j < seq_kv; j++) { +#pragma unroll + for (int b = 0; b < block_k; b++) { + T val = grad_score_vals[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] * + T(scale); + int grad_k_idx = (cu_seqlens_kv_padded[batch_idx] + j) * head_num * head_dim + + head_idx * head_dim + thread_head_offset + b; + grad_K[grad_k_idx] = val; + } + } + } +} + +template +void run_attn_bwd_impl(int b, + int head_num, + const T* Q, + const T* K, + const T* V, + const T* grad_O, + const T* attn_weights, + const T* dropout_mask, + float dropout_p, + float sqr_dk_scale, + T* grad_Q, + T* grad_K, + T* grad_V, + T* workspace, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + hipStream_t stream) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int warp_size = 64; + + int merge_bs = b * head_num; + float scale = sqr_dk_scale; + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + + dim3 block(warp_size); + constexpr int tasks_per_block_v = 16; + dim3 grid_v((b * seq_q * head_num + tasks_per_block_v - 1) / tasks_per_block_v); + compute_grad_v_kernel<<>>( + attn_weights, grad_O, grad_V, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); + + constexpr int tasks_per_block_attn = 16; + constexpr int process_head_per_warp = warp_size / (head_dim / 64); + dim3 grid_grad_attn((b * seq_q * head_num + tasks_per_block_attn * process_head_per_warp - 1) / + (tasks_per_block_attn * process_head_per_warp)); + compute_grad_attn_kernel<<>>( + grad_O, V, workspace, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); + + constexpr int work_thread_num = + Config::step2_block_size / (seq_q * max_seq_kv) * (seq_q * max_seq_kv); + dim3 grid_softmax((merge_bs * seq_q * max_seq_kv + work_thread_num - 1) / work_thread_num); + dim3 block_softmax(Config::step2_block_size); + softmax_backward_kernel<<>>( + attn_weights, dropout_mask, workspace, dropout_scale, cu_seqlens_kv, b, head_num); + + constexpr int tasks_per_block_qk = 4; + dim3 grid_qk((b * seq_q * head_num + tasks_per_block_qk - 1) / tasks_per_block_qk); + compute_grad_qk_kernel<<>>( + workspace, Q, K, grad_Q, grad_K, scale, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); +} + +size_t fused_attn_smallseq_bwd_workspace_size(size_t b, + size_t h_q, + size_t max_seqlen_kv, + DType dtype) { + constexpr size_t elt_size = 2u; // BF16 and FP16 are 2 bytes + return b * h_q * 1 * std::min(max_seqlen_kv, size_t(16)) * elt_size; +} + +template +static void dispatch_fwd(int b, int h_q, const T* Q, const T* K, const T* V, const T* dropout_mask, + float dropout, float scale, T* O, T* workspace, const int* cu_kv, + const int* cu_kv_p, hipStream_t stream) { + run_attn_fwd_impl>( + b, h_q, Q, K, V, dropout_mask, dropout, scale, O, workspace, cu_kv, cu_kv_p, stream); +} + +template +static void dispatch_bwd(int b, int h_q, const T* Q, const T* K, const T* V, const T* grad_O, + const T* attn_weights, const T* dropout_mask, float dropout, float scale, + T* grad_Q, T* grad_K, T* grad_V, T* workspace, const int* cu_kv, + const int* cu_kv_p, hipStream_t stream) { + run_attn_bwd_impl>( + b, h_q, Q, K, V, grad_O, attn_weights, dropout_mask, dropout, scale, + grad_Q, grad_K, grad_V, workspace, cu_kv, cu_kv_p, stream); +} + +void fused_attn_smallseq_fwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + bool is_training, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + void* devPtrO, + void* attn_weights_buffer, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + const void* rng_seed, + const void* rng_offset, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream) +{ + const char* nvte_smallseq = std::getenv("NVTE_LOG_CK_CONFIG"); + if (nvte_smallseq && std::string(nvte_smallseq) == "1") { + std::cout << std::endl << "attn_fwd(small-seq kernel): "; + std::cout << "b: " << b << ", "; + std::cout << "h_q: " << h_q << ", "; + std::cout << "h_kv: " << h_kv << ", "; + std::cout << "max_seqlen_kv: " << max_seqlen_kv << ", "; + std::cout << "d_qk: " << d_qk << ", "; + std::cout << "d_v: " << d_v << ", "; + std::cout << "is_training: " << is_training << ", "; + std::cout << "attn_scale: " << attn_scale << ", "; + std::cout << "dropout: " << dropout << ", "; + std::cout << "qkv_dtype: " + << (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?") + << std::endl; + } + + float sqr_dk_scale = attn_scale; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(qkv_dtype, T, + const T* Q_ptr = static_cast(devPtrQ); + const T* K_ptr = static_cast(devPtrK); + const T* V_ptr = static_cast(devPtrV); + T* O_ptr = static_cast(devPtrO); + T* attn_workspace = static_cast(attn_weights_buffer); + const int* cu_kv = static_cast(devPtrCuSeqlensKV); + const int* cu_kv_p = static_cast(devPtrSeqOffsetsKV); + const T* dropout_mask = nullptr; + int bi = static_cast(b); + int hi = static_cast(h_q); + + switch (max_seqlen_kv) { + SMALLSEQ_DISPATCH_FWD_CASE(2) + SMALLSEQ_DISPATCH_FWD_CASE(3) + SMALLSEQ_DISPATCH_FWD_CASE(4) + SMALLSEQ_DISPATCH_FWD_CASE(5) + SMALLSEQ_DISPATCH_FWD_CASE(6) + SMALLSEQ_DISPATCH_FWD_CASE(7) + SMALLSEQ_DISPATCH_FWD_CASE(8) + SMALLSEQ_DISPATCH_FWD_CASE(9) + SMALLSEQ_DISPATCH_FWD_CASE(10) + SMALLSEQ_DISPATCH_FWD_CASE(11) + SMALLSEQ_DISPATCH_FWD_CASE(12) + SMALLSEQ_DISPATCH_FWD_CASE(13) + SMALLSEQ_DISPATCH_FWD_CASE(14) + SMALLSEQ_DISPATCH_FWD_CASE(15) + SMALLSEQ_DISPATCH_FWD_CASE(16) + default: + NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); + } + ); + +} + +void fused_attn_smallseq_bwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + const void* devPtrO, + const void* devPtrdO, + const void* attn_weights, + void* devPtrdQ, + void* devPtrdK, + void* devPtrdV, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream) +{ + const char* nvte_smallseq = std::getenv("NVTE_LOG_CK_CONFIG"); + if (nvte_smallseq && std::string(nvte_smallseq) == "1") { + std::cout << std::endl << "attn_bwd(ck small-seq kernel): "; + std::cout << "b: " << b << ", "; + std::cout << "h_q: " << h_q << ", "; + std::cout << "h_kv: " << h_kv << ", "; + std::cout << "max_seqlen_kv: " << max_seqlen_kv << ", "; + std::cout << "d_qk: " << d_qk << ", "; + std::cout << "d_v: " << d_v << ", "; + std::cout << "attn_scale: " << attn_scale << ", "; + std::cout << "dropout: " << dropout << ", "; + std::cout << "qkv_dtype: " + << (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?") + << std::endl; + } + + float sqr_dk_scale = attn_scale; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(qkv_dtype, T, + const T* Q_ptr = static_cast(devPtrQ); + const T* K_ptr = static_cast(devPtrK); + const T* V_ptr = static_cast(devPtrV); + const T* O_ptr = static_cast(devPtrO); + const T* dO_ptr = static_cast(devPtrdO); + const T* attn_ptr = static_cast(attn_weights); + T* dQ_ptr = static_cast(devPtrdQ); + T* dK_ptr = static_cast(devPtrdK); + T* dV_ptr = static_cast(devPtrdV); + T* workspace_ptr = static_cast(workspace); + const int* cu_kv = static_cast(devPtrCuSeqlensKV); + const int* cu_kv_p = static_cast(devPtrSeqOffsetsKV); + const T* dropout_mask = nullptr; + int bi = static_cast(b); + int hi = static_cast(h_q); + + switch (max_seqlen_kv) { + SMALLSEQ_DISPATCH_BWD_CASE(2) + SMALLSEQ_DISPATCH_BWD_CASE(3) + SMALLSEQ_DISPATCH_BWD_CASE(4) + SMALLSEQ_DISPATCH_BWD_CASE(5) + SMALLSEQ_DISPATCH_BWD_CASE(6) + SMALLSEQ_DISPATCH_BWD_CASE(7) + SMALLSEQ_DISPATCH_BWD_CASE(8) + SMALLSEQ_DISPATCH_BWD_CASE(9) + SMALLSEQ_DISPATCH_BWD_CASE(10) + SMALLSEQ_DISPATCH_BWD_CASE(11) + SMALLSEQ_DISPATCH_BWD_CASE(12) + SMALLSEQ_DISPATCH_BWD_CASE(13) + SMALLSEQ_DISPATCH_BWD_CASE(14) + SMALLSEQ_DISPATCH_BWD_CASE(15) + SMALLSEQ_DISPATCH_BWD_CASE(16) + default: + NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); + } + ); +} + +} // namespace fused_attn_rocm +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h new file mode 100644 index 000000000..818b5448a --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h @@ -0,0 +1,81 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +/*! \file fused_attn_smallseq.h + * \brief Small-seq (varlen) attention for ROCm: seq_q=1, max_seqlen_kv<=16, THD only. + */ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ + +#include + +namespace transformer_engine { +namespace fused_attn_rocm { + +/** Workspace size in bytes for small-seq backward path */ +size_t fused_attn_smallseq_bwd_workspace_size(size_t b, + size_t h_q, + size_t max_seqlen_kv, + DType dtype); + +/** Forward: Q,K,V -> O; attention weights written to attn_weights_buffer (same as output_S). + * attn_weights_buffer is also used as internal workspace (scores then overwritten by attn + * weights). No separate workspace required for the launcher; caller may use workspace for + * get_runtime_max_seqlen (8 bytes). */ +void fused_attn_smallseq_fwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + bool is_training, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + void* devPtrO, + void* attn_weights_buffer, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + const void* rng_seed, + const void* rng_offset, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream); + +/** Backward: dO, O, attn_weights -> dQ, dK, dV. attn_weights is the buffer from forward + * (output_S). workspace must be at least fused_attn_smallseq_bwd_workspace_size. */ +void fused_attn_smallseq_bwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + const void* devPtrO, + const void* devPtrdO, + const void* attn_weights, + void* devPtrdQ, + void* devPtrdK, + void* devPtrdV, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream); + +} // namespace fused_attn_rocm +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 45d3d8b59..6b9b0a30a 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -366,12 +366,39 @@ def abstract( softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) elif backend == NVTE_Fused_Attn_Backend.NVTE_CK: if config.qkv_layout.is_thd(): - softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1) + # THD only: check env; run small-seq logic only when enabled + if os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") != "1": + softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) + else: + batch_size = reduce(operator.mul, batch_shape) + ck_standard_softmax_aux_size = ( + batch_size * attn_heads * q_max_seqlen * 1 + ) + ck_smallseq_softmax_aux_size = ( + batch_size * attn_heads * q_max_seqlen + * min(kv_max_seqlen, 16) * 2 + ) # 2 bytes for bf16/fp16 + if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size: + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) + softmax_dtype = dtypes.canonicalize_dtype(q_dtype) + else: + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16)) + softmax_dtype = dtypes.canonicalize_dtype(q_dtype) else: softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) - softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f"Unsupported {backend=}") + + if os.environ.get("NVTE_LOG_CK_CONFIG", "0") == "1": + jax.debug.print( + "attn_fwd(ck small-seq JAX abstract): batch_shape: {}, softmax_shape: {}, softmax_dtype: {}", + batch_shape, + softmax_shape, + softmax_dtype, + ) + softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 342953746..55f5575ed 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -509,6 +509,33 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( nvte_tensor_pack_destroy(&aux_input_tensors); auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); + + const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); + if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { + size_t workspace_elems = product(work_shape); + size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype()); + size_t workspace_bytes = workspace_elems * elt_size; + size_t unfused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for unfused small-seq (bf16/fp16) + + if (workspace_bytes < unfused_small_seq_workspace) { + size_t min_elems = (unfused_small_seq_workspace + elt_size - 1) / elt_size; + work_shape = std::vector{min_elems}; + workspace_elems = min_elems; + workspace_bytes = workspace_elems * elt_size; + } + + const char* nvte_log_ck_config = std::getenv("NVTE_LOG_CK_CONFIG"); + if (nvte_log_ck_config && std::string(nvte_log_ck_config) == "1") { + std::cout << std::endl << "attn_bwd(ck unfused small-seq workspace size): "; + std::cout << "input_batch: " << input_batch << ", "; + std::cout << "is_ragged: " << is_ragged << ", "; + std::cout << "workspace_elems: " << workspace_elems << ", "; + std::cout << "workspace_bytes: " << workspace_bytes << ", "; + std::cout << "unfused_small_seq_min_bytes: " << unfused_small_seq_workspace << ", "; + std::cout << "workspace_bytes >= unfused_small_seq_workspace: " + << (workspace_bytes >= unfused_small_seq_workspace ? "true" : "false") << std::endl; + } + } return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); }