diff --git a/.buildkite/test_areas/compile.yaml b/.buildkite/test_areas/compile.yaml index f9eccdcbbeee..5da7b64ac304 100644 --- a/.buildkite/test_areas/compile.yaml +++ b/.buildkite/test_areas/compile.yaml @@ -101,8 +101,8 @@ steps: - nvidia-smi # Run all models and attn backends but only Inductor partition and native custom ops - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and not +quant_fp8" - # Qwen requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported - - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and +quant_fp8 and qwen3" + # Qwen/Deepseek requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported + - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and +quant_fp8 and (qwen3 or deepseek)" - label: Fusion E2E Config Sweep (H100) timeout_in_minutes: 30 @@ -132,9 +132,9 @@ steps: commands: - nvidia-smi # Run all models but only FLASHINFER, Inductor partition and native custom ops - # Qwen requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported + # Qwen/Deepseek requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported # Run just llama3 (fp8 & fp4) for all config combinations (only inductor partition) - - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and (FLASHINFER and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3) or llama-3)" + - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and (FLASHINFER and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek)) or llama-3)" - label: Fusion E2E TP2 Quick (H100) timeout_in_minutes: 20 @@ -150,8 +150,8 @@ steps: commands: - nvidia-smi # Run all models and attn backends but only Inductor partition and native custom ops - - pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "inductor_partition and not +rms_norm and not +quant_fp8" - - pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "inductor_partition and not +rms_norm and not +quant_fp8" + - pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))" + - pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))" - label: Fusion E2E TP2 AR-RMS Config Sweep (H100) timeout_in_minutes: 40 @@ -205,7 +205,7 @@ steps: commands: - nvidia-smi # Run all models but only FLASHINFER, Inductor partition and native custom ops - # include qwen with +quant_fp8 as -quant_fp8 rms+quant fusion is not supported + # include qwen/deepseek with +quant_fp8 as -quant_fp8 rms+quant fusion is not supported # for ar-rms-quant-fp4, also sweep llama3 - - pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "(FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3)) or Llama-3.1-8B-Instruct-FP4" - - pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3)" + - pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "(FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))) or Llama-3.1-8B-Instruct-FP4" + - pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))" diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index b9a9b5cc7e43..e178f252624f 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -15,31 +15,33 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] float const* scale_ub, float const var_epsilon, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr) { + int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) { float rms = 0.0f; float token_scale = 0.0f; // Compute rms vllm::vectorized::compute_rms( - &rms, input, hidden_size, var_epsilon, residual); + &rms, input, hidden_size, input_stride, var_epsilon, residual); // Compute scale vllm::vectorized::compute_dynamic_per_token_scales( &token_scale, scales, input, weight, rms, scale_ub, hidden_size, - residual); + input_stride, residual); // RMS Norm + Quant if constexpr (std::is_same_v) { token_scale = 1.0f / token_scale; vllm::vectorized::norm_and_quant( - out, input, weight, rms, &token_scale, hidden_size, residual); + has_residual>(out, input, weight, rms, + &token_scale, hidden_size, + input_stride, residual); } else { // FP8 - Do not invert token_scale for exact match with FBGemm vllm::vectorized::norm_and_quant( - out, input, weight, rms, &token_scale, hidden_size, residual); + has_residual>(out, input, weight, rms, + &token_scale, hidden_size, + input_stride, residual); } } @@ -51,38 +53,40 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] float const* scale_ub, float const var_epsilon, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr) { + int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) { // For vectorization, token_input and token_output pointers need to be // aligned at 8-byte and 4-byte addresses respectively. - bool const can_vectorize = hidden_size % 4 == 0; + bool const can_vectorize = hidden_size % 4 == 0 and input_stride % 4 == 0; if (can_vectorize) { return rms_norm_dynamic_per_token_quant_vec( out, scales, input, weight, scale_ub, var_epsilon, hidden_size, - residual); + input_stride, residual); } float rms = 0.0f; float token_scale = 0.0f; // Compute RMS - vllm::compute_rms(&rms, input, hidden_size, - var_epsilon, residual); + vllm::compute_rms( + &rms, input, hidden_size, input_stride, var_epsilon, residual); // Compute Scale vllm::compute_dynamic_per_token_scales( &token_scale, scales, input, weight, rms, scale_ub, hidden_size, - residual); + input_stride, residual); // RMS Norm + Quant if constexpr (std::is_same_v) { token_scale = 1.0f / token_scale; vllm::norm_and_quant( - out, input, weight, rms, &token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, input_stride, + residual); } else { // FP8 - Do not invert s_token_scale for exact match with FBGemm vllm::norm_and_quant( - out, input, weight, rms, &token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, input_stride, + residual); } } @@ -97,19 +101,20 @@ __global__ void rms_norm_per_block_quant_kernel( scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] float const* scale_ub, float const var_epsilon, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) { + int32_t const input_stride, scalar_t* __restrict__ residual = nullptr, + int64_t outer_scale_stride = 1) { float rms; // Compute RMS // Always able to vectorize due to constraints on hidden_size vllm::vectorized::compute_rms( - &rms, input, hidden_size, var_epsilon, residual); + &rms, input, hidden_size, input_stride, var_epsilon, residual); // Compute Scale // Always able to vectorize due to constraints on hidden_size and group_size vllm::vectorized::compute_dynamic_per_token_scales< scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>( - nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual, - outer_scale_stride); + nullptr, scales, input, weight, rms, scale_ub, hidden_size, input_stride, + residual, outer_scale_stride); // RMS Norm + Quant // Always able to vectorize due to constraints on hidden_size @@ -120,7 +125,7 @@ __global__ void rms_norm_per_block_quant_kernel( vllm::vectorized::norm_and_quant< scalar_t, scalar_out_t, std::is_same_v, has_residual, is_scale_transposed, group_size>( - out, input, weight, rms, scales, hidden_size, residual, + out, input, weight, rms, scales, hidden_size, input_stride, residual, outer_scale_stride); } @@ -137,6 +142,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( std::optional const& scale_ub, std::optional& residual) { int32_t hidden_size = input.size(-1); + int32_t input_stride = input.view({-1, hidden_size}).stride(0); auto num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -153,7 +159,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, + var_epsilon, hidden_size, input_stride, has_residual ? residual->data_ptr() : nullptr); }); }); @@ -170,7 +176,9 @@ void rms_norm_dynamic_per_token_quant( ? c10::ScalarType::Float8_e4m3fn : c10::ScalarType::Float8_e4m3fnuz; TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); - TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.stride(-1) == 1, + "Input must be contiguous in the last dimension"); if (scale_ub.has_value()) { TORCH_CHECK(out.dtype() == kFp8Type); @@ -179,6 +187,7 @@ void rms_norm_dynamic_per_token_quant( TORCH_CHECK(scales.dtype() == torch::kFloat32); if (residual) { TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + TORCH_CHECK(residual->is_contiguous()); } VLLM_DISPATCH_FLOATING_TYPES( @@ -200,6 +209,15 @@ void rms_norm_per_block_quant_dispatch( std::optional const& scale_ub, std::optional& residual, bool is_scale_transposed) { int32_t hidden_size = input.size(-1); + int32_t input_stride = input.view({-1, hidden_size}).stride(0); + + TORCH_CHECK(hidden_size % 4 == 0, + "Hidden size must be divisible by 4 for vectorized access"); + TORCH_CHECK(input_stride % 4 == 0, + "Input stride must be divisible by 4 for vectorized access"); + TORCH_CHECK(group_size % 4 == 0, + "Group size must be divisible by 4 for vectorized access"); + auto num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -225,7 +243,7 @@ void rms_norm_per_block_quant_dispatch( weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, + var_epsilon, hidden_size, input_stride, has_residual ? residual->data_ptr() : nullptr, scales.stride(1)); @@ -246,7 +264,9 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, ? c10::ScalarType::Float8_e4m3fn : c10::ScalarType::Float8_e4m3fnuz; TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); - TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.stride(-1) == 1, + "Input must be contiguous in the last dimension"); if (scale_ub.has_value()) { TORCH_CHECK(out.dtype() == kFp8Type); @@ -255,6 +275,7 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, TORCH_CHECK(scales.dtype() == torch::kFloat32); if (residual) { TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + TORCH_CHECK(residual->is_contiguous()); } TORCH_CHECK(group_size == 128 || group_size == 64, diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index edf4024f0d49..1f0d583523c8 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -16,14 +16,17 @@ namespace vllm { // has_residual must be true, if residual is not a nullptr template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, - int32_t const hidden_size, float const epsilon, + int32_t const hidden_size, + int32_t const input_stride, float const epsilon, scalar_t const* __restrict__ residual = nullptr) { + int64_t const input_token_offset = + blockIdx.x * static_cast(input_stride); int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // sum of squares float ss = 0.0f; for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float x = static_cast(input[token_offset + i]); + float x = static_cast(input[input_token_offset + i]); if constexpr (has_residual) { x += static_cast(residual[token_offset + i]); } @@ -73,15 +76,20 @@ __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, + int32_t const hidden_size, int32_t const input_stride, + scalar_t const* __restrict__ residual = nullptr, int32_t const group_size = 0, int64_t outer_scale_stride = 1) { float block_absmax_val_maybe = 0.0f; constexpr scalar_out_t qmax{quant_type_max_v}; __syncthreads(); + + int64_t const input_token_offset = + blockIdx.x * static_cast(input_stride); + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + if (group_size > 0) { - __shared__ float s_max_vals[1024]; - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); int64_t num_groups = hidden_size / group_size; + __shared__ float s_max_vals[1024]; int64_t const threads_per_group = blockDim.x / num_groups; int64_t const thread_in_group = threadIdx.x % threads_per_group; int64_t const group_offset = threadIdx.x / threads_per_group * group_size; @@ -89,7 +97,7 @@ __device__ void compute_dynamic_per_token_scales( int64_t const thread_end = min(group_offset + group_size, static_cast(hidden_size)); for (auto i = thread_offset; i < thread_end; i += threads_per_group) { - float x = static_cast(input[token_offset + i]); + float x = static_cast(input[input_token_offset + i]); if constexpr (has_residual) { x += static_cast(residual[token_offset + i]); } @@ -144,10 +152,8 @@ __device__ void compute_dynamic_per_token_scales( } __syncthreads(); } else { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float x = static_cast(input[token_offset + i]); + float x = static_cast(input[input_token_offset + i]); if constexpr (has_residual) { x += static_cast(residual[token_offset + i]); } @@ -185,12 +191,15 @@ template (input_stride); int64_t const token_offset = blockIdx.x * static_cast(hidden_size); for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float x = static_cast(input[token_offset + i]); + float x = static_cast(input[input_token_offset + i]); if constexpr (has_residual) { x += static_cast(residual[token_offset + i]); residual[token_offset + i] = static_cast(x); @@ -224,13 +233,16 @@ namespace vectorized { // hidden_size must be a multiple of 4 template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, - int32_t const hidden_size, float const epsilon, + int32_t const hidden_size, + int32_t const input_stride, float const epsilon, scalar_t const* __restrict__ residual = nullptr) { + int64_t const input_token_offset = + blockIdx.x * static_cast(input_stride); int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // Vectorized input/output to better utilize memory bandwidth. vec4_t const* vec_input = - reinterpret_cast const*>(&input[token_offset]); + reinterpret_cast const*>(&input[input_token_offset]); vec4_t const* vec_residual = nullptr; if constexpr (has_residual) { vec_residual = @@ -288,7 +300,8 @@ __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, + int32_t const hidden_size, int32_t const input_stride, + scalar_t const* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) { constexpr scalar_out_t qmax{quant_type_max_v}; @@ -300,10 +313,13 @@ __device__ void compute_dynamic_per_token_scales( vec4_t const* vec_weight = nullptr; vec4_t const* vec_residual = nullptr; + int64_t const input_token_offset = + blockIdx.x * static_cast(input_stride); + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + if constexpr (group_size > 0) { __shared__ float s_max_vals[1024]; - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); int64_t const num_groups = hidden_size / group_size; int64_t const threads_per_group = blockDim.x / num_groups; int64_t const thread_in_group = threadIdx.x % threads_per_group; @@ -312,7 +328,8 @@ __device__ void compute_dynamic_per_token_scales( int64_t const thread_offset = group_offset + thread_in_group; int64_t const thread_end = min(group_offset + (group_size >> 2), static_cast(hidden_size >> 2)); - vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_input = + reinterpret_cast const*>(&input[input_token_offset]); vec_weight = reinterpret_cast const*>(weight); if constexpr (has_residual) { vec_residual = @@ -396,8 +413,8 @@ __device__ void compute_dynamic_per_token_scales( __syncthreads(); } else { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_input = + reinterpret_cast const*>(&input[input_token_offset]); vec_weight = reinterpret_cast const*>(weight); if constexpr (has_residual) { vec_residual = @@ -462,18 +479,18 @@ __device__ void compute_dynamic_per_token_scales( template -__device__ void norm_and_quant(scalar_out_t* __restrict__ output, - scalar_t const* __restrict__ input, - scalar_t const* __restrict__ weight, - float const rms, float* const scale, - int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr, - int64_t outer_scale_stride = 1) { +__device__ void norm_and_quant( + scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, + scalar_t const* __restrict__ weight, float const rms, float* const scale, + int32_t const hidden_size, int32_t const input_stride, + scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) { + int64_t const input_token_offset = + blockIdx.x * static_cast(input_stride); int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // Vectorized input/output/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = - reinterpret_cast const*>(&input[token_offset]); + reinterpret_cast const*>(&input[input_token_offset]); vec4_t const* vec_weight = reinterpret_cast const*>(weight); q8x4_t* vec_output = diff --git a/tests/compile/fusions_e2e/conftest.py b/tests/compile/fusions_e2e/conftest.py index 29eb8425183c..873f92cfe6ce 100644 --- a/tests/compile/fusions_e2e/conftest.py +++ b/tests/compile/fusions_e2e/conftest.py @@ -72,6 +72,16 @@ def run( rocm_aiter_ops.refresh_env_variables() + # Filter here to reduce code duplication + requires_mla = "deepseek" in model_name.lower() + is_mla = "mla" in attn_backend.backend.name.lower() + + if requires_mla != is_mla: + pytest.skip( + f"Incompatible model '{model_name}' and " + f"attention backend '{attn_backend.backend.name}'" + ) + # Disable, compile cache to make sure custom passes run. # Otherwise, we can't verify fusion happened through the logs. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") diff --git a/tests/compile/fusions_e2e/models.py b/tests/compile/fusions_e2e/models.py index e18bc1ee5652..9d6c202648e2 100644 --- a/tests/compile/fusions_e2e/models.py +++ b/tests/compile/fusions_e2e/models.py @@ -44,6 +44,20 @@ ), ) +FLASHINFER_MLA_ATTN = pytest.param( + AttentionBackendCase(backend=AttentionBackendEnum.FLASHINFER_MLA), + id="FLASHINFER_MLA", + marks=pytest.mark.skipif( + not is_blackwell() or not has_flashinfer(), + reason="FI backend requires Blackwell and FlashInfer", + ), +) + +TRITON_MLA_ATTN = pytest.param( + AttentionBackendCase(backend=AttentionBackendEnum.TRITON_MLA), + id="TRITON_MLA", +) + # Models llama3_8b = ModelFusionInfo( model_name="meta-llama/Llama-3.1-8B-Instruct", @@ -126,3 +140,25 @@ async_tp=n_layers * 2, ), ) + +deepseek_v3_fp8 = ModelFusionInfo( + model_name="deepseek-ai/DeepSeek-V3", + matches=lambda n_layers: Matches( + # 3 per dense layer (first 3): + # - input_rms + qkv_proj + # - q_a_layernorm + q_b_proj (inside MLA wrapper) + # - post_attn_layernorm + MLP + # 2 per MoE layer (remaining) due to MoE wrapping + rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers + # TODO silu+block quant + # act_quant_fusion=min(3, n_layers), # dense layers only + act_quant_fusion=0, + # MLA attn + quant not supported yet: + # https://github.com/vllm-project/vllm/issues/35792 + attn_quant_fusion=0, + ar_rms_fusion=n_layers * 2 + 1, + # TODO + # sequence_parallel= n_layers * 2 + 1, + # async_tp=n_layers * 2, + ), +) diff --git a/tests/compile/fusions_e2e/test_tp1_quant.py b/tests/compile/fusions_e2e/test_tp1_quant.py index 917116515f89..8895dadcecc9 100644 --- a/tests/compile/fusions_e2e/test_tp1_quant.py +++ b/tests/compile/fusions_e2e/test_tp1_quant.py @@ -17,9 +17,12 @@ ) from .models import ( FLASHINFER_ATTN, + FLASHINFER_MLA_ATTN, ROCM_AITER_UNIFIED_ATTN, ROCM_ATTN, TRITON_ATTN, + TRITON_MLA_ATTN, + deepseek_v3_fp8, llama3_8b_fp4, llama3_8b_fp8, llama4_scout_fp4, @@ -33,6 +36,9 @@ [ (*llama3_8b_fp8, False), (*qwen3_a3b_fp8, False), + (*qwen3_a3b_fp8, True), + (*deepseek_v3_fp8, False), + (*deepseek_v3_fp8, True), pytest.param( *llama4_scout_fp8, False, @@ -41,13 +47,6 @@ reason="Llama4 Scout FP8 only supported on CUDA", ), ), - pytest.param( - *qwen3_a3b_fp8, - True, - marks=pytest.mark.skipif( - not current_platform.is_cuda(), reason="DeepGemm only supported on CUDA" - ), - ), ], ) @pytest.mark.parametrize( @@ -57,6 +56,8 @@ FLASHINFER_ATTN, ROCM_ATTN, ROCM_AITER_UNIFIED_ATTN, + FLASHINFER_MLA_ATTN, + TRITON_MLA_ATTN, ], ) @pytest.mark.parametrize("n_layers", [6]) @@ -75,6 +76,9 @@ def test_tp1_fp8_fusions( run_e2e_fusion_test, monkeypatch, ): + if use_deepgemm and not current_platform.is_cuda(): + pytest.skip("DeepGemm only supported on CUDA") + if use_deepgemm and is_flashinfer_fp8_blockscale_gemm_supported(): # Flashinfer block FP8 GEMM has internal quantization, so it can't # be fused with other ops. @@ -86,7 +90,8 @@ def test_tp1_fp8_fusions( matches = matches_fn(n_layers) - if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops: + block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower() + if block_fp8 and "-quant_fp8" in custom_ops: # This is why config forces +quant_fp8 by default pytest.skip("native QuantFP8 matching not supported for group quant") diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index ab4aefcaf79a..8ffadbfaf298 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -17,7 +17,9 @@ ) from .models import ( FLASHINFER_ATTN, + FLASHINFER_MLA_ATTN, TRITON_ATTN, + deepseek_v3_fp8, llama3_8b, llama3_8b_fp4, llama3_8b_fp8, @@ -33,10 +35,12 @@ @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( "model_name, matches_fn, model_kwargs, hf_overrides", - # qwen3-fp8 should still fuse AR+rms even though group quant is not yet supported - [llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8], + # qwen3 & dsv3 should still fuse AR+rms even though group quant is not yet supported + [llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8, deepseek_v3_fp8], +) +@pytest.mark.parametrize( + "attn_backend", [TRITON_ATTN, FLASHINFER_ATTN, FLASHINFER_MLA_ATTN] ) -@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN]) @pytest.mark.parametrize("n_layers", [4]) @pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) @@ -54,7 +58,8 @@ def test_tp2_ar_rms_fp8_fusions( ): matches = matches_fn(n_layers) - if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops: + block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower() + if block_fp8 and "-quant_fp8" in custom_ops: # This is why config forces +quant_fp8 by default pytest.skip("native QuantFP8 matching not supported for group quant") diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 751f17dd960e..b7e6ce386b84 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -162,6 +162,7 @@ def ops_impl( ) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("strided_input", [False, True]) @torch.inference_mode() def test_rms_norm( default_vllm_config, @@ -175,6 +176,7 @@ def test_rms_norm( tma_alignment: int, seed: int, device: str, + strided_input: bool, ) -> None: torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -184,17 +186,17 @@ def test_rms_norm( if group_size is not None and hidden_size % group_size[1] != 0: # skip - return + pytest.skip("Skip non-divisible group sizes") if group_size is not None and has_scale_ub: # blockwise baseline doesn't support scale_ub - return + pytest.skip("scale_ub not supported for blockwise/group quantization") if ( group_size is None or quant_dtype != current_platform.fp8_dtype() ) and tma_alignment != 0: # TMA alignment is only supported for groupwise fp8 kernels - return + pytest.skip("tma alignment not supported for per-token or int8 quantization") if ( group_size is not None @@ -202,21 +204,36 @@ def test_rms_norm( and hidden_size // group_size[1] % tma_alignment == 0 ): # Skip tests where TMA alignment doesn't create extra padding to save time - return + pytest.skip("Skip TMA alignment cases where no extra padding is added") if has_scale_ub and quant_dtype != current_platform.fp8_dtype(): # skip - return + pytest.skip("scale_ub only supported for fp8 quantization") layer = RMSNorm(hidden_size, EPS).to(dtype=dtype) # Make weights layer.weight.data.normal_(mean=1.0, std=0.1) - # Make inputs + # Make inputs: use a wider tensor and slice to create a non-contiguous + # (strided) input when strided_input=True. The last dimension stride + # remains 1, which the kernel requires. scale = 1 / (hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale - residual = torch.randn_like(x) * scale if add_residual else None + last_dim = 2 * hidden_size if strided_input else hidden_size + x = torch.randn(num_tokens, last_dim, dtype=dtype) * scale + x = x[:, :hidden_size] + + # dim 1 gets special-cased + x_is_strided = strided_input and num_tokens != 1 + # check that the input is strided iff we expect it to be + assert x.is_contiguous() != x_is_strided + + # Residual must still be contiguous + residual = ( + torch.randn(num_tokens, hidden_size, dtype=dtype) * scale + if add_residual + else None + ) if has_scale_ub: rms_x, _ = ref_rms_norm(layer, x, residual) scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda") @@ -260,12 +277,33 @@ def test_rms_norm( if add_residual: assert torch.allclose(ref_residual, ops_residual) - output = torch.empty_like(x, dtype=quant_dtype) + output = torch.empty(x.shape, dtype=quant_dtype, device=x.device) scales = torch.empty( (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 ) - opcheck( - torch.ops._C.rms_norm_dynamic_per_token_quant, - (output, x, layer.weight, scales, 1e-5, scale_ub, residual), - ) + if group_size is None: + opcheck( + torch.ops._C.rms_norm_dynamic_per_token_quant, + (output, x, layer.weight, scales, 1e-5, scale_ub, residual), + ) + else: + # TODO(luka/eliza) opcheck is broken? + # Somehow the cloned args are getting mutated in-place, + # which causes the opcheck to fail. + # https://github.com/vllm-project/vllm/issues/36688 + return + opcheck( + torch.ops._C.rms_norm_per_block_quant, + ( + output, + x, + layer.weight, + scales, + 1e-5, + scale_ub, + residual, + group_size[1], + True, # is_scale_transposed + ), + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index dd2cca9b7443..fdc468d3b25d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -427,7 +427,7 @@ def rms_norm_dynamic_per_token_quant( scale_ub: torch.Tensor | None = None, residual: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - output = torch.empty_like(input, dtype=quant_dtype) + output = torch.empty(input.shape, dtype=quant_dtype, device=input.device) scales = torch.empty( (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 ) @@ -451,7 +451,7 @@ def rms_norm_per_block_quant( tma_alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: assert len(group_size) == 2 - output = torch.empty_like(input, dtype=quant_dtype) + output = torch.empty(input.shape, dtype=quant_dtype, device=input.device) if is_scale_transposed: if tma_alignment == 0: scales = torch.empty(