diff --git a/interactive_search/README.md b/interactive_search/README.md new file mode 100644 index 00000000..0a693484 --- /dev/null +++ b/interactive_search/README.md @@ -0,0 +1,26 @@ +# Script for Working with the Marketplace Search API + +This directory provides convenience scripts for interacting with the search API of the +Marketplace, and for running basic evaluations. + +### Interactive Search + +To run the interactive search script, use the following command: + +```bash +python interactive_search.py --data-dir ../data/mexican_3_9 +``` + +### Evaluation + +Evaluation is done by randomly sampling a menu item from EACH restaurant in the dataset, +and for each item, determining the rank of the FIRST restaurant that contains that item. + +From these ranks we compute [mean reciprocal rank](https://en.wikipedia.org/wiki/Mean_reciprocal_rank) (MRR) as the evaluation metric. + +Values closer to 1.0 are better, with 1.0 being perfect. + +To run the evaluation script, use the following command: +``` +python menu_mrr.py --data-dir ../data/mexican_3_9 +``` \ No newline at end of file diff --git a/interactive_search/interactive_search.py b/interactive_search/interactive_search.py new file mode 100644 index 00000000..538b09b0 --- /dev/null +++ b/interactive_search/interactive_search.py @@ -0,0 +1,85 @@ +"""A simple interactive search client for the agentic-economics marketplace.""" + +import argparse +import asyncio + +from search_launcher import SearchMarketLauncher + + +async def main( + data_dir: str, + postgres_host: str, + postgres_port: int, + postgres_password: str, + search_algorithm: str = "lexical", + show_all_searchable_text: bool = False, +) -> None: + """Run a simple interactive search client for the agentic-economics marketplace.""" + search_launcher = SearchMarketLauncher( + data_dir=data_dir, + postgres_host=postgres_host, + postgres_port=postgres_port, + postgres_password=postgres_password, + search_algorithm=search_algorithm, + ) + + async with search_launcher.start() as _: + # Start interactive search loop + await search_launcher.interactive_search( + show_all_searchable_text=show_all_searchable_text + ) + + # Shutdown + await search_launcher.stop() + + +if __name__ == "__main__": + # Add argument parsing for data dir using argparse + parser = argparse.ArgumentParser( + description="Interactive search client for the multi-agent-marketplace" + ) + parser.add_argument( + "--data-dir", help="Path to the dataset directory", required=True + ) + parser.add_argument( + "--search-algorithm", + default="lexical", + help="Search algorithm to use (default: lexical)", + ) + parser.add_argument( + "--postgres-host", + default="localhost", + help="PostgreSQL host (default: localhost)", + ) + + parser.add_argument( + "--postgres-port", + type=int, + default=5432, + help="PostgreSQL port (default: 5432)", + ) + + parser.add_argument( + "--postgres-password", + default="postgres", + help="PostgreSQL password (default: postgres)", + ) + + parser.add_argument( + "--show-all-searchable-text", + action="store_true", + help="Show all searchable text for searched businesses for debugging (default: False)", + ) + + args = parser.parse_args() + + asyncio.run( + main( + args.data_dir, + args.postgres_host, + args.postgres_port, + args.postgres_password, + args.search_algorithm, + args.show_all_searchable_text, + ) + ) diff --git a/interactive_search/menu_mrr.py b/interactive_search/menu_mrr.py new file mode 100644 index 00000000..ab1ba762 --- /dev/null +++ b/interactive_search/menu_mrr.py @@ -0,0 +1,126 @@ +"""Run search queries based on menu items and evaluate using Mean Reciprocal Rank (MRR).""" + +import argparse +import asyncio +import random +from pathlib import Path + +from magentic_marketplace.experiments.utils.yaml_loader import load_businesses_from_yaml +from magentic_marketplace.marketplace.shared.models import BusinessAgentProfile +from search_launcher import SearchMarketLauncher + + +def has_items(items: list[str], business: BusinessAgentProfile) -> bool: + """Check if the business has all the specified menu items.""" + menu_items = set(business.business.menu_features.keys()) + return set(items).issubset(menu_items) + + +async def main( + data_dir: str, + postgres_host: str, + postgres_port: int, + postgres_password: str, + order_size: int = 1, + search_algorithm: str = "lexical", +): + """Evaluate search functionality using Mean Reciprocal Rank (MRR) based on menu items.""" + # Start the search market server before running this script + search_launcher = SearchMarketLauncher( + data_dir=data_dir, + postgres_host=postgres_host, + postgres_port=postgres_port, + postgres_password=postgres_password, + search_algorithm=search_algorithm, + ) + async with search_launcher.start() as _: + reciprocal_ranks: list[float] = [] + businesses_dir = Path(args.data_dir) / "businesses" + + businesses = load_businesses_from_yaml(businesses_dir) + + for business in businesses: + menu_items = list(business.menu_features.keys()) + sampled_items = random.sample(menu_items, min(order_size, len(menu_items))) + + query = " ".join(sampled_items) + print(f"Searching for: {query}") + results = await search_launcher.search(query=query) + + found = False + rank = 0 + for rank in range(len(results)): + result = results[rank] + if has_items(sampled_items, result): + print( + f"Found matching business at rank {rank + 1}: {result.business.name}" + ) + found = True + break + if found: + reciprocal_ranks.append(1 / (rank + 1)) + else: + print(f"No matching business found in top {len(results)} results.") + reciprocal_ranks.append(1 / (len(results) + 1)) + print() + + print("--- Evaluation Complete ---") + print( + f"Mean Reciprocal Rank (MRR): {sum(reciprocal_ranks) / len(reciprocal_ranks) if reciprocal_ranks else 0:.4f}" + ) + + # Shutdown + print("Shutting down...") + await search_launcher.stop() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Compute Mean Reciprocal Rank (MRR) for search queries based on menu items" + ) + parser.add_argument( + "--data-dir", help="Path to the dataset directory", required=True + ) + parser.add_argument( + "--order-size", + type=int, + default=1, + help="Number of menu items to include in each search query (default: 1)", + ) + parser.add_argument( + "--search-algorithm", + type=str, + default="lexical", + help="Search algorithm to use (default: lexical)", + ) + parser.add_argument( + "--postgres-host", + default="localhost", + help="PostgreSQL host (default: localhost)", + ) + + parser.add_argument( + "--postgres-port", + type=int, + default=5432, + help="PostgreSQL port (default: 5432)", + ) + + parser.add_argument( + "--postgres-password", + default="postgres", + help="PostgreSQL password (default: postgres)", + ) + + args = parser.parse_args() + + asyncio.run( + main( + args.data_dir, + args.postgres_host, + args.postgres_port, + args.postgres_password, + args.order_size, + args.search_algorithm, + ) + ) diff --git a/interactive_search/search_launcher.py b/interactive_search/search_launcher.py new file mode 100644 index 00000000..db8f7e8a --- /dev/null +++ b/interactive_search/search_launcher.py @@ -0,0 +1,231 @@ +"""Module to launch a marketplace server with businesses and a single customer agent for interactive search experiments.""" + +import asyncio +import traceback +from contextlib import asynccontextmanager +from datetime import datetime +from pathlib import Path + +import requests +from magentic_marketplace.experiments.utils.yaml_loader import ( + load_businesses_from_yaml, + load_customers_from_yaml, +) +from magentic_marketplace.marketplace.actions.actions import ( + Search, + SearchAlgorithm, + SearchResponse, +) +from magentic_marketplace.marketplace.agents.business.agent import BusinessAgent +from magentic_marketplace.marketplace.agents.customer.agent import CustomerAgent +from magentic_marketplace.marketplace.protocol.protocol import SimpleMarketplaceProtocol +from magentic_marketplace.marketplace.shared.models import BusinessAgentProfile +from magentic_marketplace.platform.database.postgresql.postgresql import ( + connect_to_postgresql_database, +) +from magentic_marketplace.platform.launcher import MarketplaceLauncher + + +class SearchMarketLauncher: + """Class to manage launching the marketplace server for interactive search experiments with a set of business agents and a single customer agent.""" + + def __init__( + self, + data_dir: str, + postgres_host: str, + postgres_port: int, + postgres_password: str, + search_algorithm: str = "lexical", + ): + """Initialize the launcher with empty lists for agents and tasks.""" + self.business_agents = [] + self.customer_agent = None + self.tasks = [] + self.marketplace_launcher = None + + # Get the SearchAlgorithm enum value from the string value provided + if search_algorithm.lower() == "lexical": + self.search_algorithm = SearchAlgorithm.LEXICAL + elif search_algorithm.lower() == "optimal": + self.search_algorithm = SearchAlgorithm.OPTIMAL + elif search_algorithm.lower() == "filtered": + self.search_algorithm = SearchAlgorithm.FILTERED + elif search_algorithm.lower() == "simple": + self.search_algorithm = SearchAlgorithm.SIMPLE + elif search_algorithm.lower() == "rnr": + self.search_algorithm = SearchAlgorithm.RNR + else: + raise ValueError(f"Invalid search algorithm: {search_algorithm}") + + self.search_algorithm = search_algorithm + + self.business_profiles = [] + self.customer_profiles = [] + + self.load_data( + data_dir=data_dir, + postgres_host=postgres_host, + postgres_port=postgres_port, + postgres_password=postgres_password, + ) + + def load_data( + self, + data_dir: str, + postgres_host: str, + postgres_port: int, + postgres_password: str, + ): + """Load businesses and customers from YAML files.""" + businesses_dir = Path(data_dir) / "businesses" + customers_dir = Path(data_dir) / "customers" + + print(f"Loading data from: {data_dir}") + self.business_profiles = load_businesses_from_yaml(businesses_dir) + self.customer_profiles = load_customers_from_yaml(customers_dir) + + print(f"Loaded {len(self.business_profiles)} businesses") + print(f"Loaded {len(self.customer_profiles)} customers") + + experiment_name = f"marketplace_interactive_search_{len(self.business_profiles)}_{int(datetime.now().timestamp() * 1000)}" + + def database_factory(): + return connect_to_postgresql_database( + schema=experiment_name, + host=postgres_host, + port=postgres_port, + password=postgres_password, + mode="create_new", + ) + + self.marketplace_launcher = MarketplaceLauncher( + protocol=SimpleMarketplaceProtocol(), + database_factory=database_factory, + server_log_level="warning", + experiment_name=experiment_name, + ) + + print( + f"Using protocol: {self.marketplace_launcher.protocol.__class__.__name__}" + ) + + @asynccontextmanager + async def start(self): + """Startup platform, businesses, and customer task.""" + async with self.marketplace_launcher: + print( + f"Marketplace server running at: {self.marketplace_launcher.server_url}" + ) + + # Create agents from loaded profiles + business_agents = [ + BusinessAgent(business, self.marketplace_launcher.server_url) + for business in self.business_profiles + ] + self.business_agents.extend(business_agents) + + # only create one customer agent for interactive search + customer_agent = CustomerAgent( + self.customer_profiles[0], + self.marketplace_launcher.server_url, + search_algorithm=self.search_algorithm, + ) + self.customer_agent = customer_agent + + # Create agent launcher and run agents with dependency management + try: + # Startup business agents tasks only + primary_tasks = [ + asyncio.create_task(agent.run()) for agent in business_agents + ] + self.tasks.extend(primary_tasks) + print(f"Started {len(primary_tasks)} tasks for business agents") + + # Startup customer agent task + customer_task = asyncio.create_task(customer_agent.run()) + self.tasks.append(customer_task) + + print("Started task for customer agent") + + await asyncio.sleep(1) + except KeyboardInterrupt: + print("Marketplace interrupted by user") + + yield self.marketplace_launcher + + async def stop(self): + """Stop all running tasks and agents.""" + for agent in self.business_agents: + agent.shutdown() + + if self.customer_agent: + self.customer_agent.shutdown() + + await asyncio.sleep(0.1) + + await asyncio.gather(*self.tasks) + + await asyncio.sleep(0.2) + + async def search(self, query) -> list[BusinessAgentProfile]: + """Issue search queries using the customer agent and return the resulting business profiles.""" + try: + response = await self.customer_agent.execute_action( + Search( + query=query, + search_algorithm=self.search_algorithm, + limit=10, + page=1, + ) + ) + if response.is_error: + print(f"Search action failed: {response.error_message}") + return + + parsed_response = SearchResponse.model_validate(response.content) + businesses_results = parsed_response.businesses + + # Print results + return businesses_results + + except requests.RequestException as e: + print(f"Request failed: {e}") + traceback.print_exc() + + async def interactive_search(self, show_all_searchable_text: bool = False): + """Issue search queries interactively from the customer agent.""" + while True: + query = input("Query (or 'exit' to quit): ") + if query.lower() == "exit": + break + try: + print(f"Searching for: {query}") + response = await self.customer_agent.execute_action( + Search( + query=query, + search_algorithm=self.search_algorithm, + limit=10, + page=1, + ) + ) + if response.is_error: + print(f"Search action failed: {response.error_message}") + continue + + print("Search action succeeded") + parsed_response = SearchResponse.model_validate(response.content) + businesses_results = parsed_response.businesses + print(f"Found {len(businesses_results)} businesses:") + + # Print results + for b in businesses_results: + if show_all_searchable_text: + print( + f"- {b.business.name}: {b.business.get_searchable_text()}" + ) + else: + print(f"- {b.business.name}") + + except requests.RequestException as e: + print(f"Request failed: {e}") + traceback.print_exc()