diff --git a/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh b/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh index 26ef8cd4dce9..88277621d5fa 100644 --- a/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh +++ b/evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh @@ -1,16 +1,16 @@ export VLLM_USE_V1=1 -export VLLM_USE_TRITON_FLASH_ATTN=0 +export VLLM_USE_TRITON_FLASH_ATTN=1 # use triton mha # export VLLM_LOGGING_LEVEL=DEBUG export VLLM_RPC_TIMEOUT=1800000 export VLLM_ROCM_USE_AITER=1 export VLLM_ROCM_USE_AITER_MHA=0 -export VLLM_ROCM_USE_AITER_MLA=1 +export VLLM_ROCM_USE_AITER_MLA=0 # use triton mha export VLLM_ROCM_USE_AITER_MOE=1 export VLLM_ROCM_USE_TRITON_ROPE=1 # add for acc export VLLM_DISABLE_COMPILE_CACHE=1 # FIXME: for now disable fp4 asm gemm because of running issue export VLLM_ROCM_USE_AITER_FP4_ASM_GEMM=0 -#export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0 # for now disable +export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0 # disable for acc export TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE=1 export TRITON_HIP_USE_ASYNC_COPY=1 @@ -37,11 +37,12 @@ vllm serve $model_path \ --trust-remote-code \ --no-enable-prefix-caching \ --disable-log-requests \ - --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \ - --gpu_memory_utilization 0.8 \ + --enforce-eager \ + --gpu_memory_utilization 0.7 \ --async-scheduling \ + --block-size 16 \ --load-format fastsafetensors \ --seed 123 2>&1 | tee log.server.log & - # --enforce-eager \ +# --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \ # --enable-expert-parallel \ diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 781f77e96319..b2a1711613ad 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -129,14 +129,21 @@ def _flash_attn_varlen_diff_headdims( q, k, v, softmax_scale=softmax_scale, **kwargs ) else: - return super()._flash_attn_varlen_diff_headdims( - q, - k, - v, - return_softmax_lse=return_softmax_lse, + from aiter.ops.triton.mha import flash_attn_varlen_func + + result = flash_attn_varlen_func( + q=q, + k=k, + v=v, + return_lse=return_softmax_lse, softmax_scale=softmax_scale, **kwargs, ) + if type(result) is tuple and return_softmax_lse: + output, lse = result + lse = lse.T.contiguous() + return (output, lse) + return result def _forward_decode( self,