diff --git a/atom/entrypoints/openai_server.py b/atom/entrypoints/openai_server.py index 9ed9c3b45..75dc2e766 100644 --- a/atom/entrypoints/openai_server.py +++ b/atom/entrypoints/openai_server.py @@ -32,6 +32,7 @@ # Constants DEFAULT_TEMPERATURE = 1.0 +DEFAULT_TOP_K = -1 DEFAULT_TOP_P = 1.0 DEFAULT_MAX_TOKENS = 256 CHAT_COMPLETION_OBJECT = "chat.completion" @@ -63,6 +64,7 @@ class ChatCompletionRequest(BaseModel): messages: Optional[List[ChatMessage]] = None prompt: Optional[List[ChatMessage]] = None # Accept 'prompt' as alias temperature: Optional[float] = DEFAULT_TEMPERATURE + top_k: Optional[int] = DEFAULT_TOP_K top_p: Optional[float] = DEFAULT_TOP_P max_tokens: Optional[int] = DEFAULT_MAX_TOKENS stop: Optional[List[str]] = None @@ -86,6 +88,7 @@ class CompletionRequest(BaseModel): model: Optional[str] = None prompt: str temperature: Optional[float] = DEFAULT_TEMPERATURE + top_k: Optional[int] = DEFAULT_TOP_K top_p: Optional[float] = DEFAULT_TOP_P max_tokens: Optional[int] = DEFAULT_MAX_TOKENS stop: Optional[List[str]] = None @@ -253,9 +256,13 @@ def _build_sampling_params( max_tokens: int, stop_strings: Optional[List[str]], ignore_eos: bool, + top_k: int = -1, + top_p: float = 1.0, ) -> SamplingParams: return SamplingParams( temperature=temperature, + top_k=top_k, + top_p=top_p, max_tokens=max_tokens, stop_strings=stop_strings, ignore_eos=ignore_eos, @@ -667,6 +674,8 @@ async def chat_completions(request: ChatCompletionRequest): max_tokens=request.max_tokens, stop_strings=request.stop, ignore_eos=request.ignore_eos, + top_k=request.top_k, + top_p=request.top_p, ) request_id = f"chatcmpl-{uuid.uuid4().hex}" @@ -749,6 +758,8 @@ async def completions(request: CompletionRequest): max_tokens=request.max_tokens, stop_strings=request.stop, ignore_eos=request.ignore_eos, + top_k=request.top_k, + top_p=request.top_p, ) request_id = f"cmpl-{uuid.uuid4().hex}" diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index bad557093..727442b5c 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -852,6 +852,8 @@ def allocate_forward_vars(self): "input_ids": self.tokenID_processor.input_ids, "positions": CpuGpuBuffer(self.max_num_batched_tokens, **i64_kwargs), "temperatures": CpuGpuBuffer(self.max_bs, **f32_kwargs), + "top_ks": CpuGpuBuffer(self.max_bs, **i32_kwargs), + "top_ps": CpuGpuBuffer(self.max_bs, **f32_kwargs), # Keep enough space for MTP decode (max_q_len > 1). "outputs": torch.empty( self.max_num_batched_tokens, hidden_size, dtype=hidden_type @@ -1305,23 +1307,56 @@ def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None): ) return graph_bs - def prepare_sample(self, batch: ScheduledBatch) -> torch.Tensor: + def prepare_sample( + self, batch: ScheduledBatch + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: bs = batch.total_seqs_num - buffer = self.forward_vars["temperatures"] - buffer.np[:bs] = batch.temperatures - return buffer.copy_to_gpu(bs) + + temp_buffer = self.forward_vars["temperatures"] + temp_buffer.np[:bs] = batch.temperatures + temperatures = temp_buffer.copy_to_gpu(bs) + + # Check on CPU whether filtering is needed to avoid GPU sync in sampler. + # If no filtering needed, return None to skip GPU copy entirely. + needs_topk = (batch.top_ks != -1).any() + needs_topp = (batch.top_ps < 1.0).any() + + if needs_topk: + top_k_buffer = self.forward_vars["top_ks"] + top_k_buffer.np[:bs] = batch.top_ks + # If all values are the same, only copy one element to save bandwidth + if bs > 1 and (batch.top_ks == batch.top_ks[0]).all(): + top_ks = top_k_buffer.copy_to_gpu(1) + else: + top_ks = top_k_buffer.copy_to_gpu(bs) + else: + top_ks = None + + if needs_topp: + top_p_buffer = self.forward_vars["top_ps"] + top_p_buffer.np[:bs] = batch.top_ps + # If all values are the same, only copy one element to save bandwidth + if bs > 1 and (batch.top_ps == batch.top_ps[0]).all(): + top_ps = top_p_buffer.copy_to_gpu(1) + else: + top_ps = top_p_buffer.copy_to_gpu(bs) + else: + top_ps = None + + return temperatures, top_ks, top_ps def prepare_model(self, batch: ScheduledBatch): total_tokens_num = batch.total_tokens_num assert total_tokens_num > 0 - temperatures = self.prepare_sample(batch) + temperatures, top_ks, top_ps = self.prepare_sample(batch) input_ids = self.tokenID_processor.prepare_input_ids(batch) - # self.debug(f"{input_ids=}") self.prepare_inputs(batch, input_ids) return ( input_ids, temperatures, + top_ks, + top_ps, ) def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -1348,13 +1383,15 @@ def postprocess( batch: ScheduledBatch, logits: torch.Tensor, temperatures: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, # following for draft hidden_states: torch.Tensor, ) -> ScheduledBatchOutput: spec_decode_metadata = get_forward_context().spec_decode_metadata bs = batch.total_seqs_num if spec_decode_metadata is None: - sampled_tokens = self.sampler(logits, temperatures) + sampled_tokens = self.sampler(logits, temperatures, top_ks, top_ps) num_reject_tokens = self.tokenID_processor.default_num_rejected_tokens[:bs] next_token_locs = num_reject_tokens else: @@ -1367,6 +1404,8 @@ def postprocess( bonus_token_ids = self.sampler( logits=bonus_logits, temperatures=temperatures, + top_ks=top_ks, + top_ps=top_ps, ) # Validate shapes match expectations if target_logits.shape[0] != len(spec_decode_metadata.draft_token_ids): @@ -1429,12 +1468,14 @@ def postprocess( @torch.inference_mode() def forward(self, batch: ScheduledBatch) -> ScheduledBatchOutput: - input_ids, temperatures = self.prepare_model(batch) + input_ids, temperatures, top_ks, top_ps = self.prepare_model(batch) logits, hidden_states = self.run_model(input_ids) fwd_output = self.postprocess( batch, logits, temperatures, + top_ks, + top_ps, hidden_states, ) reset_forward_context() diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 46405682f..c32a7ae1a 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -147,6 +147,12 @@ def __init__( self.mamba_block_tables = [ seq.mamba_block_table for seq in seqs.values() if seq.mamba_block_table ] + self.top_ks = np.asarray( + [seq.top_k for seq in seqs.values()], dtype=np.int32 + ) + self.top_ps = np.asarray( + [seq.top_p for seq in seqs.values()], dtype=np.float32 + ) offs = self.context_lens - self.num_rejected - self.num_scheduled_tokens self.scheduled_tokens = np.empty(total_tokens_num, dtype=np.int32) diff --git a/atom/model_engine/sequence.py b/atom/model_engine/sequence.py index ee00d2ee9..617db9726 100644 --- a/atom/model_engine/sequence.py +++ b/atom/model_engine/sequence.py @@ -58,6 +58,8 @@ def __init__( self.block_table = [] self.mamba_block_table = [] self.temperature = sampling_params.temperature + self.top_k = sampling_params.top_k + self.top_p = sampling_params.top_p self.max_tokens = sampling_params.max_tokens self.ignore_eos = sampling_params.ignore_eos self.stop_strings = sampling_params.stop_strings diff --git a/atom/model_ops/sampler.py b/atom/model_ops/sampler.py index 0ea4df9f3..154ea9ebb 100644 --- a/atom/model_ops/sampler.py +++ b/atom/model_ops/sampler.py @@ -1,12 +1,32 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import warnings + import torch from aiter import mixed_sample_outer_exponential from aiter.ops.triton.softmax import softmax from aiter.ops.triton.topk import topk from torch import nn +# Try to import aiter top-k/top-p sampling ops +try: + import aiter.ops.sampling # noqa: F401 + + aiter_ops = torch.ops.aiter + AITER_TOPK_TOPP_AVAILABLE = True +except ImportError: + AITER_TOPK_TOPP_AVAILABLE = False + warnings.warn( + "aiter.ops.sampling not available. Top-k/top-p sampling will use " + "experimental native PyTorch implementation as fallback.", + UserWarning, + stacklevel=1, + ) + +# Track whether we've already warned about native sampling being used +_NATIVE_SAMPLING_WARNING_ISSUED = False + class Sampler(nn.Module): @@ -16,14 +36,55 @@ def __init__(self): def forward( self, - logits: torch.Tensor, # (token_num, vocab_size) - temperatures: torch.Tensor, # (token_num,) - ) -> torch.Tensor: # (token_num,) + logits: torch.Tensor, # (num_tokens, vocab_size) + temperatures: torch.Tensor, # (num_tokens,) + top_ks: torch.Tensor | None = None, # (num_tokens,) int32, -1 means disabled + top_ps: torch.Tensor | None = None, # (num_tokens,) float32, 1.0 means disabled + ) -> torch.Tensor: # (num_tokens,) + """ + Sample tokens from logits using temperature or top-k top-p filtering. + + Args: + logits: Raw logits from model (num_tokens, vocab_size) + temperatures: Temperature for each token (num_tokens,) + top_ks: Top-k value per token, -1 means disabled (num_tokens,) + top_ps: Top-p value per token, 1.0 means disabled (num_tokens,) + + Returns: + Sampled token IDs (num_tokens,) + """ + # No Top-K Top-P parameters, perform temperature-based sampling + if not self._needs_filtering(top_ks, top_ps): + return self._temperature_sample(logits, temperatures) + + # Apply top-k/top-p filtering + return self._topk_topp_sample(logits, temperatures, top_ks, top_ps) + + def _needs_filtering( + self, + top_ks: torch.Tensor | None, + top_ps: torch.Tensor | None, + ) -> bool: + """Check if any request needs top-k or top-p filtering. + + This check is O(1) - the actual filtering check is done on CPU in + model_runner.prepare_sample(), which passes None if no filtering needed. + """ + return top_ks is not None or top_ps is not None + + def _temperature_sample( + self, + logits: torch.Tensor, + temperatures: torch.Tensor, + ) -> torch.Tensor: + """Temperature-based Gumbel-max sampling.""" sampled_tokens = torch.empty( logits.size(0), dtype=torch.int, device=logits.device ) exponential = ( - torch.empty((1, logits.shape[-1]), dtype=torch.float, device=logits.device) + torch.empty( + (1, logits.shape[-1]), dtype=torch.float, device=logits.device + ) .exponential_(1) .expand(*logits.shape) ) @@ -31,11 +92,162 @@ def forward( sampled_tokens, logits, exponential, temperatures, eps=self.eps ) return sampled_tokens - logits = logits.float() - return torch.where( - temperatures == 0, self.greedy_sample(logits), self.random_sample(logits) - ).to(torch.int) + def _topk_topp_sample( + self, + logits: torch.Tensor, + temperatures: torch.Tensor, + top_ks: torch.Tensor | None, + top_ps: torch.Tensor | None, + ) -> torch.Tensor: + """Top-K/Top-P sampling with temperature scaling.""" + # Fast path: if ALL requests are greedy (temperature=0), just do argmax + # This avoids the overhead of softmax and top-k/top-p filtering + all_greedy = (temperatures == 0).all() + if all_greedy: + return logits.argmax(dim=-1).to(torch.int) + + # Apply temperature scaling + # Clamp to avoid division by zero; temperature=0 handled separately as greedy + scaled_logits = logits / temperatures.unsqueeze(-1).clamp(min=self.eps) + probs = scaled_logits.softmax(dim=-1, dtype=torch.float32).contiguous() + + # model_runner.prepare_sample passes None if filtering not needed for that type + has_topk = top_ks is not None + has_topp = top_ps is not None + + if AITER_TOPK_TOPP_AVAILABLE: + return self._aiter_sample( + probs, top_ks, top_ps, has_topk, has_topp, temperatures + ) + else: + return self._native_sample(probs, top_ks, top_ps, temperatures) + + def _to_tensor_scalar(self, x: torch.Tensor): + """Convert to (tensor, scalar) tuple for aiter ops. + + If tensor has size 1 (uniform value optimization from model_runner), + extract the scalar value for more efficient aiter kernel dispatch. + """ + if x is None: + return (None, 0) + if x.numel() == 1: # Uniform value - use scalar for efficiency + return (None, x[0].item()) + return (x, 0) + + def _aiter_sample( + self, + probs: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, + has_topk: bool, + has_topp: bool, + temperatures: torch.Tensor, + ) -> torch.Tensor: + """Use aiter optimized ops for top-k/top-p sampling.""" + # Convert to tensor/scalar format for aiter + k_tensor, k_scalar = self._to_tensor_scalar(top_ks) + p_tensor, p_scalar = self._to_tensor_scalar(top_ps) + + if has_topk and has_topp: + # Joint k+p path + next_tokens = aiter_ops.top_k_top_p_sampling_from_probs( + probs, + None, + k_tensor, + k_scalar, + p_tensor, + p_scalar, + deterministic=True, + ) + elif has_topp: + # Top-p only + next_tokens = aiter_ops.top_p_sampling_from_probs( + probs, None, p_tensor, p_scalar, deterministic=True + ) + elif has_topk: + # Top-k only: renormalize and multinomial + renorm_probs = aiter_ops.top_k_renorm_probs(probs, k_tensor, k_scalar) + next_tokens = torch.multinomial(renorm_probs, num_samples=1) + else: + # Neither - just multinomial from probs + next_tokens = torch.multinomial(probs, num_samples=1) + + # Handle greedy sampling (temperature=0) + greedy_mask = temperatures == 0 + if greedy_mask.any(): + next_tokens[greedy_mask] = probs[greedy_mask].argmax(dim=-1).unsqueeze(-1) + + return next_tokens.view(-1).to(torch.int) + + def _native_sample( + self, + probs: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, + temperatures: torch.Tensor, + ) -> torch.Tensor: + """ + EXPERIMENTAL: Native PyTorch fallback for top-k/top-p sampling. + + This implementation has not been thoroughly tested and may produce + different results compared to the optimized aiter implementation. + Use aiter.ops.sampling for production workloads. + """ + global _NATIVE_SAMPLING_WARNING_ISSUED + if not _NATIVE_SAMPLING_WARNING_ISSUED: + warnings.warn( + "Using experimental native top-k/top-p sampling. " + "Install aiter.ops.sampling for optimized performance.", + UserWarning, + stacklevel=2, + ) + _NATIVE_SAMPLING_WARNING_ISSUED = True + + batch_size, vocab_size = probs.shape + device = probs.device + + # Sort probs descending + sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) + cumsum_probs = torch.cumsum(sorted_probs, dim=-1) + + # Top-p mask: keep tokens until cumsum exceeds top_p + # The mask keeps tokens where cumsum - current_prob <= top_p + # (i.e., before we exceed the threshold) + if top_ps is not None: + topp_mask = (cumsum_probs - sorted_probs) <= top_ps.unsqueeze(-1) + else: + topp_mask = torch.ones_like(sorted_probs, dtype=torch.bool) + + # Top-k mask: keep first k tokens + if top_ks is not None: + indices = torch.arange(vocab_size, device=device).unsqueeze(0) + effective_k = torch.where(top_ks == -1, vocab_size, top_ks) + topk_mask = indices < effective_k.unsqueeze(-1) + else: + topk_mask = torch.ones_like(sorted_probs, dtype=torch.bool) + + # Combined filtering + mask = topp_mask & topk_mask + mask[:, 0] = True # Always keep at least one token + + filtered_probs = sorted_probs * mask.float() + filtered_probs = filtered_probs / filtered_probs.sum( + dim=-1, keepdim=True + ).clamp(min=self.eps) + + # Sample and map back to original indices + sampled_idx = torch.multinomial(filtered_probs, num_samples=1).squeeze(-1) + next_tokens = sorted_indices.gather(1, sampled_idx.unsqueeze(-1)).squeeze(-1) + + # Handle greedy (temperature=0) + greedy_mask = temperatures == 0 + if greedy_mask.any(): + next_tokens[greedy_mask] = probs[greedy_mask].argmax(dim=-1) + + return next_tokens.to(torch.int) + + # Legacy methods kept for reference def greedy_sample( self, logits: torch.Tensor # (token_num, vocab_size) ) -> torch.Tensor: # (token_num,) diff --git a/atom/sampling_params.py b/atom/sampling_params.py index 4f8e0fa16..7c0327023 100644 --- a/atom/sampling_params.py +++ b/atom/sampling_params.py @@ -8,6 +8,14 @@ @dataclass class SamplingParams: temperature: float = 1.0 + top_k: int = -1 # -1 means disabled (keep all tokens) + top_p: float = 1.0 # 1.0 means disabled (keep all tokens) max_tokens: int = 64 ignore_eos: bool = False stop_strings: Optional[list[str]] = None + + def __post_init__(self): + if self.top_k != -1 and self.top_k < 1: + raise ValueError("top_k must be -1 (disabled) or >= 1") + if not (0.0 < self.top_p <= 1.0): + raise ValueError("top_p must be in range (0.0, 1.0]")