[OOT Plugin][Performance] Optimize metadata prepare#263
[OOT Plugin][Performance] Optimize metadata prepare#263ganyi1996ppo wants to merge 4 commits intomainfrom
Conversation
Signed-off-by: ganyi <ygan@amd.com>
|
Can you help attach the accuracy check result for Qwen with OOT mode? |
There was a problem hiding this comment.
Pull request overview
Optimizes host-side attention metadata preparation for the OOT plugin (targeting faster decode-only and prefill-only paths) by avoiding unnecessary CPU computations and simplifying metadata.
Changes:
- Removed
min_query_lenfrom plugin-mode flash-attn metadata dataclasses. - Added decode-only / prefill-only fast paths to skip
seq_lens.cpu()and query-length derivation for non-mixed batches. - Adjusted how
prefill_metadataandnum_actual_kv_tokensare produced duringbuild().
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
atom/plugin/attention.py
Outdated
| else query_lens_cpu[:num_decodes].max().item() | ||
| ), | ||
| max_seq_len=( | ||
| common_attn_metadata.max_seq_len | ||
| if prefill_only | ||
| else query_lens_cpu[:num_decodes].max().item() |
There was a problem hiding this comment.
prefill_metadata.max_seq_len is computed from query_lens_cpu[:num_decodes] in the mixed-request path. This is both the wrong slice (decode, not prefill) and the wrong source (query lengths, not seq lengths), and will under-report max_seqlen_k passed to flash_attn_varlen_func.
Compute it from the prefill portion’s sequence lengths (or equivalent prefill query lengths if those are guaranteed equal here).
| else query_lens_cpu[:num_decodes].max().item() | |
| ), | |
| max_seq_len=( | |
| common_attn_metadata.max_seq_len | |
| if prefill_only | |
| else query_lens_cpu[:num_decodes].max().item() | |
| else query_lens_cpu[num_decodes + num_extends :].max().item() | |
| ), | |
| max_seq_len=( | |
| common_attn_metadata.max_seq_len | |
| if prefill_only | |
| else seq_lens_cpu[num_decodes + num_extends :].max().item() |
| seq_lens = common_attn_metadata.seq_lens.cpu() | ||
| query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] | ||
| if mixed_request: | ||
| seq_lens = common_attn_metadata.seq_lens.cpu() |
There was a problem hiding this comment.
vllm don't have something like seq_lens_cpu?
There was a problem hiding this comment.
vllm deprecate the seq_lens_cpu in common_metadata for a while.
@wuhuikx Accuracy data attached in the PR description. |
Signed-off-by: ganyi <ygan@amd.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 1 out of 1 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Motivation
Optimize metadata preparation for MHA in OOT plugin, for decode only or prefill only case, the metadata prepare process should be compressed to 50~100 us at host side.
Technical Details
Test Plan
gsm8k
Test Result
Submission Checklist