From 5595f5b9b80808d068b222823ed9b4cb9923199d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Feb 2025 10:27:05 -0800 Subject: [PATCH 1/5] tmp --- .../inline/agents/meta_reference/agent_instance.py | 9 +++++++++ .../providers/remote/inference/together/together.py | 8 ++++++++ llama_stack/providers/utils/inference/prompt_adapter.py | 1 + 3 files changed, 18 insertions(+) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index f5ddbab404..07f04c9d52 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -168,10 +168,17 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn if self.agent_config.instructions != "": messages.append(SystemMessage(content=self.agent_config.instructions)) + from rich.pretty import pprint + + print("create_and_execute_turn") + pprint(request) + for i, turn in enumerate(turns): messages.extend(self.turn_to_messages(turn)) messages.extend(request.messages) + print("create_and_execute_turn turn to messages") + pprint(messages) turn_id = str(uuid.uuid4()) span.set_attribute("turn_id", turn_id) @@ -360,6 +367,7 @@ async def _run( documents: Optional[List[Document]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> AsyncGenerator: + print("_run messages", input_messages) # TODO: simplify all of this code, it can be simpler toolgroup_args = {} toolgroups = set() @@ -490,6 +498,7 @@ async def _run( stop_reason = None with tracing.span("inference") as span: + print("just before chat completion", input_messages) async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 0b965c861d..ab4779ad8e 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -196,6 +196,8 @@ async def chat_completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) + print("inside together chat completion messages", messages) + breakpoint() request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, @@ -223,6 +225,11 @@ async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> Ch async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) + from rich.pretty import pprint + + print("together stream completion") + pprint(request) + pprint(params) # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): @@ -240,6 +247,7 @@ async def _to_async_generator(): async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: input_dict = {} media_present = request_has_media(request) + breakpoint() if isinstance(request, ChatCompletionRequest): if media_present: input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 49c6ac7a96..8156c73d34 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -375,6 +375,7 @@ def _process(c): def augment_messages_for_tools_llama_3_2( request: ChatCompletionRequest, ) -> List[Message]: + breakpoint() assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" existing_messages = request.messages From b1492ecb4e5a9866ad412764d0ec81398d07fe54 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Feb 2025 10:30:19 -0800 Subject: [PATCH 2/5] revert print --- .../inline/agents/meta_reference/agent_instance.py | 9 --------- .../providers/remote/inference/together/together.py | 8 -------- 2 files changed, 17 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 07f04c9d52..f5ddbab404 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -168,17 +168,10 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn if self.agent_config.instructions != "": messages.append(SystemMessage(content=self.agent_config.instructions)) - from rich.pretty import pprint - - print("create_and_execute_turn") - pprint(request) - for i, turn in enumerate(turns): messages.extend(self.turn_to_messages(turn)) messages.extend(request.messages) - print("create_and_execute_turn turn to messages") - pprint(messages) turn_id = str(uuid.uuid4()) span.set_attribute("turn_id", turn_id) @@ -367,7 +360,6 @@ async def _run( documents: Optional[List[Document]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> AsyncGenerator: - print("_run messages", input_messages) # TODO: simplify all of this code, it can be simpler toolgroup_args = {} toolgroups = set() @@ -498,7 +490,6 @@ async def _run( stop_reason = None with tracing.span("inference") as span: - print("just before chat completion", input_messages) async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index ab4779ad8e..0b965c861d 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -196,8 +196,6 @@ async def chat_completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) - print("inside together chat completion messages", messages) - breakpoint() request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, @@ -225,11 +223,6 @@ async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> Ch async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - from rich.pretty import pprint - - print("together stream completion") - pprint(request) - pprint(params) # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): @@ -247,7 +240,6 @@ async def _to_async_generator(): async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: input_dict = {} media_present = request_has_media(request) - breakpoint() if isinstance(request, ChatCompletionRequest): if media_present: input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] From 4d0746039898a21fba8aec36d629a91bebf663b3 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Feb 2025 10:54:53 -0800 Subject: [PATCH 3/5] tmp --- llama_stack/apis/agents/agents.py | 1 + llama_stack/apis/inference/inference.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 50bea3d55c..ba383d1f76 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -165,6 +165,7 @@ class AgentConfig(AgentConfigCommon): instructions: str enable_session_persistence: bool response_format: Optional[ResponseFormat] = None + output_parser: Optional[OutputParser] = Field(default=OutputParser.auto) class AgentConfigOverridablePerTurn(AgentConfigCommon): diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 6398f74e80..637b39297b 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -19,6 +19,7 @@ from llama_models.llama3.api.datatypes import ( BuiltinTool, + ResponseOutputParser, SamplingParams, StopReason, ToolCall, @@ -319,6 +320,7 @@ class ChatCompletionRequest(BaseModel): response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None + response_output_parser: Optional[ResponseOutputParser] = Field(default=ResponseOutputParser.auto) @json_schema_type From 662f171e884f4646238890fd3160e5b82a9c3103 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Feb 2025 11:56:50 -0800 Subject: [PATCH 4/5] api change --- llama_stack/apis/agents/agents.py | 3 ++- llama_stack/apis/inference/inference.py | 2 +- llama_stack/providers/utils/inference/prompt_adapter.py | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index ba383d1f76..ea49c3479e 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -26,6 +26,7 @@ from llama_stack.apis.inference import ( CompletionMessage, ResponseFormat, + ResponseOutputParser, SamplingParams, ToolCall, ToolChoice, @@ -165,7 +166,7 @@ class AgentConfig(AgentConfigCommon): instructions: str enable_session_persistence: bool response_format: Optional[ResponseFormat] = None - output_parser: Optional[OutputParser] = Field(default=OutputParser.auto) + response_output_parser: Optional[ResponseOutputParser] = Field(default=ResponseOutputParser.default) class AgentConfigOverridablePerTurn(AgentConfigCommon): diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 637b39297b..4b1ee82d94 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -320,7 +320,7 @@ class ChatCompletionRequest(BaseModel): response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None - response_output_parser: Optional[ResponseOutputParser] = Field(default=ResponseOutputParser.auto) + response_output_parser: Optional[ResponseOutputParser] = Field(default=ResponseOutputParser.default) @json_schema_type diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 8156c73d34..49c6ac7a96 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -375,7 +375,6 @@ def _process(c): def augment_messages_for_tools_llama_3_2( request: ChatCompletionRequest, ) -> List[Message]: - breakpoint() assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" existing_messages = request.messages From 42f0e919090e90379aebf0dedb291c70b517b9cf Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Feb 2025 17:33:08 -0800 Subject: [PATCH 5/5] remove response parser --- llama_stack/apis/agents/agents.py | 2 -- llama_stack/apis/inference/inference.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index ea49c3479e..50bea3d55c 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -26,7 +26,6 @@ from llama_stack.apis.inference import ( CompletionMessage, ResponseFormat, - ResponseOutputParser, SamplingParams, ToolCall, ToolChoice, @@ -166,7 +165,6 @@ class AgentConfig(AgentConfigCommon): instructions: str enable_session_persistence: bool response_format: Optional[ResponseFormat] = None - response_output_parser: Optional[ResponseOutputParser] = Field(default=ResponseOutputParser.default) class AgentConfigOverridablePerTurn(AgentConfigCommon): diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 4b1ee82d94..6398f74e80 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -19,7 +19,6 @@ from llama_models.llama3.api.datatypes import ( BuiltinTool, - ResponseOutputParser, SamplingParams, StopReason, ToolCall, @@ -320,7 +319,6 @@ class ChatCompletionRequest(BaseModel): response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None - response_output_parser: Optional[ResponseOutputParser] = Field(default=ResponseOutputParser.default) @json_schema_type