From 0cfb2e24735e6cc47abc31d73833754bf0dda7eb Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 8 Apr 2025 13:50:52 -0700 Subject: [PATCH 1/6] feat: add batch inference API to llama stack inference --- docs/_static/llama-stack-spec.html | 297 +++++++++++++----- docs/_static/llama-stack-spec.yaml | 256 ++++++++++----- .../apis/batch_inference/batch_inference.py | 33 +- llama_stack/apis/inference/inference.py | 34 ++ llama_stack/distribution/resolver.py | 3 + llama_stack/distribution/routers/routers.py | 40 +++ llama_stack/models/llama/llama3/generation.py | 9 +- llama_stack/models/llama/llama4/generation.py | 2 +- .../inline/inference/meta_reference/config.py | 5 +- .../inference/meta_reference/generators.py | 129 ++------ .../inference/meta_reference/inference.py | 292 ++++++++++++----- .../meta_reference/model_parallel.py | 26 +- .../meta_reference/parallel_utils.py | 8 +- .../sentence_transformers.py | 23 ++ .../providers/registry/batch_inference.py | 39 +++ .../remote/inference/ollama/ollama.py | 22 ++ .../providers/remote/inference/vllm/vllm.py | 22 ++ .../utils/inference/litellm_openai_mixin.py | 22 ++ llama_stack/schema_utils.py | 3 + .../meta-reference-gpu/run-with-safety.yaml | 6 +- .../templates/meta-reference-gpu/run.yaml | 3 +- .../inference/test_batch_inference.py | 111 +++++++ .../test_cases/inference/chat_completion.json | 26 ++ .../test_cases/inference/completion.json | 13 + 24 files changed, 1044 insertions(+), 380 deletions(-) create mode 100644 llama_stack/providers/registry/batch_inference.py create mode 100644 tests/integration/inference/test_batch_inference.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 36bfad49e2..ff7f492e7f 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -85,7 +85,7 @@ } } }, - "/v1/batch-inference/chat-completion": { + "/v1/inference/batch-chat-completion": { "post": { "responses": { "200": { @@ -112,7 +112,7 @@ } }, "tags": [ - "BatchInference (Coming Soon)" + "Inference" ], "description": "", "parameters": [], @@ -128,7 +128,7 @@ } } }, - "/v1/batch-inference/completion": { + "/v1/batch-inference/chat-completion-inline": { "post": { "responses": { "200": { @@ -136,7 +136,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/BatchCompletionResponse" + "$ref": "#/components/schemas/BatchChatCompletionResponse" } } } @@ -159,6 +159,49 @@ ], "description": "", "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BatchChatCompletionInlineRequest" + } + } + }, + "required": true + } + } + }, + "/v1/inference/batch-completion": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BatchCompletionResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Inference" + ], + "description": "", + "parameters": [], "requestBody": { "content": { "application/json": { @@ -171,6 +214,49 @@ } } }, + "/v1/batch-inference/completion-inline": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BatchCompletionResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "BatchInference (Coming Soon)" + ], + "description": "", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BatchCompletionInlineRequest" + } + } + }, + "required": true + } + } + }, "/v1/post-training/job/cancel": { "post": { "responses": { @@ -4366,6 +4452,51 @@ ], "title": "ToolCall" }, + "ToolConfig": { + "type": "object", + "properties": { + "tool_choice": { + "oneOf": [ + { + "type": "string", + "enum": [ + "auto", + "required", + "none" + ], + "title": "ToolChoice", + "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model." + }, + { + "type": "string" + } + ], + "default": "auto", + "description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto." + }, + "tool_prompt_format": { + "type": "string", + "enum": [ + "json", + "function_tag", + "python_list" + ], + "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls." + }, + "system_message_behavior": { + "type": "string", + "enum": [ + "append", + "replace" + ], + "description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.", + "default": "append" + } + }, + "additionalProperties": false, + "title": "ToolConfig", + "description": "Configuration for tool use." + }, "ToolDefinition": { "type": "object", "properties": { @@ -4554,7 +4685,7 @@ "BatchChatCompletionRequest": { "type": "object", "properties": { - "model": { + "model_id": { "type": "string" }, "messages_batch": { @@ -4575,25 +4706,8 @@ "$ref": "#/components/schemas/ToolDefinition" } }, - "tool_choice": { - "type": "string", - "enum": [ - "auto", - "required", - "none" - ], - "title": "ToolChoice", - "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model." - }, - "tool_prompt_format": { - "type": "string", - "enum": [ - "json", - "function_tag", - "python_list" - ], - "title": "ToolPromptFormat", - "description": "Prompt format for calling custom / zero shot tools." + "tool_config": { + "$ref": "#/components/schemas/ToolConfig" }, "response_format": { "$ref": "#/components/schemas/ResponseFormat" @@ -4613,7 +4727,7 @@ }, "additionalProperties": false, "required": [ - "model", + "model_id", "messages_batch" ], "title": "BatchChatCompletionRequest" @@ -4707,12 +4821,62 @@ "title": "TokenLogProbs", "description": "Log probabilities for generated tokens." }, - "BatchCompletionRequest": { + "BatchChatCompletionInlineRequest": { "type": "object", "properties": { "model": { "type": "string" }, + "messages_batch": { + "type": "array", + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Message" + } + } + }, + "sampling_params": { + "$ref": "#/components/schemas/SamplingParams" + }, + "tools": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolDefinition" + } + }, + "tool_config": { + "$ref": "#/components/schemas/ToolConfig" + }, + "response_format": { + "$ref": "#/components/schemas/ResponseFormat" + }, + "logprobs": { + "type": "object", + "properties": { + "top_k": { + "type": "integer", + "default": 0, + "description": "How many tokens (for each position) to return log probabilities for." + } + }, + "additionalProperties": false, + "title": "LogProbConfig" + } + }, + "additionalProperties": false, + "required": [ + "model", + "messages_batch" + ], + "title": "BatchChatCompletionInlineRequest" + }, + "BatchCompletionRequest": { + "type": "object", + "properties": { + "model_id": { + "type": "string" + }, "content_batch": { "type": "array", "items": { @@ -4740,7 +4904,7 @@ }, "additionalProperties": false, "required": [ - "model", + "model_id", "content_batch" ], "title": "BatchCompletionRequest" @@ -4799,63 +4963,56 @@ "title": "CompletionResponse", "description": "Response from a completion request." }, - "CancelTrainingJobRequest": { + "BatchCompletionInlineRequest": { "type": "object", "properties": { - "job_uuid": { + "model": { "type": "string" + }, + "content_batch": { + "type": "array", + "items": { + "$ref": "#/components/schemas/InterleavedContent" + } + }, + "sampling_params": { + "$ref": "#/components/schemas/SamplingParams" + }, + "response_format": { + "$ref": "#/components/schemas/ResponseFormat" + }, + "logprobs": { + "type": "object", + "properties": { + "top_k": { + "type": "integer", + "default": 0, + "description": "How many tokens (for each position) to return log probabilities for." + } + }, + "additionalProperties": false, + "title": "LogProbConfig" } }, "additionalProperties": false, "required": [ - "job_uuid" + "model", + "content_batch" ], - "title": "CancelTrainingJobRequest" + "title": "BatchCompletionInlineRequest" }, - "ToolConfig": { + "CancelTrainingJobRequest": { "type": "object", "properties": { - "tool_choice": { - "oneOf": [ - { - "type": "string", - "enum": [ - "auto", - "required", - "none" - ], - "title": "ToolChoice", - "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model." - }, - { - "type": "string" - } - ], - "default": "auto", - "description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto." - }, - "tool_prompt_format": { - "type": "string", - "enum": [ - "json", - "function_tag", - "python_list" - ], - "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls." - }, - "system_message_behavior": { - "type": "string", - "enum": [ - "append", - "replace" - ], - "description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.", - "default": "append" + "job_uuid": { + "type": "string" } }, "additionalProperties": false, - "title": "ToolConfig", - "description": "Configuration for tool use." + "required": [ + "job_uuid" + ], + "title": "CancelTrainingJobRequest" }, "ChatCompletionRequest": { "type": "object", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 82faf450a0..279e240ee1 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -40,7 +40,7 @@ paths: schema: $ref: '#/components/schemas/AppendRowsRequest' required: true - /v1/batch-inference/chat-completion: + /v1/inference/batch-chat-completion: post: responses: '200': @@ -60,7 +60,7 @@ paths: default: $ref: '#/components/responses/DefaultError' tags: - - BatchInference (Coming Soon) + - Inference description: '' parameters: [] requestBody: @@ -69,7 +69,7 @@ paths: schema: $ref: '#/components/schemas/BatchChatCompletionRequest' required: true - /v1/batch-inference/completion: + /v1/batch-inference/chat-completion-inline: post: responses: '200': @@ -77,7 +77,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/BatchCompletionResponse' + $ref: '#/components/schemas/BatchChatCompletionResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -92,12 +92,70 @@ paths: - BatchInference (Coming Soon) description: '' parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/BatchChatCompletionInlineRequest' + required: true + /v1/inference/batch-completion: + post: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/BatchCompletionResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Inference + description: '' + parameters: [] requestBody: content: application/json: schema: $ref: '#/components/schemas/BatchCompletionRequest' required: true + /v1/batch-inference/completion-inline: + post: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/BatchCompletionResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - BatchInference (Coming Soon) + description: '' + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/BatchCompletionInlineRequest' + required: true /v1/post-training/job/cancel: post: responses: @@ -3009,6 +3067,54 @@ components: - tool_name - arguments title: ToolCall + ToolConfig: + type: object + properties: + tool_choice: + oneOf: + - type: string + enum: + - auto + - required + - none + title: ToolChoice + description: >- + Whether tool use is required or automatic. This is a hint to the model + which may not be followed. It depends on the Instruction Following + capabilities of the model. + - type: string + default: auto + description: >- + (Optional) Whether tool use is automatic, required, or none. Can also + specify a tool name to use a specific tool. Defaults to ToolChoice.auto. + tool_prompt_format: + type: string + enum: + - json + - function_tag + - python_list + description: >- + (Optional) Instructs the model how to format tool calls. By default, Llama + Stack will attempt to use a format that is best adapted to the model. + - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. + - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a + tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python + syntax -- a list of function calls. + system_message_behavior: + type: string + enum: + - append + - replace + description: >- + (Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: + Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: + Replaces the default system prompt with the provided system message. The + system message can include the string '{{function_definitions}}' to indicate + where the function definitions should be inserted. + default: append + additionalProperties: false + title: ToolConfig + description: Configuration for tool use. ToolDefinition: type: object properties: @@ -3145,7 +3251,7 @@ components: BatchChatCompletionRequest: type: object properties: - model: + model_id: type: string messages_batch: type: array @@ -3159,26 +3265,8 @@ components: type: array items: $ref: '#/components/schemas/ToolDefinition' - tool_choice: - type: string - enum: - - auto - - required - - none - title: ToolChoice - description: >- - Whether tool use is required or automatic. This is a hint to the model - which may not be followed. It depends on the Instruction Following capabilities - of the model. - tool_prompt_format: - type: string - enum: - - json - - function_tag - - python_list - title: ToolPromptFormat - description: >- - Prompt format for calling custom / zero shot tools. + tool_config: + $ref: '#/components/schemas/ToolConfig' response_format: $ref: '#/components/schemas/ResponseFormat' logprobs: @@ -3193,7 +3281,7 @@ components: title: LogProbConfig additionalProperties: false required: - - model + - model_id - messages_batch title: BatchChatCompletionRequest BatchChatCompletionResponse: @@ -3258,11 +3346,47 @@ components: - logprobs_by_token title: TokenLogProbs description: Log probabilities for generated tokens. - BatchCompletionRequest: + BatchChatCompletionInlineRequest: type: object properties: model: type: string + messages_batch: + type: array + items: + type: array + items: + $ref: '#/components/schemas/Message' + sampling_params: + $ref: '#/components/schemas/SamplingParams' + tools: + type: array + items: + $ref: '#/components/schemas/ToolDefinition' + tool_config: + $ref: '#/components/schemas/ToolConfig' + response_format: + $ref: '#/components/schemas/ResponseFormat' + logprobs: + type: object + properties: + top_k: + type: integer + default: 0 + description: >- + How many tokens (for each position) to return log probabilities for. + additionalProperties: false + title: LogProbConfig + additionalProperties: false + required: + - model + - messages_batch + title: BatchChatCompletionInlineRequest + BatchCompletionRequest: + type: object + properties: + model_id: + type: string content_batch: type: array items: @@ -3283,7 +3407,7 @@ components: title: LogProbConfig additionalProperties: false required: - - model + - model_id - content_batch title: BatchCompletionRequest BatchCompletionResponse: @@ -3326,63 +3450,43 @@ components: - stop_reason title: CompletionResponse description: Response from a completion request. - CancelTrainingJobRequest: + BatchCompletionInlineRequest: type: object properties: - job_uuid: + model: type: string + content_batch: + type: array + items: + $ref: '#/components/schemas/InterleavedContent' + sampling_params: + $ref: '#/components/schemas/SamplingParams' + response_format: + $ref: '#/components/schemas/ResponseFormat' + logprobs: + type: object + properties: + top_k: + type: integer + default: 0 + description: >- + How many tokens (for each position) to return log probabilities for. + additionalProperties: false + title: LogProbConfig additionalProperties: false required: - - job_uuid - title: CancelTrainingJobRequest - ToolConfig: + - model + - content_batch + title: BatchCompletionInlineRequest + CancelTrainingJobRequest: type: object properties: - tool_choice: - oneOf: - - type: string - enum: - - auto - - required - - none - title: ToolChoice - description: >- - Whether tool use is required or automatic. This is a hint to the model - which may not be followed. It depends on the Instruction Following - capabilities of the model. - - type: string - default: auto - description: >- - (Optional) Whether tool use is automatic, required, or none. Can also - specify a tool name to use a specific tool. Defaults to ToolChoice.auto. - tool_prompt_format: - type: string - enum: - - json - - function_tag - - python_list - description: >- - (Optional) Instructs the model how to format tool calls. By default, Llama - Stack will attempt to use a format that is best adapted to the model. - - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a - tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python - syntax -- a list of function calls. - system_message_behavior: + job_uuid: type: string - enum: - - append - - replace - description: >- - (Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: - Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: - Replaces the default system prompt with the provided system message. The - system message can include the string '{{function_definitions}}' to indicate - where the function definitions should be inserted. - default: append additionalProperties: false - title: ToolConfig - description: Configuration for tool use. + required: + - job_uuid + title: CancelTrainingJobRequest ChatCompletionRequest: type: object properties: diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 330a683ba7..57fcd7ebb2 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -6,37 +6,24 @@ from typing import List, Optional, Protocol, runtime_checkable -from pydantic import BaseModel - from llama_stack.apis.inference import ( - ChatCompletionResponse, - CompletionResponse, + BatchChatCompletionResponse, + BatchCompletionResponse, InterleavedContent, LogProbConfig, Message, ResponseFormat, SamplingParams, - ToolChoice, + ToolConfig, ToolDefinition, - ToolPromptFormat, ) -from llama_stack.schema_utils import json_schema_type, webmethod - - -@json_schema_type -class BatchCompletionResponse(BaseModel): - batch: List[CompletionResponse] - - -@json_schema_type -class BatchChatCompletionResponse(BaseModel): - batch: List[ChatCompletionResponse] +from llama_stack.schema_utils import webmethod @runtime_checkable class BatchInference(Protocol): - @webmethod(route="/batch-inference/completion", method="POST") - async def batch_completion( + @webmethod(route="/batch-inference/completion-inline", method="POST") + async def batch_completion_inline( self, model: str, content_batch: List[InterleavedContent], @@ -45,16 +32,14 @@ async def batch_completion( logprobs: Optional[LogProbConfig] = None, ) -> BatchCompletionResponse: ... - @webmethod(route="/batch-inference/chat-completion", method="POST") - async def batch_chat_completion( + @webmethod(route="/batch-inference/chat-completion-inline", method="POST") + async def batch_chat_completion_inline( self, model: str, messages_batch: List[List[Message]], sampling_params: Optional[SamplingParams] = None, - # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = list, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, + tool_config: Optional[ToolConfig] = None, response_format: Optional[ResponseFormat] = None, logprobs: Optional[LogProbConfig] = None, ) -> BatchChatCompletionResponse: ... diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 3390a3fef6..21753ca236 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -681,6 +681,16 @@ class EmbeddingTaskType(Enum): document = "document" +@json_schema_type +class BatchCompletionResponse(BaseModel): + batch: List[CompletionResponse] + + +@json_schema_type +class BatchChatCompletionResponse(BaseModel): + batch: List[ChatCompletionResponse] + + @runtime_checkable @trace_protocol class Inference(Protocol): @@ -716,6 +726,17 @@ async def completion( """ ... + @webmethod(route="/inference/batch-completion", method="POST", experimental=True) + async def batch_completion( + self, + model_id: str, + content_batch: List[InterleavedContent], + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ) -> BatchCompletionResponse: + raise NotImplementedError("Batch completion is not implemented") + @webmethod(route="/inference/chat-completion", method="POST") async def chat_completion( self, @@ -756,6 +777,19 @@ async def chat_completion( """ ... + @webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True) + async def batch_chat_completion( + self, + model_id: str, + messages_batch: List[List[Message]], + sampling_params: Optional[SamplingParams] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_config: Optional[ToolConfig] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ) -> BatchChatCompletionResponse: + raise NotImplementedError("Batch chat completion is not implemented") + @webmethod(route="/inference/embeddings", method="POST") async def embeddings( self, diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 33ad343ec8..177d20f2bb 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -400,6 +400,9 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: mro = type(obj).__mro__ for name, value in inspect.getmembers(protocol): if inspect.isfunction(value) and hasattr(value, "__webmethod__"): + if value.__webmethod__.experimental: + continue + if not hasattr(obj, name): missing_methods.append((name, "missing")) elif not callable(getattr(obj, name)): diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index bc313036f6..b9623ef3ca 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -17,6 +17,8 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job from llama_stack.apis.inference import ( + BatchChatCompletionResponse, + BatchCompletionResponse, ChatCompletionResponse, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, @@ -334,6 +336,30 @@ async def stream_generator(): response.metrics = metrics if response.metrics is None else response.metrics + metrics return response + async def batch_chat_completion( + self, + model_id: str, + messages_batch: List[List[Message]], + tools: Optional[List[ToolDefinition]] = None, + tool_config: Optional[ToolConfig] = None, + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ) -> BatchChatCompletionResponse: + logger.debug( + f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", + ) + provider = self.routing_table.get_provider_impl(model_id) + return await provider.batch_chat_completion( + model_id=model_id, + messages_batch=messages_batch, + tools=tools, + tool_config=tool_config, + sampling_params=sampling_params, + response_format=response_format, + logprobs=logprobs, + ) + async def completion( self, model_id: str, @@ -398,6 +424,20 @@ async def stream_generator(): response.metrics = metrics if response.metrics is None else response.metrics + metrics return response + async def batch_completion( + self, + model_id: str, + content_batch: List[InterleavedContent], + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ) -> BatchCompletionResponse: + logger.debug( + f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", + ) + provider = self.routing_table.get_provider_impl(model_id) + return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs) + async def embeddings( self, model_id: str, diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py index 8c6aa242b9..98412a1d46 100644 --- a/llama_stack/models/llama/llama3/generation.py +++ b/llama_stack/models/llama/llama3/generation.py @@ -140,7 +140,12 @@ def build_model(): return Llama3(model, tokenizer, model_args) - def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs): + def __init__( + self, + model: Transformer | CrossAttentionTransformer, + tokenizer: Tokenizer, + args: ModelArgs, + ): self.args = args self.model = model self.tokenizer = tokenizer @@ -285,7 +290,7 @@ def generate( source="output", logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None), batch_idx=idx, - finished=eos_reached[idx], + finished=eos_reached[idx].item(), ignore_token=cur_pos < len(prompt_tokens[idx]), ) ) diff --git a/llama_stack/models/llama/llama4/generation.py b/llama_stack/models/llama/llama4/generation.py index 7a4087c8f7..8e94bb33a8 100644 --- a/llama_stack/models/llama/llama4/generation.py +++ b/llama_stack/models/llama/llama4/generation.py @@ -233,7 +233,7 @@ def generate( source="output", logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None), batch_idx=idx, - finished=eos_reached[idx], + finished=eos_reached[idx].item(), ignore_token=cur_pos < len(prompt_tokens[idx]), ) ) diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 3156675067..6f796d0d4b 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -52,14 +52,17 @@ def sample_run_config( checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}", quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}", model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}", + max_batch_size: str = "${env.MAX_BATCH_SIZE:1}", + max_seq_len: str = "${env.MAX_SEQ_LEN:4096}", **kwargs, ) -> Dict[str, Any]: return { "model": model, - "max_seq_len": 4096, "checkpoint_dir": checkpoint_dir, "quantization": { "type": quantization_type, }, "model_parallel_size": model_parallel_size, + "max_batch_size": max_batch_size, + "max_seq_len": max_seq_len, } diff --git a/llama_stack/providers/inline/inference/meta_reference/generators.py b/llama_stack/providers/inline/inference/meta_reference/generators.py index 34dd58a9a3..0a928ce734 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -22,7 +22,7 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer -from llama_stack.models.llama.sku_types import Model +from llama_stack.models.llama.sku_types import Model, ModelFamily from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, @@ -113,8 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent): return get_default_tool_prompt_format(request.model) -# TODO: combine Llama3 and Llama4 generators since they are almost identical now -class Llama4Generator: +class LlamaGenerator: def __init__( self, config: MetaReferenceInferenceConfig, @@ -144,7 +143,8 @@ def __init__( else: quantization_mode = None - self.inner_generator = Llama4.build( + cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3 + self.inner_generator = cls.build( ckpt_dir=ckpt_dir, max_seq_len=config.max_seq_len, max_batch_size=config.max_batch_size, @@ -158,142 +158,55 @@ def __init__( def completion( self, - request: CompletionRequestWithRawContent, + request_batch: List[CompletionRequestWithRawContent], ) -> Generator: - sampling_params = request.sampling_params or SamplingParams() + first_request = request_batch[0] + sampling_params = first_request.sampling_params or SamplingParams() max_gen_len = sampling_params.max_tokens if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) for result in self.inner_generator.generate( - llm_inputs=[self.formatter.encode_content(request.content)], + llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, - logprobs=bool(request.logprobs), + logprobs=bool(first_request.logprobs), echo=False, logits_processor=get_logits_processor( self.tokenizer, self.args.vocab_size, - request.response_format, + first_request.response_format, ), ): - yield result[0] + yield result def chat_completion( self, - request: ChatCompletionRequestWithRawContent, + request_batch: List[ChatCompletionRequestWithRawContent], ) -> Generator: - sampling_params = request.sampling_params or SamplingParams() + first_request = request_batch[0] + sampling_params = first_request.sampling_params or SamplingParams() max_gen_len = sampling_params.max_tokens if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) for result in self.inner_generator.generate( - llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))], + llm_inputs=[ + self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)) + for request in request_batch + ], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, - logprobs=bool(request.logprobs), + logprobs=bool(first_request.logprobs), echo=False, logits_processor=get_logits_processor( self.tokenizer, self.args.vocab_size, - request.response_format, + first_request.response_format, ), ): - yield result[0] - - -class Llama3Generator: - def __init__( - self, - config: MetaReferenceInferenceConfig, - model_id: str, - llama_model: Model, - ): - if config.checkpoint_dir and config.checkpoint_dir != "null": - ckpt_dir = config.checkpoint_dir - else: - resolved_model = resolve_model(model_id) - if resolved_model is None: - # if the model is not a native llama model, get the default checkpoint_dir based on model id - ckpt_dir = model_checkpoint_dir(model_id) - else: - # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value - ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) - - if config.quantization: - if config.quantization.type == "fp8_mixed": - quantization_mode = QuantizationMode.fp8_mixed - elif config.quantization.type == "int4_mixed": - quantization_mode = QuantizationMode.int4_mixed - elif config.quantization.type == "bf16": - quantization_mode = None - else: - raise ValueError(f"Unsupported quantization mode {config.quantization}") - else: - quantization_mode = None - - self.inner_generator = Llama3.build( - ckpt_dir=ckpt_dir, - max_seq_len=config.max_seq_len, - max_batch_size=config.max_batch_size, - world_size=config.model_parallel_size or llama_model.pth_file_count, - quantization_mode=quantization_mode, - ) - self.tokenizer = self.inner_generator.tokenizer - self.args = self.inner_generator.args - self.formatter = self.inner_generator.formatter - - def completion( - self, - request: CompletionRequestWithRawContent, - ) -> Generator: - sampling_params = request.sampling_params or SamplingParams() - max_gen_len = sampling_params.max_tokens - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: - max_gen_len = self.args.max_seq_len - 1 - - temperature, top_p = _infer_sampling_params(sampling_params) - for result in self.inner_generator.generate( - model_inputs=[self.formatter.encode_content(request.content)], - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=bool(request.logprobs), - echo=False, - logits_processor=get_logits_processor( - self.tokenizer, - self.args.vocab_size, - request.response_format, - ), - ): - yield result[0] - - def chat_completion( - self, - request: ChatCompletionRequestWithRawContent, - ) -> Generator: - sampling_params = request.sampling_params or SamplingParams() - max_gen_len = sampling_params.max_tokens - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: - max_gen_len = self.args.max_seq_len - 1 - - temperature, top_p = _infer_sampling_params(sampling_params) - for result in self.inner_generator.generate( - model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))], - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=bool(request.logprobs), - echo=False, - logits_processor=get_logits_processor( - self.tokenizer, - self.args.vocab_size, - request.response_format, - ), - ): - yield result[0] + yield result diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 3a7632065b..faf19a9c62 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -9,6 +9,7 @@ import os from typing import AsyncGenerator, List, Optional, Union +from pydantic import BaseModel from termcolor import cprint from llama_stack.apis.common.content_types import ( @@ -17,6 +18,8 @@ ToolCallParseStatus, ) from llama_stack.apis.inference import ( + BatchChatCompletionResponse, + BatchCompletionResponse, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseEvent, @@ -38,6 +41,7 @@ ToolConfig, ToolDefinition, ToolPromptFormat, + UserMessage, ) from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat @@ -65,7 +69,7 @@ ) from .config import MetaReferenceInferenceConfig -from .generators import Llama3Generator, Llama4Generator +from .generators import LlamaGenerator from .model_parallel import LlamaModelParallelGenerator log = logging.getLogger(__name__) @@ -74,12 +78,8 @@ SEMAPHORE = asyncio.Semaphore(1) -def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator: - return Llama3Generator(config, model_id, llama_model) - - -def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator: - return Llama4Generator(config, model_id, llama_model) +def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: + return LlamaGenerator(config, model_id, llama_model) class MetaReferenceInferenceImpl( @@ -139,24 +139,12 @@ async def register_model(self, model: Model) -> Model: async def load_model(self, model_id, llama_model) -> None: log.info(f"Loading model `{model_id}`") - if llama_model.model_family in { - ModelFamily.llama3, - ModelFamily.llama3_1, - ModelFamily.llama3_2, - ModelFamily.llama3_3, - }: - builder_fn = llama3_builder_fn - elif llama_model.model_family == ModelFamily.llama4: - builder_fn = llama4_builder_fn - else: - raise ValueError(f"Unsupported model family: {llama_model.model_family}") - builder_params = [self.config, model_id, llama_model] if self.config.create_distributed_process_group: self.generator = LlamaModelParallelGenerator( model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count, - builder_fn=builder_fn, + builder_fn=llama_builder_fn, builder_params=builder_params, formatter=( Llama4ChatFormat(Llama4Tokenizer.get_instance()) @@ -166,11 +154,24 @@ async def load_model(self, model_id, llama_model) -> None: ) self.generator.start() else: - self.generator = builder_fn(*builder_params) + self.generator = llama_builder_fn(*builder_params) self.model_id = model_id self.llama_model = llama_model + print("Warming up...") + await self.completion( + model_id=model_id, + content="Hello, world!", + sampling_params=SamplingParams(max_tokens=10), + ) + await self.chat_completion( + model_id=model_id, + messages=[UserMessage(content="Hi how are you?")], + sampling_params=SamplingParams(max_tokens=20), + ) + print("Warmed up!") + def check_model(self, request) -> None: if self.model_id is None or self.llama_model is None: raise RuntimeError( @@ -208,7 +209,43 @@ async def completion( if request.stream: return self._stream_completion(request) else: - return await self._nonstream_completion(request) + results = await self._nonstream_completion([request]) + return results[0] + + async def batch_completion( + self, + model_id: str, + content_batch: List[InterleavedContent], + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> BatchCompletionResponse: + if sampling_params is None: + sampling_params = SamplingParams() + if logprobs: + assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" + + content_batch = [ + augment_content_with_response_format_prompt(response_format, content) for content in content_batch + ] + + request_batch = [] + for content in content_batch: + request = CompletionRequest( + model=model_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + self.check_model(request) + request = await convert_request_to_raw(request) + request_batch.append(request) + + results = await self._nonstream_completion(request_batch) + return BatchCompletionResponse(batch=results) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: tokenizer = self.generator.formatter.tokenizer @@ -253,37 +290,54 @@ def impl(): for x in impl(): yield x - async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: + async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]: tokenizer = self.generator.formatter.tokenizer - def impl(): - tokens = [] - logprobs = [] - stop_reason = None + first_request = request_batch[0] - for token_result in self.generator.completion(request): - tokens.append(token_result.token) - if token_result.token == tokenizer.eot_id: - stop_reason = StopReason.end_of_turn - elif token_result.token == tokenizer.eom_id: - stop_reason = StopReason.end_of_message + class ItemState(BaseModel): + tokens: List[int] = [] + logprobs: List[TokenLogProbs] = [] + stop_reason: StopReason | None = None + finished: bool = False - if request.logprobs: - assert len(token_result.logprobs) == 1 - - logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens + def impl(): + states = [ItemState() for _ in request_batch] + + results = [] + for token_results in self.generator.completion(request_batch): + for result in token_results: + idx = result.batch_idx + state = states[idx] + if state.finished or result.ignore_token: + continue + + state.finished = result.finished + if first_request.logprobs: + state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) + + state.tokens.append(result.token) + if result.token == tokenizer.eot_id: + state.stop_reason = StopReason.end_of_turn + elif result.token == tokenizer.eom_id: + state.stop_reason = StopReason.end_of_message + + for state in states: + if state.stop_reason is None: + state.stop_reason = StopReason.out_of_tokens + + if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens: + state.tokens = state.tokens[:-1] + content = self.generator.formatter.tokenizer.decode(state.tokens) + results.append( + CompletionResponse( + content=content, + stop_reason=state.stop_reason, + logprobs=state.logprobs if first_request.logprobs else None, + ) + ) - if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens: - tokens = tokens[:-1] - content = self.generator.formatter.tokenizer.decode(tokens) - return CompletionResponse( - content=content, - stop_reason=stop_reason, - logprobs=logprobs if request.logprobs else None, - ) + return results if self.config.create_distributed_process_group: async with SEMAPHORE: @@ -318,7 +372,7 @@ async def chat_completion( response_format=response_format, stream=stream, logprobs=logprobs, - tool_config=tool_config, + tool_config=tool_config or ToolConfig(), ) self.check_model(request) @@ -334,44 +388,110 @@ async def chat_completion( if request.stream: return self._stream_chat_completion(request) else: - return await self._nonstream_chat_completion(request) + results = await self._nonstream_chat_completion([request]) + return results[0] - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - tokenizer = self.generator.formatter.tokenizer + async def batch_chat_completion( + self, + model_id: str, + messages_batch: List[List[Message]], + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, + ) -> BatchChatCompletionResponse: + if sampling_params is None: + sampling_params = SamplingParams() + if logprobs: + assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" - def impl(): - tokens = [] - logprobs = [] - stop_reason = None + # wrapper request to make it easier to pass around (internal only, not exposed to API) + request_batch = [] + for messages in messages_batch: + request = ChatCompletionRequest( + model=model_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + response_format=response_format, + logprobs=logprobs, + tool_config=tool_config or ToolConfig(), + ) + self.check_model(request) - for token_result in self.generator.chat_completion(request): - if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": - cprint(token_result.text, "cyan", end="") + # augment and rewrite messages depending on the model + request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) + # download media and convert to raw content so we can send it to the model + request = await convert_request_to_raw(request) + request_batch.append(request) - tokens.append(token_result.token) + if self.config.create_distributed_process_group: + if SEMAPHORE.locked(): + raise RuntimeError("Only one concurrent request is supported") - if token_result.token == tokenizer.eot_id: - stop_reason = StopReason.end_of_turn - elif token_result.token == tokenizer.eom_id: - stop_reason = StopReason.end_of_message + results = await self._nonstream_chat_completion(request_batch) + return BatchChatCompletionResponse(batch=results) - if request.logprobs: - assert len(token_result.logprobs) == 1 + async def _nonstream_chat_completion( + self, request_batch: List[ChatCompletionRequest] + ) -> List[ChatCompletionResponse]: + tokenizer = self.generator.formatter.tokenizer - logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) + first_request = request_batch[0] - if stop_reason is None: - stop_reason = StopReason.out_of_tokens + class ItemState(BaseModel): + tokens: List[int] = [] + logprobs: List[TokenLogProbs] = [] + stop_reason: StopReason | None = None + finished: bool = False - raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=raw_message.content, - stop_reason=raw_message.stop_reason, - tool_calls=raw_message.tool_calls, - ), - logprobs=logprobs if request.logprobs else None, - ) + def impl(): + states = [ItemState() for _ in request_batch] + + for token_results in self.generator.chat_completion(request_batch): + first = token_results[0] + if not first.finished: + if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"): + cprint(first.text, "cyan", end="") + if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": + cprint(f"<{first.token}>", "magenta", end="") + + for result in token_results: + idx = result.batch_idx + state = states[idx] + if state.finished or result.ignore_token: + continue + + state.finished = result.finished + if first_request.logprobs: + state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) + + state.tokens.append(result.token) + if result.token == tokenizer.eot_id: + state.stop_reason = StopReason.end_of_turn + elif result.token == tokenizer.eom_id: + state.stop_reason = StopReason.end_of_message + + results = [] + for state in states: + if state.stop_reason is None: + state.stop_reason = StopReason.out_of_tokens + + raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason) + results.append( + ChatCompletionResponse( + completion_message=CompletionMessage( + content=raw_message.content, + stop_reason=raw_message.stop_reason, + tool_calls=raw_message.tool_calls, + ), + logprobs=state.logprobs if first_request.logprobs else None, + ) + ) + + return results if self.config.create_distributed_process_group: async with SEMAPHORE: @@ -398,6 +518,22 @@ def impl(): for token_result in self.generator.chat_completion(request): if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": cprint(token_result.text, "cyan", end="") + if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": + cprint(f"<{token_result.token}>", "magenta", end="") + + if token_result.token == tokenizer.eot_id: + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.token == tokenizer.eom_id: + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text + + if request.logprobs: + assert len(token_result.logprobs) == 1 + + logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) tokens.append(token_result.token) diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index bed3025a8e..50640c6d1e 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -6,7 +6,7 @@ from copy import deepcopy from functools import partial -from typing import Any, Callable, Generator +from typing import Any, Callable, Generator, List from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat @@ -23,13 +23,13 @@ def __init__(self, llama): self.llama = llama # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` - def __call__(self, req: Any): - if isinstance(req, ChatCompletionRequestWithRawContent): - return self.llama.chat_completion(req) - elif isinstance(req, CompletionRequestWithRawContent): - return self.llama.completion(req) + def __call__(self, task: Any): + if task[0] == "chat_completion": + return self.llama.chat_completion(task[1]) + elif task[0] == "completion": + return self.llama.completion(task[1]) else: - raise ValueError(f"Unexpected task type {type(req)}") + raise ValueError(f"Unexpected task type {task[0]}") def init_model_cb( @@ -82,16 +82,16 @@ def __exit__(self, exc_type, exc_value, exc_traceback): def completion( self, - request: CompletionRequestWithRawContent, + request_batch: List[CompletionRequestWithRawContent], ) -> Generator: - req_obj = deepcopy(request) - gen = self.group.run_inference(req_obj) + req_obj = deepcopy(request_batch) + gen = self.group.run_inference(("completion", req_obj)) yield from gen def chat_completion( self, - request: ChatCompletionRequestWithRawContent, + request_batch: List[ChatCompletionRequestWithRawContent], ) -> Generator: - req_obj = deepcopy(request) - gen = self.group.run_inference(req_obj) + req_obj = deepcopy(request_batch) + gen = self.group.run_inference(("chat_completion", req_obj)) yield from gen diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 74fc49d5e5..8752f06f3f 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -19,7 +19,7 @@ import time import uuid from enum import Enum -from typing import Callable, Generator, Literal, Optional, Union +from typing import Callable, Generator, List, Literal, Optional, Tuple, Union import torch import zmq @@ -69,12 +69,12 @@ class CancelSentinel(BaseModel): class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent] + task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]] class TaskResponse(BaseModel): type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response - result: GenerationResult + result: List[GenerationResult] class ExceptionResponse(BaseModel): @@ -331,7 +331,7 @@ def stop(self): def run_inference( self, - req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent], + req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]], ) -> Generator: assert not self.running, "inference already running" diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 9c370b6c56..5bc20e3c24 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -10,6 +10,7 @@ from llama_stack.apis.inference import ( CompletionResponse, Inference, + InterleavedContent, LogProbConfig, Message, ResponseFormat, @@ -80,3 +81,25 @@ async def chat_completion( tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") + + async def batch_completion( + self, + model_id: str, + content_batch: List[InterleavedContent], + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch completion is not supported for Sentence Transformers") + + async def batch_chat_completion( + self, + model_id: str, + messages_batch: List[List[Message]], + sampling_params: Optional[SamplingParams] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_config: Optional[ToolConfig] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers") diff --git a/llama_stack/providers/registry/batch_inference.py b/llama_stack/providers/registry/batch_inference.py new file mode 100644 index 0000000000..07b2167d6b --- /dev/null +++ b/llama_stack/providers/registry/batch_inference.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.providers.datatypes import ( + Api, + InlineProviderSpec, + ProviderSpec, +) + +META_REFERENCE_DEPS = [ + "accelerate", + "blobfile", + "fairscale", + "torch", + "torchvision", + "transformers", + "zmq", + "lm-format-enforcer", + "sentence-transformers", + "torchao==0.5.0", + "fbgemm-gpu-genai==1.1.2", +] + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.inference, + provider_type="inline::meta-reference", + pip_packages=META_REFERENCE_DEPS, + module="llama_stack.providers.inline.batch_inference.meta_reference", + config_class="llama_stack.providers.inline.batch_inference.meta_reference.MetaReferenceInferenceConfig", + ), + ] diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index b8671197ef..33b48af461 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -437,6 +437,28 @@ async def openai_chat_completion( } return await self.openai_client.chat.completions.create(**params) # type: ignore + async def batch_completion( + self, + model_id: str, + content_batch: List[InterleavedContent], + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch completion is not supported for Ollama") + + async def batch_chat_completion( + self, + model_id: str, + messages_batch: List[List[Message]], + sampling_params: Optional[SamplingParams] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_config: Optional[ToolConfig] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch chat completion is not supported for Ollama") + async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 79f92adce7..0044d2e752 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -526,3 +526,25 @@ async def openai_chat_completion( user=user, ) return await self.client.chat.completions.create(**params) # type: ignore + + async def batch_completion( + self, + model_id: str, + content_batch: List[InterleavedContent], + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch completion is not supported for Ollama") + + async def batch_chat_completion( + self, + model_id: str, + messages_batch: List[List[Message]], + sampling_params: Optional[SamplingParams] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_config: Optional[ToolConfig] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch chat completion is not supported for Ollama") diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 2d2f0400ac..cd0f4ec676 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -347,3 +347,25 @@ async def openai_chat_completion( user=user, ) return litellm.completion(**params) + + async def batch_completion( + self, + model_id: str, + content_batch: List[InterleavedContent], + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch completion is not supported for OpenAI Compat") + + async def batch_chat_completion( + self, + model_id: str, + messages_batch: List[List[Message]], + sampling_params: Optional[SamplingParams] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_config: Optional[ToolConfig] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat") diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index 8fd55add07..40d604c3c7 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -20,6 +20,7 @@ class WebMethod: raw_bytes_request_body: Optional[bool] = False # A descriptive name of the corresponding span created by tracing descriptive_name: Optional[str] = None + experimental: Optional[bool] = False T = TypeVar("T", bound=Callable[..., Any]) @@ -33,6 +34,7 @@ def webmethod( response_examples: Optional[List[Any]] = None, raw_bytes_request_body: Optional[bool] = False, descriptive_name: Optional[str] = None, + experimental: Optional[bool] = False, ) -> Callable[[T], T]: """ Decorator that supplies additional metadata to an endpoint operation function. @@ -52,6 +54,7 @@ def wrap(func: T) -> T: response_examples=response_examples, raw_bytes_request_body=raw_bytes_request_body, descriptive_name=descriptive_name, + experimental=experimental, ) return func diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 9f97158f8b..63177ab095 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -16,11 +16,12 @@ providers: provider_type: inline::meta-reference config: model: ${env.INFERENCE_MODEL} - max_seq_len: 4096 checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} quantization: type: ${env.QUANTIZATION_TYPE:bf16} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} + max_batch_size: ${env.MAX_BATCH_SIZE:1} + max_seq_len: ${env.MAX_SEQ_LEN:4096} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} @@ -28,11 +29,12 @@ providers: provider_type: inline::meta-reference config: model: ${env.SAFETY_MODEL} - max_seq_len: 4096 checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null} quantization: type: ${env.QUANTIZATION_TYPE:bf16} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} + max_batch_size: ${env.MAX_BATCH_SIZE:1} + max_seq_len: ${env.MAX_SEQ_LEN:4096} vector_io: - provider_id: faiss provider_type: inline::faiss diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index eda332123d..380d83060a 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -16,11 +16,12 @@ providers: provider_type: inline::meta-reference config: model: ${env.INFERENCE_MODEL} - max_seq_len: 4096 checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} quantization: type: ${env.QUANTIZATION_TYPE:bf16} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} + max_batch_size: ${env.MAX_BATCH_SIZE:1} + max_seq_len: ${env.MAX_SEQ_LEN:4096} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} diff --git a/tests/integration/inference/test_batch_inference.py b/tests/integration/inference/test_batch_inference.py new file mode 100644 index 0000000000..f2bbd06989 --- /dev/null +++ b/tests/integration/inference/test_batch_inference.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +from llama_stack.models.llama.sku_list import resolve_model + +from ..test_cases.test_case import TestCase + +PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"} + + +def skip_if_model_doesnt_support_completion(client_with_models, model_id): + models = {m.identifier: m for m in client_with_models.models.list()} + models.update({m.provider_resource_id: m for m in client_with_models.models.list()}) + provider_id = models[model_id].provider_id + providers = {p.provider_id: p for p in client_with_models.providers.list()} + provider = providers[provider_id] + if provider.provider_type in ( + "remote::openai", + "remote::anthropic", + "remote::gemini", + "remote::groq", + "remote::llama-openai-compat", + ): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion") + + +def get_llama_model(client_with_models, model_id): + models = {} + for m in client_with_models.models.list(): + models[m.identifier] = m + models[m.provider_resource_id] = m + + assert model_id in models, f"Model {model_id} not found" + + model = models[model_id] + ids = (model.identifier, model.provider_resource_id) + for mid in ids: + if resolve_model(mid): + return mid + + return model.metadata.get("llama_model", None) + + +def get_llama_tokenizer(): + from llama_models.llama3.api.chat_format import ChatFormat + from llama_models.llama3.api.tokenizer import Tokenizer + + tokenizer = Tokenizer.get_instance() + formatter = ChatFormat(tokenizer) + return tokenizer, formatter + + +@pytest.mark.parametrize( + "test_case", + [ + "inference:completion:batch_completion", + ], +) +def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_completion(client_with_models, text_model_id) + tc = TestCase(test_case) + + content_batch = tc["contents"] + response = client_with_models.inference.batch_completion( + content_batch=content_batch, + model_id=text_model_id, + sampling_params={ + "max_tokens": 50, + }, + ) + assert len(response.batch) == len(content_batch) + for i, r in enumerate(response.batch): + print(f"response {i}: {r.content}") + assert len(r.content) > 10 + + +@pytest.mark.parametrize( + "test_case", + [ + "inference:chat_completion:batch_completion", + ], +) +def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case): + tc = TestCase(test_case) + qa_pairs = tc["qa_pairs"] + + message_batch = [ + [ + { + "role": "user", + "content": qa["question"], + } + ] + for qa in qa_pairs + ] + + response = client_with_models.inference.batch_chat_completion( + messages_batch=message_batch, + model_id=text_model_id, + ) + assert len(response.batch) == len(qa_pairs) + for i, r in enumerate(response.batch): + print(f"response {i}: {r.completion_message.content}") + assert len(r.completion_message.content) > 0 + assert qa_pairs[i]["answer"].lower() in r.completion_message.content.lower() diff --git a/tests/integration/test_cases/inference/chat_completion.json b/tests/integration/test_cases/inference/chat_completion.json index 01956bd59c..5663089fb7 100644 --- a/tests/integration/test_cases/inference/chat_completion.json +++ b/tests/integration/test_cases/inference/chat_completion.json @@ -537,5 +537,31 @@ } ] } + }, + "batch_completion": { + "data": { + "qa_pairs": [ + { + "question": "What is the capital of France?", + "answer": "Paris" + }, + { + "question": "Who wrote the book '1984'?", + "answer": "George Orwell" + }, + { + "question": "Which planet has rings around it with a name starting with letter S?", + "answer": "Saturn" + }, + { + "question": "When did the first moon landing happen?", + "answer": "1969" + }, + { + "question": "What word says 'hello' in Spanish?", + "answer": "Hola" + } + ] + } } } diff --git a/tests/integration/test_cases/inference/completion.json b/tests/integration/test_cases/inference/completion.json index 06abbdc8b2..731ceddbca 100644 --- a/tests/integration/test_cases/inference/completion.json +++ b/tests/integration/test_cases/inference/completion.json @@ -44,5 +44,18 @@ "year_retired": "2003" } } + }, + "batch_completion": { + "data": { + "contents": [ + "Micheael Jordan is born in ", + "Roses are red, violets are ", + "If you had a million dollars, what would you do with it? ", + "All you need is ", + "The capital of France is ", + "It is a good day to ", + "The answer to the universe is " + ] + } } } From 73d927850e1b7b9ad813b1eb8e39e84ab451fd12 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 11 Apr 2025 16:15:59 -0700 Subject: [PATCH 2/6] updates --- docs/_static/llama-stack-spec.html | 182 +----------------- docs/_static/llama-stack-spec.yaml | 137 ++----------- .../apis/batch_inference/batch_inference.py | 32 +-- .../inference/meta_reference/inference.py | 8 +- 4 files changed, 43 insertions(+), 316 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index ff7f492e7f..3e9539f41f 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -128,49 +128,6 @@ } } }, - "/v1/batch-inference/chat-completion-inline": { - "post": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/BatchChatCompletionResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "BatchInference (Coming Soon)" - ], - "description": "", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/BatchChatCompletionInlineRequest" - } - } - }, - "required": true - } - } - }, "/v1/inference/batch-completion": { "post": { "responses": { @@ -214,49 +171,6 @@ } } }, - "/v1/batch-inference/completion-inline": { - "post": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/BatchCompletionResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "BatchInference (Coming Soon)" - ], - "description": "", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/BatchCompletionInlineRequest" - } - } - }, - "required": true - } - } - }, "/v1/post-training/job/cancel": { "post": { "responses": { @@ -325,7 +239,7 @@ } }, "tags": [ - "Inference" + "BatchInference (Coming Soon)" ], "description": "Generate a chat completion for the given messages using the specified model.", "parameters": [], @@ -373,7 +287,7 @@ } }, "tags": [ - "Inference" + "BatchInference (Coming Soon)" ], "description": "Generate a completion for the given content using the specified model.", "parameters": [], @@ -4821,56 +4735,6 @@ "title": "TokenLogProbs", "description": "Log probabilities for generated tokens." }, - "BatchChatCompletionInlineRequest": { - "type": "object", - "properties": { - "model": { - "type": "string" - }, - "messages_batch": { - "type": "array", - "items": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Message" - } - } - }, - "sampling_params": { - "$ref": "#/components/schemas/SamplingParams" - }, - "tools": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolDefinition" - } - }, - "tool_config": { - "$ref": "#/components/schemas/ToolConfig" - }, - "response_format": { - "$ref": "#/components/schemas/ResponseFormat" - }, - "logprobs": { - "type": "object", - "properties": { - "top_k": { - "type": "integer", - "default": 0, - "description": "How many tokens (for each position) to return log probabilities for." - } - }, - "additionalProperties": false, - "title": "LogProbConfig" - } - }, - "additionalProperties": false, - "required": [ - "model", - "messages_batch" - ], - "title": "BatchChatCompletionInlineRequest" - }, "BatchCompletionRequest": { "type": "object", "properties": { @@ -4963,44 +4827,6 @@ "title": "CompletionResponse", "description": "Response from a completion request." }, - "BatchCompletionInlineRequest": { - "type": "object", - "properties": { - "model": { - "type": "string" - }, - "content_batch": { - "type": "array", - "items": { - "$ref": "#/components/schemas/InterleavedContent" - } - }, - "sampling_params": { - "$ref": "#/components/schemas/SamplingParams" - }, - "response_format": { - "$ref": "#/components/schemas/ResponseFormat" - }, - "logprobs": { - "type": "object", - "properties": { - "top_k": { - "type": "integer", - "default": 0, - "description": "How many tokens (for each position) to return log probabilities for." - } - }, - "additionalProperties": false, - "title": "LogProbConfig" - } - }, - "additionalProperties": false, - "required": [ - "model", - "content_batch" - ], - "title": "BatchCompletionInlineRequest" - }, "CancelTrainingJobRequest": { "type": "object", "properties": { @@ -11331,7 +11157,9 @@ "x-displayName": "Agents API for creating and interacting with agentic systems." }, { - "name": "BatchInference (Coming Soon)" + "name": "BatchInference (Coming Soon)", + "description": "This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.\n\nNOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs\nincluding (post-training, evals, etc).", + "x-displayName": "Batch inference API for generating completions and chat completions." }, { "name": "Benchmarks" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 279e240ee1..0e632fcde8 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -69,35 +69,6 @@ paths: schema: $ref: '#/components/schemas/BatchChatCompletionRequest' required: true - /v1/batch-inference/chat-completion-inline: - post: - responses: - '200': - description: OK - content: - application/json: - schema: - $ref: '#/components/schemas/BatchChatCompletionResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - BatchInference (Coming Soon) - description: '' - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/BatchChatCompletionInlineRequest' - required: true /v1/inference/batch-completion: post: responses: @@ -127,35 +98,6 @@ paths: schema: $ref: '#/components/schemas/BatchCompletionRequest' required: true - /v1/batch-inference/completion-inline: - post: - responses: - '200': - description: OK - content: - application/json: - schema: - $ref: '#/components/schemas/BatchCompletionResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - BatchInference (Coming Soon) - description: '' - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/BatchCompletionInlineRequest' - required: true /v1/post-training/job/cancel: post: responses: @@ -206,7 +148,7 @@ paths: default: $ref: '#/components/responses/DefaultError' tags: - - Inference + - BatchInference (Coming Soon) description: >- Generate a chat completion for the given messages using the specified model. parameters: [] @@ -241,7 +183,7 @@ paths: default: $ref: '#/components/responses/DefaultError' tags: - - Inference + - BatchInference (Coming Soon) description: >- Generate a completion for the given content using the specified model. parameters: [] @@ -3346,42 +3288,6 @@ components: - logprobs_by_token title: TokenLogProbs description: Log probabilities for generated tokens. - BatchChatCompletionInlineRequest: - type: object - properties: - model: - type: string - messages_batch: - type: array - items: - type: array - items: - $ref: '#/components/schemas/Message' - sampling_params: - $ref: '#/components/schemas/SamplingParams' - tools: - type: array - items: - $ref: '#/components/schemas/ToolDefinition' - tool_config: - $ref: '#/components/schemas/ToolConfig' - response_format: - $ref: '#/components/schemas/ResponseFormat' - logprobs: - type: object - properties: - top_k: - type: integer - default: 0 - description: >- - How many tokens (for each position) to return log probabilities for. - additionalProperties: false - title: LogProbConfig - additionalProperties: false - required: - - model - - messages_batch - title: BatchChatCompletionInlineRequest BatchCompletionRequest: type: object properties: @@ -3450,34 +3356,6 @@ components: - stop_reason title: CompletionResponse description: Response from a completion request. - BatchCompletionInlineRequest: - type: object - properties: - model: - type: string - content_batch: - type: array - items: - $ref: '#/components/schemas/InterleavedContent' - sampling_params: - $ref: '#/components/schemas/SamplingParams' - response_format: - $ref: '#/components/schemas/ResponseFormat' - logprobs: - type: object - properties: - top_k: - type: integer - default: 0 - description: >- - How many tokens (for each position) to return log probabilities for. - additionalProperties: false - title: LogProbConfig - additionalProperties: false - required: - - model - - content_batch - title: BatchCompletionInlineRequest CancelTrainingJobRequest: type: object properties: @@ -7737,6 +7615,17 @@ tags: x-displayName: >- Agents API for creating and interacting with agentic systems. - name: BatchInference (Coming Soon) + description: >- + This is an asynchronous API. If the request is successful, the response will + be a job which can be polled for completion. + + + NOTE: This API is not yet implemented and is subject to change in concert with + other asynchronous APIs + + including (post-training, evals, etc). + x-displayName: >- + Batch inference API for generating completions and chat completions. - name: Benchmarks - name: DatasetIO - name: Datasets diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 57fcd7ebb2..7a324128dd 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -6,40 +6,50 @@ from typing import List, Optional, Protocol, runtime_checkable +from llama_stack.apis.common.job_types import Job from llama_stack.apis.inference import ( - BatchChatCompletionResponse, - BatchCompletionResponse, InterleavedContent, LogProbConfig, Message, ResponseFormat, SamplingParams, - ToolConfig, + ToolChoice, ToolDefinition, + ToolPromptFormat, ) from llama_stack.schema_utils import webmethod @runtime_checkable class BatchInference(Protocol): - @webmethod(route="/batch-inference/completion-inline", method="POST") - async def batch_completion_inline( + """Batch inference API for generating completions and chat completions. + + This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion. + + NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs + including (post-training, evals, etc). + """ + + @webmethod(route="/batch-inference/completion", method="POST") + async def completion( self, model: str, content_batch: List[InterleavedContent], sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, logprobs: Optional[LogProbConfig] = None, - ) -> BatchCompletionResponse: ... + ) -> Job: ... - @webmethod(route="/batch-inference/chat-completion-inline", method="POST") - async def batch_chat_completion_inline( + @webmethod(route="/batch-inference/chat-completion", method="POST") + async def chat_completion( self, model: str, messages_batch: List[List[Message]], sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = list, - tool_config: Optional[ToolConfig] = None, + # zero-shot tool definitions as input to the model + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = None, response_format: Optional[ResponseFormat] = None, logprobs: Optional[LogProbConfig] = None, - ) -> BatchChatCompletionResponse: ... + ) -> Job: ... diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index faf19a9c62..da5ded0f32 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import os from typing import AsyncGenerator, List, Optional, Union @@ -44,6 +43,7 @@ UserMessage, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.log import get_logger from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat @@ -72,7 +72,7 @@ from .generators import LlamaGenerator from .model_parallel import LlamaModelParallelGenerator -log = logging.getLogger(__name__) +log = get_logger(__name__, category="inference") # there's a single model parallel process running serving the model. for now, # we don't support multiple concurrent requests to this process. SEMAPHORE = asyncio.Semaphore(1) @@ -159,7 +159,7 @@ async def load_model(self, model_id, llama_model) -> None: self.model_id = model_id self.llama_model = llama_model - print("Warming up...") + log.info("Warming up...") await self.completion( model_id=model_id, content="Hello, world!", @@ -170,7 +170,7 @@ async def load_model(self, model_id, llama_model) -> None: messages=[UserMessage(content="Hi how are you?")], sampling_params=SamplingParams(max_tokens=20), ) - print("Warmed up!") + log.info("Warmed up!") def check_model(self, request) -> None: if self.model_id is None or self.llama_model is None: From 1d855461d55b0aec705cdc1a83324a5453838401 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 11 Apr 2025 16:21:21 -0700 Subject: [PATCH 3/6] kill batch inference registry --- .../providers/registry/batch_inference.py | 39 ------------------- 1 file changed, 39 deletions(-) delete mode 100644 llama_stack/providers/registry/batch_inference.py diff --git a/llama_stack/providers/registry/batch_inference.py b/llama_stack/providers/registry/batch_inference.py deleted file mode 100644 index 07b2167d6b..0000000000 --- a/llama_stack/providers/registry/batch_inference.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import List - -from llama_stack.providers.datatypes import ( - Api, - InlineProviderSpec, - ProviderSpec, -) - -META_REFERENCE_DEPS = [ - "accelerate", - "blobfile", - "fairscale", - "torch", - "torchvision", - "transformers", - "zmq", - "lm-format-enforcer", - "sentence-transformers", - "torchao==0.5.0", - "fbgemm-gpu-genai==1.1.2", -] - - -def available_providers() -> List[ProviderSpec]: - return [ - InlineProviderSpec( - api=Api.inference, - provider_type="inline::meta-reference", - pip_packages=META_REFERENCE_DEPS, - module="llama_stack.providers.inline.batch_inference.meta_reference", - config_class="llama_stack.providers.inline.batch_inference.meta_reference.MetaReferenceInferenceConfig", - ), - ] From a3cee70014748cca17fc954272002ac430a39c84 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 11 Apr 2025 17:13:46 -0700 Subject: [PATCH 4/6] kill experimental attr on webmethod --- llama_stack/apis/inference/inference.py | 4 ++-- llama_stack/distribution/resolver.py | 3 --- llama_stack/schema_utils.py | 3 --- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 21753ca236..9eb3910c61 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -726,7 +726,7 @@ async def completion( """ ... - @webmethod(route="/inference/batch-completion", method="POST", experimental=True) + @webmethod(route="/inference/batch-completion", method="POST") async def batch_completion( self, model_id: str, @@ -777,7 +777,7 @@ async def chat_completion( """ ... - @webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True) + @webmethod(route="/inference/batch-chat-completion", method="POST") async def batch_chat_completion( self, model_id: str, diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 177d20f2bb..33ad343ec8 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -400,9 +400,6 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: mro = type(obj).__mro__ for name, value in inspect.getmembers(protocol): if inspect.isfunction(value) and hasattr(value, "__webmethod__"): - if value.__webmethod__.experimental: - continue - if not hasattr(obj, name): missing_methods.append((name, "missing")) elif not callable(getattr(obj, name)): diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index 40d604c3c7..8fd55add07 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -20,7 +20,6 @@ class WebMethod: raw_bytes_request_body: Optional[bool] = False # A descriptive name of the corresponding span created by tracing descriptive_name: Optional[str] = None - experimental: Optional[bool] = False T = TypeVar("T", bound=Callable[..., Any]) @@ -34,7 +33,6 @@ def webmethod( response_examples: Optional[List[Any]] = None, raw_bytes_request_body: Optional[bool] = False, descriptive_name: Optional[str] = None, - experimental: Optional[bool] = False, ) -> Callable[[T], T]: """ Decorator that supplies additional metadata to an endpoint operation function. @@ -54,7 +52,6 @@ def wrap(func: T) -> T: response_examples=response_examples, raw_bytes_request_body=raw_bytes_request_body, descriptive_name=descriptive_name, - experimental=experimental, ) return func From 771daa4b911a52860a439457f0f553256b9da573 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 12 Apr 2025 10:51:43 -0700 Subject: [PATCH 5/6] fix test, fix llama3 generator --- llama_stack/models/llama/llama3/generation.py | 14 +++--- .../inference/test_batch_inference.py | 45 +++---------------- 2 files changed, 12 insertions(+), 47 deletions(-) diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py index 98412a1d46..35c1407078 100644 --- a/llama_stack/models/llama/llama3/generation.py +++ b/llama_stack/models/llama/llama3/generation.py @@ -154,7 +154,7 @@ def __init__( @torch.inference_mode() def generate( self, - model_inputs: List[LLMInput], + llm_inputs: List[LLMInput], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, @@ -169,15 +169,15 @@ def generate( print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" if print_model_input: - for inp in model_inputs: + for inp in llm_inputs: tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens] cprint( "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", "red", ) - prompt_tokens = [inp.tokens for inp in model_inputs] + prompt_tokens = [inp.tokens for inp in llm_inputs] - bsz = len(model_inputs) + bsz = len(llm_inputs) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) min_prompt_len = min(len(t) for t in prompt_tokens) @@ -198,8 +198,8 @@ def generate( is_vision = not isinstance(self.model, Transformer) if is_vision: - images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs] - mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs] + images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs] + mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs] xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( batch_images=images, @@ -234,7 +234,7 @@ def generate( for cur_pos in range(min_prompt_len, total_len): if is_vision: position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) - text_only_inference = all(inp.vision is None for inp in model_inputs) + text_only_inference = all(inp.vision is None for inp in llm_inputs) logits = self.model.forward( position_ids, tokens, diff --git a/tests/integration/inference/test_batch_inference.py b/tests/integration/inference/test_batch_inference.py index f2bbd06989..9a1a62ce0e 100644 --- a/tests/integration/inference/test_batch_inference.py +++ b/tests/integration/inference/test_batch_inference.py @@ -7,53 +7,17 @@ import pytest -from llama_stack.models.llama.sku_list import resolve_model - from ..test_cases.test_case import TestCase -PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"} - -def skip_if_model_doesnt_support_completion(client_with_models, model_id): +def skip_if_provider_doesnt_support_batch_inference(client_with_models, model_id): models = {m.identifier: m for m in client_with_models.models.list()} models.update({m.provider_resource_id: m for m in client_with_models.models.list()}) provider_id = models[model_id].provider_id providers = {p.provider_id: p for p in client_with_models.providers.list()} provider = providers[provider_id] - if provider.provider_type in ( - "remote::openai", - "remote::anthropic", - "remote::gemini", - "remote::groq", - "remote::llama-openai-compat", - ): - pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion") - - -def get_llama_model(client_with_models, model_id): - models = {} - for m in client_with_models.models.list(): - models[m.identifier] = m - models[m.provider_resource_id] = m - - assert model_id in models, f"Model {model_id} not found" - - model = models[model_id] - ids = (model.identifier, model.provider_resource_id) - for mid in ids: - if resolve_model(mid): - return mid - - return model.metadata.get("llama_model", None) - - -def get_llama_tokenizer(): - from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.tokenizer import Tokenizer - - tokenizer = Tokenizer.get_instance() - formatter = ChatFormat(tokenizer) - return tokenizer, formatter + if provider.provider_type not in ("inline::meta-reference",): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support batch inference") @pytest.mark.parametrize( @@ -63,7 +27,7 @@ def get_llama_tokenizer(): ], ) def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case): - skip_if_model_doesnt_support_completion(client_with_models, text_model_id) + skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id) tc = TestCase(test_case) content_batch = tc["contents"] @@ -87,6 +51,7 @@ def test_batch_completion_non_streaming(client_with_models, text_model_id, test_ ], ) def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case): + skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id) tc = TestCase(test_case) qa_pairs = tc["qa_pairs"] From 14ff4c647cde43815a135746eecc8a0505eec511 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 12 Apr 2025 11:23:25 -0700 Subject: [PATCH 6/6] include content in the message even if you have parsed out a tool call --- llama_stack/models/llama/llama3/chat_format.py | 1 - llama_stack/models/llama/llama4/chat_format.py | 1 - .../providers/inline/inference/meta_reference/inference.py | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index f55cd5e1ce..fe7a7a8983 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -226,7 +226,6 @@ def decode_assistant_message_from_content(self, content: str, stop_reason: StopR arguments_json=json.dumps(tool_arguments), ) ) - content = "" return RawMessage( role="assistant", diff --git a/llama_stack/models/llama/llama4/chat_format.py b/llama_stack/models/llama/llama4/chat_format.py index 160bb00f84..9d60d00e9f 100644 --- a/llama_stack/models/llama/llama4/chat_format.py +++ b/llama_stack/models/llama/llama4/chat_format.py @@ -301,7 +301,6 @@ def decode_assistant_message_from_content(self, content: str, stop_reason: StopR arguments=tool_arguments, ) ) - content = "" return RawMessage( role="assistant", diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index da5ded0f32..0b56ba1f71 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -452,7 +452,7 @@ def impl(): for token_results in self.generator.chat_completion(request_batch): first = token_results[0] - if not first.finished: + if not first.finished and not first.ignore_token: if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"): cprint(first.text, "cyan", end="") if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":