Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ echo "running $model_path"
# FIXME: for now use 0.8 for memory utilization
vllm serve $model_path \
--host localhost \
--port 9000 \
--tensor-parallel-size 8 \
--max-num-batched-tokens 32768 \
--trust-remote-code \
Expand Down
131 changes: 128 additions & 3 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import batched_gemm_afp4wfp4_pre_quant


class QueryLenSupport(Enum):
Expand Down Expand Up @@ -275,8 +276,9 @@
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_FP8BMM
and envs.VLLM_ROCM_USE_AITER
)
and (not ENABLE_FP4)
)

Check failure on line 281 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/v1/attention/backends/mla/common.py:281:18: F821 Undefined name `ENABLE_FP4`

if is_rocm_aiter_fp8bmm_enabled():
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501
Expand Down Expand Up @@ -1141,7 +1143,9 @@
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
print("kv_b_proj_wieght:", self.kv_b_proj)
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
print("kv_b_proj_weight after dequant:", kv_b_proj_weight.dtype)
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
Expand All @@ -1165,6 +1169,8 @@
if is_rocm_aiter_fp8bmm_enabled():
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
print("W_k dtype:", W_K.dtype)
print("W_V dtype:", W_V.dtype)
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
W_K, dtype=current_platform.fp8_dtype()
)
Expand Down Expand Up @@ -1211,11 +1217,32 @@

def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
# x [num_heads, batch_size, kv_lora_rank]
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)

if is_rocm_aiter_fp8bmm_enabled():
if self.W_V.dtype == torch.uint8:
out = out.view(-1, self.num_heads, self.v_head_dim)
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
# x [num_heads, batch_size, kv_lora_rank]
# WV [num_heads, v_head_dim // 2, kv_lora_rank]
out_buffer = torch.empty(
x.shape[0], # num_heads
x.shape[1], # batchsize
self.W_V.shape[1], # v
device=x.device,
dtype=torch.bfloat16)
batched_gemm_afp4wfp4_pre_quant(
x,
self.W_V,
self.W_V_scale,
torch.bfloat16,
out_buffer
)
out_buffer = out_buffer.transpose(0, 1) # [batchsize, num_heads, v]
out.copy_(out_buffer)
elif is_rocm_aiter_fp8bmm_enabled() and (not ENABLE_FP4):
out = out.view(-1, self.num_heads, self.v_head_dim)
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)

Check failure on line 1245 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/v1/attention/backends/mla/common.py:1245:54: F821 Undefined name `ENABLE_FP4`
x = aiter_triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
)
Expand Down Expand Up @@ -1545,10 +1572,96 @@
return dequant_weights.T
return layer.weight

if self.kv_b_proj.weight.dtype == torch.uint8: # mxfp4 elemnts packed in a byte
# kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank]
kv_b_proj_weight = self.kv_b_proj.weight.T
kv_b_proj_weight = kv_b_proj_weight.reshape(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim // 2 + self.v_head_dim // 2,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim // 2, self.v_head_dim // 2], dim=-1
)
# W_K [self.kv_lora_rank, num_heads, qk_nope_head_dim // 2] -> [num_heads, kv_lora_rank, qk_nope_head_dim //2]
self.W_K = W_UK.transpose(0, 1)
# W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim, kv_lora_rank // 2]

Check failure on line 1588 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/common.py:1588:89: E501 Line too long (122 > 88)
# Alway pack at the last dimension, need check acc here.
self.W_V = W_UV.permute(1, 2, 0)

Check failure on line 1590 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/common.py:1590:89: E501 Line too long (106 > 88)
self.W_V = self.W_V.reshape(self.num_heads, self.v_head_dim, self.kv_lora_rank // 2)

kv_b_proj_weight_sc = self.kv_b_proj.weight_scale

Check failure on line 1593 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/common.py:1593:89: E501 Line too long (96 > 88)
# kv_b_proj_weight_sc: [num_heads x (qk_nope_head_dim+v_head_dim), kv_lora_rank // 32]

Check failure on line 1595 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/v1/attention/backends/mla/common.py:1595:13: F841 Local variable `kv_b_proj_weight_sc` is assigned to but never used
# Obtain W_V_Scale first

Check failure on line 1596 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/common.py:1596:89: E501 Line too long (98 > 88)
W_scale = self.kv_b_proj.weight_scale.view(
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
self.kv_lora_rank // 32)
self.W_K_scale, self.W_V_scale = W_scale.split([self.qk_nope_head_dim, self.v_head_dim], dim=1)

# Obtain W_K_scale

Check failure on line 1603 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/common.py:1603:89: E501 Line too long (107 > 88)
self.W_K_scale = self.W_K_scale.view(self.num_heads, self.qk_nope_head_dim//32, self.kv_lora_rank)
self.W_K_scale = self.W_K_scale.permute(0, 2, 1)

Check failure on line 1606 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/common.py:1606:89: E501 Line too long (110 > 88)
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
pre_compilation_list = list(range(1, max_batch_size + 1))
if is_global_first_rank():
pre_compilation_list = tqdm(
pre_compilation_list,
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
total=max_batch_size,
)
for m in pre_compilation_list:
# [ num_heads, m, qk_nope_head_dim // 2 * 2]
x = torch.empty(
(self.W_K.shape[0], m, self.W_K.shape[2] * 2),
dtype=torch.bfloat16,
device=self.W_K.device,
)
# x shape [ num_heads, m , qk_nope_head_dim //2 * 2]
# W_K shape [num_heads, kv_lora_ranks, qk_nope_head_dim //2]
out = torch.empty(
x.shape[0], x.shape[1], self.W_K.shape[1], device=x.device, dtype=torch.bfloat16
)

Check failure on line 1627 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/common.py:1627:89: E501 Line too long (100 > 88)
batched_gemm_afp4wfp4_pre_quant(
x,
self.W_K,
self.W_K_scale,
torch.bfloat16,
out
)

## x [ num_heads, m, kv_lora_rank]
x = torch.empty(
(self.W_V.shape[0], m, self.W_V.shape[2] ** 2),
dtype=torch.bfloat16,
device=self.W_V.device,
)
## [num_heads, m, kv_lora_rank] x [ num_heads, v_head_dim // 2, kv_lora_rank]
## [num_heads, m, v_head_dim //2]
out = torch.empty(
x.shape[0], x.shape[1], self.W_V.shape[1], device=x.device, dtype=torch.bfloat16)
batched_gemm_afp4wfp4_pre_quant(
x,
self.W_V,
self.W_V_scale,
torch.bfloat16,
out
)
# Early return, the left is for fp8 scenario.
return

# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
print("self.kv_b_proj:", self.kv_b_proj.weight.dtype)
print("self.kv_b_proj.shape:", self.kv_b_proj.weight.shape)
print("self.qk_nope_head_dim: ", self.qk_nope_head_dim, " self.v_head_dim: ", self.v_head_dim)
print("self.kv_lora_rank:", self.kv_lora_rank, " self.num_heads: ", self.num_heads)
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
print("kv_b_proj_weight dtype after dequant:", kv_b_proj_weight.shape)
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
Expand All @@ -1572,6 +1685,8 @@
if is_rocm_aiter_fp8bmm_enabled():
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
print("W_k.dtype:", W_K.dtype)
print("W_V.dtype:", W_V.dtype)
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
W_K, dtype=current_platform.fp8_dtype()
)
Expand Down Expand Up @@ -1971,7 +2086,17 @@
decode_pe_padded.copy_(decode_q_pe)
decode_q_pe = decode_pe_padded

if is_rocm_aiter_fp8bmm_enabled():
if self.kv_b_proj.weight.dtype == torch.uint8:
decode_ql_nope = torch.empty(decode_q_nope.shape[0], decode_q_nope.shape[1], self.W_K.shape[1], dtype=torch.bfloat16, device=decode_q_nope.device)
batched_gemm_afp4wfp4_pre_quant(
decode_q_nope,
self.W_K,
self.W_K_scale,
torch.bfloat16,
decode_ql_nope
)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
elif is_rocm_aiter_fp8bmm_enabled():
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = aiter_triton_fp8_bmm(
decode_q_nope,
Expand Down