Skip to content

Conversation

@devpatelio
Copy link
Collaborator

@devpatelio devpatelio commented Jan 15, 2026

#877 When using top_p and/or top_k sampling, we use the same mask generated from vLLM and apply it to the trainer by applying it to the logits in the forward pass. To achieve this, the InferenceEngineOutput and TrainingInput was modified to store the sampling mask and hand it off in the trainer.

In the future, we can remove patch and create issue in vLLM for public API to access sampling mask

gemini-code-assist[bot]

This comment was marked as outdated.

@devpatelio
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces functionality to capture and apply sampling masks from the inference engine to the trainer. Key changes include adding sampling_masks to various data structures (GeneratorOutput, InferenceEngineOutput, Experience, TrainingInput) and implementing the logic to process these masks into tensors. A significant portion of the changes involves a temporary hack to extract sampling masks from vLLM's internal sampler by dynamically patching its functions. While this enables the feature, it introduces a high degree of fragility and maintainability risk. Additionally, there's a minor issue with a misleading type hint in the apply_sampling_mask function.

Comment on lines +41 to +87
# TODO(devpatel): This is a hack to get the sampling masks. We should find a better way to do this... fast
_sampling_masks = threading.local()
_sampler_patched = False


def _reset_sampling_masks() -> None:
_sampling_masks.items = []


def _append_sampling_mask(mask: torch.Tensor) -> None:
if not hasattr(_sampling_masks, "items"):
_sampling_masks.items = []
_sampling_masks.items.append(mask)


def _consume_sampling_masks() -> Optional[List[torch.Tensor]]:
masks = getattr(_sampling_masks, "items", None)
_sampling_masks.items = []
return masks


def _patch_vllm_sampler() -> None:
global _sampler_patched
if _sampler_patched:
return
try:
from vllm.v1.sample.ops import topk_topp_sampler as sampler
except Exception as exc:
logger.warning(f"Could not import vLLM topk_topp_sampler op and/or Sampler class: {exc}")
return

original_top_k_top_p = sampler.apply_top_k_top_p
original_top_k_only = sampler.apply_top_k_only

def _wrapped_top_k_top_p(logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None) -> torch.Tensor:
output = original_top_k_top_p(logits, k, p)
_append_sampling_mask(torch.isfinite(output).to(dtype=torch.bool).cpu())
return output

def _wrapped_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
output = original_top_k_only(logits, k)
_append_sampling_mask(torch.isfinite(output).to(dtype=torch.bool).cpu())
return output

sampler.apply_top_k_top_p = _wrapped_top_k_top_p
sampler.apply_top_k_only = _wrapped_top_k_only
_sampler_patched = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dynamic patching of vLLM's internal sampler functions (apply_top_k_top_p and apply_top_k_only) using threading.local() is a highly fragile approach. This relies on specific internal implementation details of vLLM, which are not part of its public API and can change without warning in future updates. This could lead to unexpected behavior, crashes, or incorrect sampling mask generation if vLLM's internal structure changes. While the TODO comment acknowledges this is a hack, it poses a significant risk to the maintainability and stability of the system. It would be preferable to find a more robust and officially supported way to extract this information from vLLM, or to encapsulate this hack more thoroughly with version checks and fallback mechanisms.

Comment on lines +272 to +278
# TODO(devpatel): We don't have the request_ids in the sampling metadata, so order by index.
for output_idx in range(len(outputs)):
per_request = []
for step_mask in masks:
if output_idx < step_mask.shape[0]:
per_request.append(step_mask[output_idx].nonzero(as_tuple=False).squeeze(-1).tolist())
sampling_masks.append(per_request)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment "We don't have the request_ids in the sampling metadata, so order by index" indicates an implicit assumption about the order of masks matching the order of outputs. If vLLM processes requests asynchronously or reorders them internally, this assumption could lead to incorrect sampling masks being associated with the wrong outputs. This needs to be explicitly guaranteed by vLLM's behavior or handled more robustly (e.g., by associating request_id with sampling masks if possible).

Comment on lines +190 to +192
logits: Float[torch.Tensor, "batch_size seqlen top_tokens"],
sampling_mask: Integer[torch.Tensor, "batch_size seqlen mask_size"],
) -> Float[torch.Tensor, "batch_size seqlen top_tokens"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for logits is Float[torch.Tensor, "batch_size seqlen top_tokens"]. However, the implementation assumes logits.shape[2] represents the full vocab_size when creating valid_token_mask (line 201). If logits has already been truncated to top_tokens (a subset of the full vocabulary), and sampling_mask contains indices from the full vocabulary, then sampling_mask indices could exceed logits.shape[2], leading to an out-of-bounds error during the scatter_ operation. Please clarify the expected shape of logits and ensure consistency between the type hint and the actual vocab_size used for masking.

def apply_sampling_mask(
    logits: Float[torch.Tensor, "batch_size seqlen vocab_size"],
    sampling_mask: Integer[torch.Tensor, "batch_size seqlen mask_size"],
) -> Float[torch.Tensor, "batch_size seqlen vocab_size"]:

@devpatelio devpatelio changed the title [draft] Sampling mask from Inference Engine Applies to Trainer [algorithm] Sampling mask from Inference Engine Applies to Trainer Jan 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant