From a055062e7c8bb7f97b2c34ef653c7b30f88a342e Mon Sep 17 00:00:00 2001 From: William Reed Date: Fri, 30 May 2025 11:50:12 -0400 Subject: [PATCH 1/3] feat(anthropic): correctly implement prompt caching for messages api --- src/mcp_agent/core/request_params.py | 9 ++++ src/mcp_agent/llm/augmented_llm.py | 22 ++++++-- src/mcp_agent/llm/augmented_llm_slow.py | 42 +++++++++++++++ src/mcp_agent/llm/model_factory.py | 4 ++ .../llm/providers/augmented_llm_anthropic.py | 54 +++++++++++++------ .../sampling/fastagent.config.yaml | 6 ++- tests/integration/sampling/live.py | 5 +- .../sampling/sampling_test_server.py | 41 ++++++++++++++ 8 files changed, 162 insertions(+), 21 deletions(-) create mode 100644 src/mcp_agent/llm/augmented_llm_slow.py diff --git a/src/mcp_agent/core/request_params.py b/src/mcp_agent/core/request_params.py index 7b087829b..9325d892f 100644 --- a/src/mcp_agent/core/request_params.py +++ b/src/mcp_agent/core/request_params.py @@ -52,3 +52,12 @@ class RequestParams(CreateMessageRequestParams): """ Optional dictionary of template variables for dynamic templates. Currently only works for TensorZero inference backend """ + + prompt_caching: bool = Field( + default=False, description="Enable prompt caching for supported LLM providers." + ) + """ + Enable prompt caching if the underlying LLM provider supports it. + For Anthropic, this adds 'cache_control': {'type': 'ephemeral'} to relevant message parts. + Behavior for other providers will depend on their specific caching mechanisms. + """ diff --git a/src/mcp_agent/llm/augmented_llm.py b/src/mcp_agent/llm/augmented_llm.py index 6ae9b6468..0b4b94ea8 100644 --- a/src/mcp_agent/llm/augmented_llm.py +++ b/src/mcp_agent/llm/augmented_llm.py @@ -95,8 +95,10 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT PARAM_USE_HISTORY = "use_history" PARAM_MAX_ITERATIONS = "max_iterations" PARAM_TEMPLATE_VARS = "template_vars" + PARAM_PROMPT_CACHING = "prompt_caching" + # Base set of fields that should always be excluded - BASE_EXCLUDE_FIELDS = {PARAM_METADATA} + BASE_EXCLUDE_FIELDS = {PARAM_METADATA, PARAM_PROMPT_CACHING} """ The basic building block of agentic systems is an LLM enhanced with augmentations @@ -361,16 +363,28 @@ def prepare_provider_arguments( # Start with base arguments arguments = base_args.copy() - # Use provided exclude_fields or fall back to base exclusions - exclude_fields = exclude_fields or self.BASE_EXCLUDE_FIELDS.copy() + # Combine base exclusions with provider-specific exclusions + final_exclude_fields = self.BASE_EXCLUDE_FIELDS.copy() + if exclude_fields: + final_exclude_fields.update(exclude_fields) # Add all fields from params that aren't explicitly excluded - params_dict = request_params.model_dump(exclude=exclude_fields) + # Ensure model_dump only includes set fields if that's the desired behavior, + # or adjust exclude_unset=True/False as needed. + # Default Pydantic v2 model_dump is exclude_unset=False + params_dict = request_params.model_dump(exclude=final_exclude_fields) + for key, value in params_dict.items(): + # Only add if not None and not already in base_args (base_args take precedence) + # or if None is a valid value for the provider, this logic might need adjustment. if value is not None and key not in arguments: arguments[key] = value + elif value is not None and key in arguments and arguments[key] is None: + # Allow overriding a None in base_args with a set value from params + arguments[key] = value # Finally, add any metadata fields as a last layer of overrides + # This ensures metadata can override anything previously set if keys conflict. if request_params.metadata: arguments.update(request_params.metadata) diff --git a/src/mcp_agent/llm/augmented_llm_slow.py b/src/mcp_agent/llm/augmented_llm_slow.py new file mode 100644 index 000000000..b40265077 --- /dev/null +++ b/src/mcp_agent/llm/augmented_llm_slow.py @@ -0,0 +1,42 @@ +import asyncio +from typing import Any, List, Optional, Union + +from mcp_agent.llm.augmented_llm import ( + MessageParamT, + RequestParams, +) +from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM +from mcp_agent.llm.provider_types import Provider +from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart + + +class SlowLLM(PassthroughLLM): + """ + A specialized LLM implementation that sleeps for 3 seconds before responding like PassthroughLLM. + + This is useful for testing scenarios where you want to simulate slow responses + or for debugging timing-related issues in parallel workflows. + """ + + def __init__( + self, provider=Provider.FAST_AGENT, name: str = "Slow", **kwargs: dict[str, Any] + ) -> None: + super().__init__(name=name, provider=provider, **kwargs) + + async def generate_str( + self, + message: Union[str, MessageParamT, List[MessageParamT]], + request_params: Optional[RequestParams] = None, + ) -> str: + """Sleep for 3 seconds then return the input message as a string.""" + await asyncio.sleep(3) + return await super().generate_str(message, request_params) + + async def _apply_prompt_provider_specific( + self, + multipart_messages: List["PromptMessageMultipart"], + request_params: RequestParams | None = None, + ) -> PromptMessageMultipart: + """Sleep for 3 seconds then apply prompt like PassthroughLLM.""" + await asyncio.sleep(3) + return await super()._apply_prompt_provider_specific(multipart_messages, request_params) diff --git a/src/mcp_agent/llm/model_factory.py b/src/mcp_agent/llm/model_factory.py index 318dab00b..1e7e1fe4f 100644 --- a/src/mcp_agent/llm/model_factory.py +++ b/src/mcp_agent/llm/model_factory.py @@ -8,6 +8,7 @@ from mcp_agent.core.request_params import RequestParams from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM from mcp_agent.llm.augmented_llm_playback import PlaybackLLM +from mcp_agent.llm.augmented_llm_slow import SlowLLM from mcp_agent.llm.provider_types import Provider from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM from mcp_agent.llm.providers.augmented_llm_azure import AzureOpenAIAugmentedLLM @@ -29,6 +30,7 @@ Type[OpenAIAugmentedLLM], Type[PassthroughLLM], Type[PlaybackLLM], + Type[SlowLLM], Type[DeepSeekAugmentedLLM], Type[OpenRouterAugmentedLLM], Type[TensorZeroAugmentedLLM], @@ -73,6 +75,7 @@ class ModelFactory: DEFAULT_PROVIDERS = { "passthrough": Provider.FAST_AGENT, "playback": Provider.FAST_AGENT, + "slow": Provider.FAST_AGENT, "gpt-4o": Provider.OPENAI, "gpt-4o-mini": Provider.OPENAI, "gpt-4.1": Provider.OPENAI, @@ -139,6 +142,7 @@ class ModelFactory: # This overrides the provider-based class selection MODEL_SPECIFIC_CLASSES: Dict[str, LLMClass] = { "playback": PlaybackLLM, + "slow": SlowLLM, } @classmethod diff --git a/src/mcp_agent/llm/providers/augmented_llm_anthropic.py b/src/mcp_agent/llm/providers/augmented_llm_anthropic.py index b719a6cf9..d484d341b 100644 --- a/src/mcp_agent/llm/providers/augmented_llm_anthropic.py +++ b/src/mcp_agent/llm/providers/augmented_llm_anthropic.py @@ -63,6 +63,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]): AugmentedLLM.PARAM_MAX_ITERATIONS, AugmentedLLM.PARAM_PARALLEL_TOOL_CALLS, AugmentedLLM.PARAM_TEMPLATE_VARS, + AugmentedLLM.PARAM_PROMPT_CACHING, } def __init__(self, *args, **kwargs) -> None: @@ -117,7 +118,33 @@ async def _anthropic_completion( # if use_history is True messages.extend(self.history.get(include_completion_history=params.use_history)) - messages.append(message_param) + messages.append(message_param) # message_param is the current user turn + + # If prompt caching is enabled, apply cache_control to the last content block + # of the current user message (which is now the last in 'messages'). + # This modification affects the 'messages' list that will be used in the first API call within the loop. + if params.prompt_caching and messages: + last_message_in_prompt = messages[-1] # This is a MessageParam (dict) + if ( + isinstance(last_message_in_prompt, dict) + and "content" in last_message_in_prompt + and isinstance(last_message_in_prompt["content"], list) + and last_message_in_prompt["content"] + ): + last_content_block = last_message_in_prompt["content"][-1] + if isinstance(last_content_block, dict): # Content blocks are dicts + self.logger.debug( + "Prompt caching enabled: Applying cache_control to the last content block of the current user message." + ) + last_content_block["cache_control"] = {"type": "ephemeral"} + else: + self.logger.warning( + "Could not apply cache_control to current user message: Last content block is not a dictionary." + ) + else: + self.logger.warning( + "Could not apply cache_control to current user message: It has no content, content is not a list, or message is not a dict." + ) tool_list: ListToolsResult = await self.aggregator.list_tools() available_tools: List[ToolParam] = [ @@ -178,13 +205,13 @@ async def _anthropic_completion( # Convert other errors to text response error_message = f"Error during generation: {error_details}" response = Message( - id="error", # Required field - model="error", # Required field + id="error", + model="error", role="assistant", type="message", content=[TextBlock(type="text", text=error_message)], - stop_reason="end_turn", # Must be one of the allowed values - usage=Usage(input_tokens=0, output_tokens=0), # Required field + stop_reason="end_turn", + usage=Usage(input_tokens=0, output_tokens=0), ) self.logger.debug( @@ -194,7 +221,7 @@ async def _anthropic_completion( response_as_message = self.convert_message_to_message_param(response) messages.append(response_as_message) - if response.content[0].type == "text": + if response.content and response.content[0].type == "text": responses.append(TextContent(type="text", text=response.content[0].text)) if response.stop_reason == "end_turn": @@ -254,12 +281,13 @@ async def _anthropic_completion( # Process all tool calls and collect results tool_results = [] - for i, content in enumerate(tool_uses): - tool_name = content.name - tool_args = content.input - tool_use_id = content.id + # Use a different loop variable for tool enumeration if 'i' is outer loop counter + for tool_idx, content_block in enumerate(tool_uses): + tool_name = content_block.name + tool_args = content_block.input + tool_use_id = content_block.id - if i == 0: # Only show message for first tool use + if tool_idx == 0: # Only show message for first tool use await self.show_assistant_message(message_text, tool_name) self.show_tool_call(available_tools, tool_name, tool_args) @@ -284,11 +312,7 @@ async def _anthropic_completion( if params.use_history: # Get current prompt messages prompt_messages = self.history.get(include_completion_history=False) - - # Calculate new conversation messages (excluding prompts) new_messages = messages[len(prompt_messages) :] - - # Update conversation history self.history.set(new_messages) self._log_chat_finished(model=model) diff --git a/tests/integration/sampling/fastagent.config.yaml b/tests/integration/sampling/fastagent.config.yaml index 8c7cffa1e..03d962cba 100644 --- a/tests/integration/sampling/fastagent.config.yaml +++ b/tests/integration/sampling/fastagent.config.yaml @@ -23,7 +23,11 @@ mcp: args: ["run", "sampling_test_server.py"] sampling: model: "passthrough" - + slow_sampling: + command: "uv" + args: ["run", "sampling_test_server.py"] + sampling: + model: "slow" sampling_test_no_config: command: "uv" args: ["run", "sampling_test_server.py"] diff --git a/tests/integration/sampling/live.py b/tests/integration/sampling/live.py index 732aebaae..93bc577ae 100644 --- a/tests/integration/sampling/live.py +++ b/tests/integration/sampling/live.py @@ -7,13 +7,16 @@ # Define the agent -@fast.agent(servers=["sampling_test"]) +@fast.agent(servers=["sampling_test", "slow_sampling"]) async def main(): # use the --model command line switch or agent arguments to change model async with fast.run() as agent: result = await agent.send('***CALL_TOOL sampling_test-sample {"to_sample": "123foo"}') print(f"RESULT: {result}") + result = await agent.send('***CALL_TOOL slow_sampling-sample_parallel') + print(f"RESULT: {result}") + if __name__ == "__main__": asyncio.run(main()) diff --git a/tests/integration/sampling/sampling_test_server.py b/tests/integration/sampling/sampling_test_server.py index b26585c58..89d98d543 100644 --- a/tests/integration/sampling/sampling_test_server.py +++ b/tests/integration/sampling/sampling_test_server.py @@ -61,6 +61,47 @@ async def sample_many(ctx: Context) -> CallToolResult: return CallToolResult(content=[TextContent(type="text", text=str(result))]) +@mcp.tool() +async def sample_parallel(ctx: Context, count: int = 5) -> CallToolResult: + """Tool that makes multiple concurrent sampling requests to test parallel processing""" + try: + logger.info(f"Making {count} concurrent sampling requests") + + # Create multiple concurrent sampling requests + import asyncio + + async def _send_sampling(request: int): + return await ctx.session.create_message( + max_tokens=100, + messages=[SamplingMessage( + role="user", + content=TextContent(type="text", text=f"Parallel request {request+1}") + )], + ) + + + tasks = [] + for i in range(count): + task = _send_sampling(i) + tasks.append(task) + + # Execute all requests concurrently + results = await asyncio.gather(*[_send_sampling(i) for i in range(count)]) + + # Combine results + response_texts = [result.content.text for result in results] + combined_response = f"Completed {len(results)} parallel requests: " + ", ".join(response_texts[:3]) + if len(response_texts) > 3: + combined_response += f"... and {len(response_texts) - 3} more" + + logger.info(f"Parallel sampling completed: {combined_response}") + return CallToolResult(content=[TextContent(type="text", text=combined_response)]) + + except Exception as e: + logger.error(f"Error in sample_parallel tool: {e}", exc_info=True) + return CallToolResult(isError=True, content=[TextContent(type="text", text=f"Error: {str(e)}")]) + + if __name__ == "__main__": logger.info("Starting sampling test server...") mcp.run() From 9faf69a6b739fd3c80b67376ee509dbd3736e62a Mon Sep 17 00:00:00 2001 From: Tom X Nguyen Date: Sun, 1 Jun 2025 18:21:19 +0700 Subject: [PATCH 2/3] feat: Implement Anthropic cache_mode with refined 'auto' behavior --- src/mcp_agent/config.py | 10 ++- src/mcp_agent/core/request_params.py | 9 --- src/mcp_agent/llm/augmented_llm.py | 3 +- .../llm/providers/augmented_llm_anthropic.py | 77 +++++++++++++++---- 4 files changed, 72 insertions(+), 27 deletions(-) diff --git a/src/mcp_agent/config.py b/src/mcp_agent/config.py index 767e65a2a..4000be36c 100644 --- a/src/mcp_agent/config.py +++ b/src/mcp_agent/config.py @@ -113,6 +113,14 @@ class AnthropicSettings(BaseModel): base_url: str | None = None + cache_mode: Literal["off", "prompt", "auto"] = "off" + """ + Controls how caching is applied for Anthropic models when prompt_caching is enabled globally. + - "off": No caching, even if global prompt_caching is true. + - "prompt": Caches the initial system/user prompt. Useful for large, static prompts. + - "auto": Caches the last user message. Default behavior if prompt_caching is true. + """ + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) @@ -291,7 +299,7 @@ class Settings(BaseSettings): Default model for agents. Format is provider.model_name., for example openai.o3-mini.low Aliases are provided for common models e.g. sonnet, haiku, gpt-4.1, o3-mini etc. """ - + auto_sampling: bool = True """Enable automatic sampling model selection if not explicitly configured""" diff --git a/src/mcp_agent/core/request_params.py b/src/mcp_agent/core/request_params.py index 9325d892f..7b087829b 100644 --- a/src/mcp_agent/core/request_params.py +++ b/src/mcp_agent/core/request_params.py @@ -52,12 +52,3 @@ class RequestParams(CreateMessageRequestParams): """ Optional dictionary of template variables for dynamic templates. Currently only works for TensorZero inference backend """ - - prompt_caching: bool = Field( - default=False, description="Enable prompt caching for supported LLM providers." - ) - """ - Enable prompt caching if the underlying LLM provider supports it. - For Anthropic, this adds 'cache_control': {'type': 'ephemeral'} to relevant message parts. - Behavior for other providers will depend on their specific caching mechanisms. - """ diff --git a/src/mcp_agent/llm/augmented_llm.py b/src/mcp_agent/llm/augmented_llm.py index 0b4b94ea8..7258af293 100644 --- a/src/mcp_agent/llm/augmented_llm.py +++ b/src/mcp_agent/llm/augmented_llm.py @@ -95,10 +95,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT PARAM_USE_HISTORY = "use_history" PARAM_MAX_ITERATIONS = "max_iterations" PARAM_TEMPLATE_VARS = "template_vars" - PARAM_PROMPT_CACHING = "prompt_caching" # Base set of fields that should always be excluded - BASE_EXCLUDE_FIELDS = {PARAM_METADATA, PARAM_PROMPT_CACHING} + BASE_EXCLUDE_FIELDS = {PARAM_METADATA} """ The basic building block of agentic systems is an LLM enhanced with augmentations diff --git a/src/mcp_agent/llm/providers/augmented_llm_anthropic.py b/src/mcp_agent/llm/providers/augmented_llm_anthropic.py index d484d341b..526ddc181 100644 --- a/src/mcp_agent/llm/providers/augmented_llm_anthropic.py +++ b/src/mcp_agent/llm/providers/augmented_llm_anthropic.py @@ -63,7 +63,6 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]): AugmentedLLM.PARAM_MAX_ITERATIONS, AugmentedLLM.PARAM_PARALLEL_TOOL_CALLS, AugmentedLLM.PARAM_TEMPLATE_VARS, - AugmentedLLM.PARAM_PROMPT_CACHING, } def __init__(self, *args, **kwargs) -> None: @@ -120,30 +119,64 @@ async def _anthropic_completion( messages.append(message_param) # message_param is the current user turn - # If prompt caching is enabled, apply cache_control to the last content block - # of the current user message (which is now the last in 'messages'). - # This modification affects the 'messages' list that will be used in the first API call within the loop. - if params.prompt_caching and messages: - last_message_in_prompt = messages[-1] # This is a MessageParam (dict) + # Prepare for caching based on provider-specific cache_mode + apply_cache_to_system_prompt = False + messages_to_cache_indices: List[int] = [] + + if self.context.config and self.context.config.anthropic: + cache_mode = self.context.config.anthropic.cache_mode + self.logger.debug(f"Anthropic cache_mode: {cache_mode}") + + if cache_mode == "auto": + apply_cache_to_system_prompt = True # Cache system prompt + if messages: # If there are any messages + messages_to_cache_indices.append( + len(messages) - 1 + ) # Cache only the last message + self.logger.debug( + f"Auto mode: Caching system prompt (if present) and last message at index: {messages_to_cache_indices}" + ) + elif cache_mode == "prompt": + # Find the first user message in the fully constructed messages list + for idx, msg in enumerate(messages): + if isinstance(msg, dict) and msg.get("role") == "user": + messages_to_cache_indices.append(idx) + self.logger.debug( + f"Prompt mode: Caching first user message in constructed prompt at index: {idx}" + ) + break + elif cache_mode == "off": + self.logger.debug("Anthropic cache_mode is 'off'. No caching will be applied.") + else: # Should not happen due to Literal validation + self.logger.warning( + f"Unknown Anthropic cache_mode: {cache_mode}. No caching will be applied." + ) + else: + self.logger.debug("Anthropic settings not found. No caching will be applied.") + + # Apply cache_control to selected messages + for msg_idx in messages_to_cache_indices: + message_to_cache = messages[msg_idx] if ( - isinstance(last_message_in_prompt, dict) - and "content" in last_message_in_prompt - and isinstance(last_message_in_prompt["content"], list) - and last_message_in_prompt["content"] + isinstance(message_to_cache, dict) + and "content" in message_to_cache + and isinstance(message_to_cache["content"], list) + and message_to_cache["content"] ): - last_content_block = last_message_in_prompt["content"][-1] - if isinstance(last_content_block, dict): # Content blocks are dicts + # Apply to the last content block of the message + last_content_block = message_to_cache["content"][-1] + if isinstance(last_content_block, dict): self.logger.debug( - "Prompt caching enabled: Applying cache_control to the last content block of the current user message." + f"Applying cache_control to last content block of message at index {msg_idx}." ) last_content_block["cache_control"] = {"type": "ephemeral"} else: self.logger.warning( - "Could not apply cache_control to current user message: Last content block is not a dictionary." + f"Could not apply cache_control to message at index {msg_idx}: Last content block is not a dictionary." ) else: self.logger.warning( - "Could not apply cache_control to current user message: It has no content, content is not a list, or message is not a dict." + f"Could not apply cache_control to message at index {msg_idx}: Invalid message structure or no content." ) tool_list: ListToolsResult = await self.aggregator.list_tools() @@ -171,6 +204,20 @@ async def _anthropic_completion( "tools": available_tools, } + # Apply cache_control to system prompt for "auto" mode + if apply_cache_to_system_prompt and base_args["system"]: + if isinstance(base_args["system"], str): + base_args["system"] = [ + { + "type": "text", + "text": base_args["system"], + "cache_control": {"type": "ephemeral"}, + } + ] + self.logger.debug( + "Applying cache_control to system prompt by wrapping it in a list of content blocks." + ) + if params.maxTokens is not None: base_args["max_tokens"] = params.maxTokens From 58a891fa4e3e996c914309920ed2fe79773bcc8d Mon Sep 17 00:00:00 2001 From: Tom X Nguyen Date: Wed, 11 Jun 2025 10:26:32 +0700 Subject: [PATCH 3/3] feat(cache): cache last 3 messages for anthropic auto mode --- src/mcp_agent/config.py | 2 +- src/mcp_agent/llm/providers/augmented_llm_anthropic.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/mcp_agent/config.py b/src/mcp_agent/config.py index c5165b850..2d7bad254 100644 --- a/src/mcp_agent/config.py +++ b/src/mcp_agent/config.py @@ -118,7 +118,7 @@ class AnthropicSettings(BaseModel): Controls how caching is applied for Anthropic models when prompt_caching is enabled globally. - "off": No caching, even if global prompt_caching is true. - "prompt": Caches the initial system/user prompt. Useful for large, static prompts. - - "auto": Caches the last user message. Default behavior if prompt_caching is true. + - "auto": Caches the last three messages. Default behavior if prompt_caching is true. """ model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) diff --git a/src/mcp_agent/llm/providers/augmented_llm_anthropic.py b/src/mcp_agent/llm/providers/augmented_llm_anthropic.py index 526ddc181..133111fdd 100644 --- a/src/mcp_agent/llm/providers/augmented_llm_anthropic.py +++ b/src/mcp_agent/llm/providers/augmented_llm_anthropic.py @@ -130,11 +130,12 @@ async def _anthropic_completion( if cache_mode == "auto": apply_cache_to_system_prompt = True # Cache system prompt if messages: # If there are any messages - messages_to_cache_indices.append( - len(messages) - 1 - ) # Cache only the last message + # Cache the last 3 messages + messages_to_cache_indices.extend( + range(max(0, len(messages) - 3), len(messages)) + ) self.logger.debug( - f"Auto mode: Caching system prompt (if present) and last message at index: {messages_to_cache_indices}" + f"Auto mode: Caching system prompt (if present) and last three messages at indices: {messages_to_cache_indices}" ) elif cache_mode == "prompt": # Find the first user message in the fully constructed messages list