[ROCm][Deepseekv3.2] Refactor Sparse Indexer as CustomOp#29287
[ROCm][Deepseekv3.2] Refactor Sparse Indexer as CustomOp#29287tjtanaa merged 11 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| self.paged_kv_last_page_len = torch.ones( | ||
| max_num_seqs, dtype=torch.int32, device=device | ||
| ) | ||
|
|
||
| # These two needs to be calculated in runtime, |
There was a problem hiding this comment.
Compute paged_kv_last_page_len for non-unit block sizes
ROCm sparse MLA now runs with block_size 64 (DeepseekV32IndexerBackend.supported_kernel_block_sizes was switched to [64]), but the ROCM metadata builder still initializes paged_kv_last_page_len to all ones and never derives the actual last-page lengths before passing it into mla_decode_fwd. With block sizes larger than 1, this tells the kernel every last cache page has only a single valid token, so any decode where a sequence spans more than one cache page will mask out all but the first position of its final page, truncating attention for longer contexts.
Useful? React with 👍 / 👎.
|
@ganyi1996ppo Can you also include the AITER commit or requirements, and also information if the AITER commit in the |
Should works fine after this aiter commit |
|
@ganyi1996ppo please use 20-shot gsm-8k to verify correctness. https://docs.vllm.ai/projects/recipes/en/latest/DeepSeek/DeepSeek-V3_2-Exp.html#accuracy-benchmarking This model selects 2048 tokens so correctness evaluation should use requests longer than 2048 tokens |
| num_decode_tokens = attn_metadata.num_decode_tokens | ||
|
|
||
| ops.indexer_k_quant_and_cache( | ||
| indexer_k_quant_cache_and_cache_func = ops.indexer_k_quant_and_cache |
There was a problem hiding this comment.
can you use CustomOp to select the kernel implementation for different platforms?
There was a problem hiding this comment.
hi @heheda12345 , Thanks for the suggestion. We actually planning to replace many kernels in sparse_attn_indexer. So perhaps it's better to wrap the sparse_attn_indexer to CustomOp. What's your thought?
There was a problem hiding this comment.
Yes that works for me
Sure, thanks for the suggestion! |
|
Update 20-shot gsm8k result: |
b9b2b37 to
83df375
Compare
|
hi @heheda12345 @tjtanaa, I just update the code as we discussed before, please take a look. |
|
And @tjtanaa , we might need to wait triton update to 3.5.0 before merge this PR. Or the gluon version of |
@gshtras do you know if there are any plans for triton updates and also aiter updates? |
0f82d58 to
f9d3e0f
Compare
f9d3e0f to
5111929
Compare
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
5111929 to
ee0a366
Compare
tjtanaa
left a comment
There was a problem hiding this comment.
I added some comments to expedite the review. I will add more in the next couple hours.
| "NHD", | ||
| block_tile_size, | ||
| head_tile_size, | ||
| IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, |
There was a problem hiding this comment.
there is a helper function current_platform.is_fp8_fnuz(), can you also help to add cache decorator to is_fp8_fnuz ?
There was a problem hiding this comment.
I found cache decorator can not be captured by torch.compile, maybe we can leave this one? This host overhead should minor to big models like dsv3.2
| return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None | ||
|
|
||
| if rocm_aiter_ops.is_enabled() and has_mqa_logits_module(): | ||
| from aiter.ops.triton.attention.fp8_mqa_logits import fp8_mqa_logits |
There was a problem hiding this comment.
In the AITER version used in the Dockerfile.rocm_base, the kernel existed already. However, the path is different from the one used in latest main.
The AITER code in Dockerfile.rocm_base store the ops at path from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits
There was a problem hiding this comment.
Please take a look again, I have wrap them up to make it compatible to both aiter version
|
|
||
| if rocm_aiter_ops.is_enabled(): | ||
| from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1 | ||
| from aiter.ops.triton.attention.pa_mqa_logits import ( |
There was a problem hiding this comment.
The AITER used in the Dockerfile.rocm_base does not have this path yet. It should be
from aiter.ops.triton.pa_mqa_logits import (
deepgemm_fp8_paged_mqa_logits_stage1,
)
There was a problem hiding this comment.
This is also compatible too
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
…t#29287) Signed-off-by: ganyi <ygan@amd.com> Signed-off-by: mohammad najafi <mohammad.najafi@amd.com>
…t#29287) Signed-off-by: ganyi <ygan@amd.com> Signed-off-by: 陈建华 <1647430658@qq.com>
…t#29287) Signed-off-by: ganyi <ygan@amd.com>
…t#29287) Signed-off-by: ganyi <ygan@amd.com>
Purpose
This PR optimize the deepseekv3.2's performance on AMD's device, and separate
SparseAttnIndexerout as aCustomOpas it contains lots of heavy kernels likefp8_mqa_logitsorfp8_paged_mqa_logits. The solution might vary on different platform for this indexer op in order to achieve optimal performance on vllm. The main change include:SparseAttnIndexerout asCustomOp.mla_decode_fwdtoAiterMLASparseBackendto accelerate the performance on sparse mlafetch_ragged_layouttriton kernel to handle the dynamic shape issue and enable the full cudagraph on decode phase_indexer_k_quant_and_cache_kernel_cp_gather_indexer_quant_cache_kerneltriton kernel for preshuffle layout supportfp8_paged_mqa_logitswith preshuffle layout supportTest Plan
accuracy: gsm8k
performance: vllm bench
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.