diff --git a/a2a/weather_service/src/weather_service/agent.py b/a2a/weather_service/src/weather_service/agent.py index b7574be..3d83afe 100644 --- a/a2a/weather_service/src/weather_service/agent.py +++ b/a2a/weather_service/src/weather_service/agent.py @@ -122,10 +122,20 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): # Here we just run the agent logic - spans from LangChain are auto-captured output = None + # Forward inbound Authorization header to outbound MCP tool calls. + # This enables transparent token exchange when deployed behind a waypoint + # or AuthBridge proxy (same pattern as git_issue_agent, see c8ebde1). + mcp_headers = None + if context.call_context and (context.call_context.state or {}).get("headers", {}).get("authorization"): + mcp_headers = {"Authorization": context.call_context.state["headers"]["authorization"]} + logger.info("Forwarding inbound Authorization header to MCP tool calls") + else: + logger.warning("No inbound Authorization header; MCP tool calls will be unauthenticated") + # Test MCP connection first logger.info(f"Attempting to connect to MCP server at: {os.getenv('MCP_URL', 'http://localhost:8000/sse')}") - mcpclient = get_mcpclient() + mcpclient = get_mcpclient(headers=mcp_headers) # Try to get tools to verify connection try: diff --git a/a2a/weather_service/src/weather_service/graph.py b/a2a/weather_service/src/weather_service/graph.py index cd9e8cf..71a41f8 100644 --- a/a2a/weather_service/src/weather_service/graph.py +++ b/a2a/weather_service/src/weather_service/graph.py @@ -16,15 +16,15 @@ class ExtendedMessagesState(MessagesState): final_answer: str = "" -def get_mcpclient(): - return MultiServerMCPClient( - { - "math": { - "url": os.getenv("MCP_URL", "http://localhost:8000/mcp"), - "transport": os.getenv("MCP_TRANSPORT", "streamable_http"), - } - } - ) +def get_mcpclient(headers=None): + """Create MCP client, optionally forwarding headers (e.g. Authorization).""" + mcp_config = { + "url": os.getenv("MCP_URL", "http://localhost:8000/mcp"), + "transport": os.getenv("MCP_TRANSPORT", "streamable_http"), + } + if headers: + mcp_config["headers"] = headers + return MultiServerMCPClient({"math": mcp_config}) async def get_graph(client) -> StateGraph: