Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .buildkite/test_areas/compile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ steps:
- nvidia-smi
# Run all models and attn backends but only Inductor partition and native custom ops
- pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and not +quant_fp8"
# Qwen requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported
- pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and +quant_fp8 and qwen3"
# Qwen/Deepseek requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported
- pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and +quant_fp8 and (qwen3 or deepseek)"

- label: Fusion E2E Config Sweep (H100)
timeout_in_minutes: 30
Expand Down Expand Up @@ -132,9 +132,9 @@ steps:
commands:
- nvidia-smi
# Run all models but only FLASHINFER, Inductor partition and native custom ops
# Qwen requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported
# Qwen/Deepseek requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported
# Run just llama3 (fp8 & fp4) for all config combinations (only inductor partition)
- pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and (FLASHINFER and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3) or llama-3)"
- pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and (FLASHINFER and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek)) or llama-3)"

- label: Fusion E2E TP2 Quick (H100)
timeout_in_minutes: 20
Expand All @@ -150,8 +150,8 @@ steps:
commands:
- nvidia-smi
# Run all models and attn backends but only Inductor partition and native custom ops
- pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "inductor_partition and not +rms_norm and not +quant_fp8"
- pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "inductor_partition and not +rms_norm and not +quant_fp8"
- pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))"
- pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))"

- label: Fusion E2E TP2 AR-RMS Config Sweep (H100)
timeout_in_minutes: 40
Expand Down Expand Up @@ -205,7 +205,7 @@ steps:
commands:
- nvidia-smi
# Run all models but only FLASHINFER, Inductor partition and native custom ops
# include qwen with +quant_fp8 as -quant_fp8 rms+quant fusion is not supported
# include qwen/deepseek with +quant_fp8 as -quant_fp8 rms+quant fusion is not supported
# for ar-rms-quant-fp4, also sweep llama3
- pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "(FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3)) or Llama-3.1-8B-Instruct-FP4"
- pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3)"
- pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "(FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))) or Llama-3.1-8B-Instruct-FP4"
- pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))"
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,33 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) {
float rms = 0.0f;
float token_scale = 0.0f;

// Compute rms
vllm::vectorized::compute_rms<scalar_t, has_residual>(
&rms, input, hidden_size, var_epsilon, residual);
&rms, input, hidden_size, input_stride, var_epsilon, residual);

// Compute scale
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, hidden_size,
residual);
input_stride, residual);

// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
token_scale = 1.0f / token_scale;
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
has_residual>(
out, input, weight, rms, &token_scale, hidden_size, residual);
has_residual>(out, input, weight, rms,
&token_scale, hidden_size,
input_stride, residual);
} else {
// FP8 - Do not invert token_scale for exact match with FBGemm
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
has_residual>(
out, input, weight, rms, &token_scale, hidden_size, residual);
has_residual>(out, input, weight, rms,
&token_scale, hidden_size,
input_stride, residual);
}
}

Expand All @@ -51,38 +53,40 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) {
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;
bool const can_vectorize = hidden_size % 4 == 0 and input_stride % 4 == 0;

if (can_vectorize) {
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
has_residual>(
out, scales, input, weight, scale_ub, var_epsilon, hidden_size,
residual);
input_stride, residual);
}

float rms = 0.0f;
float token_scale = 0.0f;

// Compute RMS
vllm::compute_rms<scalar_t, has_residual>(&rms, input, hidden_size,
var_epsilon, residual);
vllm::compute_rms<scalar_t, has_residual>(
&rms, input, hidden_size, input_stride, var_epsilon, residual);
// Compute Scale
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, hidden_size,
residual);
input_stride, residual);

// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
token_scale = 1.0f / token_scale;
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
out, input, weight, rms, &token_scale, hidden_size, residual);
out, input, weight, rms, &token_scale, hidden_size, input_stride,
residual);
} else {
// FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
out, input, weight, rms, &token_scale, hidden_size, residual);
out, input, weight, rms, &token_scale, hidden_size, input_stride,
residual);
}
}

Expand All @@ -97,19 +101,20 @@ __global__ void rms_norm_per_block_quant_kernel(
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) {
int32_t const input_stride, scalar_t* __restrict__ residual = nullptr,
int64_t outer_scale_stride = 1) {
float rms;
// Compute RMS
// Always able to vectorize due to constraints on hidden_size
vllm::vectorized::compute_rms<scalar_t, has_residual>(
&rms, input, hidden_size, var_epsilon, residual);
&rms, input, hidden_size, input_stride, var_epsilon, residual);

// Compute Scale
// Always able to vectorize due to constraints on hidden_size and group_size
vllm::vectorized::compute_dynamic_per_token_scales<
scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>(
nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual,
outer_scale_stride);
nullptr, scales, input, weight, rms, scale_ub, hidden_size, input_stride,
residual, outer_scale_stride);

// RMS Norm + Quant
// Always able to vectorize due to constraints on hidden_size
Expand All @@ -120,7 +125,7 @@ __global__ void rms_norm_per_block_quant_kernel(
vllm::vectorized::norm_and_quant<
scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>,
has_residual, is_scale_transposed, group_size>(
out, input, weight, rms, scales, hidden_size, residual,
out, input, weight, rms, scales, hidden_size, input_stride, residual,
outer_scale_stride);
}

Expand All @@ -137,6 +142,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual) {
int32_t hidden_size = input.size(-1);
int32_t input_stride = input.view({-1, hidden_size}).stride(0);
auto num_tokens = input.numel() / hidden_size;

dim3 grid(num_tokens);
Expand All @@ -153,7 +159,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, hidden_size,
var_epsilon, hidden_size, input_stride,
has_residual ? residual->data_ptr<scalar_in_t>() : nullptr);
});
});
Expand All @@ -170,7 +176,9 @@ void rms_norm_dynamic_per_token_quant(
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.stride(-1) == 1,
"Input must be contiguous in the last dimension");

if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
Expand All @@ -179,6 +187,7 @@ void rms_norm_dynamic_per_token_quant(
TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
TORCH_CHECK(residual->is_contiguous());
}

VLLM_DISPATCH_FLOATING_TYPES(
Expand All @@ -200,6 +209,15 @@ void rms_norm_per_block_quant_dispatch(
std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual, bool is_scale_transposed) {
int32_t hidden_size = input.size(-1);
int32_t input_stride = input.view({-1, hidden_size}).stride(0);

TORCH_CHECK(hidden_size % 4 == 0,
"Hidden size must be divisible by 4 for vectorized access");
TORCH_CHECK(input_stride % 4 == 0,
"Input stride must be divisible by 4 for vectorized access");
TORCH_CHECK(group_size % 4 == 0,
"Group size must be divisible by 4 for vectorized access");

Comment on lines +214 to +220
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe replace with hidden_size % group_size == 0 and group_size % 4 == 0? These are the constraints that led to always vectorizing the blockwise kernel in the first place

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constraints are actually required by the kernel in terms of vectorization though, right?

I'll add the group size check as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constraints are actually required by the kernel in terms of vectorization though, right?

Yep, you can keep these checks as they are if you think it makes them easier to read this way

auto num_tokens = input.numel() / hidden_size;

dim3 grid(num_tokens);
Expand All @@ -225,7 +243,7 @@ void rms_norm_per_block_quant_dispatch(
weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr,
var_epsilon, hidden_size,
var_epsilon, hidden_size, input_stride,
has_residual ? residual->data_ptr<scalar_in_t>()
: nullptr,
scales.stride(1));
Expand All @@ -246,7 +264,9 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.stride(-1) == 1,
"Input must be contiguous in the last dimension");

if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
Expand All @@ -255,6 +275,7 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
TORCH_CHECK(residual->is_contiguous());
}

TORCH_CHECK(group_size == 128 || group_size == 64,
Expand Down
Loading