From 02a7fabdcaabc28f20735e2de4b8bc2eff621786 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 9 Mar 2026 16:45:50 -0400 Subject: [PATCH 01/11] Add support for non-contiguous input for rms-quant (dynamic & block) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- ...fused_layernorm_dynamic_per_token_quant.cu | 49 ++++++++++++------- .../fused_kernels/layernorm_utils.cuh | 31 ++++++------ 2 files changed, 46 insertions(+), 34 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index b9a9b5cc7e43..bff593af64bb 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -15,13 +15,13 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] float const* scale_ub, float const var_epsilon, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr) { + int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) { float rms = 0.0f; float token_scale = 0.0f; // Compute rms vllm::vectorized::compute_rms( - &rms, input, hidden_size, var_epsilon, residual); + &rms, input, hidden_size, input_stride, var_epsilon, residual); // Compute scale vllm::vectorized::compute_dynamic_per_token_scales) { token_scale = 1.0f / token_scale; vllm::vectorized::norm_and_quant( - out, input, weight, rms, &token_scale, hidden_size, residual); + has_residual>(out, input, weight, rms, + &token_scale, hidden_size, + input_stride, residual); } else { // FP8 - Do not invert token_scale for exact match with FBGemm vllm::vectorized::norm_and_quant( - out, input, weight, rms, &token_scale, hidden_size, residual); + has_residual>(out, input, weight, rms, + &token_scale, hidden_size, + input_stride, residual); } } @@ -51,7 +53,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] float const* scale_ub, float const var_epsilon, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr) { + int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) { // For vectorization, token_input and token_output pointers need to be // aligned at 8-byte and 4-byte addresses respectively. bool const can_vectorize = hidden_size % 4 == 0; @@ -60,15 +62,15 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( return rms_norm_dynamic_per_token_quant_vec( out, scales, input, weight, scale_ub, var_epsilon, hidden_size, - residual); + input_stride, residual); } float rms = 0.0f; float token_scale = 0.0f; // Compute RMS - vllm::compute_rms(&rms, input, hidden_size, - var_epsilon, residual); + vllm::compute_rms( + &rms, input, hidden_size, input_stride, var_epsilon, residual); // Compute Scale vllm::compute_dynamic_per_token_scales( &token_scale, scales, input, weight, rms, scale_ub, hidden_size, @@ -78,11 +80,13 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( if constexpr (std::is_same_v) { token_scale = 1.0f / token_scale; vllm::norm_and_quant( - out, input, weight, rms, &token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, input_stride, + residual); } else { // FP8 - Do not invert s_token_scale for exact match with FBGemm vllm::norm_and_quant( - out, input, weight, rms, &token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, input_stride, + residual); } } @@ -97,12 +101,13 @@ __global__ void rms_norm_per_block_quant_kernel( scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] float const* scale_ub, float const var_epsilon, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) { + int32_t const input_stride, scalar_t* __restrict__ residual = nullptr, + int64_t outer_scale_stride = 1) { float rms; // Compute RMS // Always able to vectorize due to constraints on hidden_size vllm::vectorized::compute_rms( - &rms, input, hidden_size, var_epsilon, residual); + &rms, input, hidden_size, input_stride, var_epsilon, residual); // Compute Scale // Always able to vectorize due to constraints on hidden_size and group_size @@ -120,7 +125,7 @@ __global__ void rms_norm_per_block_quant_kernel( vllm::vectorized::norm_and_quant< scalar_t, scalar_out_t, std::is_same_v, has_residual, is_scale_transposed, group_size>( - out, input, weight, rms, scales, hidden_size, residual, + out, input, weight, rms, scales, hidden_size, input_stride, residual, outer_scale_stride); } @@ -137,6 +142,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( std::optional const& scale_ub, std::optional& residual) { int32_t hidden_size = input.size(-1); + int32_t input_stride = input.view({-1, hidden_size}).stride(0); auto num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -153,7 +159,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, + var_epsilon, hidden_size, input_stride, has_residual ? residual->data_ptr() : nullptr); }); }); @@ -170,7 +176,9 @@ void rms_norm_dynamic_per_token_quant( ? c10::ScalarType::Float8_e4m3fn : c10::ScalarType::Float8_e4m3fnuz; TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); - TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.stride(-1) == 1, + "Input must be contiguous in the last dimension"); if (scale_ub.has_value()) { TORCH_CHECK(out.dtype() == kFp8Type); @@ -200,6 +208,7 @@ void rms_norm_per_block_quant_dispatch( std::optional const& scale_ub, std::optional& residual, bool is_scale_transposed) { int32_t hidden_size = input.size(-1); + int32_t input_stride = input.view({-1, hidden_size}).stride(0); auto num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -225,7 +234,7 @@ void rms_norm_per_block_quant_dispatch( weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, + var_epsilon, hidden_size, input_stride, has_residual ? residual->data_ptr() : nullptr, scales.stride(1)); @@ -246,7 +255,9 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, ? c10::ScalarType::Float8_e4m3fn : c10::ScalarType::Float8_e4m3fnuz; TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); - TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.stride(-1) == 1, + "Input must be contiguous in the last dimension"); if (scale_ub.has_value()) { TORCH_CHECK(out.dtype() == kFp8Type); diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index edf4024f0d49..0397c13d340a 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -16,9 +16,10 @@ namespace vllm { // has_residual must be true, if residual is not a nullptr template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, - int32_t const hidden_size, float const epsilon, + int32_t const hidden_size, + int32_t const input_stride, float const epsilon, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + int64_t const token_offset = blockIdx.x * static_cast(input_stride); // sum of squares float ss = 0.0f; @@ -185,9 +186,10 @@ template (hidden_size); + int32_t const hidden_size, int32_t const input_stride, + scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0, + int64_t outer_scale_stride = 1) { + int64_t const token_offset = blockIdx.x * static_cast(input_stride); for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { float x = static_cast(input[token_offset + i]); @@ -224,9 +226,10 @@ namespace vectorized { // hidden_size must be a multiple of 4 template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, - int32_t const hidden_size, float const epsilon, + int32_t const hidden_size, + int32_t const input_stride, float const epsilon, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + int64_t const token_offset = blockIdx.x * static_cast(input_stride); // Vectorized input/output to better utilize memory bandwidth. vec4_t const* vec_input = @@ -462,14 +465,12 @@ __device__ void compute_dynamic_per_token_scales( template -__device__ void norm_and_quant(scalar_out_t* __restrict__ output, - scalar_t const* __restrict__ input, - scalar_t const* __restrict__ weight, - float const rms, float* const scale, - int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr, - int64_t outer_scale_stride = 1) { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); +__device__ void norm_and_quant( + scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, + scalar_t const* __restrict__ weight, float const rms, float* const scale, + int32_t const hidden_size, int32_t const input_stride, + scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) { + int64_t const token_offset = blockIdx.x * static_cast(input_stride); // Vectorized input/output/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = From 5378e99edce701f76335b9de390d7b5e9203e758 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 9 Mar 2026 17:32:18 -0400 Subject: [PATCH 02/11] Fix kernel offset calculation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- ...fused_layernorm_dynamic_per_token_quant.cu | 9 +-- .../fused_kernels/layernorm_utils.cuh | 56 ++++++++++++------- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index bff593af64bb..09bb89fec2d8 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -27,7 +27,7 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( vllm::vectorized::compute_dynamic_per_token_scales( &token_scale, scales, input, weight, rms, scale_ub, hidden_size, - residual); + input_stride, residual); // RMS Norm + Quant if constexpr (std::is_same_v) { @@ -74,7 +74,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( // Compute Scale vllm::compute_dynamic_per_token_scales( &token_scale, scales, input, weight, rms, scale_ub, hidden_size, - residual); + input_stride, residual); // RMS Norm + Quant if constexpr (std::is_same_v) { @@ -113,8 +113,8 @@ __global__ void rms_norm_per_block_quant_kernel( // Always able to vectorize due to constraints on hidden_size and group_size vllm::vectorized::compute_dynamic_per_token_scales< scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>( - nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual, - outer_scale_stride); + nullptr, scales, input, weight, rms, scale_ub, hidden_size, input_stride, + residual, outer_scale_stride); // RMS Norm + Quant // Always able to vectorize due to constraints on hidden_size @@ -266,6 +266,7 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, TORCH_CHECK(scales.dtype() == torch::kFloat32); if (residual) { TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + TORCH_CHECK(residual->is_contiguous()); } TORCH_CHECK(group_size == 128 || group_size == 64, diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 0397c13d340a..1f0d583523c8 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -19,12 +19,14 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, int32_t const hidden_size, int32_t const input_stride, float const epsilon, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(input_stride); + int64_t const input_token_offset = + blockIdx.x * static_cast(input_stride); + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // sum of squares float ss = 0.0f; for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float x = static_cast(input[token_offset + i]); + float x = static_cast(input[input_token_offset + i]); if constexpr (has_residual) { x += static_cast(residual[token_offset + i]); } @@ -74,15 +76,20 @@ __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, + int32_t const hidden_size, int32_t const input_stride, + scalar_t const* __restrict__ residual = nullptr, int32_t const group_size = 0, int64_t outer_scale_stride = 1) { float block_absmax_val_maybe = 0.0f; constexpr scalar_out_t qmax{quant_type_max_v}; __syncthreads(); + + int64_t const input_token_offset = + blockIdx.x * static_cast(input_stride); + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + if (group_size > 0) { - __shared__ float s_max_vals[1024]; - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); int64_t num_groups = hidden_size / group_size; + __shared__ float s_max_vals[1024]; int64_t const threads_per_group = blockDim.x / num_groups; int64_t const thread_in_group = threadIdx.x % threads_per_group; int64_t const group_offset = threadIdx.x / threads_per_group * group_size; @@ -90,7 +97,7 @@ __device__ void compute_dynamic_per_token_scales( int64_t const thread_end = min(group_offset + group_size, static_cast(hidden_size)); for (auto i = thread_offset; i < thread_end; i += threads_per_group) { - float x = static_cast(input[token_offset + i]); + float x = static_cast(input[input_token_offset + i]); if constexpr (has_residual) { x += static_cast(residual[token_offset + i]); } @@ -145,10 +152,8 @@ __device__ void compute_dynamic_per_token_scales( } __syncthreads(); } else { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float x = static_cast(input[token_offset + i]); + float x = static_cast(input[input_token_offset + i]); if constexpr (has_residual) { x += static_cast(residual[token_offset + i]); } @@ -189,10 +194,12 @@ __device__ void norm_and_quant( int32_t const hidden_size, int32_t const input_stride, scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0, int64_t outer_scale_stride = 1) { - int64_t const token_offset = blockIdx.x * static_cast(input_stride); + int64_t const input_token_offset = + blockIdx.x * static_cast(input_stride); + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float x = static_cast(input[token_offset + i]); + float x = static_cast(input[input_token_offset + i]); if constexpr (has_residual) { x += static_cast(residual[token_offset + i]); residual[token_offset + i] = static_cast(x); @@ -229,11 +236,13 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, int32_t const hidden_size, int32_t const input_stride, float const epsilon, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(input_stride); + int64_t const input_token_offset = + blockIdx.x * static_cast(input_stride); + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // Vectorized input/output to better utilize memory bandwidth. vec4_t const* vec_input = - reinterpret_cast const*>(&input[token_offset]); + reinterpret_cast const*>(&input[input_token_offset]); vec4_t const* vec_residual = nullptr; if constexpr (has_residual) { vec_residual = @@ -291,7 +300,8 @@ __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, + int32_t const hidden_size, int32_t const input_stride, + scalar_t const* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) { constexpr scalar_out_t qmax{quant_type_max_v}; @@ -303,10 +313,13 @@ __device__ void compute_dynamic_per_token_scales( vec4_t const* vec_weight = nullptr; vec4_t const* vec_residual = nullptr; + int64_t const input_token_offset = + blockIdx.x * static_cast(input_stride); + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + if constexpr (group_size > 0) { __shared__ float s_max_vals[1024]; - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); int64_t const num_groups = hidden_size / group_size; int64_t const threads_per_group = blockDim.x / num_groups; int64_t const thread_in_group = threadIdx.x % threads_per_group; @@ -315,7 +328,8 @@ __device__ void compute_dynamic_per_token_scales( int64_t const thread_offset = group_offset + thread_in_group; int64_t const thread_end = min(group_offset + (group_size >> 2), static_cast(hidden_size >> 2)); - vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_input = + reinterpret_cast const*>(&input[input_token_offset]); vec_weight = reinterpret_cast const*>(weight); if constexpr (has_residual) { vec_residual = @@ -399,8 +413,8 @@ __device__ void compute_dynamic_per_token_scales( __syncthreads(); } else { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_input = + reinterpret_cast const*>(&input[input_token_offset]); vec_weight = reinterpret_cast const*>(weight); if constexpr (has_residual) { vec_residual = @@ -470,11 +484,13 @@ __device__ void norm_and_quant( scalar_t const* __restrict__ weight, float const rms, float* const scale, int32_t const hidden_size, int32_t const input_stride, scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) { - int64_t const token_offset = blockIdx.x * static_cast(input_stride); + int64_t const input_token_offset = + blockIdx.x * static_cast(input_stride); + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // Vectorized input/output/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = - reinterpret_cast const*>(&input[token_offset]); + reinterpret_cast const*>(&input[input_token_offset]); vec4_t const* vec_weight = reinterpret_cast const*>(weight); q8x4_t* vec_output = From ab2e78d8b735a069eb8035b23bfc6332b1dd492d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 9 Mar 2026 17:36:49 -0400 Subject: [PATCH 03/11] Add deepseek to tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/fusions_e2e/conftest.py | 10 ++++++ tests/compile/fusions_e2e/models.py | 36 ++++++++++++++++++++ tests/compile/fusions_e2e/test_tp1_quant.py | 21 +++++++----- tests/compile/fusions_e2e/test_tp2_ar_rms.py | 13 ++++--- 4 files changed, 68 insertions(+), 12 deletions(-) diff --git a/tests/compile/fusions_e2e/conftest.py b/tests/compile/fusions_e2e/conftest.py index 29eb8425183c..873f92cfe6ce 100644 --- a/tests/compile/fusions_e2e/conftest.py +++ b/tests/compile/fusions_e2e/conftest.py @@ -72,6 +72,16 @@ def run( rocm_aiter_ops.refresh_env_variables() + # Filter here to reduce code duplication + requires_mla = "deepseek" in model_name.lower() + is_mla = "mla" in attn_backend.backend.name.lower() + + if requires_mla != is_mla: + pytest.skip( + f"Incompatible model '{model_name}' and " + f"attention backend '{attn_backend.backend.name}'" + ) + # Disable, compile cache to make sure custom passes run. # Otherwise, we can't verify fusion happened through the logs. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") diff --git a/tests/compile/fusions_e2e/models.py b/tests/compile/fusions_e2e/models.py index e18bc1ee5652..526892a5bf1b 100644 --- a/tests/compile/fusions_e2e/models.py +++ b/tests/compile/fusions_e2e/models.py @@ -44,6 +44,20 @@ ), ) +FLASHINFER_MLA_ATTN = pytest.param( + AttentionBackendCase(backend=AttentionBackendEnum.FLASHINFER_MLA), + id="FLASHINFER_MLA", + marks=pytest.mark.skipif( + not is_blackwell() or not has_flashinfer(), + reason="FI backend requires Blackwell and FlashInfer", + ), +) + +TRITON_MLA_ATTN = pytest.param( + AttentionBackendCase(backend=AttentionBackendEnum.TRITON_MLA), + id="TRITON_MLA", +) + # Models llama3_8b = ModelFusionInfo( model_name="meta-llama/Llama-3.1-8B-Instruct", @@ -126,3 +140,25 @@ async_tp=n_layers * 2, ), ) + +deepseek_v3_fp8 = ModelFusionInfo( + model_name="deepseek-ai/DeepSeek-V3", + matches=lambda n_layers: Matches( + # 3 per dense layer (first 3): + # - input_rms + qkv_proj + # - q_a_layernorm + q_b_proj (inside MLA wrapper) + # - post_attn_layernorm + MLP + # 2 per MoE layer (remaining) due to MoE wrapping + rms_quant_fusion=n_layers * 2 + + min(3, n_layers), # add extra for 3 dense layers + # TODO silu+block quant + # act_quant_fusion=min(3, n_layers), # dense layers only + # MLA attn + quant not supported yet: + # https://github.com/vllm-project/vllm/issues/35792 + attn_quant_fusion=0, + ar_rms_fusion=n_layers * 2 + 1, + # TODO + # sequence_parallel= n_layers * 2 + 1, + # async_tp=n_layers * 2, + ), +) diff --git a/tests/compile/fusions_e2e/test_tp1_quant.py b/tests/compile/fusions_e2e/test_tp1_quant.py index 917116515f89..8895dadcecc9 100644 --- a/tests/compile/fusions_e2e/test_tp1_quant.py +++ b/tests/compile/fusions_e2e/test_tp1_quant.py @@ -17,9 +17,12 @@ ) from .models import ( FLASHINFER_ATTN, + FLASHINFER_MLA_ATTN, ROCM_AITER_UNIFIED_ATTN, ROCM_ATTN, TRITON_ATTN, + TRITON_MLA_ATTN, + deepseek_v3_fp8, llama3_8b_fp4, llama3_8b_fp8, llama4_scout_fp4, @@ -33,6 +36,9 @@ [ (*llama3_8b_fp8, False), (*qwen3_a3b_fp8, False), + (*qwen3_a3b_fp8, True), + (*deepseek_v3_fp8, False), + (*deepseek_v3_fp8, True), pytest.param( *llama4_scout_fp8, False, @@ -41,13 +47,6 @@ reason="Llama4 Scout FP8 only supported on CUDA", ), ), - pytest.param( - *qwen3_a3b_fp8, - True, - marks=pytest.mark.skipif( - not current_platform.is_cuda(), reason="DeepGemm only supported on CUDA" - ), - ), ], ) @pytest.mark.parametrize( @@ -57,6 +56,8 @@ FLASHINFER_ATTN, ROCM_ATTN, ROCM_AITER_UNIFIED_ATTN, + FLASHINFER_MLA_ATTN, + TRITON_MLA_ATTN, ], ) @pytest.mark.parametrize("n_layers", [6]) @@ -75,6 +76,9 @@ def test_tp1_fp8_fusions( run_e2e_fusion_test, monkeypatch, ): + if use_deepgemm and not current_platform.is_cuda(): + pytest.skip("DeepGemm only supported on CUDA") + if use_deepgemm and is_flashinfer_fp8_blockscale_gemm_supported(): # Flashinfer block FP8 GEMM has internal quantization, so it can't # be fused with other ops. @@ -86,7 +90,8 @@ def test_tp1_fp8_fusions( matches = matches_fn(n_layers) - if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops: + block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower() + if block_fp8 and "-quant_fp8" in custom_ops: # This is why config forces +quant_fp8 by default pytest.skip("native QuantFP8 matching not supported for group quant") diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index ab4aefcaf79a..8ffadbfaf298 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -17,7 +17,9 @@ ) from .models import ( FLASHINFER_ATTN, + FLASHINFER_MLA_ATTN, TRITON_ATTN, + deepseek_v3_fp8, llama3_8b, llama3_8b_fp4, llama3_8b_fp8, @@ -33,10 +35,12 @@ @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( "model_name, matches_fn, model_kwargs, hf_overrides", - # qwen3-fp8 should still fuse AR+rms even though group quant is not yet supported - [llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8], + # qwen3 & dsv3 should still fuse AR+rms even though group quant is not yet supported + [llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8, deepseek_v3_fp8], +) +@pytest.mark.parametrize( + "attn_backend", [TRITON_ATTN, FLASHINFER_ATTN, FLASHINFER_MLA_ATTN] ) -@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN]) @pytest.mark.parametrize("n_layers", [4]) @pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) @@ -54,7 +58,8 @@ def test_tp2_ar_rms_fp8_fusions( ): matches = matches_fn(n_layers) - if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops: + block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower() + if block_fp8 and "-quant_fp8" in custom_ops: # This is why config forces +quant_fp8 by default pytest.skip("native QuantFP8 matching not supported for group quant") From f7769d6e34b94a8c4695257c040c582ae4d2ba6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 9 Mar 2026 17:55:40 -0400 Subject: [PATCH 04/11] Add residual contiguous assert check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 09bb89fec2d8..723b903b6c20 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -187,6 +187,7 @@ void rms_norm_dynamic_per_token_quant( TORCH_CHECK(scales.dtype() == torch::kFloat32); if (residual) { TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + TORCH_CHECK(residual->is_contiguous()); } VLLM_DISPATCH_FLOATING_TYPES( From 90c46ea6021bcb1406e91acf6ffe83810065d279 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 9 Mar 2026 17:59:43 -0400 Subject: [PATCH 05/11] Add E2E tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .buildkite/test_areas/compile.yaml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.buildkite/test_areas/compile.yaml b/.buildkite/test_areas/compile.yaml index f9eccdcbbeee..5da7b64ac304 100644 --- a/.buildkite/test_areas/compile.yaml +++ b/.buildkite/test_areas/compile.yaml @@ -101,8 +101,8 @@ steps: - nvidia-smi # Run all models and attn backends but only Inductor partition and native custom ops - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and not +quant_fp8" - # Qwen requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported - - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and +quant_fp8 and qwen3" + # Qwen/Deepseek requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported + - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and +quant_fp8 and (qwen3 or deepseek)" - label: Fusion E2E Config Sweep (H100) timeout_in_minutes: 30 @@ -132,9 +132,9 @@ steps: commands: - nvidia-smi # Run all models but only FLASHINFER, Inductor partition and native custom ops - # Qwen requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported + # Qwen/Deepseek requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported # Run just llama3 (fp8 & fp4) for all config combinations (only inductor partition) - - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and (FLASHINFER and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3) or llama-3)" + - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and (FLASHINFER and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek)) or llama-3)" - label: Fusion E2E TP2 Quick (H100) timeout_in_minutes: 20 @@ -150,8 +150,8 @@ steps: commands: - nvidia-smi # Run all models and attn backends but only Inductor partition and native custom ops - - pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "inductor_partition and not +rms_norm and not +quant_fp8" - - pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "inductor_partition and not +rms_norm and not +quant_fp8" + - pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))" + - pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))" - label: Fusion E2E TP2 AR-RMS Config Sweep (H100) timeout_in_minutes: 40 @@ -205,7 +205,7 @@ steps: commands: - nvidia-smi # Run all models but only FLASHINFER, Inductor partition and native custom ops - # include qwen with +quant_fp8 as -quant_fp8 rms+quant fusion is not supported + # include qwen/deepseek with +quant_fp8 as -quant_fp8 rms+quant fusion is not supported # for ar-rms-quant-fp4, also sweep llama3 - - pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "(FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3)) or Llama-3.1-8B-Instruct-FP4" - - pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3)" + - pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "(FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))) or Llama-3.1-8B-Instruct-FP4" + - pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))" From dd3a399e5c36375d5afd06314ebb695fe3ab75dd Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 18:51:02 -0400 Subject: [PATCH 06/11] Add non-contiguous input tests for rms_norm_per_block_quant and dynamic per-token quant kernels (#36552) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com> --- .../core/test_fused_quant_layernorm.py | 36 +++++++++++++------ vllm/_custom_ops.py | 4 +-- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 751f17dd960e..13f9ee02b7de 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -162,6 +162,7 @@ def ops_impl( ) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("strided_input", [False, True]) @torch.inference_mode() def test_rms_norm( default_vllm_config, @@ -175,6 +176,7 @@ def test_rms_norm( tma_alignment: int, seed: int, device: str, + strided_input: bool, ) -> None: torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -213,10 +215,20 @@ def test_rms_norm( # Make weights layer.weight.data.normal_(mean=1.0, std=0.1) - # Make inputs + # Make inputs: use a wider tensor and slice to create a non-contiguous + # (strided) input when strided_input=True. The last dimension stride + # remains 1, which the kernel requires. scale = 1 / (hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale - residual = torch.randn_like(x) * scale if add_residual else None + last_dim = 2 * hidden_size if strided_input else hidden_size + x = torch.randn(num_tokens, last_dim, dtype=dtype) * scale + x = x[:, :hidden_size] + assert x.is_contiguous() != strided_input + # Residual must be contiguous since the kernel requires it. + residual = ( + torch.randn(num_tokens, hidden_size, dtype=dtype) * scale + if add_residual + else None + ) if has_scale_ub: rms_x, _ = ref_rms_norm(layer, x, residual) scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda") @@ -260,12 +272,14 @@ def test_rms_norm( if add_residual: assert torch.allclose(ref_residual, ops_residual) - output = torch.empty_like(x, dtype=quant_dtype) - scales = torch.empty( - (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 - ) + # opcheck uses contiguous tensors (strided inputs are tested above). + if not strided_input: + output = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + scales = torch.empty( + (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 + ) - opcheck( - torch.ops._C.rms_norm_dynamic_per_token_quant, - (output, x, layer.weight, scales, 1e-5, scale_ub, residual), - ) + opcheck( + torch.ops._C.rms_norm_dynamic_per_token_quant, + (output, x, layer.weight, scales, 1e-5, scale_ub, residual), + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index dd2cca9b7443..fdc468d3b25d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -427,7 +427,7 @@ def rms_norm_dynamic_per_token_quant( scale_ub: torch.Tensor | None = None, residual: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - output = torch.empty_like(input, dtype=quant_dtype) + output = torch.empty(input.shape, dtype=quant_dtype, device=input.device) scales = torch.empty( (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 ) @@ -451,7 +451,7 @@ def rms_norm_per_block_quant( tma_alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: assert len(group_size) == 2 - output = torch.empty_like(input, dtype=quant_dtype) + output = torch.empty(input.shape, dtype=quant_dtype, device=input.device) if is_scale_transposed: if tma_alignment == 0: scales = torch.empty( From 0ebf4e969b43d99c240fd085703ea1ed97897499 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 10 Mar 2026 11:53:44 -0400 Subject: [PATCH 07/11] Add checks for hidden size and input stride divisibility by 4, fix unit test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- ...fused_layernorm_dynamic_per_token_quant.cu | 6 +++ .../core/test_fused_quant_layernorm.py | 37 ++++++++++++++----- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 723b903b6c20..1dcd464eb2e1 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -210,6 +210,12 @@ void rms_norm_per_block_quant_dispatch( std::optional& residual, bool is_scale_transposed) { int32_t hidden_size = input.size(-1); int32_t input_stride = input.view({-1, hidden_size}).stride(0); + + TORCH_CHECK(hidden_size % 4 == 0, + "Hidden size must be divisible by 4 for vectorized access"); + TORCH_CHECK(input_stride % 4 == 0, + "Input stride must be divisible by 4 for vectorized access"); + auto num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 13f9ee02b7de..58992bdd0563 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -186,7 +186,7 @@ def test_rms_norm( if group_size is not None and hidden_size % group_size[1] != 0: # skip - return + pytest.skip("Skip non-divisible group sizes") if group_size is not None and has_scale_ub: # blockwise baseline doesn't support scale_ub @@ -222,8 +222,13 @@ def test_rms_norm( last_dim = 2 * hidden_size if strided_input else hidden_size x = torch.randn(num_tokens, last_dim, dtype=dtype) * scale x = x[:, :hidden_size] - assert x.is_contiguous() != strided_input - # Residual must be contiguous since the kernel requires it. + + # dim 1 gets special-cased + x_is_strided = strided_input and num_tokens != 1 + # check that the input is strided iff we expect it to be + assert x.is_contiguous() != x_is_strided + + # Residual must still be contiguous residual = ( torch.randn(num_tokens, hidden_size, dtype=dtype) * scale if add_residual @@ -272,14 +277,28 @@ def test_rms_norm( if add_residual: assert torch.allclose(ref_residual, ops_residual) - # opcheck uses contiguous tensors (strided inputs are tested above). - if not strided_input: - output = torch.empty(x.shape, dtype=quant_dtype, device=x.device) - scales = torch.empty( - (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 - ) + output = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + scales = torch.empty( + (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 + ) + if group_size is None: opcheck( torch.ops._C.rms_norm_dynamic_per_token_quant, (output, x, layer.weight, scales, 1e-5, scale_ub, residual), ) + else: + opcheck( + torch.ops._C.rms_norm_per_block_quant, + ( + output, + x, + layer.weight, + scales, + 1e-5, + scale_ub, + residual, + group_size[1], + True, # is_scale_transposed + ), + ) From f247a5f37d7bb666a1c145a06bbadc696605bddd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 10 Mar 2026 13:56:10 -0400 Subject: [PATCH 08/11] Add checks for hidden size and input stride divisibility by 4, remove opcheck from unit test & convert returns to skips MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .../fused_layernorm_dynamic_per_token_quant.cu | 2 ++ tests/kernels/core/test_fused_quant_layernorm.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 1dcd464eb2e1..2aab120935e0 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -215,6 +215,8 @@ void rms_norm_per_block_quant_dispatch( "Hidden size must be divisible by 4 for vectorized access"); TORCH_CHECK(input_stride % 4 == 0, "Input stride must be divisible by 4 for vectorized access"); + TORCH_CHECK(group_size % 4 == 0, + "Group size must be divisible by 4 for vectorized access"); auto num_tokens = input.numel() / hidden_size; diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 58992bdd0563..b7e6ce386b84 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -190,13 +190,13 @@ def test_rms_norm( if group_size is not None and has_scale_ub: # blockwise baseline doesn't support scale_ub - return + pytest.skip("scale_ub not supported for blockwise/group quantization") if ( group_size is None or quant_dtype != current_platform.fp8_dtype() ) and tma_alignment != 0: # TMA alignment is only supported for groupwise fp8 kernels - return + pytest.skip("tma alignment not supported for per-token or int8 quantization") if ( group_size is not None @@ -204,11 +204,11 @@ def test_rms_norm( and hidden_size // group_size[1] % tma_alignment == 0 ): # Skip tests where TMA alignment doesn't create extra padding to save time - return + pytest.skip("Skip TMA alignment cases where no extra padding is added") if has_scale_ub and quant_dtype != current_platform.fp8_dtype(): # skip - return + pytest.skip("scale_ub only supported for fp8 quantization") layer = RMSNorm(hidden_size, EPS).to(dtype=dtype) @@ -288,6 +288,11 @@ def test_rms_norm( (output, x, layer.weight, scales, 1e-5, scale_ub, residual), ) else: + # TODO(luka/eliza) opcheck is broken? + # Somehow the cloned args are getting mutated in-place, + # which causes the opcheck to fail. + # https://github.com/vllm-project/vllm/issues/36688 + return opcheck( torch.ops._C.rms_norm_per_block_quant, ( From f4109e9de23ba2104d212abba388d45a772af87f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 10 Mar 2026 14:14:26 -0400 Subject: [PATCH 09/11] Apply suggestion from @ProExpertProg MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/fusions_e2e/models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/compile/fusions_e2e/models.py b/tests/compile/fusions_e2e/models.py index 526892a5bf1b..5138c3b86db5 100644 --- a/tests/compile/fusions_e2e/models.py +++ b/tests/compile/fusions_e2e/models.py @@ -149,8 +149,7 @@ # - q_a_layernorm + q_b_proj (inside MLA wrapper) # - post_attn_layernorm + MLP # 2 per MoE layer (remaining) due to MoE wrapping - rms_quant_fusion=n_layers * 2 - + min(3, n_layers), # add extra for 3 dense layers + rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers # TODO silu+block quant # act_quant_fusion=min(3, n_layers), # dense layers only # MLA attn + quant not supported yet: From 37a46cf7f3d48d5a337bb258ed1719df8a939e65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 10 Mar 2026 14:16:51 -0400 Subject: [PATCH 10/11] Apply suggestion from @ProExpertProg MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- tests/compile/fusions_e2e/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/compile/fusions_e2e/models.py b/tests/compile/fusions_e2e/models.py index 5138c3b86db5..9d6c202648e2 100644 --- a/tests/compile/fusions_e2e/models.py +++ b/tests/compile/fusions_e2e/models.py @@ -152,6 +152,7 @@ rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers # TODO silu+block quant # act_quant_fusion=min(3, n_layers), # dense layers only + act_quant_fusion=0, # MLA attn + quant not supported yet: # https://github.com/vllm-project/vllm/issues/35792 attn_quant_fusion=0, From 96e18cfa17337594f43fa714a39a6431431ca5e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 10 Mar 2026 18:04:00 -0400 Subject: [PATCH 11/11] Check input_stride for vectorization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- .../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 2aab120935e0..e178f252624f 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -56,7 +56,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) { // For vectorization, token_input and token_output pointers need to be // aligned at 8-byte and 4-byte addresses respectively. - bool const can_vectorize = hidden_size % 4 == 0; + bool const can_vectorize = hidden_size % 4 == 0 and input_stride % 4 == 0; if (can_vectorize) { return rms_norm_dynamic_per_token_quant_vec