From 36046a940e0d38675871f2076bdb322ce7833b63 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Wed, 5 Nov 2025 02:28:40 +0000 Subject: [PATCH 1/3] Precompile, decode forward --- vllm/v1/attention/backends/mla/common.py | 69 +++++++++++++++++++++++- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 45a1ee858ddd..34a39ed8aa67 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -230,7 +230,7 @@ split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec - +from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import batched_gemm_afp4wfp4_pre_quant class QueryLenSupport(Enum): """Defines the level of query length support for an attention backend's @@ -1141,7 +1141,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low + print("self.kv_b_proj:", self.kv_b_proj) kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + print("self.kv_b_proj after dequant:", self.kv_b_proj) assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1165,6 +1167,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase): if is_rocm_aiter_fp8bmm_enabled(): W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 + print("W_K.dtype:", W_K.dtype) + print("W_V.dtype:", W_V.dtype) self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( W_K, dtype=current_platform.fp8_dtype() ) @@ -1545,6 +1549,60 @@ def get_and_maybe_dequant_weights(layer: LinearBase): return dequant_weights.T return layer.weight + print("self.kv_b_proj:", self.kv_b_proj.shape) + print("self.qv_nope_head_dim:", self.qk_nope_head_dim, "self.v_head_dim:", self.v_head_dim) + if self.kv_b_proj.weight.dtype == torch.uint8: # mxfp4 elemnts packed in a byte + kv_b_proj_weight = self.kv_b_proj.weight.view( + self.kv_lora_rank, + self.num_heads, + self.qv_nope_head_dim // 2 + self.v_head_dim // 2, + ) + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim // 2, self.v_head_dim // 2], dim=-1 + ) + if is_global_first_rank(): + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + for m in pre_compilation_list: + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + out = torch.empty( + x.shape[0], x.shape[1], x.shape[2], device=x.device, dtype=torch.bfloat16 + ) + + # aiter_triton_fp8_bmm( + # x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + # ) + batched_gemm_afp4wfp4_pre_quant( + x, + self.W_K, + self.W_k_scale, + torch.bfloat16 + ) + + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + batched_gemm_afp4wfp4_pre_quant( + x, + self.W_V, + self.W_V_scale, + torch.bfloat16 + ) + # aiter_triton_fp8_bmm( + # x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + # ) + # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low @@ -1971,7 +2029,14 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if is_rocm_aiter_fp8bmm_enabled(): + if self.kv_b_proj.weight.dtype == torch.uint8: + decode_ql_nope = batched_gemm_afp4wfp4_pre_quant( + decode_q_nope, + self.W_K, + self.W_k_scale, + torch.bfloat16 + ) + elif is_rocm_aiter_fp8bmm_enabled(): # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) decode_ql_nope = aiter_triton_fp8_bmm( decode_q_nope, From e484e3c1c70850ec594b53afc714eacf84da1217 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Thu, 6 Nov 2025 06:16:10 +0000 Subject: [PATCH 2/3] [fp4 bmm] runnable version --- .../deepseek_fp4/launch_deepseekr1_fp4_TP.sh | 1 - vllm/v1/attention/backends/mla/common.py | 182 +++++++++++++----- 2 files changed, 133 insertions(+), 50 deletions(-) diff --git a/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh b/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh index 26ef8cd4dce9..efa80c226d33 100644 --- a/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh +++ b/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh @@ -31,7 +31,6 @@ echo "running $model_path" # FIXME: for now use 0.8 for memory utilization vllm serve $model_path \ --host localhost \ - --port 9000 \ --tensor-parallel-size 8 \ --max-num-batched-tokens 32768 \ --trust-remote-code \ diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 34a39ed8aa67..bbd0c20b753d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -232,6 +232,9 @@ from vllm.v1.kv_cache_interface import AttentionSpec from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import batched_gemm_afp4wfp4_pre_quant + +ENABLE_FP4=True + class QueryLenSupport(Enum): """Defines the level of query length support for an attention backend's decode pipeline. @@ -275,7 +278,8 @@ def is_rocm_aiter_fp8bmm_enabled() -> bool: current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER_FP8BMM and envs.VLLM_ROCM_USE_AITER - ) + and (not ENABLE_FP4) + ) if is_rocm_aiter_fp8bmm_enabled(): @@ -1141,9 +1145,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low - print("self.kv_b_proj:", self.kv_b_proj) + print("kv_b_proj_wieght:", self.kv_b_proj) kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - print("self.kv_b_proj after dequant:", self.kv_b_proj) + print("kv_b_proj_weight after dequant:", kv_b_proj_weight.dtype) assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1167,8 +1171,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase): if is_rocm_aiter_fp8bmm_enabled(): W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 - print("W_K.dtype:", W_K.dtype) - print("W_V.dtype:", W_V.dtype) + print("W_k dtype:", W_K.dtype) + print("W_V dtype:", W_V.dtype) self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( W_K, dtype=current_platform.fp8_dtype() ) @@ -1215,9 +1219,36 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) + # x [num_heads, batch_size, kv_lora_rank] x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + print("[Unified Path]", "out shape:", out.shape, "out dtype:", out.dtype) - if is_rocm_aiter_fp8bmm_enabled(): + if self.W_V.dtype == torch.uint8: + out = out.view(-1, self.num_heads, self.v_head_dim) + # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + # x [num_heads, batch_size, kv_lora_rank] + # WV [num_heads, v_head_dim // 2, kv_lora_rank] + out_buffer = torch.empty( + x.shape[0], # num_heads + x.shape[1], # batchsize + self.W_V.shape[2] * 2, # v + device=x.device, + dtype=torch.bfloat16) + print("In _v_up_proj:") + print("x.shape:", x.shape, " self.W_V.shape:", self.W_V.shape, + "out_buffer.shape:", out_buffer.shape, " out.shape:", + out.shape) + batched_gemm_afp4wfp4_pre_quant( + x, + self.W_V, + self.W_V_scale, + torch.bfloat16, + out_buffer + ) + out_buffer = out_buffer.transpose(0, 1) # [batchsize, num_heads, v] + #out = out.transpose(0, 1) # [num_heads, batch_size, v] + out.copy_(out_buffer) + elif is_rocm_aiter_fp8bmm_enabled() and (not ENABLE_FP4): out = out.view(-1, self.num_heads, self.v_head_dim) # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = aiter_triton_fp8_bmm( @@ -1549,64 +1580,110 @@ def get_and_maybe_dequant_weights(layer: LinearBase): return dequant_weights.T return layer.weight - print("self.kv_b_proj:", self.kv_b_proj.shape) - print("self.qv_nope_head_dim:", self.qk_nope_head_dim, "self.v_head_dim:", self.v_head_dim) - if self.kv_b_proj.weight.dtype == torch.uint8: # mxfp4 elemnts packed in a byte - kv_b_proj_weight = self.kv_b_proj.weight.view( + print("self.kv_b_proj:", self.kv_b_proj.weight.shape) + print("self.qk_nope_head_dim:", self.qk_nope_head_dim, "self.v_head_dim:", self.v_head_dim) + + if self.kv_b_proj.weight.dtype == torch.uint8 and ENABLE_FP4: # mxfp4 elemnts packed in a byte + # self.kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank] + kv_b_proj_weight = self.kv_b_proj.weight.T + kv_b_proj_weight = kv_b_proj_weight.reshape( self.kv_lora_rank, self.num_heads, - self.qv_nope_head_dim // 2 + self.v_head_dim // 2, + self.qk_nope_head_dim // 2 + self.v_head_dim // 2, ) W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim // 2, self.v_head_dim // 2], dim=-1 ) + # W_K [self.kv_lora_rank, num_heads, qk_nope_head_dim // 2] -> [num_heads, kv_lora_rank, qk_nope_head_dim //2] + self.W_K = W_UK.transpose(0, 1) + # W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim // 2, kv_lora_rank] + #self.W_V = W_UV.permute(1, 2, 0) + self.W_V = W_UV.transpose(0, 1) + + # split w_scale + kv_b_proj_weight_sc = self.kv_b_proj.weight_scale + print("kv_b_proj_weight_sc.shape:", kv_b_proj_weight_sc.shape) + # Shape should be [num_headsx(qk_nope_head_dim+v_head_dim), kv_lora_rank // 32] + kv_b_proj_weight_sc = self.kv_b_proj.weight_scale.T.reshape( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim // 32 + self.v_head_dim // 32 + ) + # self.W_K_scale [kv_lora_rank, num_heads, qk_nope_head_dim //32] + self.W_K_scale, self.W_V_scale = kv_b_proj_weight_sc.split( + [self.qk_nope_head_dim // 32, self.v_head_dim // 32], dim=-1) + self.W_K_scale = self.W_K_scale.transpose(0, 1) + self.W_V_scale = self.W_V_scale.permute(1, 2, 0) + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) if is_global_first_rank(): - max_batch_size = 1024 # [ToDo] Find the optimal upper limit - pre_compilation_list = list(range(1, max_batch_size + 1)) pre_compilation_list = tqdm( pre_compilation_list, desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", total=max_batch_size, ) - for m in pre_compilation_list: - x = torch.empty( - (self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device, - ) - out = torch.empty( - x.shape[0], x.shape[1], x.shape[2], device=x.device, dtype=torch.bfloat16 - ) + for m in pre_compilation_list: + #print("Pre-Compiling first kernel", flush=True) + #print("self.W_K.shape:", self.W_K.shape, flush=True) + #print("self.W_K_scale.shape:", self.W_K_scale.shape, flush=True) + # [ num_heads, m, qk_nope_head_dim // 2 * 2] + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2] * 2), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + #print("x.shape:", x.shape, flush=True) + # x shape [ num_heads, m , qk_nope_head_dim //2 * 2] + # W_K shape [num_heads, kv_lora_ranks, qk_nope_head_dim //2] + out = torch.empty( + x.shape[0], x.shape[1], self.W_K.shape[1], device=x.device, dtype=torch.bfloat16 + ) + #print("out.shape:", out.shape, flush=True) - # aiter_triton_fp8_bmm( - # x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True - # ) - batched_gemm_afp4wfp4_pre_quant( - x, - self.W_K, - self.W_k_scale, - torch.bfloat16 - ) + # self.W_K [kv_lora_rank, num_heads, qk_nope_head_dim //32] + + batched_gemm_afp4wfp4_pre_quant( + x, + self.W_K, + self.W_K_scale, + torch.bfloat16, + out + ) + + print("Pre-Compiling second kernel", flush=True) + ## x [ num_heads, m, kv_lora_rank] + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2] * 2), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + print("x.shape:", x.shape, flush=True) + ## [num_heads, m, kv_lora_rank] x [ num_heads, v_head_dim // 2, kv_lora_rank] + ## [num_heads, m, v_head_dim //2] + out = torch.empty( + x.shape[0], x.shape[1], self.W_K.shape[1], device=x.device, dtype=torch.bfloat16) + print("out.shape:", out.shape, flush=True) + print("self.W_V.shape:", self.W_V.shape, flush=True) + print("self.W_V_scale.shape:", self.W_V_scale.shape, flush=True) + batched_gemm_afp4wfp4_pre_quant( + x, + self.W_V, + self.W_V_scale, + torch.bfloat16, + out + ) + return - x = torch.empty( - (self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device, - ) - batched_gemm_afp4wfp4_pre_quant( - x, - self.W_V, - self.W_V_scale, - torch.bfloat16 - ) - # aiter_triton_fp8_bmm( - # x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True - # ) # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low + print("self.kv_b_proj:", self.kv_b_proj.weight.dtype) + print("self.kv_b_proj.shape:", self.kv_b_proj.weight.shape) + print("self.qk_nope_head_dim: ", self.qk_nope_head_dim, " self.v_head_dim: ", self.v_head_dim) + print("self.kv_lora_rank:", self.kv_lora_rank, " self.num_heads: ", self.num_heads) kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + print("kv_b_proj_weight dtype after dequant:", kv_b_proj_weight.shape) assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1630,6 +1707,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase): if is_rocm_aiter_fp8bmm_enabled(): W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 + print("W_k.dtype:", W_K.dtype) + print("W_V.dtype:", W_V.dtype) self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( W_K, dtype=current_platform.fp8_dtype() ) @@ -2029,13 +2108,17 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if self.kv_b_proj.weight.dtype == torch.uint8: - decode_ql_nope = batched_gemm_afp4wfp4_pre_quant( + if self.kv_b_proj.weight.dtype == torch.uint8 and ENABLE_FP4: + decode_ql_nope = torch.empty(decode_q_nope.shape[0], decode_q_nope.shape[1], self.W_K.shape[1], dtype=torch.bfloat16, device=decode_q_nope.device) + batched_gemm_afp4wfp4_pre_quant( decode_q_nope, self.W_K, - self.W_k_scale, - torch.bfloat16 + self.W_K_scale, + torch.bfloat16, + decode_ql_nope ) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + print("[FP4 Path] decode_ql_nope.shape:", decode_ql_nope.shape, "dtype:", decode_ql_nope.dtype) elif is_rocm_aiter_fp8bmm_enabled(): # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) decode_ql_nope = aiter_triton_fp8_bmm( @@ -2045,6 +2128,7 @@ def forward( group_size=128, transpose_bm=True, ) + print("[FP8 Path] decode_ql_nope.shape:", decode_ql_nope.shape, "dtype:", decode_ql_nope.dtype) else: # Pads the head_dim if necessary (for the underlying kernel) N, B, P = decode_q_nope.shape From 95308fd90169f0082d906366fdfa9aa3d201a665 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Fri, 7 Nov 2025 02:26:27 +0000 Subject: [PATCH 3/3] clean code --- vllm/v1/attention/backends/mla/common.py | 70 ++++++++---------------- 1 file changed, 23 insertions(+), 47 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index bbd0c20b753d..342d8e29e980 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -233,8 +233,6 @@ from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import batched_gemm_afp4wfp4_pre_quant -ENABLE_FP4=True - class QueryLenSupport(Enum): """Defines the level of query length support for an attention backend's decode pipeline. @@ -1221,7 +1219,6 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) # x [num_heads, batch_size, kv_lora_rank] x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - print("[Unified Path]", "out shape:", out.shape, "out dtype:", out.dtype) if self.W_V.dtype == torch.uint8: out = out.view(-1, self.num_heads, self.v_head_dim) @@ -1231,13 +1228,9 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): out_buffer = torch.empty( x.shape[0], # num_heads x.shape[1], # batchsize - self.W_V.shape[2] * 2, # v + self.W_V.shape[1], # v device=x.device, dtype=torch.bfloat16) - print("In _v_up_proj:") - print("x.shape:", x.shape, " self.W_V.shape:", self.W_V.shape, - "out_buffer.shape:", out_buffer.shape, " out.shape:", - out.shape) batched_gemm_afp4wfp4_pre_quant( x, self.W_V, @@ -1246,7 +1239,6 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): out_buffer ) out_buffer = out_buffer.transpose(0, 1) # [batchsize, num_heads, v] - #out = out.transpose(0, 1) # [num_heads, batch_size, v] out.copy_(out_buffer) elif is_rocm_aiter_fp8bmm_enabled() and (not ENABLE_FP4): out = out.view(-1, self.num_heads, self.v_head_dim) @@ -1580,11 +1572,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase): return dequant_weights.T return layer.weight - print("self.kv_b_proj:", self.kv_b_proj.weight.shape) - print("self.qk_nope_head_dim:", self.qk_nope_head_dim, "self.v_head_dim:", self.v_head_dim) - - if self.kv_b_proj.weight.dtype == torch.uint8 and ENABLE_FP4: # mxfp4 elemnts packed in a byte - # self.kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank] + if self.kv_b_proj.weight.dtype == torch.uint8: # mxfp4 elemnts packed in a byte + # kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank] kv_b_proj_weight = self.kv_b_proj.weight.T kv_b_proj_weight = kv_b_proj_weight.reshape( self.kv_lora_rank, @@ -1596,24 +1585,25 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) # W_K [self.kv_lora_rank, num_heads, qk_nope_head_dim // 2] -> [num_heads, kv_lora_rank, qk_nope_head_dim //2] self.W_K = W_UK.transpose(0, 1) - # W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim // 2, kv_lora_rank] - #self.W_V = W_UV.permute(1, 2, 0) - self.W_V = W_UV.transpose(0, 1) + # W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim, kv_lora_rank // 2] + # Alway pack at the last dimension, need check acc here. + self.W_V = W_UV.permute(1, 2, 0) + self.W_V = self.W_V.reshape(self.num_heads, self.v_head_dim, self.kv_lora_rank // 2) - # split w_scale kv_b_proj_weight_sc = self.kv_b_proj.weight_scale - print("kv_b_proj_weight_sc.shape:", kv_b_proj_weight_sc.shape) - # Shape should be [num_headsx(qk_nope_head_dim+v_head_dim), kv_lora_rank // 32] - kv_b_proj_weight_sc = self.kv_b_proj.weight_scale.T.reshape( - self.kv_lora_rank, + # kv_b_proj_weight_sc: [num_heads x (qk_nope_head_dim+v_head_dim), kv_lora_rank // 32] + + # Obtain W_V_Scale first + W_scale = self.kv_b_proj.weight_scale.view( self.num_heads, - self.qk_nope_head_dim // 32 + self.v_head_dim // 32 - ) - # self.W_K_scale [kv_lora_rank, num_heads, qk_nope_head_dim //32] - self.W_K_scale, self.W_V_scale = kv_b_proj_weight_sc.split( - [self.qk_nope_head_dim // 32, self.v_head_dim // 32], dim=-1) - self.W_K_scale = self.W_K_scale.transpose(0, 1) - self.W_V_scale = self.W_V_scale.permute(1, 2, 0) + self.qk_nope_head_dim + self.v_head_dim, + self.kv_lora_rank // 32) + self.W_K_scale, self.W_V_scale = W_scale.split([self.qk_nope_head_dim, self.v_head_dim], dim=1) + + # Obtain W_K_scale + self.W_K_scale = self.W_K_scale.view(self.num_heads, self.qk_nope_head_dim//32, self.kv_lora_rank) + self.W_K_scale = self.W_K_scale.permute(0, 2, 1) + max_batch_size = 1024 # [ToDo] Find the optimal upper limit pre_compilation_list = list(range(1, max_batch_size + 1)) if is_global_first_rank(): @@ -1623,24 +1613,17 @@ def get_and_maybe_dequant_weights(layer: LinearBase): total=max_batch_size, ) for m in pre_compilation_list: - #print("Pre-Compiling first kernel", flush=True) - #print("self.W_K.shape:", self.W_K.shape, flush=True) - #print("self.W_K_scale.shape:", self.W_K_scale.shape, flush=True) # [ num_heads, m, qk_nope_head_dim // 2 * 2] x = torch.empty( (self.W_K.shape[0], m, self.W_K.shape[2] * 2), dtype=torch.bfloat16, device=self.W_K.device, ) - #print("x.shape:", x.shape, flush=True) # x shape [ num_heads, m , qk_nope_head_dim //2 * 2] # W_K shape [num_heads, kv_lora_ranks, qk_nope_head_dim //2] out = torch.empty( x.shape[0], x.shape[1], self.W_K.shape[1], device=x.device, dtype=torch.bfloat16 ) - #print("out.shape:", out.shape, flush=True) - - # self.W_K [kv_lora_rank, num_heads, qk_nope_head_dim //32] batched_gemm_afp4wfp4_pre_quant( x, @@ -1650,21 +1633,16 @@ def get_and_maybe_dequant_weights(layer: LinearBase): out ) - print("Pre-Compiling second kernel", flush=True) ## x [ num_heads, m, kv_lora_rank] x = torch.empty( - (self.W_V.shape[0], m, self.W_V.shape[2] * 2), + (self.W_V.shape[0], m, self.W_V.shape[2] ** 2), dtype=torch.bfloat16, device=self.W_V.device, ) - print("x.shape:", x.shape, flush=True) ## [num_heads, m, kv_lora_rank] x [ num_heads, v_head_dim // 2, kv_lora_rank] ## [num_heads, m, v_head_dim //2] out = torch.empty( - x.shape[0], x.shape[1], self.W_K.shape[1], device=x.device, dtype=torch.bfloat16) - print("out.shape:", out.shape, flush=True) - print("self.W_V.shape:", self.W_V.shape, flush=True) - print("self.W_V_scale.shape:", self.W_V_scale.shape, flush=True) + x.shape[0], x.shape[1], self.W_V.shape[1], device=x.device, dtype=torch.bfloat16) batched_gemm_afp4wfp4_pre_quant( x, self.W_V, @@ -1672,9 +1650,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase): torch.bfloat16, out ) + # Early return, the left is for fp8 scenario. return - # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low @@ -2108,7 +2086,7 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if self.kv_b_proj.weight.dtype == torch.uint8 and ENABLE_FP4: + if self.kv_b_proj.weight.dtype == torch.uint8: decode_ql_nope = torch.empty(decode_q_nope.shape[0], decode_q_nope.shape[1], self.W_K.shape[1], dtype=torch.bfloat16, device=decode_q_nope.device) batched_gemm_afp4wfp4_pre_quant( decode_q_nope, @@ -2118,7 +2096,6 @@ def forward( decode_ql_nope ) decode_ql_nope = decode_ql_nope.transpose(0, 1) - print("[FP4 Path] decode_ql_nope.shape:", decode_ql_nope.shape, "dtype:", decode_ql_nope.dtype) elif is_rocm_aiter_fp8bmm_enabled(): # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) decode_ql_nope = aiter_triton_fp8_bmm( @@ -2128,7 +2105,6 @@ def forward( group_size=128, transpose_bm=True, ) - print("[FP8 Path] decode_ql_nope.shape:", decode_ql_nope.shape, "dtype:", decode_ql_nope.dtype) else: # Pads the head_dim if necessary (for the underlying kernel) N, B, P = decode_q_nope.shape