-
-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[torch.compile] Add support for non-contiguous fused RMSNorm + group quant #36551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
02a7fab
5378e99
ab2e78d
f7769d6
90c46ea
dd3a399
0ebf4e9
f247a5f
f4109e9
37a46cf
96e18cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,31 +15,33 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( | |
| scalar_t const* __restrict__ input, // [..., hidden_size] | ||
| scalar_t const* __restrict__ weight, // [hidden_size] | ||
| float const* scale_ub, float const var_epsilon, int32_t const hidden_size, | ||
| scalar_t* __restrict__ residual = nullptr) { | ||
| int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) { | ||
| float rms = 0.0f; | ||
| float token_scale = 0.0f; | ||
|
|
||
| // Compute rms | ||
| vllm::vectorized::compute_rms<scalar_t, has_residual>( | ||
| &rms, input, hidden_size, var_epsilon, residual); | ||
| &rms, input, hidden_size, input_stride, var_epsilon, residual); | ||
|
|
||
| // Compute scale | ||
| vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, | ||
| has_residual>( | ||
| &token_scale, scales, input, weight, rms, scale_ub, hidden_size, | ||
| residual); | ||
| input_stride, residual); | ||
|
|
||
| // RMS Norm + Quant | ||
| if constexpr (std::is_same_v<scalar_out_t, int8_t>) { | ||
| token_scale = 1.0f / token_scale; | ||
| vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true, | ||
| has_residual>( | ||
| out, input, weight, rms, &token_scale, hidden_size, residual); | ||
| has_residual>(out, input, weight, rms, | ||
| &token_scale, hidden_size, | ||
| input_stride, residual); | ||
| } else { | ||
| // FP8 - Do not invert token_scale for exact match with FBGemm | ||
| vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false, | ||
| has_residual>( | ||
| out, input, weight, rms, &token_scale, hidden_size, residual); | ||
| has_residual>(out, input, weight, rms, | ||
| &token_scale, hidden_size, | ||
| input_stride, residual); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -51,38 +53,40 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( | |
| scalar_t const* __restrict__ input, // [..., hidden_size] | ||
| scalar_t const* __restrict__ weight, // [hidden_size] | ||
| float const* scale_ub, float const var_epsilon, int32_t const hidden_size, | ||
| scalar_t* __restrict__ residual = nullptr) { | ||
| int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) { | ||
| // For vectorization, token_input and token_output pointers need to be | ||
| // aligned at 8-byte and 4-byte addresses respectively. | ||
| bool const can_vectorize = hidden_size % 4 == 0; | ||
| bool const can_vectorize = hidden_size % 4 == 0 and input_stride % 4 == 0; | ||
|
|
||
| if (can_vectorize) { | ||
| return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t, | ||
| has_residual>( | ||
| out, scales, input, weight, scale_ub, var_epsilon, hidden_size, | ||
| residual); | ||
| input_stride, residual); | ||
| } | ||
|
|
||
| float rms = 0.0f; | ||
| float token_scale = 0.0f; | ||
|
|
||
| // Compute RMS | ||
| vllm::compute_rms<scalar_t, has_residual>(&rms, input, hidden_size, | ||
| var_epsilon, residual); | ||
| vllm::compute_rms<scalar_t, has_residual>( | ||
| &rms, input, hidden_size, input_stride, var_epsilon, residual); | ||
| // Compute Scale | ||
| vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>( | ||
| &token_scale, scales, input, weight, rms, scale_ub, hidden_size, | ||
| residual); | ||
| input_stride, residual); | ||
|
|
||
| // RMS Norm + Quant | ||
| if constexpr (std::is_same_v<scalar_out_t, int8_t>) { | ||
| token_scale = 1.0f / token_scale; | ||
| vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>( | ||
| out, input, weight, rms, &token_scale, hidden_size, residual); | ||
| out, input, weight, rms, &token_scale, hidden_size, input_stride, | ||
| residual); | ||
| } else { | ||
| // FP8 - Do not invert s_token_scale for exact match with FBGemm | ||
| vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>( | ||
| out, input, weight, rms, &token_scale, hidden_size, residual); | ||
| out, input, weight, rms, &token_scale, hidden_size, input_stride, | ||
| residual); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -97,19 +101,20 @@ __global__ void rms_norm_per_block_quant_kernel( | |
| scalar_t const* __restrict__ input, // [..., hidden_size] | ||
| scalar_t const* __restrict__ weight, // [hidden_size] | ||
| float const* scale_ub, float const var_epsilon, int32_t const hidden_size, | ||
| scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) { | ||
| int32_t const input_stride, scalar_t* __restrict__ residual = nullptr, | ||
| int64_t outer_scale_stride = 1) { | ||
| float rms; | ||
| // Compute RMS | ||
| // Always able to vectorize due to constraints on hidden_size | ||
| vllm::vectorized::compute_rms<scalar_t, has_residual>( | ||
| &rms, input, hidden_size, var_epsilon, residual); | ||
| &rms, input, hidden_size, input_stride, var_epsilon, residual); | ||
|
|
||
| // Compute Scale | ||
| // Always able to vectorize due to constraints on hidden_size and group_size | ||
| vllm::vectorized::compute_dynamic_per_token_scales< | ||
| scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>( | ||
| nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual, | ||
| outer_scale_stride); | ||
| nullptr, scales, input, weight, rms, scale_ub, hidden_size, input_stride, | ||
| residual, outer_scale_stride); | ||
|
|
||
| // RMS Norm + Quant | ||
| // Always able to vectorize due to constraints on hidden_size | ||
|
|
@@ -120,7 +125,7 @@ __global__ void rms_norm_per_block_quant_kernel( | |
| vllm::vectorized::norm_and_quant< | ||
| scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>, | ||
| has_residual, is_scale_transposed, group_size>( | ||
| out, input, weight, rms, scales, hidden_size, residual, | ||
| out, input, weight, rms, scales, hidden_size, input_stride, residual, | ||
| outer_scale_stride); | ||
| } | ||
|
|
||
|
|
@@ -137,6 +142,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( | |
| std::optional<at::Tensor> const& scale_ub, | ||
| std::optional<at::Tensor>& residual) { | ||
| int32_t hidden_size = input.size(-1); | ||
| int32_t input_stride = input.view({-1, hidden_size}).stride(0); | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto num_tokens = input.numel() / hidden_size; | ||
|
|
||
| dim3 grid(num_tokens); | ||
|
|
@@ -153,7 +159,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( | |
| out.data_ptr<scalar_t>(), scales.data_ptr<float>(), | ||
| input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(), | ||
| scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr, | ||
| var_epsilon, hidden_size, | ||
| var_epsilon, hidden_size, input_stride, | ||
| has_residual ? residual->data_ptr<scalar_in_t>() : nullptr); | ||
| }); | ||
| }); | ||
|
|
@@ -170,7 +176,9 @@ void rms_norm_dynamic_per_token_quant( | |
| ? c10::ScalarType::Float8_e4m3fn | ||
| : c10::ScalarType::Float8_e4m3fnuz; | ||
| TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); | ||
| TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); | ||
| TORCH_CHECK(out.is_contiguous()); | ||
| TORCH_CHECK(input.stride(-1) == 1, | ||
| "Input must be contiguous in the last dimension"); | ||
|
|
||
| if (scale_ub.has_value()) { | ||
| TORCH_CHECK(out.dtype() == kFp8Type); | ||
|
|
@@ -179,6 +187,7 @@ void rms_norm_dynamic_per_token_quant( | |
| TORCH_CHECK(scales.dtype() == torch::kFloat32); | ||
| if (residual) { | ||
| TORCH_CHECK(residual->scalar_type() == input.scalar_type()); | ||
| TORCH_CHECK(residual->is_contiguous()); | ||
| } | ||
|
|
||
| VLLM_DISPATCH_FLOATING_TYPES( | ||
|
|
@@ -200,6 +209,15 @@ void rms_norm_per_block_quant_dispatch( | |
| std::optional<at::Tensor> const& scale_ub, | ||
| std::optional<at::Tensor>& residual, bool is_scale_transposed) { | ||
| int32_t hidden_size = input.size(-1); | ||
| int32_t input_stride = input.view({-1, hidden_size}).stride(0); | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| TORCH_CHECK(hidden_size % 4 == 0, | ||
| "Hidden size must be divisible by 4 for vectorized access"); | ||
| TORCH_CHECK(input_stride % 4 == 0, | ||
| "Input stride must be divisible by 4 for vectorized access"); | ||
| TORCH_CHECK(group_size % 4 == 0, | ||
| "Group size must be divisible by 4 for vectorized access"); | ||
|
|
||
|
Comment on lines
+214
to
+220
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe replace with
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These constraints are actually required by the kernel in terms of vectorization though, right? I'll add the group size check as well
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yep, you can keep these checks as they are if you think it makes them easier to read this way |
||
| auto num_tokens = input.numel() / hidden_size; | ||
|
|
||
| dim3 grid(num_tokens); | ||
|
|
@@ -225,7 +243,7 @@ void rms_norm_per_block_quant_dispatch( | |
| weight.data_ptr<scalar_in_t>(), | ||
| scale_ub.has_value() ? scale_ub->data_ptr<float>() | ||
| : nullptr, | ||
| var_epsilon, hidden_size, | ||
| var_epsilon, hidden_size, input_stride, | ||
| has_residual ? residual->data_ptr<scalar_in_t>() | ||
| : nullptr, | ||
| scales.stride(1)); | ||
|
|
@@ -246,7 +264,9 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, | |
| ? c10::ScalarType::Float8_e4m3fn | ||
| : c10::ScalarType::Float8_e4m3fnuz; | ||
| TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); | ||
| TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); | ||
| TORCH_CHECK(out.is_contiguous()); | ||
| TORCH_CHECK(input.stride(-1) == 1, | ||
| "Input must be contiguous in the last dimension"); | ||
|
|
||
| if (scale_ub.has_value()) { | ||
| TORCH_CHECK(out.dtype() == kFp8Type); | ||
|
|
@@ -255,6 +275,7 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, | |
| TORCH_CHECK(scales.dtype() == torch::kFloat32); | ||
| if (residual) { | ||
| TORCH_CHECK(residual->scalar_type() == input.scalar_type()); | ||
| TORCH_CHECK(residual->is_contiguous()); | ||
| } | ||
|
|
||
| TORCH_CHECK(group_size == 128 || group_size == 64, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.