-
Notifications
You must be signed in to change notification settings - Fork 23
[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16) #461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
10f7ee6
b3ef62c
db685c4
b6a5ee8
75f7cfa
c737072
4537cce
c6e0eae
366945e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| from functools import partial | ||
| from math import sqrt | ||
| from typing import Tuple, Optional, Dict | ||
| import os | ||
| import random | ||
|
|
||
| import jax | ||
|
|
@@ -329,7 +330,11 @@ class FusedAttnRunner: | |
| # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. | ||
| def _get_max_segments_per_sequence(self): | ||
| if self.qkv_layout.is_thd(): | ||
| if 90400 <= get_cudnn_version() < 90500: | ||
| if ( | ||
| 90400 <= get_cudnn_version() < 90500 | ||
| or ( is_hip_extension() and | ||
| os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1") | ||
| ): | ||
| return self.num_segments_per_seq | ||
| else: | ||
| # +1 for testing runtime_segments < max_segments | ||
|
|
@@ -418,6 +423,57 @@ def _check_configs(self): | |
| "the F16_arbitrary_seqlen backend." | ||
| ) | ||
|
|
||
| def _setup_thd_segments_ck_smallseq(self, generate_random_segment_ids): | ||
| """ | ||
| Build THD segment descriptors for the CK small-seq path (NVTE_FUSED_ATTN_CK_SMALLSEQ=1). | ||
|
|
||
| Uses num_segments_per_seq = max_seqlen_q for both Q and KV. For Q: if max_seqlen_q == 1, | ||
| uses a fixed layout (one token per batch, cu_seqlens [0,1,...,batch_size]); otherwise | ||
| generates random segments. For KV: always generates random segments. | ||
| """ | ||
| num_segments_per_seq = self.max_seqlen_q | ||
| if self.max_seqlen_q == 1: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will it run into problems if we call generate_random_segment_ids directly when self.max_seqlen_q==1? |
||
| # Q: deterministic - one segment of length 1 per batch -> cu_seqlen [0,1,2,...,batch_size] | ||
| segment_ids_q = jnp.ones((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) | ||
| segment_pos_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) | ||
| pad_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) | ||
| seqlens_q = jnp.ones((self.batch_size, 1), dtype=jnp.int32) | ||
| offsets_q = jnp.concatenate( | ||
| [ | ||
| jnp.arange(self.batch_size, dtype=jnp.int32)[:, None], | ||
| jnp.full((self.batch_size, 1), -1, dtype=jnp.int32), | ||
| ], | ||
| axis=1, | ||
| ) | ||
| else: | ||
| segment_ids_q, segment_pos_q, pad_q = generate_random_segment_ids( | ||
| self.batch_size, self.max_seqlen_q, num_segments_per_seq, seed=42 | ||
| ) | ||
| seqlens_q, offsets_q = get_seqlens_and_offsets(segment_ids_q) | ||
|
|
||
| min_segment_len = None if self.window_size is None else seqlens_q | ||
| segment_ids_kv, segment_pos_kv, pad_kv = generate_random_segment_ids( | ||
| self.batch_size, | ||
| self.max_seqlen_kv, | ||
| num_segments_per_seq, | ||
| seed=2024, | ||
| min_segment_len=min_segment_len, | ||
| ) | ||
| seqlens_kv, offsets_kv = get_seqlens_and_offsets(segment_ids_kv) | ||
| return ( | ||
| num_segments_per_seq, | ||
| segment_ids_q, | ||
| segment_pos_q, | ||
| pad_q, | ||
| seqlens_q, | ||
| offsets_q, | ||
| segment_ids_kv, | ||
| segment_pos_kv, | ||
| pad_kv, | ||
| seqlens_kv, | ||
| offsets_kv, | ||
| ) | ||
|
|
||
| def _setup_inputs(self): | ||
| self._check_configs() | ||
|
|
||
|
|
@@ -539,27 +595,42 @@ def generate_random_segment_ids( | |
| return segment_ids, segment_pos, segment_pad | ||
|
|
||
| if self.qkv_layout.is_thd(): | ||
| self.num_segments_per_seq = 2 | ||
| self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( | ||
| self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 | ||
| ) | ||
| self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) | ||
| # TODO(rewang): record only self attention and find the reason of cross attention | ||
| if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: | ||
| self.segment_ids_kv = self.segment_ids_q | ||
| self.segment_pos_kv = self.segment_pos_q | ||
| self.pad_kv = self.pad_q | ||
| else: | ||
| # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support | ||
| min_segment_len = None if self.window_size is None else self.seqlens_q | ||
| self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( | ||
| self.batch_size, | ||
| self.max_seqlen_kv, | ||
| if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe put the small_seq into self.config to replace the checking with ENV? |
||
| ( | ||
| self.num_segments_per_seq, | ||
| seed=2024, | ||
| min_segment_len=min_segment_len, | ||
| self.segment_ids_q, | ||
| self.segment_pos_q, | ||
| self.pad_q, | ||
| self.seqlens_q, | ||
| self.offsets_q, | ||
| self.segment_ids_kv, | ||
| self.segment_pos_kv, | ||
| self.pad_kv, | ||
| self.seqlens_kv, | ||
| self.offsets_kv, | ||
| ) = self._setup_thd_segments_ck_smallseq(generate_random_segment_ids) | ||
| else: | ||
| self.num_segments_per_seq = 2 | ||
| self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( | ||
| self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 | ||
| ) | ||
| self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) | ||
| self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) | ||
| # TODO(rewang): record only self attention and find the reason of cross attention | ||
| if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: | ||
| self.segment_ids_kv = self.segment_ids_q | ||
| self.segment_pos_kv = self.segment_pos_q | ||
| self.pad_kv = self.pad_q | ||
| else: | ||
| # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support | ||
| min_segment_len = None if self.window_size is None else self.seqlens_q | ||
| self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( | ||
| self.batch_size, | ||
| self.max_seqlen_kv, | ||
| self.num_segments_per_seq, | ||
| seed=2024, | ||
| min_segment_len=min_segment_len, | ||
| ) | ||
| self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) | ||
| else: | ||
| self.num_segments_per_seq = 1 | ||
| self.segment_ids_q, self.pad_q = gen_valid( | ||
|
|
@@ -1214,3 +1285,61 @@ def test_jax_new_rng(): | |
| ) | ||
| runner = FusedAttnRunner(**kwargs) | ||
| runner.test_forward() | ||
|
|
||
|
|
||
| # ROCm CK small-seq varlen tests. | ||
| @pytest.fixture | ||
| def ck_smallseq_env(monkeypatch): | ||
| """Enable CK small-seq path and disable XLA GPU graphs for these tests.""" | ||
| if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""): | ||
| pytest.skip("Test must be run with XLA_FLAGS='--xla_gpu_graph_level=0'") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure the new XLA_FLAG for cudagraph is due to the change of rocm or jax. If it's with the jax change, we can use a jax version check |
||
| monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1") | ||
| yield | ||
|
|
||
| @pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"]) | ||
| @pytest.mark.parametrize( | ||
| "b, s_q, s_kv, h_q, h_kv, d_qk, d_v", | ||
| [ | ||
| pytest.param(4000, 1, 2, 16, 16, 128, 128, id="4000-1-2-16-16-128-128"), | ||
| pytest.param(4000, 1, 4, 16, 16, 128, 128, id="4000-1-4-16-16-128-128"), | ||
| pytest.param(4000, 1, 6, 16, 16, 128, 128, id="4000-1-6-16-16-128-128"), | ||
| pytest.param(4000, 1, 8, 16, 16, 128, 128, id="4000-1-8-16-16-128-128"), | ||
| pytest.param(4000, 1, 12, 16, 16, 128, 128, id="4000-1-12-16-16-128-128"), | ||
| pytest.param(4000, 1, 16, 16, 16, 128, 128, id="4000-1-16-16-16-128-128"), | ||
| pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"), | ||
| pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"), | ||
| ], | ||
| ) | ||
| @pytest.mark.skipif( | ||
| not is_hip_extension(), reason="CK unfused smallseq backend only available on AMD hardware" | ||
| ) | ||
| def test_ck_unfused_smallseq_backend( | ||
Micky774 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ck_smallseq_env, b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype | ||
| ): | ||
| """ | ||
| Test the CK unfused small-seq (varlen) path on ROCm: s_q=1, s_kv<=16, THD layout. | ||
| Uses THD_THD_THD (Q,K,V all THD). ck_smallseq_env sets NVTE_FUSED_ATTN_CK_SMALLSEQ=1 and | ||
| restores it after the test. | ||
| """ | ||
wangye805 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| runner = FusedAttnRunner( | ||
| batch_size=b, | ||
| max_seqlen_q=s_q, | ||
| max_seqlen_kv=s_kv, | ||
| num_heads_q=h_q, | ||
| num_heads_kv=h_kv, | ||
| head_dim_qk=d_qk, | ||
| head_dim_v=d_v, | ||
| attn_bias_type=AttnBiasType.NO_BIAS, | ||
| attn_mask_type=AttnMaskType.PADDING_MASK, | ||
| dropout_prob=0.0, | ||
| use_old_rng=True, | ||
| dtype=dtype, | ||
| is_training=True, | ||
| qkv_layout=QKVLayout.THD_THD_THD, | ||
| bias_shape=None, | ||
| window_size=None, | ||
| seq_desc_format=SeqDescFormat.Seqlens, | ||
| ) | ||
| runner._setup_inputs() | ||
| # runner.test_forward() | ||
| runner.test_backward() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. | ||
| * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * License for AMD contributions = MIT. See LICENSE for more information | ||
| ************************************************************************/ | ||
|
|
@@ -9,6 +9,7 @@ | |
| #include <numeric> // Required for std::accumulate | ||
| #ifdef USE_FUSED_ATTN_CK | ||
| #include <ck_fused_attn/ck_fused_attn.hpp> | ||
| #include "fused_attn_smallseq.h" | ||
| #endif // USE_FUSED_ATTN_CK | ||
| #include "../util/cuda_runtime.h" | ||
| #include "../util/system.h" | ||
|
|
@@ -614,6 +615,40 @@ void fused_attn_ck_fwd_impl( | |
| // denote the next available section of workspace from upstream | ||
| void* workspace_next = workspace; | ||
|
|
||
| const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); | ||
| if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add another filter s_q !=s_kv here |
||
| void* max_seqlen_workspace = workspace_next; | ||
| size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| static_cast<uint64_t>(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); | ||
| size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we don't pass cu_seqlen_padded into runtime max seqlen check. What if max_seqlen without padding satisfy the ck kernel condition but with padding they do not? Can ck kernel handle those corner cases? |
||
| static_cast<uint64_t>(b), devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream)); | ||
| workspace_next = static_cast<void*>(static_cast<int8_t*>(workspace_next) + sizeof(uint64_t)); | ||
|
|
||
| if (nvte_log_ck_config) { | ||
Micky774 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| std::cout << std::endl << "attn_fwd(ck small-seq): "; | ||
| std::cout << "b: " << b << ", "; | ||
| std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; | ||
| std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", "; | ||
| std::cout << "flow: " | ||
| << (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && | ||
| runtime_max_seqlen_kv <= 16 | ||
| ? "ck-smallseq" | ||
| : "regular ck/aiter") | ||
| << std::endl; | ||
| } | ||
|
|
||
| if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { | ||
wangye805 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| fused_attn_rocm::fused_attn_smallseq_fwd( | ||
wangye805 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| b, h, hg, runtime_max_seqlen_kv, d_qk, d_v, | ||
| is_training, scaling_factor, dropout_probability, | ||
| devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxAux, | ||
| devPtrCuSeqlensKV, devPtrSeqOffsetsKV, | ||
| devPtrDropoutSeed, devPtrDropoutOffset, | ||
| dtype, workspace, workspace_size, stream); | ||
| return; | ||
| } | ||
| } | ||
|
|
||
| std::array<uint64_t, 4> q_stride; | ||
| std::array<uint64_t, 4> k_stride; | ||
| std::array<uint64_t, 4> v_stride; | ||
|
|
@@ -916,6 +951,40 @@ void fused_attn_ck_bwd_impl( | |
| // denote the next available section of workspace from upstream | ||
| void* workspace_next = workspace; | ||
|
|
||
| const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); | ||
| if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, let's add s_q!=s_kv here |
||
| void* max_seqlen_workspace = workspace_next; | ||
| size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| b, devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); | ||
| size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| b, devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream)); | ||
| workspace_next = static_cast<void*>(static_cast<int8_t*>(workspace_next) + sizeof(uint64_t)); | ||
|
|
||
| if (nvte_log_ck_config) { | ||
Micky774 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| std::cout << std::endl << "attn_bwd(ck small-seq): "; | ||
| std::cout << "b: " << b << ", "; | ||
| std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; | ||
| std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", "; | ||
| std::cout << "flow: " | ||
| << (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && | ||
| runtime_max_seqlen_kv <= 16 | ||
| ? "ck-smallseq" | ||
| : "regular ck/aiter") | ||
| << std::endl; | ||
| } | ||
|
|
||
| if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { | ||
| fused_attn_rocm::fused_attn_smallseq_bwd( | ||
| b, h, hg, runtime_max_seqlen_kv, d_qk, d_v, | ||
| scaling_factor, dropout_probability, | ||
| devPtrQ, devPtrK, devPtrV, devPtrO, devPtrdO, devPtrSoftmaxAux, | ||
| devPtrdQ, devPtrdK, devPtrdV, | ||
| devPtrCuSeqlensKV, devPtrSeqOffsetsKV, | ||
| dtype, workspace, workspace_size, stream); | ||
| return; | ||
| } | ||
| } | ||
|
|
||
| std::array<uint64_t, 4> q_stride; | ||
| std::array<uint64_t, 4> k_stride; | ||
| std::array<uint64_t, 4> v_stride; | ||
|
|
@@ -1828,7 +1897,7 @@ void fused_attn_ck_fwd( | |
| size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast<size_t>(1), std::multiplies<size_t>())/h_q/d_qk; | ||
| size_t max_tokens_kv = std::accumulate((input_K->data).shape.begin(), (input_K->data).shape.end(), static_cast<size_t>(1), std::multiplies<size_t>())/h_kv/d_qk; | ||
|
|
||
| bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; | ||
| bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; | ||
| if (Aux_CTX_Tensors->size == 0) { | ||
| if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { | ||
| Aux_CTX_Tensors->size = 3; | ||
|
|
@@ -1883,7 +1952,6 @@ void fused_attn_ck_fwd( | |
| bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); | ||
|
|
||
| fused_attn_ck_fwd_impl( | ||
| b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h, | ||
| max_tokens_q, max_tokens_kv, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For our new rocm7.2 image, the xla cudagraph disabling need to use