Skip to content

[torch.compile] Add support for non-contiguous fused RMSNorm + group quant#36551

Open
ProExpertProg wants to merge 11 commits intovllm-project:mainfrom
neuralmagic:luka/fix-rms-quant-non-contiguous
Open

[torch.compile] Add support for non-contiguous fused RMSNorm + group quant#36551
ProExpertProg wants to merge 11 commits intovllm-project:mainfrom
neuralmagic:luka/fix-rms-quant-non-contiguous

Conversation

@ProExpertProg
Copy link
Collaborator

@ProExpertProg ProExpertProg commented Mar 9, 2026

Background

Fused rms_norm + group fp8 quant kernel only supports contiguous inputs. This is an issue in the Deepseek case, because the norm input is a slice of the qkv_lora tensor:

vllm/model_executor/layers/mla.py:134-139:

q_c, kv_lora = qkv_lora.split(
    [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
    dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]

Current rms_norm + quant fusion with rms_norm disabled (default) inserts redundant type conversions in between that prevent this error from happening by default. However, with rms_norm enabled, an error occurs (below). With vLLM IR (#32358), these redundant type conversions are gone and so the following error occurs as well.

$ vllm serve deepseek-ai/DeepSeek-V3 -cc.custom_ops+=+rms_norm -tp=8
(Worker pid=1951505) (Worker_TP5 pid=1951505) ERROR 03-10 13:34:42 [multiproc_executor.py:932] RuntimeError: Expected out.is_contiguous() && input.is_contiguous() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Changes

Add input_stride arg to allow for padded higher dims for rms_quant input, and add appropriate unit tests. Also add deepseek to E2E fusion tests.

Test Plan

Validated locally, CI, lm_eval

Test Result

$ vllm serve qwen/qwen3-30b-a3b-fp8 
$ lm_eval --model local-completions --model_args pretrained=qwen/qwen3-30b-a3b-fp8,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3 --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 100
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.86 ± 0.0349
strict-match 5 exact_match 0.91 ± 0.0288

lm-eval appears broken for DSv3 (#36662), fix in #36296. Below results include this PR

$ vllm serve deepseek-ai/DeepSeek-V3
$ lm_eval --model local-completions --model_args pretrained=deepseek-ai/DeepSeek-V3,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3 --tasks gsm8k --num_fewshot 5 --batch_size auto

#36296:

vllm serve deepseek-ai/DeepSeek-V3 -tp=8

# no deepgemm
local-completions (pretrained=deepseek-ai/DeepSeek-V3,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9462 ± 0.0062
strict-match 5 exact_match 0.9462 ± 0.0062
# with deepgemm
local-completions (pretrained=deepseek-ai/DeepSeek-V3,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.95 ± 0.006
strict-match 5 exact_match 0.95 ± 0.006

#36296 + this PR:

vllm serve deepseek-ai/DeepSeek-V3 -tp=8

local-completions (pretrained=deepseek-ai/DeepSeek-V3,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9477 ± 0.0061
strict-match 5 exact_match 0.9477 ± 0.0061
vllm serve deepseek-ai/DeepSeek-V3 -cc.custom_ops+=+rms_norm -tp=8

local-completions (pretrained=deepseek-ai/DeepSeek-V3,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9507 ± 0.006
strict-match 5 exact_match 0.9507 ± 0.006

Just this PR:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.01 ± 0.01
strict-match 5 exact_match 0.00 ± 0.00

Perf

vllm bench serve --dataset-name random --ignore-eos --model=deepseek-ai/DeepSeek-V3 --num-prompts 120 --request-rate 1

#36296 + this PR:

vllm serve deepseek-ai/DeepSeek-V3 -tp=8

============ Serving Benchmark Result ============
Successful requests:                     120       
Failed requests:                         0         
Request rate configured (RPS):           1.00      
Benchmark duration (s):                  121.17    
Total input tokens:                      122760    
Total generated tokens:                  15360     
Request throughput (req/s):              0.99      
Output token throughput (tok/s):         126.77    
Peak output token throughput (tok/s):    462.00    
Peak concurrent requests:                9.00      
Total token throughput (tok/s):          1139.92   
---------------Time to First Token----------------
Mean TTFT (ms):                          106.44    
Median TTFT (ms):                        106.73    
P99 TTFT (ms):                           177.41    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.05     
Median TPOT (ms):                        9.72      
P99 TPOT (ms):                           13.91     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.05     
Median ITL (ms):                         9.15      
P99 ITL (ms):                            81.27     
==================================================

# Note that +rms_norm slows down rms_norm so perf hit is expected 
(even though fusion benefits from removing the cast that gets inserted when -rms_norm).
# vLLM IR will resolve these kind of issues cleanly
vllm serve deepseek-ai/DeepSeek-V3 -cc.custom_ops+=+rms_norm -tp=8

============ Serving Benchmark Result ============
Successful requests:                     120       
Failed requests:                         0         
Request rate configured (RPS):           1.00      
Benchmark duration (s):                  121.17    
Total input tokens:                      122760    
Total generated tokens:                  15360     
Request throughput (req/s):              0.99      
Output token throughput (tok/s):         126.77    
Peak output token throughput (tok/s):    452.00    
Peak concurrent requests:                9.00      
Total token throughput (tok/s):          1139.93   
---------------Time to First Token----------------
Mean TTFT (ms):                          107.80    
Median TTFT (ms):                        105.39    
P99 TTFT (ms):                           182.84    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.16     
Median TPOT (ms):                        9.82      
P99 TPOT (ms):                           14.03     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.16     
Median ITL (ms):                         9.27      
P99 ITL (ms):                            80.03     
==================================================

#36296:

vllm serve deepseek-ai/DeepSeek-V3 -tp=8

============ Serving Benchmark Result ============
Successful requests:                     120       
Failed requests:                         0         
Request rate configured (RPS):           1.00      
Benchmark duration (s):                  121.12    
Total input tokens:                      122760    
Total generated tokens:                  15360     
Request throughput (req/s):              0.99      
Output token throughput (tok/s):         126.82    
Peak output token throughput (tok/s):    468.00    
Peak concurrent requests:                9.00      
Total token throughput (tok/s):          1140.36   
---------------Time to First Token----------------
Mean TTFT (ms):                          192.34    
Median TTFT (ms):                        112.36    
P99 TTFT (ms):                           2195.44   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.36     
Median TPOT (ms):                        9.57      
P99 TPOT (ms):                           24.47     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.36     
Median ITL (ms):                         8.88      
P99 ITL (ms):                            84.25     
==================================================

# With deepgemm (vllm docker image)

============ Serving Benchmark Result ============
Successful requests:                     120       
Failed requests:                         0         
Request rate configured (RPS):           1.00      
Benchmark duration (s):                  121.09    
Total input tokens:                      122760    
Total generated tokens:                  15360     
Request throughput (req/s):              0.99      
Output token throughput (tok/s):         126.85    
Peak output token throughput (tok/s):    454.00    
Peak concurrent requests:                9.00      
Total token throughput (tok/s):          1140.63   
---------------Time to First Token----------------
Mean TTFT (ms):                          106.40    
Median TTFT (ms):                        107.15    
P99 TTFT (ms):                           182.38    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.52      
Median TPOT (ms):                        9.21      
P99 TPOT (ms):                           13.34     
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.52      
Median ITL (ms):                         8.66      
P99 ITL (ms):                            84.35     
==================================================

@mergify mergify bot added the ci/build label Mar 9, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for non-contiguous inputs in the fused RMSNorm and group quantization kernels by adding an input_stride parameter, motivated by the need to support models like Deepseek. While the implementation appears robust with input_stride logic propagated through CUDA kernels and improved safety checks for tensor contiguity, two high-severity security issues were identified. These include integer truncation of the stride parameter, which can lead to out-of-bounds memory reads on large tensors, and potential misaligned memory access in vectorized CUDA kernels when the stride is not a multiple of the vectorization factor, potentially causing kernel crashes (Denial of Service). Addressing these security concerns is critical before merging.

@ProExpertProg ProExpertProg marked this pull request as draft March 9, 2026 22:57
@ProExpertProg ProExpertProg changed the title [torch.compile] Add support for fused RMSNorm + group quant [torch.compile] Add support for non-contiguous fused RMSNorm + group quant Mar 10, 2026
@ProExpertProg ProExpertProg marked this pull request as ready for review March 10, 2026 14:10
@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2026
ProExpertProg and others added 7 commits March 10, 2026 11:53
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
…ic per-token quant kernels (vllm-project#36552)

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
…it test

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Comment on lines +214 to +218
TORCH_CHECK(hidden_size % 4 == 0,
"Hidden size must be divisible by 4 for vectorized access");
TORCH_CHECK(input_stride % 4 == 0,
"Input stride must be divisible by 4 for vectorized access");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe replace with hidden_size % group_size == 0 and group_size % 4 == 0? These are the constraints that led to always vectorizing the blockwise kernel in the first place

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constraints are actually required by the kernel in terms of vectorization though, right?

I'll add the group size check as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constraints are actually required by the kernel in terms of vectorization though, right?

Yep, you can keep these checks as they are if you think it makes them easier to read this way

@ElizaWszola
Copy link
Contributor

The kernel and testing parts look good to me!

… opcheck from unit test & convert returns to skips

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!
Could you remove --limit 100 in lm_eval so we can get the exact acc score?
Also, could you add a bench metrics so that we don't hurt performance? Or show how much perf could we get

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@ProExpertProg ProExpertProg enabled auto-merge (squash) March 10, 2026 23:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants