Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions packages/magentic-marketplace/src/magentic_marketplace/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def run_experiment_command(args):
export_sqlite=args.export,
export_dir=args.export_dir,
export_filename=args.export_filename,
drop_empty_fetch_messages=args.drop_empty_fetch_messages,
drop_all_fetch_messages=args.drop_all_fetch_messages,
)
)

Expand Down Expand Up @@ -297,6 +299,19 @@ def main():
help="Output filename for SQLite export (default: <experiment_name>.db). Only used with --export.",
)

# Add mutually exclusive group for fetch messages persistence options
fetch_messages_group = experiment_parser.add_mutually_exclusive_group()
fetch_messages_group.add_argument(
"--drop-empty-fetch-messages",
action="store_true",
help="Don't save empty FetchMessages actions to the database (saves only non-empty fetches).",
)
fetch_messages_group.add_argument(
"--drop-all-fetch-messages",
action="store_true",
help="Don't save any FetchMessages actions to the database.",
)

# analytics subcommand
analytics_parser = subparsers.add_parser(
"analyze", help="Analyze marketplace simulation data"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
load_customers_from_yaml,
)
from magentic_marketplace.marketplace.agents import BusinessAgent, CustomerAgent
from magentic_marketplace.marketplace.protocol.protocol import SimpleMarketplaceProtocol
from magentic_marketplace.marketplace.protocol.protocol import (
FetchMessagesPersistence,
SimpleMarketplaceProtocol,
)
from magentic_marketplace.platform.database import (
connect_to_postgresql_database,
)
Expand All @@ -35,6 +38,8 @@ async def run_marketplace_experiment(
export_sqlite: bool = False,
export_dir: str | None = None,
export_filename: str | None = None,
drop_empty_fetch_messages: bool = False,
drop_all_fetch_messages: bool = False,
):
"""Run a marketplace experiment using YAML configuration files."""
# Load businesses and customers from YAML files
Expand Down Expand Up @@ -69,8 +74,18 @@ def database_factory():
server_port = s.getsockname()[1]
print(f"Auto-assigned server port: {server_port}")

# Determine fetch messages persistence mode based on CLI flags
if drop_all_fetch_messages:
fetch_messages_persistence = FetchMessagesPersistence.NONE
elif drop_empty_fetch_messages:
fetch_messages_persistence = FetchMessagesPersistence.NON_EMPTY
else:
fetch_messages_persistence = FetchMessagesPersistence.ALL

marketplace_launcher = MarketplaceLauncher(
protocol=SimpleMarketplaceProtocol(),
protocol=SimpleMarketplaceProtocol(
fetch_messages_persistence=fetch_messages_persistence
),
database_factory=database_factory,
host=server_host,
port=server_port,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""FetchMessages action implementation for the simple marketplace."""

from enum import Enum

from magentic_marketplace.platform.database.base import BaseDatabaseController
from magentic_marketplace.platform.database.models import ActionRow
from magentic_marketplace.platform.database.queries.base import (
Expand All @@ -15,11 +17,21 @@
from ..database import queries


class FetchMessagesPersistence(str, Enum):
"""Enum for controlling FetchMessages action persistence."""

ALL = "all" # Save all fetch messages actions
NON_EMPTY = "non_empty" # Save only non-empty fetch messages
NONE = "none" # Don't save any fetch messages actions


async def execute_fetch_messages(
fetch_messages: FetchMessages,
agent: AgentProfile,
database: BaseDatabaseController,
) -> ActionExecutionResult:
agent_last_fetch_messages_count: dict[str, int],
persistence: FetchMessagesPersistence,
) -> tuple[ActionExecutionResult, bool]:
"""Execute a fetch messages action.

This function implements the message fetching functionality that was previously
Expand All @@ -29,21 +41,35 @@ async def execute_fetch_messages(
fetch_messages: The fetch messages action containing query parameters
agent: The agent fetching messages
database: Database controller for accessing data
agent_last_fetch_messages_count: Dictionary mapping agent ids to the length of the messages returned by the last fetch
persistence: Strategy for when to persist FetchMessages actions

Returns:
ActionExecutionResult containing the fetched messages
A tuple of (ActionExecutionResult, bool) where:
- ActionExecutionResult contains the fetched messages
- bool indicates whether the action should be persisted

"""
messages, has_more = await _fetch_messages_from_database(
fetch_messages, agent, database
)

response = FetchMessagesResponse(
messages=messages,
has_more=has_more,
)
last_messages_count = agent_last_fetch_messages_count[agent.id]
agent_last_fetch_messages_count[agent.id] = len(messages)

content = FetchMessagesResponse(messages=messages, has_more=has_more)
result = ActionExecutionResult(content=content.model_dump(mode="json"))

# Determine if we should persist this FetchMessages action based on the persistence mode
if persistence == FetchMessagesPersistence.ALL:
persist = True
elif persistence == FetchMessagesPersistence.NON_EMPTY:
# Are there more messages now than before?
persist = last_messages_count < len(messages)
else: # FetchMessagesPersistence.NONE
persist = False

return ActionExecutionResult(content=response.model_dump(mode="json"))
return result, persist


async def _fetch_messages_from_database(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Simple marketplace protocol implementation."""

from collections import defaultdict

from magentic_marketplace.platform.database.base import BaseDatabaseController
from magentic_marketplace.platform.protocol.base import BaseMarketplaceProtocol
from magentic_marketplace.platform.shared.models import (
Expand All @@ -14,16 +16,30 @@
Search,
SendMessage,
)
from .fetch_messages import execute_fetch_messages
from .fetch_messages import FetchMessagesPersistence, execute_fetch_messages
from .search import execute_search
from .send_message import execute_send_message


class SimpleMarketplaceProtocol(BaseMarketplaceProtocol):
"""Marketplace protocol."""

def __init__(self):
"""Initialize the marketplace protocol."""
def __init__(
self,
fetch_messages_persistence: FetchMessagesPersistence = FetchMessagesPersistence.ALL,
):
"""Initialize the marketplace protocol.

Args:
fetch_messages_persistence: Controls which FetchMessages actions are persisted to database.
- ALL (default): Save all FetchMessages actions
- NON_EMPTY: Save only FetchMessages that returned messages
- NONE: Don't save any FetchMessages actions

"""
self.fetch_messages_persistence = fetch_messages_persistence
# Track how many messages were fetched by an agent in the last count, use it to determine if "new" messages were provided or not
self._to_agent_id_last_fetch_messages_count: dict[str, int] = defaultdict(int)

def get_actions(self):
"""Define available actions in the marketplace."""
Expand All @@ -35,19 +51,32 @@ async def execute_action(
agent: AgentProfile,
action: ActionExecutionRequest,
database: BaseDatabaseController,
) -> ActionExecutionResult:
"""Execute an action."""
) -> tuple[ActionExecutionResult, bool]:
"""Execute an action.

Returns:
A tuple of (ActionExecutionResult, bool) where:
- ActionExecutionResult contains the action execution result
- bool indicates whether the action should be persisted to the database

"""
parsed_action = ActionAdapter.validate_python(action.parameters)

if isinstance(parsed_action, SendMessage):
return await execute_send_message(parsed_action, database)
return await execute_send_message(parsed_action, database), True

elif isinstance(parsed_action, FetchMessages):
return await execute_fetch_messages(parsed_action, agent, database)
return await execute_fetch_messages(
parsed_action,
agent,
database,
self._to_agent_id_last_fetch_messages_count,
self.fetch_messages_persistence,
)

elif isinstance(parsed_action, Search):
return await execute_search(
search=parsed_action, agent=agent, database=database
)
), True
else:
raise ValueError(f"Unknown action type: {parsed_action.type}")
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ async def execute_action(
agent: AgentProfile,
action: ActionExecutionRequest,
database: BaseDatabaseController,
) -> ActionExecutionResult:
"""Execute a specific action with the given name and parameters."""
) -> tuple[ActionExecutionResult, bool]:
"""Execute a specific action with the given name and parameters.

Returns:
A tuple of (ActionExecutionResult, bool) where:
- ActionExecutionResult contains the action execution result
- bool indicates whether the action should be persisted to the database

"""
...
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,23 @@ async def execute_action(
agent_with_id = authenticated_agent.data.model_copy()
agent_with_id.id = authenticated_agent.id # TODO why is this necessary?

result = await protocol.execute_action(
result, persist = await protocol.execute_action(
agent=agent_with_id,
action=request,
database=db,
)

action_data = ActionRowData(
agent_id=authenticated_agent.id, request=request, result=result
)
db_action = ActionRow(
id="", # auto-generated by DB
created_at=datetime.now(UTC),
data=action_data,
)
if persist:
action_data = ActionRowData(
agent_id=authenticated_agent.id, request=request, result=result
)
db_action = ActionRow(
id="", # auto-generated by DB
created_at=datetime.now(UTC),
data=action_data,
)

await db.actions.create(db_action)
await db.actions.create(db_action)

return result

Expand Down