From af59ebacc8e9fc40deec1420e6f275aa49882b5a Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 20 Feb 2026 16:54:05 +0000 Subject: [PATCH 1/6] refactor for top-k top-p sampling support --- atom/entrypoints/openai_server.py | 11 ++ atom/model_engine/model_runner.py | 37 ++++-- atom/model_engine/scheduler.py | 2 + atom/model_engine/sequence.py | 2 + atom/model_ops/sampler.py | 193 ++++++++++++++++++++++++++++-- atom/sampling_params.py | 8 ++ 6 files changed, 238 insertions(+), 15 deletions(-) diff --git a/atom/entrypoints/openai_server.py b/atom/entrypoints/openai_server.py index a65807497..07e6905dd 100644 --- a/atom/entrypoints/openai_server.py +++ b/atom/entrypoints/openai_server.py @@ -32,6 +32,7 @@ # Constants DEFAULT_TEMPERATURE = 1.0 +DEFAULT_TOP_K = -1 DEFAULT_TOP_P = 1.0 DEFAULT_MAX_TOKENS = 256 CHAT_COMPLETION_OBJECT = "chat.completion" @@ -63,6 +64,7 @@ class ChatCompletionRequest(BaseModel): messages: Optional[List[ChatMessage]] = None prompt: Optional[List[ChatMessage]] = None # Accept 'prompt' as alias temperature: Optional[float] = DEFAULT_TEMPERATURE + top_k: Optional[int] = DEFAULT_TOP_K top_p: Optional[float] = DEFAULT_TOP_P max_tokens: Optional[int] = DEFAULT_MAX_TOKENS stop: Optional[List[str]] = None @@ -86,6 +88,7 @@ class CompletionRequest(BaseModel): model: Optional[str] = None prompt: str temperature: Optional[float] = DEFAULT_TEMPERATURE + top_k: Optional[int] = DEFAULT_TOP_K top_p: Optional[float] = DEFAULT_TOP_P max_tokens: Optional[int] = DEFAULT_MAX_TOKENS stop: Optional[List[str]] = None @@ -253,9 +256,13 @@ def _build_sampling_params( max_tokens: int, stop_strings: Optional[List[str]], ignore_eos: bool, + top_k: int = -1, + top_p: float = 1.0, ) -> SamplingParams: return SamplingParams( temperature=temperature, + top_k=top_k, + top_p=top_p, max_tokens=max_tokens, stop_strings=stop_strings, ignore_eos=ignore_eos, @@ -667,6 +674,8 @@ async def chat_completions(request: ChatCompletionRequest): max_tokens=request.max_tokens, stop_strings=request.stop, ignore_eos=request.ignore_eos, + top_k=request.top_k, + top_p=request.top_p, ) request_id = f"chatcmpl-{uuid.uuid4().hex}" @@ -749,6 +758,8 @@ async def completions(request: CompletionRequest): max_tokens=request.max_tokens, stop_strings=request.stop, ignore_eos=request.ignore_eos, + top_k=request.top_k, + top_p=request.top_p, ) request_id = f"cmpl-{uuid.uuid4().hex}" diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 93d6290f7..81626a7e6 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -854,6 +854,8 @@ def allocate_forward_vars(self): "input_ids": self.tokenID_processor.input_ids, "positions": CpuGpuBuffer(self.max_num_batched_tokens, **i64_kwargs), "temperatures": CpuGpuBuffer(self.max_bs, **f32_kwargs), + "top_ks": CpuGpuBuffer(self.max_bs, **i32_kwargs), + "top_ps": CpuGpuBuffer(self.max_bs, **f32_kwargs), # Keep enough space for MTP decode (max_q_len > 1). "outputs": torch.empty( self.max_num_batched_tokens, hidden_size, dtype=hidden_type @@ -1175,11 +1177,24 @@ def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None): ) return graph_bs - def prepare_sample(self, batch: ScheduledBatch) -> torch.Tensor: + def prepare_sample( + self, batch: ScheduledBatch + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bs = batch.total_seqs_num - buffer = self.forward_vars["temperatures"] - buffer.np[:bs] = batch.temperatures - return buffer.copy_to_gpu(bs) + + temp_buffer = self.forward_vars["temperatures"] + temp_buffer.np[:bs] = batch.temperatures + temperatures = temp_buffer.copy_to_gpu(bs) + + top_k_buffer = self.forward_vars["top_ks"] + top_k_buffer.np[:bs] = batch.top_ks + top_ks = top_k_buffer.copy_to_gpu(bs) + + top_p_buffer = self.forward_vars["top_ps"] + top_p_buffer.np[:bs] = batch.top_ps + top_ps = top_p_buffer.copy_to_gpu(bs) + + return temperatures, top_ks, top_ps def prepare_model(self, batch: ScheduledBatch): total_tokens_num = batch.total_tokens_num @@ -1187,10 +1202,12 @@ def prepare_model(self, batch: ScheduledBatch): input_ids = self.tokenID_processor.prepare_input_ids(batch) self.prepare_inputs(batch, input_ids) - temperatures = self.prepare_sample(batch) + temperatures, top_ks, top_ps = self.prepare_sample(batch) return ( input_ids, temperatures, + top_ks, + top_ps, ) def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -1215,12 +1232,14 @@ def postprocess( batch: ScheduledBatch, logits: torch.Tensor, temperatures: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, # following for draft hidden_states: torch.Tensor, ) -> tuple[dict[int, tuple[int, ...]], Optional[torch.Tensor]]: spec_decode_metadata = get_forward_context().spec_decode_metadata if spec_decode_metadata is None: - sampled_tokens = self.sampler(logits, temperatures) + sampled_tokens = self.sampler(logits, temperatures, top_ks, top_ps) else: assert logits is not None bonus_logits_indices = spec_decode_metadata.bonus_logits_indices @@ -1230,6 +1249,8 @@ def postprocess( bonus_token_ids = self.sampler( logits=bonus_logits, temperatures=temperatures, + top_ks=top_ks, + top_ps=top_ps, ) # Just like `bonus_logits`, `target_logits` is a new tensor with # separate storage from the original `logits` tensor. Therefore, @@ -1290,12 +1311,14 @@ def postprocess( @torch.inference_mode() def forward(self, batch: ScheduledBatch) -> ScheduledBatchOutput: - input_ids, temperatures = self.prepare_model(batch) + input_ids, temperatures, top_ks, top_ps = self.prepare_model(batch) logits, hidden_states = self.run_model(input_ids) fwd_output = self.postprocess( batch, logits, temperatures, + top_ks, + top_ps, hidden_states, ) reset_forward_context() diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 658acdda6..dc7c43ff5 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -38,6 +38,8 @@ def __init__( # print(f"{num_scheduled_tokens=}") # print(f"{self.scheduled_tokens=}") self.temperatures = [seq.temperature for seq in seqs.values()] + self.top_ks = [seq.top_k for seq in seqs.values()] + self.top_ps = [seq.top_p for seq in seqs.values()] self.context_lens = [seq.num_tokens for seq in seqs.values()] self.block_tables = [ seq.block_table for seq in seqs.values() if seq.block_table diff --git a/atom/model_engine/sequence.py b/atom/model_engine/sequence.py index 8c4c9f092..d1d6e01cf 100644 --- a/atom/model_engine/sequence.py +++ b/atom/model_engine/sequence.py @@ -51,6 +51,8 @@ def __init__( self.num_cached_tokens = 0 self.block_table = [] self.temperature = sampling_params.temperature + self.top_k = sampling_params.top_k + self.top_p = sampling_params.top_p self.max_tokens = sampling_params.max_tokens self.ignore_eos = sampling_params.ignore_eos self.stop_strings = sampling_params.stop_strings diff --git a/atom/model_ops/sampler.py b/atom/model_ops/sampler.py index 0ea4df9f3..a5785fa06 100644 --- a/atom/model_ops/sampler.py +++ b/atom/model_ops/sampler.py @@ -7,6 +7,15 @@ from aiter.ops.triton.topk import topk from torch import nn +# Try to import aiter top-k/top-p sampling ops +try: + import aiter.ops.sampling # noqa: F401 + + aiter_ops = torch.ops.aiter + AITER_TOPK_TOPP_AVAILABLE = True +except ImportError: + AITER_TOPK_TOPP_AVAILABLE = False + class Sampler(nn.Module): @@ -16,14 +25,57 @@ def __init__(self): def forward( self, - logits: torch.Tensor, # (token_num, vocab_size) - temperatures: torch.Tensor, # (token_num,) - ) -> torch.Tensor: # (token_num,) + logits: torch.Tensor, # (num_tokens, vocab_size) + temperatures: torch.Tensor, # (num_tokens,) + top_ks: torch.Tensor = None, # (num_tokens,) int32, -1 means disabled + top_ps: torch.Tensor = None, # (num_tokens,) float32, 1.0 means disabled + ) -> torch.Tensor: # (num_tokens,) + """ + Sample tokens from logits with optional top-k and top-p filtering. + + Args: + logits: Raw logits from model (num_tokens, vocab_size) + temperatures: Temperature for each token (num_tokens,) + top_ks: Top-k value per token, -1 means disabled (num_tokens,) + top_ps: Top-p value per token, 1.0 means disabled (num_tokens,) + + Returns: + Sampled token IDs (num_tokens,) + """ + # Fast path: no filtering needed, use existing optimized sampler + if not self._needs_filtering(top_ks, top_ps): + return self._temperature_sample(logits, temperatures) + + # Slow path: apply top-k/top-p filtering + return self._topk_topp_sample(logits, temperatures, top_ks, top_ps) + + def _needs_filtering( + self, + top_ks: torch.Tensor, + top_ps: torch.Tensor, + ) -> bool: + """Check if any request needs top-k or top-p filtering.""" + if top_ks is None and top_ps is None: + return False + + needs_topk = top_ks is not None and (top_ks != -1).any() + needs_topp = top_ps is not None and (top_ps < 1.0).any() + + return needs_topk or needs_topp + + def _temperature_sample( + self, + logits: torch.Tensor, + temperatures: torch.Tensor, + ) -> torch.Tensor: + """Original temperature-based Gumbel-max sampling (fast path).""" sampled_tokens = torch.empty( logits.size(0), dtype=torch.int, device=logits.device ) exponential = ( - torch.empty((1, logits.shape[-1]), dtype=torch.float, device=logits.device) + torch.empty( + (1, logits.shape[-1]), dtype=torch.float, device=logits.device + ) .exponential_(1) .expand(*logits.shape) ) @@ -31,11 +83,136 @@ def forward( sampled_tokens, logits, exponential, temperatures, eps=self.eps ) return sampled_tokens - logits = logits.float() - return torch.where( - temperatures == 0, self.greedy_sample(logits), self.random_sample(logits) - ).to(torch.int) + def _topk_topp_sample( + self, + logits: torch.Tensor, + temperatures: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, + ) -> torch.Tensor: + """Top-K/Top-P sampling with temperature scaling.""" + # Apply temperature scaling + # Clamp to avoid division by zero; temperature=0 handled separately as greedy + scaled_logits = logits / temperatures.unsqueeze(-1).clamp(min=self.eps) + probs = scaled_logits.softmax(dim=-1, dtype=torch.float32).contiguous() + + # Determine which filtering is needed + has_topk = top_ks is not None and (top_ks != -1).any() + has_topp = top_ps is not None and (top_ps < 1.0).any() + + if AITER_TOPK_TOPP_AVAILABLE: + return self._aiter_sample( + probs, top_ks, top_ps, has_topk, has_topp, temperatures + ) + else: + return self._native_sample(probs, top_ks, top_ps, temperatures) + + def _to_tensor_scalar(self, x: torch.Tensor): + """Convert to (tensor, scalar) tuple for aiter ops.""" + if x is None: + return (None, 0) + if (x == x[0]).all(): # Uniform value - use scalar for efficiency + return (None, x[0].item()) + return (x, 0) + + def _aiter_sample( + self, + probs: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, + has_topk: bool, + has_topp: bool, + temperatures: torch.Tensor, + ) -> torch.Tensor: + """Use aiter optimized ops for top-k/top-p sampling.""" + # Convert to tensor/scalar format for aiter + k_tensor, k_scalar = self._to_tensor_scalar(top_ks) + p_tensor, p_scalar = self._to_tensor_scalar(top_ps) + + if has_topk and has_topp: + # Joint k+p path + next_tokens = aiter_ops.top_k_top_p_sampling_from_probs( + probs, + None, + k_tensor, + k_scalar, + p_tensor, + p_scalar, + deterministic=True, + ) + elif has_topp: + # Top-p only + next_tokens = aiter_ops.top_p_sampling_from_probs( + probs, None, p_tensor, p_scalar, deterministic=True + ) + elif has_topk: + # Top-k only: renormalize and multinomial + renorm_probs = aiter_ops.top_k_renorm_probs(probs, k_tensor, k_scalar) + next_tokens = torch.multinomial(renorm_probs, num_samples=1) + else: + # Neither - just multinomial from probs + next_tokens = torch.multinomial(probs, num_samples=1) + + # Handle greedy sampling (temperature=0) + greedy_mask = temperatures == 0 + if greedy_mask.any(): + next_tokens[greedy_mask] = probs[greedy_mask].argmax(dim=-1).unsqueeze(-1) + + return next_tokens.view(-1).to(torch.int) + + def _native_sample( + self, + probs: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, + temperatures: torch.Tensor, + ) -> torch.Tensor: + """Native PyTorch fallback for top-k/top-p sampling.""" + batch_size, vocab_size = probs.shape + device = probs.device + + # Sort probs descending + sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) + cumsum_probs = torch.cumsum(sorted_probs, dim=-1) + + # Top-p mask: keep tokens until cumsum exceeds top_p + # The mask keeps tokens where cumsum - current_prob <= top_p + # (i.e., before we exceed the threshold) + if top_ps is not None: + topp_mask = (cumsum_probs - sorted_probs) <= top_ps.unsqueeze(-1) + else: + topp_mask = torch.ones_like(sorted_probs, dtype=torch.bool) + + # Top-k mask: keep first k tokens + if top_ks is not None: + indices = torch.arange(vocab_size, device=device).unsqueeze(0) + effective_k = torch.where(top_ks == -1, vocab_size, top_ks) + topk_mask = indices < effective_k.unsqueeze(-1) + else: + topk_mask = torch.ones_like(sorted_probs, dtype=torch.bool) + + # Combined filtering + mask = topp_mask & topk_mask + mask[:, 0] = True # Always keep at least one token + + filtered_probs = sorted_probs * mask.float() + filtered_probs = filtered_probs / filtered_probs.sum( + dim=-1, keepdim=True + ).clamp(min=self.eps) + + # Sample and map back to original indices + sampled_idx = torch.multinomial(filtered_probs, num_samples=1).squeeze(-1) + next_tokens = sorted_indices.gather(1, sampled_idx.unsqueeze(-1)).squeeze(-1) + + # Handle greedy (temperature=0) + greedy_mask = temperatures == 0 + if greedy_mask.any(): + next_tokens[greedy_mask] = probs[greedy_mask].argmax(dim=-1) + + return next_tokens.to(torch.int) + + # Legacy methods kept for reference def greedy_sample( self, logits: torch.Tensor # (token_num, vocab_size) ) -> torch.Tensor: # (token_num,) diff --git a/atom/sampling_params.py b/atom/sampling_params.py index 4f8e0fa16..7c0327023 100644 --- a/atom/sampling_params.py +++ b/atom/sampling_params.py @@ -8,6 +8,14 @@ @dataclass class SamplingParams: temperature: float = 1.0 + top_k: int = -1 # -1 means disabled (keep all tokens) + top_p: float = 1.0 # 1.0 means disabled (keep all tokens) max_tokens: int = 64 ignore_eos: bool = False stop_strings: Optional[list[str]] = None + + def __post_init__(self): + if self.top_k != -1 and self.top_k < 1: + raise ValueError("top_k must be -1 (disabled) or >= 1") + if not (0.0 < self.top_p <= 1.0): + raise ValueError("top_p must be in range (0.0, 1.0]") From 9d19d260f1d722bbfa4cb966b8102ecb9626edbb Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Mon, 23 Feb 2026 16:40:25 +0000 Subject: [PATCH 2/6] sampler.py: change in comments and default param syntax --- atom/model_ops/sampler.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/atom/model_ops/sampler.py b/atom/model_ops/sampler.py index a5785fa06..c755e6ca7 100644 --- a/atom/model_ops/sampler.py +++ b/atom/model_ops/sampler.py @@ -27,11 +27,11 @@ def forward( self, logits: torch.Tensor, # (num_tokens, vocab_size) temperatures: torch.Tensor, # (num_tokens,) - top_ks: torch.Tensor = None, # (num_tokens,) int32, -1 means disabled - top_ps: torch.Tensor = None, # (num_tokens,) float32, 1.0 means disabled + top_ks: torch.Tensor | None = None, # (num_tokens,) int32, -1 means disabled + top_ps: torch.Tensor | None = None, # (num_tokens,) float32, 1.0 means disabled ) -> torch.Tensor: # (num_tokens,) """ - Sample tokens from logits with optional top-k and top-p filtering. + Sample tokens from logits using temperature or top-k top-p filtering. Args: logits: Raw logits from model (num_tokens, vocab_size) @@ -42,11 +42,11 @@ def forward( Returns: Sampled token IDs (num_tokens,) """ - # Fast path: no filtering needed, use existing optimized sampler + # No Top-K Top-P parameters, perform temperature-based sampling if not self._needs_filtering(top_ks, top_ps): return self._temperature_sample(logits, temperatures) - # Slow path: apply top-k/top-p filtering + # Apply top-k/top-p filtering return self._topk_topp_sample(logits, temperatures, top_ks, top_ps) def _needs_filtering( @@ -68,7 +68,7 @@ def _temperature_sample( logits: torch.Tensor, temperatures: torch.Tensor, ) -> torch.Tensor: - """Original temperature-based Gumbel-max sampling (fast path).""" + """Temperature-based Gumbel-max sampling.""" sampled_tokens = torch.empty( logits.size(0), dtype=torch.int, device=logits.device ) From 1e58f53c7b8e8a91f208a690fca14c129f23994c Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Tue, 24 Feb 2026 13:29:50 +0000 Subject: [PATCH 3/6] sampler.py: adds fast path for all temperature=0 --- atom/model_ops/sampler.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/atom/model_ops/sampler.py b/atom/model_ops/sampler.py index c755e6ca7..851aba87b 100644 --- a/atom/model_ops/sampler.py +++ b/atom/model_ops/sampler.py @@ -92,6 +92,12 @@ def _topk_topp_sample( top_ps: torch.Tensor, ) -> torch.Tensor: """Top-K/Top-P sampling with temperature scaling.""" + # Fast path: if ALL requests are greedy (temperature=0), just do argmax + # This avoids the overhead of softmax and top-k/top-p filtering + all_greedy = (temperatures == 0).all() + if all_greedy: + return logits.argmax(dim=-1).to(torch.int) + # Apply temperature scaling # Clamp to avoid division by zero; temperature=0 handled separately as greedy scaled_logits = logits / temperatures.unsqueeze(-1).clamp(min=self.eps) From db9755249bda193fec46b48fac4cb12a75b15671 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Wed, 25 Feb 2026 11:19:32 +0000 Subject: [PATCH 4/6] sampler.py: adds warning for native Pytorch implementation of top-k top-p --- atom/model_ops/sampler.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/atom/model_ops/sampler.py b/atom/model_ops/sampler.py index 851aba87b..8781f17f8 100644 --- a/atom/model_ops/sampler.py +++ b/atom/model_ops/sampler.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import warnings + import torch from aiter import mixed_sample_outer_exponential from aiter.ops.triton.softmax import softmax @@ -15,6 +17,15 @@ AITER_TOPK_TOPP_AVAILABLE = True except ImportError: AITER_TOPK_TOPP_AVAILABLE = False + warnings.warn( + "aiter.ops.sampling not available. Top-k/top-p sampling will use " + "experimental native PyTorch implementation as fallback.", + UserWarning, + stacklevel=1, + ) + +# Track whether we've already warned about native sampling being used +_NATIVE_SAMPLING_WARNING_ISSUED = False class Sampler(nn.Module): @@ -174,7 +185,23 @@ def _native_sample( top_ps: torch.Tensor, temperatures: torch.Tensor, ) -> torch.Tensor: - """Native PyTorch fallback for top-k/top-p sampling.""" + """ + EXPERIMENTAL: Native PyTorch fallback for top-k/top-p sampling. + + This implementation has not been thoroughly tested and may produce + different results compared to the optimized aiter implementation. + Use aiter.ops.sampling for production workloads. + """ + global _NATIVE_SAMPLING_WARNING_ISSUED + if not _NATIVE_SAMPLING_WARNING_ISSUED: + warnings.warn( + "Using experimental native top-k/top-p sampling. " + "Install aiter.ops.sampling for optimized performance.", + UserWarning, + stacklevel=2, + ) + _NATIVE_SAMPLING_WARNING_ISSUED = True + batch_size, vocab_size = probs.shape device = probs.device From b74f38295279171f6c86c6dbd9f625d61d13ca23 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 26 Feb 2026 13:53:37 +0000 Subject: [PATCH 5/6] only copy one element to GPU if all ks or ps the same --- atom/model_engine/model_runner.py | 12 ++++++++++-- atom/model_ops/sampler.py | 8 ++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 81626a7e6..1f354a964 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1186,13 +1186,21 @@ def prepare_sample( temp_buffer.np[:bs] = batch.temperatures temperatures = temp_buffer.copy_to_gpu(bs) + # For top_ks and top_ps, check uniformity on CPU before GPU copy. + # If all values are the same, only copy a single element to save bandwidth. top_k_buffer = self.forward_vars["top_ks"] top_k_buffer.np[:bs] = batch.top_ks - top_ks = top_k_buffer.copy_to_gpu(bs) + if bs > 1 and all(k == batch.top_ks[0] for k in batch.top_ks): + top_ks = top_k_buffer.copy_to_gpu(1) + else: + top_ks = top_k_buffer.copy_to_gpu(bs) top_p_buffer = self.forward_vars["top_ps"] top_p_buffer.np[:bs] = batch.top_ps - top_ps = top_p_buffer.copy_to_gpu(bs) + if bs > 1 and all(p == batch.top_ps[0] for p in batch.top_ps): + top_ps = top_p_buffer.copy_to_gpu(1) + else: + top_ps = top_p_buffer.copy_to_gpu(bs) return temperatures, top_ks, top_ps diff --git a/atom/model_ops/sampler.py b/atom/model_ops/sampler.py index 8781f17f8..656ad14fa 100644 --- a/atom/model_ops/sampler.py +++ b/atom/model_ops/sampler.py @@ -126,10 +126,14 @@ def _topk_topp_sample( return self._native_sample(probs, top_ks, top_ps, temperatures) def _to_tensor_scalar(self, x: torch.Tensor): - """Convert to (tensor, scalar) tuple for aiter ops.""" + """Convert to (tensor, scalar) tuple for aiter ops. + + If tensor has size 1 (uniform value optimization from model_runner), + extract the scalar value for more efficient aiter kernel dispatch. + """ if x is None: return (None, 0) - if (x == x[0]).all(): # Uniform value - use scalar for efficiency + if x.numel() == 1: # Uniform value - use scalar for efficiency return (None, x[0].item()) return (x, 0) From 73e788f9fd280e7436961a9e824417585dd854e4 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 6 Mar 2026 17:29:51 +0000 Subject: [PATCH 6/6] check if topk or topp filtering is required in model_runner.py --- atom/model_engine/model_runner.py | 39 ++++++++++++++++++++----------- atom/model_ops/sampler.py | 26 ++++++++++----------- 2 files changed, 37 insertions(+), 28 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 4ce9e31b5..727442b5c 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1309,28 +1309,39 @@ def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None): def prepare_sample( self, batch: ScheduledBatch - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: bs = batch.total_seqs_num temp_buffer = self.forward_vars["temperatures"] temp_buffer.np[:bs] = batch.temperatures temperatures = temp_buffer.copy_to_gpu(bs) - # For top_ks and top_ps, check uniformity on CPU before GPU copy. - # If all values are the same, only copy a single element to save bandwidth. - top_k_buffer = self.forward_vars["top_ks"] - top_k_buffer.np[:bs] = batch.top_ks - if bs > 1 and all(k == batch.top_ks[0] for k in batch.top_ks): - top_ks = top_k_buffer.copy_to_gpu(1) + # Check on CPU whether filtering is needed to avoid GPU sync in sampler. + # If no filtering needed, return None to skip GPU copy entirely. + needs_topk = (batch.top_ks != -1).any() + needs_topp = (batch.top_ps < 1.0).any() + + if needs_topk: + top_k_buffer = self.forward_vars["top_ks"] + top_k_buffer.np[:bs] = batch.top_ks + # If all values are the same, only copy one element to save bandwidth + if bs > 1 and (batch.top_ks == batch.top_ks[0]).all(): + top_ks = top_k_buffer.copy_to_gpu(1) + else: + top_ks = top_k_buffer.copy_to_gpu(bs) else: - top_ks = top_k_buffer.copy_to_gpu(bs) - - top_p_buffer = self.forward_vars["top_ps"] - top_p_buffer.np[:bs] = batch.top_ps - if bs > 1 and all(p == batch.top_ps[0] for p in batch.top_ps): - top_ps = top_p_buffer.copy_to_gpu(1) + top_ks = None + + if needs_topp: + top_p_buffer = self.forward_vars["top_ps"] + top_p_buffer.np[:bs] = batch.top_ps + # If all values are the same, only copy one element to save bandwidth + if bs > 1 and (batch.top_ps == batch.top_ps[0]).all(): + top_ps = top_p_buffer.copy_to_gpu(1) + else: + top_ps = top_p_buffer.copy_to_gpu(bs) else: - top_ps = top_p_buffer.copy_to_gpu(bs) + top_ps = None return temperatures, top_ks, top_ps diff --git a/atom/model_ops/sampler.py b/atom/model_ops/sampler.py index 656ad14fa..154ea9ebb 100644 --- a/atom/model_ops/sampler.py +++ b/atom/model_ops/sampler.py @@ -62,17 +62,15 @@ def forward( def _needs_filtering( self, - top_ks: torch.Tensor, - top_ps: torch.Tensor, + top_ks: torch.Tensor | None, + top_ps: torch.Tensor | None, ) -> bool: - """Check if any request needs top-k or top-p filtering.""" - if top_ks is None and top_ps is None: - return False - - needs_topk = top_ks is not None and (top_ks != -1).any() - needs_topp = top_ps is not None and (top_ps < 1.0).any() + """Check if any request needs top-k or top-p filtering. - return needs_topk or needs_topp + This check is O(1) - the actual filtering check is done on CPU in + model_runner.prepare_sample(), which passes None if no filtering needed. + """ + return top_ks is not None or top_ps is not None def _temperature_sample( self, @@ -99,8 +97,8 @@ def _topk_topp_sample( self, logits: torch.Tensor, temperatures: torch.Tensor, - top_ks: torch.Tensor, - top_ps: torch.Tensor, + top_ks: torch.Tensor | None, + top_ps: torch.Tensor | None, ) -> torch.Tensor: """Top-K/Top-P sampling with temperature scaling.""" # Fast path: if ALL requests are greedy (temperature=0), just do argmax @@ -114,9 +112,9 @@ def _topk_topp_sample( scaled_logits = logits / temperatures.unsqueeze(-1).clamp(min=self.eps) probs = scaled_logits.softmax(dim=-1, dtype=torch.float32).contiguous() - # Determine which filtering is needed - has_topk = top_ks is not None and (top_ks != -1).any() - has_topp = top_ps is not None and (top_ps < 1.0).any() + # model_runner.prepare_sample passes None if filtering not needed for that type + has_topk = top_ks is not None + has_topp = top_ps is not None if AITER_TOPK_TOPP_AVAILABLE: return self._aiter_sample(