Skip to content

[ROCm][Deepseekv3.2] Refactor Sparse Indexer as CustomOp#29287

Merged
tjtanaa merged 11 commits intovllm-project:mainfrom
ROCm:ganyi/optimized_dsv3.2
Jan 21, 2026
Merged

[ROCm][Deepseekv3.2] Refactor Sparse Indexer as CustomOp#29287
tjtanaa merged 11 commits intovllm-project:mainfrom
ROCm:ganyi/optimized_dsv3.2

Conversation

@ganyi1996ppo
Copy link
Contributor

@ganyi1996ppo ganyi1996ppo commented Nov 24, 2025

Purpose

This PR optimize the deepseekv3.2's performance on AMD's device, and separate SparseAttnIndexer out as a CustomOp as it contains lots of heavy kernels like fp8_mqa_logits or fp8_paged_mqa_logits. The solution might vary on different platform for this indexer op in order to achieve optimal performance on vllm. The main change include:

  • Separate SparseAttnIndexer out as CustomOp.
  • Integrate mla_decode_fwd to AiterMLASparseBackend to accelerate the performance on sparse mla
  • add fetch_ragged_layout triton kernel to handle the dynamic shape issue and enable the full cudagraph on decode phase
  • add _indexer_k_quant_and_cache_kernel _cp_gather_indexer_quant_cache_kernel triton kernel for preshuffle layout support
  • integrate gluon implementation of fp8_paged_mqa_logits with preshuffle layout support

Test Plan

accuracy: gsm8k
performance: vllm bench

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added deepseek Related to DeepSeek models rocm Related to AMD ROCm v1 labels Nov 24, 2025
@ganyi1996ppo ganyi1996ppo marked this pull request as ready for review November 24, 2025 13:43
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +186 to +190
self.paged_kv_last_page_len = torch.ones(
max_num_seqs, dtype=torch.int32, device=device
)

# These two needs to be calculated in runtime,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Compute paged_kv_last_page_len for non-unit block sizes

ROCm sparse MLA now runs with block_size 64 (DeepseekV32IndexerBackend.supported_kernel_block_sizes was switched to [64]), but the ROCM metadata builder still initializes paged_kv_last_page_len to all ones and never derives the actual last-page lengths before passing it into mla_decode_fwd. With block sizes larger than 1, this tells the kernel every last cache page has only a single valid token, so any decode where a sequence spans more than one cache page will mask out all but the first position of its final page, truncating attention for longer contexts.

Useful? React with 👍 / 👎.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 24, 2025

@ganyi1996ppo Can you also include the AITER commit or requirements, and also information if the AITER commit in the Dockerfile.rocm_base supports the feature? It will be easier for us to keep track and plan the merge.

@ganyi1996ppo ganyi1996ppo changed the title [ROCm][Deepseekv3.2][Perf] Performance optimize of deepseek v3.2 on AMD device [ROCm][Deepseekv3.2][Perf] Performance optimization of deepseek v3.2 on AMD device Nov 24, 2025
@ganyi1996ppo
Copy link
Contributor Author

ganyi1996ppo commented Nov 25, 2025

@ganyi1996ppo Can you also include the AITER commit or requirements, and also information if the AITER commit in the Dockerfile.rocm_base supports the feature? It will be easier for us to keep track and plan the merge.

Should works fine after this aiter commit e2a1a6f7c8628e14b28c09844ee25ef0b6f9b19d

@heheda12345
Copy link
Collaborator

@ganyi1996ppo please use 20-shot gsm-8k to verify correctness. https://docs.vllm.ai/projects/recipes/en/latest/DeepSeek/DeepSeek-V3_2-Exp.html#accuracy-benchmarking

This model selects 2048 tokens so correctness evaluation should use requests longer than 2048 tokens

num_decode_tokens = attn_metadata.num_decode_tokens

ops.indexer_k_quant_and_cache(
indexer_k_quant_cache_and_cache_func = ops.indexer_k_quant_and_cache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use CustomOp to select the kernel implementation for different platforms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @heheda12345 , Thanks for the suggestion. We actually planning to replace many kernels in sparse_attn_indexer. So perhaps it's better to wrap the sparse_attn_indexer to CustomOp. What's your thought?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that works for me

@ganyi1996ppo
Copy link
Contributor Author

@ganyi1996ppo please use 20-shot gsm-8k to verify correctness. >https://docs.vllm.ai/projects/recipes/en/latest/DeepSeek/DeepSeek-V3_2-Exp.html#accuracy-benchmarking
This model selects 2048 tokens so correctness evaluation should use requests longer than 2048 tokens

Sure, thanks for the suggestion!

@ganyi1996ppo
Copy link
Contributor Author

Update 20-shot gsm8k result:

# 20-shot
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|    20|exact_match|↑  |0.9507|±  |0.0060|
|     |       |strict-match    |    20|exact_match|↑  |0.9515|±  |0.0059|

@ganyi1996ppo
Copy link
Contributor Author

hi @heheda12345 @tjtanaa, I just update the code as we discussed before, please take a look.

@ganyi1996ppo
Copy link
Contributor Author

And @tjtanaa , we might need to wait triton update to 3.5.0 before merge this PR. Or the gluon version of fp8_paged_mqa_logits might just not be able to compile. Do you have any clue on the triton update plan on rocm platform?

@tjtanaa
Copy link
Collaborator

tjtanaa commented Dec 11, 2025

And @tjtanaa , we might need to wait triton update to 3.5.0 before merge this PR. Or the gluon version of fp8_paged_mqa_logits might just not be able to compile. Do you have any clue on the triton update plan on rocm platform?

@gshtras do you know if there are any plans for triton updates and also aiter updates?

Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/optimized_dsv3.2 branch from 5111929 to ee0a366 Compare January 21, 2026 02:25
Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments to expedite the review. I will add more in the next couple hours.

"NHD",
block_tile_size,
head_tile_size,
IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a helper function current_platform.is_fp8_fnuz(), can you also help to add cache decorator to is_fp8_fnuz ?

Copy link
Contributor Author

@ganyi1996ppo ganyi1996ppo Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found cache decorator can not be captured by torch.compile, maybe we can leave this one? This host overhead should minor to big models like dsv3.2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok sure.

return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None

if rocm_aiter_ops.is_enabled() and has_mqa_logits_module():
from aiter.ops.triton.attention.fp8_mqa_logits import fp8_mqa_logits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the AITER version used in the Dockerfile.rocm_base, the kernel existed already. However, the path is different from the one used in latest main.

The AITER code in Dockerfile.rocm_base store the ops at path from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look again, I have wrap them up to make it compatible to both aiter version


if rocm_aiter_ops.is_enabled():
from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1
from aiter.ops.triton.attention.pa_mqa_logits import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AITER used in the Dockerfile.rocm_base does not have this path yet. It should be

        from aiter.ops.triton.pa_mqa_logits import (
            deepgemm_fp8_paged_mqa_logits_stage1,
        )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also compatible too

Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@tjtanaa tjtanaa changed the title [ROCm][Deepseekv3.2][Perf] Performance optimization of deepseek v3.2 on AMD device [ROCm][Deepseekv3.2] Refactor Sparse Indexer as CustomOp Jan 21, 2026
@tjtanaa tjtanaa merged commit 6c20e89 into vllm-project:main Jan 21, 2026
65 checks passed
monajafi-amd pushed a commit to monajafi-amd/vllm that referenced this pull request Jan 23, 2026
…t#29287)

Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: mohammad najafi <mohammad.najafi@amd.com>
cwazai pushed a commit to cwazai/vllm that referenced this pull request Jan 25, 2026
…t#29287)

Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: 陈建华 <1647430658@qq.com>
lapy pushed a commit to lapy/vllm that referenced this pull request Jan 27, 2026
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants