Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions evaluation/deepseek_fp4/launch_deepseekr1_fp4_DP_EP.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export VLLM_ROCM_USE_TRITON_ROPE=1 # add for acc
export VLLM_DISABLE_COMPILE_CACHE=1
# FIXME: for now disable fp4 asm gemm because of running issue
export VLLM_ROCM_USE_AITER_FP4_ASM_GEMM=0
export VLLM_ROCM_USE_AITER_BMM=1
export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0 # disable for acc

export TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE=1
Expand Down
1 change: 1 addition & 0 deletions evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export VLLM_ROCM_USE_TRITON_ROPE=1 # add for acc
export VLLM_DISABLE_COMPILE_CACHE=1
# FIXME: for now disable fp4 asm gemm because of running issue
export VLLM_ROCM_USE_AITER_FP4_ASM_GEMM=0
export VLLM_ROCM_USE_AITER_BMM=1
export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0 # disable for acc

export TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE=1
Expand Down
8 changes: 4 additions & 4 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_TRITON_ROPE: bool = True
VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE: bool = True
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
Expand Down Expand Up @@ -929,8 +929,8 @@ def get_vllm_port() -> int | None:
),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
"VLLM_ROCM_USE_AITER_BMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_BMM", "True").lower() in ("true", "1")
),
# Use AITER triton unified attention for V1 attention
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
Expand Down Expand Up @@ -1579,7 +1579,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
"VLLM_ROCM_USE_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_AITER_BMM",
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING",
Expand Down
107 changes: 107 additions & 0 deletions vllm/model_executor/layers/quantization/quark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Any

import regex as re
import torch
from aiter.ops.triton.quant import dynamic_mxfp4_quant


def deep_compare(dict1: Any, dict2: Any) -> bool:
Expand Down Expand Up @@ -103,3 +105,108 @@ def _is_equal_or_regex_match(
elif target == value:
return True
return False


def quant_to_mxfp4(x):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be some utils in vllm for quant&dequant

"""
Quant the input tensor x to mxfp4 format
"""
h, b, d = x.shape
x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)


def dequant_mxfp4_to_fp32(x, is_threed):
"""
Dequant the input tensor x from mxfp4 format to fp32 format
"""
# repeat interleave 2x because we pack mxfp4 in uint8
x = x.repeat_interleave(2, dim=-1)
if is_threed:
x[..., ::2] = x[..., ::2] & 0xF
x[..., 1::2] = x[..., 1::2] >> 4
else:
x[:, ::2] = x[:, ::2] & 0xF
x[:, 1::2] = x[:, 1::2] >> 4

mxfp4_list = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
return mxfp4_in_f32[x.long()]


def convert_e8m0_to_fp32(x):
"""
Convert the input tensor x from e8m0 format to fp32 format
"""
# Convert the input tensor `x` (assumed to be in
# e8m0 format) to float32. e8m0 is a custom 8-bit
# floating point format with 8 bits for exponent, 0 for mantissa.
# This means the value is essentially 2^(exponent - 127),
# similar to how IEEE-754 stores floats.

# Convert x to float32 for computation, and
# compute the power of 2 by subtracting the bias (127).
x_f32 = 2 ** ((x.to(torch.float32)) - 127)

# If the exponent value was 255 (i.e., 2^(128)), this
# is a special case usually used to represent NaN or Inf.
# Since this custom format has no mantissa, treat 2^128 as NaN.
x_f32[x_f32 == 128] = float("nan")
return x_f32


def quark_post_load_weights(
qk_nope_head_dim: int,
v_head_dim: int,
weight: torch.Tensor,
weight_scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Post load weights for quark MXFP4 BMM
"""

def _quant_and_split_weight(loaded_weight: torch.Tensor):
W_UK, W_UV = loaded_weight.unflatten(
0, (-1, (qk_nope_head_dim + v_head_dim))
).split([qk_nope_head_dim, v_head_dim], dim=1)
W_UK, W_UK_scale = quant_to_mxfp4(W_UK.transpose(-2, -1))
W_UV, W_UV_scale = quant_to_mxfp4(W_UV)
W_UK_scale = W_UK_scale.contiguous()
W_UV_scale = W_UV_scale.contiguous()
return W_UK, W_UK_scale, W_UV, W_UV_scale

# weight: [kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)]
# for the model with BF16 weight to use MXFP4 BMM,
# quant the weight to U8 packed format(MXFP4*2)
if weight.dtype == torch.bfloat16:
W_UK, W_UK_scale, W_UV, W_UV_scale = _quant_and_split_weight(weight)
elif weight.dtype == torch.uint8:
assert weight_scale is not None, (
"[Error][ROCm] weight_scale is required for U8 weight"
)
weight = dequant_mxfp4_to_fp32(weight, True).to(torch.bfloat16)
weight_scale = weight_scale.repeat_interleave(32, dim=-1)
weight_scale = convert_e8m0_to_fp32(weight_scale).to(torch.bfloat16)
weight = weight * weight_scale
W_UK, W_UK_scale, W_UV, W_UV_scale = _quant_and_split_weight(weight)
else:
raise ValueError("[Error][ROCm] Unsupported weight dtype: ", weight.dtype)

return W_UK, W_UK_scale, W_UV, W_UV_scale
Loading