-
Notifications
You must be signed in to change notification settings - Fork 222
[algorithm] Sampling mask from Inference Engine Applies to Trainer #883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces functionality to capture and apply sampling masks from the inference engine to the trainer. Key changes include adding sampling_masks to various data structures (GeneratorOutput, InferenceEngineOutput, Experience, TrainingInput) and implementing the logic to process these masks into tensors. A significant portion of the changes involves a temporary hack to extract sampling masks from vLLM's internal sampler by dynamically patching its functions. While this enables the feature, it introduces a high degree of fragility and maintainability risk. Additionally, there's a minor issue with a misleading type hint in the apply_sampling_mask function.
| # TODO(devpatel): This is a hack to get the sampling masks. We should find a better way to do this... fast | ||
| _sampling_masks = threading.local() | ||
| _sampler_patched = False | ||
|
|
||
|
|
||
| def _reset_sampling_masks() -> None: | ||
| _sampling_masks.items = [] | ||
|
|
||
|
|
||
| def _append_sampling_mask(mask: torch.Tensor) -> None: | ||
| if not hasattr(_sampling_masks, "items"): | ||
| _sampling_masks.items = [] | ||
| _sampling_masks.items.append(mask) | ||
|
|
||
|
|
||
| def _consume_sampling_masks() -> Optional[List[torch.Tensor]]: | ||
| masks = getattr(_sampling_masks, "items", None) | ||
| _sampling_masks.items = [] | ||
| return masks | ||
|
|
||
|
|
||
| def _patch_vllm_sampler() -> None: | ||
| global _sampler_patched | ||
| if _sampler_patched: | ||
| return | ||
| try: | ||
| from vllm.v1.sample.ops import topk_topp_sampler as sampler | ||
| except Exception as exc: | ||
| logger.warning(f"Could not import vLLM topk_topp_sampler op and/or Sampler class: {exc}") | ||
| return | ||
|
|
||
| original_top_k_top_p = sampler.apply_top_k_top_p | ||
| original_top_k_only = sampler.apply_top_k_only | ||
|
|
||
| def _wrapped_top_k_top_p(logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None) -> torch.Tensor: | ||
| output = original_top_k_top_p(logits, k, p) | ||
| _append_sampling_mask(torch.isfinite(output).to(dtype=torch.bool).cpu()) | ||
| return output | ||
|
|
||
| def _wrapped_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor: | ||
| output = original_top_k_only(logits, k) | ||
| _append_sampling_mask(torch.isfinite(output).to(dtype=torch.bool).cpu()) | ||
| return output | ||
|
|
||
| sampler.apply_top_k_top_p = _wrapped_top_k_top_p | ||
| sampler.apply_top_k_only = _wrapped_top_k_only | ||
| _sampler_patched = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dynamic patching of vLLM's internal sampler functions (apply_top_k_top_p and apply_top_k_only) using threading.local() is a highly fragile approach. This relies on specific internal implementation details of vLLM, which are not part of its public API and can change without warning in future updates. This could lead to unexpected behavior, crashes, or incorrect sampling mask generation if vLLM's internal structure changes. While the TODO comment acknowledges this is a hack, it poses a significant risk to the maintainability and stability of the system. It would be preferable to find a more robust and officially supported way to extract this information from vLLM, or to encapsulate this hack more thoroughly with version checks and fallback mechanisms.
| # TODO(devpatel): We don't have the request_ids in the sampling metadata, so order by index. | ||
| for output_idx in range(len(outputs)): | ||
| per_request = [] | ||
| for step_mask in masks: | ||
| if output_idx < step_mask.shape[0]: | ||
| per_request.append(step_mask[output_idx].nonzero(as_tuple=False).squeeze(-1).tolist()) | ||
| sampling_masks.append(per_request) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment "We don't have the request_ids in the sampling metadata, so order by index" indicates an implicit assumption about the order of masks matching the order of outputs. If vLLM processes requests asynchronously or reorders them internally, this assumption could lead to incorrect sampling masks being associated with the wrong outputs. This needs to be explicitly guaranteed by vLLM's behavior or handled more robustly (e.g., by associating request_id with sampling masks if possible).
| logits: Float[torch.Tensor, "batch_size seqlen top_tokens"], | ||
| sampling_mask: Integer[torch.Tensor, "batch_size seqlen mask_size"], | ||
| ) -> Float[torch.Tensor, "batch_size seqlen top_tokens"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for logits is Float[torch.Tensor, "batch_size seqlen top_tokens"]. However, the implementation assumes logits.shape[2] represents the full vocab_size when creating valid_token_mask (line 201). If logits has already been truncated to top_tokens (a subset of the full vocabulary), and sampling_mask contains indices from the full vocabulary, then sampling_mask indices could exceed logits.shape[2], leading to an out-of-bounds error during the scatter_ operation. Please clarify the expected shape of logits and ensure consistency between the type hint and the actual vocab_size used for masking.
def apply_sampling_mask(
logits: Float[torch.Tensor, "batch_size seqlen vocab_size"],
sampling_mask: Integer[torch.Tensor, "batch_size seqlen mask_size"],
) -> Float[torch.Tensor, "batch_size seqlen vocab_size"]:
#877 When using top_p and/or top_k sampling, we use the same mask generated from vLLM and apply it to the trainer by applying it to the logits in the forward pass. To achieve this, the InferenceEngineOutput and TrainingInput was modified to store the sampling mask and hand it off in the trainer.
In the future, we can remove patch and create issue in vLLM for public API to access sampling mask