[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461
[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461VeeraRajasekhar wants to merge 9 commits intodevfrom
Conversation
Integrate the CK team's unfused variable-length attention HIP kernels from
varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized
path for specialized cross-attention (Q length 1, KV length 2-16, large
batch)..
- Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under
fused_attn_rocm/: declarations and implementation adapted from
varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output;
grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over
max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.
- Add fused_attn_smallseq.cpp to the ROCm fused-attn build in
transformer_engine/common/CMakeLists.txt.
- In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when
max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size
== 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host
max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q,
h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2)
call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence
count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen,
output_S shape, workspace size, and small-seq fwd so varlen kernel indexing
matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen
kernel expects sequence-level batch).
- In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query
(workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host
max_seqlen_kv; on real run call get_runtime_max_seqlen then
fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for
get_runtime_max_seqlen, workspace size, and small-seq bwd.
- Reuse softmax LSE auxiliary buffer for attention weights in the small-seq
path (forward write, backward read);
- JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and
kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads,
q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux
buffer matches C++ attention-weights convention.
- Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py
(parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD,
SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in
C++.
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
|
Let's make this PR work for jax extension first. Later we can support pytorch. One key difference btw jax and pytorch fused-attn dispatch is that pytorch can calculate, request, and allocate softmax_aux, workspace during runtime with actual cu_seqlen_q/kv data. However, in jax extension, softmax_aux and workspace calculation is done in TransformerEngine/transformer_engine/jax/cpp_extensions/attention.py Lines 364 to 375 in b685686 General guideline: |
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp
Outdated
Show resolved
Hide resolved
| NVTE_CHECK(workspace != nullptr, "small-seq bwd requires workspace."); | ||
|
|
||
| float sqr_dk_scale = attn_scale; | ||
| hipStream_t hip_stream = reinterpret_cast<hipStream_t>(stream); |
There was a problem hiding this comment.
Probably no need for this cast. cudaStream_t will be hipified correctly to hipStream_t
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
…port to small-seq kernels
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
- tests/jax: CK small-seq tests use fixture to set/restore NVTE_FUSED_ATTN_CK_SMALLSEQ=1; parametrize dtype (BF16/FP16) and add sequence-packing cases (2048-2-4, 2-4096-8192); when env set, num_segments_per_seq = max_seqlen_q for THD else 2. - JAX attention.py: THD softmax shape/dtype uses small-seq path only when env=1, else original layout - JAX attention.cpp: Added env guard - fused_attn_smallseq: Use TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT for fwd/bwd; add FP16 (__half) support; fix __half*float with T(scale).
ci/jax.sh
Outdated
| @@ -58,6 +58,7 @@ run_test_config() { | |||
| run_default_fa 1 test_custom_call_compute.py | |||
| run_default_fa 1 test_functions.py | |||
| run 1 test_fused_attn.py | |||
There was a problem hiding this comment.
Do we need to skip the small seq tests for regular ck/aiter flow?
There was a problem hiding this comment.
Yes, skipped with both v2 and v3
tests/jax/test_fused_attn.py
Outdated
| ) | ||
| self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) | ||
| else: | ||
| if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": |
There was a problem hiding this comment.
Emm, maybe we can make a dedicated generate_random_segment_ids_small_seq to separate it from the original generate_random_segment_ids function
There was a problem hiding this comment.
Updated with a new function call which takes care of generating inputs for this new flow
| } | ||
|
|
||
| if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { | ||
| fused_attn_rocm::fused_attn_smallseq_fwd( |
There was a problem hiding this comment.
If NVTE_LOG_CK_CONFIG=1, also log whether we are running the ck-small seq flow or regular ck/aiter flow
| } | ||
|
|
||
| if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { | ||
|
|
There was a problem hiding this comment.
Same here for logging which bwd flow to run
| const T* V_ptr = static_cast<const T*>(devPtrV); | ||
| T* O_ptr = static_cast<T*>(devPtrO); | ||
| T* attn_workspace = static_cast<T*>(attn_weights_buffer); | ||
| const int* cu_kv = static_cast<const int*>(devPtrCuSeqlensKV); |
There was a problem hiding this comment.
Will it run into issues if we don't pass in cu_seqlen_q/cu_seqlen_q_padded?
For example, if there are several empty segments for q/kv but for all non-empty ones, s_q always equal to 1?
|
|
||
| size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| static_cast<uint64_t>(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); | ||
| size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( |
There was a problem hiding this comment.
Here we don't pass cu_seqlen_padded into runtime max seqlen check. What if max_seqlen without padding satisfy the ck kernel condition but with padding they do not? Can ck kernel handle those corner cases?
| #define TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ | ||
|
|
||
| #include "../common.h" | ||
| #include "transformer_engine/fused_attn.h" |
There was a problem hiding this comment.
Check for whether we really need those header files? In this .h or .cpp?
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
| ) | ||
| ck_smallseq_softmax_aux_size = ( | ||
| batch_size * attn_heads * q_max_seqlen | ||
| * min(kv_max_seqlen, 16) * 2 |
There was a problem hiding this comment.
Looking at the implementation, we only support kv_max_seqlen<=16 right? So should this be checked via an assertion instead of enforced via min?
There was a problem hiding this comment.
We cannot keep that cause, we care about run_time_max_seq_len, here we don't know the run_time_max_seqlen, for example
TransformerEngine/tests/jax/test_fused_attn.py
Lines 1280 to 1281 in 4537cce
TransformerEngine/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
Lines 954 to 955 in 4537cce
There was a problem hiding this comment.
So this test would break for any case where the runtime_max_seqlen_kv is actually >16?
| if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size: | ||
| softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) | ||
| softmax_dtype = dtypes.canonicalize_dtype(q_dtype) | ||
| else: | ||
| softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16)) | ||
| softmax_dtype = dtypes.canonicalize_dtype(q_dtype) |
There was a problem hiding this comment.
| if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size: | |
| softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) | |
| softmax_dtype = dtypes.canonicalize_dtype(q_dtype) | |
| else: | |
| softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16)) | |
| softmax_dtype = dtypes.canonicalize_dtype(q_dtype) | |
| softmax_shape = (*batch_shape, attn_heads, q_max_seqlen) | |
| softmax_dtype = dtypes.canonicalize_dtype(q_dtype) | |
| if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size: | |
| softmax_shape += (1,) | |
| else: | |
| softmax_shape += (min(kv_max_seqlen, 16),) |
b5c5fb7 to
c6e0eae
Compare
| ) | ||
| ck_smallseq_softmax_aux_size = ( | ||
| batch_size * attn_heads * q_max_seqlen | ||
| * min(kv_max_seqlen, 16) * 2 |
There was a problem hiding this comment.
So this test would break for any case where the runtime_max_seqlen_kv is actually >16?
Integrate the CK team's unfused variable-length attention HIP kernels from varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized path for specialized cross-attention (Q length 1, KV length 2-16, large batch)..
Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under fused_attn_rocm/: declarations and implementation adapted from varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output; grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.
Add fused_attn_smallseq.cpp to the ROCm fused-attn build in transformer_engine/common/CMakeLists.txt.
In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q, h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2) call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen, output_S shape, workspace size, and small-seq fwd so varlen kernel indexing matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen kernel expects sequence-level batch).
In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host max_seqlen_kv; on real run call get_runtime_max_seqlen then fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for get_runtime_max_seqlen, workspace size, and small-seq bwd.
Reuse softmax LSE auxiliary buffer for attention weights in the small-seq path (forward write, backward read);
JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux buffer matches C++ attention-weights convention.
Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD, SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in C++.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: