diff --git a/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh b/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh index 26ef8cd4dce9..efa80c226d33 100644 --- a/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh +++ b/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh @@ -31,7 +31,6 @@ echo "running $model_path" # FIXME: for now use 0.8 for memory utilization vllm serve $model_path \ --host localhost \ - --port 9000 \ --tensor-parallel-size 8 \ --max-num-batched-tokens 32768 \ --trust-remote-code \ diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 45a1ee858ddd..342d8e29e980 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -230,6 +230,7 @@ split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec +from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import batched_gemm_afp4wfp4_pre_quant class QueryLenSupport(Enum): @@ -275,7 +276,8 @@ def is_rocm_aiter_fp8bmm_enabled() -> bool: current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER_FP8BMM and envs.VLLM_ROCM_USE_AITER - ) + and (not ENABLE_FP4) + ) if is_rocm_aiter_fp8bmm_enabled(): @@ -1141,7 +1143,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low + print("kv_b_proj_wieght:", self.kv_b_proj) kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + print("kv_b_proj_weight after dequant:", kv_b_proj_weight.dtype) assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1165,6 +1169,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase): if is_rocm_aiter_fp8bmm_enabled(): W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 + print("W_k dtype:", W_K.dtype) + print("W_V dtype:", W_V.dtype) self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( W_K, dtype=current_platform.fp8_dtype() ) @@ -1211,9 +1217,30 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) + # x [num_heads, batch_size, kv_lora_rank] x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - if is_rocm_aiter_fp8bmm_enabled(): + if self.W_V.dtype == torch.uint8: + out = out.view(-1, self.num_heads, self.v_head_dim) + # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + # x [num_heads, batch_size, kv_lora_rank] + # WV [num_heads, v_head_dim // 2, kv_lora_rank] + out_buffer = torch.empty( + x.shape[0], # num_heads + x.shape[1], # batchsize + self.W_V.shape[1], # v + device=x.device, + dtype=torch.bfloat16) + batched_gemm_afp4wfp4_pre_quant( + x, + self.W_V, + self.W_V_scale, + torch.bfloat16, + out_buffer + ) + out_buffer = out_buffer.transpose(0, 1) # [batchsize, num_heads, v] + out.copy_(out_buffer) + elif is_rocm_aiter_fp8bmm_enabled() and (not ENABLE_FP4): out = out.view(-1, self.num_heads, self.v_head_dim) # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = aiter_triton_fp8_bmm( @@ -1545,10 +1572,96 @@ def get_and_maybe_dequant_weights(layer: LinearBase): return dequant_weights.T return layer.weight + if self.kv_b_proj.weight.dtype == torch.uint8: # mxfp4 elemnts packed in a byte + # kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank] + kv_b_proj_weight = self.kv_b_proj.weight.T + kv_b_proj_weight = kv_b_proj_weight.reshape( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim // 2 + self.v_head_dim // 2, + ) + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim // 2, self.v_head_dim // 2], dim=-1 + ) + # W_K [self.kv_lora_rank, num_heads, qk_nope_head_dim // 2] -> [num_heads, kv_lora_rank, qk_nope_head_dim //2] + self.W_K = W_UK.transpose(0, 1) + # W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim, kv_lora_rank // 2] + # Alway pack at the last dimension, need check acc here. + self.W_V = W_UV.permute(1, 2, 0) + self.W_V = self.W_V.reshape(self.num_heads, self.v_head_dim, self.kv_lora_rank // 2) + + kv_b_proj_weight_sc = self.kv_b_proj.weight_scale + # kv_b_proj_weight_sc: [num_heads x (qk_nope_head_dim+v_head_dim), kv_lora_rank // 32] + + # Obtain W_V_Scale first + W_scale = self.kv_b_proj.weight_scale.view( + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + self.kv_lora_rank // 32) + self.W_K_scale, self.W_V_scale = W_scale.split([self.qk_nope_head_dim, self.v_head_dim], dim=1) + + # Obtain W_K_scale + self.W_K_scale = self.W_K_scale.view(self.num_heads, self.qk_nope_head_dim//32, self.kv_lora_rank) + self.W_K_scale = self.W_K_scale.permute(0, 2, 1) + + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + for m in pre_compilation_list: + # [ num_heads, m, qk_nope_head_dim // 2 * 2] + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2] * 2), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + # x shape [ num_heads, m , qk_nope_head_dim //2 * 2] + # W_K shape [num_heads, kv_lora_ranks, qk_nope_head_dim //2] + out = torch.empty( + x.shape[0], x.shape[1], self.W_K.shape[1], device=x.device, dtype=torch.bfloat16 + ) + + batched_gemm_afp4wfp4_pre_quant( + x, + self.W_K, + self.W_K_scale, + torch.bfloat16, + out + ) + + ## x [ num_heads, m, kv_lora_rank] + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2] ** 2), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + ## [num_heads, m, kv_lora_rank] x [ num_heads, v_head_dim // 2, kv_lora_rank] + ## [num_heads, m, v_head_dim //2] + out = torch.empty( + x.shape[0], x.shape[1], self.W_V.shape[1], device=x.device, dtype=torch.bfloat16) + batched_gemm_afp4wfp4_pre_quant( + x, + self.W_V, + self.W_V_scale, + torch.bfloat16, + out + ) + # Early return, the left is for fp8 scenario. + return + # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low + print("self.kv_b_proj:", self.kv_b_proj.weight.dtype) + print("self.kv_b_proj.shape:", self.kv_b_proj.weight.shape) + print("self.qk_nope_head_dim: ", self.qk_nope_head_dim, " self.v_head_dim: ", self.v_head_dim) + print("self.kv_lora_rank:", self.kv_lora_rank, " self.num_heads: ", self.num_heads) kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + print("kv_b_proj_weight dtype after dequant:", kv_b_proj_weight.shape) assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1572,6 +1685,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase): if is_rocm_aiter_fp8bmm_enabled(): W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 + print("W_k.dtype:", W_K.dtype) + print("W_V.dtype:", W_V.dtype) self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( W_K, dtype=current_platform.fp8_dtype() ) @@ -1971,7 +2086,17 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if is_rocm_aiter_fp8bmm_enabled(): + if self.kv_b_proj.weight.dtype == torch.uint8: + decode_ql_nope = torch.empty(decode_q_nope.shape[0], decode_q_nope.shape[1], self.W_K.shape[1], dtype=torch.bfloat16, device=decode_q_nope.device) + batched_gemm_afp4wfp4_pre_quant( + decode_q_nope, + self.W_K, + self.W_K_scale, + torch.bfloat16, + decode_ql_nope + ) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + elif is_rocm_aiter_fp8bmm_enabled(): # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) decode_ql_nope = aiter_triton_fp8_bmm( decode_q_nope,