diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index d77db53..03dae28 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -5,7 +5,7 @@ import threading from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Set import boto3 from botocore.config import Config as BotocoreConfig @@ -101,7 +101,6 @@ def __init__( self.config = agentcore_memory_config self.memory_client = MemoryClient(region_name=region_name) session = boto_session or boto3.Session(region_name=region_name) - self.has_existing_agent = False # Override the clients if custom boto session or config is provided # Add strands-agents to the request user agent @@ -122,8 +121,52 @@ def __init__( self.memory_client.gmdp_client = session.client( "bedrock-agentcore", region_name=region_name or session.region_name, config=client_config ) + + # Query existing branches and find root event + self._created_branches: Set[str] = set() + self._root_event_id: str = self._init_branches() + super().__init__(session_id=self.config.session_id, session_repository=self) + def _init_branches(self) -> str: + """Initialize branch tracking. Returns root event ID.""" + try: + # List existing events to find branches + response = self.memory_client.gmdp_client.list_events( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=self.config.session_id, + maxResults=100, + ) + events = response.get("events", []) + + # Find existing branches and root event + root_event_id = None + for event in events: + branch = event.get("branch", {}) + if branch.get("name"): + self._created_branches.add(branch["name"]) + # First event without a branch name is the root + if not branch.get("name") and root_event_id is None: + root_event_id = event.get("eventId") + + # If we found a root event, use it; otherwise create one + if root_event_id: + return root_event_id + + except Exception as e: + logger.debug("No existing events found: %s", e) + + # Create root event for new session + initial_event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=self.config.session_id, + payload=[{"blob": json.dumps({"type": "session_start"})}], + eventTimestamp=self._get_monotonic_timestamp(), + ) + return initial_event.get("event", {}).get("eventId") + def _get_full_session_id(self, session_id: str) -> str: """Get the full session ID with the configured prefix. @@ -156,6 +199,17 @@ def _get_full_agent_id(self, agent_id: str) -> str: ) return full_agent_id + def _get_branch_for_agent(self, agent_id: str) -> dict[str, str]: + """Get branch dict for create_event. + + First event on a branch includes rootEventId to create it. + Subsequent events only need the branch name. + """ + branch_name = f"agent_{agent_id}" + if branch_name in self._created_branches: + return {"name": branch_name} + return {"name": branch_name, "rootEventId": self._root_event_id} + # region SessionRepository interface implementation def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session in AgentCore Memory. @@ -358,24 +412,35 @@ def create_message( original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) + # Get branch for this agent + branch = self._get_branch_for_agent(agent_id) + + # messages[0] is a tuple (text, role) + text, role = messages[0] + if not AgentCoreMemoryConverter.exceeds_conversational_limit(messages[0]): - event = self.memory_client.create_event( - memory_id=self.config.memory_id, - actor_id=self.config.actor_id, - session_id=session_id, - messages=messages, - event_timestamp=monotonic_timestamp, + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=session_id, + payload=[{"conversational": {"content": {"text": text}, "role": role.upper()}}], + eventTimestamp=monotonic_timestamp, + branch=branch, ) else: event = self.memory_client.gmdp_client.create_event( memoryId=self.config.memory_id, actorId=self.config.actor_id, sessionId=session_id, - payload=[ - {"blob": json.dumps(messages[0])}, - ], + payload=[{"blob": text}], eventTimestamp=monotonic_timestamp, + branch=branch, ) + + # Mark branch as created after successful event + branch_name = f"agent_{agent_id}" + self._created_branches.add(branch_name) + logger.debug("Created event: %s for message: %s", event.get("eventId"), session_message.message_id) return event except Exception as e: @@ -449,18 +514,32 @@ def list_messages( try: max_results = (limit + offset) if limit else 100 - events = self.memory_client.list_events( - memory_id=self.config.memory_id, - actor_id=self.config.actor_id, - session_id=session_id, - max_results=max_results, - ) + branch_name = f"agent_{agent_id}" + + # Build filter for branch + params = { + "memoryId": self.config.memory_id, + "actorId": self.config.actor_id, + "sessionId": session_id, + "maxResults": max_results, + "filter": {"branch": {"name": branch_name}}, + } + + response = self.memory_client.gmdp_client.list_events(**params) + events = response.get("events", []) + messages = AgentCoreMemoryConverter.events_to_messages(events) if limit is not None: return messages[offset : offset + limit] else: return messages[offset:] + except self.memory_client.gmdp_client.exceptions.ValidationException as e: + # Branch doesn't exist yet - return empty list + if "Branch not found" in str(e): + return [] + logger.error("Failed to list messages from AgentCore Memory: %s", e) + return [] except Exception as e: logger.error("Failed to list messages from AgentCore Memory: %s", e) return [] @@ -564,12 +643,6 @@ def register_hooks(self, registry: HookRegistry, **kwargs) -> None: @override def initialize(self, agent: "Agent", **kwargs: Any) -> None: - if self.has_existing_agent: - logger.warning( - "An Agent already exists in session %s. We currently support one agent per session.", self.session_id - ) - else: - self.has_existing_agent = True RepositorySessionManager.initialize(self, agent, **kwargs) # endregion RepositorySessionManager overrides diff --git a/test_multi_agent.py b/test_multi_agent.py new file mode 100644 index 0000000..f737466 --- /dev/null +++ b/test_multi_agent.py @@ -0,0 +1,159 @@ +"""Test script for multi-agent session support with coordinator pattern.""" + +from strands import Agent, tool +from strands.models.bedrock import BedrockModel +from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager +from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig + +# Tools for sub-agents +@tool +def get_weather(location: str) -> str: + """Get the current weather for a location.""" + weather_data = { + "seattle": "Rainy, 55°F", "san francisco": "Foggy, 62°F", + "new york": "Sunny, 68°F", "tokyo": "Clear, 72°F", + } + return weather_data.get(location.lower(), f"No weather data for {location}") + +@tool +def search_restaurants(cuisine: str, city: str) -> list: + """Search for restaurants by cuisine type in a city.""" + restaurants = { + ("italian", "seattle"): ["Altura", "Spinasse", "Il Corvo"], + ("japanese", "seattle"): ["Shiro's", "Sushi Kashiba"], + ("mexican", "new york"): ["Cosme", "Atla"], + } + return restaurants.get((cuisine.lower(), city.lower()), [f"No {cuisine} restaurants in {city}"]) + +@tool +def calculate_tip(bill_amount: float, tip_percentage: float = 18.0) -> str: + """Calculate tip and total for a bill.""" + tip = bill_amount * (tip_percentage / 100) + return f"Tip: ${tip:.2f}, Total: ${bill_amount + tip:.2f}" + +@tool +def get_directions(from_loc: str, to_loc: str) -> str: + """Get directions between two locations.""" + return f"Directions from {from_loc} to {to_loc}: Head north, turn right. ~15 min." + +# Global sub-agents (initialized in main) +weather_agent = None +restaurant_agent = None +travel_agent = None + +# Coordinator tools to delegate to sub-agents +@tool +def ask_weather_specialist(question: str) -> str: + """Delegate weather-related questions to the weather specialist.""" + return str(weather_agent(question)) + +@tool +def ask_restaurant_specialist(question: str) -> str: + """Delegate restaurant/food questions to the restaurant specialist.""" + return str(restaurant_agent(question)) + +@tool +def ask_travel_specialist(question: str) -> str: + """Delegate travel/directions questions to the travel specialist.""" + return str(travel_agent(question)) + +def main(): + global weather_agent, restaurant_agent, travel_agent + + print("🚀 Multi-Agent Coordinator Test\n") + + config = AgentCoreMemoryConfig( + memory_id="CustomerSupport-pkP616GF9D", + session_id="coordinator-test-002", + actor_id="coordinator-user", + ) + + print("Creating session manager...") + session_manager = AgentCoreMemorySessionManager( + agentcore_memory_config=config, + region_name="us-west-2", + ) + print(f"Root event ID: {session_manager._root_event_id}\n") + + model = BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0") + + # Create sub-agents + print("Creating sub-agents...") + weather_agent = Agent( + agent_id="weather", + system_prompt="You are a weather specialist. Use get_weather tool. Be concise.", + session_manager=session_manager, + model=model, + tools=[get_weather], + ) + + restaurant_agent = Agent( + agent_id="restaurant", + system_prompt="You are a restaurant specialist. Use search_restaurants and calculate_tip tools. Be concise.", + session_manager=session_manager, + model=model, + tools=[search_restaurants, calculate_tip], + ) + + travel_agent = Agent( + agent_id="travel", + system_prompt="You are a travel specialist. Use get_directions tool. Be concise.", + session_manager=session_manager, + model=model, + tools=[get_directions], + ) + + # Create coordinator + print("Creating coordinator agent...") + coordinator = Agent( + agent_id="coordinator", + system_prompt="""You are a helpful coordinator. Route requests to specialists: +- Weather questions → ask_weather_specialist +- Restaurant/food/tip questions → ask_restaurant_specialist +- Travel/directions questions → ask_travel_specialist +Always use the appropriate specialist tool. Be concise.""", + session_manager=session_manager, + model=model, + tools=[ask_weather_specialist, ask_restaurant_specialist, ask_travel_specialist], + ) + + print(f"Branches created: {session_manager._created_branches}\n") + + # Interactive loop + print("="*60) + print("🤖 Coordinator Ready! (delegates to weather/restaurant/travel)") + print("="*60) + print("Examples: 'weather in Seattle', 'Italian restaurants in Seattle',") + print(" 'calculate 20% tip on $50', 'directions to downtown'") + print("Commands: 'branches', 'messages', 'quit'") + print("="*60 + "\n") + + while True: + try: + user_input = input("You: ").strip() + if not user_input: + continue + if user_input.lower() == 'quit': + break + if user_input.lower() == 'branches': + print(f"Branches: {session_manager._created_branches}\n") + continue + if user_input.lower() == 'messages': + for aid in ["coordinator", "weather", "restaurant", "travel"]: + msgs = session_manager.list_messages(config.session_id, aid) + print(f" {aid}: {len(msgs)} messages") + print() + continue + + response = coordinator(user_input) + print(f"Coordinator: {response}\n") + + except KeyboardInterrupt: + break + except Exception as e: + print(f"Error: {e}\n") + + print("Goodbye!") + +if __name__ == "__main__": + main()