From 33173b0d3d386dc3b85ae877f7b80f0ecb1d5f1d Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Tue, 24 Feb 2026 22:18:04 +0200 Subject: [PATCH 01/53] decouple base engine client logic from engine client Signed-off-by: Sage Ahrac --- vllm/engine/protocol.py | 50 +++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 91b1e41801a9..5f2686797e18 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -38,8 +38,15 @@ class StreamingInput: sampling_params: SamplingParams | None = None -class EngineClient(ABC): - """Protocol class for Clients to Engine""" +class BaseEngineClient(ABC): + """Engine client interface for non-inference operations. + + Contains only methods and attributes that don't require a running + inference engine: configuration, tokenization, health checks, and + status monitoring. + + See :class:`EngineClient` for the full interface including inference. + """ vllm_config: VllmConfig model_config: ModelConfig @@ -63,6 +70,30 @@ def errored(self) -> bool: ... @abstractmethod def dead_error(self) -> BaseException: ... + @abstractmethod + async def is_tracing_enabled(self) -> bool: ... + + @abstractmethod + async def do_log_stats(self) -> None: ... + + @abstractmethod + async def check_health(self) -> None: + """Raise if unhealthy""" + ... + + async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + """Get supported tasks""" + raise NotImplementedError + + +class EngineClient(BaseEngineClient): + """Full engine client interface including inference operations. + + Extends :class:`BaseEngineClient` with methods that require a running + inference engine: generation, encoding, profiling, cache management, + scheduler control, and weight transfer. + """ + @abstractmethod def generate( self, @@ -109,17 +140,6 @@ async def abort(self, request_id: str | Iterable[str]) -> None: """ ... - @abstractmethod - async def is_tracing_enabled(self) -> bool: ... - - @abstractmethod - async def do_log_stats(self) -> None: ... - - @abstractmethod - async def check_health(self) -> None: - """Raise if unhealthy""" - ... - @abstractmethod async def start_profile(self) -> None: """Start profiling the engine""" @@ -216,10 +236,6 @@ async def collective_rpc( """Perform a collective RPC call to the given path.""" raise NotImplementedError - async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: - """Get supported tasks""" - raise NotImplementedError - async def init_weight_transfer_engine( self, init_request: WeightTransferInitRequest ) -> None: From ef863e0f17666e5be5a51080684bf2b6f0c83c87 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Tue, 24 Feb 2026 22:26:17 +0200 Subject: [PATCH 02/53] split openapi serving class Signed-off-by: Sage Ahrac --- .../openai/chat_completion/serving.py | 4 +- vllm/entrypoints/openai/completion/serving.py | 4 +- vllm/entrypoints/openai/engine/serving.py | 1728 +++++++++-------- vllm/entrypoints/openai/realtime/serving.py | 4 +- vllm/entrypoints/openai/responses/serving.py | 4 +- .../openai/speech_to_text/speech_to_text.py | 7 +- vllm/entrypoints/pooling/classify/serving.py | 4 +- vllm/entrypoints/pooling/embed/serving.py | 4 +- vllm/entrypoints/pooling/pooling/serving.py | 4 +- vllm/entrypoints/pooling/score/serving.py | 4 +- vllm/entrypoints/serve/disagg/serving.py | 7 +- vllm/entrypoints/serve/tokenize/serving.py | 4 +- 12 files changed, 906 insertions(+), 872 deletions(-) diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 39f8635bf297..9ea03681852d 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -52,7 +52,7 @@ ) from vllm.entrypoints.openai.engine.serving import ( GenerationError, - OpenAIServing, + OpenAIServingInference, clamp_prompt_logprobs, ) from vllm.entrypoints.openai.models.serving import OpenAIServingModels @@ -85,7 +85,7 @@ logger = init_logger(__name__) -class OpenAIServingChat(OpenAIServing): +class OpenAIServingChat(OpenAIServingInference): def __init__( self, engine_client: EngineClient, diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index c6534489fd34..935560dd3b27 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -28,7 +28,7 @@ ) from vllm.entrypoints.openai.engine.serving import ( GenerationError, - OpenAIServing, + OpenAIServingInference, clamp_prompt_logprobs, ) from vllm.entrypoints.openai.models.serving import OpenAIServingModels @@ -46,7 +46,7 @@ logger = init_logger(__name__) -class OpenAIServingCompletion(OpenAIServing): +class OpenAIServingCompletion(OpenAIServingInference): def __init__( self, engine_client: EngineClient, diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 3e376ba9c704..2be03402fa9d 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -21,7 +21,7 @@ import vllm.envs as envs from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import ModelConfig -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import BaseEngineClient, EngineClient from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, @@ -229,7 +229,7 @@ class OpenAIServing: def __init__( self, - engine_client: EngineClient, + engine_client: BaseEngineClient, models: OpenAIServingModels, *, request_logger: RequestLogger | None, @@ -252,808 +252,1000 @@ def __init__( self.io_processor = engine_client.io_processor self.input_processor = engine_client.input_processor - async def beam_search( + def create_error_response( self, - prompt: ProcessorInputs, - request_id: str, - params: BeamSearchParams, - lora_request: LoRARequest | None = None, - trace_headers: Mapping[str, str] | None = None, - ) -> AsyncGenerator[RequestOutput, None]: - beam_width = params.beam_width - max_tokens = params.max_tokens - ignore_eos = params.ignore_eos - temperature = params.temperature - length_penalty = params.length_penalty - include_stop_str_in_output = params.include_stop_str_in_output + message: str | Exception, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + param: str | None = None, + ) -> ErrorResponse: + exc: Exception | None = None - tokenizer = self.renderer.get_tokenizer() - eos_token_id = tokenizer.eos_token_id - sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) + if isinstance(message, Exception): + exc = message - if prompt["type"] == "embeds": - raise NotImplementedError("Embedding prompt not supported for beam search") - if prompt["type"] == "enc_dec": - raise NotImplementedError( - "Encoder-decoder prompt not supported for beam search" - ) + from vllm.exceptions import VLLMValidationError - prompt_text = prompt.get("prompt") - prompt_token_ids = prompt["prompt_token_ids"] - tokenized_length = len(prompt_token_ids) + if isinstance(exc, VLLMValidationError): + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = exc.parameter + elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)): + # Common validation errors from user input + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = None + elif isinstance(exc, NotImplementedError): + err_type = "NotImplementedError" + status_code = HTTPStatus.NOT_IMPLEMENTED + param = None + elif exc.__class__.__name__ == "TemplateError": + # jinja2.TemplateError (avoid importing jinja2) + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = None + else: + err_type = "InternalServerError" + status_code = HTTPStatus.INTERNAL_SERVER_ERROR + param = None - logprobs_num = 2 * beam_width - sampling_params = SamplingParams( - logprobs=logprobs_num, - max_tokens=1, - temperature=temperature, - ) - all_beams = [ - BeamSearchSequence( - orig_prompt=prompt, - tokens=prompt_token_ids, - cum_logprob=0, - logprobs=[], - lora_request=lora_request, - ) - ] - completed = [] + message = str(exc) - for _ in range(max_tokens): - tasks = [] - request_id_batch = f"{request_id}-{random_uuid()}" + if self.log_error_stack: + exc_type, _, _ = sys.exc_info() + if exc_type is not None: + traceback.print_exc() + else: + traceback.print_stack() - for i, beam in enumerate(all_beams): - prompt_item = beam.get_prompt() - lora_request_item = beam.lora_request - request_id_item = f"{request_id_batch}-beam-{i}" - task = asyncio.create_task( - collect_from_async_generator( - self.engine_client.generate( - prompt_item, - sampling_params, - request_id_item, - lora_request=lora_request_item, - trace_headers=trace_headers, - ) - ) - ) - tasks.append(task) + return ErrorResponse( + error=ErrorInfo( + message=sanitize_message(message), + type=err_type, + code=status_code.value, + param=param, + ) + ) - output = [x[0] for x in await asyncio.gather(*tasks)] + def create_streaming_error_response( + self, + message: str | Exception, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + param: str | None = None, + ) -> str: + json_str = json.dumps( + self.create_error_response( + message=message, + err_type=err_type, + status_code=status_code, + param=param, + ).model_dump() + ) + return json_str - new_beams = [] - # Store all new tokens generated by beam - all_beams_token_id = [] - # Store the cumulative probability of all tokens - # generated by beam search - all_beams_logprob = [] - # Iterate through all beam inference results - for i, result in enumerate(output): - current_beam = all_beams[i] + def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None: + """Raise GenerationError if finish_reason indicates an error.""" + if finish_reason == "error": + logger.error( + "Request %s failed with an internal error during generation", + request_id, + ) + raise GenerationError("Internal server error") - # check for error finish reason and abort beam search - if result.outputs[0].finish_reason == "error": - # yield error output and terminate beam search - yield RequestOutput( - request_id=request_id, - prompt=prompt_text, - outputs=[ - CompletionOutput( - index=0, - text="", - token_ids=[], - cumulative_logprob=None, - logprobs=None, - finish_reason="error", - ) - ], - finished=True, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=None, - ) - return + def _convert_generation_error_to_response( + self, e: GenerationError + ) -> ErrorResponse: + """Convert GenerationError to ErrorResponse.""" + return self.create_error_response( + str(e), + err_type="InternalServerError", + status_code=e.status_code, + ) - if result.outputs[0].logprobs is not None: - logprobs = result.outputs[0].logprobs[0] - all_beams_token_id.extend(list(logprobs.keys())) - all_beams_logprob.extend( - [ - current_beam.cum_logprob + obj.logprob - for obj in logprobs.values() - ] - ) + def _convert_generation_error_to_streaming_response( + self, e: GenerationError + ) -> str: + """Convert GenerationError to streaming error response.""" + return self.create_streaming_error_response( + str(e), + err_type="InternalServerError", + status_code=e.status_code, + ) - # Handle the token for the end of sentence (EOS) - all_beams_token_id = np.array(all_beams_token_id) - all_beams_logprob = np.array(all_beams_logprob) + async def _check_model( + self, + request: AnyRequest, + ) -> ErrorResponse | None: + error_response = None - if not ignore_eos: - # Get the index position of eos token in all generated results - eos_idx = np.where(all_beams_token_id == eos_token_id)[0] - for idx in eos_idx: - current_beam = all_beams[idx // logprobs_num] - result = output[idx // logprobs_num] - assert result.outputs[0].logprobs is not None - logprobs_entry = result.outputs[0].logprobs[0] - completed.append( - BeamSearchSequence( - orig_prompt=prompt, - tokens=current_beam.tokens + [eos_token_id] - if include_stop_str_in_output - else current_beam.tokens, - logprobs=current_beam.logprobs + [logprobs_entry], - cum_logprob=float(all_beams_logprob[idx]), - finish_reason="stop", - stop_reason=eos_token_id, - ) - ) - # After processing, set the log probability of the eos condition - # to negative infinity. - all_beams_logprob[eos_idx] = -np.inf + if self._is_model_supported(request.model): + return None + if request.model in self.models.lora_requests: + return None + if ( + envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING + and request.model + and (load_result := await self.models.resolve_lora(request.model)) + ): + if isinstance(load_result, LoRARequest): + return None + if ( + isinstance(load_result, ErrorResponse) + and load_result.error.code == HTTPStatus.BAD_REQUEST.value + ): + error_response = load_result - # Processing non-EOS tokens - # Get indices of the top beam_width probabilities - topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[ - :beam_width - ] + return error_response or self.create_error_response( + message=f"The model `{request.model}` does not exist.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND, + param="model", + ) - for idx in topn_idx: - current_beam = all_beams[idx // logprobs_num] - result = output[idx // logprobs_num] - token_id = int(all_beams_token_id[idx]) - assert result.outputs[0].logprobs is not None - logprobs_entry = result.outputs[0].logprobs[0] - new_beams.append( - BeamSearchSequence( - orig_prompt=prompt, - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs_entry], - lora_request=current_beam.lora_request, - cum_logprob=float(all_beams_logprob[idx]), - ) - ) - - all_beams = new_beams - - completed.extend(all_beams) - sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) - best_beams = sorted_completed[:beam_width] - - for beam in best_beams: - if beam.tokens[-1] == eos_token_id and not ignore_eos: - # Skip the eos token in the text. - tokens = beam.tokens[tokenized_length:-1] - else: - tokens = beam.tokens[tokenized_length:] - beam.text = tokenizer.decode(tokens) + def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None: + """Determine if there are any active default multimodal loras.""" + # TODO: Currently this is only enabled for chat completions + # to be better aligned with only being enabled for .generate + # when run offline. It would be nice to support additional + # tasks types in the future. + message_types = self._get_message_types(request) + default_mm_loras = set() - yield RequestOutput( - request_id=request_id, - prompt=prompt_text, - outputs=[ - CompletionOutput( - text=beam.text, # type: ignore - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens[tokenized_length:], - index=i, - logprobs=beam.logprobs, - finish_reason=beam.finish_reason - if beam.finish_reason is not None - else "length", - stop_reason=beam.stop_reason, - ) - for (i, beam) in enumerate(best_beams) - ], - finished=True, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=None, - ) + for lora in self.models.lora_requests.values(): + # Best effort match for default multimodal lora adapters; + # There is probably a better way to do this, but currently + # this matches against the set of 'types' in any content lists + # up until '_', e.g., to match audio_url -> audio + if lora.lora_name in message_types: + default_mm_loras.add(lora) - async def _preprocess( - self, - ctx: ServeContext, - ) -> ErrorResponse | None: - """ - Default preprocessing hook. Subclasses may override - to prepare `ctx` (classification, embedding, etc.). - """ + # Currently only support default modality specific loras if + # we have exactly one lora matched on the request. + if len(default_mm_loras) == 1: + return default_mm_loras.pop() return None - def _build_response( + def _maybe_get_adapters( self, - ctx: ServeContext, - ) -> AnyResponse | ErrorResponse: - """ - Default response builder. Subclass may override this method - to return the appropriate response object. - """ - return self.create_error_response("unimplemented endpoint") + request: AnyRequest, + supports_default_mm_loras: bool = False, + ) -> LoRARequest | None: + if request.model in self.models.lora_requests: + return self.models.lora_requests[request.model] - async def handle( - self, - ctx: ServeContext, - ) -> AnyResponse | ErrorResponse: - async for response in self._pipeline(ctx): - return response + # Currently only support default modality specific loras + # if we have exactly one lora matched on the request. + if supports_default_mm_loras: + default_mm_lora = self._get_active_default_mm_loras(request) + if default_mm_lora is not None: + return default_mm_lora - return self.create_error_response("No response yielded from pipeline") + if self._is_model_supported(request.model): + return None - async def _pipeline( - self, - ctx: ServeContext, - ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]: - """Execute the request processing pipeline yielding responses.""" - if error := await self._check_model(ctx.request): - yield error - if error := self._validate_request(ctx): - yield error + # if _check_model has been called earlier, this will be unreachable + raise ValueError(f"The model `{request.model}` does not exist.") - preprocess_ret = await self._preprocess(ctx) - if isinstance(preprocess_ret, ErrorResponse): - yield preprocess_ret + def _get_message_types(self, request: AnyRequest) -> set[str]: + """Retrieve the set of types from message content dicts up + until `_`; we use this to match potential multimodal data + with default per modality loras. + """ + message_types: set[str] = set() - generators_ret = await self._prepare_generators(ctx) - if isinstance(generators_ret, ErrorResponse): - yield generators_ret + if not hasattr(request, "messages"): + return message_types - collect_ret = await self._collect_batch(ctx) - if isinstance(collect_ret, ErrorResponse): - yield collect_ret + messages = request.messages + if messages is None or isinstance(messages, (str, bytes)): + return message_types - yield self._build_response(ctx) + for message in messages: + if ( + isinstance(message, dict) + and "content" in message + and isinstance(message["content"], list) + ): + for content_dict in message["content"]: + if "type" in content_dict: + message_types.add(content_dict["type"].split("_")[0]) + return message_types - def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None: - truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) + def _validate_input( + self, + request: object, + input_ids: list[int], + input_text: str, + ) -> TokensPrompt: + token_num = len(input_ids) + max_model_len = self.model_config.max_model_len - if ( - truncate_prompt_tokens is not None - and truncate_prompt_tokens > self.model_config.max_model_len + # Note: EmbeddingRequest, ClassificationRequest, + # and ScoreRequest doesn't have max_tokens + if isinstance( + request, + ( + EmbeddingChatRequest, + EmbeddingCompletionRequest, + ScoreDataRequest, + ScoreTextRequest, + ScoreQueriesDocumentsRequest, + RerankRequest, + ClassificationCompletionRequest, + ClassificationChatRequest, + ), ): - return self.create_error_response( - "truncate_prompt_tokens value is " - "greater than max_model_len." - " Please, select a smaller truncation size." + # Note: input length can be up to the entire model context length + # since these requests don't generate tokens. + if token_num > max_model_len: + operations: dict[type[AnyRequest], str] = { + ScoreDataRequest: "score", + ScoreTextRequest: "score", + ScoreQueriesDocumentsRequest: "score", + ClassificationCompletionRequest: "classification", + ClassificationChatRequest: "classification", + } + operation = operations.get(type(request), "embedding generation") + raise VLLMValidationError( + f"This model's maximum context length is " + f"{max_model_len} tokens. However, you requested " + f"{token_num} tokens in the input for {operation}. " + f"Please reduce the length of the input.", + parameter="input_tokens", + value=token_num, + ) + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + + # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens + # and does not require model context length validation + if isinstance( + request, + (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest), + ): + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + + # chat completion endpoint supports max_completion_tokens + if isinstance(request, ChatCompletionRequest): + # TODO(#9845): remove max_tokens when field dropped from OpenAI API + max_tokens = request.max_completion_tokens or request.max_tokens + else: + max_tokens = getattr(request, "max_tokens", None) + + # Note: input length can be up to model context length - 1 for + # completion-like requests. + if token_num >= max_model_len: + raise VLLMValidationError( + f"This model's maximum context length is " + f"{max_model_len} tokens. However, your request has " + f"{token_num} input tokens. Please reduce the length of " + "the input messages.", + parameter="input_tokens", + value=token_num, ) - return None - def _create_pooling_params( - self, - ctx: ServeContext, - ) -> PoolingParams | ErrorResponse: - if not hasattr(ctx.request, "to_pooling_params"): - return self.create_error_response( - "Request type does not support pooling parameters" + if max_tokens is not None and token_num + max_tokens > max_model_len: + raise VLLMValidationError( + "'max_tokens' or 'max_completion_tokens' is too large: " + f"{max_tokens}. This model's maximum context length is " + f"{max_model_len} tokens and your request has " + f"{token_num} input tokens ({max_tokens} > {max_model_len}" + f" - {token_num}).", + parameter="max_tokens", + value=max_tokens, ) - return ctx.request.to_pooling_params() + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) - async def _prepare_generators( + def _validate_chat_template( self, - ctx: ServeContext, + request_chat_template: str | None, + chat_template_kwargs: dict[str, Any] | None, + trust_request_chat_template: bool, ) -> ErrorResponse | None: - """Schedule the request and get the result generator.""" - generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - - try: - trace_headers = ( - None - if ctx.raw_request is None - else await self._get_trace_headers(ctx.raw_request.headers) + if not trust_request_chat_template and ( + request_chat_template is not None + or ( + chat_template_kwargs + and chat_template_kwargs.get("chat_template") is not None ) + ): + return self.create_error_response( + "Chat template is passed with request, but " + "--trust-request-chat-template is not set. " + "Refused request with untrusted chat template." + ) + return None - pooling_params = self._create_pooling_params(ctx) - if isinstance(pooling_params, ErrorResponse): - return pooling_params - - if ctx.engine_prompts is None: - return self.create_error_response("Engine prompts not available") - - for i, engine_prompt in enumerate(ctx.engine_prompts): - request_id_item = f"{ctx.request_id}-{i}" - - self._log_inputs( - request_id_item, - engine_prompt, - params=pooling_params, - lora_request=ctx.lora_request, - ) + @staticmethod + def _prepare_extra_chat_template_kwargs( + request_chat_template_kwargs: dict[str, Any] | None = None, + default_chat_template_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Helper to merge server-default and request-specific chat template kwargs.""" + request_chat_template_kwargs = request_chat_template_kwargs or {} + if default_chat_template_kwargs is None: + return request_chat_template_kwargs + # Apply server defaults first, then request kwargs override. + return default_chat_template_kwargs | request_chat_template_kwargs - generator = self.engine_client.encode( - engine_prompt, - pooling_params, - request_id_item, - lora_request=ctx.lora_request, - trace_headers=trace_headers, - priority=getattr(ctx.request, "priority", 0), - ) + async def _preprocess_completion( + self, + request: RendererRequest, + prompt_input: str | list[str] | list[int] | list[list[int]] | None, + prompt_embeds: bytes | list[bytes] | None, + ) -> list[ProcessorInputs]: + prompts = list[SingletonPrompt | bytes]() + if prompt_embeds is not None: # embeds take higher priority + prompts.extend(prompt_to_seq(prompt_embeds)) + if prompt_input is not None: + prompts.extend(prompt_to_seq(prompt_input)) - generators.append(generator) + return await self._preprocess_cmpl(request, prompts) - ctx.result_generator = merge_async_iterators(*generators) + async def _preprocess_cmpl( + self, + request: RendererRequest, + prompts: Sequence[PromptType | bytes], + ) -> list[ProcessorInputs]: + renderer = self.renderer + model_config = self.model_config - return None + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(model_config, prompt) + ) + for prompt in prompts + ] + tok_params = request.build_tok_params(model_config) - except Exception as e: - return self.create_error_response(e) + return await renderer.render_cmpl_async( + parsed_prompts, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) - async def _collect_batch( + async def _preprocess_chat( self, - ctx: ServeContext, - ) -> ErrorResponse | None: - """Collect batch results from the result generator.""" - try: - if ctx.engine_prompts is None: - return self.create_error_response("Engine prompts not available") + request: RendererChatRequest, + messages: list[ChatCompletionMessageParam], + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + default_template_kwargs: dict[str, Any] | None, + tool_dicts: list[dict[str, Any]] | None = None, + tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, + ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]: + renderer = self.renderer - num_prompts = len(ctx.engine_prompts) - final_res_batch: list[PoolingRequestOutput | None] - final_res_batch = [None] * num_prompts + default_template_kwargs = merge_kwargs( + default_template_kwargs, + dict( + tools=tool_dicts, + tokenize=is_mistral_tokenizer(renderer.tokenizer), + ), + ) - if ctx.result_generator is None: - return self.create_error_response("Result generator not available") + tok_params = request.build_tok_params(self.model_config) + chat_params = request.build_chat_params( + default_template, default_template_content_format + ).with_defaults(default_template_kwargs) - async for i, res in ctx.result_generator: - final_res_batch[i] = res + (conversation,), (engine_prompt,) = await renderer.render_chat_async( + [messages], + chat_params, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) - if None in final_res_batch: - return self.create_error_response( - "Failed to generate results for all prompts" - ) + # tool parsing is done only if a tool_parser has been set and if + # tool_choice is not "none" (if tool_choice is "none" but a tool_parser + # is set, we want to prevent parsing a tool_call hallucinated by the LLM + if tool_parser is not None: + tool_choice = getattr(request, "tool_choice", "none") + if tool_choice != "none": + if not isinstance(request, ChatCompletionRequest | ResponsesRequest): + msg = ( + "Tool usage is only supported for Chat Completions API " + "or Responses API requests." + ) + raise NotImplementedError(msg) - ctx.final_res_batch = [res for res in final_res_batch if res is not None] + # TODO: Update adjust_request to accept ResponsesRequest + tokenizer = renderer.get_tokenizer() + request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type] - return None + return conversation, [engine_prompt] - except Exception as e: - return self.create_error_response(e) + def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs): + return extract_prompt_components(self.model_config, prompt) - def create_error_response( - self, - message: str | Exception, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, - param: str | None = None, - ) -> ErrorResponse: - exc: Exception | None = None + def _extract_prompt_text(self, prompt: ProcessorInputs): + return self._extract_prompt_components(prompt).text - if isinstance(message, Exception): - exc = message + def _extract_prompt_len(self, prompt: ProcessorInputs): + return extract_prompt_len(self.model_config, prompt) - from vllm.exceptions import VLLMValidationError + async def _render_next_turn( + self, + request: ResponsesRequest, + messages: list[ResponseInputOutputItem], + tool_dicts: list[dict[str, Any]] | None, + tool_parser: Callable[[TokenizerLike], ToolParser] | None, + chat_template: str | None, + chat_template_content_format: ChatTemplateContentFormatOption, + ): + new_messages = construct_input_messages( + request_input=messages, + ) - if isinstance(exc, VLLMValidationError): - err_type = "BadRequestError" - status_code = HTTPStatus.BAD_REQUEST - param = exc.parameter - elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)): - # Common validation errors from user input - err_type = "BadRequestError" - status_code = HTTPStatus.BAD_REQUEST - param = None - elif isinstance(exc, NotImplementedError): - err_type = "NotImplementedError" - status_code = HTTPStatus.NOT_IMPLEMENTED - param = None - elif exc.__class__.__name__ == "TemplateError": - # jinja2.TemplateError (avoid importing jinja2) - err_type = "BadRequestError" - status_code = HTTPStatus.BAD_REQUEST - param = None - else: - err_type = "InternalServerError" - status_code = HTTPStatus.INTERNAL_SERVER_ERROR - param = None + _, engine_prompts = await self._preprocess_chat( + request, + new_messages, + default_template=chat_template, + default_template_content_format=chat_template_content_format, + default_template_kwargs=None, + tool_dicts=tool_dicts, + tool_parser=tool_parser, + ) + return engine_prompts - message = str(exc) + def _log_inputs( + self, + request_id: str, + inputs: PromptType | ProcessorInputs, + params: SamplingParams | PoolingParams | BeamSearchParams | None, + lora_request: LoRARequest | None, + ) -> None: + if self.request_logger is None: + return - if self.log_error_stack: - exc_type, _, _ = sys.exc_info() - if exc_type is not None: - traceback.print_exc() - else: - traceback.print_stack() + components = self._extract_prompt_components(inputs) - return ErrorResponse( - error=ErrorInfo( - message=sanitize_message(message), - type=err_type, - code=status_code.value, - param=param, - ) + self.request_logger.log_inputs( + request_id, + components.text, + components.token_ids, + components.embeds, + params=params, + lora_request=lora_request, ) - def create_streaming_error_response( + async def _get_trace_headers( self, - message: str | Exception, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, - param: str | None = None, - ) -> str: - json_str = json.dumps( - self.create_error_response( - message=message, - err_type=err_type, - status_code=status_code, - param=param, - ).model_dump() - ) - return json_str + headers: Headers, + ) -> Mapping[str, str] | None: + is_tracing_enabled = await self.engine_client.is_tracing_enabled() - def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None: - """Raise GenerationError if finish_reason indicates an error.""" - if finish_reason == "error": - logger.error( - "Request %s failed with an internal error during generation", - request_id, - ) - raise GenerationError("Internal server error") + if is_tracing_enabled: + return extract_trace_headers(headers) - def _convert_generation_error_to_response( - self, e: GenerationError - ) -> ErrorResponse: - """Convert GenerationError to ErrorResponse.""" - return self.create_error_response( - str(e), - err_type="InternalServerError", - status_code=e.status_code, - ) + if contains_trace_headers(headers): + log_tracing_disabled_warning() - def _convert_generation_error_to_streaming_response( - self, e: GenerationError - ) -> str: - """Convert GenerationError to streaming error response.""" - return self.create_streaming_error_response( - str(e), - err_type="InternalServerError", - status_code=e.status_code, - ) + return None - async def _check_model( - self, - request: AnyRequest, - ) -> ErrorResponse | None: - error_response = None + @staticmethod + def _base_request_id( + raw_request: Request | None, default: str | None = None + ) -> str | None: + """Pulls the request id to use from a header, if provided""" + if raw_request is not None and ( + (req_id := raw_request.headers.get("X-Request-Id")) is not None + ): + return req_id - if self._is_model_supported(request.model): + return random_uuid() if default is None else default + + @staticmethod + def _get_data_parallel_rank(raw_request: Request | None) -> int | None: + """Pulls the data parallel rank from a header, if provided""" + if raw_request is None: return None - if request.model in self.models.lora_requests: + + rank_str = raw_request.headers.get("X-data-parallel-rank") + if rank_str is None: return None - if ( - envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING - and request.model - and (load_result := await self.models.resolve_lora(request.model)) - ): - if isinstance(load_result, LoRARequest): - return None - if ( - isinstance(load_result, ErrorResponse) - and load_result.error.code == HTTPStatus.BAD_REQUEST.value - ): - error_response = load_result - return error_response or self.create_error_response( - message=f"The model `{request.model}` does not exist.", - err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND, - param="model", - ) + try: + return int(rank_str) + except ValueError: + return None - def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None: - """Determine if there are any active default multimodal loras.""" - # TODO: Currently this is only enabled for chat completions - # to be better aligned with only being enabled for .generate - # when run offline. It would be nice to support additional - # tasks types in the future. - message_types = self._get_message_types(request) - default_mm_loras = set() + @staticmethod + def _parse_tool_calls_from_content( + request: ResponsesRequest | ChatCompletionRequest, + tokenizer: TokenizerLike | None, + enable_auto_tools: bool, + tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, + content: str | None = None, + ) -> tuple[list[FunctionCall] | None, str | None]: + function_calls = list[FunctionCall]() + if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction): + assert content is not None + # Forced Function Call + function_calls.append( + FunctionCall(name=request.tool_choice.name, arguments=content) + ) + content = None # Clear content since tool is called. + elif request.tool_choice and isinstance( + request.tool_choice, ChatCompletionNamedToolChoiceParam + ): + assert content is not None + # Forced Function Call + function_calls.append( + FunctionCall(name=request.tool_choice.function.name, arguments=content) + ) + content = None # Clear content since tool is called. + elif request.tool_choice == "required": + assert content is not None + tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content) + function_calls.extend( + [ + FunctionCall( + name=tool_call.name, + arguments=json.dumps(tool_call.parameters, ensure_ascii=False), + ) + for tool_call in tool_calls + ] + ) + content = None # Clear content since tool is called. + elif ( + tool_parser_cls + and enable_auto_tools + and (request.tool_choice == "auto" or request.tool_choice is None) + ): + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) - for lora in self.models.lora_requests.values(): - # Best effort match for default multimodal lora adapters; - # There is probably a better way to do this, but currently - # this matches against the set of 'types' in any content lists - # up until '_', e.g., to match audio_url -> audio - if lora.lora_name in message_types: - default_mm_loras.add(lora) + # Automatic Tool Call Parsing + try: + tool_parser = tool_parser_cls(tokenizer) + except RuntimeError as e: + logger.exception("Error in tool parser creation.") + raise e + tool_call_info = tool_parser.extract_tool_calls( + content if content is not None else "", + request=request, # type: ignore + ) + if tool_call_info is not None and tool_call_info.tools_called: + # extract_tool_calls() returns a list of tool calls. + function_calls.extend( + FunctionCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + for tool_call in tool_call_info.tool_calls + ) + content = tool_call_info.content + if content and content.strip() == "": + content = None + else: + # No tool calls. + return None, content - # Currently only support default modality specific loras if - # we have exactly one lora matched on the request. - if len(default_mm_loras) == 1: - return default_mm_loras.pop() - return None + return function_calls, content - def _maybe_get_adapters( - self, - request: AnyRequest, - supports_default_mm_loras: bool = False, - ) -> LoRARequest | None: - if request.model in self.models.lora_requests: - return self.models.lora_requests[request.model] + @staticmethod + def _get_decoded_token( + logprob: Logprob, + token_id: int, + tokenizer: TokenizerLike | None, + return_as_token_id: bool = False, + ) -> str: + if return_as_token_id: + return f"token_id:{token_id}" - # Currently only support default modality specific loras - # if we have exactly one lora matched on the request. - if supports_default_mm_loras: - default_mm_lora = self._get_active_default_mm_loras(request) - if default_mm_lora is not None: - return default_mm_lora + if logprob.decoded_token is not None: + return logprob.decoded_token - if self._is_model_supported(request.model): - return None + if tokenizer is None: + raise ValueError( + "Unable to get tokenizer because `skip_tokenizer_init=True`" + ) - # if _check_model has been called earlier, this will be unreachable - raise ValueError(f"The model `{request.model}` does not exist.") + return tokenizer.decode([token_id]) - def _get_message_types(self, request: AnyRequest) -> set[str]: - """Retrieve the set of types from message content dicts up - until `_`; we use this to match potential multimodal data - with default per modality loras. - """ - message_types: set[str] = set() + def _is_model_supported(self, model_name: str | None) -> bool: + if not model_name: + return True + return self.models.is_base_model(model_name) - if not hasattr(request, "messages"): - return message_types - messages = request.messages - if messages is None or isinstance(messages, (str, bytes)): - return message_types +class OpenAIServingInference(OpenAIServing): + """OpenAIServing subclass for endpoints that require inference. - for message in messages: - if ( - isinstance(message, dict) - and "content" in message - and isinstance(message["content"], list) - ): - for content_dict in message["content"]: - if "type" in content_dict: - message_types.add(content_dict["type"].split("_")[0]) - return message_types + Extends :class:`OpenAIServing` by narrowing ``engine_client`` from + :class:`BaseEngineClient` to :class:`EngineClient` and adding methods + that call ``generate()`` / ``encode()`` on the engine. + """ - def _validate_input( + engine_client: EngineClient + + def __init__( self, - request: object, - input_ids: list[int], - input_text: str, - ) -> TokensPrompt: - token_num = len(input_ids) - max_model_len = self.model_config.max_model_len + engine_client: EngineClient, + models: OpenAIServingModels, + *, + request_logger: RequestLogger | None, + return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, + ): + super().__init__( + engine_client=engine_client, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + log_error_stack=log_error_stack, + ) - # Note: EmbeddingRequest, ClassificationRequest, - # and ScoreRequest doesn't have max_tokens - if isinstance( - request, - ( - EmbeddingChatRequest, - EmbeddingCompletionRequest, - ScoreDataRequest, - ScoreTextRequest, - ScoreQueriesDocumentsRequest, - RerankRequest, - ClassificationCompletionRequest, - ClassificationChatRequest, - ), - ): - # Note: input length can be up to the entire model context length - # since these requests don't generate tokens. - if token_num > max_model_len: - operations: dict[type[AnyRequest], str] = { - ScoreDataRequest: "score", - ScoreTextRequest: "score", - ScoreQueriesDocumentsRequest: "score", - ClassificationCompletionRequest: "classification", - ClassificationChatRequest: "classification", - } - operation = operations.get(type(request), "embedding generation") - raise VLLMValidationError( - f"This model's maximum context length is " - f"{max_model_len} tokens. However, you requested " - f"{token_num} tokens in the input for {operation}. " - f"Please reduce the length of the input.", - parameter="input_tokens", - value=token_num, - ) - return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + async def _preprocess( + self, + ctx: ServeContext, + ) -> ErrorResponse | None: + """ + Default preprocessing hook. Subclasses may override + to prepare `ctx` (classification, embedding, etc.). + """ + return None - # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens - # and does not require model context length validation - if isinstance( - request, - (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest), - ): - return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + def _build_response( + self, + ctx: ServeContext, + ) -> AnyResponse | ErrorResponse: + """ + Default response builder. Subclass may override this method + to return the appropriate response object. + """ + return self.create_error_response("unimplemented endpoint") - # chat completion endpoint supports max_completion_tokens - if isinstance(request, ChatCompletionRequest): - # TODO(#9845): remove max_tokens when field dropped from OpenAI API - max_tokens = request.max_completion_tokens or request.max_tokens - else: - max_tokens = getattr(request, "max_tokens", None) + async def handle( + self, + ctx: ServeContext, + ) -> AnyResponse | ErrorResponse: + async for response in self._pipeline(ctx): + return response - # Note: input length can be up to model context length - 1 for - # completion-like requests. - if token_num >= max_model_len: - raise VLLMValidationError( - f"This model's maximum context length is " - f"{max_model_len} tokens. However, your request has " - f"{token_num} input tokens. Please reduce the length of " - "the input messages.", - parameter="input_tokens", - value=token_num, - ) + return self.create_error_response("No response yielded from pipeline") - if max_tokens is not None and token_num + max_tokens > max_model_len: - raise VLLMValidationError( - "'max_tokens' or 'max_completion_tokens' is too large: " - f"{max_tokens}. This model's maximum context length is " - f"{max_model_len} tokens and your request has " - f"{token_num} input tokens ({max_tokens} > {max_model_len}" - f" - {token_num}).", - parameter="max_tokens", - value=max_tokens, - ) + async def _pipeline( + self, + ctx: ServeContext, + ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]: + """Execute the request processing pipeline yielding responses.""" + if error := await self._check_model(ctx.request): + yield error + if error := self._validate_request(ctx): + yield error - return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + preprocess_ret = await self._preprocess(ctx) + if isinstance(preprocess_ret, ErrorResponse): + yield preprocess_ret - def _validate_chat_template( - self, - request_chat_template: str | None, - chat_template_kwargs: dict[str, Any] | None, - trust_request_chat_template: bool, - ) -> ErrorResponse | None: - if not trust_request_chat_template and ( - request_chat_template is not None - or ( - chat_template_kwargs - and chat_template_kwargs.get("chat_template") is not None - ) + generators_ret = await self._prepare_generators(ctx) + if isinstance(generators_ret, ErrorResponse): + yield generators_ret + + collect_ret = await self._collect_batch(ctx) + if isinstance(collect_ret, ErrorResponse): + yield collect_ret + + yield self._build_response(ctx) + + def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None: + truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) + + if ( + truncate_prompt_tokens is not None + and truncate_prompt_tokens > self.model_config.max_model_len ): return self.create_error_response( - "Chat template is passed with request, but " - "--trust-request-chat-template is not set. " - "Refused request with untrusted chat template." + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size." ) return None - @staticmethod - def _prepare_extra_chat_template_kwargs( - request_chat_template_kwargs: dict[str, Any] | None = None, - default_chat_template_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Helper to merge server-default and request-specific chat template kwargs.""" - request_chat_template_kwargs = request_chat_template_kwargs or {} - if default_chat_template_kwargs is None: - return request_chat_template_kwargs - # Apply server defaults first, then request kwargs override. - return default_chat_template_kwargs | request_chat_template_kwargs + def _create_pooling_params( + self, + ctx: ServeContext, + ) -> PoolingParams | ErrorResponse: + if not hasattr(ctx.request, "to_pooling_params"): + return self.create_error_response( + "Request type does not support pooling parameters" + ) - async def _preprocess_completion( + return ctx.request.to_pooling_params() + + async def _collect_batch( self, - request: RendererRequest, - prompt_input: str | list[str] | list[int] | list[list[int]] | None, - prompt_embeds: bytes | list[bytes] | None, - ) -> list[ProcessorInputs]: - prompts = list[SingletonPrompt | bytes]() - if prompt_embeds is not None: # embeds take higher priority - prompts.extend(prompt_to_seq(prompt_embeds)) - if prompt_input is not None: - prompts.extend(prompt_to_seq(prompt_input)) + ctx: ServeContext, + ) -> ErrorResponse | None: + """Collect batch results from the result generator.""" + try: + if ctx.engine_prompts is None: + return self.create_error_response("Engine prompts not available") - return await self._preprocess_cmpl(request, prompts) + num_prompts = len(ctx.engine_prompts) + final_res_batch: list[PoolingRequestOutput | None] + final_res_batch = [None] * num_prompts - async def _preprocess_cmpl( + if ctx.result_generator is None: + return self.create_error_response("Result generator not available") + + async for i, res in ctx.result_generator: + final_res_batch[i] = res + + if None in final_res_batch: + return self.create_error_response( + "Failed to generate results for all prompts" + ) + + ctx.final_res_batch = [res for res in final_res_batch if res is not None] + + return None + + except Exception as e: + return self.create_error_response(e) + + async def beam_search( self, - request: RendererRequest, - prompts: Sequence[PromptType | bytes], - ) -> list[ProcessorInputs]: - renderer = self.renderer - model_config = self.model_config + prompt: ProcessorInputs, + request_id: str, + params: BeamSearchParams, + lora_request: LoRARequest | None = None, + trace_headers: Mapping[str, str] | None = None, + ) -> AsyncGenerator[RequestOutput, None]: + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + include_stop_str_in_output = params.include_stop_str_in_output - parsed_prompts = [ - ( - prompt - if isinstance(prompt, bytes) - else parse_model_prompt(model_config, prompt) + tokenizer = self.renderer.get_tokenizer() + eos_token_id = tokenizer.eos_token_id + sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) + + if prompt["type"] == "embeds": + raise NotImplementedError("Embedding prompt not supported for beam search") + if prompt["type"] == "enc_dec": + raise NotImplementedError( + "Encoder-decoder prompt not supported for beam search" ) - for prompt in prompts - ] - tok_params = request.build_tok_params(model_config) - return await renderer.render_cmpl_async( - parsed_prompts, - tok_params, - prompt_extras={ - k: v - for k in ("mm_processor_kwargs", "cache_salt") - if (v := getattr(request, k, None)) is not None - }, + prompt_text = prompt.get("prompt") + prompt_token_ids = prompt["prompt_token_ids"] + tokenized_length = len(prompt_token_ids) + + logprobs_num = 2 * beam_width + sampling_params = SamplingParams( + logprobs=logprobs_num, + max_tokens=1, + temperature=temperature, ) + all_beams = [ + BeamSearchSequence( + orig_prompt=prompt, + tokens=prompt_token_ids, + cum_logprob=0, + logprobs=[], + lora_request=lora_request, + ) + ] + completed = [] - async def _preprocess_chat( - self, - request: RendererChatRequest, - messages: list[ChatCompletionMessageParam], - default_template: str | None, - default_template_content_format: ChatTemplateContentFormatOption, - default_template_kwargs: dict[str, Any] | None, - tool_dicts: list[dict[str, Any]] | None = None, - tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, - ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]: - renderer = self.renderer + for _ in range(max_tokens): + tasks = [] + request_id_batch = f"{request_id}-{random_uuid()}" - default_template_kwargs = merge_kwargs( - default_template_kwargs, - dict( - tools=tool_dicts, - tokenize=is_mistral_tokenizer(renderer.tokenizer), - ), + for i, beam in enumerate(all_beams): + prompt_item = beam.get_prompt() + lora_request_item = beam.lora_request + request_id_item = f"{request_id_batch}-beam-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.engine_client.generate( + prompt_item, + sampling_params, + request_id_item, + lora_request=lora_request_item, + trace_headers=trace_headers, + ) + ) + ) + tasks.append(task) + + output = [x[0] for x in await asyncio.gather(*tasks)] + + new_beams = [] + # Store all new tokens generated by beam + all_beams_token_id = [] + # Store the cumulative probability of all tokens + # generated by beam search + all_beams_logprob = [] + # Iterate through all beam inference results + for i, result in enumerate(output): + current_beam = all_beams[i] + + # check for error finish reason and abort beam search + if result.outputs[0].finish_reason == "error": + # yield error output and terminate beam search + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput( + index=0, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + ) + return + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + all_beams_token_id.extend(list(logprobs.keys())) + all_beams_logprob.extend( + [ + current_beam.cum_logprob + obj.logprob + for obj in logprobs.values() + ] + ) + + # Handle the token for the end of sentence (EOS) + all_beams_token_id = np.array(all_beams_token_id) + all_beams_logprob = np.array(all_beams_logprob) + + if not ignore_eos: + # Get the index position of eos token in all generated results + eos_idx = np.where(all_beams_token_id == eos_token_id)[0] + for idx in eos_idx: + current_beam = all_beams[idx // logprobs_num] + result = output[idx // logprobs_num] + assert result.outputs[0].logprobs is not None + logprobs_entry = result.outputs[0].logprobs[0] + completed.append( + BeamSearchSequence( + orig_prompt=prompt, + tokens=current_beam.tokens + [eos_token_id] + if include_stop_str_in_output + else current_beam.tokens, + logprobs=current_beam.logprobs + [logprobs_entry], + cum_logprob=float(all_beams_logprob[idx]), + finish_reason="stop", + stop_reason=eos_token_id, + ) + ) + # After processing, set the log probability of the eos condition + # to negative infinity. + all_beams_logprob[eos_idx] = -np.inf + + # Processing non-EOS tokens + # Get indices of the top beam_width probabilities + topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[ + :beam_width + ] + + for idx in topn_idx: + current_beam = all_beams[idx // logprobs_num] + result = output[idx // logprobs_num] + token_id = int(all_beams_token_id[idx]) + assert result.outputs[0].logprobs is not None + logprobs_entry = result.outputs[0].logprobs[0] + new_beams.append( + BeamSearchSequence( + orig_prompt=prompt, + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs_entry], + lora_request=current_beam.lora_request, + cum_logprob=float(all_beams_logprob[idx]), + ) + ) + + all_beams = new_beams + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + if beam.tokens[-1] == eos_token_id and not ignore_eos: + # Skip the eos token in the text. + tokens = beam.tokens[tokenized_length:-1] + else: + tokens = beam.tokens[tokenized_length:] + beam.text = tokenizer.decode(tokens) + + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput( + text=beam.text, # type: ignore + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + finish_reason=beam.finish_reason + if beam.finish_reason is not None + else "length", + stop_reason=beam.stop_reason, + ) + for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, ) - tok_params = request.build_tok_params(self.model_config) - chat_params = request.build_chat_params( - default_template, default_template_content_format - ).with_defaults(default_template_kwargs) + async def _prepare_generators( + self, + ctx: ServeContext, + ) -> ErrorResponse | None: + """Schedule the request and get the result generator.""" + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + + try: + trace_headers = ( + None + if ctx.raw_request is None + else await self._get_trace_headers(ctx.raw_request.headers) + ) - (conversation,), (engine_prompt,) = await renderer.render_chat_async( - [messages], - chat_params, - tok_params, - prompt_extras={ - k: v - for k in ("mm_processor_kwargs", "cache_salt") - if (v := getattr(request, k, None)) is not None - }, - ) + pooling_params = self._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params - # tool parsing is done only if a tool_parser has been set and if - # tool_choice is not "none" (if tool_choice is "none" but a tool_parser - # is set, we want to prevent parsing a tool_call hallucinated by the LLM - if tool_parser is not None: - tool_choice = getattr(request, "tool_choice", "none") - if tool_choice != "none": - if not isinstance(request, ChatCompletionRequest | ResponsesRequest): - msg = ( - "Tool usage is only supported for Chat Completions API " - "or Responses API requests." - ) - raise NotImplementedError(msg) + if ctx.engine_prompts is None: + return self.create_error_response("Engine prompts not available") - # TODO: Update adjust_request to accept ResponsesRequest - tokenizer = renderer.get_tokenizer() - request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type] + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_id_item = f"{ctx.request_id}-{i}" - return conversation, [engine_prompt] + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) - def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs): - return extract_prompt_components(self.model_config, prompt) + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) - def _extract_prompt_text(self, prompt: ProcessorInputs): - return self._extract_prompt_components(prompt).text + generators.append(generator) - def _extract_prompt_len(self, prompt: ProcessorInputs): - return extract_prompt_len(self.model_config, prompt) + ctx.result_generator = merge_async_iterators(*generators) - async def _render_next_turn( - self, - request: ResponsesRequest, - messages: list[ResponseInputOutputItem], - tool_dicts: list[dict[str, Any]] | None, - tool_parser: Callable[[TokenizerLike], ToolParser] | None, - chat_template: str | None, - chat_template_content_format: ChatTemplateContentFormatOption, - ): - new_messages = construct_input_messages( - request_input=messages, - ) + return None - _, engine_prompts = await self._preprocess_chat( - request, - new_messages, - default_template=chat_template, - default_template_content_format=chat_template_content_format, - default_template_kwargs=None, - tool_dicts=tool_dicts, - tool_parser=tool_parser, - ) - return engine_prompts + except Exception as e: + return self.create_error_response(e) async def _generate_with_builtin_tools( self, @@ -1134,170 +1326,6 @@ async def _generate_with_builtin_tools( priority = orig_priority - 1 sub_request += 1 - def _log_inputs( - self, - request_id: str, - inputs: PromptType | ProcessorInputs, - params: SamplingParams | PoolingParams | BeamSearchParams | None, - lora_request: LoRARequest | None, - ) -> None: - if self.request_logger is None: - return - - components = self._extract_prompt_components(inputs) - - self.request_logger.log_inputs( - request_id, - components.text, - components.token_ids, - components.embeds, - params=params, - lora_request=lora_request, - ) - - async def _get_trace_headers( - self, - headers: Headers, - ) -> Mapping[str, str] | None: - is_tracing_enabled = await self.engine_client.is_tracing_enabled() - - if is_tracing_enabled: - return extract_trace_headers(headers) - - if contains_trace_headers(headers): - log_tracing_disabled_warning() - - return None - - @staticmethod - def _base_request_id( - raw_request: Request | None, default: str | None = None - ) -> str | None: - """Pulls the request id to use from a header, if provided""" - if raw_request is not None and ( - (req_id := raw_request.headers.get("X-Request-Id")) is not None - ): - return req_id - - return random_uuid() if default is None else default - - @staticmethod - def _get_data_parallel_rank(raw_request: Request | None) -> int | None: - """Pulls the data parallel rank from a header, if provided""" - if raw_request is None: - return None - - rank_str = raw_request.headers.get("X-data-parallel-rank") - if rank_str is None: - return None - - try: - return int(rank_str) - except ValueError: - return None - - @staticmethod - def _parse_tool_calls_from_content( - request: ResponsesRequest | ChatCompletionRequest, - tokenizer: TokenizerLike | None, - enable_auto_tools: bool, - tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, - content: str | None = None, - ) -> tuple[list[FunctionCall] | None, str | None]: - function_calls = list[FunctionCall]() - if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction): - assert content is not None - # Forced Function Call - function_calls.append( - FunctionCall(name=request.tool_choice.name, arguments=content) - ) - content = None # Clear content since tool is called. - elif request.tool_choice and isinstance( - request.tool_choice, ChatCompletionNamedToolChoiceParam - ): - assert content is not None - # Forced Function Call - function_calls.append( - FunctionCall(name=request.tool_choice.function.name, arguments=content) - ) - content = None # Clear content since tool is called. - elif request.tool_choice == "required": - assert content is not None - tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content) - function_calls.extend( - [ - FunctionCall( - name=tool_call.name, - arguments=json.dumps(tool_call.parameters, ensure_ascii=False), - ) - for tool_call in tool_calls - ] - ) - content = None # Clear content since tool is called. - elif ( - tool_parser_cls - and enable_auto_tools - and (request.tool_choice == "auto" or request.tool_choice is None) - ): - if tokenizer is None: - raise ValueError( - "Tokenizer not available when `skip_tokenizer_init=True`" - ) - - # Automatic Tool Call Parsing - try: - tool_parser = tool_parser_cls(tokenizer) - except RuntimeError as e: - logger.exception("Error in tool parser creation.") - raise e - tool_call_info = tool_parser.extract_tool_calls( - content if content is not None else "", - request=request, # type: ignore - ) - if tool_call_info is not None and tool_call_info.tools_called: - # extract_tool_calls() returns a list of tool calls. - function_calls.extend( - FunctionCall( - id=tool_call.id, - name=tool_call.function.name, - arguments=tool_call.function.arguments, - ) - for tool_call in tool_call_info.tool_calls - ) - content = tool_call_info.content - if content and content.strip() == "": - content = None - else: - # No tool calls. - return None, content - - return function_calls, content - - @staticmethod - def _get_decoded_token( - logprob: Logprob, - token_id: int, - tokenizer: TokenizerLike | None, - return_as_token_id: bool = False, - ) -> str: - if return_as_token_id: - return f"token_id:{token_id}" - - if logprob.decoded_token is not None: - return logprob.decoded_token - - if tokenizer is None: - raise ValueError( - "Unable to get tokenizer because `skip_tokenizer_init=True`" - ) - - return tokenizer.decode([token_id]) - - def _is_model_supported(self, model_name: str | None) -> bool: - if not model_name: - return True - return self.models.is_base_model(model_name) - def clamp_prompt_logprobs( prompt_logprobs: PromptLogprobs | None, diff --git a/vllm/entrypoints/openai/realtime/serving.py b/vllm/entrypoints/openai/realtime/serving.py index d239968e75d2..4a00a50306d6 100644 --- a/vllm/entrypoints/openai/realtime/serving.py +++ b/vllm/entrypoints/openai/realtime/serving.py @@ -10,7 +10,7 @@ from vllm.engine.protocol import EngineClient, StreamingInput from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.engine.serving import OpenAIServing +from vllm.entrypoints.openai.engine.serving import OpenAIServingInference from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.inputs.data import PromptType from vllm.logger import init_logger @@ -20,7 +20,7 @@ logger = init_logger(__name__) -class OpenAIServingRealtime(OpenAIServing): +class OpenAIServingRealtime(OpenAIServingInference): """Realtime audio transcription service via WebSocket streaming. Provides streaming audio-to-text transcription by transforming audio chunks diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 67f6fd35db3d..30e2851311b9 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -54,7 +54,7 @@ ) from vllm.entrypoints.openai.engine.serving import ( GenerationError, - OpenAIServing, + OpenAIServingInference, ) from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.parser.harmony_utils import ( @@ -153,7 +153,7 @@ def _extract_allowed_tools_from_mcp_requests( return allowed_tools_map -class OpenAIServingResponses(OpenAIServing): +class OpenAIServingResponses(OpenAIServingInference): def __init__( self, engine_client: EngineClient, diff --git a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py index 966e6d457162..4e3d22af13aa 100644 --- a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py @@ -22,7 +22,10 @@ RequestResponseMetadata, UsageInfo, ) -from vllm.entrypoints.openai.engine.serving import OpenAIServing, SpeechToTextRequest +from vllm.entrypoints.openai.engine.serving import ( + OpenAIServingInference, + SpeechToTextRequest, +) from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.speech_to_text.protocol import ( TranscriptionResponse, @@ -76,7 +79,7 @@ logger = init_logger(__name__) -class OpenAISpeechToText(OpenAIServing): +class OpenAISpeechToText(OpenAIServingInference): """Base class for speech-to-text operations like transcription and translation.""" diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index 8cdbbde6d6f6..a3ba96549a5f 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -11,7 +11,7 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo -from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext +from vllm.entrypoints.openai.engine.serving import OpenAIServingInference, ServeContext from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.classify.protocol import ( ClassificationChatRequest, @@ -29,7 +29,7 @@ ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest] -class ServingClassification(OpenAIServing): +class ServingClassification(OpenAIServingInference): request_id_prefix = "classify" def __init__( diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index de4dca623503..cf812f7eb962 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -13,7 +13,7 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo -from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext +from vllm.entrypoints.openai.engine.serving import OpenAIServingInference, ServeContext from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.embed.protocol import ( EmbeddingBytesResponse, @@ -42,7 +42,7 @@ EmbeddingServeContext: TypeAlias = ServeContext[EmbeddingRequest] -class OpenAIServingEmbedding(OpenAIServing): +class OpenAIServingEmbedding(OpenAIServingInference): request_id_prefix = "embd" def __init__( diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index f27a27191f99..5427d7582ef5 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -16,7 +16,7 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo -from vllm.entrypoints.openai.engine.serving import OpenAIServing +from vllm.entrypoints.openai.engine.serving import OpenAIServingInference from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.pooling.protocol import ( IOProcessorRequest, @@ -43,7 +43,7 @@ logger = init_logger(__name__) -class OpenAIServingPooling(OpenAIServing): +class OpenAIServingPooling(OpenAIServingInference): def __init__( self, engine_client: EngineClient, diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index 3fe18ca8b3a5..7e86c553a346 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -14,7 +14,7 @@ ErrorResponse, UsageInfo, ) -from vllm.entrypoints.openai.engine.serving import OpenAIServing +from vllm.entrypoints.openai.engine.serving import OpenAIServingInference from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.score.protocol import ( RerankDocument, @@ -47,7 +47,7 @@ logger = init_logger(__name__) -class ServingScores(OpenAIServing): +class ServingScores(OpenAIServingInference): def __init__( self, engine_client: EngineClient, diff --git a/vllm/entrypoints/serve/disagg/serving.py b/vllm/entrypoints/serve/disagg/serving.py index f004e5269830..86b7ce351362 100644 --- a/vllm/entrypoints/serve/disagg/serving.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -22,7 +22,10 @@ RequestResponseMetadata, UsageInfo, ) -from vllm.entrypoints.openai.engine.serving import OpenAIServing, clamp_prompt_logprobs +from vllm.entrypoints.openai.engine.serving import ( + OpenAIServingInference, + clamp_prompt_logprobs, +) from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.serve.disagg.protocol import ( GenerateRequest, @@ -38,7 +41,7 @@ logger = init_logger(__name__) -class ServingTokens(OpenAIServing): +class ServingTokens(OpenAIServingInference): """Provides Tokens IN <> Tokens OUT functionality to vLLM API.""" def __init__( diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index 55d7ea827c57..7c768ea9227d 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -6,7 +6,7 @@ import jinja2 from fastapi import Request -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import BaseEngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ErrorResponse @@ -30,7 +30,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - engine_client: EngineClient, + engine_client: BaseEngineClient, models: OpenAIServingModels, *, request_logger: RequestLogger | None, From cc7183704c31cab814be2f73f7a1f5efdcf9157a Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Wed, 25 Feb 2026 11:19:49 +0200 Subject: [PATCH 03/53] render client Signed-off-by: Sage Ahrac --- vllm/engine/protocol.py | 8 ++-- vllm/entrypoints/openai/engine/serving.py | 52 ++++++++++++++-------- vllm/entrypoints/serve/tokenize/serving.py | 4 +- 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 5f2686797e18..490dbb47d42b 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -38,8 +38,8 @@ class StreamingInput: sampling_params: SamplingParams | None = None -class BaseEngineClient(ABC): - """Engine client interface for non-inference operations. +class RendererClient(ABC): + """Client interface for the renderer layer (CPU-only operations). Contains only methods and attributes that don't require a running inference engine: configuration, tokenization, health checks, and @@ -86,10 +86,10 @@ async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: raise NotImplementedError -class EngineClient(BaseEngineClient): +class EngineClient(RendererClient): """Full engine client interface including inference operations. - Extends :class:`BaseEngineClient` with methods that require a running + Extends :class:`RendererClient` with methods that require a running inference engine: generation, encoding, profiling, cache management, scheduler control, and weight transfer. """ diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 2be03402fa9d..b6a9fe7ad866 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -21,7 +21,7 @@ import vllm.envs as envs from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import ModelConfig -from vllm.engine.protocol import BaseEngineClient, EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, @@ -229,7 +229,7 @@ class OpenAIServing: def __init__( self, - engine_client: BaseEngineClient, + engine_client: RendererClient, models: OpenAIServingModels, *, request_logger: RequestLogger | None, @@ -357,26 +357,12 @@ async def _check_model( self, request: AnyRequest, ) -> ErrorResponse | None: - error_response = None - if self._is_model_supported(request.model): return None if request.model in self.models.lora_requests: return None - if ( - envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING - and request.model - and (load_result := await self.models.resolve_lora(request.model)) - ): - if isinstance(load_result, LoRARequest): - return None - if ( - isinstance(load_result, ErrorResponse) - and load_result.error.code == HTTPStatus.BAD_REQUEST.value - ): - error_response = load_result - return error_response or self.create_error_response( + return self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND, @@ -868,7 +854,7 @@ class OpenAIServingInference(OpenAIServing): """OpenAIServing subclass for endpoints that require inference. Extends :class:`OpenAIServing` by narrowing ``engine_client`` from - :class:`BaseEngineClient` to :class:`EngineClient` and adding methods + :class:`RendererClient` to :class:`EngineClient` and adding methods that call ``generate()`` / ``encode()`` on the engine. """ @@ -891,6 +877,36 @@ def __init__( log_error_stack=log_error_stack, ) + async def _check_model( + self, + request: AnyRequest, + ) -> ErrorResponse | None: + """Extend base model check with runtime LoRA resolution. + + Runtime LoRA loading requires :meth:`EngineClient.add_lora`, + which is only available on the full inference engine. + """ + # Check known models/LoRAs first (base class) + error_response = await super()._check_model(request) + if error_response is None: + return None + + # Attempt runtime LoRA resolution (needs inference engine) + if ( + envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING + and request.model + and (load_result := await self.models.resolve_lora(request.model)) + ): + if isinstance(load_result, LoRARequest): + return None + if ( + isinstance(load_result, ErrorResponse) + and load_result.error.code == HTTPStatus.BAD_REQUEST.value + ): + return load_result + + return error_response + async def _preprocess( self, ctx: ServeContext, diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index 7c768ea9227d..40dbb7e5e29d 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -6,7 +6,7 @@ import jinja2 from fastapi import Request -from vllm.engine.protocol import BaseEngineClient +from vllm.engine.protocol import RendererClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ErrorResponse @@ -30,7 +30,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - engine_client: BaseEngineClient, + engine_client: RendererClient, models: OpenAIServingModels, *, request_logger: RequestLogger | None, From a762964c185209516a07a23447b1b2a015215f98 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Wed, 25 Feb 2026 12:25:28 +0200 Subject: [PATCH 04/53] remove inheritance between engine client and render client Signed-off-by: Sage Ahrac --- vllm/engine/protocol.py | 9 ++----- .../openai/chat_completion/serving.py | 8 +++--- vllm/entrypoints/openai/completion/serving.py | 8 +++--- vllm/entrypoints/openai/engine/serving.py | 26 +++++++++---------- vllm/entrypoints/openai/responses/serving.py | 8 +++--- .../openai/speech_to_text/serving.py | 6 ++++- .../openai/speech_to_text/speech_to_text.py | 8 +++--- vllm/v1/engine/async_llm.py | 4 +-- 8 files changed, 42 insertions(+), 35 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 490dbb47d42b..1a9275ffaeff 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -86,13 +86,8 @@ async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: raise NotImplementedError -class EngineClient(RendererClient): - """Full engine client interface including inference operations. - - Extends :class:`RendererClient` with methods that require a running - inference engine: generation, encoding, profiling, cache management, - scheduler control, and weight transfer. - """ +class EngineClient(ABC): + """Engine client interface for inference operations.""" @abstractmethod def generate( diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 9ea03681852d..488909d8165b 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -15,7 +15,7 @@ from openai_harmony import Message as OpenAIMessage from partial_json_parser.core.options import Allow -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.chat_utils import ( ChatTemplateContentFormatOption, ConversationMessage, @@ -88,6 +88,7 @@ class OpenAIServingChat(OpenAIServingInference): def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, response_role: str, @@ -109,6 +110,7 @@ def __init__( default_chat_template_kwargs: dict[str, Any] | None = None, ) -> None: super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, @@ -232,8 +234,8 @@ async def render_chat_request( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.engine_client.errored: - raise self.engine_client.dead_error + if self.renderer_client.errored: + raise self.renderer_client.dead_error try: tokenizer = self.renderer.tokenizer diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index 935560dd3b27..a375f80096af 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -10,7 +10,7 @@ import jinja2 from fastapi import Request -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.completion.protocol import ( CompletionLogProbs, @@ -49,6 +49,7 @@ class OpenAIServingCompletion(OpenAIServingInference): def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -59,6 +60,7 @@ def __init__( log_error_stack: bool = False, ): super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, @@ -95,8 +97,8 @@ async def render_completion_request( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.engine_client.errored: - raise self.engine_client.dead_error + if self.renderer_client.errored: + raise self.renderer_client.dead_error # Return error for unsupported features. if request.suffix is not None: diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index b6a9fe7ad866..0bd8c20dc14d 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -229,7 +229,7 @@ class OpenAIServing: def __init__( self, - engine_client: RendererClient, + renderer_client: RendererClient, models: OpenAIServingModels, *, request_logger: RequestLogger | None, @@ -238,7 +238,7 @@ def __init__( ): super().__init__() - self.engine_client = engine_client + self.renderer_client = renderer_client self.models = models @@ -247,10 +247,10 @@ def __init__( self.log_error_stack = log_error_stack - self.model_config = engine_client.model_config - self.renderer = engine_client.renderer - self.io_processor = engine_client.io_processor - self.input_processor = engine_client.input_processor + self.model_config = renderer_client.model_config + self.renderer = renderer_client.renderer + self.io_processor = renderer_client.io_processor + self.input_processor = renderer_client.input_processor def create_error_response( self, @@ -710,7 +710,7 @@ async def _get_trace_headers( self, headers: Headers, ) -> Mapping[str, str] | None: - is_tracing_enabled = await self.engine_client.is_tracing_enabled() + is_tracing_enabled = await self.renderer_client.is_tracing_enabled() if is_tracing_enabled: return extract_trace_headers(headers) @@ -853,15 +853,14 @@ def _is_model_supported(self, model_name: str | None) -> bool: class OpenAIServingInference(OpenAIServing): """OpenAIServing subclass for endpoints that require inference. - Extends :class:`OpenAIServing` by narrowing ``engine_client`` from - :class:`RendererClient` to :class:`EngineClient` and adding methods - that call ``generate()`` / ``encode()`` on the engine. + Extends :class:`OpenAIServing` with a separate ``engine_client`` + (:class:`EngineClient`) for methods that call ``generate()`` / + ``encode()`` on the engine. """ - engine_client: EngineClient - def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -870,12 +869,13 @@ def __init__( log_error_stack: bool = False, ): super().__init__( - engine_client=engine_client, + renderer_client=renderer_client, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, log_error_stack=log_error_stack, ) + self.engine_client = engine_client async def _check_model( self, diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 30e2851311b9..86049faa4fb8 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -40,7 +40,7 @@ from vllm import envs from vllm.config.utils import replace -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, @@ -156,6 +156,7 @@ def _extract_allowed_tools_from_mcp_requests( class OpenAIServingResponses(OpenAIServingInference): def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -173,6 +174,7 @@ def __init__( log_error_stack: bool = False, ) -> None: super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, @@ -339,8 +341,8 @@ async def create_responses( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.engine_client.errored: - raise self.engine_client.dead_error + if self.renderer_client.errored: + raise self.renderer_client.dead_error if request.store and not self.enable_store: # Disable the store option. diff --git a/vllm/entrypoints/openai/speech_to_text/serving.py b/vllm/entrypoints/openai/speech_to_text/serving.py index b5ce17d0ef79..6ca135d720a2 100644 --- a/vllm/entrypoints/openai/speech_to_text/serving.py +++ b/vllm/entrypoints/openai/speech_to_text/serving.py @@ -4,7 +4,7 @@ from fastapi import Request -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, @@ -35,6 +35,7 @@ class OpenAIServingTranscription(OpenAISpeechToText): def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -44,6 +45,7 @@ def __init__( enable_force_include_usage: bool = False, ): super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, @@ -108,6 +110,7 @@ class OpenAIServingTranslation(OpenAISpeechToText): def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -117,6 +120,7 @@ def __init__( enable_force_include_usage: bool = False, ): super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, diff --git a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py index 4e3d22af13aa..ea0a60348893 100644 --- a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py @@ -14,7 +14,7 @@ from transformers import PreTrainedTokenizerBase import vllm.envs as envs -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ( DeltaMessage, @@ -85,6 +85,7 @@ class OpenAISpeechToText(OpenAIServingInference): def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -95,6 +96,7 @@ def __init__( enable_force_include_usage: bool = False, ): super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, @@ -478,8 +480,8 @@ async def _create_speech_to_text( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.engine_client.errored: - raise self.engine_client.dead_error + if self.renderer_client.errored: + raise self.renderer_client.dead_error if request.response_format not in ["text", "json", "verbose_json"]: return self.create_error_response( diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 20da4c3b1a4a..a4793fe1723c 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -19,7 +19,7 @@ WeightTransferUpdateRequest, ) from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.protocol import EngineClient, StreamingInput +from vllm.engine.protocol import EngineClient, RendererClient, StreamingInput from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -67,7 +67,7 @@ def __init__(self, cause: Exception): super().__init__(str(cause)) -class AsyncLLM(EngineClient): +class AsyncLLM(RendererClient, EngineClient): """An asynchronous wrapper for the vLLM engine.""" def __init__( From e331104a1aae8d2a0faf5d4eafa52143cbd745af Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Wed, 25 Feb 2026 12:46:15 +0200 Subject: [PATCH 05/53] remove OpenAIServingInference Signed-off-by: Sage Ahrac --- .../openai/chat_completion/serving.py | 4 +- vllm/entrypoints/openai/completion/serving.py | 4 +- vllm/entrypoints/openai/engine/serving.py | 1706 ++++++++--------- vllm/entrypoints/openai/models/serving.py | 12 +- vllm/entrypoints/openai/realtime/serving.py | 8 +- vllm/entrypoints/openai/responses/serving.py | 4 +- .../openai/speech_to_text/speech_to_text.py | 7 +- vllm/entrypoints/pooling/classify/serving.py | 8 +- vllm/entrypoints/pooling/embed/serving.py | 8 +- vllm/entrypoints/pooling/pooling/serving.py | 8 +- vllm/entrypoints/pooling/score/serving.py | 8 +- vllm/entrypoints/serve/disagg/serving.py | 15 +- 12 files changed, 879 insertions(+), 913 deletions(-) diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 488909d8165b..cdb1ba32a2be 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -52,7 +52,7 @@ ) from vllm.entrypoints.openai.engine.serving import ( GenerationError, - OpenAIServingInference, + OpenAIServing, clamp_prompt_logprobs, ) from vllm.entrypoints.openai.models.serving import OpenAIServingModels @@ -85,7 +85,7 @@ logger = init_logger(__name__) -class OpenAIServingChat(OpenAIServingInference): +class OpenAIServingChat(OpenAIServing): def __init__( self, renderer_client: RendererClient, diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index a375f80096af..67c9a9a58ecf 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -28,7 +28,7 @@ ) from vllm.entrypoints.openai.engine.serving import ( GenerationError, - OpenAIServingInference, + OpenAIServing, clamp_prompt_logprobs, ) from vllm.entrypoints.openai.models.serving import OpenAIServingModels @@ -46,7 +46,7 @@ logger = init_logger(__name__) -class OpenAIServingCompletion(OpenAIServingInference): +class OpenAIServingCompletion(OpenAIServing): def __init__( self, renderer_client: RendererClient, diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 0bd8c20dc14d..4c147ab4ef31 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -230,6 +230,7 @@ class OpenAIServing: def __init__( self, renderer_client: RendererClient, + engine_client: EngineClient, models: OpenAIServingModels, *, request_logger: RequestLogger | None, @@ -239,6 +240,7 @@ def __init__( super().__init__() self.renderer_client = renderer_client + self.engine_client = engine_client self.models = models @@ -252,646 +254,473 @@ def __init__( self.io_processor = renderer_client.io_processor self.input_processor = renderer_client.input_processor - def create_error_response( + async def beam_search( self, - message: str | Exception, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, - param: str | None = None, - ) -> ErrorResponse: - exc: Exception | None = None - - if isinstance(message, Exception): - exc = message - - from vllm.exceptions import VLLMValidationError + prompt: ProcessorInputs, + request_id: str, + params: BeamSearchParams, + lora_request: LoRARequest | None = None, + trace_headers: Mapping[str, str] | None = None, + ) -> AsyncGenerator[RequestOutput, None]: + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + include_stop_str_in_output = params.include_stop_str_in_output - if isinstance(exc, VLLMValidationError): - err_type = "BadRequestError" - status_code = HTTPStatus.BAD_REQUEST - param = exc.parameter - elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)): - # Common validation errors from user input - err_type = "BadRequestError" - status_code = HTTPStatus.BAD_REQUEST - param = None - elif isinstance(exc, NotImplementedError): - err_type = "NotImplementedError" - status_code = HTTPStatus.NOT_IMPLEMENTED - param = None - elif exc.__class__.__name__ == "TemplateError": - # jinja2.TemplateError (avoid importing jinja2) - err_type = "BadRequestError" - status_code = HTTPStatus.BAD_REQUEST - param = None - else: - err_type = "InternalServerError" - status_code = HTTPStatus.INTERNAL_SERVER_ERROR - param = None + tokenizer = self.renderer.get_tokenizer() + eos_token_id = tokenizer.eos_token_id + sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) - message = str(exc) + if prompt["type"] == "embeds": + raise NotImplementedError("Embedding prompt not supported for beam search") + if prompt["type"] == "enc_dec": + raise NotImplementedError( + "Encoder-decoder prompt not supported for beam search" + ) - if self.log_error_stack: - exc_type, _, _ = sys.exc_info() - if exc_type is not None: - traceback.print_exc() - else: - traceback.print_stack() + prompt_text = prompt.get("prompt") + prompt_token_ids = prompt["prompt_token_ids"] + tokenized_length = len(prompt_token_ids) - return ErrorResponse( - error=ErrorInfo( - message=sanitize_message(message), - type=err_type, - code=status_code.value, - param=param, - ) + logprobs_num = 2 * beam_width + sampling_params = SamplingParams( + logprobs=logprobs_num, + max_tokens=1, + temperature=temperature, ) + all_beams = [ + BeamSearchSequence( + orig_prompt=prompt, + tokens=prompt_token_ids, + cum_logprob=0, + logprobs=[], + lora_request=lora_request, + ) + ] + completed = [] - def create_streaming_error_response( - self, - message: str | Exception, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, - param: str | None = None, - ) -> str: - json_str = json.dumps( - self.create_error_response( - message=message, - err_type=err_type, - status_code=status_code, - param=param, - ).model_dump() - ) - return json_str + for _ in range(max_tokens): + tasks = [] + request_id_batch = f"{request_id}-{random_uuid()}" - def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None: - """Raise GenerationError if finish_reason indicates an error.""" - if finish_reason == "error": - logger.error( - "Request %s failed with an internal error during generation", - request_id, - ) - raise GenerationError("Internal server error") + for i, beam in enumerate(all_beams): + prompt_item = beam.get_prompt() + lora_request_item = beam.lora_request + request_id_item = f"{request_id_batch}-beam-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.engine_client.generate( + prompt_item, + sampling_params, + request_id_item, + lora_request=lora_request_item, + trace_headers=trace_headers, + ) + ) + ) + tasks.append(task) - def _convert_generation_error_to_response( - self, e: GenerationError - ) -> ErrorResponse: - """Convert GenerationError to ErrorResponse.""" - return self.create_error_response( - str(e), - err_type="InternalServerError", - status_code=e.status_code, - ) + output = [x[0] for x in await asyncio.gather(*tasks)] - def _convert_generation_error_to_streaming_response( - self, e: GenerationError - ) -> str: - """Convert GenerationError to streaming error response.""" - return self.create_streaming_error_response( - str(e), - err_type="InternalServerError", - status_code=e.status_code, - ) + new_beams = [] + # Store all new tokens generated by beam + all_beams_token_id = [] + # Store the cumulative probability of all tokens + # generated by beam search + all_beams_logprob = [] + # Iterate through all beam inference results + for i, result in enumerate(output): + current_beam = all_beams[i] - async def _check_model( - self, - request: AnyRequest, - ) -> ErrorResponse | None: - if self._is_model_supported(request.model): - return None - if request.model in self.models.lora_requests: - return None + # check for error finish reason and abort beam search + if result.outputs[0].finish_reason == "error": + # yield error output and terminate beam search + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput( + index=0, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + ) + return - return self.create_error_response( - message=f"The model `{request.model}` does not exist.", - err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND, - param="model", - ) + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + all_beams_token_id.extend(list(logprobs.keys())) + all_beams_logprob.extend( + [ + current_beam.cum_logprob + obj.logprob + for obj in logprobs.values() + ] + ) - def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None: - """Determine if there are any active default multimodal loras.""" - # TODO: Currently this is only enabled for chat completions - # to be better aligned with only being enabled for .generate - # when run offline. It would be nice to support additional - # tasks types in the future. - message_types = self._get_message_types(request) - default_mm_loras = set() + # Handle the token for the end of sentence (EOS) + all_beams_token_id = np.array(all_beams_token_id) + all_beams_logprob = np.array(all_beams_logprob) - for lora in self.models.lora_requests.values(): - # Best effort match for default multimodal lora adapters; - # There is probably a better way to do this, but currently - # this matches against the set of 'types' in any content lists - # up until '_', e.g., to match audio_url -> audio - if lora.lora_name in message_types: - default_mm_loras.add(lora) + if not ignore_eos: + # Get the index position of eos token in all generated results + eos_idx = np.where(all_beams_token_id == eos_token_id)[0] + for idx in eos_idx: + current_beam = all_beams[idx // logprobs_num] + result = output[idx // logprobs_num] + assert result.outputs[0].logprobs is not None + logprobs_entry = result.outputs[0].logprobs[0] + completed.append( + BeamSearchSequence( + orig_prompt=prompt, + tokens=current_beam.tokens + [eos_token_id] + if include_stop_str_in_output + else current_beam.tokens, + logprobs=current_beam.logprobs + [logprobs_entry], + cum_logprob=float(all_beams_logprob[idx]), + finish_reason="stop", + stop_reason=eos_token_id, + ) + ) + # After processing, set the log probability of the eos condition + # to negative infinity. + all_beams_logprob[eos_idx] = -np.inf - # Currently only support default modality specific loras if - # we have exactly one lora matched on the request. - if len(default_mm_loras) == 1: - return default_mm_loras.pop() - return None + # Processing non-EOS tokens + # Get indices of the top beam_width probabilities + topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[ + :beam_width + ] - def _maybe_get_adapters( - self, - request: AnyRequest, - supports_default_mm_loras: bool = False, - ) -> LoRARequest | None: - if request.model in self.models.lora_requests: - return self.models.lora_requests[request.model] + for idx in topn_idx: + current_beam = all_beams[idx // logprobs_num] + result = output[idx // logprobs_num] + token_id = int(all_beams_token_id[idx]) + assert result.outputs[0].logprobs is not None + logprobs_entry = result.outputs[0].logprobs[0] + new_beams.append( + BeamSearchSequence( + orig_prompt=prompt, + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs_entry], + lora_request=current_beam.lora_request, + cum_logprob=float(all_beams_logprob[idx]), + ) + ) - # Currently only support default modality specific loras - # if we have exactly one lora matched on the request. - if supports_default_mm_loras: - default_mm_lora = self._get_active_default_mm_loras(request) - if default_mm_lora is not None: - return default_mm_lora + all_beams = new_beams - if self._is_model_supported(request.model): - return None + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] - # if _check_model has been called earlier, this will be unreachable - raise ValueError(f"The model `{request.model}` does not exist.") + for beam in best_beams: + if beam.tokens[-1] == eos_token_id and not ignore_eos: + # Skip the eos token in the text. + tokens = beam.tokens[tokenized_length:-1] + else: + tokens = beam.tokens[tokenized_length:] + beam.text = tokenizer.decode(tokens) - def _get_message_types(self, request: AnyRequest) -> set[str]: - """Retrieve the set of types from message content dicts up - until `_`; we use this to match potential multimodal data - with default per modality loras. + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput( + text=beam.text, # type: ignore + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + finish_reason=beam.finish_reason + if beam.finish_reason is not None + else "length", + stop_reason=beam.stop_reason, + ) + for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + ) + + async def _preprocess( + self, + ctx: ServeContext, + ) -> ErrorResponse | None: """ - message_types: set[str] = set() + Default preprocessing hook. Subclasses may override + to prepare `ctx` (classification, embedding, etc.). + """ + return None - if not hasattr(request, "messages"): - return message_types + def _build_response( + self, + ctx: ServeContext, + ) -> AnyResponse | ErrorResponse: + """ + Default response builder. Subclass may override this method + to return the appropriate response object. + """ + return self.create_error_response("unimplemented endpoint") - messages = request.messages - if messages is None or isinstance(messages, (str, bytes)): - return message_types + async def handle( + self, + ctx: ServeContext, + ) -> AnyResponse | ErrorResponse: + async for response in self._pipeline(ctx): + return response - for message in messages: - if ( - isinstance(message, dict) - and "content" in message - and isinstance(message["content"], list) - ): - for content_dict in message["content"]: - if "type" in content_dict: - message_types.add(content_dict["type"].split("_")[0]) - return message_types + return self.create_error_response("No response yielded from pipeline") - def _validate_input( + async def _pipeline( self, - request: object, - input_ids: list[int], - input_text: str, - ) -> TokensPrompt: - token_num = len(input_ids) - max_model_len = self.model_config.max_model_len - - # Note: EmbeddingRequest, ClassificationRequest, - # and ScoreRequest doesn't have max_tokens - if isinstance( - request, - ( - EmbeddingChatRequest, - EmbeddingCompletionRequest, - ScoreDataRequest, - ScoreTextRequest, - ScoreQueriesDocumentsRequest, - RerankRequest, - ClassificationCompletionRequest, - ClassificationChatRequest, - ), - ): - # Note: input length can be up to the entire model context length - # since these requests don't generate tokens. - if token_num > max_model_len: - operations: dict[type[AnyRequest], str] = { - ScoreDataRequest: "score", - ScoreTextRequest: "score", - ScoreQueriesDocumentsRequest: "score", - ClassificationCompletionRequest: "classification", - ClassificationChatRequest: "classification", - } - operation = operations.get(type(request), "embedding generation") - raise VLLMValidationError( - f"This model's maximum context length is " - f"{max_model_len} tokens. However, you requested " - f"{token_num} tokens in the input for {operation}. " - f"Please reduce the length of the input.", - parameter="input_tokens", - value=token_num, - ) - return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + ctx: ServeContext, + ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]: + """Execute the request processing pipeline yielding responses.""" + if error := await self._check_model(ctx.request): + yield error + if error := self._validate_request(ctx): + yield error - # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens - # and does not require model context length validation - if isinstance( - request, - (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest), - ): - return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + preprocess_ret = await self._preprocess(ctx) + if isinstance(preprocess_ret, ErrorResponse): + yield preprocess_ret - # chat completion endpoint supports max_completion_tokens - if isinstance(request, ChatCompletionRequest): - # TODO(#9845): remove max_tokens when field dropped from OpenAI API - max_tokens = request.max_completion_tokens or request.max_tokens - else: - max_tokens = getattr(request, "max_tokens", None) + generators_ret = await self._prepare_generators(ctx) + if isinstance(generators_ret, ErrorResponse): + yield generators_ret - # Note: input length can be up to model context length - 1 for - # completion-like requests. - if token_num >= max_model_len: - raise VLLMValidationError( - f"This model's maximum context length is " - f"{max_model_len} tokens. However, your request has " - f"{token_num} input tokens. Please reduce the length of " - "the input messages.", - parameter="input_tokens", - value=token_num, - ) + collect_ret = await self._collect_batch(ctx) + if isinstance(collect_ret, ErrorResponse): + yield collect_ret - if max_tokens is not None and token_num + max_tokens > max_model_len: - raise VLLMValidationError( - "'max_tokens' or 'max_completion_tokens' is too large: " - f"{max_tokens}. This model's maximum context length is " - f"{max_model_len} tokens and your request has " - f"{token_num} input tokens ({max_tokens} > {max_model_len}" - f" - {token_num}).", - parameter="max_tokens", - value=max_tokens, - ) + yield self._build_response(ctx) - return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None: + truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) - def _validate_chat_template( - self, - request_chat_template: str | None, - chat_template_kwargs: dict[str, Any] | None, - trust_request_chat_template: bool, - ) -> ErrorResponse | None: - if not trust_request_chat_template and ( - request_chat_template is not None - or ( - chat_template_kwargs - and chat_template_kwargs.get("chat_template") is not None - ) + if ( + truncate_prompt_tokens is not None + and truncate_prompt_tokens > self.model_config.max_model_len ): return self.create_error_response( - "Chat template is passed with request, but " - "--trust-request-chat-template is not set. " - "Refused request with untrusted chat template." + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size." ) return None - @staticmethod - def _prepare_extra_chat_template_kwargs( - request_chat_template_kwargs: dict[str, Any] | None = None, - default_chat_template_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Helper to merge server-default and request-specific chat template kwargs.""" - request_chat_template_kwargs = request_chat_template_kwargs or {} - if default_chat_template_kwargs is None: - return request_chat_template_kwargs - # Apply server defaults first, then request kwargs override. - return default_chat_template_kwargs | request_chat_template_kwargs + def _create_pooling_params( + self, + ctx: ServeContext, + ) -> PoolingParams | ErrorResponse: + if not hasattr(ctx.request, "to_pooling_params"): + return self.create_error_response( + "Request type does not support pooling parameters" + ) - async def _preprocess_completion( + return ctx.request.to_pooling_params() + + async def _prepare_generators( self, - request: RendererRequest, - prompt_input: str | list[str] | list[int] | list[list[int]] | None, - prompt_embeds: bytes | list[bytes] | None, - ) -> list[ProcessorInputs]: - prompts = list[SingletonPrompt | bytes]() - if prompt_embeds is not None: # embeds take higher priority - prompts.extend(prompt_to_seq(prompt_embeds)) - if prompt_input is not None: - prompts.extend(prompt_to_seq(prompt_input)) - - return await self._preprocess_cmpl(request, prompts) - - async def _preprocess_cmpl( - self, - request: RendererRequest, - prompts: Sequence[PromptType | bytes], - ) -> list[ProcessorInputs]: - renderer = self.renderer - model_config = self.model_config + ctx: ServeContext, + ) -> ErrorResponse | None: + """Schedule the request and get the result generator.""" + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - parsed_prompts = [ - ( - prompt - if isinstance(prompt, bytes) - else parse_model_prompt(model_config, prompt) + try: + trace_headers = ( + None + if ctx.raw_request is None + else await self._get_trace_headers(ctx.raw_request.headers) ) - for prompt in prompts - ] - tok_params = request.build_tok_params(model_config) - - return await renderer.render_cmpl_async( - parsed_prompts, - tok_params, - prompt_extras={ - k: v - for k in ("mm_processor_kwargs", "cache_salt") - if (v := getattr(request, k, None)) is not None - }, - ) - - async def _preprocess_chat( - self, - request: RendererChatRequest, - messages: list[ChatCompletionMessageParam], - default_template: str | None, - default_template_content_format: ChatTemplateContentFormatOption, - default_template_kwargs: dict[str, Any] | None, - tool_dicts: list[dict[str, Any]] | None = None, - tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, - ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]: - renderer = self.renderer - default_template_kwargs = merge_kwargs( - default_template_kwargs, - dict( - tools=tool_dicts, - tokenize=is_mistral_tokenizer(renderer.tokenizer), - ), - ) + pooling_params = self._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params - tok_params = request.build_tok_params(self.model_config) - chat_params = request.build_chat_params( - default_template, default_template_content_format - ).with_defaults(default_template_kwargs) + if ctx.engine_prompts is None: + return self.create_error_response("Engine prompts not available") - (conversation,), (engine_prompt,) = await renderer.render_chat_async( - [messages], - chat_params, - tok_params, - prompt_extras={ - k: v - for k in ("mm_processor_kwargs", "cache_salt") - if (v := getattr(request, k, None)) is not None - }, - ) + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_id_item = f"{ctx.request_id}-{i}" - # tool parsing is done only if a tool_parser has been set and if - # tool_choice is not "none" (if tool_choice is "none" but a tool_parser - # is set, we want to prevent parsing a tool_call hallucinated by the LLM - if tool_parser is not None: - tool_choice = getattr(request, "tool_choice", "none") - if tool_choice != "none": - if not isinstance(request, ChatCompletionRequest | ResponsesRequest): - msg = ( - "Tool usage is only supported for Chat Completions API " - "or Responses API requests." - ) - raise NotImplementedError(msg) + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) - # TODO: Update adjust_request to accept ResponsesRequest - tokenizer = renderer.get_tokenizer() - request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type] + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) - return conversation, [engine_prompt] + generators.append(generator) - def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs): - return extract_prompt_components(self.model_config, prompt) + ctx.result_generator = merge_async_iterators(*generators) - def _extract_prompt_text(self, prompt: ProcessorInputs): - return self._extract_prompt_components(prompt).text + return None - def _extract_prompt_len(self, prompt: ProcessorInputs): - return extract_prompt_len(self.model_config, prompt) + except Exception as e: + return self.create_error_response(e) - async def _render_next_turn( + async def _collect_batch( self, - request: ResponsesRequest, - messages: list[ResponseInputOutputItem], - tool_dicts: list[dict[str, Any]] | None, - tool_parser: Callable[[TokenizerLike], ToolParser] | None, - chat_template: str | None, - chat_template_content_format: ChatTemplateContentFormatOption, - ): - new_messages = construct_input_messages( - request_input=messages, - ) + ctx: ServeContext, + ) -> ErrorResponse | None: + """Collect batch results from the result generator.""" + try: + if ctx.engine_prompts is None: + return self.create_error_response("Engine prompts not available") - _, engine_prompts = await self._preprocess_chat( - request, - new_messages, - default_template=chat_template, - default_template_content_format=chat_template_content_format, - default_template_kwargs=None, - tool_dicts=tool_dicts, - tool_parser=tool_parser, - ) - return engine_prompts + num_prompts = len(ctx.engine_prompts) + final_res_batch: list[PoolingRequestOutput | None] + final_res_batch = [None] * num_prompts - def _log_inputs( - self, - request_id: str, - inputs: PromptType | ProcessorInputs, - params: SamplingParams | PoolingParams | BeamSearchParams | None, - lora_request: LoRARequest | None, - ) -> None: - if self.request_logger is None: - return + if ctx.result_generator is None: + return self.create_error_response("Result generator not available") - components = self._extract_prompt_components(inputs) + async for i, res in ctx.result_generator: + final_res_batch[i] = res - self.request_logger.log_inputs( - request_id, - components.text, - components.token_ids, - components.embeds, - params=params, - lora_request=lora_request, - ) + if None in final_res_batch: + return self.create_error_response( + "Failed to generate results for all prompts" + ) - async def _get_trace_headers( - self, - headers: Headers, - ) -> Mapping[str, str] | None: - is_tracing_enabled = await self.renderer_client.is_tracing_enabled() + ctx.final_res_batch = [res for res in final_res_batch if res is not None] - if is_tracing_enabled: - return extract_trace_headers(headers) + return None - if contains_trace_headers(headers): - log_tracing_disabled_warning() + except Exception as e: + return self.create_error_response(e) - return None + def create_error_response( + self, + message: str | Exception, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + param: str | None = None, + ) -> ErrorResponse: + exc: Exception | None = None - @staticmethod - def _base_request_id( - raw_request: Request | None, default: str | None = None - ) -> str | None: - """Pulls the request id to use from a header, if provided""" - if raw_request is not None and ( - (req_id := raw_request.headers.get("X-Request-Id")) is not None - ): - return req_id + if isinstance(message, Exception): + exc = message - return random_uuid() if default is None else default + from vllm.exceptions import VLLMValidationError - @staticmethod - def _get_data_parallel_rank(raw_request: Request | None) -> int | None: - """Pulls the data parallel rank from a header, if provided""" - if raw_request is None: - return None + if isinstance(exc, VLLMValidationError): + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = exc.parameter + elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)): + # Common validation errors from user input + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = None + elif isinstance(exc, NotImplementedError): + err_type = "NotImplementedError" + status_code = HTTPStatus.NOT_IMPLEMENTED + param = None + elif exc.__class__.__name__ == "TemplateError": + # jinja2.TemplateError (avoid importing jinja2) + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = None + else: + err_type = "InternalServerError" + status_code = HTTPStatus.INTERNAL_SERVER_ERROR + param = None - rank_str = raw_request.headers.get("X-data-parallel-rank") - if rank_str is None: - return None + message = str(exc) - try: - return int(rank_str) - except ValueError: - return None + if self.log_error_stack: + exc_type, _, _ = sys.exc_info() + if exc_type is not None: + traceback.print_exc() + else: + traceback.print_stack() - @staticmethod - def _parse_tool_calls_from_content( - request: ResponsesRequest | ChatCompletionRequest, - tokenizer: TokenizerLike | None, - enable_auto_tools: bool, - tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, - content: str | None = None, - ) -> tuple[list[FunctionCall] | None, str | None]: - function_calls = list[FunctionCall]() - if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction): - assert content is not None - # Forced Function Call - function_calls.append( - FunctionCall(name=request.tool_choice.name, arguments=content) - ) - content = None # Clear content since tool is called. - elif request.tool_choice and isinstance( - request.tool_choice, ChatCompletionNamedToolChoiceParam - ): - assert content is not None - # Forced Function Call - function_calls.append( - FunctionCall(name=request.tool_choice.function.name, arguments=content) - ) - content = None # Clear content since tool is called. - elif request.tool_choice == "required": - assert content is not None - tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content) - function_calls.extend( - [ - FunctionCall( - name=tool_call.name, - arguments=json.dumps(tool_call.parameters, ensure_ascii=False), - ) - for tool_call in tool_calls - ] - ) - content = None # Clear content since tool is called. - elif ( - tool_parser_cls - and enable_auto_tools - and (request.tool_choice == "auto" or request.tool_choice is None) - ): - if tokenizer is None: - raise ValueError( - "Tokenizer not available when `skip_tokenizer_init=True`" - ) - - # Automatic Tool Call Parsing - try: - tool_parser = tool_parser_cls(tokenizer) - except RuntimeError as e: - logger.exception("Error in tool parser creation.") - raise e - tool_call_info = tool_parser.extract_tool_calls( - content if content is not None else "", - request=request, # type: ignore + return ErrorResponse( + error=ErrorInfo( + message=sanitize_message(message), + type=err_type, + code=status_code.value, + param=param, ) - if tool_call_info is not None and tool_call_info.tools_called: - # extract_tool_calls() returns a list of tool calls. - function_calls.extend( - FunctionCall( - id=tool_call.id, - name=tool_call.function.name, - arguments=tool_call.function.arguments, - ) - for tool_call in tool_call_info.tool_calls - ) - content = tool_call_info.content - if content and content.strip() == "": - content = None - else: - # No tool calls. - return None, content - - return function_calls, content + ) - @staticmethod - def _get_decoded_token( - logprob: Logprob, - token_id: int, - tokenizer: TokenizerLike | None, - return_as_token_id: bool = False, + def create_streaming_error_response( + self, + message: str | Exception, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + param: str | None = None, ) -> str: - if return_as_token_id: - return f"token_id:{token_id}" - - if logprob.decoded_token is not None: - return logprob.decoded_token + json_str = json.dumps( + self.create_error_response( + message=message, + err_type=err_type, + status_code=status_code, + param=param, + ).model_dump() + ) + return json_str - if tokenizer is None: - raise ValueError( - "Unable to get tokenizer because `skip_tokenizer_init=True`" + def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None: + """Raise GenerationError if finish_reason indicates an error.""" + if finish_reason == "error": + logger.error( + "Request %s failed with an internal error during generation", + request_id, ) + raise GenerationError("Internal server error") - return tokenizer.decode([token_id]) - - def _is_model_supported(self, model_name: str | None) -> bool: - if not model_name: - return True - return self.models.is_base_model(model_name) - - -class OpenAIServingInference(OpenAIServing): - """OpenAIServing subclass for endpoints that require inference. - - Extends :class:`OpenAIServing` with a separate ``engine_client`` - (:class:`EngineClient`) for methods that call ``generate()`` / - ``encode()`` on the engine. - """ + def _convert_generation_error_to_response( + self, e: GenerationError + ) -> ErrorResponse: + """Convert GenerationError to ErrorResponse.""" + return self.create_error_response( + str(e), + err_type="InternalServerError", + status_code=e.status_code, + ) - def __init__( - self, - renderer_client: RendererClient, - engine_client: EngineClient, - models: OpenAIServingModels, - *, - request_logger: RequestLogger | None, - return_tokens_as_token_ids: bool = False, - log_error_stack: bool = False, - ): - super().__init__( - renderer_client=renderer_client, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - log_error_stack=log_error_stack, + def _convert_generation_error_to_streaming_response( + self, e: GenerationError + ) -> str: + """Convert GenerationError to streaming error response.""" + return self.create_streaming_error_response( + str(e), + err_type="InternalServerError", + status_code=e.status_code, ) - self.engine_client = engine_client async def _check_model( self, request: AnyRequest, ) -> ErrorResponse | None: - """Extend base model check with runtime LoRA resolution. + error_response = None - Runtime LoRA loading requires :meth:`EngineClient.add_lora`, - which is only available on the full inference engine. - """ - # Check known models/LoRAs first (base class) - error_response = await super()._check_model(request) - if error_response is None: + if self._is_model_supported(request.model): + return None + if request.model in self.models.lora_requests: return None - - # Attempt runtime LoRA resolution (needs inference engine) if ( envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model @@ -903,365 +732,330 @@ async def _check_model( isinstance(load_result, ErrorResponse) and load_result.error.code == HTTPStatus.BAD_REQUEST.value ): - return load_result + error_response = load_result - return error_response - - async def _preprocess( - self, - ctx: ServeContext, - ) -> ErrorResponse | None: - """ - Default preprocessing hook. Subclasses may override - to prepare `ctx` (classification, embedding, etc.). - """ - return None + return error_response or self.create_error_response( + message=f"The model `{request.model}` does not exist.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND, + param="model", + ) - def _build_response( - self, - ctx: ServeContext, - ) -> AnyResponse | ErrorResponse: - """ - Default response builder. Subclass may override this method - to return the appropriate response object. - """ - return self.create_error_response("unimplemented endpoint") + def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None: + """Determine if there are any active default multimodal loras.""" + # TODO: Currently this is only enabled for chat completions + # to be better aligned with only being enabled for .generate + # when run offline. It would be nice to support additional + # tasks types in the future. + message_types = self._get_message_types(request) + default_mm_loras = set() - async def handle( - self, - ctx: ServeContext, - ) -> AnyResponse | ErrorResponse: - async for response in self._pipeline(ctx): - return response + for lora in self.models.lora_requests.values(): + # Best effort match for default multimodal lora adapters; + # There is probably a better way to do this, but currently + # this matches against the set of 'types' in any content lists + # up until '_', e.g., to match audio_url -> audio + if lora.lora_name in message_types: + default_mm_loras.add(lora) - return self.create_error_response("No response yielded from pipeline") + # Currently only support default modality specific loras if + # we have exactly one lora matched on the request. + if len(default_mm_loras) == 1: + return default_mm_loras.pop() + return None - async def _pipeline( + def _maybe_get_adapters( self, - ctx: ServeContext, - ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]: - """Execute the request processing pipeline yielding responses.""" - if error := await self._check_model(ctx.request): - yield error - if error := self._validate_request(ctx): - yield error + request: AnyRequest, + supports_default_mm_loras: bool = False, + ) -> LoRARequest | None: + if request.model in self.models.lora_requests: + return self.models.lora_requests[request.model] - preprocess_ret = await self._preprocess(ctx) - if isinstance(preprocess_ret, ErrorResponse): - yield preprocess_ret + # Currently only support default modality specific loras + # if we have exactly one lora matched on the request. + if supports_default_mm_loras: + default_mm_lora = self._get_active_default_mm_loras(request) + if default_mm_lora is not None: + return default_mm_lora - generators_ret = await self._prepare_generators(ctx) - if isinstance(generators_ret, ErrorResponse): - yield generators_ret + if self._is_model_supported(request.model): + return None - collect_ret = await self._collect_batch(ctx) - if isinstance(collect_ret, ErrorResponse): - yield collect_ret + # if _check_model has been called earlier, this will be unreachable + raise ValueError(f"The model `{request.model}` does not exist.") - yield self._build_response(ctx) + def _get_message_types(self, request: AnyRequest) -> set[str]: + """Retrieve the set of types from message content dicts up + until `_`; we use this to match potential multimodal data + with default per modality loras. + """ + message_types: set[str] = set() - def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None: - truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) + if not hasattr(request, "messages"): + return message_types - if ( - truncate_prompt_tokens is not None - and truncate_prompt_tokens > self.model_config.max_model_len - ): - return self.create_error_response( - "truncate_prompt_tokens value is " - "greater than max_model_len." - " Please, select a smaller truncation size." - ) - return None - - def _create_pooling_params( - self, - ctx: ServeContext, - ) -> PoolingParams | ErrorResponse: - if not hasattr(ctx.request, "to_pooling_params"): - return self.create_error_response( - "Request type does not support pooling parameters" - ) + messages = request.messages + if messages is None or isinstance(messages, (str, bytes)): + return message_types - return ctx.request.to_pooling_params() + for message in messages: + if ( + isinstance(message, dict) + and "content" in message + and isinstance(message["content"], list) + ): + for content_dict in message["content"]: + if "type" in content_dict: + message_types.add(content_dict["type"].split("_")[0]) + return message_types - async def _collect_batch( + def _validate_input( self, - ctx: ServeContext, - ) -> ErrorResponse | None: - """Collect batch results from the result generator.""" - try: - if ctx.engine_prompts is None: - return self.create_error_response("Engine prompts not available") - - num_prompts = len(ctx.engine_prompts) - final_res_batch: list[PoolingRequestOutput | None] - final_res_batch = [None] * num_prompts - - if ctx.result_generator is None: - return self.create_error_response("Result generator not available") - - async for i, res in ctx.result_generator: - final_res_batch[i] = res + request: object, + input_ids: list[int], + input_text: str, + ) -> TokensPrompt: + token_num = len(input_ids) + max_model_len = self.model_config.max_model_len - if None in final_res_batch: - return self.create_error_response( - "Failed to generate results for all prompts" + # Note: EmbeddingRequest, ClassificationRequest, + # and ScoreRequest doesn't have max_tokens + if isinstance( + request, + ( + EmbeddingChatRequest, + EmbeddingCompletionRequest, + ScoreDataRequest, + ScoreTextRequest, + ScoreQueriesDocumentsRequest, + RerankRequest, + ClassificationCompletionRequest, + ClassificationChatRequest, + ), + ): + # Note: input length can be up to the entire model context length + # since these requests don't generate tokens. + if token_num > max_model_len: + operations: dict[type[AnyRequest], str] = { + ScoreDataRequest: "score", + ScoreTextRequest: "score", + ScoreQueriesDocumentsRequest: "score", + ClassificationCompletionRequest: "classification", + ClassificationChatRequest: "classification", + } + operation = operations.get(type(request), "embedding generation") + raise VLLMValidationError( + f"This model's maximum context length is " + f"{max_model_len} tokens. However, you requested " + f"{token_num} tokens in the input for {operation}. " + f"Please reduce the length of the input.", + parameter="input_tokens", + value=token_num, ) + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) - ctx.final_res_batch = [res for res in final_res_batch if res is not None] - - return None - - except Exception as e: - return self.create_error_response(e) - - async def beam_search( - self, - prompt: ProcessorInputs, - request_id: str, - params: BeamSearchParams, - lora_request: LoRARequest | None = None, - trace_headers: Mapping[str, str] | None = None, - ) -> AsyncGenerator[RequestOutput, None]: - beam_width = params.beam_width - max_tokens = params.max_tokens - ignore_eos = params.ignore_eos - temperature = params.temperature - length_penalty = params.length_penalty - include_stop_str_in_output = params.include_stop_str_in_output + # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens + # and does not require model context length validation + if isinstance( + request, + (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest), + ): + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) - tokenizer = self.renderer.get_tokenizer() - eos_token_id = tokenizer.eos_token_id - sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) + # chat completion endpoint supports max_completion_tokens + if isinstance(request, ChatCompletionRequest): + # TODO(#9845): remove max_tokens when field dropped from OpenAI API + max_tokens = request.max_completion_tokens or request.max_tokens + else: + max_tokens = getattr(request, "max_tokens", None) - if prompt["type"] == "embeds": - raise NotImplementedError("Embedding prompt not supported for beam search") - if prompt["type"] == "enc_dec": - raise NotImplementedError( - "Encoder-decoder prompt not supported for beam search" + # Note: input length can be up to model context length - 1 for + # completion-like requests. + if token_num >= max_model_len: + raise VLLMValidationError( + f"This model's maximum context length is " + f"{max_model_len} tokens. However, your request has " + f"{token_num} input tokens. Please reduce the length of " + "the input messages.", + parameter="input_tokens", + value=token_num, ) - prompt_text = prompt.get("prompt") - prompt_token_ids = prompt["prompt_token_ids"] - tokenized_length = len(prompt_token_ids) - - logprobs_num = 2 * beam_width - sampling_params = SamplingParams( - logprobs=logprobs_num, - max_tokens=1, - temperature=temperature, - ) - all_beams = [ - BeamSearchSequence( - orig_prompt=prompt, - tokens=prompt_token_ids, - cum_logprob=0, - logprobs=[], - lora_request=lora_request, + if max_tokens is not None and token_num + max_tokens > max_model_len: + raise VLLMValidationError( + "'max_tokens' or 'max_completion_tokens' is too large: " + f"{max_tokens}. This model's maximum context length is " + f"{max_model_len} tokens and your request has " + f"{token_num} input tokens ({max_tokens} > {max_model_len}" + f" - {token_num}).", + parameter="max_tokens", + value=max_tokens, ) - ] - completed = [] - - for _ in range(max_tokens): - tasks = [] - request_id_batch = f"{request_id}-{random_uuid()}" - for i, beam in enumerate(all_beams): - prompt_item = beam.get_prompt() - lora_request_item = beam.lora_request - request_id_item = f"{request_id_batch}-beam-{i}" - task = asyncio.create_task( - collect_from_async_generator( - self.engine_client.generate( - prompt_item, - sampling_params, - request_id_item, - lora_request=lora_request_item, - trace_headers=trace_headers, - ) - ) - ) - tasks.append(task) - - output = [x[0] for x in await asyncio.gather(*tasks)] - - new_beams = [] - # Store all new tokens generated by beam - all_beams_token_id = [] - # Store the cumulative probability of all tokens - # generated by beam search - all_beams_logprob = [] - # Iterate through all beam inference results - for i, result in enumerate(output): - current_beam = all_beams[i] - - # check for error finish reason and abort beam search - if result.outputs[0].finish_reason == "error": - # yield error output and terminate beam search - yield RequestOutput( - request_id=request_id, - prompt=prompt_text, - outputs=[ - CompletionOutput( - index=0, - text="", - token_ids=[], - cumulative_logprob=None, - logprobs=None, - finish_reason="error", - ) - ], - finished=True, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=None, - ) - return - - if result.outputs[0].logprobs is not None: - logprobs = result.outputs[0].logprobs[0] - all_beams_token_id.extend(list(logprobs.keys())) - all_beams_logprob.extend( - [ - current_beam.cum_logprob + obj.logprob - for obj in logprobs.values() - ] - ) - - # Handle the token for the end of sentence (EOS) - all_beams_token_id = np.array(all_beams_token_id) - all_beams_logprob = np.array(all_beams_logprob) + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) - if not ignore_eos: - # Get the index position of eos token in all generated results - eos_idx = np.where(all_beams_token_id == eos_token_id)[0] - for idx in eos_idx: - current_beam = all_beams[idx // logprobs_num] - result = output[idx // logprobs_num] - assert result.outputs[0].logprobs is not None - logprobs_entry = result.outputs[0].logprobs[0] - completed.append( - BeamSearchSequence( - orig_prompt=prompt, - tokens=current_beam.tokens + [eos_token_id] - if include_stop_str_in_output - else current_beam.tokens, - logprobs=current_beam.logprobs + [logprobs_entry], - cum_logprob=float(all_beams_logprob[idx]), - finish_reason="stop", - stop_reason=eos_token_id, - ) - ) - # After processing, set the log probability of the eos condition - # to negative infinity. - all_beams_logprob[eos_idx] = -np.inf + def _validate_chat_template( + self, + request_chat_template: str | None, + chat_template_kwargs: dict[str, Any] | None, + trust_request_chat_template: bool, + ) -> ErrorResponse | None: + if not trust_request_chat_template and ( + request_chat_template is not None + or ( + chat_template_kwargs + and chat_template_kwargs.get("chat_template") is not None + ) + ): + return self.create_error_response( + "Chat template is passed with request, but " + "--trust-request-chat-template is not set. " + "Refused request with untrusted chat template." + ) + return None - # Processing non-EOS tokens - # Get indices of the top beam_width probabilities - topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[ - :beam_width - ] + @staticmethod + def _prepare_extra_chat_template_kwargs( + request_chat_template_kwargs: dict[str, Any] | None = None, + default_chat_template_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Helper to merge server-default and request-specific chat template kwargs.""" + request_chat_template_kwargs = request_chat_template_kwargs or {} + if default_chat_template_kwargs is None: + return request_chat_template_kwargs + # Apply server defaults first, then request kwargs override. + return default_chat_template_kwargs | request_chat_template_kwargs - for idx in topn_idx: - current_beam = all_beams[idx // logprobs_num] - result = output[idx // logprobs_num] - token_id = int(all_beams_token_id[idx]) - assert result.outputs[0].logprobs is not None - logprobs_entry = result.outputs[0].logprobs[0] - new_beams.append( - BeamSearchSequence( - orig_prompt=prompt, - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs_entry], - lora_request=current_beam.lora_request, - cum_logprob=float(all_beams_logprob[idx]), - ) - ) + async def _preprocess_completion( + self, + request: RendererRequest, + prompt_input: str | list[str] | list[int] | list[list[int]] | None, + prompt_embeds: bytes | list[bytes] | None, + ) -> list[ProcessorInputs]: + prompts = list[SingletonPrompt | bytes]() + if prompt_embeds is not None: # embeds take higher priority + prompts.extend(prompt_to_seq(prompt_embeds)) + if prompt_input is not None: + prompts.extend(prompt_to_seq(prompt_input)) - all_beams = new_beams + return await self._preprocess_cmpl(request, prompts) - completed.extend(all_beams) - sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) - best_beams = sorted_completed[:beam_width] + async def _preprocess_cmpl( + self, + request: RendererRequest, + prompts: Sequence[PromptType | bytes], + ) -> list[ProcessorInputs]: + renderer = self.renderer + model_config = self.model_config - for beam in best_beams: - if beam.tokens[-1] == eos_token_id and not ignore_eos: - # Skip the eos token in the text. - tokens = beam.tokens[tokenized_length:-1] - else: - tokens = beam.tokens[tokenized_length:] - beam.text = tokenizer.decode(tokens) + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(model_config, prompt) + ) + for prompt in prompts + ] + tok_params = request.build_tok_params(model_config) - yield RequestOutput( - request_id=request_id, - prompt=prompt_text, - outputs=[ - CompletionOutput( - text=beam.text, # type: ignore - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens[tokenized_length:], - index=i, - logprobs=beam.logprobs, - finish_reason=beam.finish_reason - if beam.finish_reason is not None - else "length", - stop_reason=beam.stop_reason, - ) - for (i, beam) in enumerate(best_beams) - ], - finished=True, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=None, + return await renderer.render_cmpl_async( + parsed_prompts, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, ) - async def _prepare_generators( + async def _preprocess_chat( self, - ctx: ServeContext, - ) -> ErrorResponse | None: - """Schedule the request and get the result generator.""" - generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + request: RendererChatRequest, + messages: list[ChatCompletionMessageParam], + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + default_template_kwargs: dict[str, Any] | None, + tool_dicts: list[dict[str, Any]] | None = None, + tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, + ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]: + renderer = self.renderer + + default_template_kwargs = merge_kwargs( + default_template_kwargs, + dict( + tools=tool_dicts, + tokenize=is_mistral_tokenizer(renderer.tokenizer), + ), + ) - try: - trace_headers = ( - None - if ctx.raw_request is None - else await self._get_trace_headers(ctx.raw_request.headers) - ) + tok_params = request.build_tok_params(self.model_config) + chat_params = request.build_chat_params( + default_template, default_template_content_format + ).with_defaults(default_template_kwargs) - pooling_params = self._create_pooling_params(ctx) - if isinstance(pooling_params, ErrorResponse): - return pooling_params + (conversation,), (engine_prompt,) = await renderer.render_chat_async( + [messages], + chat_params, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) - if ctx.engine_prompts is None: - return self.create_error_response("Engine prompts not available") + # tool parsing is done only if a tool_parser has been set and if + # tool_choice is not "none" (if tool_choice is "none" but a tool_parser + # is set, we want to prevent parsing a tool_call hallucinated by the LLM + if tool_parser is not None: + tool_choice = getattr(request, "tool_choice", "none") + if tool_choice != "none": + if not isinstance(request, ChatCompletionRequest | ResponsesRequest): + msg = ( + "Tool usage is only supported for Chat Completions API " + "or Responses API requests." + ) + raise NotImplementedError(msg) - for i, engine_prompt in enumerate(ctx.engine_prompts): - request_id_item = f"{ctx.request_id}-{i}" + # TODO: Update adjust_request to accept ResponsesRequest + tokenizer = renderer.get_tokenizer() + request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type] - self._log_inputs( - request_id_item, - engine_prompt, - params=pooling_params, - lora_request=ctx.lora_request, - ) + return conversation, [engine_prompt] - generator = self.engine_client.encode( - engine_prompt, - pooling_params, - request_id_item, - lora_request=ctx.lora_request, - trace_headers=trace_headers, - priority=getattr(ctx.request, "priority", 0), - ) + def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs): + return extract_prompt_components(self.model_config, prompt) - generators.append(generator) + def _extract_prompt_text(self, prompt: ProcessorInputs): + return self._extract_prompt_components(prompt).text - ctx.result_generator = merge_async_iterators(*generators) + def _extract_prompt_len(self, prompt: ProcessorInputs): + return extract_prompt_len(self.model_config, prompt) - return None + async def _render_next_turn( + self, + request: ResponsesRequest, + messages: list[ResponseInputOutputItem], + tool_dicts: list[dict[str, Any]] | None, + tool_parser: Callable[[TokenizerLike], ToolParser] | None, + chat_template: str | None, + chat_template_content_format: ChatTemplateContentFormatOption, + ): + new_messages = construct_input_messages( + request_input=messages, + ) - except Exception as e: - return self.create_error_response(e) + _, engine_prompts = await self._preprocess_chat( + request, + new_messages, + default_template=chat_template, + default_template_content_format=chat_template_content_format, + default_template_kwargs=None, + tool_dicts=tool_dicts, + tool_parser=tool_parser, + ) + return engine_prompts async def _generate_with_builtin_tools( self, @@ -1342,6 +1136,170 @@ async def _generate_with_builtin_tools( priority = orig_priority - 1 sub_request += 1 + def _log_inputs( + self, + request_id: str, + inputs: PromptType | ProcessorInputs, + params: SamplingParams | PoolingParams | BeamSearchParams | None, + lora_request: LoRARequest | None, + ) -> None: + if self.request_logger is None: + return + + components = self._extract_prompt_components(inputs) + + self.request_logger.log_inputs( + request_id, + components.text, + components.token_ids, + components.embeds, + params=params, + lora_request=lora_request, + ) + + async def _get_trace_headers( + self, + headers: Headers, + ) -> Mapping[str, str] | None: + is_tracing_enabled = await self.renderer_client.is_tracing_enabled() + + if is_tracing_enabled: + return extract_trace_headers(headers) + + if contains_trace_headers(headers): + log_tracing_disabled_warning() + + return None + + @staticmethod + def _base_request_id( + raw_request: Request | None, default: str | None = None + ) -> str | None: + """Pulls the request id to use from a header, if provided""" + if raw_request is not None and ( + (req_id := raw_request.headers.get("X-Request-Id")) is not None + ): + return req_id + + return random_uuid() if default is None else default + + @staticmethod + def _get_data_parallel_rank(raw_request: Request | None) -> int | None: + """Pulls the data parallel rank from a header, if provided""" + if raw_request is None: + return None + + rank_str = raw_request.headers.get("X-data-parallel-rank") + if rank_str is None: + return None + + try: + return int(rank_str) + except ValueError: + return None + + @staticmethod + def _parse_tool_calls_from_content( + request: ResponsesRequest | ChatCompletionRequest, + tokenizer: TokenizerLike | None, + enable_auto_tools: bool, + tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, + content: str | None = None, + ) -> tuple[list[FunctionCall] | None, str | None]: + function_calls = list[FunctionCall]() + if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction): + assert content is not None + # Forced Function Call + function_calls.append( + FunctionCall(name=request.tool_choice.name, arguments=content) + ) + content = None # Clear content since tool is called. + elif request.tool_choice and isinstance( + request.tool_choice, ChatCompletionNamedToolChoiceParam + ): + assert content is not None + # Forced Function Call + function_calls.append( + FunctionCall(name=request.tool_choice.function.name, arguments=content) + ) + content = None # Clear content since tool is called. + elif request.tool_choice == "required": + assert content is not None + tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content) + function_calls.extend( + [ + FunctionCall( + name=tool_call.name, + arguments=json.dumps(tool_call.parameters, ensure_ascii=False), + ) + for tool_call in tool_calls + ] + ) + content = None # Clear content since tool is called. + elif ( + tool_parser_cls + and enable_auto_tools + and (request.tool_choice == "auto" or request.tool_choice is None) + ): + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + + # Automatic Tool Call Parsing + try: + tool_parser = tool_parser_cls(tokenizer) + except RuntimeError as e: + logger.exception("Error in tool parser creation.") + raise e + tool_call_info = tool_parser.extract_tool_calls( + content if content is not None else "", + request=request, # type: ignore + ) + if tool_call_info is not None and tool_call_info.tools_called: + # extract_tool_calls() returns a list of tool calls. + function_calls.extend( + FunctionCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + for tool_call in tool_call_info.tool_calls + ) + content = tool_call_info.content + if content and content.strip() == "": + content = None + else: + # No tool calls. + return None, content + + return function_calls, content + + @staticmethod + def _get_decoded_token( + logprob: Logprob, + token_id: int, + tokenizer: TokenizerLike | None, + return_as_token_id: bool = False, + ) -> str: + if return_as_token_id: + return f"token_id:{token_id}" + + if logprob.decoded_token is not None: + return logprob.decoded_token + + if tokenizer is None: + raise ValueError( + "Unable to get tokenizer because `skip_tokenizer_init=True`" + ) + + return tokenizer.decode([token_id]) + + def _is_model_supported(self, model_name: str | None) -> bool: + if not model_name: + return True + return self.models.is_base_model(model_name) + def clamp_prompt_logprobs( prompt_logprobs: PromptLogprobs | None, diff --git a/vllm/entrypoints/openai/models/serving.py b/vllm/entrypoints/openai/models/serving.py index e99d8f7ac767..d7d9a8097cbc 100644 --- a/vllm/entrypoints/openai/models/serving.py +++ b/vllm/entrypoints/openai/models/serving.py @@ -5,7 +5,7 @@ from collections import defaultdict from http import HTTPStatus -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.openai.engine.protocol import ( ErrorInfo, ErrorResponse, @@ -38,6 +38,7 @@ class OpenAIServingModels: def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, base_model_paths: list[BaseModelPath], *, @@ -45,6 +46,7 @@ def __init__( ): super().__init__() + self.renderer_client = renderer_client self.engine_client = engine_client self.base_model_paths = base_model_paths @@ -59,10 +61,10 @@ def __init__( ) self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) - self.model_config = self.engine_client.model_config - self.renderer = self.engine_client.renderer - self.io_processor = self.engine_client.io_processor - self.input_processor = self.engine_client.input_processor + self.model_config = self.renderer_client.model_config + self.renderer = self.renderer_client.renderer + self.io_processor = self.renderer_client.io_processor + self.input_processor = self.renderer_client.input_processor async def init_static_loras(self): """Loads all static LoRA modules. diff --git a/vllm/entrypoints/openai/realtime/serving.py b/vllm/entrypoints/openai/realtime/serving.py index 4a00a50306d6..b9b155e08f28 100644 --- a/vllm/entrypoints/openai/realtime/serving.py +++ b/vllm/entrypoints/openai/realtime/serving.py @@ -8,9 +8,9 @@ import numpy as np -from vllm.engine.protocol import EngineClient, StreamingInput +from vllm.engine.protocol import EngineClient, RendererClient, StreamingInput from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.engine.serving import OpenAIServingInference +from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.inputs.data import PromptType from vllm.logger import init_logger @@ -20,7 +20,7 @@ logger = init_logger(__name__) -class OpenAIServingRealtime(OpenAIServingInference): +class OpenAIServingRealtime(OpenAIServing): """Realtime audio transcription service via WebSocket streaming. Provides streaming audio-to-text transcription by transforming audio chunks @@ -29,6 +29,7 @@ class OpenAIServingRealtime(OpenAIServingInference): def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -36,6 +37,7 @@ def __init__( log_error_stack: bool = False, ): super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 86049faa4fb8..b43c0887684c 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -54,7 +54,7 @@ ) from vllm.entrypoints.openai.engine.serving import ( GenerationError, - OpenAIServingInference, + OpenAIServing, ) from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.parser.harmony_utils import ( @@ -153,7 +153,7 @@ def _extract_allowed_tools_from_mcp_requests( return allowed_tools_map -class OpenAIServingResponses(OpenAIServingInference): +class OpenAIServingResponses(OpenAIServing): def __init__( self, renderer_client: RendererClient, diff --git a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py index ea0a60348893..a764737289d8 100644 --- a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py @@ -22,10 +22,7 @@ RequestResponseMetadata, UsageInfo, ) -from vllm.entrypoints.openai.engine.serving import ( - OpenAIServingInference, - SpeechToTextRequest, -) +from vllm.entrypoints.openai.engine.serving import OpenAIServing, SpeechToTextRequest from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.speech_to_text.protocol import ( TranscriptionResponse, @@ -79,7 +76,7 @@ logger = init_logger(__name__) -class OpenAISpeechToText(OpenAIServingInference): +class OpenAISpeechToText(OpenAIServing): """Base class for speech-to-text operations like transcription and translation.""" diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index a3ba96549a5f..8d42f7fbc081 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -7,11 +7,11 @@ import numpy as np from fastapi import Request -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo -from vllm.entrypoints.openai.engine.serving import OpenAIServingInference, ServeContext +from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.classify.protocol import ( ClassificationChatRequest, @@ -29,11 +29,12 @@ ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest] -class ServingClassification(OpenAIServingInference): +class ServingClassification(OpenAIServing): request_id_prefix = "classify" def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -44,6 +45,7 @@ def __init__( log_error_stack: bool = False, ) -> None: super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index cf812f7eb962..31b4c337ea27 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -9,11 +9,11 @@ from fastapi import Request from typing_extensions import assert_never -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo -from vllm.entrypoints.openai.engine.serving import OpenAIServingInference, ServeContext +from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.embed.protocol import ( EmbeddingBytesResponse, @@ -42,11 +42,12 @@ EmbeddingServeContext: TypeAlias = ServeContext[EmbeddingRequest] -class OpenAIServingEmbedding(OpenAIServingInference): +class OpenAIServingEmbedding(OpenAIServing): request_id_prefix = "embd" def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -57,6 +58,7 @@ def __init__( log_error_stack: bool = False, ) -> None: super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index 5427d7582ef5..d25e58351eb2 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -12,11 +12,11 @@ from fastapi import Request from typing_extensions import assert_never -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo -from vllm.entrypoints.openai.engine.serving import OpenAIServingInference +from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.pooling.protocol import ( IOProcessorRequest, @@ -43,9 +43,10 @@ logger = init_logger(__name__) -class OpenAIServingPooling(OpenAIServingInference): +class OpenAIServingPooling(OpenAIServing): def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -56,6 +57,7 @@ def __init__( log_error_stack: bool = False, ) -> None: super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index 7e86c553a346..f368bde87f4b 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -8,13 +8,13 @@ from fastapi import Request -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, UsageInfo, ) -from vllm.entrypoints.openai.engine.serving import OpenAIServingInference +from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.score.protocol import ( RerankDocument, @@ -47,9 +47,10 @@ logger = init_logger(__name__) -class ServingScores(OpenAIServingInference): +class ServingScores(OpenAIServing): def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -58,6 +59,7 @@ def __init__( log_error_stack: bool = False, ) -> None: super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, diff --git a/vllm/entrypoints/serve/disagg/serving.py b/vllm/entrypoints/serve/disagg/serving.py index 86b7ce351362..26fc835338fb 100644 --- a/vllm/entrypoints/serve/disagg/serving.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -9,7 +9,7 @@ from fastapi import Request -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionLogProb, @@ -22,10 +22,7 @@ RequestResponseMetadata, UsageInfo, ) -from vllm.entrypoints.openai.engine.serving import ( - OpenAIServingInference, - clamp_prompt_logprobs, -) +from vllm.entrypoints.openai.engine.serving import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.serve.disagg.protocol import ( GenerateRequest, @@ -41,11 +38,12 @@ logger = init_logger(__name__) -class ServingTokens(OpenAIServingInference): +class ServingTokens(OpenAIServing): """Provides Tokens IN <> Tokens OUT functionality to vLLM API.""" def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, *, @@ -57,6 +55,7 @@ def __init__( enable_log_outputs: bool = False, ): super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, @@ -85,8 +84,8 @@ async def serve_tokens( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.engine_client.errored: - raise self.engine_client.dead_error + if self.renderer_client.errored: + raise self.renderer_client.dead_error lora_request = None lora_request = self._maybe_get_adapters(request, supports_default_mm_loras=True) From 559ba405ab1b6b7a232ed8bc0119f1aea8c1e1b7 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Wed, 25 Feb 2026 13:31:04 +0200 Subject: [PATCH 06/53] Split EngineClient into RendererClient + EngineClient sibling ABCs Signed-off-by: Sage Ahrac --- vllm/entrypoints/anthropic/serving.py | 4 +- vllm/entrypoints/openai/api_server.py | 53 +++++++++++++++---- .../entrypoints/openai/generate/api_router.py | 32 ++++++----- .../entrypoints/openai/realtime/api_router.py | 8 +-- .../openai/speech_to_text/api_router.py | 13 +++-- vllm/entrypoints/pooling/__init__.py | 23 ++++---- vllm/entrypoints/serve/tokenize/serving.py | 6 ++- 7 files changed, 96 insertions(+), 43 deletions(-) diff --git a/vllm/entrypoints/anthropic/serving.py b/vllm/entrypoints/anthropic/serving.py index 8fb347aabed3..3f16d32a328c 100644 --- a/vllm/entrypoints/anthropic/serving.py +++ b/vllm/entrypoints/anthropic/serving.py @@ -13,7 +13,7 @@ from fastapi import Request -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.anthropic.protocol import ( AnthropicContentBlock, AnthropicDelta, @@ -51,6 +51,7 @@ class AnthropicServingMessages(OpenAIServingChat): def __init__( self, + renderer_client: RendererClient, engine_client: EngineClient, models: OpenAIServingModels, response_role: str, @@ -66,6 +67,7 @@ def __init__( enable_force_include_usage: bool = False, ): super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, response_role=response_role, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d76a7446d2a9..213290d391a7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -22,7 +22,7 @@ import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger @@ -289,12 +289,13 @@ def build_app( async def init_app_state( + renderer_client: RendererClient, engine_client: EngineClient, state: State, args: Namespace, supported_tasks: tuple["SupportedTask", ...] | None = None, ) -> None: - vllm_config = engine_client.vllm_config + vllm_config = renderer_client.vllm_config if supported_tasks is None: warnings.warn( "The 'supported_tasks' parameter was not provided to " @@ -319,6 +320,7 @@ async def init_app_state( BaseModelPath(name=name, model_path=args.model) for name in served_model_names ] + state.renderer_client = renderer_client state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.vllm_config = vllm_config @@ -334,14 +336,16 @@ async def init_app_state( lora_modules = process_lora_modules(args.lora_modules, default_mm_loras) state.openai_serving_models = OpenAIServingModels( + renderer_client=renderer_client, engine_client=engine_client, base_model_paths=base_model_paths, lora_modules=lora_modules, ) await state.openai_serving_models.init_static_loras() state.openai_serving_tokenization = OpenAIServingTokenization( - engine_client, - state.openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -353,7 +357,12 @@ async def init_app_state( from vllm.entrypoints.openai.generate.api_router import init_generate_state await init_generate_state( - engine_client, state, args, request_logger, supported_tasks + renderer_client, + engine_client, + state, + args, + request_logger, + supported_tasks, ) if "transcription" in supported_tasks: @@ -362,18 +371,37 @@ async def init_app_state( ) init_transcription_state( - engine_client, state, args, request_logger, supported_tasks + renderer_client, + engine_client, + state, + args, + request_logger, + supported_tasks, ) if "realtime" in supported_tasks: from vllm.entrypoints.openai.realtime.api_router import init_realtime_state - init_realtime_state(engine_client, state, args, request_logger, supported_tasks) + init_realtime_state( + renderer_client, + engine_client, + state, + args, + request_logger, + supported_tasks, + ) if any(task in POOLING_TASKS for task in supported_tasks): from vllm.entrypoints.pooling import init_pooling_state - init_pooling_state(engine_client, state, args, request_logger, supported_tasks) + init_pooling_state( + renderer_client, + engine_client, + state, + args, + request_logger, + supported_tasks, + ) state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 @@ -491,11 +519,16 @@ async def run_server_worker( args, client_config=client_config, ) as engine_client: - supported_tasks = await engine_client.get_supported_tasks() + # In co-located mode, AsyncLLM implements both RendererClient + # (CPU ops) and EngineClient (inference ops). + renderer_client = engine_client + supported_tasks = await renderer_client.get_supported_tasks() logger.info("Supported tasks: %s", supported_tasks) app = build_app(args, supported_tasks) - await init_app_state(engine_client, app.state, args, supported_tasks) + await init_app_state( + renderer_client, engine_client, app.state, args, supported_tasks + ) logger.info( "Starting vLLM API server %d on %s", diff --git a/vllm/entrypoints/openai/generate/api_router.py b/vllm/entrypoints/openai/generate/api_router.py index ac74c7582058..f8f92206789c 100644 --- a/vllm/entrypoints/openai/generate/api_router.py +++ b/vllm/entrypoints/openai/generate/api_router.py @@ -9,7 +9,7 @@ from starlette.datastructures import State - from vllm.engine.protocol import EngineClient + from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.logger import RequestLogger from vllm.tasks import SupportedTask else: @@ -43,6 +43,7 @@ def register_generate_api_routers(app: FastAPI): async def init_generate_state( + renderer_client: "RendererClient", engine_client: "EngineClient", state: "State", args: "Namespace", @@ -74,8 +75,9 @@ async def init_generate_state( state.openai_serving_responses = ( OpenAIServingResponses( - engine_client, - state.openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -94,9 +96,10 @@ async def init_generate_state( ) state.openai_serving_chat = ( OpenAIServingChat( - engine_client, - state.openai_serving_models, - args.response_role, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, + response_role=args.response_role, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -121,8 +124,9 @@ async def init_generate_state( await state.openai_serving_chat.warmup() state.openai_serving_completion = ( OpenAIServingCompletion( - engine_client, - state.openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_prompt_tokens_details=args.enable_prompt_tokens_details, @@ -134,9 +138,10 @@ async def init_generate_state( ) state.anthropic_serving_messages = ( AnthropicServingMessages( - engine_client, - state.openai_serving_models, - args.response_role, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, + response_role=args.response_role, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -152,8 +157,9 @@ async def init_generate_state( ) state.serving_tokens = ( ServingTokens( - engine_client, - state.openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, log_error_stack=args.log_error_stack, diff --git a/vllm/entrypoints/openai/realtime/api_router.py b/vllm/entrypoints/openai/realtime/api_router.py index fb7decbd707a..0cf1b564d0ab 100644 --- a/vllm/entrypoints/openai/realtime/api_router.py +++ b/vllm/entrypoints/openai/realtime/api_router.py @@ -16,7 +16,7 @@ from starlette.datastructures import State - from vllm.engine.protocol import EngineClient + from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.logger import RequestLogger from vllm.tasks import SupportedTask else: @@ -57,6 +57,7 @@ def attach_router(app: FastAPI): def init_realtime_state( + renderer_client: "RendererClient", engine_client: "EngineClient", state: "State", args: "Namespace", @@ -65,8 +66,9 @@ def init_realtime_state( ): state.openai_serving_realtime = ( OpenAIServingRealtime( - engine_client, - state.openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, request_logger=request_logger, log_error_stack=args.log_error_stack, ) diff --git a/vllm/entrypoints/openai/speech_to_text/api_router.py b/vllm/entrypoints/openai/speech_to_text/api_router.py index 7477b79c08b0..58496e408e87 100644 --- a/vllm/entrypoints/openai/speech_to_text/api_router.py +++ b/vllm/entrypoints/openai/speech_to_text/api_router.py @@ -30,7 +30,7 @@ from starlette.datastructures import State - from vllm.engine.protocol import EngineClient + from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.logger import RequestLogger from vllm.tasks import SupportedTask else: @@ -129,6 +129,7 @@ def attach_router(app: FastAPI): def init_transcription_state( + renderer_client: "RendererClient", engine_client: "EngineClient", state: "State", args: "Namespace", @@ -137,8 +138,9 @@ def init_transcription_state( ): state.openai_serving_transcription = ( OpenAIServingTranscription( - engine_client, - state.openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, request_logger=request_logger, log_error_stack=args.log_error_stack, enable_force_include_usage=args.enable_force_include_usage, @@ -148,8 +150,9 @@ def init_transcription_state( ) state.openai_serving_translation = ( OpenAIServingTranslation( - engine_client, - state.openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, request_logger=request_logger, log_error_stack=args.log_error_stack, enable_force_include_usage=args.enable_force_include_usage, diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py index 1108be175bc6..3aa0e4eef40e 100644 --- a/vllm/entrypoints/pooling/__init__.py +++ b/vllm/entrypoints/pooling/__init__.py @@ -10,7 +10,7 @@ from starlette.datastructures import State - from vllm.engine.protocol import EngineClient + from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.logger import RequestLogger from vllm.tasks import SupportedTask else: @@ -48,6 +48,7 @@ def register_pooling_api_routers( def init_pooling_state( + renderer_client: "RendererClient", engine_client: "EngineClient", state: "State", args: "Namespace", @@ -66,8 +67,9 @@ def init_pooling_state( state.openai_serving_pooling = ( ( OpenAIServingPooling( - engine_client, - state.openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -80,8 +82,9 @@ def init_pooling_state( ) state.openai_serving_embedding = ( OpenAIServingEmbedding( - engine_client, - state.openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -93,8 +96,9 @@ def init_pooling_state( ) state.openai_serving_classification = ( ServingClassification( - engine_client, - state.openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -110,8 +114,9 @@ def init_pooling_state( # - "token_embed" task (late interaction models like ColBERT) state.openai_serving_scores = ( ServingScores( - engine_client, - state.openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=state.openai_serving_models, request_logger=request_logger, score_template=resolved_chat_template, log_error_stack=args.log_error_stack, diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index 40dbb7e5e29d..23cd55fbce23 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -6,7 +6,7 @@ import jinja2 from fastapi import Request -from vllm.engine.protocol import RendererClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import ErrorResponse @@ -30,7 +30,8 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - engine_client: RendererClient, + renderer_client: RendererClient, + engine_client: EngineClient, models: OpenAIServingModels, *, request_logger: RequestLogger | None, @@ -40,6 +41,7 @@ def __init__( log_error_stack: bool = False, ) -> None: super().__init__( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=request_logger, From f38e0bc3dd08368f4f834999a920f0e52a33e358 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Wed, 25 Feb 2026 15:27:39 +0200 Subject: [PATCH 07/53] AsyncRenderer Signed-off-by: Sage Ahrac --- .../entrypoints/openai/test_serving_models.py | 12 +-- .../openai/test_serving_responses.py | 33 +++++---- tests/lora/test_add_lora.py | 2 +- tests/lora/test_lora_functions.py | 2 +- vllm/benchmarks/throughput.py | 22 ++---- vllm/engine/protocol.py | 55 ++++++++------ vllm/entrypoints/openai/api_server.py | 43 ++++------- .../openai/chat_completion/serving.py | 4 +- vllm/entrypoints/openai/cli_args.py | 3 - vllm/entrypoints/openai/completion/serving.py | 4 +- vllm/entrypoints/openai/responses/serving.py | 4 +- vllm/entrypoints/openai/run_batch.py | 41 ++++++---- .../openai/speech_to_text/speech_to_text.py | 4 +- vllm/entrypoints/serve/disagg/serving.py | 4 +- vllm/v1/engine/async_llm.py | 74 +++++++++++++++---- 15 files changed, 178 insertions(+), 129 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index f6755f489343..ea3d0ef4fb81 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -7,7 +7,7 @@ import pytest from vllm.config import ModelConfig -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, ) @@ -28,16 +28,18 @@ async def _async_serving_models_init() -> OpenAIServingModels: + mock_renderer_client = MagicMock(spec=RendererClient) mock_engine_client = MagicMock(spec=EngineClient) # Set the max_model_len attribute to avoid missing attribute mock_model_config = MagicMock(spec=ModelConfig) mock_model_config.max_model_len = 2048 - mock_engine_client.model_config = mock_model_config - mock_engine_client.input_processor = MagicMock() - mock_engine_client.io_processor = MagicMock() - mock_engine_client.renderer = MagicMock() + mock_renderer_client.model_config = mock_model_config + mock_renderer_client.input_processor = MagicMock() + mock_renderer_client.io_processor = MagicMock() + mock_renderer_client.renderer = MagicMock() serving_models = OpenAIServingModels( + renderer_client=mock_renderer_client, engine_client=mock_engine_client, base_model_paths=BASE_MODEL_PATHS, lora_modules=None, diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index 5cf07ac0f6a3..79c4d6d73a39 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -14,6 +14,7 @@ ) import vllm.envs as envs +from vllm.engine.protocol import RendererClient from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, @@ -128,17 +129,18 @@ class TestInitializeToolSessions: async def serving_responses_instance(self): """Create a real OpenAIServingResponses instance for testing""" # Create minimal mocks for required dependencies + renderer_client = MagicMock(spec=RendererClient) engine_client = MagicMock() model_config = MagicMock() model_config.max_model_len = 100 model_config.hf_config.model_type = "test" model_config.get_diff_sampling_param.return_value = {} - engine_client.model_config = model_config + renderer_client.model_config = model_config - engine_client.input_processor = MagicMock() - engine_client.io_processor = MagicMock() - engine_client.renderer = MagicMock() + renderer_client.input_processor = MagicMock() + renderer_client.io_processor = MagicMock() + renderer_client.renderer = MagicMock() models = MagicMock() @@ -146,6 +148,7 @@ async def serving_responses_instance(self): # Create the actual instance instance = OpenAIServingResponses( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=None, @@ -216,22 +219,24 @@ class TestValidateGeneratorInput: async def serving_responses_instance(self): """Create a real OpenAIServingResponses instance for testing""" # Create minimal mocks for required dependencies + renderer_client = MagicMock(spec=RendererClient) engine_client = MagicMock() model_config = MagicMock() model_config.max_model_len = 100 model_config.hf_config.model_type = "test" model_config.get_diff_sampling_param.return_value = {} - engine_client.model_config = model_config + renderer_client.model_config = model_config - engine_client.input_processor = MagicMock() - engine_client.io_processor = MagicMock() - engine_client.renderer = MagicMock() + renderer_client.input_processor = MagicMock() + renderer_client.io_processor = MagicMock() + renderer_client.renderer = MagicMock() models = MagicMock() # Create the actual instance instance = OpenAIServingResponses( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=None, @@ -279,22 +284,24 @@ def get_vocab(self): # Force non-harmony, SimpleContext path monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False) + renderer_client = MagicMock(spec=RendererClient) engine_client = MagicMock() model_config = MagicMock() model_config.hf_config.model_type = "test" model_config.hf_text_config = MagicMock() model_config.get_diff_sampling_param.return_value = {} - engine_client.model_config = model_config - engine_client.input_processor = MagicMock() - engine_client.io_processor = MagicMock() - engine_client.renderer = MagicMock() + renderer_client.model_config = model_config + renderer_client.input_processor = MagicMock() + renderer_client.io_processor = MagicMock() + renderer_client.renderer = MagicMock() tokenizer = FakeTokenizer() - engine_client.renderer.get_tokenizer.return_value = tokenizer + renderer_client.renderer.get_tokenizer.return_value = tokenizer models = MagicMock() serving = OpenAIServingResponses( + renderer_client=renderer_client, engine_client=engine_client, models=models, request_logger=None, diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index 9a82ab99ea9c..38d12fb2d36d 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -86,7 +86,7 @@ async def test_add_lora(chatglm3_lora_files): warmup_run_requests = lora_requests[part_size : part_size * 2] cold_run_requests = lora_requests[part_size * 2 :] - async with build_async_engine_client_from_engine_args(engine_args) as llm: + async with build_async_engine_client_from_engine_args(engine_args) as (_, llm): # Dummy run - So any 1-time functionality like triton kernel compilation # is complete here. await requests_processing_time(llm, dummy_run_requests) diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py index 1c692630284d..c5419e09f74f 100644 --- a/tests/lora/test_lora_functions.py +++ b/tests/lora/test_lora_functions.py @@ -88,7 +88,7 @@ async def run_check(fn, args, expected: list): await fn(args) assert set(await llm.list_loras()) == set(expected) - async with build_async_engine_client_from_engine_args(engine_args) as llm: + async with build_async_engine_client_from_engine_args(engine_args) as (_, llm): await run_check(llm.add_lora, make_lora_request(1), [1]) await run_check(llm.add_lora, make_lora_request(2), [1, 2]) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 3c0fea8e0111..be60ae4ca6d6 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -180,7 +180,6 @@ async def run_vllm_async( n: int, engine_args: AsyncEngineArgs, do_profile: bool, - disable_frontend_multiprocessing: bool = False, disable_detokenize: bool = False, ) -> float: from vllm import SamplingParams @@ -190,9 +189,8 @@ async def run_vllm_async( async with build_async_engine_client_from_engine_args( engine_args, - disable_frontend_multiprocessing=disable_frontend_multiprocessing, - ) as llm: - model_config = llm.model_config + ) as (renderer_client, engine_client): + model_config = renderer_client.model_config assert all( model_config.max_model_len >= (request.prompt_len + request.expected_output_len) @@ -233,17 +231,19 @@ async def run_vllm_async( generators = [] start = time.perf_counter() if do_profile: - await llm.start_profile() + await engine_client.start_profile() for i, (prompt, sp, lr) in enumerate( zip(prompts, sampling_params, lora_requests) ): - generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}") + generator = engine_client.generate( + prompt, sp, lora_request=lr, request_id=f"test{i}" + ) generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: pass if do_profile: - await llm.stop_profile() + await engine_client.stop_profile() end = time.perf_counter() return end - start @@ -745,12 +745,7 @@ def add_cli_args(parser: argparse.ArgumentParser): default=False, help="Use vLLM async engine rather than LLM class.", ) - parser.add_argument( - "--disable-frontend-multiprocessing", - action="store_true", - default=False, - help="Disable decoupled async engine frontend.", - ) + parser.add_argument( "--disable-detokenize", action="store_true", @@ -859,7 +854,6 @@ def main(args: argparse.Namespace): requests, args.n, AsyncEngineArgs.from_cli_args(args), - disable_frontend_multiprocessing=args.disable_frontend_multiprocessing, disable_detokenize=args.disable_detokenize, do_profile=args.profile, ) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 1a9275ffaeff..179f383c675a 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -38,22 +38,12 @@ class StreamingInput: sampling_params: SamplingParams | None = None -class RendererClient(ABC): - """Client interface for the renderer layer (CPU-only operations). - - Contains only methods and attributes that don't require a running - inference engine: configuration, tokenization, health checks, and - status monitoring. +class Client(ABC): + """Base client interface for liveness and health monitoring. - See :class:`EngineClient` for the full interface including inference. + Shared by both :class:`RendererClient` and :class:`EngineClient`. """ - vllm_config: VllmConfig - model_config: ModelConfig - renderer: BaseRenderer - io_processor: IOProcessor | None - input_processor: InputProcessor - @property @abstractmethod def is_running(self) -> bool: ... @@ -71,24 +61,45 @@ def errored(self) -> bool: ... def dead_error(self) -> BaseException: ... @abstractmethod - async def is_tracing_enabled(self) -> bool: ... + async def check_health(self) -> None: + """Raise if unhealthy""" + ... + + +class RendererClient(Client): + """Client interface for the renderer layer (CPU-only operations). + + Covers configuration, tokenization, and tracing — everything that + does not require a running inference engine. + + See :class:`EngineClient` for the tok-in/tok-out inference interface. + """ + + vllm_config: VllmConfig + model_config: ModelConfig + renderer: BaseRenderer + io_processor: IOProcessor | None + input_processor: InputProcessor @abstractmethod - async def do_log_stats(self) -> None: ... + async def is_tracing_enabled(self) -> bool: ... + + +class EngineClient(Client): + """Engine client interface for tok-in/tok-out inference operations. + + Covers generation, encoding, LoRA management, and engine control. + Does not extend :class:`RendererClient`; the two interfaces are + independently implementable for disaggregated prefill deployments. + """ @abstractmethod - async def check_health(self) -> None: - """Raise if unhealthy""" - ... + async def do_log_stats(self) -> None: ... async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: """Get supported tasks""" raise NotImplementedError - -class EngineClient(ABC): - """Engine client interface for inference operations.""" - @abstractmethod def generate( self, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 213290d391a7..18677de11093 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -71,9 +71,8 @@ async def build_async_engine_client( args: Namespace, *, usage_context: UsageContext = UsageContext.OPENAI_API_SERVER, - disable_frontend_multiprocessing: bool | None = None, client_config: dict[str, Any] | None = None, -) -> AsyncIterator[EngineClient]: +) -> AsyncIterator[tuple[RendererClient, EngineClient]]: if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver": # The executor is expected to be mp. # Pre-import heavy modules in the forkserver process @@ -90,16 +89,12 @@ async def build_async_engine_client( engine_args._api_process_count = client_config.get("client_count", 1) engine_args._api_process_rank = client_config.get("client_index", 0) - if disable_frontend_multiprocessing is None: - disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing) - async with build_async_engine_client_from_engine_args( engine_args, usage_context=usage_context, - disable_frontend_multiprocessing=disable_frontend_multiprocessing, client_config=client_config, - ) as engine: - yield engine + ) as clients: + yield clients @asynccontextmanager @@ -107,25 +102,16 @@ async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, *, usage_context: UsageContext = UsageContext.OPENAI_API_SERVER, - disable_frontend_multiprocessing: bool = False, client_config: dict[str, Any] | None = None, -) -> AsyncIterator[EngineClient]: - """ - Create EngineClient, either: - - in-process using the AsyncLLMEngine Directly - - multiprocess using AsyncLLMEngine RPC - - Returns the Client or None if the creation failed. - """ +) -> AsyncIterator[tuple[RendererClient, EngineClient]]: + """Create a co-located (RendererClient, EngineClient) pair backed by AsyncLLM.""" # Create the EngineConfig (determines if we can use V1). vllm_config = engine_args.create_engine_config(usage_context=usage_context) - if disable_frontend_multiprocessing: - logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.") - - from vllm.v1.engine.async_llm import AsyncLLM + from vllm.v1.engine.async_llm import AsyncLLM, AsyncRenderer + async_renderer: AsyncRenderer | None = None async_llm: AsyncLLM | None = None # Don't mutate the input client_config @@ -134,6 +120,7 @@ async def build_async_engine_client_from_engine_args( client_index = client_config.pop("client_index", 0) try: + async_renderer = AsyncRenderer(vllm_config) async_llm = AsyncLLM.from_vllm_config( vllm_config=vllm_config, usage_context=usage_context, @@ -146,13 +133,14 @@ async def build_async_engine_client_from_engine_args( ) # Don't keep the dummy data in memory - assert async_llm is not None await async_llm.reset_mm_cache() - yield async_llm + yield async_renderer, async_llm finally: if async_llm: async_llm.shutdown() + if async_renderer: + async_renderer.shutdown() def build_app( @@ -518,11 +506,8 @@ async def run_server_worker( async with build_async_engine_client( args, client_config=client_config, - ) as engine_client: - # In co-located mode, AsyncLLM implements both RendererClient - # (CPU ops) and EngineClient (inference ops). - renderer_client = engine_client - supported_tasks = await renderer_client.get_supported_tasks() + ) as (renderer_client, engine_client): + supported_tasks = await engine_client.get_supported_tasks() logger.info("Supported tasks: %s", supported_tasks) app = build_app(args, supported_tasks) @@ -532,7 +517,7 @@ async def run_server_worker( logger.info( "Starting vLLM API server %d on %s", - engine_client.vllm_config.parallel_config._api_process_rank, + renderer_client.vllm_config.parallel_config._api_process_rank, listen_address, ) shutdown_task = await serve_http( diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index cdb1ba32a2be..f05b7bb8109e 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -234,8 +234,8 @@ async def render_chat_request( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.renderer_client.errored: - raise self.renderer_client.dead_error + if self.engine_client.errored: + raise self.engine_client.dead_error try: tokenizer = self.renderer.tokenizer diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 983040a89dcf..e0be36835e60 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -150,9 +150,6 @@ class FrontendArgs: """When `--max-logprobs` is specified, represents single tokens as strings of the form 'token_id:{token_id}' so that tokens that are not JSON-encodable can be identified.""" - disable_frontend_multiprocessing: bool = False - """If specified, will run the OpenAI frontend server in the same process as - the model serving engine.""" enable_request_id_headers: bool = False """If specified, API server will add X-Request-Id header to responses.""" enable_auto_tool_choice: bool = False diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index 67c9a9a58ecf..8fda80969715 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -97,8 +97,8 @@ async def render_completion_request( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.renderer_client.errored: - raise self.renderer_client.dead_error + if self.engine_client.errored: + raise self.engine_client.dead_error # Return error for unsupported features. if request.suffix is not None: diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index b43c0887684c..d39b08a3299f 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -341,8 +341,8 @@ async def create_responses( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.renderer_client.errored: - raise self.renderer_client.dead_error + if self.engine_client.errored: + raise self.engine_client.dead_error if request.store and not self.enable_store: # Disable the store option. diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 747025750e45..f9abae10272b 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -20,7 +20,7 @@ from tqdm import tqdm from vllm.engine.arg_utils import AsyncEngineArgs, optional_type -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, @@ -672,6 +672,7 @@ async def transcription_wrapper( def build_endpoint_registry( + renderer_client: RendererClient, engine_client: EngineClient, args: Namespace, base_model_paths: list[BaseModelPath], @@ -682,6 +683,7 @@ def build_endpoint_registry( Build the endpoint registry with all serving objects and handler configurations. Args: + renderer_client: The renderer client engine_client: The engine client args: Command line arguments base_model_paths: List of base model paths @@ -691,10 +693,11 @@ def build_endpoint_registry( Returns: Dictionary mapping endpoint keys to their configurations """ - model_config = engine_client.model_config + model_config = renderer_client.model_config # Create the openai serving objects. openai_serving_models = OpenAIServingModels( + renderer_client=renderer_client, engine_client=engine_client, base_model_paths=base_model_paths, lora_modules=None, @@ -702,9 +705,10 @@ def build_endpoint_registry( openai_serving_chat = ( OpenAIServingChat( - engine_client, - openai_serving_models, - args.response_role, + renderer_client=renderer_client, + engine_client=engine_client, + models=openai_serving_models, + response_role=args.response_role, request_logger=request_logger, chat_template=None, chat_template_content_format="auto", @@ -721,8 +725,9 @@ def build_endpoint_registry( openai_serving_embedding = ( OpenAIServingEmbedding( - engine_client, - openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=openai_serving_models, request_logger=request_logger, chat_template=None, chat_template_content_format="auto", @@ -738,8 +743,9 @@ def build_endpoint_registry( openai_serving_scores = ( ServingScores( - engine_client, - openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=openai_serving_models, request_logger=request_logger, score_template=None, ) @@ -749,8 +755,9 @@ def build_endpoint_registry( openai_serving_transcription = ( OpenAIServingTranscription( - engine_client, - openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=openai_serving_models, request_logger=request_logger, enable_force_include_usage=args.enable_force_include_usage, ) @@ -760,8 +767,9 @@ def build_endpoint_registry( openai_serving_translation = ( OpenAIServingTranslation( - engine_client, - openai_serving_models, + renderer_client=renderer_client, + engine_client=engine_client, + models=openai_serving_models, request_logger=request_logger, enable_force_include_usage=args.enable_force_include_usage, ) @@ -842,6 +850,7 @@ def validate_run_batch_args(args): async def run_batch( + renderer_client: RendererClient, engine_client: EngineClient, args: Namespace, ) -> None: @@ -863,6 +872,7 @@ async def run_batch( logger.info("Supported tasks: %s", supported_tasks) endpoint_registry = build_endpoint_registry( + renderer_client=renderer_client, engine_client=engine_client, args=args, base_model_paths=base_model_paths, @@ -927,9 +937,8 @@ async def main(args: Namespace): async with build_async_engine_client( args, usage_context=UsageContext.OPENAI_BATCH_RUNNER, - disable_frontend_multiprocessing=False, - ) as engine_client: - await run_batch(engine_client, args) + ) as (renderer_client, engine_client): + await run_batch(renderer_client, engine_client, args) if __name__ == "__main__": diff --git a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py index a764737289d8..1223ed80faf2 100644 --- a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py @@ -477,8 +477,8 @@ async def _create_speech_to_text( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.renderer_client.errored: - raise self.renderer_client.dead_error + if self.engine_client.errored: + raise self.engine_client.dead_error if request.response_format not in ["text", "json", "verbose_json"]: return self.create_error_response( diff --git a/vllm/entrypoints/serve/disagg/serving.py b/vllm/entrypoints/serve/disagg/serving.py index 26fc835338fb..4f73df7d69bf 100644 --- a/vllm/entrypoints/serve/disagg/serving.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -84,8 +84,8 @@ async def serve_tokens( # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). - if self.renderer_client.errored: - raise self.renderer_client.dead_error + if self.engine_client.errored: + raise self.engine_client.dead_error lora_request = None lora_request = self._maybe_get_adapters(request, supports_default_mm_loras=True) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a4793fe1723c..a06c46a05c7c 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -31,7 +31,6 @@ from vllm.renderers.inputs.preprocess import extract_prompt_components from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.usage.usage_lib import UsageContext @@ -67,7 +66,7 @@ def __init__(self, cause: Exception): super().__init__(str(cause)) -class AsyncLLM(RendererClient, EngineClient): +class AsyncLLM(EngineClient): """An asynchronous wrapper for the vLLM engine.""" def __init__( @@ -131,18 +130,18 @@ def __init__( "enabling logging without default stat loggers." ) - self.renderer = renderer = renderer_from_config(self.vllm_config) + self.renderer = renderer_from_config(self.vllm_config) self.io_processor = get_io_processor( self.vllm_config, self.model_config.io_processor_plugin, ) # Convert TokPrompt --> EngineCoreRequest. - self.input_processor = InputProcessor(self.vllm_config, renderer) + self.input_processor = InputProcessor(self.vllm_config, self.renderer) # Converts EngineCoreOutputs --> RequestOutput. self.output_processor = OutputProcessor( - renderer.tokenizer, + self.renderer.tokenizer, log_stats=self.log_stats, stream_interval=self.vllm_config.scheduler_config.stream_interval, tracing_enabled=tracing_endpoint is not None, @@ -845,16 +844,6 @@ async def encode( if q is not None: q.close() - @property - def tokenizer(self) -> TokenizerLike | None: - return self.renderer.tokenizer - - def get_tokenizer(self) -> TokenizerLike: - return self.renderer.get_tokenizer() - - async def is_tracing_enabled(self) -> bool: - return self.observability_config.otlp_traces_endpoint is not None - async def do_log_stats(self) -> None: if self.logger_manager: self.logger_manager.log() @@ -1052,3 +1041,58 @@ async def update_weights(self, request: WeightTransferUpdateRequest) -> None: await self.collective_rpc( "update_weights", kwargs={"update_info": update_info_dict} ) + + +class AsyncRenderer(RendererClient): + """Standalone RendererClient built directly from a VllmConfig. + + Owns the renderer, io_processor, and input_processor — all CPU-only + resources. Does not depend on :class:`AsyncLLM` or any inference engine. + In a disaggregated deployment this class would be replaced by a remote stub + that talks to a dedicated renderer process over the network. + """ + + def __init__(self, vllm_config: VllmConfig) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.renderer = renderer = renderer_from_config(vllm_config) + self.io_processor = get_io_processor( + vllm_config, + self.model_config.io_processor_plugin, + ) + self.input_processor = InputProcessor(vllm_config, renderer) + self._observability_config = vllm_config.observability_config + + tracing_endpoint = self._observability_config.otlp_traces_endpoint + if tracing_endpoint is not None: + init_tracer("vllm.llm_engine", tracing_endpoint) + + # Client base (liveness) — renderer has no long-running background process + + @property + def is_running(self) -> bool: + return True + + @property + def is_stopped(self) -> bool: + return False + + @property + def errored(self) -> bool: + return False + + @property + def dead_error(self) -> BaseException: + raise RuntimeError("AsyncRenderer has no error state") + + async def check_health(self) -> None: + pass # no background process to check + + def shutdown(self) -> None: + if renderer := getattr(self, "renderer", None): + renderer.shutdown() + + # RendererClient methods --------------------------------------------------- + + async def is_tracing_enabled(self) -> bool: + return self._observability_config.otlp_traces_endpoint is not None From 4afc9248a313c100a058b1ec7133d38a9e3e91c9 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Wed, 25 Feb 2026 17:07:04 +0200 Subject: [PATCH 08/53] fix after merge Signed-off-by: Sage Ahrac --- vllm/entrypoints/openai/cli_args.py | 3 --- vllm/entrypoints/openai/run_batch.py | 2 +- vllm/v1/engine/async_llm.py | 2 -- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index eac581e5da9b..e4dcade889d1 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -105,9 +105,6 @@ class BaseFrontendArgs: """When `--max-logprobs` is specified, represents single tokens as strings of the form 'token_id:{token_id}' so that tokens that are not JSON-encodable can be identified.""" - disable_frontend_multiprocessing: bool = False - """If specified, will run the OpenAI frontend server in the same process as - the model serving engine.""" enable_auto_tool_choice: bool = False """Enable auto tool choice for supported models. Use `--tool-call-parser` to specify which parser to use.""" diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 91a29575b550..8a995cfc8371 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -676,7 +676,7 @@ async def build_endpoint_registry( # Initialize all serving objects using init_app_state # This provides full functionality including chat template processing, # LoRA support, tool servers, etc. - await init_app_state(engine_client, state, args, supported_tasks) + await init_app_state(renderer_client, engine_client, state, args, supported_tasks) # Get serving objects from state (defaulting to None if not set) openai_serving_chat = getattr(state, "openai_serving_chat", None) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b011fe41b45f..e8206a456055 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1097,7 +1097,5 @@ def shutdown(self) -> None: if renderer := getattr(self, "renderer", None): renderer.shutdown() - # RendererClient methods --------------------------------------------------- - async def is_tracing_enabled(self) -> bool: return self._observability_config.otlp_traces_endpoint is not None From cb3d2a8fd20b9dde1c8f19cd9f67cf242d984867 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Wed, 25 Feb 2026 17:22:46 +0200 Subject: [PATCH 09/53] revert disable_frontend_multiprocessing Signed-off-by: Sage Ahrac --- vllm/benchmarks/throughput.py | 3 +++ vllm/entrypoints/openai/api_server.py | 9 +++++++++ vllm/entrypoints/openai/cli_args.py | 3 +++ vllm/entrypoints/openai/run_batch.py | 1 + 4 files changed, 16 insertions(+) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index be60ae4ca6d6..aa82672600d5 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -180,6 +180,7 @@ async def run_vllm_async( n: int, engine_args: AsyncEngineArgs, do_profile: bool, + disable_frontend_multiprocessing: bool = False, disable_detokenize: bool = False, ) -> float: from vllm import SamplingParams @@ -189,6 +190,7 @@ async def run_vllm_async( async with build_async_engine_client_from_engine_args( engine_args, + disable_frontend_multiprocessing=disable_frontend_multiprocessing, ) as (renderer_client, engine_client): model_config = renderer_client.model_config assert all( @@ -854,6 +856,7 @@ def main(args: argparse.Namespace): requests, args.n, AsyncEngineArgs.from_cli_args(args), + disable_frontend_multiprocessing=args.disable_frontend_multiprocessing, disable_detokenize=args.disable_detokenize, do_profile=args.profile, ) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 18677de11093..b4f99b6b0dcc 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -71,6 +71,7 @@ async def build_async_engine_client( args: Namespace, *, usage_context: UsageContext = UsageContext.OPENAI_API_SERVER, + disable_frontend_multiprocessing: bool | None = None, client_config: dict[str, Any] | None = None, ) -> AsyncIterator[tuple[RendererClient, EngineClient]]: if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver": @@ -89,9 +90,13 @@ async def build_async_engine_client( engine_args._api_process_count = client_config.get("client_count", 1) engine_args._api_process_rank = client_config.get("client_index", 0) + if disable_frontend_multiprocessing is None: + disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing) + async with build_async_engine_client_from_engine_args( engine_args, usage_context=usage_context, + disable_frontend_multiprocessing=disable_frontend_multiprocessing, client_config=client_config, ) as clients: yield clients @@ -102,6 +107,7 @@ async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, *, usage_context: UsageContext = UsageContext.OPENAI_API_SERVER, + disable_frontend_multiprocessing: bool = False, client_config: dict[str, Any] | None = None, ) -> AsyncIterator[tuple[RendererClient, EngineClient]]: """Create a co-located (RendererClient, EngineClient) pair backed by AsyncLLM.""" @@ -109,6 +115,9 @@ async def build_async_engine_client_from_engine_args( # Create the EngineConfig (determines if we can use V1). vllm_config = engine_args.create_engine_config(usage_context=usage_context) + if disable_frontend_multiprocessing: + logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.") + from vllm.v1.engine.async_llm import AsyncLLM, AsyncRenderer async_renderer: AsyncRenderer | None = None diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index e4dcade889d1..eac581e5da9b 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -105,6 +105,9 @@ class BaseFrontendArgs: """When `--max-logprobs` is specified, represents single tokens as strings of the form 'token_id:{token_id}' so that tokens that are not JSON-encodable can be identified.""" + disable_frontend_multiprocessing: bool = False + """If specified, will run the OpenAI frontend server in the same process as + the model serving engine.""" enable_auto_tool_choice: bool = False """Enable auto tool choice for supported models. Use `--tool-call-parser` to specify which parser to use.""" diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 8a995cfc8371..2e24df55d8e3 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -825,6 +825,7 @@ async def main(args: Namespace): async with build_async_engine_client( args, usage_context=UsageContext.OPENAI_BATCH_RUNNER, + disable_frontend_multiprocessing=False, ) as (renderer_client, engine_client): await run_batch(renderer_client, engine_client, args) From 7d267c25691780994c81f6f69ae0b521ea5db853 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Wed, 25 Feb 2026 20:20:01 +0200 Subject: [PATCH 10/53] revert disable_frontend_multiprocessing flag in benchamrks Signed-off-by: Sage Ahrac --- vllm/benchmarks/throughput.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index aa82672600d5..41fce28e9835 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -747,6 +747,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=False, help="Use vLLM async engine rather than LLM class.", ) + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + default=False, + help="Disable decoupled async engine frontend.", + ) parser.add_argument( "--disable-detokenize", From 9fb01a74c6d2775d56fbc0682810ac383f407dbc Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Mon, 2 Mar 2026 13:40:27 +0200 Subject: [PATCH 11/53] fix get_world_size Signed-off-by: Sage Ahrac --- vllm/entrypoints/serve/rlhf/api_router.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/serve/rlhf/api_router.py b/vllm/entrypoints/serve/rlhf/api_router.py index 64a1dd20fdc7..b4d432bc62e3 100644 --- a/vllm/entrypoints/serve/rlhf/api_router.py +++ b/vllm/entrypoints/serve/rlhf/api_router.py @@ -13,7 +13,7 @@ WeightTransferInitRequest, WeightTransferUpdateRequest, ) -from vllm.engine.protocol import EngineClient +from vllm.engine.protocol import EngineClient, RendererClient from vllm.logger import init_logger from vllm.v1.engine import PauseMode @@ -24,6 +24,10 @@ def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client +def renderer_client(request: Request) -> RendererClient: + return request.app.state.renderer_client + + router = APIRouter() @@ -158,7 +162,7 @@ async def get_world_size( data parallelism (TP * PP * DP). If False, returns the world size without data parallelism (TP * PP). """ - parallel_config = engine_client(raw_request).vllm_config.parallel_config + parallel_config = renderer_client(raw_request).vllm_config.parallel_config if include_dp: world_size = parallel_config.world_size_across_dp else: From f516e9bd420ed1b3fbb965f01e49a89572901efc Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Mon, 2 Mar 2026 14:15:33 +0200 Subject: [PATCH 12/53] asyncllm backward compatibility Signed-off-by: Sage Ahrac --- vllm/v1/engine/async_llm.py | 40 +++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 685c6fa8c6ae..87c6745642a8 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -32,6 +32,7 @@ from vllm.renderers.inputs.preprocess import extract_prompt_components from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.tasks import SupportedTask +from vllm.tokenizers.protocol import TokenizerLike from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.usage.usage_lib import UsageContext @@ -857,6 +858,34 @@ async def encode( if q is not None: q.close() + @property + def tokenizer(self) -> TokenizerLike | None: + warnings.warn( + "`AsyncLLM.tokenizer` is deprecated and will be removed in a " + "future version. Please use `AsyncRenderer.tokenizer` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.renderer.tokenizer + + def get_tokenizer(self) -> TokenizerLike: + warnings.warn( + "`AsyncLLM.get_tokenizer()` is deprecated and will be removed in a " + "future version. Please use `AsyncRenderer.get_tokenizer()` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.renderer.get_tokenizer() + + async def is_tracing_enabled(self) -> bool: + warnings.warn( + "`AsyncLLM.is_tracing_enabled()` is deprecated and will be removed in a " + "future version. Please use `AsyncRenderer.is_tracing_enabled()` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.observability_config.otlp_traces_endpoint is not None + async def do_log_stats(self) -> None: if self.logger_manager: self.logger_manager.log() @@ -1074,12 +1103,12 @@ class AsyncRenderer(RendererClient): def __init__(self, vllm_config: VllmConfig) -> None: self.vllm_config = vllm_config self.model_config = vllm_config.model_config - self.renderer = renderer = renderer_from_config(vllm_config) + self.renderer = renderer_from_config(vllm_config) self.io_processor = get_io_processor( vllm_config, self.model_config.io_processor_plugin, ) - self.input_processor = InputProcessor(vllm_config, renderer) + self.input_processor = InputProcessor(vllm_config, self.renderer) self._observability_config = vllm_config.observability_config tracing_endpoint = self._observability_config.otlp_traces_endpoint @@ -1111,5 +1140,12 @@ def shutdown(self) -> None: if renderer := getattr(self, "renderer", None): renderer.shutdown() + @property + def tokenizer(self) -> TokenizerLike | None: + return self.renderer.tokenizer + + def get_tokenizer(self) -> TokenizerLike: + return self.renderer.get_tokenizer() + async def is_tracing_enabled(self) -> bool: return self._observability_config.otlp_traces_endpoint is not None From b976c3c5cfe85eac0368e61dd8b96b693f3f6d30 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Mon, 2 Mar 2026 14:38:27 +0200 Subject: [PATCH 13/53] render client optional arg in init_app_state Signed-off-by: Sage Ahrac --- vllm/entrypoints/openai/api_server.py | 21 ++++++++++++++++++--- vllm/entrypoints/openai/run_batch.py | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b4f99b6b0dcc..c257bdd80c0e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -12,7 +12,7 @@ from argparse import Namespace from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any +from typing import Any, cast import uvloop from fastapi import FastAPI, HTTPException @@ -286,12 +286,27 @@ def build_app( async def init_app_state( - renderer_client: RendererClient, engine_client: EngineClient, state: State, args: Namespace, supported_tasks: tuple["SupportedTask", ...] | None = None, + renderer_client: RendererClient | None = None, ) -> None: + if renderer_client is None: + # Backward compat: callers that only pass engine_client (e.g. external + # users such as open-instruct). AsyncLLM satisfies the RendererClient + # interface structurally (owns renderer, vllm_config, input_processor). + warnings.warn( + "Calling init_app_state without renderer_client is deprecated " + "and will be removed in a future version. " + "Pass the renderer explicitly: " + "init_app_state(engine_client, state, args, " + "renderer_client=renderer_client).", + DeprecationWarning, + stacklevel=2, + ) + renderer_client = cast(RendererClient, engine_client) + vllm_config = renderer_client.vllm_config if supported_tasks is None: warnings.warn( @@ -521,7 +536,7 @@ async def run_server_worker( app = build_app(args, supported_tasks) await init_app_state( - renderer_client, engine_client, app.state, args, supported_tasks + engine_client, app.state, args, supported_tasks, renderer_client ) logger.info( diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 2e24df55d8e3..b1621a3ec5e0 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -676,7 +676,7 @@ async def build_endpoint_registry( # Initialize all serving objects using init_app_state # This provides full functionality including chat template processing, # LoRA support, tool servers, etc. - await init_app_state(renderer_client, engine_client, state, args, supported_tasks) + await init_app_state(engine_client, state, args, supported_tasks, renderer_client) # Get serving objects from state (defaulting to None if not set) openai_serving_chat = getattr(state, "openai_serving_chat", None) From 2adc96c425c0b455e3b8567f70cfe11b33d79ff6 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Mon, 2 Mar 2026 14:51:50 +0200 Subject: [PATCH 14/53] fix io_processor init Signed-off-by: Sage Ahrac --- vllm/v1/engine/async_llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 87c6745642a8..b6fc739160d7 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1106,6 +1106,7 @@ def __init__(self, vllm_config: VllmConfig) -> None: self.renderer = renderer_from_config(vllm_config) self.io_processor = get_io_processor( vllm_config, + self.renderer, self.model_config.io_processor_plugin, ) self.input_processor = InputProcessor(vllm_config, self.renderer) From 8dc756fca25b8f997603004b422eb5cf24a78114 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Mon, 2 Mar 2026 16:35:23 +0200 Subject: [PATCH 15/53] Add co-author Co-authored-by: HyunKyun Moon Signed-off-by: Sage Ahrac From 1952ce9a7c842759e838fb0ed826f312cffd6a79 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Tue, 3 Mar 2026 09:30:06 +0200 Subject: [PATCH 16/53] build_async_clients_from_engine_args Signed-off-by: Sage Ahrac --- tests/lora/test_add_lora.py | 4 ++-- tests/lora/test_lora_functions.py | 4 ++-- vllm/benchmarks/throughput.py | 4 ++-- vllm/entrypoints/openai/api_server.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index 38d12fb2d36d..0fb156c48ab9 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -7,7 +7,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args, + build_async_clients_from_engine_args, ) from vllm.inputs import TextPrompt from vllm.lora.request import LoRARequest @@ -86,7 +86,7 @@ async def test_add_lora(chatglm3_lora_files): warmup_run_requests = lora_requests[part_size : part_size * 2] cold_run_requests = lora_requests[part_size * 2 :] - async with build_async_engine_client_from_engine_args(engine_args) as (_, llm): + async with build_async_clients_from_engine_args(engine_args) as (_, llm): # Dummy run - So any 1-time functionality like triton kernel compilation # is complete here. await requests_processing_time(llm, dummy_run_requests) diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py index c5419e09f74f..6230ad86fab7 100644 --- a/tests/lora/test_lora_functions.py +++ b/tests/lora/test_lora_functions.py @@ -8,7 +8,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args, + build_async_clients_from_engine_args, ) from vllm.lora.request import LoRARequest from vllm.v1.engine.llm_engine import LLMEngine @@ -88,7 +88,7 @@ async def run_check(fn, args, expected: list): await fn(args) assert set(await llm.list_loras()) == set(expected) - async with build_async_engine_client_from_engine_args(engine_args) as (_, llm): + async with build_async_clients_from_engine_args(engine_args) as (_, llm): await run_check(llm.add_lora, make_lora_request(1), [1]) await run_check(llm.add_lora, make_lora_request(2), [1, 2]) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 41fce28e9835..f9b9b192fc8b 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -185,10 +185,10 @@ async def run_vllm_async( ) -> float: from vllm import SamplingParams from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args, + build_async_clients_from_engine_args, ) - async with build_async_engine_client_from_engine_args( + async with build_async_clients_from_engine_args( engine_args, disable_frontend_multiprocessing=disable_frontend_multiprocessing, ) as (renderer_client, engine_client): diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c257bdd80c0e..ce90f3e4d67e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -93,7 +93,7 @@ async def build_async_engine_client( if disable_frontend_multiprocessing is None: disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing) - async with build_async_engine_client_from_engine_args( + async with build_async_clients_from_engine_args( engine_args, usage_context=usage_context, disable_frontend_multiprocessing=disable_frontend_multiprocessing, @@ -103,7 +103,7 @@ async def build_async_engine_client( @asynccontextmanager -async def build_async_engine_client_from_engine_args( +async def build_async_clients_from_engine_args( engine_args: AsyncEngineArgs, *, usage_context: UsageContext = UsageContext.OPENAI_API_SERVER, From b8401cde0ebb8ea3896f809fc84d6e7ea5eb830e Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 2 Mar 2026 23:32:15 -0800 Subject: [PATCH 17/53] add regression test (#35834) Signed-off-by: hallerite --- .../openai/test_tokenization_vlm.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tests/entrypoints/openai/test_tokenization_vlm.py diff --git a/tests/entrypoints/openai/test_tokenization_vlm.py b/tests/entrypoints/openai/test_tokenization_vlm.py new file mode 100644 index 000000000000..c84ac3cf7df7 --- /dev/null +++ b/tests/entrypoints/openai/test_tokenization_vlm.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Regression test: ``/tokenize`` must expand image placeholders for VLM models. + +Fixed by PR #34560 ("Move InputPreprocessor into Renderer (2/2)"). +Before that change, ``/tokenize`` returned ~26 tokens for a message with an +image instead of the expected 1451. Confirmed broken on 0.15.1 and 0.16.0. +""" + +import json + +import pytest +import requests + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "4096", + "--max-num-seqs", + "5", + "--enforce-eager", + "--limit-mm-per-prompt", + json.dumps({"image": 1}), + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +def test_tokenize_chat_expands_image_placeholders( + server: RemoteOpenAIServer, + local_asset_server, +): + image_url = local_asset_server.url_for("stop_sign.jpg") + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Describe this image."}, + ], + } + ] + + response = requests.post( + server.url_for("tokenize"), + json={"model": MODEL_NAME, "messages": messages}, + ) + response.raise_for_status() + + # stop_sign.jpg (1300x876) produces 1451 tokens after expansion. + # Without expansion the count would be ~26 (text + one placeholder). + assert response.json()["count"] == 1451 From fd0298441140b61b469c45b0304cabf860539189 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Tue, 3 Mar 2026 09:35:00 +0200 Subject: [PATCH 18/53] decouple AsyncRenderer to a separate file Signed-off-by: Sage Ahrac --- vllm/v1/engine/async_llm.py | 63 +--------------------------- vllm/v1/engine/async_renderer.py | 70 ++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 62 deletions(-) create mode 100644 vllm/v1/engine/async_renderer.py diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b6fc739160d7..e390b4261658 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -19,7 +19,7 @@ WeightTransferUpdateRequest, ) from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.protocol import EngineClient, RendererClient, StreamingInput +from vllm.engine.protocol import EngineClient, StreamingInput from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger @@ -1089,64 +1089,3 @@ async def update_weights(self, request: WeightTransferUpdateRequest) -> None: await self.collective_rpc( "update_weights", kwargs={"update_info": update_info_dict} ) - - -class AsyncRenderer(RendererClient): - """Standalone RendererClient built directly from a VllmConfig. - - Owns the renderer, io_processor, and input_processor — all CPU-only - resources. Does not depend on :class:`AsyncLLM` or any inference engine. - In a disaggregated deployment this class would be replaced by a remote stub - that talks to a dedicated renderer process over the network. - """ - - def __init__(self, vllm_config: VllmConfig) -> None: - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.renderer = renderer_from_config(vllm_config) - self.io_processor = get_io_processor( - vllm_config, - self.renderer, - self.model_config.io_processor_plugin, - ) - self.input_processor = InputProcessor(vllm_config, self.renderer) - self._observability_config = vllm_config.observability_config - - tracing_endpoint = self._observability_config.otlp_traces_endpoint - if tracing_endpoint is not None: - init_tracer("vllm.llm_engine", tracing_endpoint) - - # Client base (liveness) — renderer has no long-running background process - - @property - def is_running(self) -> bool: - return True - - @property - def is_stopped(self) -> bool: - return False - - @property - def errored(self) -> bool: - return False - - @property - def dead_error(self) -> BaseException: - raise RuntimeError("AsyncRenderer has no error state") - - async def check_health(self) -> None: - pass # no background process to check - - def shutdown(self) -> None: - if renderer := getattr(self, "renderer", None): - renderer.shutdown() - - @property - def tokenizer(self) -> TokenizerLike | None: - return self.renderer.tokenizer - - def get_tokenizer(self) -> TokenizerLike: - return self.renderer.get_tokenizer() - - async def is_tracing_enabled(self) -> bool: - return self._observability_config.otlp_traces_endpoint is not None diff --git a/vllm/v1/engine/async_renderer.py b/vllm/v1/engine/async_renderer.py new file mode 100644 index 000000000000..258957cf5be9 --- /dev/null +++ b/vllm/v1/engine/async_renderer.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.config import VllmConfig +from vllm.engine.protocol import RendererClient +from vllm.plugins.io_processors import get_io_processor +from vllm.renderers import renderer_from_config +from vllm.tokenizers.protocol import TokenizerLike +from vllm.tracing import init_tracer +from vllm.v1.engine.input_processor import InputProcessor + + +class AsyncRenderer(RendererClient): + """Standalone RendererClient built directly from a VllmConfig. + + Owns the renderer, io_processor, and input_processor — all CPU-only + resources. Does not depend on :class:`AsyncLLM` or any inference engine. + In a disaggregated deployment this class would be replaced by a remote stub + that talks to a dedicated renderer process over the network. + """ + + def __init__(self, vllm_config: VllmConfig) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.renderer = renderer_from_config(vllm_config) + self.io_processor = get_io_processor( + vllm_config, + self.renderer, + self.model_config.io_processor_plugin, + ) + self.input_processor = InputProcessor(vllm_config, self.renderer) + self._observability_config = vllm_config.observability_config + + tracing_endpoint = self._observability_config.otlp_traces_endpoint + if tracing_endpoint is not None: + init_tracer("vllm.llm_engine", tracing_endpoint) + + # Client base (liveness) — renderer has no long-running background process + + @property + def is_running(self) -> bool: + return True + + @property + def is_stopped(self) -> bool: + return False + + @property + def errored(self) -> bool: + return False + + @property + def dead_error(self) -> BaseException: + raise RuntimeError("AsyncRenderer has no error state") + + async def check_health(self) -> None: + pass # no background process to check + + def shutdown(self) -> None: + if renderer := getattr(self, "renderer", None): + renderer.shutdown() + + @property + def tokenizer(self) -> TokenizerLike | None: + return self.renderer.tokenizer + + def get_tokenizer(self) -> TokenizerLike: + return self.renderer.get_tokenizer() + + async def is_tracing_enabled(self) -> bool: + return self._observability_config.otlp_traces_endpoint is not None From 8d3f4eeab83c080b47d4a0fd14bb89e872998796 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Tue, 3 Mar 2026 09:42:46 +0200 Subject: [PATCH 19/53] import fix Signed-off-by: Sage Ahrac --- vllm/entrypoints/openai/api_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ce90f3e4d67e..8f70f8abf5f0 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -118,7 +118,8 @@ async def build_async_clients_from_engine_args( if disable_frontend_multiprocessing: logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.") - from vllm.v1.engine.async_llm import AsyncLLM, AsyncRenderer + from vllm.v1.engine.async_llm import AsyncLLM + from vllm.v1.engine.async_renderer import AsyncRenderer async_renderer: AsyncRenderer | None = None async_llm: AsyncLLM | None = None From 4beebfd14650b1c6a687e7ab496d501423a0e50d Mon Sep 17 00:00:00 2001 From: Szymon Reginis Date: Tue, 3 Mar 2026 12:48:24 +0100 Subject: [PATCH 20/53] [CI/Build][Intel] Add new performance benchmarks for Intel Gaudi 3 (#31025) Signed-off-by: Szymon Reginis Co-authored-by: Kunshang Ji --- .../tests/latency-tests-hpu.json | 51 ++++++++++++ .../tests/serving-tests-hpu.json | 79 +++++++++++++++++++ .../tests/throughput-tests-hpu.json | 62 +++++++++++++++ 3 files changed, 192 insertions(+) diff --git a/.buildkite/performance-benchmarks/tests/latency-tests-hpu.json b/.buildkite/performance-benchmarks/tests/latency-tests-hpu.json index 296380f72a66..3b3fb4bed801 100644 --- a/.buildkite/performance-benchmarks/tests/latency-tests-hpu.json +++ b/.buildkite/performance-benchmarks/tests/latency-tests-hpu.json @@ -51,5 +51,56 @@ "max-model-len": 256, "async-scheduling": "" } + }, + { + "test_name": "latency_deepseek_r1", + "environment_variables": { + "PT_HPU_LAZY_MODE": 1, + "PT_HPU_ENABLE_LAZY_COLLECTIVES": 1, + "VLLM_CONTIGUOUS_PA": 1, + "VLLM_DEFRAG": 1 + }, + "parameters": { + "model": "deepseek-ai/DeepSeek-R1", + "tensor_parallel_size": 8, + "load_format": "dummy", + "max-model-len": 2048, + "dtype": "bfloat16" + } + }, + { + "test_name": "latency_llama4_maverick_17b128e_instruct_fp8", + "environment_variables": { + "PT_HPU_LAZY_MODE": 1, + "PT_HPU_ENABLE_LAZY_COLLECTIVES": 1, + "VLLM_CONTIGUOUS_PA": 1, + "VLLM_DEFRAG": 1 + }, + "parameters": { + "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + "tensor_parallel_size": 8, + "max-model-len": 512, + "max-num-seqs": 128, + "async-scheduling": "", + "gpu-memory-utilization": 0.95, + "enable_expert_parallel": "" + } + }, + { + "test_name": "latency_qwen3_8b", + "environment_variables": { + "PT_HPU_LAZY_MODE": 1, + "PT_HPU_ENABLE_LAZY_COLLECTIVES": 1, + "VLLM_CONTIGUOUS_PA": 1, + "VLLM_DEFRAG": 1 + }, + "parameters": { + "model": "Qwen/Qwen3-8B", + "tensor_parallel_size": 1, + "max-model-len": 2048, + "max-num-seqs": 128, + "dtype": "bfloat16", + "async-scheduling": "" + } } ] diff --git a/.buildkite/performance-benchmarks/tests/serving-tests-hpu.json b/.buildkite/performance-benchmarks/tests/serving-tests-hpu.json index 8c6b34bd9fa3..a2e42aa16fd3 100644 --- a/.buildkite/performance-benchmarks/tests/serving-tests-hpu.json +++ b/.buildkite/performance-benchmarks/tests/serving-tests-hpu.json @@ -78,5 +78,84 @@ "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "num_prompts": 200 } + }, + { + "test_name": "serving_deepseek_r1", + "qps_list": [1, 4, 16, "inf"], + "server_environment_variables": { + "PT_HPU_LAZY_MODE": 1, + "PT_HPU_ENABLE_LAZY_COLLECTIVES": 1, + "VLLM_CONTIGUOUS_PA": 1, + "VLLM_DEFRAG": 1 + }, + "server_parameters": { + "model": "deepseek-ai/DeepSeek-R1", + "tensor_parallel_size": 8, + "swap_space": 16, + "disable_log_stats": "", + "load_format": "dummy", + "max-model-len": 2048, + "max-num-seqs": 200, + "async-scheduling": "", + "dtype": "bfloat16" + }, + "client_parameters": { + "model": "deepseek-ai/DeepSeek-R1", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama4_maverick_17b128e_instruct_fp8", + "qps_list": [1, 4, 16, "inf"], + "server_environment_variables": { + "PT_HPU_LAZY_MODE": 1, + "PT_HPU_ENABLE_LAZY_COLLECTIVES": 1, + "VLLM_CONTIGUOUS_PA": 1, + "VLLM_DEFRAG": 1 + }, + "server_parameters": { + "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + "tensor_parallel_size": 8, + "disable_log_stats": "", + "max-model-len": 2048, + "max-num-seqs": 128, + "async-scheduling": "", + "enable_expert_parallel": "", + "max-num-batched-tokens": 4096 + }, + "client_parameters": { + "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_qwen3_8b", + "qps_list": [1, 4, 10, "inf"], + "server_environment_variables": { + "PT_HPU_LAZY_MODE": 1, + "PT_HPU_ENABLE_LAZY_COLLECTIVES": 1, + "VLLM_CONTIGUOUS_PA": 1, + "VLLM_DEFRAG": 1 + }, + "server_parameters": { + "model": "Qwen/Qwen-3-8B", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "disable_log_stats": "", + "async-scheduling": "" + }, + "client_parameters": { + "model": "Qwen/Qwen-3-8B", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } } ] diff --git a/.buildkite/performance-benchmarks/tests/throughput-tests-hpu.json b/.buildkite/performance-benchmarks/tests/throughput-tests-hpu.json index 3127bf2f6bce..25344348bb39 100644 --- a/.buildkite/performance-benchmarks/tests/throughput-tests-hpu.json +++ b/.buildkite/performance-benchmarks/tests/throughput-tests-hpu.json @@ -57,5 +57,67 @@ "max-num-seqs": 512, "async-scheduling": "" } + }, + { + "test_name": "throughput_deepseek_r1", + "environment_variables": { + "PT_HPU_LAZY_MODE": 1, + "PT_HPU_ENABLE_LAZY_COLLECTIVES": 1, + "VLLM_CONTIGUOUS_PA": 1, + "VLLM_DEFRAG": 1 + }, + "parameters": { + "model": "deepseek-ai/DeepSeek-R1", + "tensor_parallel_size": 8, + "load_format": "dummy", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "dataset_name": "sharegpt", + "num_prompts": 1000, + "backend": "vllm", + "max-model-len": 2048, + "max-num-seqs": 384, + "async-scheduling": "" + } + }, + { + "test_name": "throughput_llama4_maverick_17b128e_instruct_fp8", + "environment_variables": { + "PT_HPU_LAZY_MODE": 1, + "PT_HPU_ENABLE_LAZY_COLLECTIVES": 1, + "VLLM_CONTIGUOUS_PA": 1, + "VLLM_DEFRAG": 1 + }, + "parameters": { + "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + "tensor_parallel_size": 8, + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "dataset_name": "sharegpt", + "num_prompts": 1000, + "backend": "vllm", + "max-model-len": 2048, + "max-num-seqs": 512, + "async-scheduling": "", + "enable_expert_parallel": "" + } + }, + { + "test_name": "throughput_qwen3_8b", + "environment_variables": { + "PT_HPU_LAZY_MODE": 1, + "PT_HPU_ENABLE_LAZY_COLLECTIVES": 1, + "VLLM_CONTIGUOUS_PA": 1, + "VLLM_DEFRAG": 1 + }, + "parameters": { + "model": "Qwen/Qwen-3-8B", + "tensor_parallel_size": 1, + "load_format": "dummy", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "dataset_name": "sharegpt", + "num_prompts": 1000, + "max-num-seqs": 512, + "backend": "vllm", + "async-scheduling": "" + } } ] From ad9d09e2b8a601b50d07c76fb8736c2bbda2d6fb Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 3 Mar 2026 13:15:43 +0100 Subject: [PATCH 21/53] [Perf] [Hybrid] Copy num_accepted_tokens in non-blocking way when not using prefix caching (#35442) Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8b818f67c3d2..c9d9ecf4ac67 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1191,13 +1191,14 @@ def _update_states_after_model_execute( return # Find the number of accepted tokens for each sequence. - num_accepted_tokens = ( + num_reqs = output_token_ids.size(0) + self.num_accepted_tokens.gpu[:num_reqs] = ( ( torch.cat( [ output_token_ids, torch.full( - (output_token_ids.size(0), 1), + (num_reqs, 1), -1, device=output_token_ids.device, ), @@ -1208,12 +1209,13 @@ def _update_states_after_model_execute( ) .int() .argmax(-1) - .cpu() - .numpy() ) - for i, num_tokens in enumerate(num_accepted_tokens): - self.input_batch.num_accepted_tokens_cpu[i] = num_tokens if self.cache_config.mamba_cache_mode == "align": + for i, num_tokens in enumerate( + self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() + ): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + mamba_utils.postprocess_mamba( scheduler_output, self.kv_cache_config, @@ -1224,6 +1226,10 @@ def _update_states_after_model_execute( self.model.get_mamba_state_copy_func(), self._get_mamba_copy_bufs(), ) + else: + self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_( + self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True + ) def _update_streaming_request( self, req_id: str, new_req_data: NewRequestData From fd4a90f337f7fe188581d71d4d3ec712767320c0 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 3 Mar 2026 21:15:51 +0800 Subject: [PATCH 22/53] [CI] And PPL test for Qwen3.5. (#35853) Signed-off-by: wang.yuqi Signed-off-by: wang.yuqi Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../language/generation_ppl_test/test_gemma.py | 6 +++--- .../language/generation_ppl_test/test_gpt.py | 2 +- .../language/generation_ppl_test/test_qwen.py | 18 ++++++++++++------ 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/models/language/generation_ppl_test/test_gemma.py b/tests/models/language/generation_ppl_test/test_gemma.py index 5324de143d67..b846bb702064 100644 --- a/tests/models/language/generation_ppl_test/test_gemma.py +++ b/tests/models/language/generation_ppl_test/test_gemma.py @@ -7,9 +7,9 @@ from .ppl_utils import wikitext_ppl_test MODELS = [ - GenerateModelInfo("google/gemma-2b"), - GenerateModelInfo("google/gemma-2-2b"), - GenerateModelInfo("google/gemma-3-4b-it"), + GenerateModelInfo("google/gemma-2b", hf_ppl=21.48524284362793), + GenerateModelInfo("google/gemma-2-2b", hf_ppl=102.59290313720703), + GenerateModelInfo("google/gemma-3-4b-it", hf_ppl=27.79648208618164), ] diff --git a/tests/models/language/generation_ppl_test/test_gpt.py b/tests/models/language/generation_ppl_test/test_gpt.py index f3f9e55a2423..784f3e85a138 100644 --- a/tests/models/language/generation_ppl_test/test_gpt.py +++ b/tests/models/language/generation_ppl_test/test_gpt.py @@ -6,7 +6,7 @@ from .ppl_utils import wikitext_ppl_test -MODELS = [GenerateModelInfo("openai-community/gpt2-large")] +MODELS = [GenerateModelInfo("openai-community/gpt2-large", hf_ppl=19.457056045532227)] @pytest.mark.parametrize("model_info", MODELS) diff --git a/tests/models/language/generation_ppl_test/test_qwen.py b/tests/models/language/generation_ppl_test/test_qwen.py index 0d3127cbaac4..60e69c3f87a4 100644 --- a/tests/models/language/generation_ppl_test/test_qwen.py +++ b/tests/models/language/generation_ppl_test/test_qwen.py @@ -8,14 +8,20 @@ from .ppl_utils import wikitext_ppl_test MODELS = [ - GenerateModelInfo("Qwen/Qwen3-0.6B"), - GenerateModelInfo("Qwen/Qwen3-0.6B-FP8"), - # transformers: - # Loading a GPTQ quantized model requires optimum, gptqmodel - # GenerateModelInfo("Qwen/Qwen3-0.6B-GPTQ-Int8"), + # for Qwen3 + GenerateModelInfo("Qwen/Qwen3-0.6B", hf_ppl=23.864173889160156), + GenerateModelInfo("Qwen/Qwen3-0.6B-FP8", hf_ppl=24.313045501708984), + # for Qwen3.5 + GenerateModelInfo("Qwen/Qwen3.5-0.8B", hf_ppl=19.38858413696289), ] @pytest.mark.parametrize("model_info", MODELS) def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo): - wikitext_ppl_test(hf_runner, vllm_runner, model_info) + vllm_extra_kwargs = {} + if model_info.name == "Qwen/Qwen3.5-0.8B": + vllm_extra_kwargs["language_model_only"] = True + + wikitext_ppl_test( + hf_runner, vllm_runner, model_info, vllm_extra_kwargs=vllm_extra_kwargs + ) From 440f0e7dc6cb0adfc9c3c98076939668b90c4bf2 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Tue, 3 Mar 2026 21:56:08 +0800 Subject: [PATCH 23/53] [Bugfix] Avoid src/dst as None in irecv/isend_tensor_dict (#35754) Signed-off-by: jiang1.li --- .../run-cpu-distributed-smoke-test.sh | 25 ++++++++++++++++--- vllm/distributed/parallel_state.py | 17 +++++++------ 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-cpu-distributed-smoke-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-distributed-smoke-test.sh index 3caa49832c3f..f289a43c6be4 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-distributed-smoke-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-distributed-smoke-test.sh @@ -1,26 +1,43 @@ #!/bin/bash set -euox pipefail +export VLLM_CPU_CI_ENV=0 echo "--- PP+TP" vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 & server_pid=$! -timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 +timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1 vllm bench serve \ --backend vllm \ --dataset-name random \ --model meta-llama/Llama-3.2-3B-Instruct \ --num-prompts 20 \ + --result-dir ./test_results \ + --result-filename tp_pp.json \ + --save-result \ --endpoint /v1/completions -kill -s SIGTERM $server_pid & +kill -s SIGTERM $server_pid; wait $server_pid || true +failed_req=$(jq '.failed' ./test_results/tp_pp.json) +if [ "$failed_req" -ne 0 ]; then + echo "Some requests were failed!" + exit 1 +fi echo "--- DP+TP" vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 & server_pid=$! -timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 +timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1 vllm bench serve \ --backend vllm \ --dataset-name random \ --model meta-llama/Llama-3.2-3B-Instruct \ --num-prompts 20 \ + --result-dir ./test_results \ + --result-filename dp_pp.json \ + --save-result \ --endpoint /v1/completions -kill -s SIGTERM $server_pid & +kill -s SIGTERM $server_pid; wait $server_pid || true +failed_req=$(jq '.failed' ./test_results/dp_pp.json) +if [ "$failed_req" -ne 0 ]; then + echo "Some requests were failed!" + exit 1 +fi diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 40b797a1a8d9..fc554bd75694 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -851,6 +851,10 @@ def isend_tensor_dict( if self.world_size <= 1: return [] + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size, f"Invalid dst rank ({dst})" + if self.use_cpu_custom_send_recv: if self.device_communicator is None: raise ValueError("No device communicator found") @@ -868,10 +872,6 @@ def isend_tensor_dict( group = self.device_group metadata_group = self.cpu_group - if dst is None: - dst = (self.rank_in_group + 1) % self.world_size - assert dst < self.world_size, f"Invalid dst rank ({dst})" - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) self.send_object(metadata_list, dst=dst) @@ -948,6 +948,11 @@ def irecv_tensor_dict( ]: if not torch.distributed.is_initialized() or self.world_size == 1: return None, [], [] + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size, f"Invalid src rank ({src})" + if self.use_cpu_custom_send_recv: if self.device_communicator is None: raise ValueError("No device communicator found") @@ -965,10 +970,6 @@ def irecv_tensor_dict( group = self.device_group metadata_group = self.cpu_group - if src is None: - src = (self.rank_in_group - 1) % self.world_size - assert src < self.world_size, f"Invalid src rank ({src})" - recv_metadata_list = self.recv_object(src=src) tensor_dict: dict[str, Any] = {} handles: list[Handle] = [] From ea463978bb987a4c15c9b51c0013d620a722aa67 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 3 Mar 2026 22:05:36 +0800 Subject: [PATCH 24/53] [Frontend][1/n] Improve pooling entrypoints | classify. (#35604) Signed-off-by: wang.yuqi Signed-off-by: wang.yuqi Co-authored-by: Cyrus Leung --- vllm/entrypoints/chat_utils.py | 8 + vllm/entrypoints/llm.py | 92 +++-- vllm/entrypoints/openai/engine/serving.py | 21 +- vllm/entrypoints/pooling/base/io_processor.py | 189 +++++++++ vllm/entrypoints/pooling/base/serving.py | 378 ++++++++++++++++++ .../pooling/classify/api_router.py | 31 +- .../pooling/classify/io_processor.py | 50 +++ vllm/entrypoints/pooling/classify/serving.py | 136 ++----- .../pooling/io_processor_factories.py | 31 ++ vllm/entrypoints/pooling/typing.py | 51 +++ vllm/entrypoints/sagemaker/api_router.py | 3 +- vllm/entrypoints/utils.py | 71 +++- 12 files changed, 890 insertions(+), 171 deletions(-) create mode 100644 vllm/entrypoints/pooling/base/io_processor.py create mode 100644 vllm/entrypoints/pooling/base/serving.py create mode 100644 vllm/entrypoints/pooling/classify/io_processor.py create mode 100644 vllm/entrypoints/pooling/io_processor_factories.py create mode 100644 vllm/entrypoints/pooling/typing.py diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c48d7bea983c..1d10aa6b09e7 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from collections import Counter, defaultdict from collections.abc import Awaitable, Callable, Iterable +from dataclasses import dataclass from functools import cached_property, lru_cache, partial from itertools import accumulate from pathlib import Path @@ -1024,6 +1025,13 @@ def parse_video(self, video_url: str | None, uuid: str | None = None) -> None: self._add_placeholder("video", placeholder) +@dataclass +class ChatTemplateConfig: + chat_template: str | None = None + chat_template_content_format: ChatTemplateContentFormatOption = "auto" + trust_request_chat_template: bool = False + + def validate_chat_template(chat_template: Path | str | None): """Raises if the provided chat template appears invalid.""" if chat_template is None: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b3260f9144ec..d5a51a6b95c7 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,6 +3,7 @@ import itertools from collections.abc import Callable, Iterable, Sequence +from pathlib import Path from typing import TYPE_CHECKING, Any import cloudpickle @@ -40,8 +41,11 @@ from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, + ChatTemplateConfig, ChatTemplateContentFormatOption, + load_chat_template, ) +from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors from vllm.entrypoints.pooling.score.utils import ( ScoreData, ScoreMultiModalParam, @@ -145,6 +149,7 @@ class LLM: a tag name, or a commit id. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. + chat_template: The chat template to apply. seed: The seed to initialize the random number generator for sampling. gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. Higher @@ -232,6 +237,7 @@ def __init__( quantization: QuantizationMethods | None = None, revision: str | None = None, tokenizer_revision: str | None = None, + chat_template: Path | str | None = None, seed: int = 0, gpu_memory_utilization: float = 0.9, swap_space: float = 4, @@ -384,9 +390,16 @@ def _make_config(value: Any, cls: type[_R]) -> _R: self.model_config = self.llm_engine.model_config self.renderer = self.llm_engine.renderer + self.chat_template = load_chat_template(chat_template) self.io_processor = self.llm_engine.io_processor self.input_processor = self.llm_engine.input_processor - + self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template) + self.init_pooling_io_processors = init_pooling_io_processors( + supported_tasks=supported_tasks, + model_config=self.model_config, + renderer=self.renderer, + chat_template_config=self.chat_template_config, + ) # Cache for __repr__ to avoid repeated collective_rpc calls self._cached_repr: str | None = None @@ -1086,7 +1099,7 @@ def encode( "pooling model." ) - if use_io_processor := (isinstance(prompts, dict) and "data" in prompts): + if isinstance(prompts, dict) and "data" in prompts: if self.io_processor is None: raise ValueError( "No IOProcessor plugin installed. Please refer " @@ -1120,6 +1133,31 @@ def encode( for p in params_seq: if p.task is None: p.task = "plugin" + + outputs = self._run_completion( + prompts=prompts_seq, + params=params_seq, + output_type=PoolingRequestOutput, + use_tqdm=use_tqdm, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) + + # get the post-processed model outputs + assert self.io_processor is not None + processed_outputs = self.io_processor.post_process(outputs) + + return [ + PoolingRequestOutput[Any]( + request_id="", + outputs=processed_outputs, + num_cached_tokens=getattr( + processed_outputs, "num_cached_tokens", 0 + ), + prompt_token_ids=[], + finished=True, + ) + ] else: if pooling_params is None: # Use default pooling params. @@ -1137,32 +1175,36 @@ def encode( ) raise ValueError(msg) - outputs = self._run_completion( - prompts=prompts_seq, - params=params_seq, - output_type=PoolingRequestOutput, - use_tqdm=use_tqdm, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - ) - - if use_io_processor: - # get the post-processed model outputs - assert self.io_processor is not None - processed_outputs = self.io_processor.post_process(outputs) + if pooling_task in self.init_pooling_io_processors: + io_processor = self.init_pooling_io_processors[pooling_task] + processor_inputs = io_processor.pre_process_offline( + prompts_seq, tokenization_kwargs + ) + seq_lora_requests = self._lora_request_to_seq( + lora_request, len(prompts_seq) + ) + seq_priority = self._priority_to_seq(None, len(prompts)) - return [ - PoolingRequestOutput[Any]( - request_id="", - outputs=processed_outputs, - num_cached_tokens=getattr( - processed_outputs, "num_cached_tokens", 0 - ), - prompt_token_ids=[], - finished=True, + self._render_and_add_requests( + prompts=processor_inputs, + params=params_seq, + lora_requests=seq_lora_requests, + priorities=seq_priority, ) - ] + outputs = self._run_engine( + use_tqdm=use_tqdm, output_type=PoolingRequestOutput + ) + outputs = io_processor.post_process(outputs) + else: + outputs = self._run_completion( + prompts=prompts_seq, + params=params_seq, + output_type=PoolingRequestOutput, + use_tqdm=use_tqdm, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) return outputs def embed( diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 3e376ba9c704..e864f562ee1e 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -62,11 +62,6 @@ TranscriptionResponse, TranslationRequest, ) -from vllm.entrypoints.pooling.classify.protocol import ( - ClassificationChatRequest, - ClassificationCompletionRequest, - ClassificationResponse, -) from vllm.entrypoints.pooling.embed.protocol import ( EmbeddingBytesResponse, EmbeddingChatRequest, @@ -161,7 +156,6 @@ def build_chat_params( | TokenizeCompletionRequest | DetokenizeRequest | EmbeddingCompletionRequest - | ClassificationCompletionRequest | RerankRequest | ScoreRequest | PoolingCompletionRequest @@ -171,7 +165,6 @@ def build_chat_params( ChatCompletionRequest | TokenizeChatRequest | EmbeddingChatRequest - | ClassificationChatRequest | PoolingChatRequest ) @@ -194,12 +187,10 @@ def build_chat_params( | TranscriptionResponse | TokenizeResponse | PoolingResponse - | ClassificationResponse | ScoreResponse | GenerateResponse ) - RequestT = TypeVar("RequestT", bound=AnyRequest) @@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]): class OpenAIServing: request_id_prefix: ClassVar[str] = """ - A short string prepended to every request’s ID (e.g. "embd", "classify") - so you can easily tell “this ID came from Embedding vs Classification.” + A short string prepended to every request’s ID (e.g. "embd") + so you can easily tell “this ID came from Embedding.” """ def __init__( @@ -456,7 +447,7 @@ async def _preprocess( ) -> ErrorResponse | None: """ Default preprocessing hook. Subclasses may override - to prepare `ctx` (classification, embedding, etc.). + to prepare `ctx` (embedding, etc.). """ return None @@ -817,7 +808,7 @@ def _validate_input( token_num = len(input_ids) max_model_len = self.model_config.max_model_len - # Note: EmbeddingRequest, ClassificationRequest, + # Note: EmbeddingRequest, # and ScoreRequest doesn't have max_tokens if isinstance( request, @@ -828,8 +819,6 @@ def _validate_input( ScoreTextRequest, ScoreQueriesDocumentsRequest, RerankRequest, - ClassificationCompletionRequest, - ClassificationChatRequest, ), ): # Note: input length can be up to the entire model context length @@ -839,8 +828,6 @@ def _validate_input( ScoreDataRequest: "score", ScoreTextRequest: "score", ScoreQueriesDocumentsRequest: "score", - ClassificationCompletionRequest: "classification", - ClassificationChatRequest: "classification", } operation = operations.get(type(request), "embedding generation") raise VLLMValidationError( diff --git a/vllm/entrypoints/pooling/base/io_processor.py b/vllm/entrypoints/pooling/base/io_processor.py new file mode 100644 index 000000000000..254c3d64a4bd --- /dev/null +++ b/vllm/entrypoints/pooling/base/io_processor.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable, Sequence +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Final + +from vllm import PoolingRequestOutput, PromptType +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateConfig, + ChatTemplateContentFormatOption, + ConversationMessage, +) +from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest +from vllm.inputs import ProcessorInputs, SingletonPrompt +from vllm.renderers import BaseRenderer, merge_kwargs +from vllm.renderers.inputs import TokPrompt +from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser +from vllm.utils.mistral import is_mistral_tokenizer + + +class PoolingIOProcessor: + def __init__( + self, + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, + ): + self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + + self.model_config = model_config + self.renderer = renderer + + self.chat_template = chat_template_config.chat_template + self.chat_template_content_format: Final = ( + chat_template_config.chat_template_content_format + ) + self.trust_request_chat_template = ( + chat_template_config.trust_request_chat_template + ) + + def pre_process_online(self, *args, **kwargs): + raise NotImplementedError + + async def pre_process_online_async(self, *args, **kwargs): + return self.pre_process_online(*args, **kwargs) + + def pre_process_offline(self, *args, **kwargs): + raise NotImplementedError + + async def pre_process_offline_async(self, *args, **kwargs): + return self.pre_process_offline(*args, **kwargs) + + def post_process( + self, outputs: list[PoolingRequestOutput] + ) -> list[PoolingRequestOutput]: + return outputs + + async def post_process_async( + self, outputs: list[PoolingRequestOutput] + ) -> list[PoolingRequestOutput]: + return self.post_process(outputs) + + def create_pooling_params(self, request): + return request.to_pooling_params() + + def _preprocess_completion_online( + self, + request: RendererRequest, + prompt_input: str | list[str] | list[int] | list[list[int]] | None, + prompt_embeds: bytes | list[bytes] | None, + ) -> list[TokPrompt]: + renderer = self.renderer + model_config = self.model_config + + prompts = list[SingletonPrompt | bytes]() + if prompt_embeds is not None: # embeds take higher priority + prompts.extend(prompt_to_seq(prompt_embeds)) + if prompt_input is not None: + prompts.extend(prompt_to_seq(prompt_input)) + + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(model_config, prompt) + ) + for prompt in prompts + ] + tok_params = request.build_tok_params(model_config) + + return renderer.render_cmpl( + parsed_prompts, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) + + def _preprocess_chat_online( + self, + request: RendererChatRequest, + messages: list[ChatCompletionMessageParam], + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + default_template_kwargs: dict[str, Any] | None, + tool_dicts: list[dict[str, Any]] | None = None, + tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, + ) -> tuple[list[ConversationMessage], list[TokPrompt]]: + renderer = self.renderer + + default_template_kwargs = merge_kwargs( + default_template_kwargs, + dict( + tools=tool_dicts, + tokenize=is_mistral_tokenizer(renderer.tokenizer), + ), + ) + + tok_params = request.build_tok_params(self.model_config) + chat_params = request.build_chat_params( + default_template, default_template_content_format + ).with_defaults(default_template_kwargs) + + (conversation,), (engine_prompt,) = renderer.render_chat( + [messages], + chat_params, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) + + return conversation, [engine_prompt] + + def _preprocess_completion_offline( + self, + prompts: PromptType | Sequence[PromptType], + tokenization_kwargs: dict[str, Any] | None = None, + ) -> Sequence[ProcessorInputs]: + renderer = self.renderer + model_config = self.model_config + + prompts = prompt_to_seq(prompts) + + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(model_config, prompt) + ) + for prompt in prompts + ] + tok_params = renderer.default_cmpl_tok_params.with_kwargs( + **(tokenization_kwargs or {}) + ) + + return renderer.render_cmpl( + parsed_prompts, + tok_params, + ) + + def _validate_chat_template( + self, + request_chat_template: str | None, + chat_template_kwargs: dict[str, Any] | None, + trust_request_chat_template: bool, + ): + if not trust_request_chat_template and ( + request_chat_template is not None + or ( + chat_template_kwargs + and chat_template_kwargs.get("chat_template") is not None + ) + ): + raise ValueError( + "Chat template is passed with request, but " + "--trust-request-chat-template is not set. " + "Refused request with untrusted chat template." + ) + return None diff --git a/vllm/entrypoints/pooling/base/serving.py b/vllm/entrypoints/pooling/base/serving.py new file mode 100644 index 000000000000..813282d3d13f --- /dev/null +++ b/vllm/entrypoints/pooling/base/serving.py @@ -0,0 +1,378 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time +from collections.abc import AsyncGenerator, Mapping +from dataclasses import dataclass, field +from http import HTTPStatus +from typing import ClassVar, Generic, TypeVar + +from fastapi import Request +from pydantic import ConfigDict +from starlette.datastructures import Headers +from starlette.responses import JSONResponse + +from vllm import ( + PoolingParams, + PoolingRequestOutput, + PromptType, + SamplingParams, + envs, +) +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ( + ChatTemplateConfig, + ChatTemplateContentFormatOption, +) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.openai.models.serving import OpenAIServingModels +from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse +from vllm.inputs import ProcessorInputs +from vllm.lora.request import LoRARequest +from vllm.renderers import BaseRenderer +from vllm.renderers.inputs.preprocess import extract_prompt_components +from vllm.sampling_params import BeamSearchParams +from vllm.tracing import ( + contains_trace_headers, + extract_trace_headers, + log_tracing_disabled_warning, +) +from vllm.utils import random_uuid +from vllm.utils.async_utils import merge_async_iterators + +from ...utils import create_error_response +from .io_processor import PoolingIOProcessor + +PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest) + + +@dataclass(kw_only=True) +class PoolingServeContext(Generic[PoolingRequestT]): + request: PoolingRequestT + raw_request: Request | None = None + model_name: str + request_id: str + created_time: int = field(default_factory=lambda: int(time.time())) + lora_request: LoRARequest | None = None + engine_prompts: list[ProcessorInputs] | None = None + + result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( + None + ) + final_res_batch: list[PoolingRequestOutput] = field(default_factory=list) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class PoolingServing: + request_id_prefix: ClassVar[str] + + def __init__( + self, + engine_client: EngineClient, + models: OpenAIServingModels, + *, + request_logger: RequestLogger | None, + chat_template: str | None = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + trust_request_chat_template: bool = False, + return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, + ): + super().__init__() + self.engine_client = engine_client + self.models = models + self.model_config = models.model_config + self.max_model_len = self.model_config.max_model_len + self.request_logger = request_logger + self.return_tokens_as_token_ids = return_tokens_as_token_ids + self.log_error_stack = log_error_stack + self.chat_template_config = ChatTemplateConfig( + chat_template=chat_template, + chat_template_content_format=chat_template_content_format, + trust_request_chat_template=trust_request_chat_template, + ) + self.io_processor = self.init_io_processor( + model_config=models.model_config, + renderer=models.renderer, + chat_template_config=self.chat_template_config, + ) + + def init_io_processor( + self, + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, + ) -> PoolingIOProcessor: + raise NotImplementedError + + async def __call__( + self, + request: AnyPoolingRequest, + raw_request: Request, + ) -> JSONResponse: + try: + model_name = self.models.model_name() + request_id = ( + f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" + ) + + await self._check_model(request) + + ctx = PoolingServeContext( + request=request, + raw_request=raw_request, + model_name=model_name, + request_id=request_id, + ) + + self._validate_request(ctx) + self._maybe_get_adapters(ctx) + await self._preprocess(ctx) + await self._prepare_generators(ctx) + await self._collect_batch(ctx) + response = await self._build_response(ctx) + return JSONResponse(content=response.model_dump()) + except Exception as e: + error_response = create_error_response(e) + return JSONResponse( + content=error_response.model_dump(), + status_code=error_response.error.code, + ) + + async def _preprocess( + self, + ctx: PoolingServeContext, + ): + ctx.engine_prompts = await self.io_processor.pre_process_online_async( + ctx.request + ) + + async def _prepare_generators( + self, + ctx: PoolingServeContext, + ): + if ctx.engine_prompts is None: + raise ValueError("Engine prompts not available") + + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + + trace_headers = ( + None + if ctx.raw_request is None + else await self._get_trace_headers(ctx.raw_request.headers) + ) + + pooling_params = self.io_processor.create_pooling_params(ctx.request) + + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_id_item = f"{ctx.request_id}-{i}" + + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + ctx.result_generator = merge_async_iterators(*generators) + + async def _collect_batch( + self, + ctx: PoolingServeContext, + ): + if ctx.engine_prompts is None: + raise ValueError("Engine prompts not available") + + if ctx.result_generator is None: + raise ValueError("Result generator not available") + + num_prompts = len(ctx.engine_prompts) + final_res_batch: list[PoolingRequestOutput | None] + final_res_batch = [None] * num_prompts + + async for i, res in ctx.result_generator: + final_res_batch[i] = res + + if None in final_res_batch: + raise ValueError("Failed to generate results for all prompts") + + ctx.final_res_batch = [res for res in final_res_batch if res is not None] + + async def _build_response( + self, + ctx: PoolingServeContext, + ) -> AnyPoolingResponse: + raise NotImplementedError + + @staticmethod + def _base_request_id( + raw_request: Request | None, default: str | None = None + ) -> str | None: + """Pulls the request id to use from a header, if provided""" + if raw_request is not None and ( + (req_id := raw_request.headers.get("X-Request-Id")) is not None + ): + return req_id + + return random_uuid() if default is None else default + + def _is_model_supported(self, model_name: str | None) -> bool: + if not model_name: + return True + return self.models.is_base_model(model_name) + + async def _check_model( + self, + request: AnyPoolingRequest, + ) -> ErrorResponse | None: + if self._is_model_supported(request.model): + return None + if request.model in self.models.lora_requests: + return None + if ( + envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING + and request.model + and (load_result := await self.models.resolve_lora(request.model)) + ): + if isinstance(load_result, LoRARequest): + return None + if ( + isinstance(load_result, ErrorResponse) + and load_result.error.code == HTTPStatus.BAD_REQUEST.value + ): + raise ValueError(load_result.error.message) + return None + + def _validate_request(self, ctx: PoolingServeContext) -> None: + truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) + + if ( + truncate_prompt_tokens is not None + and truncate_prompt_tokens > self.max_model_len + ): + raise ValueError( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size." + ) + return None + + async def _get_trace_headers( + self, + headers: Headers, + ) -> Mapping[str, str] | None: + is_tracing_enabled = await self.engine_client.is_tracing_enabled() + + if is_tracing_enabled: + return extract_trace_headers(headers) + + if contains_trace_headers(headers): + log_tracing_disabled_warning() + + return None + + def _maybe_get_adapters( + self, + ctx: PoolingServeContext, + supports_default_mm_loras: bool = False, + ): + request = ctx.request + if request.model in self.models.lora_requests: + ctx.lora_request = self.models.lora_requests[request.model] + + # Currently only support default modality specific loras + # if we have exactly one lora matched on the request. + if supports_default_mm_loras: + default_mm_lora = self._get_active_default_mm_loras(request) + if default_mm_lora is not None: + ctx.lora_request = default_mm_lora + + if self._is_model_supported(request.model): + return None + + # if _check_model has been called earlier, this will be unreachable + raise ValueError(f"The model `{request.model}` does not exist.") + + def _get_active_default_mm_loras( + self, request: AnyPoolingRequest + ) -> LoRARequest | None: + """Determine if there are any active default multimodal loras.""" + # TODO: Currently this is only enabled for chat completions + # to be better aligned with only being enabled for .generate + # when run offline. It would be nice to support additional + # tasks types in the future. + message_types = self._get_message_types(request) + default_mm_loras = set() + + for lora in self.models.lora_requests.values(): + # Best effort match for default multimodal lora adapters; + # There is probably a better way to do this, but currently + # this matches against the set of 'types' in any content lists + # up until '_', e.g., to match audio_url -> audio + if lora.lora_name in message_types: + default_mm_loras.add(lora) + + # Currently only support default modality specific loras if + # we have exactly one lora matched on the request. + if len(default_mm_loras) == 1: + return default_mm_loras.pop() + return None + + def _get_message_types(self, request: AnyPoolingRequest) -> set[str]: + """Retrieve the set of types from message content dicts up + until `_`; we use this to match potential multimodal data + with default per modality loras. + """ + message_types: set[str] = set() + + if not hasattr(request, "messages"): + return message_types + + messages = request.messages + if messages is None or isinstance(messages, (str, bytes)): + return message_types + + for message in messages: + if ( + isinstance(message, dict) + and "content" in message + and isinstance(message["content"], list) + ): + for content_dict in message["content"]: + if "type" in content_dict: + message_types.add(content_dict["type"].split("_")[0]) + return message_types + + def _log_inputs( + self, + request_id: str, + inputs: PromptType | ProcessorInputs, + params: SamplingParams | PoolingParams | BeamSearchParams | None, + lora_request: LoRARequest | None, + ) -> None: + if self.request_logger is None: + return + + components = extract_prompt_components(self.model_config, inputs) + + self.request_logger.log_inputs( + request_id, + components.text, + components.token_ids, + components.embeds, + params=params, + lora_request=lora_request, + ) diff --git a/vllm/entrypoints/pooling/classify/api_router.py b/vllm/entrypoints/pooling/classify/api_router.py index 8a1513ebc928..0e99a86fe1d1 100644 --- a/vllm/entrypoints/pooling/classify/api_router.py +++ b/vllm/entrypoints/pooling/classify/api_router.py @@ -3,16 +3,17 @@ from fastapi import APIRouter, Depends, Request from starlette.responses import JSONResponse -from typing_extensions import assert_never -from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.pooling.classify.protocol import ( ClassificationRequest, - ClassificationResponse, ) from vllm.entrypoints.pooling.classify.serving import ServingClassification -from vllm.entrypoints.utils import load_aware_call, with_cancellation +from vllm.entrypoints.utils import ( + create_error_response, + load_aware_call, + with_cancellation, +) router = APIRouter() @@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None: @router.post("/classify", dependencies=[Depends(validate_json_request)]) @with_cancellation @load_aware_call -async def create_classify(request: ClassificationRequest, raw_request: Request): +async def create_classify( + request: ClassificationRequest, raw_request: Request +) -> JSONResponse: handler = classify(raw_request) if handler is None: - base_server = raw_request.app.state.openai_serving_tokenization - return base_server.create_error_response( + error_response = create_error_response( message="The model does not support Classification API" ) - - try: - generator = await handler.create_classify(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) - - if isinstance(generator, ErrorResponse): return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code + content=error_response.model_dump(), + status_code=error_response.error.code, ) - elif isinstance(generator, ClassificationResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) + return await handler(request, raw_request) diff --git a/vllm/entrypoints/pooling/classify/io_processor.py b/vllm/entrypoints/pooling/classify/io_processor.py new file mode 100644 index 000000000000..90d5b0e4fe0d --- /dev/null +++ b/vllm/entrypoints/pooling/classify/io_processor.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import Any + +from vllm import PromptType +from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor +from vllm.entrypoints.pooling.classify.protocol import ( + ClassificationChatRequest, + ClassificationCompletionRequest, +) +from vllm.inputs import ProcessorInputs +from vllm.renderers.inputs import TokPrompt + + +class ClassifyIOProcessor(PoolingIOProcessor): + def pre_process_online( + self, request: ClassificationCompletionRequest | ClassificationChatRequest + ) -> list[TokPrompt] | None: + if isinstance(request, ClassificationChatRequest): + self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + _, engine_prompts = self._preprocess_chat_online( + request, + request.messages, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, + ) + elif isinstance(request, ClassificationCompletionRequest): + engine_prompts = self._preprocess_completion_online( + request, + prompt_input=request.input, + prompt_embeds=None, + ) + else: + raise ValueError("Invalid classification request type") + return engine_prompts + + def pre_process_offline( + self, + prompts: PromptType | Sequence[PromptType], + tokenization_kwargs: dict[str, Any] | None = None, + ) -> Sequence[ProcessorInputs]: + return self._preprocess_completion_offline( + prompts=prompts, tokenization_kwargs=tokenization_kwargs + ) diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index 8cdbbde6d6f6..efd4be77c527 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -1,116 +1,57 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Final, TypeAlias +from typing import TypeAlias -import jinja2 import numpy as np -from fastapi import Request - -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo -from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext -from vllm.entrypoints.openai.models.serving import OpenAIServingModels -from vllm.entrypoints.pooling.classify.protocol import ( - ClassificationChatRequest, - ClassificationCompletionRequest, + +from vllm import ClassificationOutput +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ChatTemplateConfig +from vllm.entrypoints.openai.engine.protocol import UsageInfo +from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing +from vllm.logger import init_logger +from vllm.renderers import BaseRenderer + +from .io_processor import ClassifyIOProcessor +from .protocol import ( ClassificationData, ClassificationRequest, ClassificationResponse, ) -from vllm.logger import init_logger -from vllm.outputs import ClassificationOutput logger = init_logger(__name__) -ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest] +ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest] -class ServingClassification(OpenAIServing): +class ServingClassification(PoolingServing): request_id_prefix = "classify" - def __init__( + def init_io_processor( self, - engine_client: EngineClient, - models: OpenAIServingModels, - *, - request_logger: RequestLogger | None, - chat_template: str | None = None, - chat_template_content_format: ChatTemplateContentFormatOption = "auto", - trust_request_chat_template: bool = False, - log_error_stack: bool = False, - ) -> None: - super().__init__( - engine_client=engine_client, - models=models, - request_logger=request_logger, - log_error_stack=log_error_stack, + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, + ) -> ClassifyIOProcessor: + return ClassifyIOProcessor( + model_config=model_config, + renderer=renderer, + chat_template_config=chat_template_config, ) - self.chat_template = chat_template - self.chat_template_content_format: Final = chat_template_content_format - self.trust_request_chat_template = trust_request_chat_template - - async def _preprocess( + async def _build_response( self, ctx: ClassificationServeContext, - ) -> ErrorResponse | None: - """ - Process classification inputs: tokenize text, resolve adapters, - and prepare model-specific inputs. - """ - try: - ctx.lora_request = self._maybe_get_adapters(ctx.request) - - if isinstance(ctx.request, ClassificationChatRequest): - error_check_ret = self._validate_chat_template( - request_chat_template=ctx.request.chat_template, - chat_template_kwargs=ctx.request.chat_template_kwargs, - trust_request_chat_template=self.trust_request_chat_template, - ) - if error_check_ret: - return error_check_ret - - _, ctx.engine_prompts = await self._preprocess_chat( - ctx.request, - ctx.request.messages, - default_template=self.chat_template, - default_template_content_format=self.chat_template_content_format, - default_template_kwargs=None, - ) - elif isinstance(ctx.request, ClassificationCompletionRequest): - ctx.engine_prompts = await self._preprocess_completion( - ctx.request, - prompt_input=ctx.request.input, - prompt_embeds=None, - ) - else: - return self.create_error_response("Invalid classification request type") - - return None - - except (ValueError, TypeError, jinja2.TemplateError) as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) - - def _build_response( - self, - ctx: ClassificationServeContext, - ) -> ClassificationResponse | ErrorResponse: - """ - Convert model outputs to a formatted classification response - with probabilities and labels. - """ - id2label = getattr(self.model_config.hf_config, "id2label", {}) + ) -> ClassificationResponse: + final_res_batch_checked = await self.io_processor.post_process_async( + ctx.final_res_batch + ) - items: list[ClassificationData] = [] + id2label = getattr(self.model_config.hf_config, "id2label", {}) num_prompt_tokens = 0 - - final_res_batch_checked = ctx.final_res_batch - + items: list[ClassificationData] = [] for idx, final_res in enumerate(final_res_batch_checked): classify_res = ClassificationOutput.from_base(final_res.outputs) @@ -141,20 +82,3 @@ def _build_response( data=items, usage=usage, ) - - async def create_classify( - self, - request: ClassificationRequest, - raw_request: Request, - ) -> ClassificationResponse | ErrorResponse: - model_name = self.models.model_name() - request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" - - ctx = ClassificationServeContext( - request=request, - raw_request=raw_request, - model_name=model_name, - request_id=request_id, - ) - - return await self.handle(ctx) # type: ignore[return-value] diff --git a/vllm/entrypoints/pooling/io_processor_factories.py b/vllm/entrypoints/pooling/io_processor_factories.py new file mode 100644 index 000000000000..97476768cc6e --- /dev/null +++ b/vllm/entrypoints/pooling/io_processor_factories.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ChatTemplateConfig +from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor +from vllm.renderers import BaseRenderer +from vllm.tasks import SupportedTask + + +def init_pooling_io_processors( + supported_tasks: tuple[SupportedTask, ...], + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, +) -> dict[str, PoolingIOProcessor]: + pooling_io_processors: dict[str, PoolingIOProcessor] = {} + + if "classify" in supported_tasks: + from vllm.entrypoints.pooling.classify.io_processor import ( + ClassifyIOProcessor, + ) + + pooling_io_processors["classify"] = ClassifyIOProcessor( + model_config=model_config, + renderer=renderer, + chat_template_config=chat_template_config, + ) + + return pooling_io_processors diff --git a/vllm/entrypoints/pooling/typing.py b/vllm/entrypoints/pooling/typing.py new file mode 100644 index 000000000000..87d6487edb31 --- /dev/null +++ b/vllm/entrypoints/pooling/typing.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TypeAlias + +from vllm.entrypoints.pooling.classify.protocol import ( + ClassificationChatRequest, + ClassificationCompletionRequest, + ClassificationResponse, +) +from vllm.entrypoints.pooling.embed.protocol import ( + EmbeddingBytesResponse, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + EmbeddingResponse, +) +from vllm.entrypoints.pooling.pooling.protocol import ( + IOProcessorRequest, + PoolingChatRequest, + PoolingCompletionRequest, + PoolingResponse, +) +from vllm.entrypoints.pooling.score.protocol import ( + RerankRequest, + ScoreRequest, + ScoreResponse, +) + +PoolingCompletionLikeRequest: TypeAlias = ( + EmbeddingCompletionRequest + | ClassificationCompletionRequest + | RerankRequest + | ScoreRequest + | PoolingCompletionRequest +) + +PoolingChatLikeRequest: TypeAlias = ( + EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest +) + +AnyPoolingRequest: TypeAlias = ( + PoolingCompletionLikeRequest | PoolingChatLikeRequest | IOProcessorRequest +) + +AnyPoolingResponse: TypeAlias = ( + ClassificationResponse + | EmbeddingResponse + | EmbeddingBytesResponse + | PoolingResponse + | ScoreResponse +) diff --git a/vllm/entrypoints/sagemaker/api_router.py b/vllm/entrypoints/sagemaker/api_router.py index 1138225c36fb..32faaa02e681 100644 --- a/vllm/entrypoints/sagemaker/api_router.py +++ b/vllm/entrypoints/sagemaker/api_router.py @@ -13,6 +13,7 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.utils import validate_json_request +from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.serve.instrumentator.basic import base from vllm.entrypoints.serve.instrumentator.health import health from vllm.tasks import POOLING_TASKS, SupportedTask @@ -20,7 +21,7 @@ # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # (requires typing_extensions >= 4.13) RequestType = Any -GetHandlerFn = Callable[[Request], OpenAIServing | None] +GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None] EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 34df85f37a24..6390a72ce0e1 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -5,7 +5,10 @@ import dataclasses import functools import os +import sys +import traceback from argparse import Namespace +from http import HTTPStatus from logging import Logger from string import Template from typing import TYPE_CHECKING @@ -17,17 +20,23 @@ from vllm import envs from vllm.engine.arg_utils import EngineArgs +from vllm.exceptions import VLLMValidationError from vllm.logger import current_formatter_type, init_logger from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser if TYPE_CHECKING: - from vllm.entrypoints.openai.engine.protocol import StreamOptions + from vllm.entrypoints.openai.engine.protocol import ( + ErrorInfo, + ErrorResponse, + StreamOptions, + ) from vllm.entrypoints.openai.models.protocol import LoRAModulePath else: - StreamOptions = object + ErrorResponse = object + ErrorInfo = object LoRAModulePath = object - + StreamOptions = object logger = init_logger(__name__) @@ -291,3 +300,59 @@ def log_version_and_model(lgr: Logger, version: str, model_name: str) -> None: message = logo_template.substitute(colors) lgr.info(message, version, model_name) + + +def create_error_response( + message: str | Exception, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + param: str | None = None, + log_error_stack: bool = False, +) -> "ErrorResponse": + exc: Exception | None = None + + from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse + + if isinstance(message, Exception): + exc = message + + if isinstance(exc, VLLMValidationError): + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = exc.parameter + elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)): + # Common validation errors from user input + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = None + elif isinstance(exc, NotImplementedError): + err_type = "NotImplementedError" + status_code = HTTPStatus.NOT_IMPLEMENTED + param = None + elif exc.__class__.__name__ == "TemplateError": + # jinja2.TemplateError (avoid importing jinja2) + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = None + else: + err_type = "InternalServerError" + status_code = HTTPStatus.INTERNAL_SERVER_ERROR + param = None + + message = str(exc) + + if log_error_stack: + exc_type, _, _ = sys.exc_info() + if exc_type is not None: + traceback.print_exc() + else: + traceback.print_stack() + + return ErrorResponse( + error=ErrorInfo( + message=sanitize_message(message), + type=err_type, + code=status_code.value, + param=param, + ) + ) From fb7fdc49c4a0c629fd92a5e49c08ec86f5dd8ff9 Mon Sep 17 00:00:00 2001 From: TJian Date: Tue, 3 Mar 2026 22:24:21 +0800 Subject: [PATCH 25/53] [ROCm] [CI] Add new fusion test cases that are relevant to vLLM IR Ops (#34307) Signed-off-by: tjtanaa Signed-off-by: vllmellm Co-authored-by: vllmellm --- .buildkite/test-amd.yaml | 147 +++++++++++++----- tests/compile/fusions_e2e/common.py | 4 + tests/compile/fusions_e2e/conftest.py | 5 + tests/compile/fusions_e2e/models.py | 22 ++- tests/compile/fusions_e2e/test_tp1_quant.py | 42 ++++- tests/compile/fusions_e2e/test_tp2_ar_rms.py | 3 + .../compile/fusions_e2e/test_tp2_async_tp.py | 3 + .../distributed/test_sequence_parallelism.py | 2 + .../passes/test_silu_mul_quant_fusion.py | 28 +++- .../passes/fusion/rocm_aiter_fusion.py | 22 ++- 10 files changed, 217 insertions(+), 61 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 4f0db88fe702..2b80937e8580 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -610,6 +610,8 @@ steps: --ignore=lora/test_qwen3moe_tp.py parallelism: 4 +##### .buildkite/test_areas/pytorch.yaml ##### +# corresponds to .buildkite/test_areas/pytorch.yaml - label: PyTorch Compilation Unit Tests # 15min timeout_in_minutes: 30 mirror_hardwares: [amdexperimental, amdproduction] @@ -627,6 +629,20 @@ steps: # they do not suffer from https://github.com/vllm-project/vllm/issues/28965 - "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;" +# corresponds to .buildkite/test_areas/pytorch.yaml +- label: PyTorch Compilation Passes Unit Tests + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + source_file_dependencies: + - vllm/ + - tests/compile/passes + commands: + # TODO: clean up this comment if not needed. It is used to + # keep track of the tests changes during vLLM IR Ops refactoring. + # Use `find` to launch multiple instances of pytest. + - "find compile/passes -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;" + - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 mirror_hardwares: [amdexperimental, amdproduction] @@ -1211,41 +1227,6 @@ steps: - pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py -- label: Blackwell Fusion and Compile Tests # 30 min - timeout_in_minutes: 40 - working_dir: "/vllm-workspace/" - gpu: b200 - source_file_dependencies: - - csrc/quantization/fp4/ - - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - - vllm/v1/attention/backends/flashinfer.py - - vllm/v1/worker/ - - vllm/v1/cudagraph_dispatcher.py - - vllm/compilation/ - # can affect pattern matching - - vllm/model_executor/layers/layernorm.py - - vllm/model_executor/layers/activation.py - - vllm/model_executor/layers/quantization/input_quant_fp8.py - - tests/compile/passes/test_fusion_attn.py - - tests/compile/passes/test_silu_mul_quant_fusion.py - - tests/compile/passes/distributed/test_fusion_all_reduce.py - - tests/compile/fullgraph/test_full_graph.py - commands: - - nvidia-smi - - pytest -v -s tests/compile/passes/test_fusion_attn.py - - pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py - # this runner has 2 GPUs available even though num_gpus=2 is not set - - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py - - # # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time - # # Wrap with quotes to escape yaml - # - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'" - # Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293 - # in favor of new tests in fusions_e2e. We avoid replicating the new jobs in this file as it's deprecated. - - # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) - - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile - - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 working_dir: "/vllm-workspace/" @@ -1371,7 +1352,6 @@ steps: - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - - pytest -v -s compile/correctness_e2e/test_sequence_parallel.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - pytest -v -s v1/worker/test_worker_memory_snapshot.py @@ -1601,16 +1581,16 @@ steps: commands: - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py - pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py - - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py + # TODO: this test is not supported on ROCm, there are aiter kernels for this. + # - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py #- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm # - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" # Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293 # in favor of new tests in fusions_e2e. We avoid replicating the new jobs in this file as it's deprecated. - - - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py - HIP_VISIBLE_DEVICES=0,1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=allgather_reducescatter --disable-nccl-for-dp-synchronization - - pytest -v -s tests/v1/distributed/test_dbo.py + # this test is not supported on ROCm + # - pytest -v -s tests/v1/distributed/test_dbo.py ##### B200 test ##### - label: Distributed Tests (B200) # optional @@ -1721,6 +1701,93 @@ steps: commands: - bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040 +##### .buildkite/test_areas/compile.yaml ##### +# Slowly setting up the tests so that it is also easier for the +# CI team to review and upstream to the pipelinev2. +# The following tests are important for vLLM IR Ops refactoring, +# which affects fusion passes on ROCm. So we have to +# enable them as as soon as possible. + +## TODO: Enable the test in this group +# # corresponds to .buildkite/test_areas/compile.yaml +# - label: Fusion and Compile Unit Tests (2xMI325 GPUs) +# timeout_in_minutes: 20 +# working_dir: "/vllm-workspace/" +# mirror_hardwares: [amdexperimental, amdproduction, tj] +# agent_pool: mi325_1 # changed to 1 GPU until the fusion all reduce is enabled then only revert back to 2 GPUs +# source_file_dependencies: +# - csrc/quantization/fp4/ +# - vllm/model_executor/layers/quantization/ +# - vllm/model_executor/layers/layernorm.py +# - vllm/model_executor/layers/activation.py +# - vllm/model_executor/layers/attention/attention.py +# - vllm/v1/attention/backends/flashinfer.py +# - vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes +# - tests/compile/test_fusion_attn.py +# - tests/compile/test_silu_mul_quant_fusion.py +# - tests/compile/distributed/test_fusion_all_reduce.py +# - tests/compile/fullgraph/test_full_graph.py +# commands: +# - rocm-smi +# # we run all backend tests on ROCm +# # These two tests are covered in "PyTorch Compilation Passes Unit Tests" +# # - "pytest -v -s tests/compile/passes/test_fusion_attn.py" +# # - "pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py" +# # TODO: this test is not supported on ROCm, there are aiter kernels for this. +# # - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py +# # TODO: find out more details +# # - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile + +# corresponds to .buildkite/test_areas/compile.yaml +- label: Fusion E2E Quick (MI325) + timeout_in_minutes: 15 + working_dir: "/vllm-workspace/" + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + num_devices: 1 + source_file_dependencies: + - csrc/quantization/ + - vllm/model_executor/ + - vllm/v1/attention/ + - vllm/compilation/ + - tests/compile/fusions_e2e/ + commands: + - rocm-smi + # Run all models and attn backends but only Inductor partition and native custom ops + - "pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k 'inductor_partition and not +rms_norm and not +quant_fp8'" + # Different from CUDA, Qwen requires +rms_norm and +quant_fp8 as rms+quant fusion is only supported on AITER + - "pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k 'inductor_partition and +rms_norm and +quant_fp8 and qwen3'" + +# corresponds to .buildkite/test_areas/compile.yaml +- label: Fusion E2E Config Sweep (MI325) + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/" + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + num_devices: 1 + source_file_dependencies: + - csrc/quantization/ + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/attention/attention.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + - tests/compile/fusions_e2e/ + commands: + - rocm-smi + # Run just llama3 (fp8) for all config combinations + - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "llama-3" + +## There are no ops on ROCm for these tests. +## The test still passes but the logs are not useful. +## fused ops just call torch.ops.symm_mem which +## exists in ROCm even though they don't work +# - label: AsyncTP Correctness Tests (2xMI325 GPUs) +# - label: Fusion E2E TP2 Quick (MI325) +# - label: Fusion E2E TP2 AsyncTP Config Sweep (MI325) +# - label: Fusion E2E TP2 (MI325) +# - label: Sequence Parallel Correctness Tests (2xMI325 GPUs) ##################################################################################################################################### diff --git a/tests/compile/fusions_e2e/common.py b/tests/compile/fusions_e2e/common.py index 284a9d66b957..2c6dc2b3ebbc 100644 --- a/tests/compile/fusions_e2e/common.py +++ b/tests/compile/fusions_e2e/common.py @@ -13,6 +13,7 @@ class Matches(NamedTuple): # simple pointwise + aiter_rms_quant_fusion: int = 0 rms_quant_fusion: int = 0 act_quant_fusion: int = 0 norm_rope_fusion: int = 0 @@ -82,6 +83,9 @@ def has_cuda_graph_wrapper_metadata() -> bool: ] FUSION_LOG_PATTERNS: dict[str, re.Pattern] = { + "aiter_rms_quant_fusion": re.compile( + r"RocmAiterRMSNormQuantFusionPass Replaced (\d+) patterns" + ), "rms_quant_fusion": re.compile(r"rms_quant_fusion.py:\d+] Replaced (\d+) patterns"), "act_quant_fusion": re.compile(r"act_quant_fusion.py:\d+] Replaced (\d+) patterns"), "norm_rope_fusion": re.compile( diff --git a/tests/compile/fusions_e2e/conftest.py b/tests/compile/fusions_e2e/conftest.py index 40b4de57f66f..d083b6f14e4b 100644 --- a/tests/compile/fusions_e2e/conftest.py +++ b/tests/compile/fusions_e2e/conftest.py @@ -63,9 +63,14 @@ def run( compilation_config: dict, matches_check: list[str], use_deepgemm: bool = False, + use_aiter: bool = False, tp_size: int = 1, ): monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1" if use_deepgemm else "0") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_aiter else "0") + from vllm._aiter_ops import rocm_aiter_ops + + rocm_aiter_ops.refresh_env_variables() # Disable, compile cache to make sure custom passes run. # Otherwise, we can't verify fusion happened through the logs. diff --git a/tests/compile/fusions_e2e/models.py b/tests/compile/fusions_e2e/models.py index f54f617c64d4..e18bc1ee5652 100644 --- a/tests/compile/fusions_e2e/models.py +++ b/tests/compile/fusions_e2e/models.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +from vllm._aiter_ops import is_aiter_found_and_supported +from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -24,6 +26,24 @@ AttentionBackendCase(backend=AttentionBackendEnum.TRITON_ATTN), id="TRITON_ATTN" ) +ROCM_ATTN = pytest.param( + AttentionBackendCase(backend=AttentionBackendEnum.ROCM_ATTN), + id="ROCM_ATTN", + marks=pytest.mark.skipif( + not current_platform.is_rocm(), + reason="ROCm attention only for AMD", + ), +) + +ROCM_AITER_UNIFIED_ATTN = pytest.param( + AttentionBackendCase(backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN), + id="ROCM_AITER_UNIFIED_ATTN", + marks=pytest.mark.skipif( + not is_aiter_found_and_supported(), + reason="ROCM_AITER_UNIFIED_ATTN only for AMD when AITER is installed", + ), +) + # Models llama3_8b = ModelFusionInfo( model_name="meta-llama/Llama-3.1-8B-Instruct", @@ -49,7 +69,6 @@ llama3_8b_fp4 = ModelFusionInfo( model_name="nvidia/Llama-3.1-8B-Instruct-FP4", matches=lambda n_layers: Matches( - rms_quant_fusion=0, act_quant_fusion=n_layers, attn_quant_fusion=n_layers, ar_rms_fusion=n_layers * 2 + 1, @@ -79,7 +98,6 @@ model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-NVFP4", hf_overrides=lambda n_layers: {"text_config": {"num_hidden_layers": n_layers}}, matches=lambda n_layers: Matches( - rms_quant_fusion=0, attn_quant_fusion=n_layers, ar_rms_fusion=n_layers * 2, sequence_parallel=n_layers * 2, diff --git a/tests/compile/fusions_e2e/test_tp1_quant.py b/tests/compile/fusions_e2e/test_tp1_quant.py index f98400c2e26d..917116515f89 100644 --- a/tests/compile/fusions_e2e/test_tp1_quant.py +++ b/tests/compile/fusions_e2e/test_tp1_quant.py @@ -5,6 +5,7 @@ import pytest from vllm.config import PassConfig +from vllm.platforms import current_platform from vllm.utils.flashinfer import is_flashinfer_fp8_blockscale_gemm_supported from .common import ( @@ -16,6 +17,8 @@ ) from .models import ( FLASHINFER_ATTN, + ROCM_AITER_UNIFIED_ATTN, + ROCM_ATTN, TRITON_ATTN, llama3_8b_fp4, llama3_8b_fp8, @@ -29,12 +32,33 @@ "model_name, matches_fn, model_kwargs, hf_overrides, use_deepgemm", [ (*llama3_8b_fp8, False), - (*llama4_scout_fp8, False), (*qwen3_a3b_fp8, False), - (*qwen3_a3b_fp8, True), + pytest.param( + *llama4_scout_fp8, + False, + marks=pytest.mark.skipif( + not current_platform.is_cuda(), + reason="Llama4 Scout FP8 only supported on CUDA", + ), + ), + pytest.param( + *qwen3_a3b_fp8, + True, + marks=pytest.mark.skipif( + not current_platform.is_cuda(), reason="DeepGemm only supported on CUDA" + ), + ), + ], +) +@pytest.mark.parametrize( + "attn_backend", + [ + TRITON_ATTN, + FLASHINFER_ATTN, + ROCM_ATTN, + ROCM_AITER_UNIFIED_ATTN, ], ) -@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN]) @pytest.mark.parametrize("n_layers", [6]) @pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) @@ -81,6 +105,8 @@ def test_tp1_fp8_fusions( ), ) + use_aiter = current_platform.is_rocm() and ("qwen" in model_name.lower()) + matches_check = [ "rms_quant_fusion", "act_quant_fusion", @@ -88,6 +114,15 @@ def test_tp1_fp8_fusions( "attn_quant_fusion", ] + if use_aiter: + matches_check[0] = "aiter_rms_quant_fusion" + + matches = matches._replace(aiter_rms_quant_fusion=matches.rms_quant_fusion) + # TODO: enable the `norm_rope_fusion` test, + # On ROCm norm_rope_fusion is only supported without + # enabling AITER. + matches_check.remove("norm_rope_fusion") + run_e2e_fusion_test( model_name, matches, @@ -96,6 +131,7 @@ def test_tp1_fp8_fusions( compilation_config, matches_check, use_deepgemm=use_deepgemm, + use_aiter=use_aiter, ) diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index 18b19565c1fc..ab4aefcaf79a 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -5,6 +5,7 @@ import pytest from vllm.config import PassConfig +from vllm.platforms import current_platform from ...utils import multi_gpu_test from .common import ( @@ -26,6 +27,8 @@ qwen3_a3b_fp8, ) +pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( diff --git a/tests/compile/fusions_e2e/test_tp2_async_tp.py b/tests/compile/fusions_e2e/test_tp2_async_tp.py index 921839ea0692..9657d64b88f7 100644 --- a/tests/compile/fusions_e2e/test_tp2_async_tp.py +++ b/tests/compile/fusions_e2e/test_tp2_async_tp.py @@ -5,6 +5,7 @@ import pytest from vllm.config import PassConfig +from vllm.platforms import current_platform from ...utils import multi_gpu_test from .common import ( @@ -23,6 +24,8 @@ qwen3_a3b, ) +pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( diff --git a/tests/compile/passes/distributed/test_sequence_parallelism.py b/tests/compile/passes/distributed/test_sequence_parallelism.py index 78c3cf92a067..a0fe717ba026 100644 --- a/tests/compile/passes/distributed/test_sequence_parallelism.py +++ b/tests/compile/passes/distributed/test_sequence_parallelism.py @@ -36,6 +36,8 @@ from vllm.utils.system_utils import update_environment_variables from vllm.utils.torch_utils import set_random_seed +pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") + FP8_DTYPE = current_platform.fp8_dtype() prompts = [ "Hello, my name is", diff --git a/tests/compile/passes/test_silu_mul_quant_fusion.py b/tests/compile/passes/test_silu_mul_quant_fusion.py index cc06208ea758..a77b4e6de7bd 100644 --- a/tests/compile/passes/test_silu_mul_quant_fusion.py +++ b/tests/compile/passes/test_silu_mul_quant_fusion.py @@ -182,8 +182,24 @@ def ops_in_model_after(self): "model_class, enable_quant_fp8_custom_op, force_kernel", list(itertools.product([TestSiluMulFp8QuantModel], [True, False], TEST_KERNELS)) + [ - (TestSiluMulNvfp4QuantModel, False, None), - (TestSiluMulGroupFp8QuantModel, False, None), + pytest.param( + TestSiluMulNvfp4QuantModel, + False, + None, + marks=pytest.mark.skipif( + not current_platform.is_cuda(), reason="CUDA only" + ), + ), + # GroupFP8Quant fusion only works with AITER on ROCm. + # and the enable_quant_fp8_custom_op must be True. + pytest.param( + TestSiluMulGroupFp8QuantModel, + True, + None, + marks=pytest.mark.skipif( + not current_platform.is_rocm(), reason="ROCm only" + ), + ), ], ) @pytest.mark.skipif( @@ -201,6 +217,7 @@ def test_fusion_silu_and_mul_quant( enable_silu_mul_custom_op: bool, enable_quant_fp8_custom_op: bool, force_kernel: FP8ScaledMMLinearKernel | None, + monkeypatch: pytest.MonkeyPatch, ): if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): pytest.skip("NVFP4 is not supported on this GPU.") @@ -227,13 +244,16 @@ def test_fusion_silu_and_mul_quant( ), ) - with set_current_vllm_config(config): + with set_current_vllm_config(config), monkeypatch.context() as m: fusion_passes = [ActivationQuantFusionPass(config)] - if IS_AITER_FOUND: + if IS_AITER_FOUND and model_class is TestSiluMulGroupFp8QuantModel: + from vllm._aiter_ops import rocm_aiter_ops from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( RocmAiterSiluMulFp8GroupQuantFusionPass, ) + m.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)] passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)] diff --git a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py index d8131ce952d2..59c94db5e812 100644 --- a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py +++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py @@ -5,7 +5,6 @@ import torch._inductor.pattern_matcher as pm from torch import fx from torch._inductor.pattern_matcher import PatternMatcherPass -from torch._ops import OpOverload import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 from vllm._aiter_ops import rocm_aiter_ops @@ -15,6 +14,7 @@ GroupShape, QuantKey, ScaleDesc, + kFp8Dynamic128Sym, ) from vllm.platforms import current_platform @@ -312,7 +312,9 @@ def __init__(self, config: VllmConfig) -> None: @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: self.matched_count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", self.matched_count) + logger.debug( + "%s Replaced %s patterns", self.__class__.__name__, self.matched_count + ) def uuid(self) -> str: fusion_patterns = [ @@ -332,9 +334,11 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op() - def __init__(self, quant_op: OpOverload) -> None: + def __init__(self) -> None: self.silu_and_mul_matcher = MatcherSiluAndMul() - self.quant_op = quant_op + self.quant_matcher = MatcherQuantFP8( + quant_key=kFp8Dynamic128Sym, match_rocm_aiter=True + ) def get_inputs(self) -> list[torch.Tensor]: return [ @@ -346,7 +350,7 @@ def pattern( input: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: at1 = self.silu_and_mul_matcher(input) - at2 = self.quant_op(at1, 128) + at2 = self.quant_matcher(at1) return at2[0], at2[1] def replacement( @@ -370,11 +374,6 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 """ - AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op() - TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default - - QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP] - @enable_fake_mode def __init__(self, config: VllmConfig) -> None: super().__init__(config) @@ -383,8 +382,7 @@ def __init__(self, config: VllmConfig) -> None: pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass" ) - for quant_op in self.QUANT_OPS: - AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns) + AiterSiluMulFp8GroupQuantPattern().register(self.patterns) self.dump_patterns(config, self.patterns) From 28ef9ba399340ea7013df8cd1c359b07acc0a302 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 3 Mar 2026 10:21:57 -0500 Subject: [PATCH 26/53] [BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552) Signed-off-by: Lucas Wilkinson Signed-off-by: Matthew Bonanni Co-authored-by: Matthew Bonanni --- tests/v1/spec_decode/test_eagle.py | 53 ++--- tests/v1/spec_decode/test_mtp.py | 10 +- .../layers/sparse_attn_indexer.py | 6 + vllm/v1/attention/backends/mla/indexer.py | 140 ++++++++--- vllm/v1/spec_decode/eagle.py | 220 ++++++++---------- vllm/v1/worker/gpu_model_runner.py | 22 +- vllm/v1/worker/utils.py | 2 +- 7 files changed, 258 insertions(+), 195 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 8b180168dffc..cdbbdb13ebe6 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -476,12 +476,12 @@ def test_set_inputs_first_pass_draft_model(): proposer.max_num_tokens, dtype=torch.bool, device=device ) - # Mock the attn_metadata_builder to avoid needing the full model setup + # Mock draft_attn_groups to avoid needing the full model setup mock_kv_cache_spec = mock.MagicMock() mock_kv_cache_spec.block_size = block_size - mock_builder = mock.MagicMock() - mock_builder.kv_cache_spec = mock_kv_cache_spec - proposer.attn_metadata_builder = mock_builder + mock_attn_group = mock.MagicMock() + mock_attn_group.kv_cache_spec = mock_kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] # Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2 batch_spec = BatchSpec( @@ -616,12 +616,12 @@ def test_set_inputs_first_pass_parallel_drafting(): proposer.max_num_tokens, dtype=torch.bool, device=device ) - # Mock the attn_metadata_builder + # Mock draft_attn_groups mock_kv_cache_spec = mock.MagicMock() mock_kv_cache_spec.block_size = block_size - mock_builder = mock.MagicMock() - mock_builder.kv_cache_spec = mock_kv_cache_spec - proposer.attn_metadata_builder = mock_builder + mock_attn_group = mock.MagicMock() + mock_attn_group.kv_cache_spec = mock_kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] # Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid) batch_spec = BatchSpec( @@ -916,7 +916,7 @@ def create_deterministic_logits(token_ids): proposer.model = model_mock # Assign draft attn_layer_names since load_model is not invoked - proposer.attn_layer_names = ["layer.0"] + proposer._draft_attn_layer_names = {"layer.0"} # Create input tensors batch_spec = BatchSpec( @@ -961,20 +961,18 @@ def create_deterministic_logits(token_ids): attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), - layer_names=proposer.attn_layer_names, + layer_names=proposer._draft_attn_layer_names, vllm_config=proposer.vllm_config, device=device, ) - # Mock runner for attention metadata building + # Mock runner and draft_attn_groups for attention metadata building proposer.runner = mock.MagicMock() - proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][ - 0 - ].get_metadata_builder.return_value = attn_metadata_builder - proposer._get_attention_metadata_builder = mock.MagicMock( - return_value=attn_metadata_builder - ) + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder + mock_attn_group.layer_names = list(proposer._draft_attn_layer_names) + mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] result = proposer.propose( target_token_ids=target_token_ids, @@ -1089,7 +1087,7 @@ def create_deterministic_logits(token_ids, k: int): proposer.model = model_mock # Assign draft attn_layer_names since load_model is not invoked - proposer.attn_layer_names = ["layer.0"] + proposer._draft_attn_layer_names = {"layer.0"} # Get the tree attention metadata builder. attn_metadata_builder_cls, _ = try_get_attention_backend( @@ -1097,21 +1095,18 @@ def create_deterministic_logits(token_ids, k: int): ) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), - layer_names=proposer.attn_layer_names, + layer_names=proposer._draft_attn_layer_names, vllm_config=proposer.vllm_config, device=device, ) - # Mock runner for attention metadata building. + # Mock runner and draft_attn_groups for attention metadata building. proposer.runner = mock.MagicMock() - proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder] - proposer.runner.attn_groups[0][ - 0 - ].get_metadata_builder.return_value = attn_metadata_builder - proposer._get_attention_metadata_builder = mock.MagicMock( - return_value=attn_metadata_builder - ) + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder + mock_attn_group.layer_names = list(proposer._draft_attn_layer_names) + mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] # Setup inputs for the proposer. target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index 16f4fb0befe6..0a48b0e7b98c 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -162,7 +162,7 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): model_mock.compute_logits.side_effect = logits_returns proposer.model = model_mock - proposer.attn_layer_names = ["layer.0"] + proposer._draft_attn_layer_names = {"layer.0"} # Prepare inputs batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) @@ -190,13 +190,17 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), - layer_names=proposer.attn_layer_names, + layer_names=list(proposer._draft_attn_layer_names), vllm_config=proposer.vllm_config, device=device, ) proposer.runner = mock.MagicMock() - proposer.attn_metadata_builder = attn_metadata_builder + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder + mock_attn_group.layer_names = list(proposer._draft_attn_layer_names) + mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] # Run propose result = proposer.propose( diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index f4ce6fca8d56..5383e2f11e19 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -79,6 +79,12 @@ def sparse_attn_indexer( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens + # During speculative decoding, k may be padded to the CUDA graph batch + # size while slot_mapping only covers actual tokens. Truncate k to avoid + # out-of-bounds reads in the kernel. + num_tokens = slot_mapping.shape[0] + k = k[:num_tokens] + ops.indexer_k_quant_and_cache( k, kv_cache, diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 7c81a4359223..e84312970989 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -12,6 +12,7 @@ get_paged_mqa_logits_metadata, is_deep_gemm_supported, ) +from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import num_compute_units from vllm.v1.attention.backend import ( AttentionBackend, @@ -24,6 +25,7 @@ split_decodes_and_prefills, split_prefill_chunks, ) +from vllm.v1.worker.cp_utils import get_total_cp_world_size logger = init_logger(__name__) @@ -214,20 +216,39 @@ def __init__(self, *args, **kwargs): if self.vllm_config.speculative_config else 0 ) - if self.num_speculative_tokens > 1: - raise ValueError( - "Sparse MLA only supports " - "num_speculative_tokens <= 1 because the DeepGEMM " - "fp8_paged_mqa_logits kernel does not support next_n > 2. " - f"Got num_speculative_tokens={self.num_speculative_tokens}." - ) self.reorder_batch_threshold += self.num_speculative_tokens sm_count = num_compute_units(self.device.index) self.num_sms = sm_count self.decode_lens_buffer = torch.empty( - (scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device + (scheduler_config.max_num_batched_tokens,), + dtype=torch.int32, + device=self.device, + ) + + # Pre-allocated buffers for flattening (spec decode). + self.arange_buffer = torch.arange( + scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens), + dtype=torch.int32, + device=self.device, + ) + self.expanded_seq_lens_buffer = torch.zeros( + (scheduler_config.max_num_batched_tokens,), + dtype=torch.int32, + device=self.device, + ) + max_num_blocks_per_req = cdiv( + self.vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size * get_total_cp_world_size(), + ) + self.expanded_block_table_buffer = torch.zeros( + ( + scheduler_config.max_num_batched_tokens, + max_num_blocks_per_req, + ), + dtype=torch.int32, + device=self.device, ) # See: DeepGMM/csrc/apis/attention.hpp @@ -326,42 +347,97 @@ def build( common_attn_metadata.query_start_loc_cpu[: num_decodes + 1] ) - # Use CPU to avoid GPU sync; breaking async scheduling - requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() + seq_lens = common_attn_metadata.seq_lens[:num_decodes] + block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...] - # Decide which top-k kernel to use based on batch size and sequence length - batch_size = num_decodes - _is_large_context = common_attn_metadata.max_seq_len > 8192 + # Padded CUDA graph requests have block_table entries of -1. + # Clamp to 0 to prevent OOB access in the DeepGEMM kernel. + # This is safe because padded requests have seq_lens=0, so the + # kernel produces no meaningful output for those rows. + block_table.clamp_(min=0) - # Decision logic based on micro-benchmark results: - # - large_context_topk wins for batch <= 128 and seq_len > 8K - # - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K - use_large_context_topk = batch_size <= 128 and _is_large_context + max_decode_len = int(decode_lens_cpu.max().item()) + if max_decode_len > 1: + # Flatten multi-token decode requests into single-token + # batch entries, expanding seq_lens and block tables so + # the kernel always sees next_n=1. - next_n = 1 + self.num_speculative_tokens - if next_n > 1: - offsets = torch.arange(next_n, device=self.device, dtype=torch.int32) - else: - offsets = None + # Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is + # padding) and decode_lens [3, 1, 4, 0] in the below example comments. + # The context lengths are therefore + # [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0]. - seq_lens = common_attn_metadata.seq_lens[:num_decodes] + # 3 + 1 + 4 + 0 = 8 + actual_expanded = int(decode_lens_cpu.sum().item()) + + # [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8] + expanded_base = torch.repeat_interleave( + seq_lens - decode_lens, decode_lens + ) + + # [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4] + expanded_starts = torch.repeat_interleave( + common_attn_metadata.query_start_loc[:num_decodes], decode_lens + ) + + # [0, 1, 2, 0, 0, 1, 2, 3] + positions_within = ( + self.arange_buffer[:actual_expanded] - expanded_starts + ) + + # [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space + self.expanded_seq_lens_buffer[:actual_expanded] = ( + expanded_base + positions_within + 1 + ) + self.expanded_seq_lens_buffer[actual_expanded:] = 0 + seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens] + + # Give each of the flattened entries the same block table row as the + # original request. + self.expanded_block_table_buffer[:actual_expanded] = ( + torch.repeat_interleave(block_table, decode_lens, dim=0) + ) + if actual_expanded < num_decode_tokens: + self.expanded_block_table_buffer[ + actual_expanded:num_decode_tokens, 0 + ] = 0 + block_table = self.expanded_block_table_buffer[:num_decode_tokens] + + # All reqs now have decode_len=1 + self.decode_lens_buffer[:num_decode_tokens] = 1 + decode_lens = self.decode_lens_buffer[:num_decode_tokens] + offsets = None + batch_size = num_decode_tokens + else: + next_n = 1 + self.num_speculative_tokens + if next_n > 1: + offsets = torch.arange( + next_n, device=self.device, dtype=torch.int32 + ) + else: + offsets = None + batch_size = num_decodes # DeepGEMM is required for the paged MQA logits on CUDA devices if current_platform.is_cuda() and is_deep_gemm_supported(): self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( - seq_lens, self.kv_cache_spec.block_size, self.num_sms + seq_lens, + self.kv_cache_spec.block_size, + self.num_sms, ) - block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...] - # Padded CUDA graph requests have block_table entries of -1. - # Clamp to 0 to prevent OOB access in the DeepGEMM kernel. - # This is safe because padded requests have seq_lens=0, so the - # kernel produces no meaningful output for those rows. - block_table.clamp_(min=0) + + # Decide which top-k kernel to use based on batch size and sequence length + # Decision logic based on micro-benchmark results: + # - large_context_topk wins for batch <= 128 and seq_len > 8K + # - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K + _is_large_context = common_attn_metadata.max_seq_len > 8192 + use_large_context_topk = batch_size <= 128 and _is_large_context + decode_metadata = DeepSeekV32IndexerDecodeMetadata( block_table=block_table, - seq_lens=common_attn_metadata.seq_lens[:num_decodes], + seq_lens=seq_lens, decode_lens=decode_lens, - requires_padding=requires_padding, + requires_padding=False, schedule_metadata=self.scheduler_metadata_buffer, use_large_context_topk=use_large_context_topk, offsets=offsets, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index e53de6a1de4f..ca58c441f46d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -20,17 +20,13 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal -from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.triton_utils import triton from vllm.utils.platform_utils import is_pin_memory_available -from vllm.v1.attention.backend import ( - AttentionMetadataBuilder, - CommonAttentionMetadata, -) +from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.tree_attn import ( TreeAttentionMetadata, @@ -38,7 +34,7 @@ ) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -53,6 +49,7 @@ from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.utils import AttentionGroup logger = init_logger(__name__) @@ -113,10 +110,8 @@ def __init__( vllm_config.model_config ) - self.attn_metadata_builder: AttentionMetadataBuilder | None = None - self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None - self.attn_layer_names: list[str] = [] - self.indexer_layer_names: list[str] = [] + self.draft_attn_groups: list[AttentionGroup] = [] + self.kv_cache_gid: int = -1 self.eagle3_use_aux_hidden_state: bool = ( self._get_eagle3_use_aux_hidden_state_from_config() ) @@ -353,7 +348,7 @@ def _get_slot_mapping( self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID) view = self._slot_mapping_buffer[:num_tokens] - return {name: view for name in self.attn_layer_names + self.indexer_layer_names} + return {name: view for name in self._draft_attn_layer_names} def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: """Initialize cudagraph dispatcher keys for eagle. @@ -420,33 +415,13 @@ def propose( assert self.runner is not None - if self.attn_metadata_builder is None: - attn_metadata_builder = self._get_attention_metadata_builder() - else: - attn_metadata_builder = self.attn_metadata_builder - - attn_metadata = attn_metadata_builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, draft_index=0 - ) - # FIXME: support hybrid kv for draft model (remove separate indexer) - if self.draft_indexer_metadata_builder: - draft_indexer_metadata = ( - self.draft_indexer_metadata_builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, - draft_index=0, - ) + per_layer_attn_metadata: dict[str, object] = {} + for attn_group in self.draft_attn_groups: + attn_metadata = attn_group.get_metadata_builder().build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0 ) - else: - draft_indexer_metadata = None - # At this moment, we assume all eagle layers belong to the same KV - # cache group, thus using the same attention metadata. - per_layer_attn_metadata = {} - for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata - - for layer_name in self.indexer_layer_names: - assert draft_indexer_metadata is not None - per_layer_attn_metadata[layer_name] = draft_indexer_metadata + for layer_name in attn_group.layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( self._determine_batch_execution_and_padding(num_tokens) @@ -503,12 +478,7 @@ def propose( positions = self.mrope_positions[:, token_indices_to_sample] else: positions = self.positions[token_indices_to_sample] - if self.method in ( - "deepseek_mtp", - "ernie_mtp", - "longcat_flash_mtp", - "pangu_ultra_moe_mtp", - ): + if self.method == "mtp": hidden_states = self.hidden_states[token_indices_to_sample] else: hidden_states = hidden_states[token_indices_to_sample] @@ -613,7 +583,8 @@ def propose( common_attn_metadata._num_computed_tokens_cpu += 1 # Compute the slot mapping. - block_size = attn_metadata_builder.kv_cache_spec.block_size + # Use the first draft attention group's kv_cache_spec for block_size + block_size = self.draft_attn_groups[0].kv_cache_spec.block_size if self.uses_mrope: # all dimensions of positions are the same block_numbers = clamped_positions[0] // block_size @@ -639,11 +610,13 @@ def propose( ) # Rebuild attention metadata - attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore - common_attn_metadata=common_attn_metadata, draft_index=token_index + 1 - ) - for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata + for attn_group in self.draft_attn_groups: + attn_metadata = attn_group.get_metadata_builder().build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=token_index + 1, + ) + for layer_name in attn_group.layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids @@ -805,18 +778,17 @@ def set_inputs_first_pass( # 2. # Recompute the slot mapping based on the new positions and # rejection mask. - builder = ( - self._get_attention_metadata_builder() - if self.attn_metadata_builder is None - else self.attn_metadata_builder - ) + # Use the first draft attention group's kv_cache_spec for block_size + # (all draft layers share the same kv-cache group) + assert len(self.draft_attn_groups) > 0 + block_size = self.draft_attn_groups[0].kv_cache_spec.block_size new_slot_mapping = compute_new_slot_mapping( cad=cad, new_positions=self.positions[:total_num_output_tokens], is_rejected_token_mask=self.is_rejected_token_mask[ :total_num_output_tokens ], - block_size=builder.kv_cache_spec.block_size, + block_size=block_size, num_new_tokens=self.net_num_new_slots_per_request, max_model_len=self.max_model_len, ) @@ -1000,9 +972,7 @@ def propose_tree( | list[dict[str, torch.Tensor]] | None = None, ) -> list[torch.Tensor]: - tree_attn_metadata_builder = self.runner.attn_groups[0][ - 0 - ].get_metadata_builder() + tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder() assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) total_num_drafts = self.cu_drafts_per_level[0] @@ -1078,10 +1048,11 @@ def propose_tree( common_attn_metadata=common_attn_metadata, draft_index=level + 1 ) - # Apply new attention metadata to all layers. + # Apply new attention metadata to all draft layers. per_layer_attn_metadata = {} - for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata + for attn_group in self.draft_attn_groups: + for layer_name in attn_group.layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata # Consider max model length. attn_metadata.max_seq_len = min( @@ -1288,43 +1259,17 @@ def load_model(self, target_model: nn.Module) -> None: AttentionLayerBase, # type: ignore[type-abstract] ).keys() ) - # FIXME: support hybrid kv for draft model - target_indexer_layer_names = set( - get_layers_from_vllm_config( - self.vllm_config, DeepseekV32IndexerCache - ).keys() - ) self.model = self._get_model() - draft_attn_layer_names = ( - get_layers_from_vllm_config( - self.vllm_config, - AttentionLayerBase, # type: ignore[type-abstract] - ).keys() - - target_attn_layer_names + # Find draft layers (attention layers added by draft model) + all_attn_layers = get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] ) - indexer_layers = get_layers_from_vllm_config( - self.vllm_config, DeepseekV32IndexerCache + self._draft_attn_layer_names = ( + set(all_attn_layers.keys()) - target_attn_layer_names ) - draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names - self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names) - self.indexer_layer_names = list(draft_indexer_layer_names) - - if self.indexer_layer_names: - first_layer = self.indexer_layer_names[0] - self.draft_indexer_metadata_builder = ( - indexer_layers[first_layer] - .get_attn_backend() - .get_builder_cls()( - indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config), - self.indexer_layer_names, - self.vllm_config, - self.device, - ) - ) - else: - self.draft_indexer_metadata_builder = None if self.supports_mm_inputs: # Even if the target model is multimodal, we can also use @@ -1562,9 +1507,9 @@ def dummy_run( # Make sure to use EAGLE's own buffer during cudagraph capture. if ( - self.attn_layer_names + self._draft_attn_layer_names and slot_mappings is not None - and self.attn_layer_names[0] in slot_mappings + and next(iter(self._draft_attn_layer_names)) in slot_mappings ): slot_mapping_dict = self._get_slot_mapping(num_input_tokens) else: @@ -1594,31 +1539,6 @@ def dummy_run( kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] self.model(**kwargs) - def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: - """Find and return the attention metadata builders for EAGLE layers. - - Returns: - The metadata builders for EAGLE layers. - - Raises: - AssertionError: If no metadata builders are found for EAGLE layers. - """ - builder = None - chosen_layer = self.attn_layer_names[0] - - for kv_cache_group in self.runner.attn_groups: - for attn_group in kv_cache_group: - if chosen_layer in attn_group.layer_names: - builder = attn_group.get_metadata_builder() - break - if builder is not None: - break - - assert builder is not None, ( - "Failed to find attention metadata builder for EAGLE layers." - ) - return builder - def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool: """ Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary @@ -1651,13 +1571,71 @@ def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: set( [ kv_cache_groups[layer_name] - for layer_name in self.attn_layer_names + for layer_name in self._draft_attn_layer_names ] ) ) == 1 ), "All drafting layers should belong to the same kv cache group" + def initialize_attn_backend( + self, + kv_cache_config: KVCacheConfig, + kernel_block_sizes: list[int] | None = None, + ) -> None: + """ + Initialize AttentionGroups for draft layers using kv_cache_config. + Called from the model runner's initialize_metadata_builders. + """ + all_attn_layers = get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ) + + # Find which kv_cache_group the draft layers belong to + self.validate_same_kv_cache_group(kv_cache_config) + kv_cache_spec = None + for gid, group in enumerate(kv_cache_config.kv_cache_groups): + if self._draft_attn_layer_names & set(group.layer_names): + self.kv_cache_gid = gid + kv_cache_spec = group.kv_cache_spec + break + + attention_groups: dict[tuple[str, str], AttentionGroup] = {} + if kv_cache_spec is not None: + for layer_name in self._draft_attn_layer_names: + attn_backend = all_attn_layers[layer_name].get_attn_backend() + backend_key = attn_backend.full_cls_name() + if backend_key not in attention_groups: + layer_kv_cache_spec = kv_cache_spec + if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ + layer_name + ] + + kernel_block_size = ( + kernel_block_sizes[self.kv_cache_gid] + if kernel_block_sizes is not None + and self.kv_cache_gid < len(kernel_block_sizes) + else None + ) + attn_group = AttentionGroup( + backend=attn_backend, + layer_names=[layer_name], + kv_cache_spec=layer_kv_cache_spec, + kv_cache_group_id=self.kv_cache_gid, + ) + attn_group.create_metadata_builders( + self.vllm_config, + self.device, + kernel_block_size=kernel_block_size, + ) + attention_groups[backend_key] = attn_group + else: + attention_groups[backend_key].layer_names.append(layer_name) + + self.draft_attn_groups = list(attention_groups.values()) + def _determine_batch_execution_and_padding( self, num_tokens: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c9d9ecf4ac67..8c92aab266e6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1936,7 +1936,7 @@ def _build_attn_group_metadata( if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): - if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: + if self.drafter.kv_cache_gid == kv_cache_gid: spec_decode_common_attn_metadata = cm else: spec_decode_common_attn_metadata = cm @@ -5559,6 +5559,14 @@ def initialize_metadata_builders( # because some of them change the threshold at init time. self.calculate_reorder_batch_threshold() + # Initialize drafter attention backend + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ): + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) + self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes) + def _check_and_update_cudagraph_mode( self, attention_backends: list[set[type[AttentionBackend]]], @@ -6079,15 +6087,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config, kernel_block_sizes ) - if self.speculative_config and ( - self.speculative_config.use_eagle() - or self.speculative_config.uses_draft_model() - or self.speculative_config.uses_extract_hidden_states() + if ( + self.speculative_config + and self.speculative_config.uses_extract_hidden_states() ): - assert isinstance( - self.drafter, - EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, - ) + assert isinstance(self.drafter, ExtractHiddenStatesProposer) # validate all draft model layers belong to the same kv cache # group self.drafter.validate_same_kv_cache_group(kv_cache_config) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 7280679800b4..bede06592f7d 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -48,7 +48,7 @@ def create_metadata_builders( self, vllm_config, device, - kernel_block_size: int | None, + kernel_block_size: int | None = None, num_metadata_builders: int = 1, ): kv_cache_spec_builder = ( From e05cb3b93e5db3afd510189651a128018c31c251 Mon Sep 17 00:00:00 2001 From: ojhaanshika Date: Tue, 3 Mar 2026 08:35:34 -0800 Subject: [PATCH 27/53] TRTLLM gen-full attn Test Coverage (#34986) Signed-off-by: Anshika Ojha Co-authored-by: Anshika Ojha --- .../attention/test_use_trtllm_attention.py | 196 ++++++++++ .../test_trtllm_attention_integration.py | 360 ++++++++++++++++++ 2 files changed, 556 insertions(+) create mode 100644 tests/kernels/attention/test_use_trtllm_attention.py create mode 100644 tests/v1/attention/test_trtllm_attention_integration.py diff --git a/tests/kernels/attention/test_use_trtllm_attention.py b/tests/kernels/attention/test_use_trtllm_attention.py new file mode 100644 index 000000000000..e24ad1018638 --- /dev/null +++ b/tests/kernels/attention/test_use_trtllm_attention.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import patch + +import pytest +import torch + +from vllm.utils.flashinfer import ( + can_use_trtllm_attention, + supports_trtllm_attention, + use_trtllm_attention, +) + +MODEL_CONFIGS = { + "Llama-3-70B": dict(num_qo_heads=64, num_kv_heads=8), + "Llama-3-8B": dict(num_qo_heads=32, num_kv_heads=8), + "Qwen2.5-0.5B": dict(num_qo_heads=14, num_kv_heads=2), + "Mistral-7B": dict(num_qo_heads=32, num_kv_heads=8), + "Gemma-2-9B": dict(num_qo_heads=8, num_kv_heads=4), + "Falcon-40B": dict(num_qo_heads=128, num_kv_heads=8), +} + + +def get_config(model: str) -> dict: + """Return the attention config for a model.""" + return MODEL_CONFIGS[model] + + +DEFAULT_KWARGS = dict( + **get_config("Llama-3-70B"), + num_tokens=128, + max_seq_len=4096, + dcp_world_size=1, + kv_cache_dtype="auto", + q_dtype=torch.bfloat16, + is_prefill=False, + force_use_trtllm=None, + has_sinks=False, + has_spec=False, +) + + +def _call(**overrides) -> bool: + kwargs = {**DEFAULT_KWARGS, **overrides} + return use_trtllm_attention(**kwargs) + + +@pytest.fixture(autouse=True) +def _clear_supports_cache(): + """Clear functools.cache to ensure each test runs independently.""" + supports_trtllm_attention.cache_clear() + + +# supports_trtllm_attention + + +@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=True) +def test_supports_batch_invariant_disables(_mock): + assert supports_trtllm_attention() is False + + +@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) +@patch( + "vllm.utils.flashinfer.current_platform.is_device_capability_family", + return_value=True, +) +@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True) +def test_supports_sm100_with_artifactory(_art, _cap, _bi): + assert supports_trtllm_attention() is True + + +@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) +@patch( + "vllm.utils.flashinfer.current_platform.is_device_capability_family", + return_value=False, +) +def test_supports_non_sm100_platform(_cap, _bi): + assert supports_trtllm_attention() is False + + +@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) +@patch( + "vllm.utils.flashinfer.current_platform.is_device_capability_family", + return_value=True, +) +@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=False) +def test_supports_sm100_without_artifactory(_art, _cap, _bi): + assert supports_trtllm_attention() is False + + +# can_use_trtllm_attention + + +@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=False) +def test_can_use_force_disabled(_mock): + cfg = get_config("Llama-3-70B") + assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False + + +@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None) +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_can_use_compatible_heads(_sup, _force): + cfg = get_config("Llama-3-70B") + assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is True + + +@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None) +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_can_use_incompatible_heads(_sup, _force): + assert can_use_trtllm_attention(40, 6) is False + + +@pytest.mark.parametrize("model", list(MODEL_CONFIGS.keys())) +@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None) +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False) +def test_can_use_platform_unsupported(_sup, _force, model): + cfg = get_config(model) + assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False + + +# use_trtllm_attention + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_force_off(_mock): + assert _call(force_use_trtllm=False) is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_dcp_fallback(_mock): + assert _call(dcp_world_size=2) is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False) +def test_use_platform_unsupported(_mock): + assert _call() is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False) +def test_use_platform_unsupported_force_on_still_false(_mock): + assert _call(force_use_trtllm=True) is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_incompatible_heads(_mock): + assert _call(num_qo_heads=40, num_kv_heads=6) is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_incompatible_heads_force_on_still_false(_mock): + assert _call(num_qo_heads=40, num_kv_heads=6, force_use_trtllm=True) is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_spec_decode_enables(_mock): + assert _call(has_spec=True, is_prefill=False) is True + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +@patch( + "vllm.utils.flashinfer.current_platform.fp8_dtype", + return_value=torch.float8_e4m3fn, +) +def test_use_fp8_query_forces_trtllm(_fp8, _sup): + assert _call(q_dtype=torch.float8_e4m3fn) is True + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_sinks_force_trtllm(_mock): + assert _call(has_sinks=True) is True + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_auto_prefill_kv_auto(_mock): + assert _call(is_prefill=True, kv_cache_dtype="auto") is True + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_auto_prefill_kv_fp8(_mock): + assert _call(is_prefill=True, kv_cache_dtype="fp8") is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_auto_decode_small_batch(_mock): + assert _call(is_prefill=False, num_tokens=128, kv_cache_dtype="auto") is True + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_auto_decode_large_batch(_mock): + assert _call(is_prefill=False, num_tokens=512, kv_cache_dtype="auto") is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_force_on(_mock): + assert _call(force_use_trtllm=True) is True diff --git a/tests/v1/attention/test_trtllm_attention_integration.py b/tests/v1/attention/test_trtllm_attention_integration.py new file mode 100644 index 000000000000..50a2c8625313 --- /dev/null +++ b/tests/v1/attention/test_trtllm_attention_integration.py @@ -0,0 +1,360 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Integration tests for TRTLLM gen-full attention through FlashInfer.""" + +import unittest.mock +from functools import partial + +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_vllm_config, +) +from vllm.config import set_current_vllm_config +from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import set_random_seed +from vllm.v1.attention.backends.utils import ( + PerLayerParameters, + get_kv_cache_layout, + set_kv_cache_layout, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec + +if not current_platform.is_device_capability_family(100): + pytest.skip( + "TRTLLM integration tests require NVIDIA Blackwell (SM100).", + allow_module_level=True, + ) + +from vllm.v1.attention.backends.flashinfer import ( # noqa: E402 + FlashInferImpl, + FlashInferMetadataBuilder, + TRTLLMDecode, + TRTLLMPrefill, +) + + +class MockAttentionLayer: + """Minimal mock of an attention layer for testing.""" + + def __init__(self, device: torch.device): + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + self._o_scale_float = None + + +MODEL = "Qwen/Qwen2.5-0.5B" +BLOCK_SIZE = 16 +NUM_GPU_BLOCKS = 8192 + +BATCH_SPECS = { + "decode_only": BatchSpec( + seq_lens=[128, 256, 512], + query_lens=[1, 1, 1], + ), + "prefill_only": BatchSpec( + seq_lens=[64, 128, 256], + query_lens=[16, 32, 16], + ), + "mixed": BatchSpec( + seq_lens=[128, 256, 512, 128], + query_lens=[1, 1, 8, 16], + ), +} + + +def _mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): + head_size = vllm_config.model_config.get_head_size() + return { + name: PerLayerParameters( + window_left=-1, + logits_soft_cap=0.0, + sm_scale=1.0 / (head_size**0.5), + ) + for name in layer_names + } + + +def _create_hnd_kv_cache( + k_contexts, + v_contexts, + block_size, + num_kv_heads, + head_size, + dtype, + device, + num_blocks, + common_attn_metadata, +): + """Create and populate a KV cache with HND-compatible strides. + + The returned tensor has logical shape + (num_blocks, 2, block_size, num_kv_heads, head_size) but is physically + laid out as (num_blocks, 2, num_kv_heads, block_size, head_size) so that + ``kv_cache.permute(0, 1, 3, 2, 4)`` yields a contiguous HND view. + """ + seq_lens = common_attn_metadata.seq_lens.cpu() + query_lens = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + batch_size = len(k_contexts) + + # Build cache in (2, num_blocks, block_size, num_kv_heads, head_size) + # then convert to HND format (same approach as test_attention_backends.py). + kv_cache_raw = torch.zeros( + 2, + num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device=device, + ) + kv_cache_flat = kv_cache_raw.view(2, -1, num_kv_heads, head_size) + + start_block_idx = 1 + for i in range(batch_size): + k_ctx, v_ctx = k_contexts[i], v_contexts[i] + start = start_block_idx * block_size + end = start + k_ctx.shape[0] + kv_cache_flat[0, start:end] = k_ctx + kv_cache_flat[1, start:end] = v_ctx + start_block_idx += cdiv(int(seq_lens[i]), block_size) + + blocks_end = start_block_idx + + # Randomly permute blocks (starting from block 1; block 0 is null). + perm = torch.randperm(blocks_end - 1) + 1 + inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) + inv_perm[1:] = torch.argsort(perm) + 1 + kv_cache_raw[:, 1:blocks_end] = kv_cache_raw[:, perm] + + # Build block table. + start_block_idx = 1 + for i in range(batch_size): + n_blocks = cdiv(int(seq_lens[i]), block_size) + block_table[i, :n_blocks] = inv_perm[ + start_block_idx : start_block_idx + n_blocks + ] + start_block_idx += n_blocks + + # Build slot mapping that is consistent with the block table. + for i in range(batch_size): + ctx_len = int(seq_lens[i]) - int(query_lens[i]) + token_offsets = torch.arange(int(query_lens[i])) + ctx_len + block_indices = token_offsets // block_size + intra_block_offsets = token_offsets % block_size + start = common_attn_metadata.query_start_loc_cpu[i] + end = common_attn_metadata.query_start_loc_cpu[i + 1] + slot_mapping[start:end] = block_table[ + i, block_indices + ] * block_size + intra_block_offsets.to(device) + + # Transpose to FlashInfer logical shape then make HND-strided. + kv_cache = kv_cache_raw.transpose(0, 1) + kv_cache = kv_cache.transpose(2, 3).contiguous().transpose(2, 3) + return kv_cache + + +def _run_trtllm_integration(batch_spec): + """Run TRTLLM attention through the full FlashInfer pipeline + and compare against an SDPA reference.""" + set_random_seed(42) + device = torch.device("cuda:0") + + vllm_config = create_vllm_config( + model_name=MODEL, + max_model_len=max(batch_spec.seq_lens), + block_size=BLOCK_SIZE, + num_gpu_blocks=NUM_GPU_BLOCKS, + ) + vllm_config.attention_config.use_trtllm_attention = True + + num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + head_size = vllm_config.model_config.get_head_size() + dtype = vllm_config.model_config.dtype + scale = 1.0 / (head_size**0.5) + + # 1. Generate data and compute SDPA reference + all_q, all_k, all_v = [], [], [] + all_sdpa_out = [] + k_contexts, v_contexts = [], [] + + for i in range(batch_spec.batch_size): + s_len = batch_spec.seq_lens[i] + q_len = batch_spec.query_lens[i] + ctx_len = s_len - q_len + + q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device) + k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) + v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) + + # SDPA reference (N=1, H, L, D) + q_sdpa = q.unsqueeze(0).transpose(1, 2) + k_sdpa = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa = v_full.unsqueeze(0).transpose(1, 2) + + if num_q_heads != num_kv_heads: + repeats = num_q_heads // num_kv_heads + k_sdpa = k_sdpa.repeat_interleave(repeats, dim=1) + v_sdpa = v_sdpa.repeat_interleave(repeats, dim=1) + + def causal_mask_mod(b, h, q_idx, kv_idx, *, context_len): + return (q_idx + context_len) >= kv_idx + + mask_fn = partial(causal_mask_mod, context_len=ctx_len) + block_mask = create_block_mask( + mask_fn, B=None, H=None, Q_LEN=q_len, KV_LEN=s_len, device=device + ) + sdpa_out = flex_attention( + q_sdpa, + k_sdpa, + v_sdpa, + block_mask=block_mask, + scale=scale, + enable_gqa=True, + ) + all_sdpa_out.append(sdpa_out.transpose(1, 2).squeeze(0)) + + all_q.append(q) + all_k.append(k_full[ctx_len:]) + all_v.append(v_full[ctx_len:]) + k_contexts.append(k_full[:ctx_len]) + v_contexts.append(v_full[:ctx_len]) + + query_vllm = torch.cat(all_q, dim=0) + key_vllm = torch.cat(all_k, dim=0) + value_vllm = torch.cat(all_v, dim=0) + sdpa_output = torch.cat(all_sdpa_out, dim=0) + + common_attn_metadata = create_common_attn_metadata(batch_spec, BLOCK_SIZE, device) + + # 2. Create HND KV cache + kv_cache = _create_hnd_kv_cache( + k_contexts, + v_contexts, + BLOCK_SIZE, + num_kv_heads, + head_size, + dtype, + device, + NUM_GPU_BLOCKS, + common_attn_metadata, + ) + + # 3. Run through FlashInfer with TRTLLM enabled + set_kv_cache_layout("HND") + get_kv_cache_layout.cache_clear() + + try: + kv_cache_spec = FullAttentionSpec( + block_size=BLOCK_SIZE, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + ) + layer_names = ["test_layer_0"] + + with ( + set_current_vllm_config(vllm_config), + unittest.mock.patch( + "vllm.utils.flashinfer.supports_trtllm_attention", + return_value=True, + ), + unittest.mock.patch( + "vllm.v1.attention.backends.flashinfer.get_per_layer_parameters", + _mock_get_per_layer_parameters, + ), + ): + builder = FlashInferMetadataBuilder( + kv_cache_spec, layer_names, vllm_config, device + ) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Verify the correct TRTLLM metadata types were produced. + has_prefills = any(ql > 1 for ql in batch_spec.query_lens) + has_decodes = any(ql == 1 for ql in batch_spec.query_lens) + + if has_prefills: + assert isinstance(attn_metadata.prefill, TRTLLMPrefill), ( + f"Expected TRTLLMPrefill, got {type(attn_metadata.prefill)}" + ) + if has_decodes: + assert isinstance(attn_metadata.decode, TRTLLMDecode), ( + f"Expected TRTLLMDecode, got {type(attn_metadata.decode)}" + ) + + impl = FlashInferImpl( + num_heads=num_q_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + ) + + mock_layer = MockAttentionLayer(device) + output = torch.empty_like(query_vllm) + + impl.do_kv_cache_update( + mock_layer, + key_vllm, + value_vllm, + kv_cache, + attn_metadata.slot_mapping, + ) + + output = impl.forward( + mock_layer, + query_vllm, + key_vllm, + value_vllm, + kv_cache, + attn_metadata, + output=output, + ) + + # 4. Compare against SDPA reference + torch.testing.assert_close( + output, + sdpa_output, + atol=1e-2, + rtol=1e-2, + ) + + finally: + set_kv_cache_layout(None) + get_kv_cache_layout.cache_clear() + + +@pytest.mark.parametrize( + "batch_spec_name", + list(BATCH_SPECS.keys()), +) +@torch.inference_mode() +def test_trtllm_gen_full_attention_integration(batch_spec_name: str): + """Test TRTLLM gen-full attention through the full FlashInfer + MetadataBuilder.build() -> FlashInferImpl.forward() pipeline, + with real TRTLLM kernels on Blackwell.""" + _run_trtllm_integration(BATCH_SPECS[batch_spec_name]) From ae88468bcc88773d548122dc05f041a1b3670745 Mon Sep 17 00:00:00 2001 From: JasonCohere Date: Tue, 3 Mar 2026 16:47:39 +0000 Subject: [PATCH 28/53] fix: Ensure invalid audio files return 400 error (#34715) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jason Ozuzu Co-authored-by: Nicolò Lucchesi --- .../test_transcription_validation_whisper.py | 17 +++++++++++++++ .../openai/speech_to_text/speech_to_text.py | 21 ++++++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/openai/test_transcription_validation_whisper.py b/tests/entrypoints/openai/test_transcription_validation_whisper.py index 2d5468c87c5e..cbee032a7ae7 100644 --- a/tests/entrypoints/openai/test_transcription_validation_whisper.py +++ b/tests/entrypoints/openai/test_transcription_validation_whisper.py @@ -108,6 +108,23 @@ async def test_long_audio_request(mary_had_lamb, whisper_client): assert out_usage["seconds"] == 161, out_usage["seconds"] +@pytest.mark.asyncio +async def test_invalid_audio_file(whisper_client): + """Corrupted audio should surface as HTTP 400.""" + invalid_audio = io.BytesIO(b"not a valid audio file") + invalid_audio.name = "invalid.wav" + + with pytest.raises(openai.BadRequestError) as exc_info: + await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=invalid_audio, + language="en", + ) + + assert exc_info.value.status_code == 400 + assert "Invalid or unsupported audio file" in exc_info.value.message + + @pytest.mark.asyncio async def test_completion_endpoints(whisper_client): # text to text model diff --git a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py index 966e6d457162..1c56f092029d 100644 --- a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py @@ -11,6 +11,7 @@ import numpy as np from fastapi import Request +from soundfile import LibsndfileError from transformers import PreTrainedTokenizerBase import vllm.envs as envs @@ -57,6 +58,14 @@ except ImportError: librosa = PlaceholderModule("librosa") # type: ignore[assignment] +# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile +# being librosa's main backend. Used to validate if an audio loading error is due to a +# server error vs a client error (invalid audio file). +# 1 = unrecognised format (file is not a supported audio container) +# 3 = malformed file (corrupt or structurally invalid audio) +# 4 = unsupported encoding (codec not supported by this libsndfile build) +_BAD_SF_CODES = {1, 3, 4} + SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse SpeechToTextResponseVerbose: TypeAlias = ( TranscriptionResponseVerbose | TranslationResponseVerbose @@ -315,9 +324,15 @@ async def _preprocess_speech_to_text( ) with io.BytesIO(audio_data) as bytes_: - # NOTE resample to model SR here for efficiency. This is also a - # pre-requisite for chunking, as it assumes Whisper SR. - y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) + try: + # NOTE resample to model SR here for efficiency. This is also a + # pre-requisite for chunking, as it assumes Whisper SR. + y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) + except LibsndfileError as exc: + # Distinguish client errors (invalid audio) from server errors + if exc.code in _BAD_SF_CODES: + raise ValueError("Invalid or unsupported audio file.") from exc + raise duration = librosa.get_duration(y=y, sr=sr) do_split_audio = ( From 8e1fd5baf0ff272936618bf578533d9aa7080a27 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 3 Mar 2026 12:26:44 -0500 Subject: [PATCH 29/53] [CI] Bump `num_speculative_tokens` to 3 in nightly DeepSeek tests (#35882) Signed-off-by: Matthew Bonanni --- tests/evals/gsm8k/configs/DeepSeek-R1-DP.yaml | 2 +- tests/evals/gsm8k/configs/DeepSeek-R1-TP.yaml | 2 +- tests/evals/gsm8k/configs/DeepSeek-V3.2-DP.yaml | 2 +- tests/evals/gsm8k/configs/DeepSeek-V3.2-TP.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/evals/gsm8k/configs/DeepSeek-R1-DP.yaml b/tests/evals/gsm8k/configs/DeepSeek-R1-DP.yaml index f351a1722064..0c6a598a8a90 100644 --- a/tests/evals/gsm8k/configs/DeepSeek-R1-DP.yaml +++ b/tests/evals/gsm8k/configs/DeepSeek-R1-DP.yaml @@ -8,4 +8,4 @@ server_args: >- --max-model-len 4096 --data-parallel-size 8 --enable-expert-parallel - --speculative-config '{"method":"mtp","num_speculative_tokens":1}' + --speculative-config '{"method":"mtp","num_speculative_tokens":3}' diff --git a/tests/evals/gsm8k/configs/DeepSeek-R1-TP.yaml b/tests/evals/gsm8k/configs/DeepSeek-R1-TP.yaml index ba3463463b5e..f6ab81008588 100644 --- a/tests/evals/gsm8k/configs/DeepSeek-R1-TP.yaml +++ b/tests/evals/gsm8k/configs/DeepSeek-R1-TP.yaml @@ -8,4 +8,4 @@ server_args: >- --max-model-len 4096 --tensor-parallel-size 8 --enable-expert-parallel - --speculative-config '{"method":"mtp","num_speculative_tokens":1}' + --speculative-config '{"method":"mtp","num_speculative_tokens":3}' diff --git a/tests/evals/gsm8k/configs/DeepSeek-V3.2-DP.yaml b/tests/evals/gsm8k/configs/DeepSeek-V3.2-DP.yaml index d7d1df974aab..c0e2e8f044be 100644 --- a/tests/evals/gsm8k/configs/DeepSeek-V3.2-DP.yaml +++ b/tests/evals/gsm8k/configs/DeepSeek-V3.2-DP.yaml @@ -8,4 +8,4 @@ server_args: >- --max-model-len 4096 --data-parallel-size 8 --enable-expert-parallel - --speculative-config '{"method":"mtp","num_speculative_tokens":1}' + --speculative-config '{"method":"mtp","num_speculative_tokens":3}' diff --git a/tests/evals/gsm8k/configs/DeepSeek-V3.2-TP.yaml b/tests/evals/gsm8k/configs/DeepSeek-V3.2-TP.yaml index 83687594d415..d31c63b8d764 100644 --- a/tests/evals/gsm8k/configs/DeepSeek-V3.2-TP.yaml +++ b/tests/evals/gsm8k/configs/DeepSeek-V3.2-TP.yaml @@ -8,4 +8,4 @@ server_args: >- --max-model-len 4096 --tensor-parallel-size 8 --enable-expert-parallel - --speculative-config '{"method":"mtp","num_speculative_tokens":1}' + --speculative-config '{"method":"mtp","num_speculative_tokens":3}' From 881a6b011b76bddf159b1a635586064e34e221b0 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Tue, 3 Mar 2026 13:36:15 -0500 Subject: [PATCH 30/53] [CI] Temporarily Disable Llama4 MoE Refactor Test (#35870) Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> --- tests/evals/gsm8k/configs/moe-refactor/config-h100.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt b/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt index 563d5d42cd0f..7397fc4e4626 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt +++ b/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt @@ -8,8 +8,5 @@ Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml Qwen3-30B-A3B-Fp8-CT-Channel-marlin.yaml Qwen3-30B-A3B-Fp8-CT-Channel-vllm-cutlass.yaml -Llama-4-Scout-Fp8-ModelOpt-fi-cutlass.yaml -Llama-4-Scout-Fp8-ModelOpt-marlin.yaml -Llama-4-Scout-Fp8-ModelOpt-triton.yaml Qwen3-30B-A3B-BF16-fi-cutlass.yaml Qwen3-30B-A3B-BF16-triton.yaml From 97995f6376fd3dae7a67624055ddf038233e181e Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Tue, 3 Mar 2026 13:39:50 -0500 Subject: [PATCH 31/53] [MoE Refactor] Create MK for TRTLLM Kernels (#32564) Signed-off-by: Robert Shaw Signed-off-by: Robert Shaw Signed-off-by: Robert Shaw Co-authored-by: Robert Shaw Co-authored-by: Robert Shaw --- .buildkite/test_areas/kernels.yaml | 3 +- .../kernels/benchmark_cutlass_moe_fp8.py | 28 +- .../kernels/benchmark_cutlass_moe_nvfp4.py | 35 +- .../kernels/benchmark_grouped_gemm_cutlass.py | 50 +- benchmarks/kernels/benchmark_moe.py | 54 +- docs/design/dbo.md | 2 +- docs/design/fused_moe_modular_kernel.md | 104 +-- docs/design/moe_kernel_features.md | 16 +- .../moe/modular_kernel_tools/cli_args.py | 4 +- .../moe/modular_kernel_tools/common.py | 9 +- .../moe/modular_kernel_tools/mk_objects.py | 26 +- .../profile_modular_kernel.py | 2 +- tests/kernels/moe/test_batched_deepgemm.py | 17 +- tests/kernels/moe/test_block_fp8.py | 37 +- tests/kernels/moe/test_cutlass_moe.py | 58 +- tests/kernels/moe/test_deepep_deepgemm_moe.py | 18 +- tests/kernels/moe/test_deepep_moe.py | 10 +- tests/kernels/moe/test_deepgemm.py | 27 +- tests/kernels/moe/test_flashinfer.py | 70 +- tests/kernels/moe/test_flashinfer_moe.py | 22 +- .../moe/test_marlin_vs_trtllm_mxint4.py | 21 +- .../moe/test_modular_kernel_combinations.py | 5 +- .../moe/test_modular_oai_triton_moe.py | 28 +- tests/kernels/moe/test_moe.py | 8 +- tests/kernels/moe/test_nvfp4_moe.py | 32 +- tests/kernels/moe/utils.py | 54 +- tests/quantization/test_blackwell_moe.py | 80 +++ vllm/lora/layers/fused_moe.py | 19 +- .../layers/fused_moe/__init__.py | 8 +- .../layers/fused_moe/all2all_utils.py | 15 +- .../layers/fused_moe/batched_deep_gemm_moe.py | 2 +- .../model_executor/layers/fused_moe/config.py | 6 + .../layers/fused_moe/cutlass_moe.py | 15 +- .../layers/fused_moe/deep_gemm_moe.py | 2 +- .../fused_moe/deepep_ht_prepare_finalize.py | 3 +- .../fused_moe/deepep_ll_prepare_finalize.py | 4 +- .../layers/fused_moe/experts/__init__.py | 0 .../fused_moe/experts/trtllm_fp8_moe.py | 335 +++++++++ .../fused_moe/experts/trtllm_nvfp4_moe.py | 326 +++++++++ .../layers/fused_moe/fallback.py | 12 +- .../flashinfer_a2a_prepare_finalize.py | 6 +- .../fused_moe/flashinfer_cutedsl_moe.py | 2 +- .../fused_moe/flashinfer_cutlass_moe.py | 2 +- .../layers/fused_moe/flashinfer_trtllm_moe.py | 298 -------- .../layers/fused_moe/fused_batched_moe.py | 6 +- .../layers/fused_moe/fused_marlin_moe.py | 2 +- .../layers/fused_moe/fused_moe.py | 4 +- .../layers/fused_moe/fused_moe_method_base.py | 51 +- .../fused_moe/fused_moe_modular_method.py | 18 +- .../fused_moe/gpt_oss_triton_kernels_moe.py | 2 +- .../layers/fused_moe/modular_kernel.py | 674 ++++++++++++++---- .../layers/fused_moe/mori_prepare_finalize.py | 2 +- .../layers/fused_moe/oracle/fp8.py | 130 ++-- .../layers/fused_moe/oracle/nvfp4.py | 136 ++-- .../layers/fused_moe/oracle/unquantized.py | 20 +- .../layers/fused_moe/prepare_finalize.py | 209 ------ .../fused_moe/prepare_finalize/__init__.py | 22 + .../fused_moe/prepare_finalize/naive_dp_ep.py | 253 +++++++ .../fused_moe/prepare_finalize/no_dp_ep.py | 141 ++++ .../layers/fused_moe/rocm_aiter_fused_moe.py | 2 +- .../layers/fused_moe/router/base_router.py | 2 +- .../fused_moe/runner/default_moe_runner.py | 69 +- .../fused_moe/topk_weight_and_reduce.py | 4 +- .../layers/fused_moe/triton_cutlass_moe.py | 6 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 6 +- .../layers/fused_moe/trtllm_moe.py | 2 +- .../fused_moe/unquantized_fused_moe_method.py | 14 +- .../layers/fused_moe/xpu_fused_moe.py | 2 +- .../compressed_tensors_moe.py | 239 ++----- .../model_executor/layers/quantization/fp8.py | 97 +-- .../layers/quantization/modelopt.py | 229 ++---- .../layers/quantization/mxfp4.py | 14 +- .../quantization/utils/flashinfer_fp4_moe.py | 296 +------- .../quantization/utils/flashinfer_utils.py | 119 +--- .../model_executor/warmup/deep_gemm_warmup.py | 2 +- vllm/model_executor/warmup/kernel_warmup.py | 11 +- vllm/utils/flashinfer.py | 1 + 77 files changed, 2574 insertions(+), 2086 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/experts/__init__.py create mode 100644 vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py create mode 100644 vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py delete mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize.py create mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py create mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py create mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py diff --git a/.buildkite/test_areas/kernels.yaml b/.buildkite/test_areas/kernels.yaml index e1ecfeb8415f..566f4f222888 100644 --- a/.buildkite/test_areas/kernels.yaml +++ b/.buildkite/test_areas/kernels.yaml @@ -44,7 +44,8 @@ steps: - vllm/envs.py - vllm/config commands: - - pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + - pytest -v -s kernels/moe --ignore=kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + - pytest -v -s kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 - label: Kernels Mamba Test diff --git a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py index b33282523db5..bd116e36a716 100644 --- a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py +++ b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py @@ -12,12 +12,12 @@ from tests.kernels.moe.utils import make_dummy_moe_config from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.worker.workspace import init_workspace_manager @@ -137,15 +137,21 @@ def bench_run( per_out_ch_quant=per_out_ch, ) - fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + moe_config = make_dummy_moe_config( + num_experts=num_experts, + hidden_dim=k, + intermediate_size_per_partition=n, + in_dtype=a.dtype, + ) + fn = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), CutlassExpertsFp8( - moe_config=make_dummy_moe_config( - num_experts=num_experts, - hidden_dim=k, - intermediate_size_per_partition=n, - in_dtype=a.dtype, - ), + moe_config=moe_config, quant_config=quant_config, ), ) diff --git a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py index c1f4f0aa9fce..cfb1489dadf2 100644 --- a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py +++ b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py @@ -15,6 +15,9 @@ from tests.kernels.moe.utils import make_dummy_moe_config from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( fp8_w8a8_moe_quant_config, nvfp4_moe_quant_config, @@ -23,9 +26,6 @@ CutlassExpertsFp4, ) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.scalar_type import scalar_types from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.worker.workspace import init_workspace_manager @@ -196,10 +196,21 @@ def run_cutlass_moe_fp4( g2_alphas=w2_gs, ) - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + moe_config = make_dummy_moe_config( + num_experts=num_experts, + hidden_dim=k, + intermediate_size_per_partition=n, + in_dtype=a.dtype, + ) + kernel = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), CutlassExpertsFp4( - make_dummy_moe_config(), + moe_config=moe_config, quant_config=quant_config, ), ) @@ -240,11 +251,17 @@ def run_cutlass_from_graph( g1_alphas=w1_gs, g2_alphas=w2_gs, ) + moe_config = make_dummy_moe_config() - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + kernel = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), CutlassExpertsFp4( - make_dummy_moe_config(), + moe_config=moe_config, quant_config=quant_config, ), ) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 7b5daa62eb34..60ec94b878ce 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -9,15 +9,15 @@ from tests.kernels.moe.utils import make_dummy_moe_config from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_topk, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.worker.workspace import init_workspace_manager @@ -131,16 +131,22 @@ def run_cutlass_moe( w2_scale=w2_scale, per_act_token_quant=per_act_token, ) + moe_config = make_dummy_moe_config( + num_experts=w2.shape[0], + hidden_dim=w2.shape[1], + intermediate_size_per_partition=w2.shape[2], + in_dtype=a.dtype, + ) - fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + fn = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), CutlassExpertsFp8( - moe_config=make_dummy_moe_config( - num_experts=w2.shape[0], - hidden_dim=w2.shape[1], - intermediate_size_per_partition=w2.shape[2], - in_dtype=a.dtype, - ), + moe_config=moe_config, quant_config=quant_config, ), ) @@ -163,16 +169,22 @@ def run_cutlass_from_graph( w2_scale=w2_scale, per_act_token_quant=per_act_token, ) + moe_config = make_dummy_moe_config( + num_experts=w2.shape[0], + hidden_dim=w2.shape[1], + intermediate_size_per_partition=w2.shape[2], + in_dtype=a.dtype, + ) - fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + fn = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), CutlassExpertsFp8( - moe_config=make_dummy_moe_config( - num_experts=w2.shape[0], - hidden_dim=w2.shape[1], - intermediate_size_per_partition=w2.shape[2], - in_dtype=a.dtype, - ), + moe_config=moe_config, quant_config=quant_config, ), ) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index e086a109f394..4abeaefd774a 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -17,6 +17,9 @@ from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -242,24 +245,33 @@ def run(): deep_gemm_experts = None if use_deep_gemm: - deep_gemm_experts = mk.FusedMoEModularKernel( - prepare_finalize=MoEPrepareAndFinalizeNoEP(), + moe_config = ( + FusedMoEConfig( + num_experts=num_experts, + experts_per_token=topk, + hidden_dim=hidden_size, + intermediate_size_per_partition=shard_intermediate_size, + num_local_experts=num_experts, + num_logical_experts=num_experts, + activation=MoEActivation.SILU, + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + in_dtype=init_dtype, + routing_method=RoutingMethodType.TopK, + device="cuda", + ), + ) + deep_gemm_experts = mk.FusedMoEKernel( + prepare_finalize=maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), fused_experts=TritonOrDeepGemmExperts( - moe_config=FusedMoEConfig( - num_experts=num_experts, - experts_per_token=topk, - hidden_dim=hidden_size, - intermediate_size_per_partition=shard_intermediate_size, - num_local_experts=num_experts, - num_logical_experts=num_experts, - activation=MoEActivation.SILU, - moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), - in_dtype=init_dtype, - routing_method=RoutingMethodType.TopK, - device="cuda", - ), + moe_config=moe_config, quant_config=quant_config, ), + inplace=not disable_inplace(), ) with override_config(config): @@ -269,8 +281,16 @@ def run(): inplace = not disable_inplace() if use_deep_gemm: - return deep_gemm_experts( - x, w1, w2, topk_weights, topk_ids, inplace=inplace + return deep_gemm_experts.apply( + x, + w1, + w2, + topk_weights, + topk_ids, + activation=MoEActivation.SILU, + global_num_experts=num_experts, + apply_router_weight_on_input=False, + expert_map=False, ) return fused_experts( x, diff --git a/docs/design/dbo.md b/docs/design/dbo.md index f2d98ccd063f..43b3ce0bb5a7 100644 --- a/docs/design/dbo.md +++ b/docs/design/dbo.md @@ -81,7 +81,7 @@ The current implementation has all `dbo_yield` and `dbo_maybe_run_recv_hook` cal The `make_ubatch_context` function initializes two `UBatchContexts`, one for each UBatch thread. It takes two CUDA streams, the preexisting `ForwardContexts` and a CPU thread barrier. This function should be used exclusively to instantiate `UBatchContexts`. It will handle all of the event initialization. -The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalize` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel. +The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalizeModular` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel. The `dbo_maybe_run_recv_hook` method runs a callback that’s set by the `dbo_register_recv_hook` function if that callback exists. diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 9db356cdf531..7f356262bb2d 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -37,31 +37,31 @@ The rest of the document will focus on the Contiguous / Non-Batched case. Extrap FusedMoEModularKernel splits the FusedMoE operation into 3 parts, 1. TopKWeightAndReduce -2. FusedMoEPrepareAndFinalize -3. FusedMoEPermuteExpertsUnpermute +2. FusedMoEPrepareAndFinalizeModular +3. FusedMoEExpertsModular ### TopKWeightAndReduce -The TopK Weight Application and Reduction components happen right after the Unpermute operation and before the All2All Combine. Note that the `FusedMoEPermuteExpertsUnpermute` is responsible for the Unpermute and `FusedMoEPrepareAndFinalize` is responsible for the All2All Combine. There is value in doing the TopK Weight Application and Reduction in the `FusedMoEPermuteExpertsUnpermute`. But some implementations choose to do it `FusedMoEPrepareAndFinalize`. In order to enable this flexibility, we have a TopKWeightAndReduce abstract class. +The TopK Weight Application and Reduction components happen right after the Unpermute operation and before the All2All Combine. Note that the `FusedMoEExpertsModular` is responsible for the Unpermute and `FusedMoEPrepareAndFinalizeModular` is responsible for the All2All Combine. There is value in doing the TopK Weight Application and Reduction in the `FusedMoEExpertsModular`. But some implementations choose to do it `FusedMoEPrepareAndFinalizeModular`. In order to enable this flexibility, we have a TopKWeightAndReduce abstract class. Please find the implementations of TopKWeightAndReduce [here](../../vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py). -`FusedMoEPrepareAndFinalize::finalize()` method accepts a `TopKWeightAndReduce` argument that is invoked inside the method. -The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExpertsUnpermute` and `FusedMoEPerpareAndFinalize` implementations to determine where the TopK Weight Application and Reduction happens. +`FusedMoEPrepareAndFinalizeModular::finalize()` method accepts a `TopKWeightAndReduce` argument that is invoked inside the method. +The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEExpertsModular` and `FusedMoEPerpareAndFinalize` implementations to determine where the TopK Weight Application and Reduction happens. -* `FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceNoOp` if the `FusedMoEPermuteExpertsUnpermute` implementation does the weight application and reduction itself. -* `FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceContiguous` / `TopKWeightAndReduceNaiveBatched` / `TopKWeightAndReduceDelegate` if the `FusedMoEPermuteExpertsUnpermute` implementation needs the `FusedMoEPrepareAndFinalize::finalize()` to do the weight application and reduction. +* `FusedMoEExpertsModular::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceNoOp` if the `FusedMoEExpertsModular` implementation does the weight application and reduction itself. +* `FusedMoEExpertsModular::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceContiguous` / `TopKWeightAndReduceNaiveBatched` / `TopKWeightAndReduceDelegate` if the `FusedMoEExpertsModular` implementation needs the `FusedMoEPrepareAndFinalizeModular::finalize()` to do the weight application and reduction. -### FusedMoEPrepareAndFinalize +### FusedMoEPrepareAndFinalizeModular -The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions. -The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) +The `FusedMoEPrepareAndFinalizeModular` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions. +The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalizeModular` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) -![FusedMoEPrepareAndFinalize Blocks](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png) +![FusedMoEPrepareAndFinalizeModular Blocks](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png) -### FusedMoEPermuteExpertsUnpermute +### FusedMoEExpertsModular -The `FusedMoEPermuteExpertsUnpermute` class is where the crux of the MoE operations happen. The `FusedMoEPermuteExpertsUnpermute` abstract class exposes a few important functions, +The `FusedMoEExpertsModular` class is where the crux of the MoE operations happen. The `FusedMoEExpertsModular` abstract class exposes a few important functions, * apply() * workspace_shapes() @@ -81,25 +81,25 @@ The `apply` method is where the implementations perform #### workspace_shapes() -The core FusedMoE implementation performs a series of operations. It would be inefficient to create output memory for each of these operations separately. To that effect, implementations are required to declare 2 workspace shapes, the workspace datatype and the FusedMoE output shape as outputs of the workspace_shapes() method. This information is used to allocate the workspace tensors and the output tensor in `FusedMoEModularKernel::forward()` and passed on to the `FusedMoEPermuteExpertsUnpermute::apply()` method. The workspaces could then be used as intermediate buffers in the FusedMoE implementation. +The core FusedMoE implementation performs a series of operations. It would be inefficient to create output memory for each of these operations separately. To that effect, implementations are required to declare 2 workspace shapes, the workspace datatype and the FusedMoE output shape as outputs of the workspace_shapes() method. This information is used to allocate the workspace tensors and the output tensor in `FusedMoEModularKernel::forward()` and passed on to the `FusedMoEExpertsModular::apply()` method. The workspaces could then be used as intermediate buffers in the FusedMoE implementation. #### finalize_weight_and_reduce_impl() -It is sometimes efficient to perform TopK weight application and Reduction inside the `FusedMoEPermuteExpertsUnpermute::apply()`. Find an example [here](https://github.com/vllm-project/vllm/pull/20228). We have a `TopKWeightAndReduce` abstract class to facilitate such implementations. Please refer to the TopKWeightAndReduce section. -`FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl()` returns the `TopKWeightAndReduce` object that the implementation wants the `FusedMoEPrepareAndFinalize::finalize()` to use. +It is sometimes efficient to perform TopK weight application and Reduction inside the `FusedMoEExpertsModular::apply()`. Find an example [here](https://github.com/vllm-project/vllm/pull/20228). We have a `TopKWeightAndReduce` abstract class to facilitate such implementations. Please refer to the TopKWeightAndReduce section. +`FusedMoEExpertsModular::finalize_weight_and_reduce_impl()` returns the `TopKWeightAndReduce` object that the implementation wants the `FusedMoEPrepareAndFinalizeModular::finalize()` to use. -![FusedMoEPermuteExpertsUnpermute Blocks](../assets/design/fused_moe_modular_kernel/fused_experts_blocks.png) +![FusedMoEExpertsModular Blocks](../assets/design/fused_moe_modular_kernel/fused_experts_blocks.png) ### FusedMoEModularKernel -`FusedMoEModularKernel` is composed of the `FusedMoEPrepareAndFinalize` and `FusedMoEPermuteExpertsUnpermute` objects. +`FusedMoEModularKernel` is composed of the `FusedMoEPrepareAndFinalizeModular` and `FusedMoEExpertsModular` objects. `FusedMoEModularKernel` pseudocode/sketch, ```py class FusedMoEModularKernel: def __init__(self, - prepare_finalize: FusedMoEPrepareAndFinalize, - fused_experts: FusedMoEPermuteExpertsUnpermute): + prepare_finalize: FusedMoEPrepareAndFinalizeModular, + fused_experts: FusedMoEExpertsModular): self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts @@ -128,53 +128,53 @@ class FusedMoEModularKernel: ## How-To -### How To Add a FusedMoEPrepareAndFinalize Type +### How To Add a FusedMoEPrepareAndFinalizeModular Type -Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & Combine implementation / kernel. For example, +Typically a FusedMoEPrepareAndFinalizeModular type is backed by an All2All Dispatch & Combine implementation / kernel. For example, * DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and * DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels. #### Step 1: Add an All2All manager -The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](../../vllm/distributed/device_communicators/all2all.py). +The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalizeModular` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](../../vllm/distributed/device_communicators/all2all.py). -#### Step 2: Add a FusedMoEPrepareAndFinalize Type +#### Step 2: Add a FusedMoEPrepareAndFinalizeModular Type -This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalize` abstract class. +This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalizeModular` abstract class. -`FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked. +`FusedMoEPrepareAndFinalizeModular::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked. -`FusedMoEPrepareAndFinalize::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False. +`FusedMoEPrepareAndFinalizeModular::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False. -`FusedMoEPrepareAndFinalize::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked. +`FusedMoEPrepareAndFinalizeModular::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked. -`FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked. +`FusedMoEPrepareAndFinalizeModular::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked. -`FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise. +`FusedMoEPrepareAndFinalizeModular::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise. -`FusedMoEPrepareAndFinalize::topk_indices_dtype()`: Data type of the TopK ids. Some All2All kernels have strict requirements pertaining to the data type of the TopK ids. This requirement is passed on to the `FusedMoe::select_experts` function so it could be respected. If there are no strict requirements return None. +`FusedMoEPrepareAndFinalizeModular::topk_indices_dtype()`: Data type of the TopK ids. Some All2All kernels have strict requirements pertaining to the data type of the TopK ids. This requirement is passed on to the `FusedMoe::select_experts` function so it could be respected. If there are no strict requirements return None. -`FusedMoEPrepareAndFinalize::max_num_tokens_per_rank()`: This is the maximum number of tokens that would be submitted to the All2All Dispatch at once. +`FusedMoEPrepareAndFinalizeModular::max_num_tokens_per_rank()`: This is the maximum number of tokens that would be submitted to the All2All Dispatch at once. -`FusedMoEPrepareAndFinalize::num_dispatchers()`: Total number of dispatching units. This value determines the size of the Dispatch output. The Dispatch output is of shape (num_local_experts, max_num_tokens, K). Here max_num_tokens = num_dispatchers() * max_num_tokens_per_rank(). +`FusedMoEPrepareAndFinalizeModular::num_dispatchers()`: Total number of dispatching units. This value determines the size of the Dispatch output. The Dispatch output is of shape (num_local_experts, max_num_tokens, K). Here max_num_tokens = num_dispatchers() * max_num_tokens_per_rank(). -We suggest picking an already existing `FusedMoEPrepareAndFinalize` implementation that matches your All2All implementation closely and using it as a reference. +We suggest picking an already existing `FusedMoEPrepareAndFinalizeModular` implementation that matches your All2All implementation closely and using it as a reference. -### How To Add a FusedMoEPermuteExpertsUnpermute Type +### How To Add a FusedMoEExpertsModular Type -FusedMoEPermuteExpertsUnpermute performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows, +FusedMoEExpertsModular performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows, -`FusedMoEPermuteExpertsUnpermute::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format. +`FusedMoEExpertsModular::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format. -`FusedMoEPermuteExpertsUnpermute::supports_chunking()`: Return True if the implementation supports chunking. Typically +`FusedMoEExpertsModular::supports_chunking()`: Return True if the implementation supports chunking. Typically implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not. -`FusedMoEPermuteExpertsUnpermute::supports_expert_map()`: Return True if the implementation supports expert map. +`FusedMoEExpertsModular::supports_expert_map()`: Return True if the implementation supports expert map. -`FusedMoEPermuteExpertsUnpermute::workspace_shapes()` / -`FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` / -`FusedMoEPermuteExpertsUnpermute::apply`: Refer to `FusedMoEPermuteExpertsUnpermute` section above. +`FusedMoEExpertsModular::workspace_shapes()` / +`FusedMoEExpertsModular::finalize_weight_and_reduce_impl` / +`FusedMoEExpertsModular::apply`: Refer to `FusedMoEExpertsModular` section above. ### FusedMoEModularKernel Initialization @@ -186,14 +186,14 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking #### maybe_make_prepare_finalize -The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. +The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalizeModular` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalizeModular` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. Please refer to the implementations in, * `ModelOptNvFp4FusedMoE` #### select_gemm_impl -The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object. +The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEExpertsModular` object. Please refer to the implementations in, * `UnquantizedFusedMoEMethod` @@ -205,7 +205,7 @@ derived classes. #### init_prepare_finalize -Based on the input and env settings, the `init_prepare_finalize` method creates the appropriate `FusedMoEPrepareAndFinalize` object. The method then queries `select_gemm_impl` for the appropriate `FusedMoEPermuteExpertsUnpermute` object and builds the `FusedMoEModularKernel` object +Based on the input and env settings, the `init_prepare_finalize` method creates the appropriate `FusedMoEPrepareAndFinalizeModular` object. The method then queries `select_gemm_impl` for the appropriate `FusedMoEExpertsModular` object and builds the `FusedMoEModularKernel` object Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vllm/blob/1cbf951ba272c230823b947631065b826409fa62/vllm/model_executor/layers/fused_moe/layer.py#L188). **Important**: The `FusedMoEMethodBase` derived classes use the `FusedMoEMethodBase::fused_experts` object in their `apply` methods. When settings permit the construction of a valid `FusedMoEModularKernel` object, we override `FusedMoEMethodBase::fused_experts` with it. This essentially makes the derived classes agnostic to what FusedMoE implementation is used. @@ -214,9 +214,9 @@ Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vl We have `FusedMoEModularKernel` unit tests at [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py). -The unit test iterates through all combinations of `FusedMoEPrepareAndFinalize` and `FusedMoEPremuteExpertsUnpermute` types and if they are +The unit test iterates through all combinations of `FusedMoEPrepareAndFinalizeModular` and `FusedMoEPremuteExpertsUnpermute` types and if they are compatible, runs some correctness tests. -If you are adding some `FusedMoEPrepareAndFinalize` / `FusedMoEPermuteExpertsUnpermute` implementations, +If you are adding some `FusedMoEPrepareAndFinalizeModular` / `FusedMoEExpertsModular` implementations, 1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively. 2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`, @@ -225,24 +225,24 @@ If you are adding some `FusedMoEPrepareAndFinalize` / `FusedMoEPermuteExpertsUnp Doing this will add the new implementation to the test suite. -### How To Check `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` Compatibility +### How To Check `FusedMoEPrepareAndFinalizeModular` & `FusedMoEExpertsModular` Compatibility The unit test file [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script. Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts` -As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked +As a side effect, this script can be used to test `FusedMoEPrepareAndFinalizeModular` & `FusedMoEExpertsModular` compatibility. When invoked with incompatible types, the script will error. ### How To Profile Please take a look at [profile_modular_kernel.py](../../tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py) The script can be used to generate Torch traces for a single `FusedMoEModularKernel::forward()` call for any compatible -`FusedMoEPrepareAndFinalize` and `FusedMoEPermuteExpertsUnpermute` types. +`FusedMoEPrepareAndFinalizeModular` and `FusedMoEExpertsModular` types. Example: `python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel --pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts` -## FusedMoEPrepareAndFinalize Implementations +## FusedMoEPrepareAndFinalizeModular Implementations See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-modular-all2all-backends) for a list of all the available modular prepare and finalize subclasses. -## FusedMoEPermuteExpertsUnpermute +## FusedMoEExpertsModular See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-experts-kernels) for a list of all the available modular experts. diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index ac5acb66bdbf..0c92e597582e 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -4,17 +4,17 @@ The purpose of this document is to provide an overview of the various MoE kernel ## Fused MoE Modular All2All backends -There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalize` subclasses provide an interface for each all2all backend. +There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalizeModular` subclasses provide an interface for each all2all backend. The following table describes the relevant features of each backend, i.e. activation format, supported quantization schemes and async support. -The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalize` subclass, and the finalize step requires the same format. All the backend `prepare` methods expect activations in the standard format and all the `finalize` methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document. +The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalizeModular` subclass, and the finalize step requires the same format. All the backend `prepare` methods expect activations in the standard format and all the `finalize` methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document. -The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalize` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports, e.g. deepep_high_throughput supports only block-quantized fp8 format. Any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 with per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16. +The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalizeModular` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports, e.g. deepep_high_throughput supports only block-quantized fp8 format. Any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 with per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16. Async backends support the use of DBO (Dual Batch Overlap) and shared expert overlap (where shared experts are computed during the combine step). -Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalize` subclass. For non-modular kernels, it is up to the experts function to deal with this flag. +Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalizeModular` subclass. For non-modular kernels, it is up to the experts function to deal with this flag. Unless otherwise specified, backends are controlled via the `--all2all-backend` command-line argument (or the `all2all_backend` parameter in `ParallelConfig`). All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP. @@ -36,8 +36,6 @@ th { | deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] | | flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] | -| MoEPrepareAndFinalizeNoEP5 | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] | -| BatchedPrepareAndFinalize5 | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 @@ -75,9 +73,9 @@ Each experts kernel supports one or more activation functions, e.g. silu or gelu As with the backends, some experts support applying topk weights on the input activations. The entries in the column in this table only apply to the non-modular experts. -Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEPermuteExpertsUnpermute`. +Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEExpertsModular`. -To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels must have compatible activation formats, quantization types and quantization formats. +To be used with a particular `FusedMoEPrepareAndFinalizeModular` subclass, MoE kernels must have compatible activation formats, quantization types and quantization formats. | Kernel | Input act. format | Quant. types | Quant. format | Activation function | Apply Weight On Input | Modular | Source | |--------|-------------------|--------------|---------------|---------------------|-----------------------|---------|--------| @@ -106,7 +104,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels The following table shows "families" of modular kernels that are intended to work together. There are some combinations which may work but have not yet been tested, e.g. flashinfer with other fp8 experts. Note that the "naive" backend will work with any non-modular experts. -| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses | +| backend | `FusedMoEPrepareAndFinalizeModular` subclasses | `FusedMoEExpertsModular` subclasses | |---------|-----------------------------------------|----------------------------------------------| | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | | deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`CutlassBatchedExpertsFp8`,
`BatchedMarlinExperts` | diff --git a/tests/kernels/moe/modular_kernel_tools/cli_args.py b/tests/kernels/moe/modular_kernel_tools/cli_args.py index 34c6ca1f999c..375dfa748956 100644 --- a/tests/kernels/moe/modular_kernel_tools/cli_args.py +++ b/tests/kernels/moe/modular_kernel_tools/cli_args.py @@ -17,13 +17,13 @@ def make_config_arg_parser(description: str): - def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize: + def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalizeModular: for pf in MK_ALL_PREPARE_FINALIZE_TYPES: if pf.__name__ == s: return pf raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}") - def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute: + def to_experts_class_type(s: str) -> mk.FusedMoEExpertsModular: for fe in MK_FUSED_EXPERT_TYPES: if fe.__name__ == s: return fe diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index 9f67129616f9..4b2b1653babe 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -66,7 +66,7 @@ class Config: quant_config: TestMoEQuantConfig | None prepare_finalize_type: mk.FusedMoEPrepareAndFinalize - fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute + fused_experts_type: mk.FusedMoEExperts fused_moe_chunk_size: int | None world_size: int @@ -566,7 +566,7 @@ def make_modular_kernel( config: Config, vllm_config: VllmConfig, quant_config: FusedMoEQuantConfig, -) -> mk.FusedMoEModularKernel: +) -> mk.FusedMoEKernel: def next_power_of_2(x): import math @@ -613,7 +613,7 @@ def next_power_of_2(x): config.N, ) - modular_kernel = mk.FusedMoEModularKernel( + modular_kernel = mk.FusedMoEKernel( prepare_finalize=prepare_finalize, fused_experts=fused_experts, inplace=False, @@ -667,6 +667,7 @@ def run_modular_kernel( "w2": rank_weights.w2, "topk_weights": rank_tensors.topk_weights, "topk_ids": topk_ids, + "activation": MoEActivation.SILU, "expert_map": rank_tensors.expert_map, "global_num_experts": config.E, "apply_router_weight_on_input": config.topk == 1 @@ -684,6 +685,6 @@ def run_modular_kernel( num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, ): - out = mk.forward(**mk_kwargs) + out = mk.apply(**mk_kwargs) return out diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 0ea414c3af41..ee4190859e4c 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -20,7 +20,7 @@ NaiveBatchedExperts, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, + MoEPrepareAndFinalizeNoDPEPModular, ) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, @@ -71,12 +71,14 @@ class ExpertInfo: needs_aiter: bool = False -PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {} -EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {} -MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] -MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] -MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] -MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = [] +PREPARE_FINALIZE_INFO: dict[ + mk.FusedMoEPrepareAndFinalizeModular, PrepareFinalizeInfo +] = {} +EXPERT_INFO: dict[mk.FusedMoEExpertsModular, ExpertInfo] = {} +MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = [] +MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = [] +MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = [] +MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEExpertsModular] = [] standard_format = mk.FusedMoEActivationFormat.Standard batched_format = mk.FusedMoEActivationFormat.BatchedExperts @@ -162,7 +164,7 @@ def expert_info(kind) -> ExpertInfo: register_prepare_and_finalize( - MoEPrepareAndFinalizeNoEP, + MoEPrepareAndFinalizeNoDPEPModular, standard_format, common_float_types, blocked_quantization_support=True, @@ -239,14 +241,14 @@ def expert_info(kind) -> ExpertInfo: if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100): from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501 - FlashInferCutlassMoEPrepareAndFinalize, + FlashInferA2APrepareAndFinalize, ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, ) register_prepare_and_finalize( - FlashInferCutlassMoEPrepareAndFinalize, + FlashInferA2APrepareAndFinalize, standard_format, nvfp4_types + fp8_types, blocked_quantization_support=True, @@ -430,12 +432,12 @@ def make_cutlass_strides( def make_fused_experts( - fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + fused_experts_type: mk.FusedMoEExpertsModular, moe: FusedMoEConfig, quant_config: FusedMoEQuantConfig, num_dispatchers: int, N: int, -) -> mk.FusedMoEPermuteExpertsUnpermute: +) -> mk.FusedMoEExpertsModular: if ( fused_experts_type.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts diff --git a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py index 702584f9da53..2554c4fce933 100644 --- a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py +++ b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py @@ -72,7 +72,7 @@ def profile_modular_kernel( "apply_router_weight_on_input": config.topk == 1, } - do_profile(mk.forward, mk_kwargs, pgi, config) + do_profile(mk.apply, mk_kwargs, pgi, config) def rank_worker( diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py index 2c6c45a5f234..20763b91dfd9 100644 --- a/tests/kernels/moe/test_batched_deepgemm.py +++ b/tests/kernels/moe/test_batched_deepgemm.py @@ -4,6 +4,7 @@ import pytest import torch +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) @@ -12,7 +13,7 @@ BatchedPrepareAndFinalize, BatchedTritonExperts, ) -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported from .test_deepgemm import make_block_quant_fp8_weights @@ -74,19 +75,22 @@ def test_batched_deepgemm_vs_triton( quant_config=quant_config, moe_config=make_dummy_moe_config(), ) - mk_triton = FusedMoEModularKernel( + mk_triton = FusedMoEKernel( prep_finalize, triton_experts, inplace=False, ) - out_triton = mk_triton( + out_triton = mk_triton.apply( hidden_states=a, w1=w1, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, + activation=MoEActivation.SILU, global_num_experts=E, + expert_map=None, + apply_router_weight_on_input=False, ) # deepgemm @@ -96,19 +100,22 @@ def test_batched_deepgemm_vs_triton( quant_config=quant_config, moe_config=make_dummy_moe_config(), ) - mk_deepgemm = FusedMoEModularKernel( + mk_deepgemm = FusedMoEKernel( prep_finalize, deepgemm_experts, inplace=False, ) - out_deepgemm = mk_deepgemm( + out_deepgemm = mk_deepgemm.apply( hidden_states=a, w1=w1, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, + activation=MoEActivation.SILU, global_num_experts=E, + expert_map=None, + apply_router_weight_on_input=False, ) diff = calc_diff(out_deepgemm, out_triton) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 66508568ed2c..a74e739c55e4 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -21,15 +21,16 @@ fused_experts, fused_topk, ) +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm_shape, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, ) @@ -193,7 +194,17 @@ def test_w8a8_block_fp8_fused_moe( a, w1, w2, topk_weights, topk_ids, quant_config=quant_config ) - m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids) + m_out = m_fused_moe.apply( + a, + w1, + w2, + topk_weights, + topk_ids, + activation=MoEActivation.SILU, + apply_router_weight_on_input=False, + expert_map=None, + global_num_experts=w1.shape[0], + ) # 0.039 only needed for M >= 8192 tol = 0.035 if M < 8192 else 0.039 @@ -252,23 +263,33 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) w2_scale=w2_s, block_shape=block_size, ) + moe_config = make_dummy_moe_config() - deep_gemm_experts = mk.FusedMoEModularKernel( - prepare_finalize=MoEPrepareAndFinalizeNoEP(), + deep_gemm_experts = mk.FusedMoEKernel( + prepare_finalize=maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), fused_experts=TritonOrDeepGemmExperts( - moe_config=make_dummy_moe_config(), + moe_config=moe_config, quant_config=quant_config, ), inplace=False, ) def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): - return deep_gemm_experts( + return deep_gemm_experts.apply( hidden_states=a, w1=w1, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, + global_num_experts=E, + activation=MoEActivation.SILU, + apply_router_weight_on_input=False, + expert_map=False, ) # Set the context to avoid lots of warning spam. diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index ec23008dfa1f..1ec2c614ca80 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -13,6 +13,9 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, @@ -22,9 +25,6 @@ CutlassExpertsFp8, run_cutlass_moe_fp8, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed @@ -197,20 +197,26 @@ def slice_experts(): for kwargs, new_quant_config in slice_experts(): w2 = kwargs["w2"] a = kwargs["hidden_states"] - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + moe_config = make_dummy_moe_config( + num_experts=w2.shape[0], + hidden_dim=w2.shape[1], + intermediate_size_per_partition=w2.shape[2], + in_dtype=a.dtype, + ) + kernel = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=new_quant_config, + allow_new_interface=True, + use_monolithic=False, + ), CutlassExpertsFp8( - moe_config=make_dummy_moe_config( - num_experts=w2.shape[0], - hidden_dim=w2.shape[1], - intermediate_size_per_partition=w2.shape[2], - in_dtype=a.dtype, - ), + moe_config=moe_config, quant_config=new_quant_config, ), inplace=False, ) - out_tensor = out_tensor + kernel(**kwargs) + out_tensor = out_tensor + kernel.apply(**kwargs) return out_tensor @@ -252,25 +258,35 @@ def run_8_bit( "w2": moe_tensors.w2_q, # type: ignore[union-attr] "topk_weights": topk_weights, "topk_ids": topk_ids, + "global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr] + "activation": MoEActivation.SILU, + "expert_map": None, + "apply_router_weight_on_input": False, } num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined] with_ep = num_local_experts is not None or num_local_experts == num_experts if not with_ep: - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + moe_config = make_dummy_moe_config( + num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr] + hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr] + intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr] + in_dtype=moe_tensors.a.dtype, + ) + kernel = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), CutlassExpertsFp8( - moe_config=make_dummy_moe_config( - num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr] - hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr] - intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr] - in_dtype=moe_tensors.a.dtype, - ), + moe_config=moe_config, quant_config=quant_config, ), inplace=False, ) - return kernel(**kwargs) + return kernel.apply(**kwargs) assert num_local_experts is not None return run_with_expert_maps( diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 2b8240482829..a01fb1a452ea 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -22,7 +22,7 @@ fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.utils.deep_gemm import ( get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, @@ -170,7 +170,7 @@ def make_ll_modular_kernel( q_dtype: torch.dtype | None, test_config: TestConfig, quant_config: FusedMoEQuantConfig, -) -> FusedMoEModularKernel: +) -> FusedMoEKernel: assert test_config.low_latency assert test_config.use_fp8_dispatch is not None @@ -195,7 +195,7 @@ def make_ll_modular_kernel( quant_config=quant_config, moe_config=make_dummy_moe_config(), ) - return FusedMoEModularKernel( + return FusedMoEKernel( prepare_finalize=a2a, fused_experts=fused_experts, inplace=False, @@ -210,7 +210,7 @@ def make_ht_modular_kernel( q_dtype: torch.dtype | None, test_config: TestConfig, quant_config: FusedMoEQuantConfig, -) -> FusedMoEModularKernel: +) -> FusedMoEKernel: assert not test_config.low_latency assert test_config.use_fp8_dispatch is None @@ -228,7 +228,7 @@ def make_ht_modular_kernel( moe_config=make_dummy_moe_config(), quant_config=quant_config, ) - return FusedMoEModularKernel( + return FusedMoEKernel( prepare_finalize=a2a, fused_experts=fused_experts, inplace=False, @@ -242,11 +242,11 @@ def make_modular_kernel( num_local_experts: int, test_tensors: TestTensors, quant_config: FusedMoEQuantConfig, -) -> FusedMoEModularKernel: +) -> FusedMoEKernel: q_dtype = torch.float8_e4m3fn test_config = test_tensors.config - mk: FusedMoEModularKernel + mk: FusedMoEKernel # Make modular kernel if test_config.low_latency: max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0))) @@ -307,7 +307,7 @@ def build_expert_map(): ) # Make modular kernel - mk: FusedMoEModularKernel = make_modular_kernel( + mk: FusedMoEKernel = make_modular_kernel( pg=pg, pgi=pgi, dp_size=dp_size, @@ -319,7 +319,7 @@ def build_expert_map(): with with_dp_metadata( M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size ): - out = mk.forward( + out = mk.apply( hidden_states=test_tensors.rank_tokens, w1=w1, w2=w2, diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 01f340730af3..362b71a40f2d 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -20,7 +20,7 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) @@ -135,7 +135,7 @@ def make_modular_kernel( q_dtype: torch.dtype | None, use_fp8_dispatch: bool, quant_config: FusedMoEQuantConfig, -) -> FusedMoEModularKernel: +) -> FusedMoEKernel: ht_args: DeepEPHTArgs | None = None ll_args: DeepEPLLArgs | None = None @@ -180,7 +180,7 @@ def make_modular_kernel( quant_config=quant_config, ) - mk = FusedMoEModularKernel( + mk = FusedMoEKernel( prepare_finalize=a2a, fused_experts=fused_experts, inplace=False, @@ -242,7 +242,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): ) # Make modular kernel - mk: FusedMoEModularKernel = make_modular_kernel( + mk: FusedMoEKernel = make_modular_kernel( pg, pgi, low_latency_mode, @@ -255,7 +255,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): quant_config, ) - out = mk.forward( + out = mk.apply( hidden_states=rank_tokens_chunk, w1=w1, w2=w2, diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 7f9bccb739ef..c2949391c798 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -14,13 +14,16 @@ # vLLM fused-expert reference (Triton fallback + DeepGEMM option) import vllm.model_executor.layers.fused_moe.modular_kernel as mk from tests.kernels.moe.utils import make_dummy_moe_config +from vllm.model_executor.layers.fused_moe.activation import ( + MoEActivation, +) +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, ) @@ -108,11 +111,17 @@ def run_single_case(m, n, k, topk, num_experts, block_size): a1_scale=a1_scale, block_shape=block_size, ) + moe_config = make_dummy_moe_config() - deep_gemm_experts = mk.FusedMoEModularKernel( - prepare_finalize=MoEPrepareAndFinalizeNoEP(), + deep_gemm_experts = mk.FusedMoEKernel( + prepare_finalize=maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), fused_experts=TritonOrDeepGemmExperts( - moe_config=make_dummy_moe_config(), + moe_config=moe_config, quant_config=quant_config, ), inplace=False, @@ -130,12 +139,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size): ) # DeepGemm - out_deepgemm = deep_gemm_experts( + out_deepgemm = deep_gemm_experts.apply( hidden_states=tokens_bf16, w1=w1, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, + global_num_experts=num_experts, + activation=MoEActivation.SILU, + apply_router_weight_on_input=False, + expert_map=None, ) diff = calc_diff(out_deepgemm, out_triton) assert diff < 0.001, f"Diff exceeded 1%: {diff}" diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index d524b5667047..6a51853c0022 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -8,6 +8,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -15,16 +18,14 @@ RoutingMethodType, fp8_w8a8_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( + TrtLlmFp8Experts, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, ) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_fi_trtllm_fp8_per_tensor_moe, - register_scales_for_trtllm_fp8_per_tensor_moe, rotate_weights_for_fi_trtllm_fp8_per_tensor_moe, swap_w13_to_w31, ) @@ -115,6 +116,7 @@ def make_moe_tensors_8bit( e: int, is_trtllm: bool, activation: MoEActivation = MoEActivation.SILU, + topk: int = 1, ) -> "TestData": is_gated = activation.is_gated @@ -152,13 +154,6 @@ def make_moe_tensors_8bit( rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( layer.w13_weight, layer.w2_weight, is_gated ) - register_scales_for_trtllm_fp8_per_tensor_moe( - layer, - layer.w13_weight_scale, - layer.w13_input_scale, - layer.w2_weight_scale, - layer.w2_input_scale, - ) layer.custom_routing_function = Llama4MoE.custom_routing_function layer.routing_method_type = RoutingMethodType.Llama4 layer.renormalize = False @@ -166,6 +161,21 @@ def make_moe_tensors_8bit( layer.ep_rank = 0 layer.local_num_experts = e + layer.moe = FusedMoEConfig( + num_experts=e, + experts_per_token=topk, + hidden_dim=k, + intermediate_size_per_partition=n, + num_local_experts=e, + num_logical_experts=e, + moe_parallel_config=layer.moe_parallel_config, + in_dtype=hidden_states.dtype, + is_act_and_mul=is_gated, + routing_method=layer.routing_method_type, + activation=activation, + device=w13_quantized.device, + ) + return TestData( hidden_states=hidden_states, w13_quantized=w13_quantized, @@ -230,16 +240,29 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( quant_config=quant_config, ) - flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe( - layer=td.layer, + kernel = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=td.layer.moe, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=True, + ), + TrtLlmFp8Experts( + moe_config=td.layer.moe, + quant_config=quant_config, + ), + ) + + flashinfer_output = kernel.apply_monolithic( hidden_states=td.hidden_states, + w1=td.layer.w13_weight, + w2=td.layer.w2_weight, router_logits=score, - routing_bias=None, + activation=activation, global_num_experts=e, - top_k=topk, - num_expert_group=None, - topk_group=None, + expert_map=None, apply_router_weight_on_input=True, + routed_scaling_factor=1.0, ) check_accuracy( @@ -329,8 +352,13 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig: routing_method=RoutingMethodType.TopK, ) - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + kernel = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), FlashInferExperts( moe_config=moe_config, quant_config=quant_config, @@ -338,7 +366,7 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig: inplace=False, ) - flashinfer_cutlass_output = kernel( + flashinfer_cutlass_output = kernel.apply( td.hidden_states, td.layer.w13_weight, td.layer.w2_weight, diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 1f1349cff841..a3fb474f1517 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -14,6 +14,9 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -23,10 +26,7 @@ FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe, ) -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.torch_utils import set_random_seed @@ -107,19 +107,27 @@ def test_flashinfer_fp4_moe_no_graph( routing_method=RoutingMethodType.TopK, ) - flashinfer_experts = FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + flashinfer_experts = FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), FlashInferExperts(moe_config=moe_config, quant_config=quant_config), inplace=False, ) - flashinfer_output = flashinfer_experts( + flashinfer_output = flashinfer_experts.apply( hidden_states=a, w1=w1_q, w2=w2_q, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, + global_num_experts=e, + expert_map=None, + apply_router_weight_on_input=False, ) # Reference check: diff --git a/tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py b/tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py index d6735b126e2f..aaf255ca8b6a 100644 --- a/tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py +++ b/tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py @@ -221,16 +221,16 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group ) marlin_output = fused_marlin_moe( - a, - w1_marlin, - w2_marlin, - None, - None, - w1_scales_marlin, - w2_scales_marlin, - None, # gating_output not needed when topk_weights/ids provided - topk_weights, - topk_ids, + hidden_states=a, + w1=w1_marlin, + w2=w2_marlin, + bias1=None, + bias2=None, + w1_scale=w1_scales_marlin, + w2_scale=w2_scales_marlin, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_type_id=scalar_types.uint4b8.id, global_num_experts=e, expert_map=None, global_scale1=None, @@ -244,7 +244,6 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group w1_zeros=None, w2_zeros=None, input_dtype=dtype, - quant_type_id=scalar_types.uint4b8.id, is_k_full=True, ) diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index cd1d0a0afe9f..cac22a185fe9 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -168,7 +168,6 @@ def run(config: Config, verbose: bool): def is_nyi_config(config: Config) -> bool: # We know these configs to be legitimate. but still fail. info = expert_info(config.fused_experts_type) - if info.needs_matching_quant: # The triton kernels expect both per-act-token-quant and # per-out-ch-quant or neither. @@ -259,7 +258,7 @@ def test_modular_kernel_combinations_multigpu( dtype: torch.dtype, quant_config: TestMoEQuantConfig | None, prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, - fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + fused_experts_type: mk.FusedMoEExperts, chunk_size: int | None, world_size: int, pytestconfig, @@ -301,7 +300,7 @@ def test_modular_kernel_combinations_singlegpu( dtype: torch.dtype, quant_config: TestMoEQuantConfig | None, prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, - fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + fused_experts_type: mk.FusedMoEExperts, chunk_size: int | None, world_size: int, pytestconfig, diff --git a/tests/kernels/moe/test_modular_oai_triton_moe.py b/tests/kernels/moe/test_modular_oai_triton_moe.py index 99d96e970ed0..b071e72dafbb 100644 --- a/tests/kernels/moe/test_modular_oai_triton_moe.py +++ b/tests/kernels/moe/test_modular_oai_triton_moe.py @@ -7,6 +7,7 @@ import pytest import torch +from tests.utils import wait_for_gpu_memory_to_clear from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.utils.import_utils import has_triton_kernels @@ -24,15 +25,15 @@ from triton_kernels.testing import assert_close from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, UnfusedOAITritonExperts, ) -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed @@ -174,19 +175,25 @@ def oai_triton_moe_impl( w1_scale=w1_scale, w2_scale=w2_scale, ) + moe_config = make_dummy_moe_config() if unfused: - fused_experts = UnfusedOAITritonExperts(make_dummy_moe_config(), quant_config) + fused_experts = UnfusedOAITritonExperts(moe_config, quant_config) else: - fused_experts = OAITritonExperts(make_dummy_moe_config(), quant_config) - - mk = FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + fused_experts = OAITritonExperts(moe_config, quant_config) + + mk = FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), fused_experts, inplace=False, ) - return mk.forward( + return mk.apply( hidden_states=x, w1=w1, w2=w2, @@ -217,6 +224,7 @@ def test_oai_triton_moe( unfused: bool, workspace_init, ): + wait_for_gpu_memory_to_clear(devices=[0], threshold_ratio=0.1) set_random_seed(0) ( w1, diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index eb3d9f8a8f6b..cda0b5c11040 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -346,14 +346,16 @@ def m_fused_moe( expert_map: torch.Tensor | None = None, ) -> torch.Tensor: topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - return m_fused_moe_fn( + return m_fused_moe_fn.apply( a, w1, w2, topk_weights, topk_ids, + activation=MoEActivation.SILU, global_num_experts=global_num_experts, expert_map=expert_map, + apply_router_weight_on_input=False, ) fused_moe_fn = functools.partial(fused_moe, renormalize=False) @@ -500,14 +502,16 @@ def m_fused_moe( expert_map: torch.Tensor | None = None, ) -> torch.Tensor: topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - return m_fused_moe_fn( + return m_fused_moe_fn.apply( a, w1, w2, topk_weights, topk_ids, + activation=MoEActivation.SILU, global_num_experts=global_num_experts, expert_map=expert_map, + apply_router_weight_on_input=False, ) fused_moe_fn = functools.partial(fused_moe, renormalize=False) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index af47ca91a79f..e12659729c9c 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -15,12 +15,15 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassExpertsFp4, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, + make_moe_prepare_and_finalize_no_dp_ep, ) from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed @@ -89,22 +92,32 @@ def test_cutlass_fp4_moe_no_graph( w1_scale=w1_blockscale, w2_scale=w2_blockscale, ) + moe_config = make_dummy_moe_config() - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + kernel = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), CutlassExpertsFp4( - moe_config=make_dummy_moe_config(), + moe_config=moe_config, quant_config=quant_config, ), inplace=False, ) - cutlass_output = kernel( + cutlass_output = kernel.apply( hidden_states=a, w1=w1_q, w2=w2_q, topk_weights=topk_weights, topk_ids=topk_ids, + global_num_experts=e, + activation=mk.MoEActivation.SILU, + apply_router_weight_on_input=False, + expert_map=None, ) # Reference check: @@ -207,8 +220,8 @@ def test_cutlass_fp4_moe_swiglustep( w2_scale=w2_blockscale, ) - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + kernel = mk.FusedMoEKernel( + make_moe_prepare_and_finalize_no_dp_ep(use_monolithic=False), CutlassExpertsFp4( moe_config=make_dummy_moe_config(), quant_config=quant_config, @@ -216,13 +229,16 @@ def test_cutlass_fp4_moe_swiglustep( inplace=False, ) - cutlass_output = kernel( + cutlass_output = kernel.apply( hidden_states=a, w1=w1_q, w2=w2_q, topk_weights=topk_weights, topk_ids=topk_ids, activation=MoEActivation.SWIGLUSTEP, + global_num_experts=e, + expert_map=None, + apply_router_weight_on_input=False, ) # Reference: dequantize everything and run torch_moe with swiglustep diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index e0a234111fe8..4b693d8c8a55 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -8,6 +8,9 @@ from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -23,10 +26,7 @@ TritonExperts, fused_experts, ) -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.deep_gemm import per_block_cast_to_fp8 @@ -125,7 +125,9 @@ def batched_moe( a2_scale=a2_scale, ) - fused_experts = FusedMoEModularKernel( + moe_config = make_dummy_moe_config() + + fused_experts = FusedMoEKernel( BatchedPrepareAndFinalize( max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 ), @@ -133,12 +135,22 @@ def batched_moe( max_num_tokens=max_num_tokens, num_dispatchers=1, quant_config=quant_config, - moe_config=make_dummy_moe_config(), + moe_config=moe_config, ), inplace=False, ) - return fused_experts(a, w1, w2, topk_weight, topk_ids) + return fused_experts.apply( + a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=w1.shape[0], + activation=moe_config.activation, + apply_router_weight_on_input=False, + expert_map=None, + ) def naive_batched_moe( @@ -166,8 +178,9 @@ def naive_batched_moe( a1_scale=a1_scale, a2_scale=a2_scale, ) + moe_config = make_dummy_moe_config() - fused_experts = FusedMoEModularKernel( + fused_experts = FusedMoEKernel( BatchedPrepareAndFinalize( max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 ), @@ -175,12 +188,22 @@ def naive_batched_moe( max_num_tokens=max_num_tokens, num_dispatchers=1, quant_config=quant_config, - moe_config=make_dummy_moe_config(), + moe_config=moe_config, ), inplace=False, ) - return fused_experts(a, w1, w2, topk_weight, topk_ids) + return fused_experts.apply( + a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=w1.shape[0], + activation=moe_config.activation, + apply_router_weight_on_input=False, + expert_map=None, + ) def chunk_scales( @@ -581,9 +604,14 @@ def modular_triton_fused_moe( moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None, -) -> FusedMoEModularKernel: - return FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), +) -> FusedMoEKernel: + return FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), TritonExperts(moe_config, quant_config), shared_experts, inplace=False, diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 3a44ff4236a1..fe44017a04ee 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -127,6 +127,14 @@ def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): ) +def test_deepseek_fp8_block_moe_vllm_triton(monkeypatch: pytest.MonkeyPatch): + can_initialize( + "deepseek-ai/DeepSeek-V3.1", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=triton"], + ) + + @pytest.mark.skip( reason=( "Known issue: lack of kernel support. " @@ -149,6 +157,14 @@ def test_deepseek_fp8_block_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatc ) +def test_deepseek_nvfp4_moe_flashinfer_vllm(monkeypatch: pytest.MonkeyPatch): + can_initialize( + "nvidia/DeepSeek-R1-0528-FP4-v2", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=cutlass"], + ) + + def test_deepseek_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): can_initialize( "nvidia/DeepSeek-R1-0528-FP4-v2", @@ -200,3 +216,67 @@ def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): hf_overrides=HF_OVERRIDE_TEXT, extra_args=["--moe-backend=flashinfer_trtllm"], ) + + +## NemoTron ## + + +def test_nemotron_fp8_moe_flashinfer_throughput(monkeypatch: pytest.MonkeyPatch): + can_initialize( + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=flashinfer_cutlass"], + ) + + +@pytest.mark.skip( + reason=( + "FP8 MoE backend FLASHINFER_TRTLLM does not support the " + "deployment configuration since kernel does not support " + "no act_and_mul MLP layer." + ) +) +def test_nemotron_fp8_moe_flashinfer_latency(monkeypatch: pytest.MonkeyPatch): + can_initialize( + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=flashinfer_trtllm"], + ) + + +@pytest.mark.skip( + reason=( + "FP8 MoE backend TRITON does not support the " + "deployment configuration since kernel does not support " + "no act_and_mul MLP layer." + ) +) +def test_nemotron_fp8_moe_vllm_triton(monkeypatch: pytest.MonkeyPatch): + can_initialize( + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=triton"], + ) + + +def test_nemotron_fp4_moe_flashinfer_throughput(monkeypatch: pytest.MonkeyPatch): + can_initialize( + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=flashinfer_cutlass"], + ) + + +@pytest.mark.skip( + reason=( + "FP4 MoE backend FLASHINFER_TRTLLM does not support the " + "deployment configuration since kernel does not support " + "hidden_dim % 512 != 0." + ) +) +def test_nemotron_fp4_moe_flashinfer_latency(monkeypatch: pytest.MonkeyPatch): + can_initialize( + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=flashinfer_trtllm"], + ) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index c13ed44e6f70..eff05b575856 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -32,10 +32,10 @@ UnfusedOAITritonExperts, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel, + FusedMoEKernel, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, + MoEPrepareAndFinalizeNoDPEPModular, ) from .utils import _get_lora_device, try_get_optimal_moe_lora_config @@ -136,7 +136,7 @@ def _inject_lora_into_fused_moe(self): if getattr(self.base_layer.quant_method, "supports_internal_mk", False): # Use the existing modular kernel from the quant method - m_fused_moe_fn = self.base_layer.quant_method.moe_mk + m_fused_moe_fn = self.base_layer.quant_method.moe_kernel # Don't let the kernel own shared experts so the runner can # overlap them with routed experts via a separate CUDA stream. m_fused_moe_fn.shared_experts = None @@ -144,8 +144,8 @@ def _inject_lora_into_fused_moe(self): # Create a new modular kernel via select_gemm_impl. # Don't pass shared_experts to the kernel so the runner can # overlap them with routed experts via a separate CUDA stream. - prepare_finalize = MoEPrepareAndFinalizeNoEP() - m_fused_moe_fn = FusedMoEModularKernel( + prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular() + m_fused_moe_fn = FusedMoEKernel( prepare_finalize, self.base_layer.quant_method.select_gemm_impl( prepare_finalize, self.base_layer @@ -154,10 +154,11 @@ def _inject_lora_into_fused_moe(self): if quant_config.use_mxfp4_w4a16: assert isinstance( - m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts) + m_fused_moe_fn.impl.fused_experts, + (MarlinExperts, UnfusedOAITritonExperts), ) else: - assert isinstance(m_fused_moe_fn.fused_experts, TritonExperts) + assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts) def fwd_decorator(layer, func): def wrapper(*args, **kwargs): @@ -337,9 +338,9 @@ def wrapper(*args, **kwargs): return wrapper - fused_experts = m_fused_moe_fn.fused_experts + fused_experts = m_fused_moe_fn.impl.fused_experts - m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward) + m_fused_moe_fn.apply = fwd_decorator(self.base_layer, m_fused_moe_fn.apply) fused_experts.activation = act_decorator( self.base_layer, fused_experts.activation ) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index be901bd24490..f56a2e63bf40 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -22,8 +22,8 @@ ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize, + FusedMoEExpertsModular, + FusedMoEPrepareAndFinalizeModular, ) from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( FusedMoERouter, @@ -62,9 +62,9 @@ def get_config() -> dict[str, Any] | None: "MoEActivation", "UnquantizedFusedMoEMethod", "FusedMoeWeightScaleSupported", - "FusedMoEPermuteExpertsUnpermute", + "FusedMoEExpertsModular", "FusedMoEActivationFormat", - "FusedMoEPrepareAndFinalize", + "FusedMoEPrepareAndFinalizeModular", "GateLinear", "RoutingMethodType", "SharedFusedMoE", diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 8c1bfe1c3675..47ca95ee54cb 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -21,8 +21,8 @@ FusedMoEPrepareAndFinalize, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNaiveEP, - MoEPrepareAndFinalizeNoEP, + make_moe_prepare_and_finalize_naive_dp_ep, + make_moe_prepare_and_finalize_no_dp_ep, ) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep, has_mori @@ -77,6 +77,7 @@ def maybe_make_prepare_finalize( quant_config: FusedMoEQuantConfig | None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, allow_new_interface: bool = False, + use_monolithic: bool = False, ) -> FusedMoEPrepareAndFinalize | None: # NOTE(rob): we are migrating each quant_method to hold the MK # in all cases. The allow_new_interface=False flag allow us to fall @@ -102,14 +103,15 @@ def maybe_make_prepare_finalize( "Detected DP deployment with no --enable-expert-parallel. " "Falling back to AllGather+ReduceScatter dispatch/combine." ) - return MoEPrepareAndFinalizeNaiveEP( + return make_moe_prepare_and_finalize_naive_dp_ep( is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel, num_dispatchers=( get_ep_group().device_communicator.all2all_manager.world_size ), + use_monolithic=use_monolithic, ) else: - return MoEPrepareAndFinalizeNoEP() + return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic) all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None @@ -201,8 +203,9 @@ def maybe_make_prepare_finalize( ) elif moe.use_naive_all2all_kernels and allow_new_interface: - prepare_finalize = MoEPrepareAndFinalizeNaiveEP( - is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel), + prepare_finalize = make_moe_prepare_and_finalize_naive_dp_ep( + use_monolithic=use_monolithic, + is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel, num_dispatchers=all2all_manager.world_size, ) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 405965c5395b..539712587a71 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -261,7 +261,7 @@ def persistent_masked_m_silu_mul_quant( return y_q, y_s -class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): +class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular): def __init__( self, moe_config: FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 33d69b57a934..e0ed9130c2ce 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -228,6 +228,7 @@ class FusedMoEQuantConfig: _a2: FusedMoEQuantDesc _w1: FusedMoEQuantDesc _w2: FusedMoEQuantDesc + is_nvfp4_scale_swizzled: bool = True def __post_init__(self): assert not self.per_act_token_quant or self.block_shape is None, ( @@ -475,6 +476,7 @@ def make( w1_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None, weight_dtype: torch.dtype | str | None = None, + is_nvfp4_scale_swizzled: bool = True, ) -> "FusedMoEQuantConfig": """ General builder function for a FusedMoEQuantConfig. @@ -504,6 +506,7 @@ def make( - w2_bias: Optional biases for w1 (GPT OSS Triton). - w1_zp: Optional w1 zero points for int4/int8 quantization. - w2_zp: Optional w2 zero points for int4/int8 quantization. + - is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling. """ assert not isinstance(quant_dtype, str) or quant_dtype in { "nvfp4", @@ -536,6 +539,7 @@ def make( _w2=FusedMoEQuantDesc( weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias ), + is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled, ) assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_out_ch_quant == per_out_ch_quant @@ -737,6 +741,7 @@ def nvfp4_moe_quant_config( w2_scale: torch.Tensor, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, + is_nvfp4_scale_swizzled: bool = True, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and nvp4 weights. @@ -754,6 +759,7 @@ def nvfp4_moe_quant_config( per_act_token_quant=False, per_out_ch_quant=False, block_shape=None, + is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled, ) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index ac9ba56a6b70..64848bf931ae 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -21,7 +21,7 @@ moe_unpermute, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, + MoEPrepareAndFinalizeNoDPEPModular, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, @@ -262,7 +262,7 @@ def run_cutlass_moe_fp8( ) -class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): +class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular): def __init__( self, moe_config: FusedMoEConfig, @@ -661,7 +661,7 @@ def run_cutlass_moe_fp4( return -class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): +class CutlassExpertsFp4(mk.FusedMoEExpertsModular): """CUTLASS FP4 fused MoE expert implementation.""" @property @@ -928,7 +928,7 @@ def run_cutlass_moe_w4a8_fp8( ) -class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): +class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular): def __init__( self, out_dtype: torch.dtype | None, @@ -1170,8 +1170,8 @@ def cutlass_moe_w4a8_fp8( num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) - fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + fn = mk.FusedMoEKernel( + MoEPrepareAndFinalizeNoDPEPModular(), CutlassExpertsW4A8Fp8( out_dtype=a.dtype, a_strides1=a_strides1, @@ -1186,10 +1186,9 @@ def cutlass_moe_w4a8_fp8( quant_config=quant_config, group_size=group_size, ), - inplace=False, ) - return fn( + return fn.apply( a, w1_q, w2_q, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 69ca7c91cfda..8af439a0d435 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -113,7 +113,7 @@ def _valid_deep_gemm( return True -class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): +class DeepGemmExperts(mk.FusedMoEExpertsModular): """DeepGemm-based fused MoE expert implementation.""" def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig): diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 514aa205a3cb..63312557d85d 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -25,7 +25,7 @@ ) -class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): +class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """ Prepare/Finalize using DeepEP High-Throughput kernels. """ @@ -239,6 +239,7 @@ def _receiver( quant_dtype=quant_config.quant_dtype, per_act_token_quant=False, block_shape=quant_config.block_shape, + is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled, ) return ( diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index a4cee76f7167..a22b89415364 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -49,7 +49,7 @@ def dequant_fp8( return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) -class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): +class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """ Prepare/Finalize using DeepEP low-latency kernels. """ @@ -119,7 +119,7 @@ def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None: # time. This setting is handled by post_init_setup. self.use_ue8m0_dispatch = False - def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute): + def post_init_setup(self, fused_experts: mk.FusedMoEExperts): if not fused_experts.supports_packed_ue8m0_act_scales(): # Early exit. return diff --git a/vllm/model_executor/layers/fused_moe/experts/__init__.py b/vllm/model_executor/layers/fused_moe/experts/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py new file mode 100644 index 000000000000..febb3b2ef0d7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + RoutingMethodType, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + activation_to_flashinfer_int, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, + kFp8StaticTensorSym, +) +from vllm.platforms import current_platform + + +class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): + """ + Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface. + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(moe_config, quant_config) + + if moe_config.moe_parallel_config.use_ep and quant_config.is_per_tensor: + raise NotImplementedError( + "EP parallelism is not supported with TRTLLM" + "per-tensor FP8 quantization." + ) + + self.routing_method_type = moe_config.routing_method + self.topk = moe_config.experts_per_token + self.intermediate_size_per_partition = ( + moe_config.intermediate_size_per_partition + ) + self.hidden_dim = moe_config.hidden_dim + self.local_num_experts = moe_config.num_local_experts + self.ep_rank = moe_config.moe_parallel_config.ep_rank + + # Make additional scales for per-tensor interface. + if self.quant_config.is_per_tensor: + w1_scale = self.quant_config.w1_scale + assert w1_scale is not None + a1_scale = self.quant_config.a1_scale + assert a1_scale is not None + w2_scale = self.quant_config.w2_scale + assert w2_scale is not None + a2_scale = self.quant_config.a2_scale + assert a2_scale is not None + + self._g1_alphas = (w1_scale * a1_scale).squeeze() + self._g2_alphas = (w2_scale * a2_scale).squeeze() + self._g1_scale_c = ( + self._g1_alphas / self.quant_config.a2_scale + if moe_config.is_act_and_mul + else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale + ) + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + """Supports only Blackwell-family GPUs.""" + p = current_platform + # Add check flashinfer trtllm is available + return p.is_cuda() and p.is_device_capability_family(100) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + """Does not support non-gated MoE (i.e. Nanotron-3-Nano).""" + return True + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Supports Fp8 per-tensor and Fp8 block.""" + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + """Supports only SiLU and RELU^2 non-gated activation.""" + return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Monolithic kernels need to express router support.""" + # NOTE(dbari): TopK routing could also be enabled, but need to validate models + # NOTE(dbari): Default is not implemented and should not be enabled until it is + if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym): + # NOTE(rob): potentially allow others here. This is a conservative list. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): + # NOTE(dbari): as above, potentially allow others here. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Llama4, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + else: + raise ValueError("Unsupported quantization scheme.") + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """Monolithic kernel so only use with naive DP/EP and TP.""" + return ( + not moe_parallel_config.use_all2all_kernels + or moe_parallel_config.use_naive_all2all_kernels + ) and not moe_parallel_config.enable_eplb + + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + """ + The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default. + Only DeepSeekV3 routing supports float32 router_logits (which is converted + internally in the kernel). + """ + if router_logits_dtype == torch.float32: + # Only DeepSeekV3 routing handles float32 logits + # https://github.com/flashinfer-ai/flashinfer/issues/2469 + return routing_method == RoutingMethodType.DeepSeekV3 + return True + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + def _apply_per_block( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + # Delay import for non-CUDA. + import flashinfer + + assert not apply_router_weight_on_input + assert activation == MoEActivation.SILU + + if e_score_correction_bias is not None: + e_score_correction_bias = e_score_correction_bias.to(hidden_states.dtype) + + if self.routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + + assert self.topk <= global_num_experts + assert self.topk <= 10 + assert global_num_experts % 4 == 0 + assert self.quant_config.block_shape == [128, 128] + # Routing kernel expects #experts <= #threads 512 + assert global_num_experts <= 512 + + # Kernel requires transposed hidden state scales + # TODO: fuse into the quant kernel. + assert a1q_scale is not None + a1q_scale_t = a1q_scale.t().contiguous() + + return flashinfer.fused_moe.trtllm_fp8_block_scale_moe( + routing_logits=router_logits, + routing_bias=e_score_correction_bias, + hidden_states=hidden_states, + hidden_states_scale=a1q_scale_t, + gemm1_weights=w1, + gemm1_weights_scale=self.quant_config.w1_scale, + gemm2_weights=w2, + gemm2_weights_scale=self.quant_config.w2_scale, + num_experts=global_num_experts, + top_k=self.topk, + n_group=(num_expert_group or 0), + topk_group=(topk_group or 0), + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=self.routing_method_type, + use_shuffled_weight=False, + ) + + def _apply_per_tensor( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + # Delay import for non-CUDA. + import flashinfer + from flashinfer.fused_moe.core import ActivationType + + # Confirm supported activation function. + assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + + activation_type = ActivationType(activation_to_flashinfer_int(activation)) + + # Confirm Llama-4 routing is proper. + if self.routing_method_type == RoutingMethodType.Llama4: + assert apply_router_weight_on_input + else: + assert not apply_router_weight_on_input + + # The DeepSeekV3 routing method requires float32 router logits. + if self.routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + + out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe( + routing_logits=router_logits, + routing_bias=e_score_correction_bias, + hidden_states=hidden_states, + gemm1_weights=w1, + output1_scales_scalar=self._g1_scale_c, + output1_scales_gate_scalar=self._g1_alphas, + gemm2_weights=w2, + output2_scales_scalar=self._g2_alphas, + num_experts=global_num_experts, + top_k=self.topk, + n_group=num_expert_group or 0, + topk_group=topk_group or 0, + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=routed_scaling_factor, + use_routing_scales_on_input=apply_router_weight_on_input, + routing_method_type=self.routing_method_type, + activation_type=activation_type, + ) + return out + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + if self.quant_config.block_shape is not None: + return self._apply_per_block( + hidden_states, + w1, + w2, + router_logits, + activation, + global_num_experts, + expert_map, + a1q_scale, + apply_router_weight_on_input, + num_expert_group=num_expert_group, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + topk_group=topk_group, + ) + elif self.quant_config.is_per_tensor: + return self._apply_per_tensor( + hidden_states, + w1, + w2, + router_logits, + activation, + global_num_experts, + expert_map, + a1q_scale, + apply_router_weight_on_input, + num_expert_group=num_expert_group, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + else: + raise NotImplementedError( + "Only per-block and per-tensor quantization are supported in " + f"{self.__class__.__name__}." + ) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py new file mode 100644 index 000000000000..502671766400 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -0,0 +1,326 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import flashinfer +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + RoutingMethodType, +) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + activation_to_flashinfer_int, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kNvfp4Dynamic, + kNvfp4Static, +) +from vllm.platforms import current_platform + + +class TrtLlmNvFp4ExpertsBase: + """ + NvFp4 TRTLLM-Gen MoE kernels. Supports modular and monolithic interface. + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + self.moe_config = moe_config + self.quant_config = quant_config + + self.routing_method_type = self.moe_config.routing_method + self.topk = moe_config.experts_per_token + self.intermediate_size_per_partition = ( + moe_config.intermediate_size_per_partition + ) + self.hidden_dim = moe_config.hidden_dim + self.local_num_experts = moe_config.num_local_experts + self.ep_rank = moe_config.moe_parallel_config.ep_rank + + assert self.quant_config.g1_alphas is not None + assert self.quant_config.a2_gscale is not None + if moe_config.is_act_and_mul: + # g1_alpha_s = a13_scale * w13_scale_2 + # a2_gscale = (1 / a2_scale) + # g1_scale_c = a13_scale * w13_scale_2 / a2_scale + self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale + else: + self.g1_scale_c = ( + torch.ones_like(self.quant_config.a1_gscale) + * self.quant_config.a2_gscale + ) + + @staticmethod + def _supports_current_device() -> bool: + """Supports only Blackwell-family GPUs.""" + p = current_platform + return p.is_cuda() and p.is_device_capability_family(100) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + """Supports non-gated MoE (i.e. Nemotron-Nano).""" + return True + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Supports Nvfp4 quantization.""" + SUPPORTED_W_A = [ + (kNvfp4Static, kNvfp4Dynamic), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + """Supports only SiLU and RELU^2 non-gated activation.""" + return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + + @staticmethod + def _supports_shape(hidden_dim: int) -> bool: + """Requires hidden dim to be multiple of 512.""" + return hidden_dim % 512 == 0 + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + +class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular): + """ + Modular version of the implementation (just the experts). + """ + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """The modular implementation supports all parallel configs.""" + return True + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: MoEActivation, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # The workspaces for this implementation are managed by flashinfer. + workspace1 = (0,) + workspace2 = (0,) + + # Hidden states are Nvfp4, packed into int8 dtype, so we + # need to multiply K by 2 to get the output shape right. + assert self.hidden_dim == K * 2 + output = (M, self.hidden_dim) + + return (workspace1, workspace2, output) + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + assert a1q_scale is not None + assert self.quant_config.w1_scale is not None + assert self.quant_config.w2_scale is not None + + # Pack topk ids and weights into format expected by the kernel. + packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( + torch.bfloat16 + ).view(torch.int16) + + # trtllm_fp4_block_scale_routed_moe does not support autotuning + # so skip this kernel during dummy run for autotuning. + import vllm.utils.flashinfer as fi_utils + + if fi_utils._is_fi_autotuning: + return hidden_states + + # Invoke kernel. + flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( + topk_ids=packed_tensor, + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape( + *hidden_states.shape[:-1], -1 + ), + gemm1_weights=w1, + gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=w2, + gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn), + gemm2_bias=None, + output1_scale_scalar=self.g1_scale_c, + output1_scale_gate_scalar=self.quant_config.g1_alphas, + output2_scale_scalar=self.quant_config.g2_alphas, + num_experts=global_num_experts, + top_k=self.topk, + n_group=0, + topk_group=0, + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=None, + routing_method_type=1, + do_finalize=True, + activation_type=activation_to_flashinfer_int(activation), + output=output, + ) + + +class TrtLlmNvFp4ExpertsMonolithic( + TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsMonolithic +): + """ + Monolithic version of the kernel (router + experts). + """ + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """The modular implementation should be used for the Dp/Ep or EPLB case.""" + return ( + not moe_parallel_config.use_all2all_kernels + and not moe_parallel_config.enable_eplb + ) + + @staticmethod + def _supports_routing_method( + routing_method_type: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # NOTE(rob): this is a conservative list. + return routing_method_type in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + RoutingMethodType.Llama4, + ] + + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + """ + The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default. + Only DeepSeekV3 routing supports float32 router_logits (which is converted + internally in the kernel). + """ + if router_logits_dtype == torch.float32: + # Only DeepSeekV3 routing handles float32 logits + # https://github.com/flashinfer-ai/flashinfer/issues/2469 + return routing_method == RoutingMethodType.DeepSeekV3 + return True + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + assert a1q_scale is not None + assert self.quant_config.w1_scale is not None + assert self.quant_config.w2_scale is not None + assert ( + apply_router_weight_on_input + and self.routing_method_type == RoutingMethodType.Llama4 + ) or ( + not apply_router_weight_on_input + and self.routing_method_type != RoutingMethodType.Llama4 + ) + + # Prepare routing bias into kernel format. + routing_bias = e_score_correction_bias + if routing_bias is not None: + routing_bias = routing_bias.to(torch.bfloat16) + router_logits = ( + router_logits.to(torch.float32) + if self.routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ) + + # Invoke kernel. + return flashinfer.fused_moe.trtllm_fp4_block_scale_moe( + routing_logits=router_logits, + routing_bias=routing_bias, + hidden_states=hidden_states, + hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape( + *hidden_states.shape[:-1], -1 + ), + gemm1_weights=w1, + gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=w2, + gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn), + gemm2_bias=None, + output1_scale_scalar=self.g1_scale_c, + output1_scale_gate_scalar=self.quant_config.g1_alphas, + output2_scale_scalar=self.quant_config.g2_alphas, + num_experts=global_num_experts, + top_k=self.topk, + n_group=(num_expert_group or 0), + topk_group=(topk_group or 0), + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=self.routing_method_type, + do_finalize=True, + )[0] diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 4b6458e7fd33..403a71e20761 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -11,13 +11,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey -class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): +class FallbackExperts(mk.FusedMoEExpertsModular, ABC): """Base class for runtime dispatching of expert implementations.""" def __init__( self, - experts: mk.FusedMoEPermuteExpertsUnpermute, - fallback_experts: mk.FusedMoEPermuteExpertsUnpermute, + experts: mk.FusedMoEExpertsModular, + fallback_experts: mk.FusedMoEExpertsModular, ): super().__init__( moe_config=experts.moe_config, quant_config=experts.quant_config @@ -27,8 +27,8 @@ def __init__( @staticmethod def get_clses() -> tuple[ - type[mk.FusedMoEPermuteExpertsUnpermute], - type[mk.FusedMoEPermuteExpertsUnpermute], + type[mk.FusedMoEExpertsModular], + type[mk.FusedMoEExpertsModular], ]: """ Get the cls for the experts and fallback experts. @@ -149,7 +149,7 @@ def _select_experts_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: raise NotImplementedError def apply( diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py index 39b373861d03..465d0ae8f2c4 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py @@ -18,7 +18,7 @@ def get_local_sizes(): return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() -class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): +class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """Base class for FlashInfer MoE prepare and finalize operations.""" def __init__( @@ -185,8 +185,8 @@ def flashinfer_alltoall_dispatch( ep_size, ) - # Swizzle after the A2A if nvfp4. - if quant_config.quant_dtype == "nvfp4": + # Swizzle after the A2A if MoE kernel expects swizzled scales. + if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled: if x_sf.element_size() == 1: x_sf = x_sf.view(torch.uint8) x_sf = nvfp4_block_scale_interleave(x_sf) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index d0cf7533d70f..730dc0c5df3c 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -30,7 +30,7 @@ logger = init_logger(__name__) -class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): +class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): def __init__( self, moe_config: FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index b9566a3a921a..02c31fd39dac 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -60,7 +60,7 @@ def is_valid_flashinfer_cutlass_fused_moe( return True -class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): +class FlashInferExperts(mk.FusedMoEExpertsModular): def __init__( self, moe_config: mk.FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 732ab8e929ca..6765e3613f7f 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -10,16 +10,6 @@ FusedMoEParallelConfig, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, - kFp8Dynamic128Sym, - kFp8Static128BlockSym, - kFp8StaticTensorSym, -) from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op @@ -39,49 +29,10 @@ def _supports_no_act_and_mul() -> bool: return True -def _supports_quant_scheme( - weight_key: QuantKey | None, - activation_key: QuantKey | None, -) -> bool: - """Supports Fp8 per-tensor and Fp8 block.""" - SUPPORTED_W_A = [ - (kFp8Static128BlockSym, kFp8Dynamic128Sym), - (kFp8StaticTensorSym, kFp8StaticTensorSym), - ] - return (weight_key, activation_key) in SUPPORTED_W_A - - def _supports_activation(activation: MoEActivation) -> bool: return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] -def _supports_routing_method( - weight_key: QuantKey | None, - activation_key: QuantKey | None, - routing_method: RoutingMethodType, -) -> bool: - """Monolithic kernels need to express router support.""" - # NOTE(dbari): TopK routing could also be enabled, but need to validate models - # NOTE(dbari): Default is not implemented and should not be enabled until it is - if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym): - # NOTE(rob): potentially allow others here. This is a conservative list. - return routing_method in [ - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - ] - elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): - # NOTE(dbari): as above, potentially allow others here. - return routing_method in [ - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Llama4, - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - ] - else: - raise ValueError("Unsupported quantization scheme.") - - def _supports_routing_method_bf16( routing_method: RoutingMethodType, ) -> bool: @@ -99,62 +50,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo return not moe_parallel_config.enable_eplb -def _supports_router_logits_dtype( - router_logits_dtype: torch.dtype | None, - routing_method: RoutingMethodType, -) -> bool: - """ - The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default. - Only DeepSeekV3 routing supports float32 router_logits (which is converted - internally in the kernel). - """ - if router_logits_dtype == torch.float32: - # Only DeepSeekV3 routing handles float32 logits - # https://github.com/flashinfer-ai/flashinfer/issues/2469 - return routing_method == RoutingMethodType.DeepSeekV3 - return True - - -def is_supported_config_trtllm_fp8( - moe_config: FusedMoEConfig, - weight_key: QuantKey | None, - activation_key: QuantKey | None, - activation_format: mk.FusedMoEActivationFormat, -) -> tuple[bool, str | None]: - """ - This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config - """ - - def _make_reason(reason: str) -> str: - return f"kernel does not support {reason}" - - if not _supports_current_device(): - return False, _make_reason(f"current device {current_platform.device_name}") - elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): - return False, _make_reason("no act_and_mul MLP layer") - elif not _supports_activation(moe_config.activation): - return False, _make_reason(f"{moe_config.activation} activation") - elif not _supports_quant_scheme(weight_key, activation_key): - return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}") - elif not _supports_parallel_config(moe_config.moe_parallel_config): - return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}") - elif not _supports_routing_method( - weight_key, activation_key, moe_config.routing_method - ): - return False, _make_reason(f"routing method {moe_config.routing_method}") - elif activation_format != mk.FusedMoEActivationFormat.Standard: - return False, _make_reason(f"activation format {activation_format}") - elif not _supports_router_logits_dtype( - moe_config.router_logits_dtype, moe_config.routing_method - ): - return False, _make_reason( - "float32 router_logits with non-DeepSeekV3 routing " - f"{moe_config.router_logits_dtype}x{moe_config.routing_method}" - ) - - return True, None - - def is_supported_config_trtllm_bf16( moe_config: FusedMoEConfig, activation_format: mk.FusedMoEActivationFormat, @@ -183,199 +78,6 @@ def _make_reason(reason: str) -> str: return True, None -def flashinfer_fused_moe_blockscale_fp8( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int | None, - topk_group: int | None, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: list[int], - routing_method_type: int, - routed_scaling: float | None = 1.0, -) -> torch.Tensor: - from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe - - num_expert_group = num_expert_group if num_expert_group is not None else 0 - topk_group = topk_group if topk_group is not None else 0 - assert top_k <= global_num_experts - assert top_k <= 10 - assert global_num_experts % 4 == 0 - assert block_shape == [128, 128] - # Routing kernel expects #experts <= #threads 512 - assert global_num_experts <= 512 - - # The DeepSeekV3 routing method requires float32 router logits. - if routing_method_type == RoutingMethodType.DeepSeekV3: - routing_logits = routing_logits.to(torch.float32) - - if routing_bias is not None: - routing_bias = routing_bias.to(x.dtype) - - a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) - # NOTE: scales of hidden states have to be transposed! - a_sf_t = a_sf.t().contiguous() - return flashinfer_trtllm_fp8_block_scale_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=w13_weight, - gemm1_weights_scale=w13_weight_scale_inv, - gemm2_weights=w2_weight, - gemm2_weights_scale=w2_weight_scale_inv, - num_experts=global_num_experts, - top_k=top_k, - n_group=num_expert_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=expert_offset, - local_num_experts=local_num_experts, - routed_scaling_factor=routed_scaling, - routing_method_type=routing_method_type, - use_shuffled_weight=False, - ) - - -def flashinfer_fused_moe_blockscale_fp8_fake( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int, - topk_group: int, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: list[int], - routing_method_type: int, - routed_scaling: float = 1.0, -) -> torch.Tensor: - return torch.empty_like(x) - - -# TODO(bnell): Does this really need to be a torch.op? -direct_register_custom_op( - op_name="flashinfer_fused_moe_blockscale_fp8", - op_func=flashinfer_fused_moe_blockscale_fp8, - fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, - tags=(torch.Tag.needs_fixed_stride_order,), -) - - -def fi_trtllm_fp8_per_tensor_moe( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: int | None, - topk_group: int | None, - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - activation_type: int, - routed_scaling_factor: float = 1.0, -) -> torch.Tensor: - num_expert_group = num_expert_group if num_expert_group is not None else 0 - topk_group = topk_group if topk_group is not None else 0 - - quant_hidden_states, _ = moe_kernel_quantize_input( - hidden_states, - input_scale, - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=False, - ) - - from flashinfer.fused_moe.core import ActivationType - - from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe - - # The DeepSeekV3 routing method requires float32 router logits. - if routing_method_type == RoutingMethodType.DeepSeekV3: - routing_logits = routing_logits.to(torch.float32) - - return flashinfer_trtllm_fp8_per_tensor_scale_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=quant_hidden_states, - gemm1_weights=gemm1_weights, - output1_scales_scalar=output1_scales_scalar, - output1_scales_gate_scalar=output1_scales_gate_scalar, - gemm2_weights=gemm2_weights, - output2_scales_scalar=output2_scales_scalar, - num_experts=num_experts, - top_k=top_k, - n_group=num_expert_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=local_expert_offset, - local_num_experts=local_num_experts, - routed_scaling_factor=routed_scaling_factor, - use_routing_scales_on_input=use_routing_scales_on_input, - routing_method_type=routing_method_type, - # TODO: enum type Required for flashinfer==0.6.3, remove with update - # https://github.com/flashinfer-ai/flashinfer/pull/2508 - activation_type=ActivationType(activation_type), - ) - - -def fi_trtllm_fp8_per_tensor_moe_fake( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: int | None, - topk_group: int | None, - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - activation_type: int, - routed_scaling_factor: float = 1.0, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -# TODO(bnell): Does this really need to be a torch.op? -direct_register_custom_op( - op_name="fi_trtllm_fp8_per_tensor_moe", - op_func=fi_trtllm_fp8_per_tensor_moe, - mutates_args=["hidden_states"], - fake_impl=fi_trtllm_fp8_per_tensor_moe_fake, - tags=(torch.Tag.needs_fixed_stride_order,), -) - - def flashinfer_fused_moe_bf16( routing_logits: torch.Tensor, routing_bias: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 24ae2d3c82c6..68393f768dcc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -489,7 +489,7 @@ def invoke_moe_batched_triton_kernel( ) -class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): +class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """ A reference prepare/finalize class that reorganizes the tokens into expert batched format, i.e. E x max_num_tokens x K. This is the format @@ -645,7 +645,7 @@ def finalize( ) -class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): +class NaiveBatchedExperts(mk.FusedMoEExpertsModular): """ A reference MoE expert class that operates on expert batched format, i.e. E x max_num_tokens x K. This is the format that the batched @@ -877,7 +877,7 @@ def batched_moe_kernel_quantize_input( return A_q, A_q_scale -class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): +class BatchedTritonExperts(mk.FusedMoEExpertsModular): """ A Triton based MoE expert class that operates on expert batched format, i.e. E x max_num_tokens x K. This is the format that the batched diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 4a8f31255ac6..280d090795e2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -526,7 +526,7 @@ def batched_fused_marlin_moe( return output -class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): +class MarlinExpertsBase(mk.FusedMoEExpertsModular): def __init__( self, moe_config: FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 07a9a0a8b522..023cdd0b4340 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1736,7 +1736,7 @@ def fused_experts_impl( intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) # This needs separate memory since it's used concurrently with cache1 - activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation( + activation_out_dim = mk.FusedMoEExpertsModular.adjust_N_for_activation( N, activation_enum ) intermediate_cache2 = torch.empty( @@ -1924,7 +1924,7 @@ def fused_experts_impl( return out_hidden_states -class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): +class TritonExperts(mk.FusedMoEExpertsModular): """Triton-based fused MoE expert implementation.""" def __init__( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index ac7c71e52b2b..88cd173fe6a8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -12,8 +12,8 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize, + FusedMoEExpertsModular, + FusedMoEPrepareAndFinalizeModular, ) from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, @@ -27,19 +27,21 @@ def __init__(self, moe: FusedMoEConfig): super().__init__() self.moe: FusedMoEConfig = moe self.moe_quant_config: FusedMoEQuantConfig | None = None - self.moe_mk: mk.FusedMoEModularKernel | None = None + self.moe_kernel: mk.FusedMoEKernel | None = None @property def supports_internal_mk(self) -> bool: # NOTE(rob): temporary attribute to indicate support for # completed migration to the new internal MK interface. - return self.moe_mk is not None + return self.moe_kernel is not None @property def mk_owns_shared_expert(self) -> bool: # NOTE(rob): temporary attribute to indicate support for # completed migration to the new internal MK interface. - return self.moe_mk is not None and self.moe_mk.shared_experts is not None + return ( + self.moe_kernel is not None and self.moe_kernel.shared_experts is not None + ) @abstractmethod def create_weights( @@ -66,35 +68,25 @@ def uses_weight_scale_2_pattern(self) -> bool: def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> FusedMoEPrepareAndFinalize | None: + ) -> FusedMoEPrepareAndFinalizeModular | None: from .all2all_utils import maybe_make_prepare_finalize - return maybe_make_prepare_finalize( + pf = maybe_make_prepare_finalize( self.moe, self.moe_quant_config, routing_tables ) + assert pf is None or isinstance(pf, FusedMoEPrepareAndFinalizeModular) + return pf def select_gemm_impl( self, - prepare_finalize: FusedMoEPrepareAndFinalize, + prepare_finalize: FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> FusedMoEPermuteExpertsUnpermute: + ) -> FusedMoEExpertsModular: # based on the all2all implementation, select the appropriate # gemm implementation - raise NotImplementedError( - f"{self.__class__.__name__} must select appropriate gemm " - "implementation based on the prepare_finalize" - ) - - def prepare_dp_allgather_tensor( - self, - layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - """Hook to prepare tensors and extra tensors for DP allgather + EP dispatch.""" - raise NotImplementedError( - "Method 'prepare_dp_allgather_tensor' is not implemented in " - f"{self.__class__.__name__}." + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." ) @abstractmethod @@ -105,8 +97,8 @@ def get_fused_moe_quant_config( @property def topk_indices_dtype(self) -> torch.dtype | None: - if self.moe_mk is not None: - return self.moe_mk.prepare_finalize.topk_indices_dtype() + if self.moe_kernel is not None: + return self.moe_kernel.prepare_finalize.topk_indices_dtype() return None @property @@ -119,7 +111,12 @@ def method_name(self) -> str: @property def is_monolithic(self) -> bool: - return False + if self.moe_kernel is None: + if hasattr(self, "experts_cls"): + return self.experts_cls.is_monolithic() + else: + return False + return self.moe_kernel.is_monolithic def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 187464ce8e09..0065c11f3163 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -13,8 +13,8 @@ FusedMoEMethodBase, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel, - FusedMoEPrepareAndFinalize, + FusedMoEKernel, + FusedMoEPrepareAndFinalizeModular, ) logger = init_logger(__name__) @@ -26,15 +26,15 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): # --8<-- [end:modular_fused_moe] def __init__( - self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel + self, old_quant_method: FusedMoEMethodBase, moe_kernel: FusedMoEKernel ): super().__init__(old_quant_method.moe) self.moe_quant_config = old_quant_method.moe_quant_config - self.moe_mk = experts + self.moe_kernel = moe_kernel self.disable_expert_map = getattr( old_quant_method, "disable_expert_map", - not self.moe_mk.supports_expert_map(), + not self.moe_kernel.supports_expert_map(), ) self.old_quant_method = old_quant_method logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) @@ -43,13 +43,13 @@ def __init__( def make( moe_layer: torch.nn.Module, old_quant_method: FusedMoEMethodBase, - prepare_finalize: FusedMoEPrepareAndFinalize, + prepare_finalize: FusedMoEPrepareAndFinalizeModular, shared_experts: torch.nn.Module | None, inplace: bool = False, ) -> "FusedMoEModularMethod": return FusedMoEModularMethod( old_quant_method, - FusedMoEModularKernel( + FusedMoEKernel( prepare_finalize, old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), shared_experts, @@ -90,8 +90,8 @@ def apply( topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.moe_mk is not None - return self.moe_mk( + assert self.moe_kernel is not None + return self.moe_kernel.apply( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 2fcb7f193785..8d6f716e2632 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -511,7 +511,7 @@ def make_routing_data( return routing_data, gather_indx, scatter_indx -class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): +class BaseOAITritonExperts(mk.FusedMoEExpertsModular): @staticmethod def _supports_current_device() -> bool: raise NotImplementedError( diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 043b5ef2669b..7b49282fd2ca 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -20,6 +20,7 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, + RoutingMethodType, ) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, @@ -56,25 +57,25 @@ # MoE kernel implementations. # # The following main classes are defined: -# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE +# * FusedMoEPrepareAndFinalizeModular - an abstract base class for preparation of MoE # inputs (e.g. quantization, distribution) and finalization of Moe outputs. # The prepare method must take care of any needed quantization and the -# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method, +# finalize method, informed by the FusedMoEExpertsModular method, # may apply weights and/or do the final reduction of the output. -# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused +# * FusedMoEExpertsModular - an abstract base class for the main fused # MoE operation, i.e matmul + act_mul + optionally quant + matmul. -# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do +# Some FusedMoEExpertsModular implementations may choose to do # the weight application and/or reduction. The class communicates this # to [Finalize] via a TopKWeightAndReduce object. # * FusedMoEModularKernel - an interface class that combines a -# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to +# FusedMoEPrepareAndFinalizeModular and a FusedMoEExpertsModular to # provide the standard fused MoE kernel interface. # * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen -# by the FusedMoEPermuteExpertsUnpermute implementation that is passed +# by the FusedMoEExpertsModular implementation that is passed # on to [Finalize]. # # [Quantize-Prepare] and [Finalize] functionality are bundled into a single -# class `FusedMoEPrepareAndFinalize` since they could use collective +# class `FusedMoEPrepareAndFinalizeModular` since they could use collective # communication mechanisms that need to be consistent. # @@ -155,25 +156,96 @@ def apply( torch.Tensor | None, ] +# +# PrepareResultType is a tuple of: +# - quantized + dispatched a. +# - quantized + dispatched a1_scales. +# - dispatched router logits. +# +# See `prepare_monolithic` method below. +# +PrepareMonolithicResultType = tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor, +] + ReceiverType = Callable[[], PrepareResultType] +################################################################################ +# Prepare/Finalize +################################################################################ + -# TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ An abstract base class for the [Quantize-Prepare] and [Finalize] steps described above. + + There are two variants of this class: + * FusedMoEPrepareAndFinalizeModular - this operates on topk ids and weights + * FusedMoEPrepareAndFinalizeMonolithic - the operates on router_logits """ - def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"): + def post_init_setup(self, fused_experts: "FusedMoEExperts"): """ - Initialize FusedMoEPrepareAndFinalize settings that depend on - FusedMoEPermuteExpertsUnpermute experts object. - The FusedMoEPrepareAndFinalize implementations that have such + Initialize FusedMoEPrepareAndFinalizeModular settings that depend on + FusedMoEExpertsModular experts object. + The FusedMoEPrepareAndFinalizeModular implementations that have such dependencies may choose to override this function. """ return + @property + @abstractmethod + def activation_format(self) -> FusedMoEActivationFormat: + """ + A property indicating the output format of the activations for the + 'prepare' method. + """ + raise NotImplementedError + + @abstractmethod + def topk_indices_dtype(self) -> torch.dtype | None: + """ + The PrepareFinalize All2All implementations generally constrain the + dtype of the topk_ids they support. This function returns the + required topk indices dtype so it can be respected. + Return None if there are no such restrictions. + """ + raise NotImplementedError + + @abstractmethod + def max_num_tokens_per_rank(self) -> int | None: + """ + Some PrepareFinalize All2All implementations are batched. Meaning, + they can process only as set of tokens at a time. This + function returns the batch size i.e the maximum number of tokens + the implementation can process at a time. + Return None if there are no such restrictions. + """ + raise NotImplementedError + + @abstractmethod + def num_dispatchers(self) -> int: + raise NotImplementedError + + @abstractmethod + def output_is_reduced(self) -> bool: + """ + Indicates whether or not the output of finalize is reduced across all + ranks. + """ + raise NotImplementedError + + +# TODO: pass FusedMoEParallelConfig in as ctor parameter? +class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize): + """ + An abstract base class for the [Quantize-Prepare] and [Finalize] steps + described above for the Modular case. + """ + @abstractmethod def prepare( self, @@ -198,7 +270,7 @@ def prepare( activations, before quantization + dispatching. - quant_config: Quantization info provided by the fused experts. - defer_input_quant: Runtime parameter indicating whether or not to - defer input quantization to the FusedMoEPermuteExpertsUnpermute + defer input quantization to the FusedMoEExpertsModular in cases where the compute kernel expects unquantized inputs Returns a tuple of: @@ -245,7 +317,7 @@ def prepare_async( - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. - defer_input_quant: Runtime parameter indicating whether or not to - defer input quantization to the FusedMoEPermuteExpertsUnpermute + defer input quantization to the FusedMoEExpertsModular in cases where the compute kernel expects unquantized inputs Returns a callback or a hook callback pair that when invoked waits for @@ -338,56 +410,58 @@ def finalize_async( """ raise NotImplementedError - @property - @abstractmethod - def activation_format(self) -> FusedMoEActivationFormat: - """ - A property indicating the output format of the activations for the - 'prepare' method. - """ - raise NotImplementedError + +class FusedMoEPrepareAndFinalizeMonolithic(FusedMoEPrepareAndFinalize): + """ + An abstract base class for the [Quantize-Prepare] and [Finalize] steps + described above for the monolithic case. + """ @abstractmethod - def topk_indices_dtype(self) -> torch.dtype | None: + def prepare( + self, + a1: torch.Tensor, + router_logits: torch.Tensor, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> PrepareMonolithicResultType: """ - The PrepareFinalize All2All implementations generally constrain the - dtype of the topk_ids they support. This function returns the - required topk indices dtype so it can be respected. - Return None if there are no such restrictions. + Optional method for subclasses compatible with monolithic + FusedMoEExpertsModular kernels. + + Perform any quantization (and/or) dispatching needed for this kernel. + - a1: The (unquantized) input to the MoE layer. + - quant_config: Quantization info provided by the fused experts. + - defer_input_quant: Runtime parameter indicating whether or not to + defer input quantization to the FusedMoEExpertsModular + + Returns a tuple of: + - quantized + dispatched a. + - Optional quantized + dispatched a1_scales. """ raise NotImplementedError @abstractmethod - def max_num_tokens_per_rank(self) -> int | None: + def finalize(self, fused_expert_output: torch.Tensor) -> torch.Tensor: """ - Some PrepareFinalize All2All implementations are batched. Meaning, - they can process only as set of tokens at a time. This - function returns the batch size i.e the maximum number of tokens - the implementation can process at a time. - Return None if there are no such restrictions. + Optional method for subclasses compatible with monolithic + FusedMoEExpertsModular kernels. + + Perform any combine plus apply weights and perform a reduction on the + fused experts output. + - fused_expert_output: The unweighted, unreduced output of the fused + experts, it will have (M, topk, K) shape. """ raise NotImplementedError - @abstractmethod - def num_dispatchers(self) -> int: - raise NotImplementedError - @abstractmethod - def output_is_reduced(self) -> bool: - """ - Indicates whether or not the output of finalize is reduced across all - ranks. - """ - raise NotImplementedError +################################################################################ +# Experts +################################################################################ # TODO: add supported activations method (return string) -class FusedMoEPermuteExpertsUnpermute(ABC): - """ - An abstract base class for the [Permute-Experts-Unpermute] step described - above. - """ - +class FusedMoEExperts(ABC): def __init__( self, moe_config: FusedMoEConfig, @@ -419,6 +493,10 @@ def __init__( self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers + @staticmethod + def is_monolithic() -> bool: + raise NotImplementedError("Implemented by subclasses.") + @property def expects_unquantized_inputs(self) -> bool: """ @@ -439,49 +517,6 @@ def activation_format() -> FusedMoEActivationFormat: """ raise NotImplementedError - def moe_problem_size( - self, - a1: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - ) -> tuple[int, int, int, int, int]: - """ - Extract the MoE problem size from the given tensor arguments: - - a: The hidden states, input to the MoE layer. - - w1: The first set of expert weights. - - w2: The second set of expert weights. - - topk_ids: The topk ids. - - Note: extracting the problem shape from the weight and activation - tensors is not obvious. It needs to be done this way specifically - due to subtle issues with particular kernels, e.g. the int4 kernels - divide the trailing dimension by two, so it's not "correct" to - extract N or K from the trailing dimension of w1 or w2. Similarly, - some kernels transpose the weights, so this needs to be kept in mind. - - Note: This implementation covers most cases. However, if experts - require a specialized implementation, like MarlinExperts, they are free - to override this function. - """ - assert w1.dim() == 3 and w2.dim() == 3 - E, N, _ = w1.size() - K = a1.size(-1) - - if a1.dim() == 2: - # Make sure we are using the correct a1 (pre-permute). - assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}" - M = a1.size(0) - else: - assert a1.dim() == 3 - assert a1.size(0) == E, f"{a1.size(0)} == {E}" - M = a1.size(1) # This is max_num_tokens - - assert topk_ids.dim() == 2 - topk = topk_ids.size(1) - - return E, M, N, K, topk - # # Various helpers for registering support for various features. # Used by the oracle to select a particular kernel for a deployment. @@ -489,7 +524,7 @@ def moe_problem_size( @staticmethod def is_supported_config( - cls: type["FusedMoEPermuteExpertsUnpermute"], + cls: type["FusedMoEExperts"], moe_config: FusedMoEConfig, weight_key: QuantKey | None, activation_key: QuantKey | None, @@ -512,6 +547,21 @@ def _make_reason(reason: str) -> str: return False, _make_reason( f"parallel config {moe_config.moe_parallel_config}" ) + elif not cls._supports_routing_method( + moe_config.routing_method, weight_key, activation_key + ): + return False, _make_reason(f"routing method {moe_config.routing_method}") + elif not cls._supports_router_logits_dtype( + moe_config.router_logits_dtype, + moe_config.routing_method, + ): + return False, _make_reason( + f"router logits dtype {moe_config.router_logits_dtype}" + ) + elif not cls._supports_shape(moe_config.hidden_dim): + return False, _make_reason( + f"{moe_config.hidden_dim} hidden dim is not supported" + ) elif activation_format != cls.activation_format(): return False, _make_reason(f"{activation_format.value} activation format") return True, None @@ -554,10 +604,48 @@ def _supports_activation(activation: MoEActivation) -> bool: @abstractmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: """ - Whether the kernel supports deployment in expert parallel. + Whether the kernel supports deployment in particular parallel config. + + Can be overriden if a kernel does not support EP, SP or some other + configuration. """ raise NotImplementedError + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """ + Whether the kernel supports a routing method (e.g. GroupedTopK). + + Can be overriden by monolithic kernels that execute the router + in addition to the experts if certain routers are not supported. + """ + return True + + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + """ + Whether a kernel supports a particular dtype for router logits input. + + Can be overriden by monolithic kernels that execute the router + in addition to the experts if certain dtypes are not supported. + """ + return True + + @staticmethod + def _supports_shape(hidden_dim: int) -> bool: + """ + Whether a kernel supports a particular shape. Can be overridden if a kernel + has specific shape requirements. + """ + return True + # # Various helpers for accessing quantization parameters from the # quant_config. @@ -654,6 +742,65 @@ def supports_packed_ue8m0_act_scales(self) -> bool: """ return False + def enable_chunking(self): + return ( + envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking() + ) + + +class FusedMoEExpertsModular(FusedMoEExperts): + """ + An abstract base class for the [Permute-Experts-Unpermute] step described + above. + """ + + @staticmethod + def is_monolithic() -> bool: + return False + + def moe_problem_size( + self, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + ) -> tuple[int, int, int, int, int]: + """ + Extract the MoE problem size from the given tensor arguments: + - a: The hidden states, input to the MoE layer. + - w1: The first set of expert weights. + - w2: The second set of expert weights. + - topk_ids: The topk ids. + + Note: extracting the problem shape from the weight and activation + tensors is not obvious. It needs to be done this way specifically + due to subtle issues with particular kernels, e.g. the int4 kernels + divide the trailing dimension by two, so it's not "correct" to + extract N or K from the trailing dimension of w1 or w2. Similarly, + some kernels transpose the weights, so this needs to be kept in mind. + + Note: This implementation covers most cases. However, if experts + require a specialized implementation, like MarlinExperts, they are free + to override this function. + """ + assert w1.dim() == 3 and w2.dim() == 3 + E, N, _ = w1.size() + K = a1.size(-1) + + if a1.dim() == 2: + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}" + M = a1.size(0) + else: + assert a1.dim() == 3 + assert a1.size(0) == E, f"{a1.size(0)} == {E}" + M = a1.size(1) # This is max_num_tokens + + assert topk_ids.dim() == 2 + topk = topk_ids.size(1) + + return E, M, N, K, topk + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: """ Workspace type: The dtype to use for the workspace tensors. @@ -726,11 +873,7 @@ def activation( ) -> None: apply_moe_activation(activation, output, input) - def enable_chunking(self): - return ( - envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking() - ) - + @abstractmethod def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce: raise NotImplementedError @@ -791,6 +934,67 @@ def apply( raise NotImplementedError +class FusedMoEExpertsMonolithic(FusedMoEExperts): + """ + An abstract base class for the [Permute-Experts-Unpermute] step described + above, but with the monolithic interface (accepts router logits + rather than topk ids and weights). + """ + + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """ + Whether the kernel supports a routing method (e.g. GroupedTopK). + + Monolithic kernels should explicitly opt-in to support. + """ + raise NotImplementedError + + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + """ + Whether the kernel supports a dtype for router logits. + + Modular kernels should opt-in to support. + """ + raise NotImplementedError + + @staticmethod + def is_monolithic() -> bool: + return True + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + """ + Same as apply(), except uses router_logits as opposed + to the topk_ids and topk_weights. This is useful for kernels + with fused router and fused_experts (e.g. FLASHINFER_TRTLLM). + """ + raise NotImplementedError + + def _slice_scales( scales: torch.Tensor | None, start: int, end: int ) -> torch.Tensor | None: @@ -802,75 +1006,32 @@ def _slice_scales( return None -@final -class FusedMoEModularKernel(torch.nn.Module): - """ - This class combines a FusedMoEPrepareAndFinalize instance and - a FusedMoEPermuteExpertsUnpermute to provide an interface that - is compatible with the `fused_experts` function in fused_moe.py. +################################################################################ +# Kernel +################################################################################ - It takes care of managing any required scratch space. - - Note: Instances of this class should only be used for a single model - layer due to any layer specific state that may be used by the component - objects. - """ +@final +class FusedMoEKernelModularImpl: def __init__( self, - prepare_finalize: FusedMoEPrepareAndFinalize, - fused_experts: FusedMoEPermuteExpertsUnpermute, + prepare_finalize: FusedMoEPrepareAndFinalizeModular, + fused_experts: FusedMoEExpertsModular, shared_experts: torch.nn.Module | None = None, moe_parallel_config: FusedMoEParallelConfig | None = None, inplace: bool = False, ): - super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts self.shared_experts = shared_experts + self.moe_parallel_config = moe_parallel_config self.inplace = inplace - - # prefer an explicit FusedMoEParallelConfig when available (from - # FusedMoE layers / tests). - # if not provided, assume this kernel is - # running in a non-DP+EP context - self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config self.is_dp_ep = ( moe_parallel_config is not None and moe_parallel_config.dp_size > 1 and moe_parallel_config.use_ep ) - self._post_init_setup() - assert ( - prepare_finalize.activation_format == fused_experts.activation_format() - ), ( - f"{prepare_finalize.__class__.__name__}." - f"{prepare_finalize.activation_format} == " - f"{fused_experts.__class__.__name__}." - f"{fused_experts.activation_format()}" - ) - - def _post_init_setup(self): - """ - Resolve any leftover setup dependencies between self.prepare_finalize - and self.fused_experts here. - """ - self.prepare_finalize.post_init_setup(self.fused_experts) - - def supports_expert_map(self) -> bool: - """ - A flag indicating whether or not this class supports expert maps. - """ - return self.fused_experts.supports_expert_map() - - def output_is_reduced(self) -> bool: - """ - Indicates whether or not the output of fused MoE kernel - is reduced across all ranks. - """ - return self.prepare_finalize.output_is_reduced() - def _chunk_info(self, M: int) -> tuple[int, int]: """ Compute number of chunks and chunk size for given M. @@ -919,7 +1080,7 @@ def _allocate_buffers( workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) # Force worst-case allocation in profiling run for - # "mk.FusedMoEModularKernel.Standard" formats where this is only bounded + # "mk.FusedMoEKernel.Standard" formats where this is only bounded # by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with # DP+EP due to the random token routing. is_profile_run = ( @@ -1313,13 +1474,13 @@ def _finalize( assert shared_output is not None return shared_output, output - def forward( + def apply( self, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, + topk_weights: torch.Tensor, activation: MoEActivation = MoEActivation.SILU, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, @@ -1334,8 +1495,7 @@ def forward( - hidden_states: (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - - topk_weights (torch.Tensor): The topk weights applied at the end of - the layer. + - topk_weights (torch.Tensor): The topk weights applied at the end of the layer. - topk_ids (torch.Tensor): A map of row to expert id. - activation (MoEActivation): The activation function to apply after the first MoE layer. @@ -1354,7 +1514,6 @@ def forward( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - if self.inplace: assert self.shared_experts is None assert not disable_inplace() @@ -1400,3 +1559,206 @@ def forward( apply_router_weight_on_input, shared_experts_input=shared_experts_input, ) + + +@final +class FusedMoEKernelMonolithicImpl: + def __init__( + self, + prepare_finalize: FusedMoEPrepareAndFinalizeMonolithic, + fused_experts: FusedMoEExpertsMonolithic, + ): + self.prepare_finalize = prepare_finalize + self.fused_experts = fused_experts + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + """ + Same as forward(), except uses router_logits as opposed + to the topk_ids and topk_weights. This is used for kernels + that have fused router + experts (e.g. FLASHINFER_TRTLLM). + """ + + # TODO(rob): add inplace support. + a1q, a1q_scale, router_logits = self.prepare_finalize.prepare( + hidden_states, + router_logits=router_logits, + quant_config=self.fused_experts.quant_config, + defer_input_quant=self.fused_experts.expects_unquantized_inputs, + ) + + fused_out = self.fused_experts.apply( + hidden_states=a1q, + w1=w1, + w2=w2, + router_logits=router_logits, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + a1q_scale=a1q_scale, + # grouped topk + fused topk bias parameters + num_expert_group=num_expert_group, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + topk_group=topk_group, + ) + + output = self.prepare_finalize.finalize(fused_out) + + return output + + +@final +class FusedMoEKernel: + def __init__( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + fused_experts: FusedMoEExperts, + shared_experts: torch.nn.Module | None = None, + moe_parallel_config: FusedMoEParallelConfig | None = None, + inplace: bool = False, + ): + super().__init__() + self.shared_experts = shared_experts # NOTE: check if we can remove + + # Initialize the implementation (monolithic or modular). + self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl + if isinstance( + prepare_finalize, FusedMoEPrepareAndFinalizeModular + ) and isinstance(fused_experts, FusedMoEExpertsModular): + self.impl = FusedMoEKernelModularImpl( + prepare_finalize, + fused_experts, + shared_experts, + moe_parallel_config, + inplace, + ) + + elif isinstance( + prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic + ) and isinstance(fused_experts, FusedMoEExpertsMonolithic): + assert shared_experts is None + assert not inplace + self.impl = FusedMoEKernelMonolithicImpl( + prepare_finalize, + fused_experts, + ) + + else: + raise ValueError( + "prepare_finalize and fused_experts must both be either monolithic " + f"or non-monolithic but got {prepare_finalize.__class__.__name__} " + f"and {fused_experts.__class__.__name__}" + ) + + self._post_init_setup() + + @property + def is_monolithic(self) -> bool: + return isinstance(self.impl, FusedMoEKernelMonolithicImpl) + + @property + def prepare_finalize(self) -> FusedMoEPrepareAndFinalize: + return self.impl.prepare_finalize + + @property + def fused_experts(self) -> FusedMoEExperts: + return self.impl.fused_experts + + def _post_init_setup(self): + """ + Resolve any leftover setup dependencies between self.prepare_finalize + and self.fused_experts here. + """ + self.prepare_finalize.post_init_setup(self.impl.fused_experts) + assert ( + self.prepare_finalize.activation_format + == self.fused_experts.activation_format() + ) + + def supports_expert_map(self) -> bool: + """ + A flag indicating whether or not this class supports expert maps. + """ + return self.fused_experts.supports_expert_map() + + def output_is_reduced(self) -> bool: + """ + Indicates whether or not the output of fused MoE kernel + is reduced across all ranks. + """ + return self.prepare_finalize.output_is_reduced() + + def apply_monolithic( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + assert isinstance(self.impl, FusedMoEKernelMonolithicImpl) + return self.impl.apply( + hidden_states=hidden_states, + w1=w1, + w2=w2, + router_logits=router_logits, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + num_expert_group=num_expert_group, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + topk_group=topk_group, + ) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + shared_experts_input: torch.Tensor | None = None, + ) -> torch.Tensor: + assert isinstance(self.impl, FusedMoEKernelModularImpl) + return self.impl.apply( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index dc0f32dc1992..164605dde3c0 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -12,7 +12,7 @@ logger = init_logger(__name__) -class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): +class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """ Prepare/Finalize using MoRI kernels. """ diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 9edd15eede63..0ed159b93695 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -18,13 +18,9 @@ fp8_w8a8_moe_quant_config, fp8_w8a16_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( - is_supported_config_trtllm_fp8, -) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, get_flashinfer_moe_backend, - make_fp8_moe_alpha_scales_for_fi, prepare_fp8_moe_layer_for_fi, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -103,9 +99,13 @@ def _move_to_front(backends: list[Fp8MoeBackend], backend: Fp8MoeBackend) -> Non def backend_to_kernel_cls( backend: Fp8MoeBackend, -) -> type[mk.FusedMoEPermuteExpertsUnpermute]: +) -> type[mk.FusedMoEExperts]: if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - raise NotImplementedError + from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501 + TrtLlmFp8Experts, + ) + + return TrtLlmFp8Experts elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( @@ -205,13 +205,11 @@ def select_fp8_moe_backend( weight_key: QuantKey | None, activation_key: QuantKey | None, allow_vllm_cutlass: bool = False, -) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: +) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts] | None]: """ Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ - k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None - if config.is_lora_enabled: return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON) @@ -252,7 +250,7 @@ def _return_or_raise( weight_key: QuantKey | None, activation_key: QuantKey | None, activation_format: mk.FusedMoEActivationFormat, - ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: + ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: k_cls = backend_to_kernel_cls(backend) supported, reason = k_cls.is_supported_config( k_cls, config, weight_key, activation_key, activation_format @@ -287,16 +285,6 @@ def _return_or_raise( "vLLM CUTLASS FP8 MoE backend is disabled for this configuration." ) - # Handle FLASHINFER_TRTLLM specially (no kernel class). - if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - supported, reason = is_supported_config_trtllm_fp8( - config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(requested_backend)) - return requested_backend, None - raise ValueError(_make_log_unsupported(requested_backend, reason)) - return _return_or_raise( requested_backend, config, weight_key, activation_key, activation_format ) @@ -311,51 +299,32 @@ def _return_or_raise( elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): # If user is explicit about backend, validate it. fi_backend = get_flashinfer_moe_backend() - - if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: - backend = Fp8MoeBackend.FLASHINFER_TRTLLM - supported, reason = is_supported_config_trtllm_fp8( - config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(backend)) - return backend, None - else: - raise ValueError(_make_log_unsupported(backend, reason)) - - elif fi_backend == FlashinferMoeBackend.CUTLASS: + if fi_backend == FlashinferMoeBackend.CUTLASS: backend = Fp8MoeBackend.FLASHINFER_CUTLASS - return _return_or_raise( - backend, config, weight_key, activation_key, activation_format - ) - + elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM: + backend = Fp8MoeBackend.FLASHINFER_TRTLLM else: - assert fi_backend == FlashinferMoeBackend.CUTEDSL - raise ValueError("FlashInfer MaskedGEMM not supported for FP8") - + raise ValueError( + f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE." + ) + k_cls = backend_to_kernel_cls(backend) + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) else: # If the user is not explicit about the backend, try both. for backend in [ Fp8MoeBackend.FLASHINFER_TRTLLM, Fp8MoeBackend.FLASHINFER_CUTLASS, ]: - if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - k_cls = None - supported, reason = is_supported_config_trtllm_fp8( - config, - weight_key, - activation_key, - activation_format, - ) - else: - k_cls = backend_to_kernel_cls(backend) - supported, reason = k_cls.is_supported_config( - k_cls, - config, - weight_key, - activation_key, - activation_format, - ) + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) if supported: logger.info_once(_make_log_backend(backend), scope="local") @@ -408,23 +377,14 @@ def _return_or_raise( # Select kernels in order of backend. for backend in AVAILABLE_BACKENDS: - if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - k_cls = None - supported, reason = is_supported_config_trtllm_fp8( - config, - weight_key, - activation_key, - activation_format, - ) - else: - k_cls = backend_to_kernel_cls(backend) - supported, reason = k_cls.is_supported_config( - k_cls, - config, - weight_key, - activation_key, - activation_format, - ) + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) if supported: logger.info_once(_make_log_backend(backend), scope="local") @@ -510,7 +470,7 @@ def make_fp8_moe_quant_config( block_shape: list[int] | None = None, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, -) -> FusedMoEQuantConfig | None: +) -> FusedMoEQuantConfig: """ Create FusedMoEQuantConfig for the specified FP8 Backend. The FusedMoEQuantConfig holds the scales that are used @@ -523,9 +483,6 @@ def make_fp8_moe_quant_config( In a future PR, we will have this function should be a method of the modular kernel itself. """ - # TRTLLM does not use Modular Kernel abstraction yet. - if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - return None # MARLIN is mixed precision W8A16 config. if fp8_backend == Fp8MoeBackend.MARLIN: @@ -539,12 +496,6 @@ def make_fp8_moe_quant_config( # (alpha = w_scale * a_scale) and inverse a2 scale. if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None: assert a1_scale is not None and a2_scale is not None - g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( - w1_scale, - a1_scale, - w2_scale, - a2_scale, - ) return fp8_w8a8_moe_quant_config( w1_scale=w1_scale, w2_scale=w2_scale, @@ -552,8 +503,8 @@ def make_fp8_moe_quant_config( a2_scale=a2_scale, a1_gscale=(1.0 / a1_scale), a2_gscale=(1.0 / a2_scale), - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, + g1_alphas=(w1_scale * a1_scale).squeeze(), + g2_alphas=(w2_scale * a2_scale).squeeze(), ) # All other backends use normal config. return fp8_w8a8_moe_quant_config( @@ -570,17 +521,18 @@ def make_fp8_moe_quant_config( def make_fp8_moe_kernel( moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, - experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], + experts_cls: type[mk.FusedMoEExperts], fp8_backend: Fp8MoeBackend, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, shared_experts: torch.nn.Module | None = None, -) -> mk.FusedMoEModularKernel: +) -> mk.FusedMoEKernel: # Create Prepare/Finalize. prepare_finalize = maybe_make_prepare_finalize( moe=moe_config, quant_config=moe_quant_config, routing_tables=routing_tables, allow_new_interface=True, + use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic), ) assert prepare_finalize is not None @@ -605,7 +557,7 @@ def make_fp8_moe_kernel( # NOTE(rob): we only want the mk to control the shared_expert # if using all2all (for SBO). bnell is making this explicit in # the new MoE runner class. - kernel = mk.FusedMoEModularKernel( + kernel = mk.FusedMoEKernel( prepare_finalize, experts, shared_experts=( diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index d48def361936..dd1a24d863de 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -19,7 +19,6 @@ nvfp4_w4a16_moe_quant_config, ) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - is_supported_config_trtllm, prepare_nvfp4_moe_layer_for_fi_or_cutlass, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( @@ -67,39 +66,46 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: def backend_to_kernel_cls( backend: NvFp4MoeBackend, -) -> type[mk.FusedMoEPermuteExpertsUnpermute]: +) -> list[type[mk.FusedMoEExperts]]: if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - raise NotImplementedError( - "FLASHINFER_TRTLLM doesn't support Modular Kernel Interface" + from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import ( + TrtLlmNvFp4ExpertsModular, + TrtLlmNvFp4ExpertsMonolithic, ) + # NOTE: prefer Monolthic > Modular, so return Monolithic first. + return [ + TrtLlmNvFp4ExpertsMonolithic, + TrtLlmNvFp4ExpertsModular, + ] + elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, ) - return FlashInferExperts + return [FlashInferExperts] elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( FlashInferCuteDSLExperts, ) - return FlashInferCuteDSLExperts + return [FlashInferCuteDSLExperts] elif backend == NvFp4MoeBackend.VLLM_CUTLASS: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassExpertsFp4, ) - return CutlassExpertsFp4 + return [CutlassExpertsFp4] elif backend == NvFp4MoeBackend.MARLIN: from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( MarlinExperts, ) - return MarlinExperts + return [MarlinExperts] else: raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") @@ -125,7 +131,7 @@ def select_nvfp4_moe_backend( config: FusedMoEConfig, weight_key: QuantKey | None, activation_key: QuantKey | None, -) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: +) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]: """ Select the primary NvFP4 MoE backend Note: Shape-specific fallbacks may still occur at runtime. @@ -175,29 +181,21 @@ def _return_or_raise( weight_key: QuantKey | None, activation_key: QuantKey | None, activation_format: mk.FusedMoEActivationFormat, - ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: - k_cls = backend_to_kernel_cls(backend) - supported, reason = k_cls.is_supported_config( - k_cls, config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(backend)) - return backend, k_cls + ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]: + for k_cls in backend_to_kernel_cls(backend): + supported, reason = k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + raise ValueError(_make_log_unsupported(backend, reason)) # Handle explicit moe_backend from user. runner_backend = config.moe_backend if runner_backend != "auto": requested_backend = map_nvfp4_backend(runner_backend) - if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - supported, reason = is_supported_config_trtllm( - config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(requested_backend)) - return requested_backend, None - raise ValueError(_make_log_unsupported(requested_backend, reason)) - return _return_or_raise( requested_backend, config, weight_key, activation_key, activation_format ) @@ -210,36 +208,14 @@ def _return_or_raise( elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): # If user is explicit about backend, validate it. - fi_backend = get_flashinfer_moe_backend() - - if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: - backend = NvFp4MoeBackend.FLASHINFER_TRTLLM - supported, reason = is_supported_config_trtllm( - config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(backend)) - return backend, None - else: - raise ValueError(_make_log_unsupported(backend, reason)) - else: - backend = fi_2_vllm_backend_map[fi_backend] - return _return_or_raise( - backend, config, weight_key, activation_key, activation_format - ) + backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()] + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) else: # If the user is not explicit about the backend, try each. for backend in FLASHINFER_NVFP4_MOE_BACKENDS: - if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - k_cls = None - supported, reason = is_supported_config_trtllm( - config, - weight_key, - activation_key, - activation_format, - ) - else: - k_cls = backend_to_kernel_cls(backend) + for k_cls in backend_to_kernel_cls(backend): supported, reason = k_cls.is_supported_config( k_cls, config, @@ -247,13 +223,13 @@ def _return_or_raise( activation_key, activation_format, ) - if supported: - logger.info_once(_make_log_backend(backend), scope="local") - return backend, None - else: - logger.debug_once( - _make_log_unsupported(backend, reason), scope="local" - ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.debug_once( + _make_log_unsupported(backend, reason), scope="local" + ) raise NotImplementedError( "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no " @@ -268,16 +244,7 @@ def _return_or_raise( # Select kernels in order of backend. for backend in AVAILABLE_BACKENDS: - if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - k_cls = None # type: ignore[assignment] - supported, reason = is_supported_config_trtllm( - config, - weight_key, - activation_key, - activation_format, - ) - else: - k_cls = backend_to_kernel_cls(backend) + for k_cls in backend_to_kernel_cls(backend): supported, reason = k_cls.is_supported_config( k_cls, config, @@ -286,11 +253,11 @@ def _return_or_raise( activation_format, ) - if supported: - logger.info_once(_make_log_backend(backend), scope="local") - return backend, k_cls - else: - logger.debug_once(_make_log_unsupported(backend, reason), scope="local") + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.debug_once(_make_log_unsupported(backend, reason), scope="local") raise NotImplementedError( "No NvFp4 MoE backend supports the deployment configuration." @@ -398,12 +365,8 @@ def make_nvfp4_moe_quant_config( w2_scale_2: torch.Tensor, a13_scale: torch.Tensor, a2_scale: torch.Tensor, -) -> FusedMoEQuantConfig | None: - UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM] - if backend in UNSUPPORTED: - return None - - elif backend == NvFp4MoeBackend.MARLIN: +) -> FusedMoEQuantConfig: + if backend == NvFp4MoeBackend.MARLIN: return nvfp4_w4a16_moe_quant_config( g1_alphas=w13_scale_2, g2_alphas=w2_scale_2, @@ -420,22 +383,27 @@ def make_nvfp4_moe_quant_config( a2_gscale=(1.0 / a2_scale), w1_scale=w13_scale, w2_scale=w2_scale, + # NOTE(rob): this is a hack until the MoE kernels + # create their own quant configs. TRTLLM kernel + # does not accept swizzled input quant scales. + is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM), ) def make_nvfp4_moe_kernel( moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, - experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], + experts_cls: type[mk.FusedMoEExperts], routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, shared_experts: torch.nn.Module | None = None, -) -> mk.FusedMoEModularKernel: +) -> mk.FusedMoEKernel: # Create Prepare/Finalize. prepare_finalize = maybe_make_prepare_finalize( moe=moe_config, quant_config=moe_quant_config, routing_tables=routing_tables, allow_new_interface=True, + use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic), ) assert prepare_finalize is not None @@ -460,7 +428,7 @@ def make_nvfp4_moe_kernel( # NOTE(rob): we only want the mk to control the shared_expert # if using all2all (for SBO). bnell is making this explicit in # the new MoE runner class. - kernel = mk.FusedMoEModularKernel( + kernel = mk.FusedMoEKernel( prepare_finalize, experts, shared_experts=( diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index 1c582bcdc53e..9c31da10dd94 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -19,7 +19,7 @@ is_supported_config_trtllm_bf16, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, + MoEPrepareAndFinalizeNoDPEPModular, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( swap_w13_to_w31, @@ -209,7 +209,7 @@ def make_unquantized_moe_kernel( backend: UnquantizedMoeBackend, quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, -) -> mk.FusedMoEModularKernel | None: +) -> mk.FusedMoEKernel | None: if backend in UNSUPPORTED_BACKEND: return None @@ -218,8 +218,8 @@ def make_unquantized_moe_kernel( FlashInferExperts, ) - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + kernel = mk.FusedMoEKernel( + MoEPrepareAndFinalizeNoDPEPModular(), FlashInferExperts( moe_config=moe_config, quant_config=quant_config, @@ -232,8 +232,8 @@ def make_unquantized_moe_kernel( AiterExperts, ) - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + kernel = mk.FusedMoEKernel( + MoEPrepareAndFinalizeNoDPEPModular(), AiterExperts( moe_config=moe_config, quant_config=quant_config, @@ -243,8 +243,8 @@ def make_unquantized_moe_kernel( elif backend == UnquantizedMoeBackend.TRITON: from vllm.model_executor.layers.fused_moe import TritonExperts - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + kernel = mk.FusedMoEKernel( + MoEPrepareAndFinalizeNoDPEPModular(), TritonExperts( moe_config=moe_config, quant_config=quant_config, @@ -254,8 +254,8 @@ def make_unquantized_moe_kernel( elif backend == UnquantizedMoeBackend.XPU: from vllm.model_executor.layers.fused_moe import XPUExperts - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + kernel = mk.FusedMoEKernel( + MoEPrepareAndFinalizeNoDPEPModular(), XPUExperts( moe_config=moe_config, quant_config=quant_config, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py deleted file mode 100644 index 7b8dd3b775ee..000000000000 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ /dev/null @@ -1,209 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.distributed import get_ep_group -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceContiguous, - TopKWeightAndReduceDelegate, -) -from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input -from vllm.utils.flashinfer import nvfp4_block_scale_interleave - - -class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): - def __init__( - self, - is_sequence_parallel: bool = False, - num_dispatchers: int = 1, - ) -> None: - super().__init__() - self.is_sequence_parallel = is_sequence_parallel - self._num_dispatchers = num_dispatchers - - @property - def activation_format(self) -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.Standard - - def max_num_tokens_per_rank(self) -> int | None: - return None - - def topk_indices_dtype(self) -> torch.dtype | None: - return None - - def num_dispatchers(self) -> int: - return self._num_dispatchers - - def output_is_reduced(self) -> bool: - return False - - def prepare( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: torch.Tensor | None, - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - defer_input_quant: bool = False, - ) -> mk.PrepareResultType: - if apply_router_weight_on_input: - topk = topk_ids.size(1) - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1" - ) - # Note: do not use inplace for shared experts overlap - a1 = a1 * topk_weights.to(a1.dtype) - - # Defer input quantization to the MoE kernel. - use_nvfp4 = quant_config.use_nvfp4_w4a4 - if defer_input_quant: - a1q = a1 - a1q_scale = None - else: - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - # NOTE: swizzling pads the scales to multiple of 128 - # which makes the scales tensor different shape than - # the hidden states, breaking the A2A kernel. So, we - # delay the swizzling until after the A2A. - is_fp4_scale_swizzled=False, - ) - - # Skip gathering scales if we have static quantization - # (the scale is a scalar, replicated on all ranks) or - # if quantization is deferred. - skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0 - scales = None if skip_gather_scales else [a1q_scale] - - res = get_ep_group().dispatch( - a1q, - topk_weights, - topk_ids, - is_sequence_parallel=self.is_sequence_parallel, - extra_tensors=scales, - ) - if skip_gather_scales: - a1q, topk_weights, topk_ids = res - else: - a1q, topk_weights, topk_ids, scales = res - assert scales is not None and len(scales) == 1 - a1q_scale = scales[0] - if quant_config.quant_dtype == "nvfp4": - assert a1q_scale is not None - if a1q_scale.element_size() == 1: - a1q_scale = a1q_scale.view(torch.uint8) - a1q_scale = nvfp4_block_scale_interleave(a1q_scale) - - return a1q, a1q_scale, None, topk_ids, topk_weights - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): - weight_and_reduce_impl = TopKWeightAndReduceContiguous() - - out = weight_and_reduce_impl.apply( - output=None, - fused_expert_output=fused_expert_output, - topk_weights=topk_weights, - topk_ids=topk_ids, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - output.copy_( - get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel) - ) - - -class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): - """MoE prepare and finalize without expert parallelism.""" - - @property - def activation_format(self) -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.Standard - - def max_num_tokens_per_rank(self) -> int | None: - return None - - def topk_indices_dtype(self) -> torch.dtype | None: - return None - - def num_dispatchers(self) -> int: - return 1 - - def output_is_reduced(self) -> bool: - return False - - def prepare( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: torch.Tensor | None, - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - defer_input_quant: bool = False, - ) -> mk.PrepareResultType: - if apply_router_weight_on_input: - topk = topk_ids.size(1) - # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1" - ) - # Note: do not use inplace for shared experts overlap - a1 = a1 * topk_weights.to(a1.dtype) - - # Defer input quant to moe kernel for backends (e.g. AITER, FI) - # which use a single kernel call for quant + experts. - if defer_input_quant: - return a1, None, None, None, None - - input_sf = ( - quant_config.a1_gscale - if quant_config.use_nvfp4_w4a4 - else quant_config.a1_scale - ) - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - input_sf, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - ) - - return a1q, a1q_scale, None, None, None - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): - weight_and_reduce_impl = TopKWeightAndReduceContiguous() - weight_and_reduce_impl.apply( - output=output, - fused_expert_output=fused_expert_output, - topk_weights=topk_weights, - topk_ids=topk_ids, - apply_router_weight_on_input=apply_router_weight_on_input, - ) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py new file mode 100644 index 000000000000..03fea7c6d78b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.model_executor.layers.fused_moe.prepare_finalize.naive_dp_ep import ( + MoEPrepareAndFinalizeNaiveDPEPModular, + MoEPrepareAndFinalizeNaiveDPEPMonolithic, + make_moe_prepare_and_finalize_naive_dp_ep, +) +from vllm.model_executor.layers.fused_moe.prepare_finalize.no_dp_ep import ( + MoEPrepareAndFinalizeNoDPEPModular, + MoEPrepareAndFinalizeNoDPEPMonolithic, + make_moe_prepare_and_finalize_no_dp_ep, +) + +__all__ = [ + "MoEPrepareAndFinalizeNaiveDPEPMonolithic", + "MoEPrepareAndFinalizeNaiveDPEPModular", + "make_moe_prepare_and_finalize_naive_dp_ep", + "MoEPrepareAndFinalizeNoDPEPMonolithic", + "MoEPrepareAndFinalizeNoDPEPModular", + "make_moe_prepare_and_finalize_no_dp_ep", +] diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py new file mode 100644 index 000000000000..6dc9f6958048 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_ep_group +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceContiguous, + TopKWeightAndReduceDelegate, +) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.utils.flashinfer import nvfp4_block_scale_interleave + + +def _quantize_and_setup_dispatch( + a1: torch.Tensor, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, +) -> tuple[torch.Tensor, list[torch.Tensor] | None]: + # Defer input quantization to the MoE kernel. + if defer_input_quant: + a1q = a1 + a1q_scale = None + else: + input_sf = ( + quant_config.a1_gscale + if quant_config.use_nvfp4_w4a4 + else quant_config.a1_scale + ) + + # NOTE: swizzling pads the scales to multiple of 128 + # which makes the scales tensor different shape than + # the hidden states, breaking the A2A kernel. So, we + # delay the swizzling until after the A2A. + a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input( + a1, + input_sf, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, + is_fp4_scale_swizzled=False, + ) + + # Skip gathering scales if we have static quantization + # (the scale is a scalar, replicated on all ranks) or + # if quantization is deferred. + skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0 + scales = None if skip_gather_scales else [a1q_scale] + + return a1q, scales + + +def _unwrap_scale_and_prepare_for_moe( + scales: list[torch.Tensor] | None, + quant_config: FusedMoEQuantConfig, +) -> torch.Tensor: + assert scales is not None and len(scales) == 1 + a1q_scale = scales[0] + # Apply swizzling after a2a if the MoE kernel needs it. + if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled: + assert a1q_scale is not None + if a1q_scale.element_size() == 1: + a1q_scale = a1q_scale.view(torch.uint8) + a1q_scale = nvfp4_block_scale_interleave(a1q_scale) + + return a1q_scale + + +class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular): + """ + Naive Prepare/Finalize for Dp/Ep case for Modular Kernels. + + Uses Torch AR/RS or AR for dispatch/combine operations, applied + to the topk weights and ids. + """ + + def __init__( + self, + is_sequence_parallel: bool = False, + num_dispatchers: int = 1, + ) -> None: + super().__init__() + self.is_sequence_parallel = is_sequence_parallel + self._num_dispatchers = num_dispatchers + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return self._num_dispatchers + + def output_is_reduced(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareResultType: + """Quantize and Dispatch Topk Weights and Topk Ids.""" + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + # Note: do not use inplace for shared experts overlap + a1 = a1 * topk_weights.to(a1.dtype) + + a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant) + + res = get_ep_group().dispatch( + a1q, + topk_weights, + topk_ids, + is_sequence_parallel=self.is_sequence_parallel, + extra_tensors=scales, + ) + + if scales is None: + a1q, topk_weights, topk_ids = res + a1q_scale = None + else: + a1q, topk_weights, topk_ids, scales = res + a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config) + + return a1q, a1q_scale, None, topk_ids, topk_weights + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + + out = weight_and_reduce_impl.apply( + output=None, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + output.copy_( + get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel) + ) + + +class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic): + """ + Naive Prepare/Finalize for Dp/Ep case for Modular Kernels. + + Uses Torch AR/RS or AR for dispatch/combine operations, applied + to the router logits (the MoE kernel runs the router internally). + """ + + def __init__( + self, + is_sequence_parallel: bool = False, + num_dispatchers: int = 1, + ) -> None: + super().__init__() + self.is_sequence_parallel = is_sequence_parallel + self._num_dispatchers = num_dispatchers + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return self._num_dispatchers + + def output_is_reduced(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + router_logits: torch.Tensor, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareMonolithicResultType: + """Quantize and Dispatch Router Logits.""" + + a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant) + + res = get_ep_group().dispatch_router_logits( + a1q, + router_logits, + is_sequence_parallel=self.is_sequence_parallel, + extra_tensors=scales, + ) + + if scales is None: + a1q, router_logits = res + a1q_scale = None + else: + a1q, router_logits, scales = res + a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config) + + return a1q, a1q_scale, router_logits + + def finalize( + self, + fused_expert_output: torch.Tensor, + ) -> torch.Tensor: + out = get_ep_group().combine( + fused_expert_output, is_sequence_parallel=self.is_sequence_parallel + ) + return out + + +def make_moe_prepare_and_finalize_naive_dp_ep( + use_monolithic: bool, + is_sequence_parallel: bool = False, + num_dispatchers: int = 1, +) -> MoEPrepareAndFinalizeNaiveDPEPModular | MoEPrepareAndFinalizeNaiveDPEPMonolithic: + return ( + MoEPrepareAndFinalizeNaiveDPEPMonolithic( + is_sequence_parallel=is_sequence_parallel, + num_dispatchers=num_dispatchers, + ) + if use_monolithic + else MoEPrepareAndFinalizeNaiveDPEPModular( + is_sequence_parallel=is_sequence_parallel, + num_dispatchers=num_dispatchers, + ) + ) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py new file mode 100644 index 000000000000..b9d57da08326 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceContiguous, + TopKWeightAndReduceDelegate, +) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input + + +def _quantize_input( + a1: torch.Tensor, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + # Defer input quant to moe kernel for backends (e.g. AITER, FI) + # which use a single kernel call for quant + experts. + if defer_input_quant: + return a1, None + + input_sf = ( + quant_config.a1_gscale if quant_config.use_nvfp4_w4a4 else quant_config.a1_scale + ) + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + input_sf, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, + is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled, + ) + + return a1q, a1q_scale + + +class MoEPrepareAndFinalizeNoDPEPModular(mk.FusedMoEPrepareAndFinalizeModular): + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return 1 + + def output_is_reduced(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareResultType: + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + # Note: do not use inplace for shared experts overlap + a1 = a1 * topk_weights.to(a1.dtype) + + a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant) + + return a1q, a1q_scale, None, None, None + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + weight_and_reduce_impl.apply( + output=output, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +class MoEPrepareAndFinalizeNoDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic): + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return 1 + + def output_is_reduced(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + router_logits: torch.Tensor, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareMonolithicResultType: + a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant) + return a1q, a1q_scale, router_logits + + def finalize( + self, + fused_expert_output: torch.Tensor, + ) -> torch.Tensor: + return fused_expert_output + + +def make_moe_prepare_and_finalize_no_dp_ep( + use_monolithic: bool, +) -> MoEPrepareAndFinalizeNoDPEPModular | MoEPrepareAndFinalizeNoDPEPMonolithic: + return ( + MoEPrepareAndFinalizeNoDPEPMonolithic() + if use_monolithic + else MoEPrepareAndFinalizeNoDPEPModular() + ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 8c8439decbbb..c550cad9e892 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -292,7 +292,7 @@ def rocm_aiter_fused_experts( ) -class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): +class AiterExperts(mk.FusedMoEExpertsModular): @property def expects_unquantized_inputs(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py index 52005d40d525..6332827d1d09 100644 --- a/vllm/model_executor/layers/fused_moe/router/base_router.py +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -64,7 +64,7 @@ def eplb_map_to_physical_and_record( # TODO(bowen): When using `FusedMoEModularKernel`, this # can be done in a more unified way, since - # `FusedMoEPrepareAndFinalize` will return the expert + # `FusedMoEPrepareAndFinalizeModular` will return the expert # token count, in some cases directly from the kernel. # However, now there are many code paths not using # the modular kernel, e.g. calling `fused_experts`, diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py index 274929c071ac..e9e849b25910 100644 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py @@ -320,8 +320,8 @@ def must_reduce_shared_expert_outputs(self) -> bool: """ assert self.quant_method is not None return ( - self.quant_method.moe_mk is not None - and self.quant_method.moe_mk.output_is_reduced() + self.quant_method.moe_kernel is not None + and self.quant_method.moe_kernel.output_is_reduced() ) def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): @@ -640,45 +640,6 @@ def forward_impl( ) with sp_ctx: - extra_tensors = None - if do_naive_dispatch_combine: - post_quant_allgather = ( - self.quant_method is not None - and self.moe_config.dp_size > 1 - and self.moe_config.use_ep - and getattr(self.quant_method, "do_post_quant_allgather", False) - ) - if post_quant_allgather: - hidden_states_to_dispatch, extra_tensors = ( - self.quant_method.prepare_dp_allgather_tensor( - layer, hidden_states, router_logits - ) - ) - else: - hidden_states_to_dispatch = hidden_states - - dispatch_res = get_ep_group().dispatch_router_logits( - hidden_states_to_dispatch, - router_logits, - self.moe_config.is_sequence_parallel, - extra_tensors=extra_tensors, - ) - if extra_tensors is not None: - ( - orig_hidden_states, - router_logits, - extra_tensors_combined, - ) = dispatch_res - hidden_states_combined = ( - orig_hidden_states, - extra_tensors_combined[0], - ) - else: - hidden_states_combined, router_logits = dispatch_res - orig_hidden_states = hidden_states_combined - else: - orig_hidden_states = hidden_states - # Run shared experts before matrix multiply. # because matrix multiply maybe modify the hidden_states. if has_separate_shared_experts and not use_shared_experts_stream: @@ -688,6 +649,17 @@ def forward_impl( ) shared_output = self.shared_experts(shared_input) + # For naive dispatch/combine Dp/Ep, dispatch the hidden states and + # router logits to all experts. + # NOTE: this will be removed once all kernels are migrated into the + # MoEKernel framework. + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch_router_logits( + hidden_states, + router_logits, + self.moe_config.is_sequence_parallel, + ) + # NOTE: Similar with DP, PCP also needs dispatch and combine. For # simplicity, AgRsAll2All was added separately for PCP here. Maybe # we should modify All2AllManager abstract to better support PCP. @@ -701,31 +673,22 @@ def forward_impl( dim=0, ) - # TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014). - # Figure out nicer way to do this. - if do_naive_dispatch_combine: - x = hidden_states_combined - x_orig = orig_hidden_states - else: - x = hidden_states - x_orig = hidden_states - # Matrix multiply. if self.quant_method.is_monolithic: final_hidden_states = self.quant_method.apply_monolithic( layer=layer, - x=x, + x=hidden_states, router_logits=router_logits, ) else: topk_weights, topk_ids = self.router.select_experts( - hidden_states=x_orig, + hidden_states=hidden_states, router_logits=router_logits, ) final_hidden_states = self.quant_method.apply( layer=layer, - x=x, # The type signture of this is wrong due to the hack. + x=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, shared_experts_input=shared_input, diff --git a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py index d7b50aea2ad6..4cebe608a6b4 100644 --- a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py +++ b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py @@ -10,7 +10,7 @@ class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): """ - Useful in the case when some FusedMoEPermuteExpertsUnpermute + Useful in the case when some FusedMoEExpertsModular implementation does not perform weight application and reduction but cannot address the needs of all the compatible PrepareAndFinalize implementations. @@ -62,7 +62,7 @@ def apply( if output is None: return fused_expert_output - # MoEPrepareAndFinalizeNoEP needs the output to be in the `output` + # MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output` # tensor. assert output.size() == fused_expert_output.size(), ( "output shape is expected to match the fused_expert_output shape. " diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py index 21a3d05f4cd2..4aa396d24b0c 100644 --- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -32,8 +32,8 @@ def __init__( @staticmethod def get_clses() -> tuple[ - type[mk.FusedMoEPermuteExpertsUnpermute], - type[mk.FusedMoEPermuteExpertsUnpermute], + type[mk.FusedMoEExpertsModular], + type[mk.FusedMoEExpertsModular], ]: return (CutlassExpertsFp8, TritonExperts) @@ -77,7 +77,7 @@ def _select_experts_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: # Small batch fallback for sm100. if self.is_sm100 and hidden_states.shape[0] <= 8: return self.fallback_experts diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index a3f2f59c5b3c..b601806b067a 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -32,8 +32,8 @@ def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig @staticmethod def get_clses() -> tuple[ - type[mk.FusedMoEPermuteExpertsUnpermute], - type[mk.FusedMoEPermuteExpertsUnpermute], + type[mk.FusedMoEExpertsModular], + type[mk.FusedMoEExpertsModular], ]: return (DeepGemmExperts, TritonExperts) @@ -79,7 +79,7 @@ def _select_experts_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2): return self.experts else: diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 2bd4cd79e031..5160840a2f31 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -18,7 +18,7 @@ ) -class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): +class TrtLlmGenExperts(mk.FusedMoEExpertsModular): """TensorRT-LLM-based fused MoE expert implementation.""" def __init__( diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 5c86064a928f..95b6f7b77fa0 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -24,8 +24,8 @@ ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize, + FusedMoEExpertsModular, + FusedMoEPrepareAndFinalizeModular, ) from vllm.model_executor.layers.fused_moe.oracle.unquantized import ( UnquantizedMoeBackend, @@ -70,7 +70,7 @@ def __init__(self, moe: FusedMoEConfig): self.rocm_aiter_moe_enabled = ( rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul ) - self.kernel: mk.FusedMoEModularKernel | None = None + self.kernel: mk.FusedMoEKernel | None = None self._is_monolithic = ( current_platform.is_cpu() or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM @@ -107,7 +107,7 @@ def supports_eplb(self) -> bool: def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> FusedMoEPrepareAndFinalize | None: + ) -> FusedMoEPrepareAndFinalizeModular | None: if self.unquantized_backend == UnquantizedMoeBackend.AITER: return None else: @@ -115,9 +115,9 @@ def maybe_make_prepare_finalize( def select_gemm_impl( self, - prepare_finalize: FusedMoEPrepareAndFinalize, + prepare_finalize: FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> FusedMoEPermuteExpertsUnpermute: + ) -> FusedMoEExpertsModular: assert self.moe_quant_config is not None if ( prepare_finalize.activation_format @@ -325,7 +325,7 @@ def forward_cuda( ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.kernel is not None - return self.kernel( + return self.kernel.apply( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, diff --git a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py index e6f8b8efa804..0693a25468fd 100644 --- a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py @@ -23,7 +23,7 @@ from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe -class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute): +class XPUExperts(mk.FusedMoEExpertsModular): def __init__( self, moe_config: FusedMoEConfig, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 097d0bc01891..8b7fc57d0409 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -19,8 +19,8 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, + FusedMoEExpertsModular, FusedMoEMethodBase, - FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod, ) @@ -40,7 +40,6 @@ fused_marlin_moe, ) from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( - Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, @@ -59,18 +58,11 @@ WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP, ) -from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - flashinfer_trtllm_fp4_moe, - flashinfer_trtllm_fp4_routed_moe, -) from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import ( flashinfer_trtllm_mxint4_moe, is_flashinfer_mxint4_moe_available, prepare_static_weights_for_trtllm_mxint4_moe, ) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_fi_trtllm_fp8_per_tensor_moe, -) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_input_tensor_strategy_moe, process_fp8_weight_tensor_strategy_moe, @@ -336,7 +328,7 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config is not None: - self.moe_mk = make_nvfp4_moe_kernel( + self.moe_kernel = make_nvfp4_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, experts_cls=self.experts_cls, @@ -352,8 +344,8 @@ def apply( topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.moe_mk is not None - return self.moe_mk( + assert self.moe_kernel is not None + return self.moe_kernel.apply( x, layer.w13_weight, layer.w2_weight, @@ -562,43 +554,27 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: layer.w13_input_scale = a13_scale layer.w2_input_scale = a2_scale - # Setup modular kernel for TP case and naive DP/EP case. - # In non-naive DP/EP case, we will create a ModularKernelMethod. - # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel - # in both cases. + # Setup modular kernel. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config: - assert self.experts_cls is not None - self.moe_mk = make_nvfp4_moe_kernel( - moe_quant_config=self.moe_quant_config, - moe_config=self.moe, - experts_cls=self.experts_cls, - shared_experts=layer.shared_experts, - routing_tables=layer._maybe_init_expert_routing_tables(), - ) + assert self.experts_cls is not None + self.moe_kernel = make_nvfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + shared_experts=layer.shared_experts, + routing_tables=layer._maybe_init_expert_routing_tables(), + ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: - raise ValueError( - f"{self.__class__.__name__} uses the new modular kernel initialization " - "logic. This function should not be called." - ) - - def select_gemm_impl( - self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." ) - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: return make_nvfp4_moe_quant_config( backend=self.nvfp4_backend, w13_scale=layer.w13_weight_scale, @@ -609,13 +585,6 @@ def get_fused_moe_quant_config( a2_scale=layer.w2_input_scale, ) - @property - def is_monolithic(self) -> bool: - return ( - self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM - and not self.moe.moe_parallel_config.enable_eplb - ) - def apply_monolithic( self, layer: FusedMoE, @@ -623,24 +592,20 @@ def apply_monolithic( router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic - assert layer.activation == MoEActivation.SILU, ( - f"Only SiLU activation is supported, not {layer.activation}." - ) - assert ( - self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM - and not layer.enable_eplb - ) - return flashinfer_trtllm_fp4_moe( - layer=layer, - x=x, - router_logits=router_logits, - top_k=layer.top_k, + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, activation=layer.activation, global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, - custom_routing_function=layer.custom_routing_function, e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, ) def apply( @@ -651,34 +616,19 @@ def apply( topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert not self.is_monolithic - - # EPLB path - if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - assert layer.enable_eplb - return flashinfer_trtllm_fp4_routed_moe( - layer=layer, - x=x, - topk_ids=topk_ids, - topk_weights=topk_weights, - top_k=layer.top_k, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - ) - else: - assert self.moe_mk is not None - return self.moe_mk( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - shared_experts_input=shared_experts_input, - ) + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): @@ -966,7 +916,7 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: assert self.experts_cls is not None - self.moe_mk = make_fp8_moe_kernel( + self.moe_kernel = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, @@ -978,94 +928,47 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: - raise ValueError( - f"{self.__class__.__name__} uses the new modular kernel initialization " - "logic. This function should not be called." - ) - - def select_gemm_impl( - self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." ) - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - w1_scale = layer.w13_weight_scale - w2_scale = layer.w2_weight_scale - a1_scale = layer.w13_input_scale - a2_scale = layer.w2_input_scale - + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: + is_per_token = self.input_quant.strategy == QuantizationStrategy.TOKEN return make_fp8_moe_quant_config( fp8_backend=self.fp8_backend, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - per_act_token_quant=( - self.input_quant.strategy == QuantizationStrategy.TOKEN - ), - per_out_ch_quant=(self.input_quant.strategy == QuantizationStrategy.TOKEN), + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=is_per_token, + per_out_ch_quant=is_per_token, block_shape=self.weight_block_size, ) - @property - def is_monolithic(self) -> bool: - return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - def apply_monolithic( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.is_monolithic - assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - assert layer.activation == MoEActivation.SILU, ( - f"Only SiLU activation is supported, not {layer.activation}." + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, ) - if self.block_quant: - import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - - return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - x=x, - w13_weight=layer.w13_weight, - w13_weight_scale_inv=layer.w13_weight_scale, - w2_weight=layer.w2_weight, - w2_weight_scale_inv=layer.w2_weight_scale, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - intermediate_size=layer.intermediate_size_per_partition, - expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - block_shape=self.weight_block_size, - routing_method_type=layer.routing_method_type, - routed_scaling=layer.routed_scaling_factor, - ) - else: - return apply_fi_trtllm_fp8_per_tensor_moe( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) - def apply( self, layer: FusedMoE, @@ -1075,8 +978,8 @@ def apply( shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic - assert self.moe_mk is not None - return self.moe_mk( + assert self.moe_kernel is not None + return self.moe_kernel.apply( x, layer.w13_weight, layer.w2_weight, @@ -1652,9 +1555,9 @@ def get_fused_moe_quant_config( def select_gemm_impl( self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: assert self.num_bits == 4, "only supporting w4" layer.w13_weight = layer.w13_weight_packed layer.w2_weight = layer.w2_weight_packed @@ -1943,9 +1846,9 @@ def get_fused_moe_quant_config( def select_gemm_impl( self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: if self.moe.is_lora_enabled: assert self.moe_quant_config is not None from vllm.triton_utils import HAS_TRITON @@ -2527,7 +2430,7 @@ def process_weights_after_loading(self, layer): def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: return super().maybe_make_prepare_finalize(routing_tables) def get_fused_moe_quant_config( @@ -2548,9 +2451,9 @@ def get_fused_moe_quant_config( def select_gemm_impl( self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: assert self.moe_quant_config is not None assert ( prepare_finalize.activation_format == FusedMoEActivationFormat.Standard @@ -2558,7 +2461,7 @@ def select_gemm_impl( from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8 - experts: FusedMoEPermuteExpertsUnpermute + experts: FusedMoEExpertsModular logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__) experts = CutlassExpertsW4A8Fp8( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e3174ba995ff..5101347cd02a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,17 +23,13 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEMethodBase, - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, - MoEActivation, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( - Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, @@ -50,9 +46,6 @@ QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_fi_trtllm_fp8_per_tensor_moe, -) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, create_fp8_input_scale, @@ -860,14 +853,10 @@ def _setup_kernel( replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) - # Setup modular kernel for TP case and naive DP/EP case. - # In non-naive DP/EP case, we will create a ModularKernelMethod. - # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel - # in both cases. self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: assert self.experts_cls is not None - self.moe_mk = make_fp8_moe_kernel( + self.moe_kernel = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, @@ -930,29 +919,13 @@ def process_weights_after_loading(self, layer: Module) -> None: def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: - raise ValueError( - f"{self.__class__.__name__} uses the new modular kernel initialization " - "logic. This function should not be called." - ) - - def select_gemm_impl( - self, - prepare_finalize: FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." ) - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - # TRTLLM does not use Modular Kernel. - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - return None - + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: w1_scale = getattr(layer, f"w13_{self.weight_scale_name}") w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") a1_scale = layer.w13_input_scale @@ -983,10 +956,6 @@ def get_fused_moe_quant_config( def supports_eplb(self) -> bool: return True - @property - def is_monolithic(self) -> bool: - return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - def apply_monolithic( self, layer: FusedMoE, @@ -994,50 +963,22 @@ def apply_monolithic( router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic - assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - - # TODO(rob): convert this to MK. - if layer.enable_eplb: - raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") - assert layer.activation == MoEActivation.SILU, ( - f"Expected 'silu' activation but got {layer.activation}" + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, ) - if self.block_quant: - import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - - return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - x=x, - w13_weight=layer.w13_weight, - w13_weight_scale_inv=layer.w13_weight_scale_inv, - w2_weight=layer.w2_weight, - w2_weight_scale_inv=layer.w2_weight_scale_inv, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - intermediate_size=layer.intermediate_size_per_partition, - expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - block_shape=self.weight_block_size, - routing_method_type=layer.routing_method_type, - routed_scaling=layer.routed_scaling_factor, - ) - else: - return apply_fi_trtllm_fp8_per_tensor_moe( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) - def apply( self, layer: FusedMoE, @@ -1046,9 +987,9 @@ def apply( topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.moe_mk is not None assert not self.is_monolithic - return self.moe_mk( + assert self.moe_kernel is not None + return self.moe_kernel.apply( x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 999bb6325040..f167e2134470 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -13,7 +13,6 @@ init_fp8_linear_kernel, ) from vllm.model_executor.layers.attention import Attention, MLAAttention -from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -24,14 +23,12 @@ FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( - Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( - NvFp4MoeBackend, convert_to_nvfp4_moe_kernel_format, is_global_sf_supported_for_nvfp4_backend, make_nvfp4_moe_kernel, @@ -49,13 +46,6 @@ QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - flashinfer_trtllm_fp4_moe, - flashinfer_trtllm_fp4_routed_moe, -) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_fi_trtllm_fp8_per_tensor_moe, -) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, process_fp8_input_tensor_strategy_moe, @@ -746,7 +736,7 @@ def __init__( def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." @@ -754,9 +744,9 @@ def maybe_make_prepare_finalize( def select_gemm_impl( self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." @@ -871,16 +861,15 @@ def _setup_kernel( # Setup modular kernel. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config: - assert self.experts_cls is not None - self.moe_mk = make_fp8_moe_kernel( - moe_quant_config=self.moe_quant_config, - moe_config=self.moe, - fp8_backend=self.fp8_backend, - experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), - shared_experts=layer.shared_experts, - ) + assert self.experts_cls is not None + self.moe_kernel = make_fp8_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=layer.shared_experts, + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13 = layer.w13_weight @@ -913,9 +902,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: w1_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale a1_scale = layer.w13_input_scale @@ -929,10 +916,6 @@ def get_fused_moe_quant_config( a2_scale=a2_scale, ) - @property - def is_monolithic(self) -> bool: - return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - def apply_monolithic( self, layer: FusedMoE, @@ -940,28 +923,20 @@ def apply_monolithic( router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic - assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - if layer.enable_eplb: - raise NotImplementedError( - "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend." - ) - # TODO(rob): this validation should happen at kernel selection - # time in the oracle rather than here. - SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] - assert layer.activation in SUPPORTED_ACTIVATIONS, ( - f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer " - f"TRTLLM FP4 MoE, {layer.activation} found instead." - ) - return apply_fi_trtllm_fp8_per_tensor_moe( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=layer.e_score_correction_bias, + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, global_num_experts=layer.global_num_experts, - top_k=layer.top_k, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, - apply_router_weight_on_input=layer.apply_router_weight_on_input, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, ) def apply( @@ -973,25 +948,13 @@ def apply( shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic - - # TODO(rob): this validation should happen at kernel selection - # time in the oracle rather than here. - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - assert layer.activation in ( - MoEActivation.SILU, - MoEActivation.RELU2_NO_MUL, - ), ( - "Expected activation to be in ('silu', 'relu2_no_mul')," - f"but got {layer.activation}" - ) - - assert self.moe_mk is not None - return self.moe_mk( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, @@ -1235,17 +1198,7 @@ def __init__( def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: - raise ValueError( - f"{self.__class__.__name__} uses the new modular kernel initialization " - "logic. This function should not be called." - ) - - def select_gemm_impl( - self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." @@ -1420,51 +1373,18 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: replace_parameter(layer, "w2_weight_scale_2", w2_scale_2) replace_parameter(layer, "w2_input_scale", a2_scale) - # Setup modular kernel for TP case and naive DP/EP case. - # In non-naive DP/EP case, we will create a ModularKernelMethod. - # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel - # in both cases. + # Setup modular kernel. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config: - assert self.experts_cls is not None - self.moe_mk = make_nvfp4_moe_kernel( - moe_quant_config=self.moe_quant_config, - moe_config=self.moe, - experts_cls=self.experts_cls, - shared_experts=layer.shared_experts, - routing_tables=layer._maybe_init_expert_routing_tables(), - ) - - @property - def do_post_quant_allgather(self): - return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM - - def prepare_dp_allgather_tensor( - self, - layer: FusedMoE, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - """Optionally prepare extra tensors to carry through DP allgather/EP.""" - if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM: - raise RuntimeError( - "prepare_dp_allgather_tensor is only supported for " - "FlashInfer TRTLLM NVFP4 MoE backend." - ) - - import flashinfer - - hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize( - hidden_states, - layer.a1_gscale, - is_sf_swizzled_layout=False, + assert self.experts_cls is not None + self.moe_kernel = make_nvfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + shared_experts=layer.shared_experts, + routing_tables=layer._maybe_init_expert_routing_tables(), ) - extra_tensors: list[torch.Tensor] = [hidden_states_sf] - return hidden_states_fp4, extra_tensors - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: return make_nvfp4_moe_quant_config( backend=self.nvfp4_backend, w13_scale=layer.w13_weight_scale, @@ -1479,13 +1399,6 @@ def get_fused_moe_quant_config( def supports_eplb(self) -> bool: return True - @property - def is_monolithic(self) -> bool: - return ( - self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM - and not self.moe.moe_parallel_config.enable_eplb - ) - def apply_monolithic( self, layer: FusedMoE, @@ -1493,22 +1406,20 @@ def apply_monolithic( router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic - assert ( - self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM - and not layer.enable_eplb - ) - - return flashinfer_trtllm_fp4_moe( - layer=layer, - x=x, - router_logits=router_logits, - top_k=layer.top_k, + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, activation=layer.activation, global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, - custom_routing_function=layer.custom_routing_function, e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, ) def apply( @@ -1520,33 +1431,19 @@ def apply( shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic - - # EPLB path - if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - assert layer.enable_eplb - return flashinfer_trtllm_fp4_routed_moe( - layer=layer, - x=x, - topk_ids=topk_ids, - topk_weights=topk_weights, - top_k=layer.top_k, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - ) - else: - assert self.moe_mk is not None - return self.moe_mk( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - shared_experts_input=shared_experts_input, - ) + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 8856eb1e2e49..97d60178c849 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -266,7 +266,7 @@ def __init__(self, moe: FusedMoEConfig): ) self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} # Initialized in process_weights_after_loading for CUTLASS/SM90 backends - self.moe_mk: mk.FusedMoEModularKernel | None = None + self.moe_kernel: mk.FusedMoEKernel | None = None def create_weights( self, @@ -440,7 +440,7 @@ def process_weights_after_loading(self, layer): ) assert prepare_finalize is not None - self.moe_mk = mk.FusedMoEModularKernel( + self.moe_kernel = mk.FusedMoEKernel( prepare_finalize, MarlinExperts( self.moe, @@ -789,7 +789,7 @@ def _interleave_mxfp4_cutlass_sm90(w): ) assert prepare_finalize is not None - self.moe_mk = mk.FusedMoEModularKernel( + self.moe_kernel = mk.FusedMoEKernel( prepare_finalize, FlashInferExperts( moe_config=self.moe, @@ -954,9 +954,9 @@ def get_fused_moe_quant_config( def select_gemm_impl( self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: if ( prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts @@ -1043,8 +1043,8 @@ def apply( or self.mxfp4_backend == Mxfp4Backend.MARLIN ) - assert self.moe_mk is not None - return self.moe_mk( + assert self.moe_kernel is not None + return self.moe_kernel.apply( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index fadf56be1d4e..42677a5927b3 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -6,28 +6,18 @@ import torch -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm import _custom_ops as ops +import vllm.envs as envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.activation import MoEActivation -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, - FusedMoEParallelConfig, - RoutingMethodType, -) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - activation_to_flashinfer_int, align_fp4_moe_weights_for_fi, ) from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( swizzle_blockscale, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, - kNvfp4Dynamic, - kNvfp4Static, -) from vllm.platforms import current_platform +from vllm.utils.flashinfer import ( + has_flashinfer_cutlass_fused_moe, +) if TYPE_CHECKING: from vllm.model_executor.layers.fused_moe.layer import FusedMoE @@ -42,92 +32,15 @@ "reorder_w1w3_to_w3w1", ] -# -# Methods used by the oracle for kernel selection. -# - - -def _supports_current_device() -> bool: - """Supports only Blackwell-family GPUs.""" - p = current_platform - return p.is_cuda() and p.is_device_capability_family(100) - - -def _supports_no_act_and_mul() -> bool: - """Supports non-gated MoE.""" - return True - - -def _supports_quant_scheme( - weight_key: QuantKey | None, - activation_key: QuantKey | None, -) -> bool: - """Supports Nvfp4 quantization.""" - SUPPORTED_W_A = [ - (kNvfp4Static, kNvfp4Dynamic), - ] - return (weight_key, activation_key) in SUPPORTED_W_A - - -def _supports_activation(activation: MoEActivation) -> bool: - return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] - - -def _supports_routing_method( - routing_method: RoutingMethodType, -) -> bool: - """Monolithic kernels need to express router support.""" - # NOTE(rob): potentially allow others here. This is a conservative list. - return routing_method in [ - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - RoutingMethodType.Llama4, - ] - - -def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - """ - TRTLLM is a monolithic kernel that requires dispatch_router_logits() for - the naive dispatch/combine path. DeepEP HT only implements dispatch() for - the modular kernel path, so TRTLLM is incompatible with DeepEP HT. - """ - return not moe_parallel_config.use_deepep_ht_kernels - - -def is_supported_config_trtllm( - moe_config: FusedMoEConfig, - weight_key: QuantKey | None, - activation_key: QuantKey | None, - activation_format: mk.FusedMoEActivationFormat, -) -> tuple[bool, str | None]: - """ - This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config - """ - - def _make_reason(reason: str) -> str: - return f"kernel does not support {reason}" - - if not _supports_current_device(): - return False, _make_reason(f"current device {current_platform.device_name}") - elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): - return False, _make_reason("no act_and_mul MLP layer") - elif not _supports_activation(moe_config.activation): - return False, _make_reason(f"{moe_config.activation} activation") - elif not _supports_quant_scheme(weight_key, activation_key): - return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}") - elif not _supports_parallel_config(moe_config.moe_parallel_config): - return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}") - elif not _supports_routing_method(moe_config.routing_method): - return False, _make_reason(f"routing method {moe_config.routing_method}") - elif activation_format != mk.FusedMoEActivationFormat.Standard: - return False, _make_reason(f"activation format {activation_format}") - elif moe_config.hidden_dim % 512 != 0: - return False, _make_reason( - f"hidden_dim must be divisible by 512, found {moe_config.hidden_dim}" - ) - return True, None +def is_flashinfer_fp4_cutlass_moe_available() -> bool: + """Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" + return ( + envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and current_platform.is_cuda() + and current_platform.has_device_capability(100) + ) def reorder_w1w3_to_w3w1( @@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe( ) -def flashinfer_trtllm_fp4_moe( - layer: torch.nn.Module, - x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], - router_logits: torch.Tensor, - top_k: int, - activation: MoEActivation, - global_num_experts: int, - num_expert_group: int | None, - topk_group: int | None, - custom_routing_function: object | None, - e_score_correction_bias: torch.Tensor | None, -) -> torch.Tensor: - """ - Apply FlashInfer TensorRT-LLM FP4 MoE kernel. - - Args: - layer: The MoE layer with weights and scales - x: Input tensor - router_logits: Router logits for expert selection - top_k: Number of experts to select per token - activation: Activation function to use - global_num_experts: Total number of experts across all ranks - num_expert_group: Number of expert groups (for grouped routing) - topk_group: Top-k within each group - custom_routing_function: Custom routing function (e.g., Llama4) - e_score_correction_bias: Optional routing bias correction - - Returns: - Output tensor from the MoE layer - """ - import flashinfer - - from vllm.model_executor.models.llama4 import Llama4MoE - - SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] - assert activation in SUPPORTED_ACTIVATIONS, ( - f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer " - f"TRTLLM FP4 MoE, {activation} found instead." - ) - - # Quantize input to FP4 - if isinstance(x, tuple): - hidden_states_fp4, hidden_states_scale_linear_fp4 = x - else: - # hidden_states is the already quantized - (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant( - x, layer.a1_gscale, is_sf_swizzled_layout=False - ) - - # Determine routing method type - use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function - routing_method_type = layer.routing_method_type - if use_llama4_routing: - routing_method_type = flashinfer.RoutingMethodType.Llama4 - - # Cast to Fp32 (required by kernel). - router_logits = ( - router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ) - - # Determine activation type - activation_type = activation_to_flashinfer_int(layer.activation) - - # Call TRT-LLM FP4 block-scale MoE kernel - out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( - routing_logits=router_logits, - routing_bias=e_score_correction_bias, - hidden_states=hidden_states_fp4, - hidden_states_scale=hidden_states_scale_linear_fp4.view( - torch.float8_e4m3fn - ).reshape(*hidden_states_fp4.shape[:-1], -1), - gemm1_weights=layer.w13_weight.data, - gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=layer.w2_weight.data, - gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn), - gemm2_bias=None, - output1_scale_scalar=layer.g1_scale_c.data, - output1_scale_gate_scalar=layer.g1_alphas.data, - output2_scale_scalar=layer.g2_alphas.data, - num_experts=global_num_experts, - top_k=top_k, - n_group=num_expert_group if num_expert_group is not None else 0, - topk_group=topk_group if topk_group is not None else 0, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - routed_scaling_factor=None, - routing_method_type=routing_method_type, - do_finalize=True, - activation_type=activation_type, - )[0] - - return out - - -def flashinfer_trtllm_fp4_routed_moe( - layer: torch.nn.Module, - x: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - top_k: int, - activation: MoEActivation, - global_num_experts: int, -) -> torch.Tensor: - """ - Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed - input top k expert indices and scores rather than computing - top k expert indices from scores. - - Args: - layer: The MoE layer with weights and scales - x: Input tensor - topk_ids: Ids of selected experts - top_k: Number of experts to select per token - activation: Activation function to use - global_num_experts: Total number of experts across all ranks - - Returns: - Output tensor from the MoE layer - """ - import flashinfer - - # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535 - assert activation == MoEActivation.SILU, ( - "Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. " - f"{activation} found instead." - ) - - # Pack top k ids and expert weights into a single int32 tensor, as - # required by TRT-LLM - packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( - torch.bfloat16 - ).view(torch.int16) - - if isinstance(x, tuple): - # Hidden_states is the already quantized - hidden_states_fp4, hidden_states_scale_linear_fp4 = x - else: - # Quantize input to FP4 - (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant( - x, layer.a1_gscale, is_sf_swizzled_layout=False - ) - - # Call TRT-LLM FP4 block-scale MoE kernel - out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( - topk_ids=packed_tensor, - routing_bias=None, - hidden_states=hidden_states_fp4, - hidden_states_scale=hidden_states_scale_linear_fp4.view( - torch.float8_e4m3fn - ).reshape(*hidden_states_fp4.shape[:-1], -1), - gemm1_weights=layer.w13_weight.data, - gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=layer.w2_weight.data, - gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn), - gemm2_bias=None, - output1_scale_scalar=layer.g1_scale_c.data, - output1_scale_gate_scalar=layer.g1_alphas.data, - output2_scale_scalar=layer.g2_alphas.data, - num_experts=global_num_experts, - top_k=top_k, - n_group=0, - topk_group=0, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - routed_scaling_factor=None, - routing_method_type=1, - do_finalize=True, - )[0] - - return out - - def prepare_nvfp4_moe_layer_for_fi_or_cutlass( backend: "NvFp4MoeBackend", layer: "FusedMoE", @@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( ) ) layer.intermediate_size_per_partition = padded_intermediate + layer.moe_config.intermediate_size_per_partition = padded_intermediate w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe( w13, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 3d7d8e68fdcd..a8be1d61ac24 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum +from typing import TYPE_CHECKING import torch @@ -10,6 +11,9 @@ from vllm.platforms import current_platform from vllm.utils.math_utils import round_up +if TYPE_CHECKING: + from flashinfer.fused_moe.core import ActivationType + logger = init_logger(__name__) @@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum): def activation_to_flashinfer_int(activation: MoEActivation) -> int: + return activation_to_flashinfer_type(activation).value + + +def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType": from flashinfer.fused_moe.core import ActivationType # silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively @@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int: MoEActivation.GELU: ActivationType.Geglu, MoEActivation.RELU2_NO_MUL: ActivationType.Relu2, } - return ACTIVATION_TO_FI_ACTIVATION[activation].value + return ACTIVATION_TO_FI_ACTIVATION[activation] def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: @@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( ) -def register_scales_for_trtllm_fp8_per_tensor_moe( - layer: torch.nn.Module, - w13_scale: torch.Tensor, - w13_input_scale: torch.Tensor, - w2_scale: torch.Tensor, - w2_input_scale: torch.Tensor, -) -> None: - """Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel""" - g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( - w13_scale=w13_scale, - w13_input_scale=w13_input_scale, - w2_scale=w2_scale, - w2_input_scale=w2_input_scale, - ) - layer.w2_input_scale_inv = 1.0 / w2_input_scale - layer.output1_scales_gate_scalar = g1_alphas - - if layer.activation.is_gated: - layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv - else: - layer.output1_scales_scalar = ( - torch.ones_like(g1_alphas) * layer.w2_input_scale_inv - ) - layer.output2_scales_scalar = g2_alphas - - -def apply_fi_trtllm_fp8_per_tensor_moe( - layer: torch.nn.Module, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - top_k: int, - num_expert_group: int | None, - topk_group: int | None, - global_num_experts: int, - apply_router_weight_on_input: bool, -) -> torch.Tensor: - from flashinfer.fused_moe import RoutingMethodType - - import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - from vllm.model_executor.models.llama4 import Llama4MoE - - # Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe - assert ( - hasattr(layer, "output1_scales_scalar") - and hasattr(layer, "output1_scales_gate_scalar") - and hasattr(layer, "output2_scales_scalar") - ) - - if layer.routing_method_type == RoutingMethodType.Llama4: - assert ( - not layer.renormalize - and layer.custom_routing_function == Llama4MoE.custom_routing_function - ), ( - "FusedMoE flashinfer kernels with Llama4 routing method are only " - "supported for Llama4" - ) - else: - assert layer.custom_routing_function is None, ( - "Custom routing function is only supported for Llama4" - ) - activation_type = activation_to_flashinfer_int(layer.activation) - - return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe( - routing_logits=router_logits, - routing_bias=routing_bias, - hidden_states=hidden_states, - input_scale=layer.w13_input_scale, - gemm1_weights=layer.w13_weight, - gemm2_weights=layer.w2_weight, - output1_scales_scalar=layer.output1_scales_scalar, - output1_scales_gate_scalar=layer.output1_scales_gate_scalar, - output2_scales_scalar=layer.output2_scales_scalar, - num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - use_routing_scales_on_input=apply_router_weight_on_input, - routing_method_type=layer.routing_method_type, - activation_type=activation_type, - ) - - -def make_fp8_moe_alpha_scales_for_fi( - w13_scale: torch.Tensor, - w13_input_scale: torch.Tensor, - w2_scale: torch.Tensor, - w2_input_scale: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - g1_alphas = (w13_scale * w13_input_scale).squeeze() - g2_alphas = (w2_scale * w2_input_scale).squeeze() - - return g1_alphas, g2_alphas - - def get_flashinfer_moe_backend() -> FlashinferMoeBackend: backend_map = { "throughput": FlashinferMoeBackend.CUTLASS, @@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi( min_alignment, ) layer.intermediate_size_per_partition = new_intermediate + layer.moe_config.intermediate_size_per_partition = new_intermediate # FI kernels require W31 layout rather than W13. if layer.moe_config.is_act_and_mul: @@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi( w13_scale = swap_w13_to_w31(w13_scale) # FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle - # and registration of alpha scales. Note that we do not register - # as nn.Parameters since they are not needed for weight-reloading. + # and registration of alpha scales. if is_trtllm and not block_quant: assert w13_input_scale is not None assert w2_input_scale is not None rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated) - register_scales_for_trtllm_fp8_per_tensor_moe( - layer, - w13_scale=w13_scale, - w13_input_scale=w13_input_scale, - w2_scale=w2_scale, - w2_input_scale=w2_input_scale, - ) # Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel. # Some FP8 models have near-zero block scales (~1e-23) for dead/unused diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index f7df8f81347d..41854b628133 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -172,7 +172,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: # Further check if the ModularKernel implementation uses the DeepGemmExperts return isinstance( - module.quant_method.moe_mk, (DeepGemmExperts, TritonOrDeepGemmExperts) + module.quant_method.moe_kernel, (DeepGemmExperts, TritonOrDeepGemmExperts) ) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 1ba5981906ca..70abd8a6c503 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -88,9 +88,14 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None: Without autotuning, FlashInfer will rely on heuristics, which may be significantly slower. """ - from vllm.utils.flashinfer import autotune + import vllm.utils.flashinfer as fi_utils + + with torch.inference_mode(), fi_utils.autotune(): + # Certain FlashInfer kernels (e.g. nvfp4 routed moe) are + # incompatible with autotuning. This state is used to skip + # those kernels during the autotuning process. + fi_utils._is_fi_autotuning = True - with torch.inference_mode(), autotune(): # We skip EPLB here since we don't want to record dummy metrics # When autotuning with number of tokens m, flashinfer will autotune # operations for all number of tokens up to m. @@ -100,3 +105,5 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None: skip_eplb=True, is_profile=True, ) + + fi_utils._is_fi_autotuning = False diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 8ed9e11187cd..c3ac839c21d1 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -140,6 +140,7 @@ def wrapper(*args, **kwargs): "autotune", fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(), ) +_is_fi_autotuning: bool = False @functools.cache From 3a8eef5869b8997af22f7b204eba56f9e654875e Mon Sep 17 00:00:00 2001 From: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Date: Tue, 3 Mar 2026 13:43:56 -0600 Subject: [PATCH 32/53] [ROCm][Bugfix]: Disable AITER Triton ROPE by default (#35601) Signed-off-by: Rohan138 --- vllm/envs.py | 6 +++--- vllm/platforms/rocm.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 8c6eef3e7770..02fcd998a031 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -106,7 +106,7 @@ VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False - VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = True + VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_FP4BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False @@ -949,9 +949,9 @@ def _get_or_set_default() -> str: os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1") ), # Whether to use aiter rope. - # By default is enabled. + # By default is disabled. "VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "True").lower() in ("true", "1") + os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1") ), # Whether to use aiter triton fp8 bmm kernel # By default is enabled. diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ab4c3e0740a9..94675e3c96be 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -592,7 +592,6 @@ def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None: use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled() use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() - use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled() # Aiter rms norm perform best when CUDA Graph capture is enabled. if ( use_aiter_rms_norm @@ -619,9 +618,9 @@ def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None: and "-grouped_topk" not in compilation_config.custom_ops ): compilation_config.custom_ops.append("+grouped_topk") - # Enable rotary embedding when using AITER if its not disabled by user + # Enable rotary embedding customop when using AITER if not disabled by user if ( - use_aiter_triton_rope + rocm_aiter_ops.is_enabled() and "+rotary_embedding" not in compilation_config.custom_ops and "-rotary_embedding" not in compilation_config.custom_ops ): From e7213003cbf64d3f35b97d711eb595aa9e47039c Mon Sep 17 00:00:00 2001 From: Micah Williamson Date: Tue, 3 Mar 2026 14:57:34 -0600 Subject: [PATCH 33/53] [ROCm][CI] Fix TP size issue for `test_gpt_oss` (#35887) Signed-off-by: Micah Williamson --- tests/models/quantization/test_gpt_oss.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/models/quantization/test_gpt_oss.py b/tests/models/quantization/test_gpt_oss.py index 6fab653d009a..7599a5a5ee4c 100644 --- a/tests/models/quantization/test_gpt_oss.py +++ b/tests/models/quantization/test_gpt_oss.py @@ -21,6 +21,8 @@ import pytest from packaging import version +from vllm.utils.torch_utils import cuda_device_count_stateless + MODEL_ACCURACIES = { # Full quantization: attention linears and MoE linears "amd/gpt-oss-20b-WFP8-AFP8-KVFP8": 0.89, @@ -83,6 +85,9 @@ def get_model_args(self, tp_size: int): def test_gpt_oss_attention_quantization( model_name: str, tp_size: int, expected_accuracy: float ): + if tp_size > cuda_device_count_stateless(): + pytest.skip("Not enough GPUs to run this test case") + model_args = EvaluationConfig(model_name).get_model_args(tp_size) extra_run_kwargs = { From a9b8b13e5cdc52aa7f4472d4d21f178e3805bcdd Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Tue, 3 Mar 2026 16:29:57 -0500 Subject: [PATCH 34/53] [Bugfix] Fix misnamed parameter in compressed_tensors_moe.py (#35813) Signed-off-by: Bill Nell Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> --- .../quantization/compressed_tensors/compressed_tensors_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 8b7fc57d0409..f6c0009a5a41 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -887,7 +887,7 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: w13, w13_scale, shard_size=layer.intermediate_size_per_partition, - num_experts=layer.num_local_experts, + num_experts=layer.local_num_experts, is_act_and_mul=self.moe.is_act_and_mul, ) From 467886a0c48b37552c8a2f3bdea99e96f2e98f8c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 3 Mar 2026 13:47:45 -0800 Subject: [PATCH 35/53] [Model Runner V2] Fix inputs_embeds=None bug for MM models (#35917) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 35dd617eeba0..17a5be7d7060 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -907,9 +907,11 @@ def execute_model( ) inputs_embeds = None - if self.supports_mm_inputs and self.is_first_pp_rank and not dummy_run: + if self.supports_mm_inputs and self.is_first_pp_rank: # Run MM encoder (if needed) and get multimodal embeddings. # Only first PP rank prepares multimodal embeddings. + # NOTE(woosuk): We must call get_mm_embeddings even during dummy runs + # to obtain inputs_embeds, because the compiled model expects this input. inputs_embeds = self.model_state.get_mm_embeddings( scheduler_output.scheduled_encoder_inputs, input_batch, From 12b38c0f4560e33b32cd5fbe50881d4d2e97470e Mon Sep 17 00:00:00 2001 From: Amr Mahdi Date: Tue, 3 Mar 2026 14:30:47 -0800 Subject: [PATCH 36/53] [CI/Build] Allow mounting AWS credentials for sccache S3 auth (#35912) Signed-off-by: Amr Mahdi --- docker/Dockerfile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/Dockerfile b/docker/Dockerfile index 495a480b7582..ac6494ae9e58 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -262,7 +262,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Build the vLLM wheel # if USE_SCCACHE is set, use sccache to speed up compilation +# AWS credentials mounted at ~/.aws/credentials for sccache S3 auth (optional) RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=secret,id=aws-credentials,target=/root/.aws/credentials,required=false \ if [ "$USE_SCCACHE" = "1" ]; then \ echo "Installing sccache..." \ && case "${TARGETPLATFORM}" in \ From 97286a20ed5803583c50af3dd1f45268346be0e8 Mon Sep 17 00:00:00 2001 From: zhrrr <43847754+izhuhaoran@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:19:45 +0800 Subject: [PATCH 37/53] [Model Runner V2] support dp & ep for spec decoding (#35294) Signed-off-by: Giancarlo Delfin Signed-off-by: zhuhaoran Co-authored-by: Giancarlo Delfin --- vllm/v1/worker/gpu/model_runner.py | 56 +++++++--- .../worker/gpu/spec_decode/eagle/cudagraph.py | 20 ++++ .../gpu/spec_decode/eagle/speculator.py | 105 +++++++++++------- 3 files changed, 124 insertions(+), 57 deletions(-) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 17a5be7d7060..9267e187415f 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -57,10 +57,7 @@ from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager -from vllm.v1.worker.gpu.dp_utils import ( - get_cudagraph_and_dp_padding, - make_num_tokens_across_dp, -) +from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding from vllm.v1.worker.gpu.input_batch import ( InputBatch, InputBuffers, @@ -265,7 +262,7 @@ def load_model(self, *args, **kwargs) -> None: prepare_communication_buffer_for_model(self.model) if self.speculator is not None: - prepare_communication_buffer_for_model(self.speculator) + prepare_communication_buffer_for_model(self.speculator.model) # Initialize the components that require the model. self.model_state = init_model_state( @@ -382,8 +379,41 @@ def _dummy_run( return None, None assert self.execute_model_state is not None - input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state + ( + input_batch, + model_inputs, + attn_metadata, + slot_mappings_by_layer, + hidden_states, + aux_hidden_states, + kv_connector_output, + num_tokens_across_dp, + ) = self.execute_model_state self.execute_model_state = None + + # dummy run the eagle speculator's propose to ensure DP/EP sync. + if self.speculator is not None: + self.speculator.propose( + input_batch=input_batch, + attn_metadata=attn_metadata, + slot_mappings=slot_mappings_by_layer, + last_hidden_states=hidden_states, + aux_hidden_states=aux_hidden_states, + num_sampled=torch.ones( + input_batch.num_reqs, dtype=torch.int32, device=self.device + ), + num_rejected=torch.zeros( + input_batch.num_reqs, dtype=torch.int32, device=self.device + ), + last_sampled=self.req_states.last_sampled_tokens, + next_prefill_tokens=self.req_states.next_prefill_tokens, + temperature=self.sampler.sampling_states.temperature.gpu, + seeds=self.sampler.sampling_states.seeds.gpu, + num_tokens_across_dp=num_tokens_across_dp, + dummy_run=True, + skip_attn_for_dummy_run=skip_attn, + ) + assert hidden_states is not None # Last PP rank always has hidden_states sample_hidden_states = hidden_states[input_batch.logits_indices] return hidden_states, sample_hidden_states @@ -431,17 +461,6 @@ def profile_run(self) -> None: else: self._dummy_pooler_run(hidden_states) - if self.speculator is not None: - num_tokens_across_dp = make_num_tokens_across_dp( - self.parallel_config.data_parallel_size, self.max_num_tokens - ) - self.speculator.run_model( - self.max_num_tokens, - attn_metadata=None, - slot_mappings=None, - num_tokens_across_dp=num_tokens_across_dp, - ) - torch.cuda.synchronize() del hidden_states, sample_hidden_states gc.collect() @@ -979,6 +998,7 @@ def execute_model( hidden_states, aux_hidden_states, kv_connector_output, + num_tokens_across_dp, ) if not self.is_last_pp_rank: @@ -1005,6 +1025,7 @@ def sample_tokens( hidden_states, aux_hidden_states, kv_connector_output, + num_tokens_across_dp, ) = self.execute_model_state self.execute_model_state = None @@ -1078,6 +1099,7 @@ def sample_tokens( self.req_states.next_prefill_tokens, self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.seeds.gpu, + num_tokens_across_dp=num_tokens_across_dp, ) self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py index 77dddf3ada1c..157ed1182485 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py @@ -55,6 +55,26 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): def get_cudagraph_size(self, num_tokens: int) -> int | None: return self.cudagraph_sizes.get(num_tokens) + def get_cudagraph_runtime_mode( + self, num_tokens: int + ) -> tuple[CUDAGraphMode, int | None]: + cudagraph_size = self.get_cudagraph_size(num_tokens) + if cudagraph_size is None: + cudagraph_mode = CUDAGraphMode.NONE + else: + cudagraph_mode = self.cudagraph_mode + + if ( + cudagraph_mode == CUDAGraphMode.FULL + and cudagraph_size is not None + and cudagraph_size not in self.graphs + ): + # If graph wasn't captured yet, fall back to eager. + # This might happen when the dummy run is called before capture. + cudagraph_mode = CUDAGraphMode.NONE + cudagraph_size = None + return cudagraph_mode, cudagraph_size + def capture_graph( self, num_tokens: int, diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 9ea84386bdce..9185850dcb62 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -16,6 +16,7 @@ build_slot_mappings_by_layer, ) from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample @@ -48,6 +49,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vocab_size = self.draft_model_config.get_vocab_size() self.dtype = vllm_config.model_config.dtype + # DP configuration + self.dp_size = vllm_config.parallel_config.data_parallel_size + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + self.input_buffers = InputBuffers( max_num_reqs=self.max_num_reqs, max_num_tokens=self.max_num_tokens, @@ -122,8 +127,8 @@ def generate_draft( self, num_reqs: int, num_tokens_padded: int, - attn_metadata: dict[str, Any], - slot_mappings: dict[str, torch.Tensor], + attn_metadata: dict[str, Any] | None, + slot_mappings: dict[str, torch.Tensor] | None, num_tokens_across_dp: torch.Tensor | None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, ) -> None: @@ -164,9 +169,10 @@ def generate_draft( self.hidden_states, self.max_model_len, ) - self.block_tables.compute_slot_mappings( - idx_mapping, query_start_loc, pos - ) + if attn_metadata is not None: + self.block_tables.compute_slot_mappings( + idx_mapping, query_start_loc, pos + ) def capture_model(self) -> None: if self.num_speculative_steps == 1: @@ -203,6 +209,9 @@ def propose( temperature: torch.Tensor, # [max_num_reqs] seeds: torch.Tensor, + num_tokens_across_dp: torch.Tensor | None = None, + dummy_run: bool = False, + skip_attn_for_dummy_run: bool = False, ) -> torch.Tensor: # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the # number of rejected tokens, we maintain the size of eagle's input_ids and @@ -236,7 +245,7 @@ def propose( num_tokens, attn_metadata, slot_mappings, - num_tokens_across_dp=None, # FIXME + num_tokens_across_dp=num_tokens_across_dp, ) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -282,48 +291,64 @@ def propose( self.max_model_len, self.max_num_reqs, ) - query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] - slot_mappings = self.block_tables.compute_slot_mappings( - idx_mapping, query_start_loc, pos - ) - cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) - cudagraph_mode = self.cudagraph_manager.cudagraph_mode - if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL: + if not (dummy_run and skip_attn_for_dummy_run): + query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] + slot_mappings = self.block_tables.compute_slot_mappings( + idx_mapping, query_start_loc, pos + ) + + cudagraph_mode, cudagraph_size = ( + self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs) + ) + num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = ( + get_cudagraph_and_dp_padding( + num_reqs, + cudagraph_size, + cudagraph_mode.value, + self.dp_size, + self.dp_rank, + ) + ) + cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode) + if cudagraph_mode == CUDAGraphMode.FULL: # Run full CUDA graph. - self.cudagraph_manager.run_fullgraph(cudagraph_size) + self.cudagraph_manager.run_fullgraph(num_tokens_padded) return self.draft_tokens[:num_reqs] # Run eager or piecewise CUDA graph. - num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs - query_start_loc_cpu = torch.arange( - num_reqs + 1, dtype=torch.int32, device="cpu" - ) - block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] - - # FIXME(woosuk): This is UNSAFE!! - attn_metadata = build_attn_metadata( - attn_groups=self.attn_groups, - num_reqs=num_reqs, - num_tokens=num_reqs, - query_start_loc_gpu=query_start_loc, - query_start_loc_cpu=query_start_loc_cpu, - max_query_len=1, - seq_lens=self.input_buffers.seq_lens[:num_reqs], - max_seq_len=self.max_model_len, - block_tables=block_tables, - slot_mappings=slot_mappings, - kv_cache_config=self.kv_cache_config, - ) - slot_mappings_by_layer = build_slot_mappings_by_layer( - slot_mappings, self.kv_cache_config - ) + attn_metadata_updated = None + slot_mappings_updated = None + if not (dummy_run and skip_attn_for_dummy_run): + query_start_loc_cpu = torch.arange( + num_reqs + 1, dtype=torch.int32, device="cpu" + ) + block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] + + # FIXME(woosuk): This is UNSAFE!! + attn_metadata_updated = build_attn_metadata( + attn_groups=self.attn_groups, + num_reqs=num_reqs, + num_tokens=num_reqs, + query_start_loc_gpu=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + max_query_len=1, + seq_lens=self.input_buffers.seq_lens[:num_reqs], + max_seq_len=self.max_model_len, + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=self.kv_cache_config, + ) + slot_mappings_updated = build_slot_mappings_by_layer( + slot_mappings, self.kv_cache_config + ) + self.generate_draft( num_reqs, num_tokens_padded, - attn_metadata, - slot_mappings_by_layer, - num_tokens_across_dp=None, # FIXME + attn_metadata_updated, + slot_mappings_updated, + num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_mode, ) return self.draft_tokens[:num_reqs] From d15c3b90fc70ba8d787ee2b172caf5b978909fe9 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 3 Mar 2026 15:31:59 -0800 Subject: [PATCH 38/53] [Core] Move save_tensorized_model logic to Worker (#35825) Signed-off-by: Nick Hill --- vllm/v1/worker/gpu_model_runner.py | 13 +------------ vllm/v1/worker/gpu_worker.py | 10 +++++----- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8c92aab266e6..e4ddefc8180f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -58,7 +58,7 @@ MRotaryEmbedding, XDRotaryEmbedding, ) -from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.reload import ( finalize_layerwise_reload, initialize_layerwise_reload, @@ -194,7 +194,6 @@ ) if TYPE_CHECKING: - from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -4510,16 +4509,6 @@ def reload_weights( weights_not_loaded, ) - def save_tensorized_model( - self, - tensorizer_config: "TensorizerConfig", - ) -> None: - TensorizerLoader.save_model( - self.get_model(), - tensorizer_config=tensorizer_config, - model_config=self.model_config, - ) - def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 62f0433eff61..c0654abd53a2 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -57,6 +57,7 @@ from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.workspace import init_workspace_manager +from ...model_executor.model_loader import TensorizerLoader from .gpu.warmup import warmup_kernels from .utils import request_memory @@ -836,12 +837,11 @@ def save_sharded_state( max_size=max_size, ) - def save_tensorized_model( - self, - tensorizer_config: "TensorizerConfig", - ) -> None: - self.model_runner.save_tensorized_model( + def save_tensorized_model(self, tensorizer_config: "TensorizerConfig") -> None: + TensorizerLoader.save_model( + self.get_model(), tensorizer_config=tensorizer_config, + model_config=self.model_config, ) def init_weight_transfer_engine(self, init_info: dict) -> None: From f22ff2958c398ae0950598cdbb9c677c027fa5db Mon Sep 17 00:00:00 2001 From: Jaewon <52840625+jaewonlee-fb@users.noreply.github.com> Date: Tue, 3 Mar 2026 16:10:11 -0800 Subject: [PATCH 39/53] [Bugfix] Fix coord_socket assertion in DPEngineCoreProc for offline DP mode (#35916) Signed-off-by: Jaewon Lee --- vllm/v1/engine/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4de3e4ea7d3a..0c5cc29bf20f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1571,7 +1571,11 @@ def add_request(self, request: Request, request_wave: int = 0): def resume_scheduler(self): super().resume_scheduler() - if not self.engines_running and self.scheduler.has_unfinished_requests(): + if ( + self.has_coordinator + and not self.engines_running + and self.scheduler.has_unfinished_requests() + ): # Wake up other DP engines. self.output_queue.put_nowait( (-1, EngineCoreOutputs(start_wave=self.current_wave)) From f7da9cdffca2d7f11882249550cfa20605a0ca04 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Tue, 3 Mar 2026 19:44:14 -0600 Subject: [PATCH 40/53] [ROCm][CI] Support async weight transfer example with platform-aware determinism (#35710) Signed-off-by: Andreas Karatzas --- .buildkite/test-amd.yaml | 12 +- .../new_weight_syncing/rlhf_async_new_apis.py | 112 +++++++++++++----- 2 files changed, 91 insertions(+), 33 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 2b80937e8580..9130026e1c14 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -1339,6 +1339,7 @@ steps: - tests/v1/entrypoints/openai/test_multi_api_servers.py - tests/v1/shutdown - tests/v1/worker/test_worker_memory_snapshot.py + - examples/offline_inference/new_weight_syncing/ commands: # Work around HIP bug tracked here: https://github.com/ROCm/hip/issues/3876 # TODO: Remove when the bug is fixed in a future ROCm release @@ -1970,8 +1971,10 @@ steps: - label: Distributed Tests (4 GPUs) # 35min timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi355_4 + optional: true + # grade: Blocking working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -2025,7 +2028,8 @@ steps: - popd # NEW rlhf examples - pushd ../examples/offline_inference/new_weight_syncing - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - popd @@ -2989,8 +2993,10 @@ steps: - label: Distributed Tests (2 GPUs) # 68min timeout_in_minutes: 90 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi355_2 + optional: true + # grade: Blocking working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index e9bc06180069..5b72bf15934d 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -47,12 +47,14 @@ NCCLWeightTransferInitInfo, NCCLWeightTransferUpdateInfo, ) +from vllm.platforms import current_platform from vllm.utils.network_utils import get_ip, get_open_port from vllm.v1.executor import Executor MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base" MODEL_NAME_V2 = "Qwen/Qwen3-1.7B" PAUSE_TOKEN_THRESHOLD = 10 +ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN" class MyLLM(vllm.AsyncLLMEngine): @@ -116,10 +118,16 @@ def __init__(self, model_name: str): from vllm.model_executor.layers.batch_invariant import ( init_batch_invariance, ) + from vllm.platforms import current_platform from vllm.v1.attention.backends.registry import AttentionBackendEnum # need to init all env vars for batch invariance which affect nccl ops - init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) + attn_backend = ( + AttentionBackendEnum.TRITON_ATTN + if current_platform.is_rocm() + else AttentionBackendEnum.FLASH_ATTN + ) + init_batch_invariance(attn_backend) self.model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16 @@ -175,39 +183,56 @@ def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]: return new_token_ids -ray.init( - runtime_env={ - "env_vars": { - # enable batch invariance for deterministic outputs - "VLLM_BATCH_INVARIANT": "1", - # prevent ray from setting CUDA_VISIBLE_DEVICES - "RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1", - } - } -) +# Build platform-specific env vars for Ray +ray_env_vars = { + # Prevent Ray from setting CUDA_VISIBLE_DEVICES + "RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1", +} + +if current_platform.is_rocm(): + # For ROCm, BATCH_INVARIANT vllm is not supported + ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0" +else: + # Enable batch invariance for deterministic outputs on NVIDIA + ray_env_vars["VLLM_BATCH_INVARIANT"] = "1" + +ray.init(runtime_env={"env_vars": ray_env_vars}) # Launch the training model actor. Ray's resource scheduler will allocate # 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. train_model = TrainModel.remote(MODEL_NAME_V2) -# Launch the vLLM inference engine. The `enforce_eager` flag reduces -# start-up latency. -# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates -# its own placement groups internally for each DP rank, so we must NOT -# create an outer placement group (it would reserve GPUs and hide them -# from the internal DP resource check). -llm = ray.remote( - num_cpus=0, - num_gpus=0, -)(MyLLM).remote( +rocm_determinism_kwargs = {} +if current_platform.is_rocm(): + # ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and + # sequential request processing (max_num_seqs=1). + rocm_determinism_kwargs = { + "seed": 0, + "enable_prefix_caching": False, + "max_num_seqs": 1, + } + +# Build platform-specific LLM kwargs +llm_kwargs = dict( model=MODEL_NAME_V1, enforce_eager=True, max_model_len=8192, distributed_executor_backend="ray", - attention_backend="FLASH_ATTN", + attention_backend=ATTN_BACKEND, gpu_memory_utilization=0.75, weight_transfer_config=WeightTransferConfig(backend="nccl"), ) +llm_kwargs.update(rocm_determinism_kwargs) + +# Launch the vLLM inference engine. +# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates +# its own placement groups internally for each DP rank, so we must NOT +# create an outer placement group (it would reserve GPUs and hide them +# from the internal DP resource check). +llm = ray.remote( + num_cpus=0, + num_gpus=0, +)(MyLLM).remote(**llm_kwargs) PROMPTS = [ "The president of the United States is", @@ -304,25 +329,42 @@ def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]: print(f" New weights ({n_after} tokens): {after_text!r}") # ── Phase 2: validate with a fresh V2 vLLM instance ──────────────── +# This validation relies on batch-invariant (deterministic) generation to +# compare outputs from the weight-synced engine against a fresh V2 instance. +# On NVIDIA, batch invariance is fully supported, so we require 100% exact +# token match. On ROCm, batch invariance is not yet fully implemented +# (see https://github.com/vllm-project/vllm/issues/27433 and +# https://github.com/vllm-project/vllm/issues/33123), so residual +# non-determinism (e.g. GEMM accumulation order, missing kernel overrides) +# can cause single-token divergences that don't indicate a weight-sync +# failure. We relax the pass rate to 90% on ROCm to accommodate this; a +# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+. +MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9 + print(f"\n{'=' * 50}") print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance") +if current_platform.is_rocm(): + print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)") print(f"{'=' * 50}") ray.get(llm.shutdown.remote()) ray.kill(llm) ray.kill(train_model) -llm_v2 = ray.remote( - num_cpus=0, - num_gpus=0, -)(MyLLM).remote( +llm_v2_kwargs = dict( model=MODEL_NAME_V2, enforce_eager=True, max_model_len=8192, gpu_memory_utilization=0.75, distributed_executor_backend="ray", - attention_backend="FLASH_ATTN", + attention_backend=ATTN_BACKEND, ) +llm_v2_kwargs.update(rocm_determinism_kwargs) + +llm_v2 = ray.remote( + num_cpus=0, + num_gpus=0, +)(MyLLM).remote(**llm_v2_kwargs) val_futures = [ llm_v2.do_generate.remote( @@ -335,16 +377,17 @@ def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]: ] val_results = ray.get(val_futures) -all_pass = True +num_pass = 0 +num_total = len(results) for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)): expected = list(output.outputs[0].token_ids)[pause_idx:] actual = list(val_output.outputs[0].token_ids) match = actual == expected if match: + num_pass += 1 print(f" [PASS] {PROMPTS[i]!r}") else: - all_pass = False print(f" [FAIL] {PROMPTS[i]!r}") print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}") print(f" V2 vLLM: {tokenizer.decode(actual)!r}") @@ -359,5 +402,14 @@ def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]: ray.get(llm_v2.shutdown.remote()) ray.kill(llm_v2) -assert all_pass, "Some prompts failed validation, see above for details" + +pass_rate = num_pass / num_total +print(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})") +print(f" Required: >= {MIN_PASS_RATE:.0%}") + +assert pass_rate >= MIN_PASS_RATE, ( + f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) " + f"is below the required {MIN_PASS_RATE:.0%} threshold. " + f"See failures above for details." +) print("=" * 50) From 9a9d4424649fc360346bc63fd395c3f62731b7cf Mon Sep 17 00:00:00 2001 From: xjx <30485581+flutist@users.noreply.github.com> Date: Wed, 4 Mar 2026 09:46:47 +0800 Subject: [PATCH 41/53] Enable bnb for multiple indices weight (#35838) Signed-off-by: xjx <493337577@qq.com> Signed-off-by: Isotr0py Co-authored-by: Isotr0py --- vllm/model_executor/layers/linear.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f0d06e179f33..bfcdaa4c0cd2 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -744,10 +744,14 @@ def weight_loader( ) current_shard_offset = 0 use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) - if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple): + if ( + use_bitsandbytes_4bit + and isinstance(loaded_shard_id, tuple) + and self.tp_size > 1 + ): raise NotImplementedError( "Shard id with multiple indices is not supported " - "for BNB quantization yet." + "for BNB quantization with TP yet." ) shard_offsets: list[tuple[int, int, int]] = [] for i, output_size in enumerate(output_sizes): @@ -815,9 +819,14 @@ def weight_loader( is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit if use_bitsandbytes_4bit: - shard_size = loaded_weight.shape[output_dim] - shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id - + index = list(itertools.accumulate([0] + self.output_sizes)) + orig_offsets = { + str(i): (index[i], size) for i, size in enumerate(self.output_sizes) + } + orig_offsets["total"] = (self.output_size, 0) + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_offsets, str(loaded_shard_id) + ) param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = self.tp_rank * shard_size if not is_sharded_weight: From 70c73df69ee28fec37781d9bc82a994619ab95b1 Mon Sep 17 00:00:00 2001 From: William Zhang <133824995+2ez4bz@users.noreply.github.com> Date: Tue, 3 Mar 2026 18:18:11 -0800 Subject: [PATCH 42/53] [Bugfix] Fix EVS implementation for Qwen3 VL (#33607) Signed-off-by: 2ez4bz <133824995+2ez4bz@users.noreply.github.com> --- tests/model_executor/test_qwen3_vl_mrope.py | 237 ++++++ vllm/model_executor/models/qwen2_5_vl.py | 7 + vllm/model_executor/models/qwen2_vl.py | 1 + vllm/model_executor/models/qwen3_5.py | 21 +- vllm/model_executor/models/qwen3_vl.py | 821 ++++++++++++++------ vllm/model_executor/models/qwen3_vl_moe.py | 2 + vllm/multimodal/evs.py | 78 +- 7 files changed, 895 insertions(+), 272 deletions(-) create mode 100644 tests/model_executor/test_qwen3_vl_mrope.py diff --git a/tests/model_executor/test_qwen3_vl_mrope.py b/tests/model_executor/test_qwen3_vl_mrope.py new file mode 100644 index 000000000000..90d9fd6e4ff8 --- /dev/null +++ b/tests/model_executor/test_qwen3_vl_mrope.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses +import random +from dataclasses import dataclass + +import pytest +import torch + +from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalFieldElem, + MultiModalKwargsItem, + PlaceholderRange, +) + + +@pytest.fixture(autouse=True, scope="module") +def _force_cpu_default_device(): + # _get_mrope_input_positions returns CPU tensors (via torch.from_numpy). + # Ensure the default device is CPU so the rest of the test tensors match. + original = torch.get_default_device() + torch.set_default_device("cpu") + yield + torch.set_default_device(original) + + +IMAGE_TOKEN_ID = 999 +VIDEO_TOKEN_ID = 888 +VISION_START_TOKEN_ID = 777 +VISION_END_TOKEN_ID = 778 + + +@dataclass +class DummyVisionConfig: + spatial_merge_size: int = 1 + + +@dataclass +class DummyConfig: + image_token_id: int = IMAGE_TOKEN_ID + video_token_id: int = VIDEO_TOKEN_ID + vision_start_token_id: int = VISION_START_TOKEN_ID + vision_end_token_id: int = VISION_END_TOKEN_ID + vision_config: DummyVisionConfig = dataclasses.field( + default_factory=DummyVisionConfig + ) + + +def make_video_embedding( + t, h, w, interleave_text_tokens: tuple[int, int], video_pruning_rate: float = 0.0 +): + """ + Helper function to make a video embedding for a given video size and pruning rate. + + Args: + t: Number of frames. + h: Number of rows. + w: Number of columns. + interleave_text_tokens: Tuple of minimum and maximum number of text tokens to + interleave with the video. + video_pruning_rate: Pruning rate for the video. + + Returns: + Tuple of (unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask) + """ + unpruned_tokens_sequence = [] + population = list(range(1, 100)) + + for _ in range(t): + num_prefix_tokens = random.randint( + interleave_text_tokens[0], interleave_text_tokens[1] + ) + + prefix_tokens = random.choices(population, k=num_prefix_tokens) + vision_tokens = ( + [VISION_START_TOKEN_ID] + [VIDEO_TOKEN_ID] * h * w + [VISION_END_TOKEN_ID] + ) + + unpruned_tokens_sequence.extend(prefix_tokens) + unpruned_tokens_sequence.extend(vision_tokens) + + unpruned_tokens_sequence = torch.tensor(unpruned_tokens_sequence, dtype=torch.long) + video_token_mask = unpruned_tokens_sequence == VIDEO_TOKEN_ID + + pruning_mask = torch.bernoulli(video_token_mask.float() * video_pruning_rate).bool() # type: ignore[attr-defined] + # Sanity check that we don't prune what should not be pruned. + assert not pruning_mask[~video_token_mask].any() + + retention_mask = ~pruning_mask + pruned_tokens_sequence = unpruned_tokens_sequence[retention_mask] + return unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask + + +@pytest.mark.parametrize("spatial_merge_size", [1, 2]) +@pytest.mark.parametrize("grid_thw", [[3, 8, 7], [128, 10, 12]]) +@pytest.mark.parametrize("num_prefix_tokens", [1, 11]) +@pytest.mark.parametrize("num_suffix_tokens", [0, 7]) +@pytest.mark.parametrize("video_pruning_rate", [0, 0.25, 0.75]) +@pytest.mark.parametrize("interleave_text_tokens", [(0, 0), (1, 4)]) +def test_match_qwen3vl_mrope_evs_on( + spatial_merge_size: int, + num_prefix_tokens: int, + grid_thw: tuple[int, int, int], + num_suffix_tokens: int, + video_pruning_rate: float, + interleave_text_tokens: tuple[int, int], +): + hf_config = DummyConfig() + hf_config.vision_config.spatial_merge_size = spatial_merge_size + + t, h, w = grid_thw + population = list(range(1, 100)) + prefix_tokens = random.choices(population, k=num_prefix_tokens) + suffix_tokens = random.choices(population, k=num_suffix_tokens) + + video_tokens, video_tokens_pruned, retention_mask = make_video_embedding( + t, + h // spatial_merge_size, + w // spatial_merge_size, + interleave_text_tokens=interleave_text_tokens, + video_pruning_rate=video_pruning_rate, + ) + assert len(video_tokens) == len(retention_mask) + + input_tokens = prefix_tokens + video_tokens.tolist() + suffix_tokens + input_tokens_pruned = prefix_tokens + video_tokens_pruned.tolist() + suffix_tokens + + whole_sequence_retention_mask = torch.cat( + [ + torch.ones(len(prefix_tokens), dtype=torch.bool), + retention_mask, + torch.ones(len(suffix_tokens), dtype=torch.bool), + ], + dim=0, + ) + + # Build the GT mrope for unpruned input. + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem( + { + "video_grid_thw": MultiModalFieldElem( + data=torch.tensor(grid_thw), + field=None, # HACK. + ), + } + ), + modality="video", + identifier="DUMMY", + mm_position=PlaceholderRange(offset=0, length=len(input_tokens)), + ) + expected_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions( + input_tokens=input_tokens, + mm_features=[mm_feature], + config=hf_config, + ) + + # Compute mrope for a video-only media (unpruned). + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem( + { + "video_grid_thw": MultiModalFieldElem( + data=torch.tensor(grid_thw), + field=None, # HACK. + ), + } + ), + modality="video", + identifier="DUMMY", + mm_position=PlaceholderRange(offset=0, length=video_tokens.numel()), + ) + video_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions( + input_tokens=video_tokens.tolist(), + mm_features=[mm_feature], + config=hf_config, + ) + video_mrope = video_mrope.permute(1, 0) # [N, 3] + hidden_size = 16 + + is_video_embed = torch.isin( + video_tokens_pruned, torch.tensor([VIDEO_TOKEN_ID], dtype=torch.long) + ) + + expanded_positions = torch.full( + (len(video_tokens_pruned), 5), + fill_value=-100, + device=video_mrope.device, + dtype=torch.long, + ) + expanded_positions[is_video_embed, :3] = video_mrope[retention_mask][is_video_embed] + expanded_positions[~is_video_embed, :3] = video_mrope[retention_mask][ + ~is_video_embed + ] + + is_vision_start = video_tokens_pruned == VISION_START_TOKEN_ID + expanded_positions[..., 3] = is_vision_start + expanded_positions[..., 4] = is_video_embed + + # Check that all positions were filled, since we initialized them as negative. + assert (expanded_positions >= 0).all() + + video_embeddings = torch.empty( + (len(video_tokens_pruned), hidden_size), device=video_mrope.device + ) + + video_embeddings = torch.cat( + [ + video_embeddings, + expanded_positions.float(), + ], + dim=1, + ) + multimodal_embeddings = [video_embeddings] + + expected_mrope_masked = expected_mrope[:, whole_sequence_retention_mask] + + # Initialize computed_mrope with sequential positions for all prefix tokens + computed_mrope = torch.empty((3, len(input_tokens_pruned)), dtype=torch.long) + computed_mrope[:, 0 : len(prefix_tokens)] = expected_mrope[ + :, 0 : len(prefix_tokens) + ] + + # Paranoia check that computed_mrope is wrong. + assert not torch.equal(computed_mrope, expected_mrope_masked) + + _, actual_mrope, _ = Qwen3VLForConditionalGeneration._recompute_mrope_positions( + input_ids=input_tokens_pruned, + multimodal_embeddings=multimodal_embeddings, + mrope_positions=computed_mrope, + num_computed_tokens=len(prefix_tokens), + vision_start_token_id=hf_config.vision_start_token_id, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + ) + + assert torch.equal(actual_mrope, expected_mrope_masked) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3eeefbb3f26b..cd5c5356e558 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -195,6 +195,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): - second_per_grid_ts: The video time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Returned when `videos` is not `None`. + - timestamps: List of timestamp values (in seconds) for each frame + after merging. Length equals the temporal dimension after merging. """ type: Literal["pixel_values_videos"] @@ -214,6 +216,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): TensorShape("nv"), ] + timestamps: list[list[float]] | None = None + class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): """ @@ -232,6 +236,8 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): - second_per_grid_ts: The video time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Returned when `videos` is not `None`. + - timestamps: List of timestamp values (in seconds) for each frame + after merging. Length equals the temporal dimension after merging. """ type: Literal["video_embeds"] @@ -250,6 +256,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): torch.Tensor | None, TensorShape("nv"), ] = None + timestamps: list[list[float]] | None = None Qwen2_5_VLVideoInputs: TypeAlias = ( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index c4c71faf3958..aeacd99eb665 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -755,6 +755,7 @@ def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): "video", video_embed_grid_sizes ), video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True), + timestamps=MultiModalFieldConfig.batched("video", keep_on_cpu=True), ) return _qwen2vl_field_config diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index 66d8ff8e1b0a..30823ada1ee7 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -628,6 +628,9 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: dummy_inputs=Qwen3VLDummyInputsBuilder, ) class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid): + # Qwen3.5 does not support multimodal pruning (EVS). + supports_multimodal_pruning = False + packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | { "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"], "in_proj_ba": ["in_proj_b", "in_proj_a"], @@ -643,10 +646,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.video_pruning_rate = multimodal_config.video_pruning_rate - self.is_multimodal_pruning_enabled = ( - multimodal_config.is_multimodal_pruning_enabled() - ) + # Qwen3.5 does not support multimodal pruning (EVS). + self.is_multimodal_pruning_enabled = False with self._mark_tower_model(vllm_config, {"image", "video"}): self.visual = Qwen3_VisionTransformer( @@ -693,6 +694,12 @@ def embed_input_ids( return inputs_embeds + def recompute_mrope_positions(self, *args, **kwargs): + raise NotImplementedError( + "Qwen3.5 does not support multimodal pruning (EVS). " + "recompute_mrope_positions should never be called." + ) + def forward( self, input_ids: torch.Tensor, @@ -851,10 +858,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.video_pruning_rate = multimodal_config.video_pruning_rate - self.is_multimodal_pruning_enabled = ( - multimodal_config.is_multimodal_pruning_enabled() - ) + # Qwen3.5 does not support multimodal pruning (EVS). + self.is_multimodal_pruning_enabled = False with self._mark_tower_model(vllm_config, {"image", "video"}): self.visual = Qwen3_VisionTransformer( diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index e5bdbd8029c2..b19811977bbc 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -79,6 +79,7 @@ MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, + MultiModalFieldElem, MultiModalKwargsItem, MultiModalKwargsItems, PlaceholderRange, @@ -93,6 +94,8 @@ PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors +from vllm.tokenizers.protocol import TokenizerLike +from vllm.tokenizers.registry import cached_tokenizer_from_config from vllm.utils.collection_utils import is_list_of from vllm.utils.math_utils import round_up @@ -763,7 +766,6 @@ def _calculate_timestamps( def _get_video_second_idx( self, metadata: dict[str, Any], - out_item: MultiModalKwargsItem, do_sample_frames: bool | None = None, sampled_fps: float | None = None, ) -> list[int]: @@ -956,6 +958,7 @@ def _call_hf_processor( if videos := mm_data.pop("videos", []): video_grid_thw_lst = [] pixel_values_videos_lst = [] + timestamps_per_video = [] for item in videos: video_array, metadata = item @@ -979,6 +982,14 @@ def _call_hf_processor( **{k: metadata[k] for k in metadata if k != "do_sample_frames"} ) + # Compute timestamps here where we have access to metadata + timestamps = self.info._get_video_second_idx( + metadata=metadata, + do_sample_frames=video_mm_kwargs["do_sample_frames"], + sampled_fps=video_mm_kwargs.get("fps"), + ) + timestamps_per_video.append(timestamps) + video_mm_data = dict() video_mm_data["videos"] = [[video_array]] video_mm_data["video_metadata"] = [[metadata]] @@ -989,6 +1000,49 @@ def _call_hf_processor( mm_kwargs=video_mm_kwargs, tok_kwargs=tok_kwargs, ) + + merge_size = processor.video_processor.merge_size + # Get video grid info for EVS calculation. + video_grid_thw = video_outputs["video_grid_thw"] + num_frames = int(video_grid_thw[0, 0]) + tokens_per_frame_base = int(video_grid_thw[0, 1:].prod()) // ( + merge_size**2 + ) + + # Apply EVS if enabled. + video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate + if video_pruning_rate is not None and video_pruning_rate > 0.0: + num_tokens = compute_retained_tokens_count( + tokens_per_frame=tokens_per_frame_base, + num_frames=num_frames, + q=video_pruning_rate, + ) + # Here we just need placeholders that won't actually be replaced - + # we just need to make sure the total number of tokens is correct + # assign all tokens to the first frame. + tokens_per_frame = [num_tokens] + [0] * (num_frames - 1) + select_token_id = False + else: + tokens_per_frame = [tokens_per_frame_base] * num_frames + select_token_id = True + + # Generate the video replacement with EVS-adjusted token counts + tokenizer = self.info.get_tokenizer() + hf_config = self.info.get_hf_config() + video_repl = Qwen3VLMultiModalProcessor.get_video_repl( + tokens_per_frame=tokens_per_frame, + timestamps=timestamps, + tokenizer=tokenizer, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + video_token_id=hf_config.video_token_id, + select_token_id=select_token_id, + ) + + # Convert token IDs to text for the HF processor flow + video_placeholder = tokenizer.decode( + video_repl.full, skip_special_tokens=False + ) input_ids = video_outputs.pop("input_ids") video_placeholder = processor.tokenizer.batch_decode(input_ids)[0] prompt = prompt.replace( @@ -1002,6 +1056,7 @@ def _call_hf_processor( video_outputs = dict( pixel_values_videos=torch.cat(pixel_values_videos_lst), video_grid_thw=torch.cat(video_grid_thw_lst), + timestamps=timestamps_per_video, ) else: video_outputs = dict() @@ -1057,60 +1112,42 @@ def get_video_replacement_qwen3vl(item_idx: int): grid_thw = out_item["video_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) - video, metadata = mm_items["video"][item_idx] - do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames") sampled_fps = hf_processor_mm_kwargs.get("fps") if is_list_of(sampled_fps, float): sampled_fps = sampled_fps[item_idx] - timestamps = self.info._get_video_second_idx( - metadata, out_item, do_sample_frames, sampled_fps - ) + timestamps = out_item["timestamps"].data assert len(timestamps) == grid_thw[0], ( f"The timestamps length({len(timestamps)}) should be equal " f"video length ({grid_thw[0]})." ) - frames_idx_token = [ - tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False) - for curr_time in timestamps - ] - tokens_per_frame = int(grid_thw[1:].prod()) // merge_length - per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token] + # Compute tokens per frame, with EVS support + num_frames = int(grid_thw[0]) + tokens_per_frame_base = int(grid_thw[1:].prod()) // merge_length video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate if video_pruning_rate is not None and video_pruning_rate > 0.0: - total_retained = compute_retained_tokens_count( - tokens_per_frame, - len(frames_idx_token), - video_pruning_rate, + num_tokens = compute_retained_tokens_count( + tokens_per_frame=tokens_per_frame_base, + num_frames=num_frames, + q=video_pruning_rate, ) - if len(frames_idx_token) == 0: - per_frame_token_counts = [] - elif len(frames_idx_token) == 1: - per_frame_token_counts = [tokens_per_frame] - else: - first_frame_tokens = tokens_per_frame - remaining_tokens = max(total_retained - first_frame_tokens, 0) - base = remaining_tokens // (len(frames_idx_token) - 1) - remainder = remaining_tokens % (len(frames_idx_token) - 1) - per_frame_token_counts = [first_frame_tokens] - for frame_idx in range(1, len(frames_idx_token)): - extra = base + (1 if (frame_idx - 1) < remainder else 0) - per_frame_token_counts.append(extra) - - placeholder = [] - for frame_idx, timestamp_tokens in enumerate(frames_idx_token): - placeholder.extend(timestamp_tokens) - tokens_this_frame = per_frame_token_counts[ - frame_idx if frame_idx < len(per_frame_token_counts) else -1 - ] - placeholder.extend( - [vision_start_token_id] - + [video_token_id] * tokens_this_frame - + [vision_end_token_id] - ) - return PromptUpdateDetails.select_token_id(placeholder, video_token_id) + tokens_per_frame = [num_tokens] + [0] * (num_frames - 1) + select_token_id = False + else: + tokens_per_frame = [tokens_per_frame_base] * num_frames + select_token_id = True + + return Qwen3VLMultiModalProcessor.get_video_repl( + tokens_per_frame=tokens_per_frame, + timestamps=timestamps, + tokenizer=tokenizer, + vision_start_token_id=vision_start_token_id, + vision_end_token_id=vision_end_token_id, + video_token_id=video_token_id, + select_token_id=select_token_id, + ) return [ PromptReplacement( @@ -1127,6 +1164,69 @@ def get_video_replacement_qwen3vl(item_idx: int): ), ] + @staticmethod + def get_video_repl( + *, + tokens_per_frame: list[int], + timestamps: list[float | int], + tokenizer: TokenizerLike, + vision_start_token_id: int, + vision_end_token_id: int, + video_token_id: int, + select_token_id: bool = False, + ) -> PromptUpdateDetails[list[int]]: + """Build prompt replacement for a video in Qwen3VL format. + + The replacement structure for each frame is: + timestamp_tokens + vision_start_token + video_tokens + vision_end_token + + Args: + tokens_per_frame: Number of video tokens per frame (can vary per frame for + EVS). + timestamps: List of timestamps in seconds for each frame + tokenizer: Tokenizer to encode timestamp strings + vision_start_token_id: Token ID for vision start marker + vision_end_token_id: Token ID for vision end marker + video_token_id: Token ID for video content + + Returns: + PromptUpdateDetails with full token sequence + """ + assert len(timestamps) == len(tokens_per_frame), ( + "timestamps and tokens_per_frame must have the same length" + ) + + # Tokenize timestamp strings independently to avoid tokenizer merging + # tokens across boundaries. + # TODO: switch to `_seq2tokens` which has some caching. + timestamp_token_ids = [ + tokenizer.encode(f"<{timestamp:.1f} seconds>", add_special_tokens=False) + for timestamp in timestamps + ] + + # Build the full token sequence + all_token_ids = [] + for frame_timestamp_ids, num_tokens in zip( + timestamp_token_ids, tokens_per_frame + ): + # Add timestamp tokens + all_token_ids.extend(frame_timestamp_ids) + + # Add vision tokens: vision_start + video_tokens + vision_end + all_token_ids.append(vision_start_token_id) + all_token_ids.extend([video_token_id] * num_tokens) + all_token_ids.append(vision_end_token_id) + + if select_token_id: + return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id) + + # NOTE: we use `from_seq` instead of `select_token_id` because we want all + # tokens in the placeholder to be initially marked as candidates. Then + # in `get_input_embeddings``, we refine the mask to only replace + # `video_token_id` / `image_token_id`` positions with video/image embeddings, + # keeping text embeddings for timestamps and structural tokens. + return PromptUpdateDetails.from_seq(all_token_ids) + @support_torch_compile( dynamic_arg_dims={ @@ -1280,6 +1380,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): multimodal_config = vllm_config.model_config.multimodal_config self.config = config + self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config) self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.video_pruning_rate = multimodal_config.video_pruning_rate @@ -1419,6 +1520,7 @@ def _parse_and_validate_video_input( video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + timestamps = kwargs.pop("timestamps", None) if pixel_values_videos is None and video_embeds is None: return None @@ -1429,6 +1531,7 @@ def _parse_and_validate_video_input( pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, + timestamps=timestamps, ) if video_embeds is not None: @@ -1436,6 +1539,7 @@ def _parse_and_validate_video_input( type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw, + timestamps=timestamps, ) def _process_image_input( @@ -1502,19 +1606,29 @@ def _postprocess_image_embeds_evs( Returns: Tuple of image embeddings for each image item. - Resulting embeddings will have extra 4 channels for - computed mrope positions. + Resulting embeddings will have extra 5 channels for + computed mrope positions, consistent with video embeddings. """ - merge_size = self.visual.spatial_merge_size - grid_thw = image_input["image_grid_thw"] - grid_thw_list = grid_thw.tolist() - image_embeds_out = [] - for emb, size in zip(image_embeds_split, grid_thw_list): - positions = compute_mrope_for_media(size, merge_size).to(emb.device) - emb = torch.cat([emb, positions], dim=1) - image_embeds_out.append(emb) - image_embeds_split = image_embeds_out - return tuple(image_embeds_split) + if self.is_multimodal_pruning_enabled: + merge_size = self.visual.spatial_merge_size + grid_thw = image_input["image_grid_thw"] + grid_thw_list = grid_thw.tolist() + image_embeds_out = [] + for emb, size in zip(image_embeds_split, grid_thw_list): + positions = compute_mrope_for_media(size, merge_size).to(emb.device) + positions = torch.cat( + [ + positions, + torch.zeros_like( + positions[:, 0:1] + ), # Dummy extra fifth channel + ], + dim=1, + ) + emb = torch.cat([emb, positions], dim=1) + image_embeds_out.append(emb) + image_embeds_split = tuple(image_embeds_out) + return image_embeds_split def _postprocess_video_embeds_evs( self, @@ -1531,62 +1645,218 @@ def _postprocess_video_embeds_evs( Returns: Tuple of video embeddings for each video item. - Resulting embeddings will have extra 4 channels for - computed mrope positions. + Resulting embeddings will have extra 5 channels for computed mrope + positions, and whether the index corresponds to a video embedding. """ grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() merge_size = self.visual.spatial_merge_size - # Cast to long to match the original code - # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa - second_per_grid_ts = video_input.get("second_per_grid_ts") - if second_per_grid_ts is None: - # For Qwen3-VL, second_per_grid_ts might not be available - # Use default value of 1.0 for each video - second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long) + # Apply EVS to each video. + video_embeds_out = [] + for video_idx, (emb, size) in enumerate(zip(video_embeds_split, grid_thw_list)): + # Compute positions. + timestamps = video_input.timestamps[video_idx] + num_frames = len(timestamps) + + t, h, w = size + if self.is_multimodal_pruning_enabled: + # For each video, compute retention mask using EVS. + # retention_mask: [11424]. + retention_mask = compute_retention_mask( + emb, + size, + spatial_merge_size=self.visual.spatial_merge_size, + q=self.video_pruning_rate, + ) + # Apply retention mask. + emb = emb[retention_mask] + + # Calculate the actual number of retained tokens per frame. + num_frames, rows, cols = ( + t, + h // merge_size, + w // merge_size, + ) + retention_mask_thw = retention_mask.reshape(num_frames, rows, cols) + num_tokens_per_frame = ( + retention_mask_thw.sum(dim=(1, 2)).long().tolist() + ) + else: + feature_size = emb.shape[0] // num_frames + num_tokens_per_frame = [feature_size] * num_frames + retention_mask = None + + emb = self._create_final_video_embeddings( + video_embeddings=emb, + num_tokens_per_frame=num_tokens_per_frame, + timestamps=timestamps, + video_grid_thw=size, + retention_mask=retention_mask, + ) + + video_embeds_out.append(emb) + + return tuple(video_embeds_out) + + def _create_final_video_embeddings( + self, + video_embeddings: torch.Tensor, + num_tokens_per_frame: list[int], + timestamps: list[float], + video_grid_thw: list[int], + retention_mask: torch.Tensor, + ) -> torch.Tensor: + """Create final embeddings that combine video embeddings with + text embeddings of indicator tokens. + + These final embeddings contain: + - Actual video embeddings in positions corresponding to video content + - Text embeddings for indicator tokens (, , and + frame separation text) in their respective positions + + These embeddings will replace the placeholder embeddings to create + input_embeds for the LLM. + """ + device = video_embeddings.device + + # Generate video replacement token IDs using get_video_repl + # This tokenizes each frame separator independently, then uses pre-tokenized + # special tokens to ensure consistent tokenization regardless of + # num_tokens_per_frame values. + video_repl = Qwen3VLMultiModalProcessor.get_video_repl( + tokens_per_frame=num_tokens_per_frame, + tokenizer=self._tokenizer, + timestamps=timestamps, + vision_start_token_id=self.config.vision_start_token_id, + vision_end_token_id=self.config.vision_end_token_id, + video_token_id=self.config.video_token_id, + select_token_id=self.is_multimodal_pruning_enabled, + ) + + repl_token_ids = torch.tensor(video_repl.full, device=device) + embed_token_id = _cached_tensor(self.config.video_token_id, device=device) + is_video_embed = torch.isin(repl_token_ids, embed_token_id) + + # Get text embeddings for indicator tokens (has only `visual_dim``). + text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids) + + if self.use_deepstack: + ( + deepstack_input_embeds, + multimodal_embeddings, + ) = self._compute_deepstack_embeds( + inputs_embeds=text_embeddings, + multimodal_embeddings=[video_embeddings], + is_multimodal=is_video_embed, + ) else: - second_per_grid_ts = second_per_grid_ts.long() - tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0) + deepstack_input_embeds = None + multimodal_embeddings = [video_embeddings] - video_embeds_out = [] - for emb, size, video_second_per_grid_t in zip( - video_embeds_split, grid_thw_list, second_per_grid_ts - ): - # For each video, we compute retention mask using EVS - retention_mask = compute_retention_mask( - emb, - size, - spatial_merge_size=self.visual.spatial_merge_size, - q=self.video_pruning_rate, + merged_embeddings = _merge_multimodal_embeddings( + inputs_embeds=text_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_video_embed, + ) + + to_concat = [merged_embeddings] + if deepstack_input_embeds is not None: + to_concat.append( + deepstack_input_embeds.permute(1, 0, 2).reshape( + deepstack_input_embeds.shape[1], -1 + ) ) - # Debug logging for EVS pruning - logger.debug( - "EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, " - "pruning_rate=%.2f, reduction=%.1f%%)", - emb.shape[0], - retention_mask.sum().item(), - size[0], - size[1], - size[2], - self.video_pruning_rate, - (1 - retention_mask.float().mean().item()) * 100, + expanded_positions = None + if self.is_multimodal_pruning_enabled: + is_vision_start = repl_token_ids.eq(self.config.vision_start_token_id) + expanded_positions = self._get_expanded_positions( + device=merged_embeddings.device, + seq_len=merged_embeddings.shape[0], + video_grid_thw=video_grid_thw, + num_tokens_per_frame=num_tokens_per_frame, + timestamps=timestamps, + is_video_embed=is_video_embed, + is_vision_start=is_vision_start, + retention_mask=retention_mask, ) + to_concat.append(expanded_positions) - positions = compute_mrope_for_media( - size, - merge_size, - tokens_per_second=tokens_per_second, - video_second_per_grid=video_second_per_grid_t.item(), - ).to(emb.device) + final_video_embeddings = torch.cat(to_concat, dim=-1) - emb = emb[retention_mask] - positions = positions[retention_mask] - emb = torch.cat([emb, positions], dim=1) - video_embeds_out.append(emb) - return tuple(video_embeds_out) + return final_video_embeddings + + def _get_expanded_positions( + self, + device, + seq_len, + video_grid_thw, + num_tokens_per_frame, + timestamps, + is_video_embed, + is_vision_start, + retention_mask, + ): + embed_token_id = _cached_tensor(self.config.video_token_id, device=device) + + # Expand positions to match the full sequence length + # (includes both video tokens and indicator tokens) + # Shape: [full_length, 5] where positions are filled for video tokens + # and zeros for indicator tokens. + # Channel 3 flags VISION_START tokens so that + # recompute_mrope_positions can reliably count timestamp tokens + # (even when early frames have all video tokens pruned). + # Channel 4 flags video-embedding tokens. + expanded_positions = torch.zeros( + seq_len, + 5, # [t_index, h_index, w_index, is_vision_start, is_video] + device=device, + dtype=torch.long, + ) + _, h, w = video_grid_thw + merge_size = self.visual.spatial_merge_size + num_frames = len(num_tokens_per_frame) + unpruned_token_ids = Qwen3VLMultiModalProcessor.get_video_repl( + tokens_per_frame=[(h // merge_size) * (w // merge_size)] * num_frames, + tokenizer=self._tokenizer, + timestamps=timestamps, + vision_start_token_id=self.config.vision_start_token_id, + vision_end_token_id=self.config.vision_end_token_id, + video_token_id=self.config.video_token_id, + ).full + unpruned_token_ids_tensor = torch.tensor(unpruned_token_ids, device=device) + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem( + { + "video_grid_thw": MultiModalFieldElem( + data=torch.tensor(video_grid_thw), + field=None, # HACK. + ), + } + ), + modality="video", + identifier="DUMMY", + mm_position=PlaceholderRange(offset=0, length=len(unpruned_token_ids)), + ) + original_mrope = ( + self.get_mrope_input_positions( + input_tokens=unpruned_token_ids, + mm_features=[mm_feature], + )[0] + .to(device) + .permute(1, 0) + ) + full_is_video_embed = unpruned_token_ids_tensor == embed_token_id + expanded_positions[is_video_embed, :3] = original_mrope[full_is_video_embed][ + retention_mask + ] + expanded_positions[~is_video_embed, :3] = original_mrope[~full_is_video_embed] + expanded_positions[..., 3] = is_vision_start + expanded_positions[..., 4] = is_video_embed + + return expanded_positions def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -1607,66 +1877,77 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: ) return mm_input_by_modality - def iter_mm_grid_hw( - self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] - ) -> Iterator[tuple[int, int, int]]: - """ - Iterate over multimodal features and yield grid information. - - For videos with EVS (Efficient Video Sampling) enabled, this function - computes the offset based on the pruned token count rather than relying - on input_tokens.index(), which would fail when tokens are pruned. + @staticmethod + def _iter_mm_grid_hw( + input_tokens: list[int], + mm_features: list[MultiModalFeatureSpec], + video_token_id: int, + vision_start_token_id: int, + vision_end_token_id: int, + spatial_merge_size: int, + ) -> Iterator[tuple[int, int, int, int]]: + """Iterate over multimodal features and yield position info. Args: - input_tokens: List of token IDs in the prompt - mm_features: List of multimodal feature specifications + input_tokens: List of token IDs in the input sequence. + mm_features: List of multimodal feature specifications containing + image/video data and position information. + video_token_id: Token ID used for video tokens. + vision_start_token_id: Token ID marking the start of a vision sequence. + vision_end_token_id: Token ID marking the end of a vision sequence. + spatial_merge_size: Size of the spatial merge operation used to + compute logical grid dimensions from the original feature grid. Yields: - Tuple of (offset, grid_h, grid_w) for each frame/image + offset: Position of the first video/image token in the sequence. + llm_grid_h: Logical grid height (may not match actual token count with EVS). + llm_grid_w: Logical grid width (may not match actual token count with EVS). + actual_num_tokens: Actual number of video/image tokens in the placeholder. """ - video_token_id = self.config.video_token_id - spatial_merge_size = self.config.vision_config.spatial_merge_size for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): offset = mm_feature.mm_position.offset if mm_feature.modality == "image": t, h, w = mm_feature.data["image_grid_thw"].data.tolist() assert t == 1, f"Image must have 1 frame, got {t}" - yield offset, h // spatial_merge_size, w // spatial_merge_size + llm_grid_h = h // spatial_merge_size + llm_grid_w = w // spatial_merge_size + yield offset, llm_grid_h, llm_grid_w, llm_grid_h * llm_grid_w elif mm_feature.modality == "video": t, h, w = mm_feature.data["video_grid_thw"].data.tolist() llm_grid_h = h // spatial_merge_size llm_grid_w = w // spatial_merge_size - # Check if EVS (Efficient Video Sampling) is enabled - is_evs_enabled = ( - hasattr(self, "video_pruning_rate") - and self.video_pruning_rate is not None - and self.video_pruning_rate > 0.0 - ) - - if is_evs_enabled: - frame_offsets = self._extract_frame_offsets_from_mask( - mm_feature.mm_position, t - ) - if frame_offsets is not None: - for rel_offset in frame_offsets: - yield offset + rel_offset, llm_grid_h, llm_grid_w - continue - - # If EVS is enabled but mask is missing, this indicates a bug - # in the prompt processing pipeline. The is_embed mask should - # always be present when video_pruning_rate > 0. - raise RuntimeError( - f"EVS is enabled (pruning_rate={self.video_pruning_rate}) " - "but is_embed mask is missing from mm_position. " - "This indicates a bug in prompt processing." - ) - else: - # Non-EVS mode: Use original logic with input_tokens.index() - for _ in range(t): - offset = input_tokens.index(video_token_id, offset) - yield offset, llm_grid_h, llm_grid_w - offset += llm_grid_h * llm_grid_w + for _ in range(t): + # When EVS is enabled, some frames may have 0 video tokens in the + # placeholder. We use `vision_start_token_id` to locate each frame + # since it is always present for every frame. + # We then look for the first `video_token_id` after + # `vision_start_token_id` and before `vision_end_token_id`. + offset = input_tokens.index(vision_start_token_id, offset) + vision_end_offset = input_tokens.index(vision_end_token_id, offset) + + try: + actual_num_tokens = 0 + video_offset = input_tokens.index( + video_token_id, offset, vision_end_offset + ) + # NOTE: looking at the + # `Qwen3VLMultiModalProcessor.get_video_repl` code, we can + # see that we can use the below formula to get the token + # count, since everything in between `video_offset` and + # `vision_end_offset` is populated as `video_token_id`. + # This saves us from manually counting the number tokens + # that match `video_token_id` in between. + actual_num_tokens += vision_end_offset - video_offset + except ValueError: + # No `video_token_id` in this frame (EVS with 0 tokens for + # this frame) -> use `offset + 1`` to move past + # `vision_start_token_id`. + video_offset = offset + 1 + + yield video_offset, llm_grid_h, llm_grid_w, actual_num_tokens + # Move offset past this frame for next iteration. + offset = vision_end_offset + 1 else: raise ValueError(f"Unsupported modality: {mm_feature.modality}") @@ -1771,13 +2052,100 @@ def _get_actual_frame_token_counts( return [len(seg) for seg in segments] + def get_mrope_input_positions( + self, + input_tokens: list[int], + mm_features: list[MultiModalFeatureSpec], + ) -> tuple[torch.Tensor, int]: + return self._get_mrope_input_positions( + input_tokens=input_tokens, + mm_features=mm_features, + config=self.config, + ) + + @staticmethod + def _get_mrope_input_positions( + input_tokens: list[int], + mm_features: list[MultiModalFeatureSpec], + config: Qwen3VLConfig, + ): + llm_pos_ids_list = [] + st = 0 + for ( + offset, + llm_grid_h, + llm_grid_w, + actual_num_tokens, + ) in Qwen3VLForConditionalGeneration._iter_mm_grid_hw( + input_tokens, + mm_features, + video_token_id=config.video_token_id, + vision_start_token_id=config.vision_start_token_id, + vision_end_token_id=config.vision_end_token_id, + spatial_merge_size=config.vision_config.spatial_merge_size, + ): + # Skip frames with 0 tokens (EVS placeholder with tokens lumped elsewhere) + if actual_num_tokens == 0: + continue + + text_len = offset - st + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + ) + + # Check if this is a "lumped placeholder" (all tokens from multiple frames + # assigned to the 0-th frame - see + # `Qwen3VLMultiModalProcessor.get_video_repl`. + expected_tokens_per_frame = llm_grid_h * llm_grid_w + if actual_num_tokens > expected_tokens_per_frame: + # Lumped placeholder: create grid positions for all "logical" frames + # represented. + num_logical_frames = actual_num_tokens // expected_tokens_per_frame + remainder = actual_num_tokens % expected_tokens_per_frame + + # Create positions for complete frames. + for _ in range(num_logical_frames): + grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape( + 3, -1 + ) + llm_pos_ids_list.append(grid_indices + text_len + st_idx) + st_idx = llm_pos_ids_list[-1].max() + 1 + text_len = 0 # No text between frames within the lump + + # Handle remainder tokens if any (partial frame). + # NOTE: this should never be the case. Should we have an assert? + if remainder > 0: + # Create a partial grid - take first 'remainder' positions + full_grid = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) + grid_indices = full_grid[:, :remainder] + llm_pos_ids_list.append(grid_indices + text_len + st_idx) + else: + # Normal case: frame has exactly the expected tokens (after actual EVS + # pruning). + grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) + llm_pos_ids_list.append(grid_indices + text_len + st_idx) + + st = offset + actual_num_tokens + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + ) + + llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return torch.from_numpy(llm_positions), mrope_position_delta + def recompute_mrope_positions( self, input_ids: list[int], - multimodal_embeddings: tuple[torch.Tensor, ...], + multimodal_embeddings: MultiModalEmbeddings, mrope_positions: torch.LongTensor, num_computed_tokens: int, - ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]: + ) -> tuple[MultiModalEmbeddings, torch.Tensor, int]: """ Update part of input mrope positions (starting with num_computed_tokens index). Original mrope_positions are computed @@ -1786,9 +2154,10 @@ def recompute_mrope_positions( mrope_positions before we feed it to LLM. Args: - input_ids: (N,) All input tokens of the prompt (Containing - entire sequence). - multimodal_embeddings: Tuple of multimodal embeddings. + input_ids: (N,) All input tokens of the prompt containing + entire sequence. + multimodal_embeddings: Tuple of multimodal embeddings that + fits into the prefill chunk that is being processed. mrope_positions: Existing mrope positions (3, N) for entire sequence num_computed_tokens: A number of computed tokens so far. @@ -1797,10 +2166,26 @@ def recompute_mrope_positions( Tuple of (multimodal_embeddings, mrope_positions, mrope_position_delta). """ - image_token_id = self.config.image_token_id - video_token_id = self.config.video_token_id - vision_start_token_id = self.config.vision_start_token_id + return self._recompute_mrope_positions( + input_ids=input_ids, + multimodal_embeddings=multimodal_embeddings, + mrope_positions=mrope_positions, + num_computed_tokens=num_computed_tokens, + image_token_id=self.config.image_token_id, + video_token_id=self.config.video_token_id, + vision_start_token_id=self.config.vision_start_token_id, + ) + @staticmethod + def _recompute_mrope_positions( + input_ids: list[int], + multimodal_embeddings: MultiModalEmbeddings, + mrope_positions: torch.LongTensor, + num_computed_tokens: int, + vision_start_token_id: int, + image_token_id: int, + video_token_id: int, + ) -> tuple[MultiModalEmbeddings, torch.Tensor, int]: # Device device = ( multimodal_embeddings[0].device @@ -1811,10 +2196,21 @@ def recompute_mrope_positions( # Tensors input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) - mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] - mm_embeddings_pos = [ - mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings - ] + mm_embeddings_out = [] + mm_embeddings_pos = [] + # Strip position information from embeddings (last 5 channels) + # For Qwen3 VL, handle potentially empty frames (from unpacking) + for mm in multimodal_embeddings: + if mm.shape[0] > 0: # Only process non-empty frames + mm_embeddings_out.append(mm[:, :-5]) + mm_embeddings_pos.append(mm[:, -5:].permute(1, 0).long()) + else: + # Empty frame - keep as is + mm_embeddings_out.append(mm) + # Create empty position tensor with correct shape + mm_embeddings_pos.append( + torch.empty(5, 0, device=device, dtype=torch.long) + ) positions, mrope_positions_delta = recompute_mrope_positions( input_ids_t, @@ -1828,107 +2224,14 @@ def recompute_mrope_positions( return tuple(mm_embeddings_out), positions, mrope_positions_delta - def get_mrope_input_positions( - self, - input_tokens: list[int], - mm_features: list[MultiModalFeatureSpec], - ) -> tuple[torch.Tensor, int]: - # Pre-collect actual frame token counts for EVS mode - frame_token_counts_map = {} - for mm_feature in mm_features: - if mm_feature.modality == "video": - is_evs_enabled = ( - hasattr(self, "video_pruning_rate") - and self.video_pruning_rate is not None - and self.video_pruning_rate > 0.0 - ) - if is_evs_enabled: - t = mm_feature.data["video_grid_thw"].data.tolist()[0] - token_counts = self._get_actual_frame_token_counts( - mm_feature.mm_position, t - ) - assert token_counts is not None, ( - "EVS enabled but failed to extract frame token counts " - "from is_embed mask" - ) - frame_token_counts_map[mm_feature.mm_position.offset] = token_counts - - llm_pos_ids_list = [] - st = 0 - frame_counts_idx = {} - - for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( - input_tokens, mm_features - ): - text_len = offset - st - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - - # Determine actual token count for this frame - base_offset = None - for feat_offset in frame_token_counts_map: - if offset >= feat_offset: - base_offset = feat_offset - - if base_offset is not None: - # EVS mode: use actual token count from is_embed mask - assert base_offset in frame_token_counts_map, ( - f"Found base_offset {base_offset} but not in frame_token_counts_map" - ) - - if base_offset not in frame_counts_idx: - frame_counts_idx[base_offset] = 0 - - counts = frame_token_counts_map[base_offset] - idx = frame_counts_idx[base_offset] - - assert idx < len(counts), ( - f"EVS frame index {idx} out of range (total frames: {len(counts)})" - ) - - actual_frame_tokens = counts[idx] - frame_counts_idx[base_offset] += 1 - else: - # Non-EVS mode (or image): use theoretical grid size - actual_frame_tokens = llm_grid_h * llm_grid_w - - # Add text segment - text_positions = ( - np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx - ) - llm_pos_ids_list.append(text_positions) - st_idx += text_len - - # Add frame segment with actual token count (not theoretical) - grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) - # Only take the first actual_frame_tokens positions - frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx - llm_pos_ids_list.append(frame_positions) - - # Update st using actual token count - st = offset + actual_frame_tokens - - # Handle final text segment - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - final_text_positions = ( - np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx - ) - llm_pos_ids_list.append(final_text_positions) - - llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - - return torch.from_numpy(llm_positions), mrope_position_delta - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). - multimodal_embeddings: tuple[torch.Tensor, ...] = () + # tensor corresponding to a multimodal data item (image or video). + multimodal_embeddings: list[torch.Tensor] = [] # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. @@ -1936,19 +2239,20 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: multimodal_input = mm_input_by_modality[modality] if modality == "image": image_embeddings = self._process_image_input(multimodal_input) - if self.is_multimodal_pruning_enabled: - image_embeddings = self._postprocess_image_embeds_evs( - image_embeddings, multimodal_input - ) - multimodal_embeddings += tuple(image_embeddings) + image_embeddings = self._postprocess_image_embeds_evs( + image_embeddings, multimodal_input + ) + multimodal_embeddings.extend(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) if self.is_multimodal_pruning_enabled: video_embeddings = self._postprocess_video_embeds_evs( video_embeddings, multimodal_input ) - multimodal_embeddings += tuple(video_embeddings) - return multimodal_embeddings + multimodal_embeddings.extend(video_embeddings) + + embeddings_tuple = tuple(multimodal_embeddings) + return embeddings_tuple def _compute_deepstack_embeds( self, @@ -2128,3 +2432,8 @@ def get_num_mm_connector_tokens( vision_config = hf_config.vision_config merge_size = vision_config.spatial_merge_size return num_vision_tokens // merge_size**2 + + +@lru_cache +def _cached_tensor(x, device) -> torch.Tensor: + return torch.tensor(x, device=device) diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index 80815616bb7d..e6fc7d4093f4 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -45,6 +45,7 @@ ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors +from vllm.tokenizers.registry import cached_tokenizer_from_config from .interfaces import MixtureOfExperts from .qwen3_moe import ( @@ -415,6 +416,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): multimodal_config = vllm_config.model_config.multimodal_config self.config = config + self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config) self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.video_pruning_rate = multimodal_config.video_pruning_rate diff --git a/vllm/multimodal/evs.py b/vllm/multimodal/evs.py index 8a36ea415da4..62611c89719a 100644 --- a/vllm/multimodal/evs.py +++ b/vllm/multimodal/evs.py @@ -170,9 +170,9 @@ def recompute_mrope_positions( multimodal_embeddings may contain zero, some or even some part of all multimodal_embeddings for a given prompt. - Each multimodal_positions has 4 extra channels - (First 3 channels corresponds to original 3 mrope positions, last channel - is the maximum width of the media repeated). Provided multimodal_positions + Each multimodal_positions has 4 or 5 extra channels + (first 3 channels correspond to the original 3 mrope positions; + remaining channels vary by model — see below). Provided multimodal_positions do not reflect location of media position in sequence - they are computed like the media is in the 0-th position in the sequence. @@ -186,6 +186,16 @@ def recompute_mrope_positions( Args: input_ids: (N,) All input tokens of the prompt (entire sequence). multimodal_positions: List of mrope positions for each media. + If a given element is of shape (4, N), it is assumed to only describe + positions for video / image embeddings. This is the case of e.g. Qwen2.5 VL, + where each multimodal input is a contiguous chunk of embeddings. + The expected channels are [t, h, w, max_width]. + If it is of shape (5, N), it is assumed to possibly describe positions for + both video / image embeddings, as well as text embeddings. This is the case + of e.g. Qwen3 VL, where each video inputs are comprised of individual + frames' embeddings, interleaved with embeddings for timestamp tokens, + and vision start / end tokens. The expected channels are + [t, h, w, is_vision_start, is_vision]. mrope_positions: Existing mrope positions (4, N) for entire sequence. num_computed_tokens: A number of computed tokens so far. vision_start_token_id: Token indicating start of vision media. @@ -233,6 +243,21 @@ def recompute_mrope_positions( # - Current prefill chunk has no vision start indexes at all # - Vision start token appeared in previous prefill round # - Regular case + has_video_tokens = False + num_timestamp_tokens = 0 + if mm_pos.shape[0] == 5 and mm_pos.shape[1] > 0: + # mm_pos[4, :] indicates which positions are for video embeddings. + # If there are no video embeddings, skip timestamp adjustment. + has_video_tokens = torch.any(mm_pos[4, :]).item() + if has_video_tokens: + # Channel 3 flags VISION_START tokens. Timestamp tokens + # precede the first VISION_START, so its index gives us the + # exact timestamp count. This is robust even when early + # frames have all their video tokens pruned (which would + # push argmax(channel 4) far into a later frame). + first_vs = (mm_pos[3, :] == 1).nonzero(as_tuple=True)[0] + num_timestamp_tokens = first_vs[0].item() if len(first_vs) > 0 else 0 + seen_vision_start_indices = vision_start_indices[ vision_start_indices < num_computed_tokens ] @@ -249,6 +274,18 @@ def recompute_mrope_positions( in_the_middle_of_media = ( seen_mm_tokens > seem_mm_tokens_before_last_vision_start ) + # For Qwen3 VL, we can be inside a media segment even before any + # video tokens appear (timestamp tokens are text). If we've passed + # the last vision_start token but haven't reached the first video + # embedding, treat this as "in the middle of media". + if ( + not in_the_middle_of_media + and has_video_tokens + and num_computed_tokens > last_vision_start_token + and num_computed_tokens + <= last_vision_start_token + num_timestamp_tokens + 1 + ): + in_the_middle_of_media = True if in_the_middle_of_media: mm_embeddings_seen = ( @@ -274,14 +311,39 @@ def recompute_mrope_positions( mm_embeddings_seen = 0 global_mm_start = next_vision_start_token - # Offset right after vision_start_token - base = positions[-1, global_mm_start] + 1 - local_start = global_mm_start + 1 + mm_embeddings_seen + # For Qwen3 VL, mm_pos includes timestamp tokens before vision_start + # when starting a new media. Adjust global_mm_start to point to where + # the sequence actually begins (before timestamp tokens). + adjusted_for_timestamps = False + if mm_pos.shape[0] == 5 and mm_embeddings_seen == 0 and has_video_tokens: + # NOTE: -1 is because there is a vision start token right after + # timestamp tokens before any video embeddings appear. + + # Adjust global_mm_start to point to the first timestamp token + # instead of the vision_start token. + global_mm_start -= num_timestamp_tokens + adjusted_for_timestamps = True + + # Offset calculation depends on whether we adjusted for timestamp tokens + if adjusted_for_timestamps: + # Start from position before the first timestamp token + base = positions[-1, global_mm_start - 1] + 1 + local_start = global_mm_start + mm_embeddings_seen + else: + # Original logic: start after vision_start_token + base = positions[-1, global_mm_start] + 1 + local_start = global_mm_start + 1 + mm_embeddings_seen + local_end = local_start + mm_pos.shape[1] positions[:, local_start:local_end] = mm_pos[0:3] + base - # mm_pos[3, 0] is the max width of the media - offset = mm_pos[3, 0] + base + # For Qwen3 VL (5-channel), use the maximum position reached across + # all tokens (both video and text) in all dimensions (t, h, w). + # For Qwen2.5 VL (4-channel), mm_pos[3, 0] is the max width. + if mm_pos.shape[0] == 5: + offset = mm_pos[0:3, :].max() + base + 1 + else: + offset = mm_pos[3, 0] + base text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0) From 77e6dcbbfad8cfca6867663b164f038820f7a0be Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Wed, 4 Mar 2026 11:41:27 +0800 Subject: [PATCH 43/53] [PluggableLayer][MM] Add PluggableLayer for RelPosAttention (#33753) Signed-off-by: shen-shanshan <467638484@qq.com> --- docs/design/custom_op.md | 2 ++ vllm/model_executor/models/deepencoder.py | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/design/custom_op.md b/docs/design/custom_op.md index 034736ec6671..a62d033072b1 100644 --- a/docs/design/custom_op.md +++ b/docs/design/custom_op.md @@ -54,6 +54,8 @@ For example: --8<-- "vllm/model_executor/layers/attention/mm_encoder_attention.py:mm_encoder_attn" --8<-- "vllm/model_executor/layers/mla.py:multi_head_latent_attention" + +--8<-- "vllm/model_executor/models/deepencoder.py:rel_pos_attention" ``` **2. Activation:** diff --git a/vllm/model_executor/models/deepencoder.py b/vllm/model_executor/models/deepencoder.py index f7ae4264f696..68c101460d53 100644 --- a/vllm/model_executor/models/deepencoder.py +++ b/vllm/model_executor/models/deepencoder.py @@ -18,6 +18,7 @@ import torch.nn.functional as F from transformers import CLIPVisionConfig +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.quantization import QuantizationConfig @@ -263,9 +264,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class RelPosAttention(nn.Module): +# --8<-- [start:rel_pos_attention] +@PluggableLayer.register("rel_pos_attention") +class RelPosAttention(PluggableLayer): """Multi-head Attention block with relative position embeddings.""" + # --8<-- [end:rel_pos_attention] + def __init__( self, dim: int, From c1d963403c4f09cc0d5a25573c45d7405cd09abb Mon Sep 17 00:00:00 2001 From: AllenDou Date: Wed, 4 Mar 2026 11:41:30 +0800 Subject: [PATCH 44/53] [model] support FireRedASR2 (#35727) Signed-off-by: zixiao Signed-off-by: Isotr0py Co-authored-by: zixiao Co-authored-by: Isotr0py --- docs/models/supported_models.md | 1 + requirements/common.txt | 1 + tests/models/registry.py | 3 + vllm/model_executor/models/fireredasr2.py | 829 ++++++++++++++++++ vllm/model_executor/models/registry.py | 4 + .../transformers_utils/processors/__init__.py | 4 + .../processors/fireredasr2_processor.py | 341 +++++++ 7 files changed, 1183 insertions(+) create mode 100644 vllm/model_executor/models/fireredasr2.py create mode 100644 vllm/transformers_utils/processors/fireredasr2_processor.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 534411c63fb9..98d2a08d957c 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -793,6 +793,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition. | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|-------------------|----------------------|---------------------------| +| `FireRedASR2ForConditionalGeneration` | FireRedASR2 | `allendou/FireRedASR2-LLM-vllm`, etc. | | | | `FunASRForConditionalGeneration` | FunASR | `allendou/Fun-ASR-Nano-2512-vllm`, etc. | | | | `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ | diff --git a/requirements/common.txt b/requirements/common.txt index ec7ce5df9e85..9ee1b71512b1 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -57,3 +57,4 @@ opentelemetry-sdk >= 1.27.0 opentelemetry-api >= 1.27.0 opentelemetry-exporter-otlp >= 1.27.0 opentelemetry-semantic-conventions-ai >= 0.4.1 +kaldi-native-fbank >= 1.18.7 diff --git a/tests/models/registry.py b/tests/models/registry.py index 08f1a14d77b6..88017805f5f6 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -743,6 +743,9 @@ def check_available_online( "baidu/ERNIE-4.5-VL-28B-A3B-PT", trust_remote_code=True, ), + "FireRedASR2ForConditionalGeneration": _HfExamplesInfo( + "allendou/FireRedASR2-LLM-vllm", + ), "FunASRForConditionalGeneration": _HfExamplesInfo( "allendou/Fun-ASR-Nano-2512-vllm", ), diff --git a/vllm/model_executor/models/fireredasr2.py b/vllm/model_executor/models/fireredasr2.py new file mode 100644 index 000000000000..f0d3e124c03b --- /dev/null +++ b/vllm/model_executor/models/fireredasr2.py @@ -0,0 +1,829 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal, cast + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import ( + BatchFeature, + Qwen2Config, +) + +from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.linear import ( + ReplicatedLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.models.whisper_utils import ( + ISO639_1_SUPPORTED_LANGS, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.transformers_utils.processors.fireredasr2_processor import ( + FireRedASR2FeatureExtractor, +) +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsTranscription, + _require_is_multimodal, +) +from .qwen2 import Qwen2ForCausalLM +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class FireRedASR2AudioInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - nmb: Number of mel bins + - t: Time frames (M) + """ + + input_features: Annotated[ + list[torch.Tensor] | None, + TensorShape("b", "nmb", "t"), + ] + speech_lengths: Annotated[ + list[torch.Tensor] | None, + TensorShape("b"), + ] + fake_token_lengths: Annotated[ + list[torch.Tensor] | None, + TensorShape("b"), + ] + + +class Swish(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(x) + + +class Conv2dSubsampling(nn.Module): + def __init__(self, idim: int, d_model: int, out_channels: int = 32): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(1, out_channels, 3, 2), + nn.ReLU(), + nn.Conv2d(out_channels, out_channels, 3, 2), + nn.ReLU(), + ) + subsample_idim = ((idim - 1) // 2 - 1) // 2 + self.out = ReplicatedLinear( + input_size=out_channels * subsample_idim, + output_size=d_model, + bias=True, + ) + + self.subsampling = 4 + left_context = right_context = 3 # both exclude currect frame + self.context = left_context + 1 + right_context # 7 + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = x.unsqueeze(1) + x = self.conv(x) + N, C, T, D = x.size() + x, _ = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D)) + mask = x_mask[:, :, :-2:2][:, :, :-2:2] + input_lengths = mask[:, -1, :].sum(dim=-1) + return x, input_lengths, mask + + +class RelPositionalEncoding(nn.Module): + def __init__(self, d_model: int, max_len: int = 5000): + super().__init__() + pe_positive = torch.zeros(max_len, d_model, requires_grad=False) + pe_negative = torch.zeros(max_len, d_model, requires_grad=False) + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp( + torch.arange(0, d_model, 2).float() + * -(torch.log(torch.tensor(10000.0)).item() / d_model) + ) + pe_positive[:, 0::2] = torch.sin(position * div_term) + pe_positive[:, 1::2] = torch.cos(position * div_term) + pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) + pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) + + pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) + pe_negative = pe_negative[1:].unsqueeze(0) + self.pe = torch.cat([pe_positive, pe_negative], dim=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Tmax = 2 * max_len - 1 + Tmax, T = self.pe.size(1), x.size(1) + pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach() + return pos_emb + + +class ConformerFeedForward(nn.Module): + def __init__(self, d_model: int): + super().__init__() + self.pre_layer_norm = nn.LayerNorm(d_model) + self.linear_expand = ReplicatedLinear( + input_size=d_model, + output_size=d_model * 4, + bias=True, + ) + self.nonlinear = Swish() + self.linear_project = ReplicatedLinear( + input_size=d_model * 4, + output_size=d_model, + bias=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.pre_layer_norm(x) + x, _ = self.linear_expand(x) + x = self.nonlinear(x) + x, _ = self.linear_project(x) + output = x + residual + return output + + +class EncoderMultiHeadAttention(nn.Module): + def __init__(self, n_head: int, d_model: int): + super().__init__() + assert d_model % n_head == 0 + self.n_head = n_head + self.d_k = d_model // n_head + self.d_v = self.d_k + + self.w_qs = ReplicatedLinear( + input_size=d_model, output_size=n_head * self.d_k, bias=False + ) + self.w_ks = ReplicatedLinear( + input_size=d_model, output_size=n_head * self.d_k, bias=False + ) + self.w_vs = ReplicatedLinear( + input_size=d_model, output_size=n_head * self.d_v, bias=False + ) + + self.layer_norm_q = nn.LayerNorm(d_model) + self.layer_norm_k = nn.LayerNorm(d_model) + self.layer_norm_v = nn.LayerNorm(d_model) + + self.fc = ReplicatedLinear( + input_size=n_head * self.d_v, output_size=d_model, bias=False + ) + + def forward_qkv( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) + + q = self.layer_norm_q(q) + k = self.layer_norm_k(k) + v = self.layer_norm_v(v) + + q = self.w_qs(q)[0].view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k)[0].view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v)[0].view(sz_b, len_v, n_head, d_v) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + return q, k, v + + def forward_output( + self, output: torch.Tensor, residual: torch.Tensor, sz_b: int, len_q: int + ) -> torch.Tensor: + output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1) + fc_out, _ = self.fc(output) + output = fc_out + output = output + residual + return output + + def forward_attention( + self, attn: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + if mask is not None: + mask = mask.unsqueeze(1) + mask = mask.eq(0) + attn = attn.masked_fill(mask, -float("inf")) + attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0) + else: + attn = torch.softmax(attn, dim=-1) + + d_attn = attn + output = torch.matmul(d_attn, v) + + return output, attn + + +class RelPosMultiHeadAttention(EncoderMultiHeadAttention): + def __init__(self, n_head: int, d_model: int): + super().__init__(n_head, d_model) + d_k = d_model // n_head + self.scale = 1.0 / (d_k**0.5) + self.linear_pos = ReplicatedLinear( + input_size=d_model, output_size=n_head * d_k, bias=False + ) + self.pos_bias_u = nn.Parameter(torch.empty([n_head, d_k])) + self.pos_bias_v = nn.Parameter(torch.empty([n_head, d_k])) + + def _rel_shift(self, x): + N, H, T1, T2 = x.size() + zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(N, H, T2 + 1, T1) + x = x_padded[:, :, 1:].view_as(x) + x = x[:, :, :, : x.size(-1) // 2 + 1] + return x + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + pos_emb: torch.Tensor, + mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + sz_b, len_q = q.size(0), q.size(1) + + residual = q + q, k, v = self.forward_qkv(q, k, v) + + q = q.transpose(1, 2) + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb)[0].view(n_batch_pos, -1, self.n_head, self.d_k) + p = p.transpose(1, 2) + + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self._rel_shift(matrix_bd) + + attn_scores = matrix_ac + matrix_bd + attn_scores.mul_(self.scale) + + output, attn = self.forward_attention(attn_scores, v, mask=mask) + + output = self.forward_output(output, residual, sz_b, len_q) + return output, attn + + +class ConformerConvolution(nn.Module): + def __init__(self, d_model: int, kernel_size: int = 33): + super().__init__() + assert kernel_size % 2 == 1 + self.pre_layer_norm = nn.LayerNorm(d_model) + self.pointwise_conv1 = nn.Conv1d( + d_model, d_model * 4, kernel_size=1, bias=False + ) + self.padding = (kernel_size - 1) // 2 + self.depthwise_conv = nn.Conv1d( + d_model * 2, + d_model * 2, + kernel_size, + stride=1, + padding=self.padding, + groups=d_model * 2, + bias=False, + ) + self.batch_norm = nn.LayerNorm(d_model * 2) + self.swish = Swish() + self.pointwise_conv2 = nn.Conv1d( + d_model * 2, d_model, kernel_size=1, bias=False + ) + + def forward( + self, x: torch.Tensor, mask: torch.Tensor | None = None + ) -> torch.Tensor: + residual = x + out = self.pre_layer_norm(x) + out = out.transpose(1, 2) + if mask is not None: + out.masked_fill_(mask.ne(1), 0.0) + out = self.pointwise_conv1(out) + out = F.glu(out, dim=1) + out = self.depthwise_conv(out) + + out = out.transpose(1, 2) + out = self.swish(self.batch_norm(out)) + out = out.transpose(1, 2) + + out = self.pointwise_conv2(out) + if mask is not None: + out.masked_fill_(mask.ne(1), 0.0) + out = out.transpose(1, 2) + return out + residual + + +class RelPosEmbConformerBlock(nn.Module): + def __init__(self, d_model, n_head, kernel_size=33): + super().__init__() + self.ffn1 = ConformerFeedForward(d_model) + self.mhsa = RelPosMultiHeadAttention(n_head, d_model) + self.conv = ConformerConvolution(d_model, kernel_size) + self.ffn2 = ConformerFeedForward(d_model) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + x: torch.Tensor, + pos_emb: torch.Tensor, + slf_attn_mask: torch.Tensor | None = None, + pad_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + out = 0.5 * x + 0.5 * self.ffn1(x) + out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0] + out = self.conv(out, pad_mask) + out = 0.5 * out + 0.5 * self.ffn2(out) + out = self.layer_norm(out) + return out + + +class ConformerEncoder(nn.Module): + def __init__( + self, + idim: int, + n_layers_enc: int, + n_head: int, + d_model: int, + kernel_size: int = 33, + pe_maxlen: int = 5000, + ): + super().__init__() + self.odim = d_model + + self.input_preprocessor = Conv2dSubsampling(idim, d_model) + self.positional_encoding = RelPositionalEncoding(d_model) + + self.layer_stack = nn.ModuleList() + for _ in range(n_layers_enc): + block = RelPosEmbConformerBlock(d_model, n_head, kernel_size) + self.layer_stack.append(block) + + def forward( + self, padded_input: torch.Tensor, input_lengths: torch.Tensor, pad: bool = True + ): + if pad: + padded_input = F.pad( + padded_input, + (0, 0, 0, self.input_preprocessor.context - 1), + "constant", + 0.0, + ) + src_mask = self.padding_position_is_0(padded_input, input_lengths) + + embed_output, input_lengths, src_mask = self.input_preprocessor( + padded_input, src_mask + ) + enc_output = embed_output + + pos_emb = self.positional_encoding(embed_output) + + enc_outputs = [] + for enc_layer in self.layer_stack: + enc_output = enc_layer( + enc_output, pos_emb, slf_attn_mask=src_mask, pad_mask=src_mask + ) + enc_outputs.append(enc_output) + + return enc_output, input_lengths, src_mask + + def padding_position_is_0( + self, padded_input: torch.Tensor, input_lengths: torch.Tensor + ) -> torch.Tensor: + N, T = padded_input.size()[:2] + mask = torch.ones((N, T)).to(padded_input.device) + for i in range(N): + mask[i, input_lengths[i] :] = 0 + mask = mask.unsqueeze(dim=1) + return mask.to(torch.uint8) + + +class FireRedASR2Adapter(nn.Module): + def __init__(self, encoder_dim: int, llm_dim: int, downsample_rate: int = 2): + super().__init__() + self.ds = downsample_rate + self.linear1 = ReplicatedLinear( + input_size=encoder_dim * downsample_rate, + output_size=llm_dim, + bias=True, + ) + self.relu = _ACTIVATION_REGISTRY["relu"] + self.linear2 = ReplicatedLinear( + input_size=llm_dim, + output_size=llm_dim, + bias=True, + ) + + def forward(self, x, x_lens): + batch_size, seq_len, feat_dim = x.size() + num_frames_to_discard = seq_len % self.ds + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view(batch_size, seq_len // self.ds, feat_dim * self.ds) + + x, _ = self.linear1(x) + x = self.relu(x) + x, _ = self.linear2(x) + + new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds + return x, new_x_lens + + +class FireRedASR2Encoder(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + ): + super().__init__() + self.audio_encoder = ConformerEncoder( + **vllm_config.model_config.hf_config.audio_encoder_conf + ) + + +class FireRedASR2Model(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.encoder = FireRedASR2Encoder( + vllm_config=vllm_config, + ) + encoder_dim = self.encoder.audio_encoder.odim + llm_dim = vllm_config.model_config.hf_config.hidden_size + self.encoder_projector = FireRedASR2Adapter( + encoder_dim, + llm_dim, + vllm_config.model_config.hf_config.encoder_downsample_rate, + ) + + self.decoder = Qwen2ForCausalLM( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "decoder") + ) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + decoder_outputs = self.decoder( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) + return decoder_outputs + + def get_encoder_outputs( + self, + speech: torch.Tensor | list[torch.Tensor] | None, + speech_lengths: torch.Tensor | list[torch.Tensor] | None, + ) -> torch.Tensor | None: + encoder_outs, enc_lengths, enc_mask = self.encoder.audio_encoder( + speech, speech_lengths + ) + speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths) + return speech_features + + +class FireRedASR2ProcessingInfo(BaseProcessingInfo): + def get_hf_config(self) -> Qwen2Config: + return self.ctx.get_hf_config(Qwen2Config) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": 1} + + def get_feature_extractor(self, **kwargs: object) -> FireRedASR2FeatureExtractor: + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, FireRedASR2FeatureExtractor) + return feature_extractor + + def get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.get_feature_extractor() + return MultiModalDataParser( + target_sr=feature_extractor.sampling_rate, + target_channels=self.get_target_channels(), + ) + + def get_target_channels(self) -> int: + return 1 + + +class FireRedASR2DummyInputsBuilder(BaseDummyInputsBuilder[FireRedASR2ProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + return "<|AUDIO|>" * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions], + ) -> MultiModalDataDict: + feature_extractor = self.info.get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) + + audio_overrides = mm_options.get("audio") + + ret = { + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) + } + return ret + + +class FireRedASR2MultiModalProcessor( + BaseMultiModalProcessor[FireRedASR2ProcessingInfo] +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_data = dict(audio=mm_data.pop("audios")) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + if "labels" in processed_outputs: + processed_outputs["input_ids"] = processed_outputs.pop("labels") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_features=MultiModalFieldConfig.batched("audio"), + speech_lengths=MultiModalFieldConfig.batched("audio"), + fake_token_lengths=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + audio_token = getattr(processor, "audio_token", "<|AUDIO|>") + + audio_token_id = vocab[audio_token] + + out_mm_data = out_mm_kwargs.get_data() + + fake_token_lengths = out_mm_data.get("fake_token_lengths") + + if fake_token_lengths is None: + audio_output_lengths = [] + else: + assert isinstance(fake_token_lengths, torch.Tensor) + + audio_output_lengths = fake_token_lengths.tolist() + + def get_replacement_fireredasr2_audio(item_idx: int): + num_features = audio_output_lengths[item_idx] + + audio_tokens = [audio_token_id] * int(num_features) + + return PromptUpdateDetails.select_token_id( + audio_tokens, + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=[audio_token_id], + replacement=get_replacement_fireredasr2_audio, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + FireRedASR2MultiModalProcessor, + info=FireRedASR2ProcessingInfo, + dummy_inputs=FireRedASR2DummyInputsBuilder, +) +class FireRedASR2ForConditionalGeneration( + nn.Module, SupportsTranscription, SupportsMultiModal +): + packed_modules_mapping = { + "self_attn.qkv_proj": [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + ], + "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "llm.": "model.decoder.", + "encoder.": "model.encoder.audio_encoder.", + "encoder_projector.": "model.encoder_projector.", + "net.0": "pre_layer_norm", + "net.1": "linear_expand", + "net.4": "linear_project", + } + ) + + supports_transcription_only = True + supports_segment_timestamp = True + supported_languages = ISO639_1_SUPPORTED_LANGS + + @classmethod + def validate_language(cls, language: str | None) -> str | None: + if language is None: + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + logger.warning( + "Defaulting to language='en'. If you wish to transcribe " + "audio in a different language, pass the `language` field " + "in the TranscriptionRequest." + ) + language = "en" + return super().validate_language(language) + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + model_config: ModelConfig, # not needed here + stt_config: SpeechToTextConfig, + language: str | None, + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: str | None, + ) -> PromptType: + if language is None: + raise ValueError( + "Language must be specified when creating the fireredasr2 prompt" + ) + + prompt_str = "<|im_start|>user\n<|AUDIO|>请转写音频为文字<|im_end|>\n<|im_start|>assistant\n" # noqa: E501 + prompt = { + "prompt": prompt_str, + "multi_modal_data": { + "audio": (audio, stt_config.sample_rate), + }, + } + return cast(PromptType, prompt) + + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: + processor = cached_processor_from_config(model_config) + + return SpeechToTextConfig( + max_audio_clip_s=processor.feature_extractor.chunk_length, + sample_rate=processor.feature_extractor.sampling_rate, + ) + + @classmethod + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> int | None: + processor = cached_processor_from_config(model_config) + hop_length = processor.feature_extractor.hop_length + assert hop_length is not None + return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.dtype = vllm_config.model_config.dtype + + self.model = FireRedASR2Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + logit_scale = getattr(config, "logit_scale", 1.0) + + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + decoder_outputs = self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) + return decoder_outputs + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + audio_input = self._parse_and_validate_audio_input(**kwargs) + + speech = audio_input["input_features"] + speech_lengths = audio_input["speech_lengths"].to(torch.int32) + enc_output = self.model.get_encoder_outputs( + speech=speech, speech_lengths=speech_lengths + ) + + return enc_output + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self.model.decoder.embed_input_ids(input_ids) + + ret = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=_require_is_multimodal(is_multimodal), + ) + return ret + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> FireRedASR2AudioInputs: + input_features = kwargs.pop("input_features", None) + speech_lengths = kwargs.pop("speech_lengths", None) + fake_token_lengths = kwargs.pop("fake_token_lengths", None) + + return FireRedASR2AudioInputs( + input_features=input_features, + speech_lengths=speech_lengths, + fake_token_lengths=fake_token_lengths, + ) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.model.decoder.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, skip_prefixes=["model.encoder.audio_encoder.positional_encoding.pe"] + ) + + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 7f6b7e300227..1e5accaf38bd 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -341,6 +341,10 @@ "ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration", ), + "FireRedASR2ForConditionalGeneration": ( + "fireredasr2", + "FireRedASR2ForConditionalGeneration", + ), "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"), # noqa: E501 "FunAudioChatForConditionalGeneration": ( "funaudiochat", diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index d726fd39a40e..0660a62ea262 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -10,6 +10,9 @@ from vllm.transformers_utils.processors.bagel import BagelProcessor from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor +from vllm.transformers_utils.processors.fireredasr2_processor import ( + FireRedASR2Processor, +) from vllm.transformers_utils.processors.funasr_processor import FunASRProcessor from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor @@ -19,6 +22,7 @@ __all__ = [ "BagelProcessor", "DeepseekVLV2Processor", + "FireRedASR2Processor", "FunASRProcessor", "HunYuanVLProcessor", "HunYuanVLImageProcessor", diff --git a/vllm/transformers_utils/processors/fireredasr2_processor.py b/vllm/transformers_utils/processors/fireredasr2_processor.py new file mode 100644 index 000000000000..67c74ab15921 --- /dev/null +++ b/vllm/transformers_utils/processors/fireredasr2_processor.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import kaldi_native_fbank as knf +import numpy as np +import torch +import torch.nn.functional as F +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + BatchFeature, +) +from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor +from transformers.processing_utils import ProcessorMixin +from transformers.utils import TensorType + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class CMVN: + def __init__(self, dim, means, inverse_std_variences): + self.dim, self.means, self.inverse_std_variences = ( + dim, + np.array(means), + np.array(inverse_std_variences), + ) + + def __call__(self, x): + assert x.shape[-1] == self.dim, "CMVN dim mismatch" + out = x - self.means + out = out * self.inverse_std_variences + return out + + +class KaldifeatFbank: + def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10, dither=1.0): + self.dither = dither + opts = knf.FbankOptions() + opts.frame_opts.dither = dither + opts.mel_opts.num_bins = num_mel_bins + opts.frame_opts.snip_edges = True + opts.mel_opts.debug_mel = False + self.opts = opts + + def __call__(self, sample_rate, wav_np, is_train=False): + dither = self.dither if is_train else 0.0 + self.opts.frame_opts.dither = dither + fbank = knf.OnlineFbank(self.opts) + + fbank.accept_waveform(sample_rate, wav_np.tolist()) + feat = [] + for i in range(fbank.num_frames_ready): + feat.append(fbank.get_frame(i)) + if len(feat) == 0: + print("Check data, len(feat) == 0", wav_np, flush=True) + return np.zeros((0, self.opts.mel_opts.num_bins)) + feat = np.vstack(feat) + return feat + + +class FireRedASR2FeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a FireRedASR2 feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_ + utils.SequenceFeatureExtractor`] which contains most of the main + methods. Users should refer to this superclass for more information + regarding those methods. + + This class extracts mel-filter bank features from raw speech using a custom + numpy implementation of the `Short Time Fourier Transform` which should + match pytorch's `torch.stft` equivalent. + + Args: + feature_size (`int`, *optional*, defaults to 80): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized + expressed in hertz (Hz). + chunk_length (`int`, *optional*, defaults to 30): + The maximum number of chunks of `sampling_rate` samples used to + trim and pad longer or shorter audio sequences. + padding_value (`float`, *optional*, defaults to 0.0): + Padding value used to pad the audio. Should correspond to silences. + dither (`float`, *optional*, defaults to 0.0): + Adds dithering. In other words, adds a small Gaussian noise to each frame. + E.g. use 0.0001 to add dithering with a normal distribution centered + around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range + of raw_speech). The value 0.0 means no dithering. + Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces + the high log_mel_fbank values for signals with hard-zero sections, + when VAD cutoff is present in the signal. + """ + + model_input_names = ["input_features"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + chunk_length=30, + padding_value=0.0, + return_attention_mask=False, + dim=80, + means=None, + inverse_std_variences=None, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, + max_length=3000, + downsample_rate=2, + left_context=3, + right_context=3, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + self.chunk_length = chunk_length + self.max_length = max_length + self.dim = dim + self.means = means + self.inverse_std_variences = inverse_std_variences + self.num_mel_bins = num_mel_bins + self.frame_length = frame_length + self.frame_shift = frame_shift + self.dither = dither + self.sampling_rate = sampling_rate + self.downsample_rate = downsample_rate + self.context = left_context + 1 + right_context + + def __call__( + self, + raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], + truncation: bool = True, + pad_to_multiple_of: int | None = None, + return_tensors: str | TensorType | None = None, + return_attention_mask: bool | None = None, + padding: str | None = "max_length", + max_length: int | None = None, + sampling_rate: int | None = None, + do_normalize: bool | None = None, + **kwargs, + ) -> BatchFeature: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: " + f"{self.__class__.__name__} was trained using a sampling " + f"rate of {self.sampling_rate}. Please make sure that the " + f"provided `raw_speech` input was sampled with " + f"{self.sampling_rate} and not {sampling_rate}." + ) + + def padding_position_is_0(padded_input, input_lengths): + N, T = padded_input.size()[:2] + mask = torch.ones((N, T)).to(padded_input.device) + for i in range(N): + mask[i, input_lengths[i] :] = 0 + mask = mask.unsqueeze(dim=1) + return mask.to(torch.uint8) + + # initialize the CMVN and Fbank objects + self.cmvn = CMVN(self.dim, self.means, self.inverse_std_variences) + self.fbank = KaldifeatFbank( + num_mel_bins=self.num_mel_bins, + frame_length=self.frame_length, + frame_shift=self.frame_shift, + dither=self.dither, + ) + + feats = [] + speech_lengths = [] + fake_token_lengths = [] + for speech in raw_speech: + """ + We must multiply by 32768 here because FireRedASR2 loads audio data + using kaldiio.load_mat, while vLLM loads audio data using librosa. + """ + speech = speech * 32768 + fbank = self.fbank(sampling_rate, speech) + fbank = self.cmvn(fbank) + fbank = torch.from_numpy(fbank).float() + length = fbank.size(0) + feats.append(fbank) + speech_lengths.append(length) + padded_input2 = fbank + padded_input2 = F.pad( + padded_input2, (0, 0, 0, self.context - 1), "constant", 0.0 + ) + src_mask = padding_position_is_0( + padded_input2[None, :, :], torch.tensor([length], dtype=torch.int32) + ) + x_mask = src_mask + mask = x_mask[:, :, :-2:2][:, :, :-2:2] + input_lengths = mask[:, -1, :].sum(dim=-1) + input_lengths = input_lengths // self.downsample_rate + fake_token_len = torch.clamp(input_lengths, min=1) + fake_token_lengths.append(fake_token_len) + + feats = torch.stack(feats, dim=0) + batched_speech = self.pad( + BatchFeature({"input_features": feats}), + padding=padding, + max_length=max_length if max_length else self.max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask or do_normalize, + ) + + if return_tensors is not None: + batched_speech = batched_speech.convert_to_tensors(return_tensors) + + batched_speech["speech_lengths"] = torch.tensor(speech_lengths) + batched_speech["fake_token_lengths"] = torch.concat(fake_token_lengths) + return batched_speech + + +class FireRedASR2Processor(ProcessorMixin): + r""" + Constructs a FireRedASR2 processor which wraps a FireRedASR2 feature extractor and + a FireRedASR2 tokenizer into a single processor. + + [`FireRedASR2Processor`] offers all the functionalities of + [`FireRedASR2FeatureExtractor`] and [`Qwen2Tokenizer`]. See the + [`~FireRedASR2Processor.__call__`] and [`~FireRedASR2Processor.decode`] for more + information. + + Args: + feature_extractor (`FireRedASR2FeatureExtractor`): An instance of + [`FireRedASR2FeatureExtractor`]. + The feature extractor is a required input. + tokenizer (`Qwen2Tokenizer`): + An instance of [`Qwen2Tokenizer`]. The tokenizer is a required + input. + """ + + feature_extractor_class = "FireRedASR2FeatureExtractor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__( + self, + feature_extractor, + tokenizer, + audio_token="<|AUDIO|>", + ): + super().__init__(feature_extractor, tokenizer) + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + self.audio_token = ( + tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token + ) + self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token) + + def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): + return self.tokenizer.get_decoder_prompt_ids( + task=task, language=language, no_timestamps=no_timestamps + ) + + def __call__(self, *args, **kwargs): + """ + Forwards the `audio` argument to FireRedASR2FeatureExtractor's + [`~FireRedASR2FeatureExtractor.__call__`] and the `text` argument to + [`~Qwen2Tokenizer.__call__`]. Please refer to the docstring of the + above two methods for more information. + """ + if self._in_target_context_manager: + return self.current_processor(*args, **kwargs) + + audio = kwargs.pop("audio", None) + sampling_rate = kwargs.pop("sampling_rate", None) + text = kwargs.pop("text", None) + if len(args) > 0: + audio = args[0] + args = args[1:] + + if text is None: + raise ValueError("You need to specify `text` input to process.") + elif isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError( + "Invalid input text. Please provide a string, or a list of strings" + ) + + if audio is not None: + # ensure we have as much audios as audio tokens + num_audio_tokens = sum(sample.count(self.audio_token) for sample in text) + num_audios = 1 if type(audio) is np.ndarray else len(audio) + if num_audio_tokens != num_audios: + raise ValueError( + f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}" # noqa: E501 + ) + inputs = self.feature_extractor( + audio, *args, sampling_rate=sampling_rate, **kwargs + ) + + expanded_text = [] + for sample in text: + replace_str = [] + while self.audio_token in sample: + num_audio_tokens = int(inputs["fake_token_lengths"].item()) + + expanded_audio_token = self.audio_token * num_audio_tokens + + replace_str.append(expanded_audio_token) + sample = sample.replace(self.audio_token, "", 1) + + while "" in sample: + sample = sample.replace("", replace_str.pop(0), 1) + expanded_text.append(sample) + text = expanded_text + + if text is not None: + encodings = self.tokenizer(text, **kwargs) + + if text is None: + return inputs + + elif audio is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + + return inputs + + def get_prompt_ids(self, text: str, return_tensors="np"): + return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors) + + +AutoFeatureExtractor.register( + "FireRedASR2FeatureExtractor", FireRedASR2FeatureExtractor +) +AutoProcessor.register("FireRedASR2Processor", FireRedASR2Processor) From 6e9f21e8a2ba1e53ee4f1cff4844e11ce600f7fa Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 4 Mar 2026 11:50:58 +0800 Subject: [PATCH 45/53] [Chore] Remove debug code in model implementation (#35883) Signed-off-by: Isotr0py --- vllm/model_executor/models/funaudiochat.py | 80 --------------- .../model_executor/models/nano_nemotron_vl.py | 98 ------------------- 2 files changed, 178 deletions(-) diff --git a/vllm/model_executor/models/funaudiochat.py b/vllm/model_executor/models/funaudiochat.py index 5bcb49e075b3..2265d0424e43 100644 --- a/vllm/model_executor/models/funaudiochat.py +++ b/vllm/model_executor/models/funaudiochat.py @@ -13,7 +13,6 @@ from __future__ import annotations -import os from collections.abc import Iterable, Mapping, Sequence from functools import cached_property from typing import Any @@ -924,53 +923,6 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: f"sequence of Tensors (got {type(speech_attention_mask)})" ) - debug = os.getenv("VLLM_FUN_AUDIOCHAT_DEBUG", "") == "1" - if debug: - print( - f"[FunAudioChat] embed_multimodal speech_ids={tuple(speech_ids.shape)} " - f"speech_attention_mask={tuple(speech_attention_mask.shape)}", - flush=True, - ) - attn_impl = getattr( - self.continuous_audio_tower.config, "_attn_implementation", None - ) - print( - f"[FunAudioChat] audio_attn_impl={attn_impl}", - flush=True, - ) - if hasattr(self.continuous_audio_tower, "conv1"): - conv1_w = self.continuous_audio_tower.conv1.weight - print( - f"[FunAudioChat] conv1_w_norm={float(conv1_w.norm().item()):.6g}", - flush=True, - ) - try: - attn0 = self.continuous_audio_tower.layers[0].self_attn - q_norm = float(attn0.q_proj.weight.norm().item()) - k_norm = float(attn0.k_proj.weight.norm().item()) - v_norm = float(attn0.v_proj.weight.norm().item()) - o_norm = float(attn0.out_proj.weight.norm().item()) - print( - f"[FunAudioChat] attn0_q_norm={q_norm:.6g} " - f"k_norm={k_norm:.6g} " - f"v_norm={v_norm:.6g} " - f"o_norm={o_norm:.6g}", - flush=True, - ) - except Exception: - pass - if isinstance(input_features, torch.Tensor): - print( - f"[FunAudioChat] input_features={tuple(input_features.shape)}", - flush=True, - ) - if isinstance(feature_attention_mask, torch.Tensor): - print( - "[FunAudioChat] feature_attention_mask=" - f"{tuple(feature_attention_mask.shape)}", - flush=True, - ) - group_size = int(self.audio_tower.group_size) speech_maxlen = int(speech_ids.shape[-1]) @@ -1019,38 +971,6 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: embeds = tuple( audio_features[i, : int(length)] for i, length in enumerate(lengths) ) - if debug: - embed_lens = [int(t.shape[0]) for t in embeds] - print(f"[FunAudioChat] embed_multimodal out_lens={embed_lens}", flush=True) - if embeds: - t0 = embeds[0] - print( - f"[FunAudioChat] embed0 dtype={t0.dtype} device={t0.device} " - f"nan={bool(torch.isnan(t0).any())} " - f"norm={float(t0.norm().item()):.6g}", - flush=True, - ) - dump_path = os.getenv("VLLM_FUN_AUDIOCHAT_DUMP_PATH", "") - if ( - dump_path - and speech_ids.shape[0] == 1 - and len(embeds) == 1 - and embed_lens[0] > 10 - ): - if not os.path.exists(dump_path): - np.save(dump_path, embeds[0].detach().float().cpu().numpy()) - print(f"[FunAudioChat] dumped embeds to {dump_path}", flush=True) - cont_path = dump_path.replace(".npy", "_cont.npy") - if continuous_audio_features is not None and not os.path.exists( - cont_path - ): - np.save( - cont_path, - continuous_audio_features.detach().float().cpu().numpy(), - ) - print( - f"[FunAudioChat] dumped continuous to {cont_path}", flush=True - ) return embeds def forward( diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 51b36b1cae38..82422e89f0b3 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -2225,104 +2225,6 @@ def is_sound_weights(name: str) -> bool: assert len(sound_weights) > 0 self.sound_encoder.load_weights(sound_weights) - def print_architecture(self, detailed: bool = True, save_to_file: str = None): - """ - Print model architecture with parameter names, shapes, and sizes. - - Args: - detailed: If True, show detailed parameter breakdown - save_to_file: If provided, save output to this file path - """ - import sys - from io import StringIO - - # Capture output if saving to file - original_stdout = sys.stdout - if save_to_file: - sys.stdout = StringIO() - - try: - print("=" * 100) - print("NemotronH_Nano_VL_V2 Model Architecture") - print("=" * 100) - - total_params = 0 - param_groups = { - "language_model": [], - "vision_model": [], - "mlp1": [], - "other": [], - } - - for name, param in self.named_parameters(): - param_size = param.numel() - total_params += param_size - - # Group parameters by main component - if name.startswith("language_model"): - param_groups["language_model"].append( - (name, param.shape, param_size, param.dtype) - ) - elif name.startswith("vision_model"): - param_groups["vision_model"].append( - (name, param.shape, param_size, param.dtype) - ) - elif name.startswith("mlp1"): - param_groups["mlp1"].append( - (name, param.shape, param_size, param.dtype) - ) - else: - param_groups["other"].append( - (name, param.shape, param_size, param.dtype) - ) - - if detailed: - print( - f"{name:<70} | Shape: {str(param.shape):<25} | " - f"Size: {param_size:>12,} | Dtype: {param.dtype}" - ) - - print("=" * 100) - print("Summary by Component:") - print("-" * 60) - - for component, params in param_groups.items(): - if params: # Only show components that have parameters - component_total = sum(size for _, _, size, _ in params) - percentage = ( - (component_total / total_params) * 100 - if total_params > 0 - else 0 - ) - print( - f"{component:<20} | Parameters: {len(params):>4} | " - f"Total Size: {component_total:>15,} | " - f"{percentage:>6.2f}%" - ) - - print("-" * 60) - print(f"{'Total Parameters':<20} | {total_params:>15,}") - - # Estimate memory usage (assuming bfloat16 = 2 bytes per parameter) - memory_mb = total_params * 2 / (1024**2) - memory_gb = memory_mb / 1024 - print(f"{'Est. Memory (MB)':<20} | {memory_mb:>15.2f}") - print(f"{'Est. Memory (GB)':<20} | {memory_gb:>15.2f}") - print("=" * 100) - - # Save to file if requested - if save_to_file: - output = sys.stdout.getvalue() - sys.stdout = original_stdout - with open(save_to_file, "w") as f: - f.write(output) - print(f"Architecture saved to: {save_to_file}") - print(output) # Also print to console - - finally: - if save_to_file and sys.stdout != original_stdout: - sys.stdout = original_stdout - def get_vit_model_from_radio_config(self, hf_config): hf_config_vision = hf_config.vision_config model_name = hf_config_vision.args.get("model") From e3793961674af8bf01208b2216542ad00ae325e6 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 4 Mar 2026 11:53:53 +0800 Subject: [PATCH 46/53] [Refactor] Clean up processor kwargs extraction (#35872) Signed-off-by: DarkLight1337 --- tests/transformers_utils/test_processor.py | 11 +-- vllm/transformers_utils/processor.py | 94 +++++++++++----------- 2 files changed, 55 insertions(+), 50 deletions(-) diff --git a/tests/transformers_utils/test_processor.py b/tests/transformers_utils/test_processor.py index 95ff9a557fa0..a3a1c7841865 100644 --- a/tests/transformers_utils/test_processor.py +++ b/tests/transformers_utils/test_processor.py @@ -7,7 +7,8 @@ from typing_extensions import Unpack from vllm.transformers_utils.processor import ( - get_processor_kwargs_from_processor, + get_processor_kwargs_keys, + get_processor_kwargs_type, ) @@ -35,7 +36,7 @@ def _assert_has_all_expected(keys: set[str]) -> None: assert k in keys -# Path 1: __call__ method has kwargs: Unpack[*ProcessingKwargs] +# Path 1: __call__ method has kwargs: Unpack[*ProcessorKwargs] class _ProcWithUnpack: def __call__(self, *args, **kwargs: Unpack[_FakeProcessorKwargs]): # type: ignore return None @@ -43,11 +44,11 @@ def __call__(self, *args, **kwargs: Unpack[_FakeProcessorKwargs]): # type: igno def test_get_processor_kwargs_from_processor_unpack_path_returns_full_union(): proc = _ProcWithUnpack() - keys = get_processor_kwargs_from_processor(proc) + keys = get_processor_kwargs_keys(get_processor_kwargs_type(proc)) _assert_has_all_expected(keys) -# ---- Path 2: No Unpack, fallback to scanning *ProcessingKwargs in module ---- +# ---- Path 2: No Unpack, fallback to scanning *ProcessorKwargs in module ---- class _ProcWithoutUnpack: @@ -62,5 +63,5 @@ def test_get_processor_kwargs_from_processor_module_scan_returns_full_union(): assert hasattr(mod, "_FakeProcessorKwargs") proc = _ProcWithoutUnpack() - keys = get_processor_kwargs_from_processor(proc) + keys = get_processor_kwargs_keys(get_processor_kwargs_type(proc)) _assert_has_all_expected(keys) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 9bedefd19d20..9190c82f50e6 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -111,29 +111,6 @@ def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]): return processor_cls -@lru_cache -def _collect_dynamic_keys_from_processing_kwargs(kwargs_cls: type) -> set[str]: - dynamic_kwargs: set[str] = set() - if kwargs_cls is None: - return dynamic_kwargs - # get kwargs annotations in processor - # merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs - kwargs_type_annotations = get_type_hints(kwargs_cls) - for kw_type in ("text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"): - if kw_type in kwargs_type_annotations: - # Use __annotations__ instead of get_type_hints() to avoid - # NameError from unresolved forward references (e.g. - # PILImageResampling). We only need key names, not types. - kw_cls = kwargs_type_annotations[kw_type] - kw_annotations: dict[str, Any] = {} - for base in reversed(kw_cls.__mro__): - kw_annotations.update(getattr(base, "__annotations__", {})) - for kw_name in kw_annotations: - dynamic_kwargs.add(kw_name) - dynamic_kwargs |= {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"} - return dynamic_kwargs - - def _merge_mm_kwargs( model_config: "ModelConfig", processor_cls: type | tuple[type, ...], @@ -224,38 +201,63 @@ def get_processor( @lru_cache -def get_processor_kwargs_from_processor(processor: _P) -> set[str]: +def get_processor_kwargs_type( + processor: ProcessorMixin, +) -> type[processing_utils.ProcessingKwargs]: try: # get kwargs annotations in processor - call_kwargs = inspect.signature(type(processor).__call__).parameters.get( - "kwargs" - ) + call_params = inspect.signature(type(processor).__call__).parameters + call_kwargs = call_params.get("kwargs") call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None + # if the processor has explicit kwargs annotation, use it if call_kwargs_annotations not in (None, inspect._empty): # get_type_hints will parse all type annotations at runtime, # and if an annotation refers to a type or # name that hasn’t been imported or defined, it will raise an error. # So we use __annotations__ to get the raw annotations directly. - return _collect_dynamic_keys_from_processing_kwargs( - get_args(call_kwargs_annotations)[0] - ) - # otherwise, try to get from ProcessingKwargs - else: - module_name = type(processor).__module__ - mod = importlib.import_module(module_name) - # find *ProcessingKwargs in the module - processor_kwargs: set[str] = set() - for name, obj in vars(mod).items(): - if name.endswith("ProcessingKwargs"): - processor_kwargs = ( - processor_kwargs - | _collect_dynamic_keys_from_processing_kwargs(obj) - ) - return processor_kwargs + return get_args(call_kwargs_annotations)[0] + + # otherwise, try to get from ProcessorKwargs + module_name = type(processor).__module__ + mod = importlib.import_module(module_name) + for name, obj in vars(mod).items(): + if name.endswith("ProcessorKwargs"): + return obj + except Exception: logger.exception("Failed to collect processor kwargs") - return set() + + return processing_utils.ProcessingKwargs + + +@lru_cache +def get_processor_kwargs_keys( + kwargs_cls: type[processing_utils.ProcessingKwargs], +) -> set[str]: + dynamic_kwargs: set[str] = set() + modality_kwargs = {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"} + + try: + # get kwargs annotations in processor + # merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs + kwargs_type_annotations = get_type_hints(kwargs_cls) + for kw_type in modality_kwargs: + if kw_type in kwargs_type_annotations: + # Use __annotations__ instead of get_type_hints() to avoid + # NameError from unresolved forward references (e.g. + # PILImageResampling). We only need key names, not types. + kw_cls = kwargs_type_annotations[kw_type] + kw_annotations: dict[str, Any] = {} + for base in reversed(kw_cls.__mro__): + kw_annotations.update(getattr(base, "__annotations__", {})) + for kw_name in kw_annotations: + dynamic_kwargs.add(kw_name) + + except Exception: + logger.exception("Failed to collect processor kwargs") + + return dynamic_kwargs | modality_kwargs def cached_get_processor_without_dynamic_kwargs( @@ -275,7 +277,9 @@ def cached_get_processor_without_dynamic_kwargs( ) # Step 2: use temporary processor collect dynamic keys - dynamic_keys = get_processor_kwargs_from_processor(processor) + dynamic_keys = get_processor_kwargs_keys( + get_processor_kwargs_type(processor) # type: ignore[arg-type] + ) # Step 3: use dynamic_keys filter kwargs filtered_kwargs = {k: v for k, v in kwargs.items() if k not in dynamic_keys} From edba15045a7419922e7a2e21e5a684682b5b8e05 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Tue, 3 Mar 2026 22:12:51 -0600 Subject: [PATCH 47/53] [Bugfix] Guard mm_token_type_ids kwarg in get_mrope_input_positions (#35711) Signed-off-by: Andreas Karatzas --- .../models/transformers/multimodal.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 3360ce59a763..beacb8266e59 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -474,7 +474,19 @@ def get_mrope_input_positions( # can't accept arbitrary args, even if its value is `None` kwargs = {} if mm_token_type_ids: - kwargs["mm_token_type_ids"] = torch.cat(mm_token_type_ids) + if not hasattr(self, "_get_rope_index_accepts_mm_token_type_ids"): + import inspect + + sig = inspect.signature(self.model.get_rope_index) + params = sig.parameters + self._get_rope_index_accepts_mm_token_type_ids = ( + "mm_token_type_ids" in params + or any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() + ) + ) + if self._get_rope_index_accepts_mm_token_type_ids: + kwargs["mm_token_type_ids"] = torch.cat(mm_token_type_ids) mrope_positions, mrope_position_delta = self.model.get_rope_index( input_ids=torch.tensor(input_tokens).unsqueeze(0), From 3c85cd9d74627735413065c40676205085d76085 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Tue, 3 Mar 2026 22:50:13 -0600 Subject: [PATCH 48/53] [Rocm][CI] Fix ROCm LM Eval Large Models (8 Card) (#35913) Signed-off-by: charlifu --- .buildkite/lm-eval-harness/configs/models-large-rocm.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/.buildkite/lm-eval-harness/configs/models-large-rocm.txt b/.buildkite/lm-eval-harness/configs/models-large-rocm.txt index a9a60f348d6a..4fb0b84bc4d8 100644 --- a/.buildkite/lm-eval-harness/configs/models-large-rocm.txt +++ b/.buildkite/lm-eval-harness/configs/models-large-rocm.txt @@ -1,2 +1 @@ Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml -Qwen3-235B-A22B-Instruct-2507-FP8.yaml From 7cdba98edf15f695d74f50a0fbe6882eb393f5cf Mon Sep 17 00:00:00 2001 From: ShiJie Zhong <62382570+ZhongsJie@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:24:46 +0800 Subject: [PATCH 49/53] [BugFix] Support tool_choice=none in the Anthropic API (#35835) Signed-off-by: ZhongsJie Co-authored-by: Chauncey --- vllm/entrypoints/anthropic/protocol.py | 2 +- vllm/entrypoints/anthropic/serving.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/anthropic/protocol.py b/vllm/entrypoints/anthropic/protocol.py index 19ca28f1d495..c541db5139d3 100644 --- a/vllm/entrypoints/anthropic/protocol.py +++ b/vllm/entrypoints/anthropic/protocol.py @@ -77,7 +77,7 @@ def validate_input_schema(cls, v): class AnthropicToolChoice(BaseModel): """Tool Choice definition""" - type: Literal["auto", "any", "tool"] + type: Literal["auto", "any", "tool", "none"] name: str | None = None @model_validator(mode="after") diff --git a/vllm/entrypoints/anthropic/serving.py b/vllm/entrypoints/anthropic/serving.py index f0110de38cb4..85232e9185f5 100644 --- a/vllm/entrypoints/anthropic/serving.py +++ b/vllm/entrypoints/anthropic/serving.py @@ -349,6 +349,8 @@ def _convert_tool_choice( req.tool_choice = "auto" elif tool_choice_type == "any": req.tool_choice = "required" + elif tool_choice_type == "none": + req.tool_choice = "none" elif tool_choice_type == "tool": req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate( { From 097eb544e9a22810c9b7a59e586b61627b308362 Mon Sep 17 00:00:00 2001 From: lailoo <1811866786@qq.com> Date: Wed, 4 Mar 2026 13:54:32 +0800 Subject: [PATCH 50/53] [Bugfix] Improve engine ready timeout error message (#35616) Signed-off-by: damaozi <1811866786@qq.com> --- vllm/v1/engine/core_client.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index e19b31396b9b..7e1f1cf418bf 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -609,8 +609,13 @@ def __init__( timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms ): raise TimeoutError( - "Timed out waiting for engines to send " - "initial message on input socket." + f"Timed out waiting for engine core processes to " + f"start. This is often caused by slow weight loading " + f"for large models. Waited " + f"{VLLM_ENGINE_READY_TIMEOUT_S}s (configured by " + f"VLLM_ENGINE_READY_TIMEOUT_S). To increase the " + f"timeout, set the environment variable: " + f"VLLM_ENGINE_READY_TIMEOUT_S=" ) identity, _ = sync_input_socket.recv_multipart() identities.remove(identity) @@ -1586,8 +1591,12 @@ async def _scale_up_elastic_ep( timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms ): raise TimeoutError( - "Timed out waiting for new engines to send initial " - "message on input socket." + f"Timed out waiting for new engine core processes to " + f"start. Waited " + f"{VLLM_ENGINE_READY_TIMEOUT_S}s (configured by " + f"VLLM_ENGINE_READY_TIMEOUT_S). To increase the " + f"timeout, set the environment variable: " + f"VLLM_ENGINE_READY_TIMEOUT_S=" ) identity, _ = sync_input_socket.recv_multipart() new_engine_identities.discard(identity) From 9e0f44bec449df17d30ed9abef7aeedc059ddfde Mon Sep 17 00:00:00 2001 From: Komal Kumar Teru <162363718+kkt-cohere@users.noreply.github.com> Date: Wed, 4 Mar 2026 12:50:15 +0530 Subject: [PATCH 51/53] [cohere][fix][spec-decode]: fix crash when allowed_token_ids is set without penalties (#35654) Signed-off-by: kkt-cohere --- vllm/v1/sample/rejection_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 278d421eb910..d3e8573458b1 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -271,7 +271,7 @@ def apply_logits_processors( # Calculate indices of target logits. if sampling_metadata.allowed_token_ids_mask is not None or has_penalties: - num_requests = len(sampling_metadata.output_token_ids) + num_requests = len(metadata.num_draft_tokens) num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu") original_indices = torch.arange(num_requests, device="cpu") repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens) From 5d199ac8f25a56495c24dcd8e6a63843002bba40 Mon Sep 17 00:00:00 2001 From: Andrii Skliar Date: Wed, 4 Mar 2026 08:20:33 +0100 Subject: [PATCH 52/53] Support Audio Extraction from MP4 Video for Nemotron Nano VL (#35539) Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Signed-off-by: Andrii Skliar Signed-off-by: Lucas Wilkinson Signed-off-by: Matthew Bonanni Signed-off-by: Lucas Wilkinson Signed-off-by: wangxiyuan Signed-off-by: Andrii Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Co-authored-by: Andrii Skliar Co-authored-by: Andrii Co-authored-by: root Co-authored-by: Roger Wang Co-authored-by: root Co-authored-by: Lucas Wilkinson Co-authored-by: Matthew Bonanni Co-authored-by: Tyler Michael Smith Co-authored-by: wangxiyuan Co-authored-by: root --- setup.py | 1 + vllm/model_executor/models/config.py | 10 ++ .../model_executor/models/nano_nemotron_vl.py | 130 +++++++++++++++++- vllm/multimodal/media/audio.py | 58 ++++++++ vllm/multimodal/video.py | 27 ++++ 5 files changed, 225 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 556a511a3429..f31b4cf24f7e 100644 --- a/setup.py +++ b/setup.py @@ -1056,6 +1056,7 @@ def _read_requirements(filename: str) -> list[str]: "scipy", "soundfile", "mistral_common[audio]", + "av", ], # Required for audio processing "video": [], # Kept for backwards compatibility "flashinfer": [], # Kept for backwards compatibility diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index ef241d545c8c..ec03d283fed1 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -622,6 +622,15 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype +class NemotronHNanoVLV2Config(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + mm_config = model_config.multimodal_config + if mm_config is not None: + video_kwargs = mm_config.media_io_kwargs.setdefault("video", {}) + video_kwargs.setdefault("video_backend", "nemotron_vl") + + class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: @@ -661,6 +670,7 @@ def verify_and_update_model_config(model_config: "ModelConfig") -> None: "GteNewModel": GteNewModelConfig, "GteNewForSequenceClassification": GteNewModelConfig, "Gemma3TextModel": Gemma3TextModelConfig, + "NemotronH_Nano_VL_V2": NemotronHNanoVLV2Config, "LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig, "LlamaBidirectionalModel": LlamaBidirectionalConfig, "LlamaNemotronVLModel": LlamaNemotronVLConfig, diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 82422e89f0b3..9b9beadc099e 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -59,9 +59,11 @@ AudioItem, MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargsItems, VideoItem, ) +from vllm.multimodal.media.audio import extract_audio_from_video_bytes from vllm.multimodal.parse import ( AudioProcessorItems, ImageEmbeddingItems, @@ -69,8 +71,13 @@ ImageSize, MultiModalDataItems, MultiModalDataParser, + VideoProcessorItems, +) +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + ProcessorInputs, + TimingContext, ) -from vllm.multimodal.processing import BaseDummyInputsBuilder from vllm.multimodal.processing.processor import ( BaseMultiModalProcessor, BaseProcessingInfo, @@ -1381,6 +1388,127 @@ class NanoNemotronVLMultiModalProcessor( ): """MultiModalProcessor extended for video support""" + def _extract_audio_from_videos( + self, + mm_items: MultiModalDataItems, + ) -> tuple[MultiModalDataItems, list[AudioItem]]: + """Extract audio tracks from video bytes in *mm_items*. + + Returns: + The augmented *mm_items* (with audio added) and the list of + extracted audio items. + """ + videos = mm_items.get_items("video", VideoProcessorItems) + assert isinstance(videos.metadata, list) + metadata_list = videos.metadata + + audio_items: list[AudioItem] = [] + for metadata in metadata_list: + video_bytes = metadata.get("original_video_bytes") + if video_bytes is None or len(video_bytes) == 0: + raise ValueError( + "Cannot extract audio from video: original_video_bytes is " + "missing or empty. When using use_audio_in_video=True, " + "video must be loaded with keep_video_bytes=True (e.g. via " + "the chat API with a model that sets use_audio_in_video)." + ) + audio_items.append(extract_audio_from_video_bytes(video_bytes)) + + # Create a new VideoProcessorItems with metadata that does not contain + # the large video bytes, to avoid modifying the input `mm_items`. + new_metadata_list = [ + {k: v for k, v in meta.items() if k != "original_video_bytes"} + for meta in metadata_list + ] + new_videos = VideoProcessorItems(data=videos.data, metadata=new_metadata_list) + + audio_parsed = self.data_parser.parse_mm_data({"audio": audio_items}) + + # Create a new MultiModalDataItems with the new video and audio items. + new_mm_items_dict = {**mm_items, **audio_parsed, "video": new_videos} + mm_items = MultiModalDataItems(new_mm_items_dict) + + return mm_items, audio_items + + def apply( + self, + processor_inputs: ProcessorInputs, + timing_ctx: TimingContext | None = None, + ) -> MultiModalInputs: + if (hf_processor_mm_kwargs := processor_inputs.hf_processor_mm_kwargs) is None: + hf_processor_mm_kwargs = {} + + use_audio_in_video = bool( + hf_processor_mm_kwargs.get("use_audio_in_video", False) + ) + + hf_processor_mm_kwargs = { + k: v for k, v in hf_processor_mm_kwargs.items() if k != "use_audio_in_video" + } + + processor_inputs.hf_processor_mm_kwargs = hf_processor_mm_kwargs + + if not ( + use_audio_in_video + and "video" in processor_inputs.mm_data_items + and "audio" not in processor_inputs.mm_data_items + ): + return super().apply( + processor_inputs, + timing_ctx, + ) + + mm_items, audio_items = self._extract_audio_from_videos( + processor_inputs.mm_data_items + ) + processor_inputs.mm_data_items = mm_items + + prompt = processor_inputs.prompt + tokenizer = self.info.get_tokenizer() + if not isinstance(prompt, str): + prompt = tokenizer.decode(prompt, skip_special_tokens=False) + + for _ in audio_items: + prompt = prompt.replace("