Skip to content

[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461

Open
VeeraRajasekhar wants to merge 9 commits intodevfrom
veergopu/fused-varlen-ck-smallseq-integration
Open

[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461
VeeraRajasekhar wants to merge 9 commits intodevfrom
veergopu/fused-varlen-ck-smallseq-integration

Conversation

@VeeraRajasekhar
Copy link
Contributor

Integrate the CK team's unfused variable-length attention HIP kernels from varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized path for specialized cross-attention (Q length 1, KV length 2-16, large batch)..

  • Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under fused_attn_rocm/: declarations and implementation adapted from varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output; grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.

  • Add fused_attn_smallseq.cpp to the ROCm fused-attn build in transformer_engine/common/CMakeLists.txt.

  • In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q, h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2) call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen, output_S shape, workspace size, and small-seq fwd so varlen kernel indexing matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen kernel expects sequence-level batch).

  • In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host max_seqlen_kv; on real run call get_runtime_max_seqlen then fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for get_runtime_max_seqlen, workspace size, and small-seq bwd.

  • Reuse softmax LSE auxiliary buffer for attention weights in the small-seq path (forward write, backward read);

  • JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux buffer matches C++ attention-weights convention.

  • Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD, SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in C++.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Integrate the CK team's unfused variable-length attention HIP kernels from
varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized
path for specialized cross-attention (Q length 1, KV length 2-16, large
batch)..

- Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under
  fused_attn_rocm/: declarations and implementation adapted from
  varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output;
  grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over
  max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.

- Add fused_attn_smallseq.cpp to the ROCm fused-attn build in
  transformer_engine/common/CMakeLists.txt.

- In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when
  max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size
  == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host
  max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q,
  h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2)
  call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence
  count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen,
  output_S shape, workspace size, and small-seq fwd so varlen kernel indexing
  matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen
  kernel expects sequence-level batch).

- In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query
  (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host
  max_seqlen_kv; on real run call get_runtime_max_seqlen then
  fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for
  get_runtime_max_seqlen, workspace size, and small-seq bwd.

- Reuse softmax LSE auxiliary buffer for attention weights in the small-seq
  path (forward write, backward read);

- JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and
  kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads,
  q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux
  buffer matches C++ attention-weights convention.

- Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py
  (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD,
  SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in
  C++.
@wangye805
Copy link
Collaborator

wangye805 commented Feb 25, 2026

Let's make this PR work for jax extension first. Later we can support pytorch.

One key difference btw jax and pytorch fused-attn dispatch is that pytorch can calculate, request, and allocate softmax_aux, workspace during runtime with actual cu_seqlen_q/kv data. However, in jax extension, softmax_aux and workspace calculation is done in

if backend == NVTE_Fused_Attn_Backend.NVTE_AOTriton:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
elif backend == NVTE_Fused_Attn_Backend.NVTE_CK:
if config.qkv_layout.is_thd():
softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
else:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f"Unsupported {backend=}")
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
and
wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
, without knowing actual runtime cu_seqlen_q/kv. Aux tensors are prepared in
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
, also without the knowledge of runtime cu_seqlen_q/kv

General guideline:
1). Pre-allocate large enough softmax_aux and workspace ahead of time. Do not modify the aux preparation function or the c++ level aux workspace calculation/preparation, since we know our softmax aux and workspace size will be large enough for both flow, and the special flow only need a valid start pointer address.
2). During actual kernel dispatch, we do a seqlen_q/kv check, if it satisfy the special cross-attn condition, we launch it here
3). Use an env to guard this new flow and disable it when CP is used

NVTE_CHECK(workspace != nullptr, "small-seq bwd requires workspace.");

float sqr_dk_scale = attn_scale;
hipStream_t hip_stream = reinterpret_cast<hipStream_t>(stream);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably no need for this cast. cudaStream_t will be hipified correctly to hipStream_t

- tests/jax: CK small-seq tests use fixture to set/restore
  NVTE_FUSED_ATTN_CK_SMALLSEQ=1; parametrize dtype (BF16/FP16) and add sequence-packing
  cases (2048-2-4, 2-4096-8192); when env set, num_segments_per_seq =
  max_seqlen_q for THD else 2.
- JAX attention.py: THD softmax shape/dtype uses small-seq path only when
  env=1, else original layout
- JAX attention.cpp: Added env guard
- fused_attn_smallseq: Use TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT for
  fwd/bwd; add FP16 (__half) support; fix __half*float with T(scale).
@VeeraRajasekhar
Copy link
Contributor Author

ci/jax.sh Outdated
@@ -58,6 +58,7 @@ run_test_config() {
run_default_fa 1 test_custom_call_compute.py
run_default_fa 1 test_functions.py
run 1 test_fused_attn.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to skip the small seq tests for regular ck/aiter flow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, skipped with both v2 and v3

)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm, maybe we can make a dedicated generate_random_segment_ids_small_seq to separate it from the original generate_random_segment_ids function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with a new function call which takes care of generating inputs for this new flow

}

if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) {
fused_attn_rocm::fused_attn_smallseq_fwd(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If NVTE_LOG_CK_CONFIG=1, also log whether we are running the ck-small seq flow or regular ck/aiter flow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

}

if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for logging which bwd flow to run

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

const T* V_ptr = static_cast<const T*>(devPtrV);
T* O_ptr = static_cast<T*>(devPtrO);
T* attn_workspace = static_cast<T*>(attn_weights_buffer);
const int* cu_kv = static_cast<const int*>(devPtrCuSeqlensKV);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it run into issues if we don't pass in cu_seqlen_q/cu_seqlen_q_padded?

For example, if there are several empty segments for q/kv but for all non-empty ones, s_q always equal to 1?


size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
static_cast<uint64_t>(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream));
size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we don't pass cu_seqlen_padded into runtime max seqlen check. What if max_seqlen without padding satisfy the ck kernel condition but with padding they do not? Can ck kernel handle those corner cases?

#define TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_

#include "../common.h"
#include "transformer_engine/fused_attn.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for whether we really need those header files? In this .h or .cpp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

)
ck_smallseq_softmax_aux_size = (
batch_size * attn_heads * q_max_seqlen
* min(kv_max_seqlen, 16) * 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the implementation, we only support kv_max_seqlen<=16 right? So should this be checked via an assertion instead of enforced via min?

Copy link
Contributor Author

@VeeraRajasekhar VeeraRajasekhar Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot keep that cause, we care about run_time_max_seq_len, here we don't know the run_time_max_seqlen, for example

pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"),
pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"),
this test cases s_kv is not 16 but the the num of segments and inputs are chosen in such a way that
size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream));
this returns <=16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this test would break for any case where the runtime_max_seqlen_kv is actually >16?

Comment on lines +382 to +387
if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(q_dtype)
else:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16))
softmax_dtype = dtypes.canonicalize_dtype(q_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(q_dtype)
else:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16))
softmax_dtype = dtypes.canonicalize_dtype(q_dtype)
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen)
softmax_dtype = dtypes.canonicalize_dtype(q_dtype)
if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size:
softmax_shape += (1,)
else:
softmax_shape += (min(kv_max_seqlen, 16),)

@VeeraRajasekhar VeeraRajasekhar force-pushed the veergopu/fused-varlen-ck-smallseq-integration branch from b5c5fb7 to c6e0eae Compare March 3, 2026 21:17
)
ck_smallseq_softmax_aux_size = (
batch_size * attn_heads * q_max_seqlen
* min(kv_max_seqlen, 16) * 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this test would break for any case where the runtime_max_seqlen_kv is actually >16?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants