[torch.compile] Add support for non-contiguous fused RMSNorm + group quant#36551
[torch.compile] Add support for non-contiguous fused RMSNorm + group quant#36551ProExpertProg wants to merge 11 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for non-contiguous inputs in the fused RMSNorm and group quantization kernels by adding an input_stride parameter, motivated by the need to support models like Deepseek. While the implementation appears robust with input_stride logic propagated through CUDA kernels and improved safety checks for tensor contiguity, two high-severity security issues were identified. These include integer truncation of the stride parameter, which can lead to out-of-bounds memory reads on large tensors, and potential misaligned memory access in vectorized CUDA kernels when the stride is not a multiple of the vectorization factor, potentially causing kernel crashes (Denial of Service). Addressing these security concerns is critical before merging.
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
Show resolved
Hide resolved
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
Show resolved
Hide resolved
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
Show resolved
Hide resolved
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
…ic per-token quant kernels (vllm-project#36552) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
…it test Signed-off-by: Luka Govedič <lgovedic@redhat.com>
3bcc63d to
0ebf4e9
Compare
| TORCH_CHECK(hidden_size % 4 == 0, | ||
| "Hidden size must be divisible by 4 for vectorized access"); | ||
| TORCH_CHECK(input_stride % 4 == 0, | ||
| "Input stride must be divisible by 4 for vectorized access"); | ||
|
|
There was a problem hiding this comment.
nit: maybe replace with hidden_size % group_size == 0 and group_size % 4 == 0? These are the constraints that led to always vectorizing the blockwise kernel in the first place
There was a problem hiding this comment.
These constraints are actually required by the kernel in terms of vectorization though, right?
I'll add the group size check as well
There was a problem hiding this comment.
These constraints are actually required by the kernel in terms of vectorization though, right?
Yep, you can keep these checks as they are if you think it makes them easier to read this way
|
The kernel and testing parts look good to me! |
… opcheck from unit test & convert returns to skips Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
Outdated
Show resolved
Hide resolved
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
yewentao256
left a comment
There was a problem hiding this comment.
LGTM, thanks for the work!
Background
Fused
rms_norm + group fp8 quantkernel only supports contiguous inputs. This is an issue in the Deepseek case, because the norm input is a slice of theqkv_loratensor:vllm/model_executor/layers/mla.py:134-139:Current
rms_norm + quantfusion withrms_normdisabled (default) inserts redundant type conversions in between that prevent this error from happening by default. However, withrms_normenabled, an error occurs (below). With vLLM IR (#32358), these redundant type conversions are gone and so the following error occurs as well.Changes
Add
input_stridearg to allow for padded higher dims forrms_quantinput, and add appropriate unit tests. Also add deepseek to E2E fusion tests.Test Plan
Validated locally, CI, lm_eval
Test Result
lm-eval appears broken for DSv3 (#36662), fix in #36296. Below results include this PR
#36296:
#36296 + this PR:
Just this PR:
Perf
#36296 + this PR:
#36296: