Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 96 additions & 23 deletions src/bedrock_agentcore/memory/integrations/strands/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Set

import boto3
from botocore.config import Config as BotocoreConfig
Expand Down Expand Up @@ -101,7 +101,6 @@ def __init__(
self.config = agentcore_memory_config
self.memory_client = MemoryClient(region_name=region_name)
session = boto_session or boto3.Session(region_name=region_name)
self.has_existing_agent = False

# Override the clients if custom boto session or config is provided
# Add strands-agents to the request user agent
Expand All @@ -122,8 +121,52 @@ def __init__(
self.memory_client.gmdp_client = session.client(
"bedrock-agentcore", region_name=region_name or session.region_name, config=client_config
)

# Query existing branches and find root event
self._created_branches: Set[str] = set()
self._root_event_id: str = self._init_branches()

super().__init__(session_id=self.config.session_id, session_repository=self)

def _init_branches(self) -> str:
"""Initialize branch tracking. Returns root event ID."""
try:
# List existing events to find branches
response = self.memory_client.gmdp_client.list_events(
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=self.config.session_id,
maxResults=100,
)
events = response.get("events", [])

# Find existing branches and root event
root_event_id = None
for event in events:
branch = event.get("branch", {})
if branch.get("name"):
self._created_branches.add(branch["name"])
# First event without a branch name is the root
if not branch.get("name") and root_event_id is None:
root_event_id = event.get("eventId")

# If we found a root event, use it; otherwise create one
if root_event_id:
return root_event_id

except Exception as e:
logger.debug("No existing events found: %s", e)

# Create root event for new session
initial_event = self.memory_client.gmdp_client.create_event(
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=self.config.session_id,
payload=[{"blob": json.dumps({"type": "session_start"})}],
eventTimestamp=self._get_monotonic_timestamp(),
)
return initial_event.get("event", {}).get("eventId")

def _get_full_session_id(self, session_id: str) -> str:
"""Get the full session ID with the configured prefix.

Expand Down Expand Up @@ -156,6 +199,17 @@ def _get_full_agent_id(self, agent_id: str) -> str:
)
return full_agent_id

def _get_branch_for_agent(self, agent_id: str) -> dict[str, str]:
"""Get branch dict for create_event.

First event on a branch includes rootEventId to create it.
Subsequent events only need the branch name.
"""
branch_name = f"agent_{agent_id}"
if branch_name in self._created_branches:
return {"name": branch_name}
return {"name": branch_name, "rootEventId": self._root_event_id}

# region SessionRepository interface implementation
def create_session(self, session: Session, **kwargs: Any) -> Session:
"""Create a new session in AgentCore Memory.
Expand Down Expand Up @@ -358,24 +412,35 @@ def create_message(
original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00"))
monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp)

# Get branch for this agent
branch = self._get_branch_for_agent(agent_id)

# messages[0] is a tuple (text, role)
text, role = messages[0]

if not AgentCoreMemoryConverter.exceeds_conversational_limit(messages[0]):
event = self.memory_client.create_event(
memory_id=self.config.memory_id,
actor_id=self.config.actor_id,
session_id=session_id,
messages=messages,
event_timestamp=monotonic_timestamp,
event = self.memory_client.gmdp_client.create_event(
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=session_id,
payload=[{"conversational": {"content": {"text": text}, "role": role.upper()}}],
eventTimestamp=monotonic_timestamp,
branch=branch,
)
else:
event = self.memory_client.gmdp_client.create_event(
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=session_id,
payload=[
{"blob": json.dumps(messages[0])},
],
payload=[{"blob": text}],
eventTimestamp=monotonic_timestamp,
branch=branch,
)

# Mark branch as created after successful event
branch_name = f"agent_{agent_id}"
self._created_branches.add(branch_name)

logger.debug("Created event: %s for message: %s", event.get("eventId"), session_message.message_id)
return event
except Exception as e:
Expand Down Expand Up @@ -449,18 +514,32 @@ def list_messages(

try:
max_results = (limit + offset) if limit else 100
events = self.memory_client.list_events(
memory_id=self.config.memory_id,
actor_id=self.config.actor_id,
session_id=session_id,
max_results=max_results,
)
branch_name = f"agent_{agent_id}"

# Build filter for branch
params = {
"memoryId": self.config.memory_id,
"actorId": self.config.actor_id,
"sessionId": session_id,
"maxResults": max_results,
"filter": {"branch": {"name": branch_name}},
}

response = self.memory_client.gmdp_client.list_events(**params)
events = response.get("events", [])

messages = AgentCoreMemoryConverter.events_to_messages(events)
if limit is not None:
return messages[offset : offset + limit]
else:
return messages[offset:]

except self.memory_client.gmdp_client.exceptions.ValidationException as e:
# Branch doesn't exist yet - return empty list
if "Branch not found" in str(e):
return []
logger.error("Failed to list messages from AgentCore Memory: %s", e)
return []
except Exception as e:
logger.error("Failed to list messages from AgentCore Memory: %s", e)
return []
Expand Down Expand Up @@ -564,12 +643,6 @@ def register_hooks(self, registry: HookRegistry, **kwargs) -> None:

@override
def initialize(self, agent: "Agent", **kwargs: Any) -> None:
if self.has_existing_agent:
logger.warning(
"An Agent already exists in session %s. We currently support one agent per session.", self.session_id
)
else:
self.has_existing_agent = True
RepositorySessionManager.initialize(self, agent, **kwargs)

# endregion RepositorySessionManager overrides
159 changes: 159 additions & 0 deletions test_multi_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Test script for multi-agent session support with coordinator pattern."""

from strands import Agent, tool
from strands.models.bedrock import BedrockModel
from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager
from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig

# Tools for sub-agents
@tool
def get_weather(location: str) -> str:
"""Get the current weather for a location."""
weather_data = {
"seattle": "Rainy, 55°F", "san francisco": "Foggy, 62°F",
"new york": "Sunny, 68°F", "tokyo": "Clear, 72°F",
}
return weather_data.get(location.lower(), f"No weather data for {location}")

@tool
def search_restaurants(cuisine: str, city: str) -> list:
"""Search for restaurants by cuisine type in a city."""
restaurants = {
("italian", "seattle"): ["Altura", "Spinasse", "Il Corvo"],
("japanese", "seattle"): ["Shiro's", "Sushi Kashiba"],
("mexican", "new york"): ["Cosme", "Atla"],
}
return restaurants.get((cuisine.lower(), city.lower()), [f"No {cuisine} restaurants in {city}"])

@tool
def calculate_tip(bill_amount: float, tip_percentage: float = 18.0) -> str:
"""Calculate tip and total for a bill."""
tip = bill_amount * (tip_percentage / 100)
return f"Tip: ${tip:.2f}, Total: ${bill_amount + tip:.2f}"

@tool
def get_directions(from_loc: str, to_loc: str) -> str:
"""Get directions between two locations."""
return f"Directions from {from_loc} to {to_loc}: Head north, turn right. ~15 min."

# Global sub-agents (initialized in main)
weather_agent = None
restaurant_agent = None
travel_agent = None

# Coordinator tools to delegate to sub-agents
@tool
def ask_weather_specialist(question: str) -> str:
"""Delegate weather-related questions to the weather specialist."""
return str(weather_agent(question))

@tool
def ask_restaurant_specialist(question: str) -> str:
"""Delegate restaurant/food questions to the restaurant specialist."""
return str(restaurant_agent(question))

@tool
def ask_travel_specialist(question: str) -> str:
"""Delegate travel/directions questions to the travel specialist."""
return str(travel_agent(question))

def main():
global weather_agent, restaurant_agent, travel_agent

print("🚀 Multi-Agent Coordinator Test\n")

config = AgentCoreMemoryConfig(
memory_id="CustomerSupport-pkP616GF9D",
session_id="coordinator-test-002",
actor_id="coordinator-user",
)

print("Creating session manager...")
session_manager = AgentCoreMemorySessionManager(
agentcore_memory_config=config,
region_name="us-west-2",
)
print(f"Root event ID: {session_manager._root_event_id}\n")

model = BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0")

# Create sub-agents
print("Creating sub-agents...")
weather_agent = Agent(
agent_id="weather",
system_prompt="You are a weather specialist. Use get_weather tool. Be concise.",
session_manager=session_manager,
model=model,
tools=[get_weather],
)

restaurant_agent = Agent(
agent_id="restaurant",
system_prompt="You are a restaurant specialist. Use search_restaurants and calculate_tip tools. Be concise.",
session_manager=session_manager,
model=model,
tools=[search_restaurants, calculate_tip],
)

travel_agent = Agent(
agent_id="travel",
system_prompt="You are a travel specialist. Use get_directions tool. Be concise.",
session_manager=session_manager,
model=model,
tools=[get_directions],
)

# Create coordinator
print("Creating coordinator agent...")
coordinator = Agent(
agent_id="coordinator",
system_prompt="""You are a helpful coordinator. Route requests to specialists:
- Weather questions → ask_weather_specialist
- Restaurant/food/tip questions → ask_restaurant_specialist
- Travel/directions questions → ask_travel_specialist
Always use the appropriate specialist tool. Be concise.""",
session_manager=session_manager,
model=model,
tools=[ask_weather_specialist, ask_restaurant_specialist, ask_travel_specialist],
)

print(f"Branches created: {session_manager._created_branches}\n")

# Interactive loop
print("="*60)
print("🤖 Coordinator Ready! (delegates to weather/restaurant/travel)")
print("="*60)
print("Examples: 'weather in Seattle', 'Italian restaurants in Seattle',")
print(" 'calculate 20% tip on $50', 'directions to downtown'")
print("Commands: 'branches', 'messages', 'quit'")
print("="*60 + "\n")

while True:
try:
user_input = input("You: ").strip()
if not user_input:
continue
if user_input.lower() == 'quit':
break
if user_input.lower() == 'branches':
print(f"Branches: {session_manager._created_branches}\n")
continue
if user_input.lower() == 'messages':
for aid in ["coordinator", "weather", "restaurant", "travel"]:
msgs = session_manager.list_messages(config.session_id, aid)
print(f" {aid}: {len(msgs)} messages")
print()
continue

response = coordinator(user_input)
print(f"Coordinator: {response}\n")

except KeyboardInterrupt:
break
except Exception as e:
print(f"Error: {e}\n")

print("Goodbye!")

if __name__ == "__main__":
main()
Loading