-
Notifications
You must be signed in to change notification settings - Fork 220
[tx] Skip full logits computation during prefill #878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
During prefill, only the last token's logits are needed to start decoding. This optimization avoids materializing the full [B, T, V] logits tensor when prompt_logprobs is not requested.
Add parametrized test in test_models_common.py that verifies both llama3 and qwen3 models produce correct output shape and matching logits when using last_token_logits_only=True. Also tests generation equivalence with and without prompt_logprobs. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a valuable optimization to skip the computation of full logits during the prefill stage when prompt log probabilities are not required. The changes are implemented by adding a last_token_logits_only parameter to the Llama3 and Qwen3 models and leveraging it in the generator's prefill logic. The implementation is clean and the logic is sound. A new parametrized test file is added, which thoroughly validates the new functionality by checking output shapes, value equivalence, and generation consistency. I have one suggestion to improve the clarity of the new test code.
| base_config = ( | ||
| AutoConfig.from_pretrained(model_name) | ||
| if config_cls == Llama3Config | ||
| else PretrainedConfig.from_pretrained(model_name) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conditional logic for loading the base configuration is a bit confusing and contains unreachable code. Both Llama3Config and Qwen3Config are aliases for ModelConfig, so the condition config_cls == Llama3Config will be true for both parameterized test cases, making the else branch dead code.
To simplify and improve clarity, you can consistently use AutoConfig.from_pretrained(model_name), which is the standard approach for loading configurations from the HuggingFace Hub.
base_config = AutoConfig.from_pretrained(model_name)|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a valuable optimization to skip computing full logits during prefill when they are not needed, controlled by a new last_token_logits_only parameter. The implementation is clean and correctly applied to both Llama3 and Qwen3 models. The logic is correctly wired into the generator utility, activating the optimization when prompt_logprobs is false. The addition of a new parameterized test in test_models_common.py is excellent, as it thoroughly verifies the correctness of the change by checking output shapes, value equivalence, and end-to-end generation consistency. I have one suggestion to improve the new test for better failure diagnostics.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Summary
ing prefill
Motivation
During prefill, only the last token's logits are needed to start decoding. Computing logits for all prompt tokens requires a
[B, T, V] matmul where V (vocab size) is typically 32K-128K. This is wasteful when prompt_logprobs is not requested.
This optimization benefits:
Test plan