diff --git a/.gitignore b/.gitignore
index f6ea22a78..e5ecee28f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -167,3 +167,8 @@ redis-data/*
sotopia/cli/install/redis-data/*
redis-stack-server-*/
examples/experimental/negotiation_arena/redis-data/*
+*.rdb
+*.dot
+logs/*
+experiments/rosters/*
+experiments/results/*
diff --git a/docs/pages/experimental/_meta.json b/docs/pages/experimental/_meta.json
index 68a2f45e1..89121eb91 100644
--- a/docs/pages/experimental/_meta.json
+++ b/docs/pages/experimental/_meta.json
@@ -4,5 +4,8 @@
},
"agents": {
"title": "Agents"
+ },
+ "social_game": {
+ "title": "Social Game Engine"
}
}
diff --git a/docs/pages/experimental/index.mdx b/docs/pages/experimental/index.mdx
index 7bcbe8fd3..e00c55fbf 100644
--- a/docs/pages/experimental/index.mdx
+++ b/docs/pages/experimental/index.mdx
@@ -33,6 +33,7 @@ The experimental APIs are in different states:
Here are the experimental APIs:
- [Agents](/experimental/agents) (*implemented*): aact-based asynchronous agents that don't follow OpenAI Gym's turn-based formulation.
+- [Social Game Engine](/experimental/social_game) (*implemented*): Engine for complex multi-agent social deduction games (e.g., Werewolves).
- Engines (*planned*): aact-based asynchronous environment engines. This would include
- [Orchestrator](https://github.com/sotopia-lab/sotopia/issues/231): an engine base class for engines that dictates the orders and turns of the agents.
- [Evaluator](https://github.com/sotopia-lab/sotopia/issues/232): an engine base class for engines that evaluates the agents' performance.
diff --git a/docs/pages/experimental/social_game.mdx b/docs/pages/experimental/social_game.mdx
new file mode 100644
index 000000000..a135fec4b
--- /dev/null
+++ b/docs/pages/experimental/social_game.mdx
@@ -0,0 +1,42 @@
+# Social Game Engine
+
+The Social Game Engine is a new experimental module in Sotopia designed for creating complex, multi-agent social simulations with structured phases, roles, and secret information.
+
+## Overview
+
+Unlike standard dyadic interactions, social games often involve:
+- **Multiple Agents**: More than 2 agents interacting simultaneously.
+- **Roles & Teams**: Agents have distinct roles (e.g., Villager, Werewolf) and conflicting goals.
+- **Dynamic Eras/Phases**: Games progress through distinct states (e.g., Day Discussion, Night Action).
+- **Private Information**: Agents have secrets and limited visibility of others' actions.
+
+## Key Classes
+
+### `SocialGame`
+The abstract base class for any multiplayer social game. It handles:
+- **Turn Management**: Supports `round-robin` (sequential) or `simultaneous` (parallel) actions.
+- **State Transitions**: Manages the flow of the game through defined states.
+
+### `SocialDeductionGame`
+A subclass specialized for games like Werewolves, Undercover, or Spyfall.
+- **Action Masking**: Enforces who can act in which phase.
+- **Visibility System**: Controls who sees what messages (Public, Team-only, Private).
+- **Environment Notifications**: Automatically broadcasts state changes to all agents, ensuring valid context even during private phases.
+
+## Example: Duskmire Werewolves
+
+We have implemented a full working example of a 6-player Werewolves game using this engine.
+
+
+You can find the code and run the example at `examples/experimental/werewolves`.
+See the [README](https://github.com/sotopia-lab/sotopia/tree/main/examples/experimental/werewolves/README.md) in that directory for details.
+
+
+### Key Features Demonstrated
+- **Sequential Discussion**: Agents speak one-by-one during the day, referencing previous speakers.
+- **Hidden Roles**: Role information is concealed upon elimination.
+- **Complex Logic**: Seer inspections, Witch potions, and Werewolf voting integration.
+
+## Usage
+
+To use the Social Game engine, you typically subclass `SocialDeductionGame`, define your states in `config.json`, and implement custom `EnvironmentProfile`s.
diff --git a/examples/experimental/games/__init__.py b/examples/experimental/games/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/experimental/games/prisoners_dilemma/__init__.py b/examples/experimental/games/prisoners_dilemma/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/experimental/games/prisoners_dilemma/config.json b/examples/experimental/games/prisoners_dilemma/config.json
new file mode 100644
index 000000000..0b664338d
--- /dev/null
+++ b/examples/experimental/games/prisoners_dilemma/config.json
@@ -0,0 +1,45 @@
+{
+ "scenario": "Prisoner's Dilemma",
+ "description": "Two players must secretly choose to Cooperate or Defect. You play exactly 5 rounds. You cannot speak to the other player. You can only choose your action. The payoffs depend on the combination of choices. COOPERATE/COOPERATE: 3 points each. DEFECT/COOPERATE: Defector gets 5, Cooperator gets 0. DEFECT/DEFECT: 1 point each. GOAL: Maximize your own score over the 5 rounds. IMPORTANT: If both players score less than 10 points in total, it is a DRAW (0 points). Otherwise, the player with the higher score wins and the other loses. \n\n The players in this game are: {agent_names}.",
+ "role_goals": {
+ "Player": "Maximize your own score."
+ },
+ "role_secrets": {
+ "Player": "You can choose to Cooperate or Defect."
+ },
+ "payoff_matrix": {
+ "Cooperate": {
+ "Cooperate": [
+ 3,
+ 3
+ ],
+ "Defect": [
+ 0,
+ 5
+ ]
+ },
+ "Defect": {
+ "Cooperate": [
+ 5,
+ 0
+ ],
+ "Defect": [
+ 1,
+ 1
+ ]
+ }
+ },
+ "initial_state": "Next Round",
+ "state_transition": {
+ "Next Round": "Next Round"
+ },
+ "state_properties": {
+ "Next Round": {
+ "actions": [
+ "action"
+ ],
+ "action_order": "simultaneous",
+ "visibility": "public"
+ }
+ }
+}
diff --git a/examples/experimental/games/prisoners_dilemma/main.py b/examples/experimental/games/prisoners_dilemma/main.py
new file mode 100644
index 000000000..b77f05976
--- /dev/null
+++ b/examples/experimental/games/prisoners_dilemma/main.py
@@ -0,0 +1,461 @@
+"""Launcher for the Prisoner's Dilemma game."""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import redis
+from rich.logging import RichHandler
+
+from sotopia.agents import LLMAgent
+from sotopia.agents.llm_agent import Agents
+from sotopia.database.persistent_profile import (
+ AgentProfile,
+ EnvironmentProfile,
+ RelationshipType,
+)
+from sotopia.envs.evaluators import SocialGameEndEvaluator
+from sotopia.envs.social_game import (
+ ActionHandler,
+ SocialDeductionGame,
+ SOCIAL_GAME_PROMPT_TEMPLATE,
+ load_config,
+)
+from sotopia.messages import AgentAction, Message, Observation, SimpleMessage
+from sotopia.server import arun_one_episode
+
+BASE_DIR = Path(__file__).resolve().parent
+CONFIG_PATH = BASE_DIR / "config.json"
+
+os.environ.setdefault("REDIS_OM_URL", "redis://:@localhost:6379")
+redis.Redis(host="localhost", port=6379)
+
+# Loggers (configured in main if running standalone)
+_gen_logger = logging.getLogger("sotopia.generation")
+_env_logger = logging.getLogger("sotopia.envs.social_game")
+
+
+# ============================================================================
+# Evaluator
+# ============================================================================
+
+
+class PrisonersDilemmaEvaluator(SocialGameEndEvaluator):
+ def __call__(
+ self, turn_number: int, messages: List[Tuple[str, Message]], **kwargs: Any
+ ) -> List[Tuple[str, Tuple[Tuple[str, int | float | bool], str]]]:
+ if turn_number >= self.max_turn_number:
+ env = kwargs.get("env")
+ if env:
+ scores = env.internal_state.get("scores", {})
+ response: List[
+ Tuple[str, Tuple[Tuple[str, int | float | bool], str]]
+ ] = [("environment", (("terminated", True), "Max (5) turns reached"))]
+
+ agent_names = list(env.agents)
+ scores_dict = {name: scores.get(name, 0) for name in agent_names}
+
+ # Logic:
+ # 1. If both < 10 -> Draw (0)
+ # 2. Otherwise -> Highest score wins (1), Loser (-1)
+ # 3. Tie >= 10 -> Draw (0)
+
+ values = list(scores_dict.values())
+ if len(values) == 2:
+ s1, s2 = values[0], values[1]
+ n1, n2 = agent_names[0], agent_names[1]
+
+ rewards = {n1: 0.0, n2: 0.0}
+
+ if s1 < 10 and s2 < 10:
+ # Draw (both failed threshold)
+ pass
+ elif s1 > s2:
+ rewards[n1] = 1.0
+ rewards[n2] = -1.0
+ elif s2 > s1:
+ rewards[n1] = -1.0
+ rewards[n2] = 1.0
+ else:
+ # Tie and at least one >= 10 (which implies both >= 10)
+ pass
+
+ for agent_name in agent_names:
+ try:
+ idx = agent_names.index(agent_name)
+ key = f"agent_{idx+1}"
+ raw_score = scores_dict[agent_name]
+ reward = rewards[agent_name]
+ response.append(
+ (
+ key,
+ (
+ ("complete_rating", reward),
+ f"Final Score: {raw_score}",
+ ),
+ )
+ )
+ except ValueError:
+ continue
+ return response
+
+ # Fallback for != 2 agents
+ for agent_name in agent_names:
+ score = scores.get(agent_name, 0)
+ idx = agent_names.index(agent_name)
+ key = f"agent_{idx+1}"
+ response.append(
+ (key, (("complete_rating", score), f"Final Score: {score}"))
+ )
+ return response
+ return [("environment", (("terminated", True), "Max turns reached"))]
+
+ return [("environment", (("terminated", False), ""))]
+
+
+# ============================================================================
+# Action Handler
+# ============================================================================
+
+
+class PrisonersDilemmaActionHandler(ActionHandler):
+ def handle_action(
+ self, env: SocialDeductionGame, agent_name: str, action: AgentAction
+ ) -> None:
+ if isinstance(env, PrisonersDilemmaEnv):
+ if action.action_type in ["action", "speak"]:
+ move = action.argument.lower()
+ current_move = None
+ if "defect" in move:
+ current_move = "Defect"
+ elif "cooperate" in move:
+ current_move = "Cooperate"
+
+ if current_move:
+ env.internal_state["current_moves"][agent_name] = current_move
+
+ def get_action_instruction(self, env: SocialDeductionGame, agent_name: str) -> str:
+ return "You are playing Prisoner's Dilemma. Choose to 'action: cooperate' or 'action: defect'. You cannot speak."
+
+
+# ============================================================================
+# Environment
+# ============================================================================
+
+
+class PrisonersDilemmaEnv(SocialDeductionGame):
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(action_handler=PrisonersDilemmaActionHandler(), **kwargs)
+ self.internal_state: Dict[str, Any] = {
+ "round": 0,
+ "scores": {},
+ "current_moves": {},
+ }
+ self.payoff_matrix = self._config.get("payoff_matrix", {})
+
+ def reset(
+ self,
+ seed: int | None = None,
+ options: Dict[str, str] | None = None,
+ agents: Agents | None = None,
+ omniscient: bool = False,
+ lite: bool = False,
+ include_background_observations: bool = True,
+ ) -> Dict[str, Observation]:
+ obs = super().reset(
+ seed=seed,
+ options=options,
+ agents=agents,
+ omniscient=omniscient,
+ lite=lite,
+ include_background_observations=include_background_observations,
+ )
+ self.internal_state = {"round": 0, "scores": {}, "current_moves": {}}
+ for agent in self.agents:
+ self.internal_state["scores"][agent] = 0
+ return obs
+
+ async def astep(
+ self, actions: Dict[str, AgentAction] | Dict[str, Dict[str, int | str]]
+ ) -> Tuple[
+ Dict[str, Observation],
+ Dict[str, float],
+ Dict[str, bool],
+ Dict[str, bool],
+ Dict[str, Dict[Any, Any]],
+ ]:
+ # Filter for AgentAction
+ valid_actions: Dict[str, AgentAction] = {
+ k: v for k, v in actions.items() if isinstance(v, AgentAction)
+ }
+
+ # Process actions
+ if self.action_handler:
+ for agent_name, action in valid_actions.items():
+ self.action_handler.handle_action(self, agent_name, action)
+
+ # Check if everyone moved
+ # We need to access internal_state safely
+ moves = self.internal_state.get("current_moves", {})
+ scores = self.internal_state.get("scores", {})
+
+ if len(moves) == len(self.agents) and len(self.agents) == 2:
+ agents = list(self.agents)
+ a1, a2 = agents[0], agents[1]
+ m1 = moves.get(a1, "Cooperate")
+ m2 = moves.get(a2, "Cooperate")
+
+ # Use loaded payoff matrix
+ try:
+ payoffs = self.payoff_matrix[m1][m2]
+ r1, r2 = payoffs[0], payoffs[1]
+ except KeyError:
+ # Fallback to standard PD if matrix missing or key error
+ r1, r2 = 0, 0
+
+ # Update scores
+ scores[a1] = scores.get(a1, 0) + r1
+ scores[a2] = scores.get(a2, 0) + r2
+ self.internal_state["scores"] = scores
+
+ msg = f"Round {self.internal_state.get('round', 0)+1} Results:\n{a1} chose {m1}, gets {r1}.\n{a2} chose {m2}, gets {r2}.\nTotal Scores: {scores}"
+ self.recv_message("Environment", SimpleMessage(message=msg))
+
+ self.internal_state["round"] = self.internal_state.get("round", 0) + 1
+ self.internal_state["current_moves"] = {}
+
+ return await super().astep(actions)
+
+
+# ============================================================================
+# Setup helpers
+# ============================================================================
+
+
+def ensure_agent_profile(config: Dict[str, Any]) -> AgentProfile:
+ """Create or retrieve agent profile."""
+ name = config.get("name", "")
+ role = config.get("role", "")
+
+ first_name, _, last_name = name.partition(" ")
+ if not last_name:
+ last_name = ""
+
+ # Try to find existing
+ try:
+ existing = AgentProfile.find(
+ (AgentProfile.first_name == first_name)
+ & (AgentProfile.last_name == last_name)
+ ).all()
+ if existing:
+ return AgentProfile.get(existing[0].pk)
+ except Exception:
+ pass
+
+ # Create new
+ role_secret = config.get("role_secrets", {}).get(role, "")
+ profile = AgentProfile(
+ first_name=first_name,
+ last_name=last_name,
+ secret=role_secret,
+ )
+ profile.save()
+ return profile
+
+
+def create_environment(
+ env_profile: EnvironmentProfile, model_name: str, config: Dict[str, Any]
+) -> PrisonersDilemmaEnv:
+ """Create PD game environment."""
+ return PrisonersDilemmaEnv(
+ env_profile=env_profile,
+ config=config,
+ model_name=model_name,
+ evaluators=[PrisonersDilemmaEvaluator(max_turn_number=5)],
+ terminal_evaluators=[],
+ hide_unknown=True,
+ )
+
+
+def create_agents(
+ agent_profiles: List[AgentProfile],
+ env_profile: EnvironmentProfile,
+ model_name: str | Dict[str, str] | List[str],
+ config: Dict[str, Any],
+) -> List[LLMAgent]:
+ """Create LLM agents."""
+ agents = []
+ for idx, profile in enumerate(agent_profiles):
+ agent_name = f"{profile.first_name}{' ' + profile.last_name if profile.last_name else ''}"
+ role_goal = env_profile.agent_goals[idx]
+
+ # Get secret based on role from config, matching profile index
+ # Assumption: agents list in config matches profile order
+ role = config.get("agents", [])[idx].get("role", "")
+ secrets = config.get("role_secrets", {}).get(role, "")
+
+ filled_template = (
+ SOCIAL_GAME_PROMPT_TEMPLATE.replace("{description}", env_profile.scenario)
+ .replace("{secret}", f"Your secret info: {secrets}")
+ .replace("{goal}", role_goal)
+ )
+
+ if isinstance(model_name, dict):
+ this_agent_model = model_name.get(
+ agent_name, model_name.get("default", "gpt-4")
+ )
+ elif isinstance(model_name, list):
+ this_agent_model = model_name[idx]
+ else:
+ this_agent_model = model_name
+
+ agent = LLMAgent(
+ agent_name=agent_name,
+ agent_profile=profile,
+ model_name=this_agent_model,
+ strict_action_constraint=True,
+ custom_template=filled_template,
+ )
+ agent.goal = role_goal
+ agents.append(agent)
+ return agents
+
+
+def prepare_scenario(
+ env_model_name: str,
+ agent_model_name: str | Dict[str, str] | List[str],
+ config: Dict[str, Any] | None = None,
+) -> tuple[SocialDeductionGame, List[LLMAgent]]:
+ """Load config and create profiles."""
+ if config is None:
+ config = load_config(CONFIG_PATH)
+
+ # Create agent profiles
+ agent_profiles = []
+ agent_goals = []
+ for entry in config.get("agents", []):
+ profile = ensure_agent_profile(entry)
+ agent_profiles.append(profile)
+
+ role_goal = config.get("role_goals", {}).get(entry.get("role", ""), "")
+ agent_goals.append(role_goal)
+
+ # Create environment profile
+ agent_names = [entry.get("name", "") for entry in config.get("agents", [])]
+ scenario = config.get("description", "Prisoner's Dilemma").format(
+ agent_names=", ".join(agent_names)
+ )
+ env_profile = EnvironmentProfile(
+ scenario=scenario,
+ relationship=RelationshipType.acquaintance,
+ agent_goals=agent_goals,
+ tag="pd",
+ )
+ env_profile.save()
+
+ env = create_environment(env_profile, env_model_name, config)
+ agents = create_agents(agent_profiles, env_profile, agent_model_name, config)
+ return env, agents
+
+
+def print_roster(config: Dict[str, Any]) -> None:
+ """Print game roster."""
+ print("Participants & roles:")
+ for entry in config.get("agents", []):
+ name = entry.get("name", "Unknown")
+ role = entry.get("role", "Unknown")
+ print(f" - {name}: {role}")
+
+
+# ============================================================================
+# Main
+# ============================================================================
+
+
+def get_model_names(config: Dict[str, Any]) -> Dict[str, str]:
+ """Extract model names from config. Enforces strict requirement."""
+ model_map = {}
+ for entry in config.get("agents", []):
+ name = entry.get("name")
+ model = entry.get("agent_model")
+ if not name:
+ continue
+ if not model:
+ raise ValueError(
+ f"Agent '{name}' missing 'agent_model' in config configuration."
+ )
+ model_map[name] = model
+ return model_map
+
+
+async def main() -> None:
+ """Run Prisoner's Dilemma game."""
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Run Prisoner's Dilemma game.")
+ parser.add_argument(
+ "--config",
+ type=str,
+ default=str(CONFIG_PATH),
+ help="Path to configuration file",
+ )
+ parser.add_argument(
+ "--roster",
+ type=str,
+ default=str(BASE_DIR / "roster.json"),
+ help="Path to roster file",
+ )
+ args = parser.parse_args()
+
+ config_path = args.config
+ roster_path = args.roster
+
+ config = load_config(config_path)
+ roster = load_config(roster_path)
+
+ # Merge roster into config
+ config["agents"] = roster.get("agents", [])
+
+ agent_model_name = get_model_names(config)
+ env_model_name = "gpt-4o"
+
+ # We pass config explicitly to prepare_scenario
+ env, agents = prepare_scenario(env_model_name, agent_model_name, config)
+
+ print("⛓️ Prisoner's Dilemma")
+ print("=" * 60)
+ print_roster(config)
+ print("=" * 60)
+
+ # Run game
+ await arun_one_episode(
+ env=env,
+ agent_list=agents,
+ omniscient=False,
+ script_like=False,
+ json_in_script=False,
+ tag="test_pd",
+ push_to_db=True,
+ )
+
+
+if __name__ == "__main__":
+ # Configure logging for standalone execution
+ LOG_FILE = BASE_DIR / "pd_game_debug.log"
+ _fh = logging.FileHandler(LOG_FILE, mode="w", encoding="utf-8")
+ _fh.setFormatter(
+ logging.Formatter("%(asctime)s %(levelname)-7s %(name)s - %(message)s")
+ )
+
+ _gen_logger.setLevel(logging.DEBUG)
+ _gen_logger.addHandler(_fh)
+
+ _env_logger.setLevel(logging.INFO)
+ _env_logger.addHandler(_fh)
+ _env_logger.addHandler(RichHandler())
+
+ asyncio.run(main())
diff --git a/examples/experimental/games/prisoners_dilemma/roster.json b/examples/experimental/games/prisoners_dilemma/roster.json
new file mode 100644
index 000000000..d949e9596
--- /dev/null
+++ b/examples/experimental/games/prisoners_dilemma/roster.json
@@ -0,0 +1,14 @@
+{
+ "agents": [
+ {
+ "name": "Alice",
+ "role": "Player",
+ "agent_model": "gpt-4o-mini"
+ },
+ {
+ "name": "Bob",
+ "role": "Player",
+ "agent_model": "gpt-4o-mini"
+ }
+ ]
+}
diff --git a/examples/experimental/games/rock_paper_scissors/__init__.py b/examples/experimental/games/rock_paper_scissors/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/experimental/games/rock_paper_scissors/config.json b/examples/experimental/games/rock_paper_scissors/config.json
new file mode 100644
index 000000000..045124326
--- /dev/null
+++ b/examples/experimental/games/rock_paper_scissors/config.json
@@ -0,0 +1,23 @@
+{
+ "scenario": "Rock Paper Scissors",
+ "description": "A classic game. Rock beats Scissors, Scissors beats Paper, Paper beats Rock. However in this game you play exactly 10 rounds. GOAL: Win as many rounds as possible. Each turn the turn winner gets a point. The player with the most points at the end wins. You cannot speak. You can only choose your action. \n\n The players in this game are: {agent_names}.",
+ "role_goals": {
+ "Player": "Win against the opponent."
+ },
+ "role_secrets": {
+ "Player": "Choose Rock, Paper, or Scissors secretly."
+ },
+ "initial_state": "Round",
+ "state_transition": {
+ "Round": "Round"
+ },
+ "state_properties": {
+ "Round": {
+ "actions": [
+ "action"
+ ],
+ "action_order": "simultaneous",
+ "visibility": "public"
+ }
+ }
+}
diff --git a/examples/experimental/games/rock_paper_scissors/main.py b/examples/experimental/games/rock_paper_scissors/main.py
new file mode 100644
index 000000000..1b29eb2c3
--- /dev/null
+++ b/examples/experimental/games/rock_paper_scissors/main.py
@@ -0,0 +1,400 @@
+"""Launcher for the Rock, Paper, Scissors game."""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import redis
+from rich.logging import RichHandler
+
+from sotopia.agents import LLMAgent
+from sotopia.agents.llm_agent import Agents
+from sotopia.database.persistent_profile import (
+ AgentProfile,
+ EnvironmentProfile,
+ RelationshipType,
+)
+from sotopia.envs.evaluators import SocialGameEndEvaluator
+from sotopia.envs.social_game import (
+ ActionHandler,
+ SocialDeductionGame,
+ SOCIAL_GAME_PROMPT_TEMPLATE,
+ load_config,
+)
+from sotopia.messages import AgentAction, Message, Observation, SimpleMessage
+from sotopia.server import arun_one_episode
+
+BASE_DIR = Path(__file__).resolve().parent
+CONFIG_PATH = BASE_DIR / "config.json"
+
+os.environ.setdefault("REDIS_OM_URL", "redis://:@localhost:6379")
+redis.Redis(host="localhost", port=6379)
+
+# Loggers (configured in main if running standalone)
+_gen_logger = logging.getLogger("sotopia.generation")
+_env_logger = logging.getLogger("sotopia.envs.social_game")
+
+
+# ============================================================================
+# Evaluator
+# ============================================================================
+
+
+class RPSEvaluator(SocialGameEndEvaluator):
+ def __call__(
+ self, turn_number: int, messages: List[Tuple[str, Message]], **kwargs: Any
+ ) -> List[Tuple[str, Tuple[Tuple[str, int | float | bool], str]]]:
+ if turn_number >= self.max_turn_number:
+ env = kwargs.get("env")
+ if env:
+ scores = env.internal_state.get("scores", {})
+ response: List[
+ Tuple[str, Tuple[Tuple[str, int | float | bool], str]]
+ ] = [("environment", (("terminated", True), "Max (10) turns reached"))]
+
+ agent_names = list(env.agents)
+ scores_dict = {name: scores.get(name, 0) for name in agent_names}
+
+ # Logic:
+ # Highest score wins (1), Loser (-1). Tie -> 0.
+
+ values = list(scores_dict.values())
+ rewards = {name: 0.0 for name in agent_names}
+
+ if len(values) == 2:
+ s1, s2 = values[0], values[1]
+ n1, n2 = agent_names[0], agent_names[1]
+
+ if s1 > s2:
+ rewards[n1] = 1.0
+ rewards[n2] = -1.0
+ elif s2 > s1:
+ rewards[n1] = -1.0
+ rewards[n2] = 1.0
+
+ for agent_name in agent_names:
+ idx = agent_names.index(agent_name)
+ key = f"agent_{idx+1}"
+ raw_score = scores_dict.get(agent_name, 0)
+ reward = rewards.get(agent_name, 0.0)
+ response.append(
+ (
+ key,
+ (("complete_rating", reward), f"Final Score: {raw_score}"),
+ )
+ )
+ return response
+ return [("environment", (("terminated", True), "Max turns reached"))]
+ return [("environment", (("terminated", False), ""))]
+
+
+# ============================================================================
+# Action Handler
+# ============================================================================
+
+
+class RPSActionHandler(ActionHandler):
+ def handle_action(
+ self, env: SocialDeductionGame, agent_name: str, action: AgentAction
+ ) -> None:
+ if isinstance(env, RPSEnv) and action.action_type in ["action", "speak"]:
+ move_str = action.argument.lower()
+ current_move = None
+ if "rock" in move_str:
+ current_move = "Rock"
+ elif "paper" in move_str:
+ current_move = "Paper"
+ elif "scissors" in move_str or "scissor" in move_str:
+ current_move = "Scissors"
+
+ if current_move:
+ env.internal_state["current_moves"][agent_name] = current_move
+
+ def get_action_instruction(self, env: SocialDeductionGame, agent_name: str) -> str:
+ return "You are playing Rock-Paper-Scissors. Choose 'action: rock', 'action: paper', or 'action: scissors'. You cannot speak."
+
+
+# ============================================================================
+# Environment
+# ============================================================================
+
+
+class RPSEnv(SocialDeductionGame):
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(action_handler=RPSActionHandler(), **kwargs)
+ self.internal_state = {"round": 0, "scores": {}, "current_moves": {}}
+
+ def reset(
+ self,
+ seed: int | None = None,
+ options: dict[str, str] | None = None,
+ agents: Agents | None = None,
+ omniscient: bool = False,
+ lite: bool = False,
+ include_background_observations: bool = True,
+ ) -> Dict[str, Observation]:
+ obs = super().reset(
+ seed=seed,
+ options=options,
+ agents=agents,
+ omniscient=omniscient,
+ lite=lite,
+ include_background_observations=include_background_observations,
+ )
+ self.internal_state = {"round": 0, "scores": {}, "current_moves": {}}
+ for agent in self.agents:
+ self.internal_state["scores"][agent] = 0
+ return obs
+
+ async def astep(
+ self, actions: Dict[str, AgentAction] | Dict[str, Dict[str, int | str]]
+ ) -> Tuple[
+ Dict[str, Any],
+ Dict[str, float],
+ Dict[str, bool],
+ Dict[str, bool],
+ Dict[str, Any],
+ ]:
+ if self.action_handler:
+ for agent_name, action in actions.items():
+ if isinstance(action, AgentAction):
+ self.action_handler.handle_action(self, agent_name, action)
+
+ moves = self.internal_state["current_moves"]
+ if len(moves) == len(self.agents) and len(self.agents) == 2:
+ agents = list(self.agents)
+ a1, a2 = agents[0], agents[1]
+ m1, m2 = moves[a1], moves[a2]
+
+ result, winner = "Draw", None
+ if m1 == m2:
+ result = "Draw"
+ elif (
+ (m1 == "Rock" and m2 == "Scissors")
+ or (m1 == "Scissors" and m2 == "Paper")
+ or (m1 == "Paper" and m2 == "Rock")
+ ):
+ winner = a1
+ else:
+ winner = a2
+
+ r1, r2 = (1, -1) if winner == a1 else (-1, 1) if winner == a2 else (0, 0)
+ result = f"{winner} wins!" if winner else "Draw"
+
+ self.internal_state["scores"][a1] += r1
+ self.internal_state["scores"][a2] += r2
+
+ msg = f"Round {self.internal_state['round']+1}: {a1} ({m1}) vs {a2} ({m2}) -> {result}"
+ self.recv_message("Environment", SimpleMessage(message=msg))
+
+ self.internal_state["round"] += 1
+ self.internal_state["current_moves"] = {}
+
+ return await super().astep(actions)
+
+
+# ============================================================================
+# Setup helpers
+# ============================================================================
+
+
+def ensure_agent_profile(config: Dict[str, Any]) -> AgentProfile:
+ """Create or retrieve agent profile."""
+ name = config.get("name", "")
+ role = config.get("role", "")
+ first_name, _, last_name = name.partition(" ")
+ if not last_name:
+ last_name = ""
+
+ try:
+ existing = AgentProfile.find(
+ (AgentProfile.first_name == first_name)
+ & (AgentProfile.last_name == last_name)
+ ).all()
+ if existing:
+ return AgentProfile.get(existing[0].pk)
+ except Exception:
+ pass
+
+ role_secret = config.get("role_secrets", {}).get(role, "")
+ profile = AgentProfile(
+ first_name=first_name, last_name=last_name, secret=role_secret
+ )
+ profile.save()
+ return profile
+
+
+def create_environment(
+ env_profile: EnvironmentProfile, model_name: str, config: Dict[str, Any]
+) -> RPSEnv:
+ return RPSEnv(
+ env_profile=env_profile,
+ config=config,
+ model_name=model_name,
+ # action_order is handled by config state_properties
+ evaluators=[RPSEvaluator(max_turn_number=10)],
+ terminal_evaluators=[],
+ hide_unknown=True,
+ )
+
+
+def create_agents(
+ agent_profiles: List[AgentProfile],
+ env_profile: EnvironmentProfile,
+ model_name: str | Dict[str, str] | List[str],
+ config: Dict[str, Any],
+) -> List[LLMAgent]:
+ agents = []
+ for idx, profile in enumerate(agent_profiles):
+ agent_name = f"{profile.first_name}{' ' + profile.last_name if profile.last_name else ''}"
+ role_goal = env_profile.agent_goals[idx]
+ role = config.get("agents", [])[idx].get("role", "")
+ secrets = config.get("role_secrets", {}).get(role, "")
+
+ filled_template = (
+ SOCIAL_GAME_PROMPT_TEMPLATE.replace("{description}", env_profile.scenario)
+ .replace("{secret}", f"Your secret info: {secrets}")
+ .replace("{goal}", role_goal)
+ )
+
+ if isinstance(model_name, dict):
+ this_agent_model = model_name.get(
+ agent_name, model_name.get("default", "gpt-4")
+ )
+ elif isinstance(model_name, list):
+ this_agent_model = model_name[idx]
+ else:
+ this_agent_model = model_name
+
+ agent = LLMAgent(
+ agent_name=agent_name,
+ agent_profile=profile,
+ model_name=this_agent_model,
+ strict_action_constraint=True,
+ custom_template=filled_template,
+ )
+ agent.goal = role_goal
+ agents.append(agent)
+ return agents
+
+
+def prepare_scenario(
+ env_model_name: str,
+ agent_model_name: str | Dict[str, str] | List[str],
+ config: Dict[str, Any] | None = None,
+) -> tuple[SocialDeductionGame, List[LLMAgent]]:
+ if config is None:
+ config = load_config(CONFIG_PATH)
+ agent_profiles = [ensure_agent_profile(entry) for entry in config.get("agents", [])]
+ agent_goals = [
+ config.get("role_goals", {}).get(entry.get("role", ""), "")
+ for entry in config.get("agents", [])
+ ]
+ agent_names = [entry.get("name", "") for entry in config.get("agents", [])]
+ scenario = config.get("description", "RPS").format(
+ agent_names=", ".join(agent_names)
+ )
+
+ env_profile = EnvironmentProfile(
+ scenario=scenario,
+ relationship=RelationshipType.acquaintance,
+ agent_goals=agent_goals,
+ tag="rps",
+ )
+ env_profile.save()
+
+ env = create_environment(env_profile, env_model_name, config)
+ agents = create_agents(agent_profiles, env_profile, agent_model_name, config)
+ return env, agents
+
+
+def print_roster(config: Dict[str, Any]) -> None:
+ print("Participants & roles:")
+ for entry in config.get("agents", []):
+ print(f" - {entry.get('name')}: {entry.get('role')}")
+
+
+def get_model_names(config: Dict[str, Any]) -> Dict[str, str]:
+ """Extract model names from config. Enforces strict requirement."""
+ model_map = {}
+ for entry in config.get("agents", []):
+ name = entry.get("name")
+ model = entry.get("agent_model")
+ if not name:
+ continue
+ if not model:
+ raise ValueError(
+ f"Agent '{name}' missing 'agent_model' in config configuration."
+ )
+ model_map[name] = model
+ return model_map
+
+
+async def main() -> None:
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Run Rock-Paper-Scissors game.")
+ parser.add_argument(
+ "--config",
+ type=str,
+ default=str(CONFIG_PATH),
+ help="Path to configuration file",
+ )
+ parser.add_argument(
+ "--roster",
+ type=str,
+ default=str(BASE_DIR / "roster.json"),
+ help="Path to roster file",
+ )
+ args = parser.parse_args()
+
+ config_path = args.config
+ roster_path = args.roster
+
+ config = load_config(config_path)
+ roster = load_config(roster_path)
+
+ # Merge roster into config for consistent access
+ config["agents"] = roster.get("agents", [])
+
+ agent_model_name = get_model_names(config)
+ env_model_name = "gpt-4o"
+
+ # We pass config explicitly to prepare_scenario
+ env, agents = prepare_scenario(env_model_name, agent_model_name, config)
+
+ print("✊✋✌️ Rock Paper Scissors")
+ print("=" * 60)
+ print_roster(config)
+ print("=" * 60)
+ await arun_one_episode(
+ env=env,
+ agent_list=agents,
+ omniscient=False,
+ script_like=False,
+ json_in_script=False,
+ tag="test_rps",
+ push_to_db=True,
+ )
+
+
+if __name__ == "__main__":
+ # Configure logging for standalone execution
+ LOG_FILE = BASE_DIR / "rps_game_debug.log"
+ _fh = logging.FileHandler(LOG_FILE, mode="w", encoding="utf-8")
+ _fh.setFormatter(
+ logging.Formatter("%(asctime)s %(levelname)-7s %(name)s - %(message)s")
+ )
+
+ _gen_logger.setLevel(logging.DEBUG)
+ _gen_logger.addHandler(_fh)
+
+ _env_logger.setLevel(logging.INFO)
+ _env_logger.addHandler(_fh)
+ _env_logger.addHandler(RichHandler())
+
+ asyncio.run(main())
diff --git a/examples/experimental/games/rock_paper_scissors/roster.json b/examples/experimental/games/rock_paper_scissors/roster.json
new file mode 100644
index 000000000..d949e9596
--- /dev/null
+++ b/examples/experimental/games/rock_paper_scissors/roster.json
@@ -0,0 +1,14 @@
+{
+ "agents": [
+ {
+ "name": "Alice",
+ "role": "Player",
+ "agent_model": "gpt-4o-mini"
+ },
+ {
+ "name": "Bob",
+ "role": "Player",
+ "agent_model": "gpt-4o-mini"
+ }
+ ]
+}
diff --git a/examples/experimental/games/spyfall/__init__.py b/examples/experimental/games/spyfall/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/experimental/games/spyfall/config.json b/examples/experimental/games/spyfall/config.json
new file mode 100644
index 000000000..b4efb4c0e
--- /dev/null
+++ b/examples/experimental/games/spyfall/config.json
@@ -0,0 +1,47 @@
+{
+ "scenario": "Spyfall",
+ "description": "A social deduction game where players are at a specific location, but one player is a Spy who doesn't know where they are. GAME STRUCTURE: One Spy, multiple Non-Spies. WIN CONDITIONS: Non-Spies win by identifying the Spy. The Spy wins by guessing the location or avoiding detection. PHASES: (1) Questioning: Players take turns asking questions to each other to prove they know the location without giving it away to the Spy. (2) Voting: Players vote to eliminate the suspected Spy. STRATEGY: Non-Spies must be subtle; Spies must listen and blend in. \n\n The players in this game are: {agent_names}.",
+ "role_goals": {
+ "Non-Spy": "Identify the Spy without revealing the location too clearly.",
+ "Spy": "Figure out the location and blend in as a Non-Spy."
+ },
+ "role_secrets": {
+ "Non-Spy": "The location is: Space Station.",
+ "Spy": "You are the Spy. You do NOT know the location."
+ },
+ "initial_state": "Round_questioning",
+ "state_transition": {
+ "Round_questioning": "Round_vote",
+ "Round_vote": "Round_questioning"
+ },
+ "state_properties": {
+ "Round_questioning": {
+ "actions": [
+ "speak"
+ ],
+ "visibility": "public"
+ },
+ "Round_vote": {
+ "actions": [
+ "action"
+ ],
+ "action_order": "simultaneous",
+ "visibility": "public"
+ }
+ },
+ "end_conditions": [
+ {
+ "type": "team_eliminated",
+ "team": "Spy",
+ "winner": "Non-Spies",
+ "message": "[Game] The Spy has been caught! Non-Spies win."
+ },
+ {
+ "type": "parity",
+ "team": "Spy",
+ "other": "Non-Spies",
+ "winner": "Spy",
+ "message": "[Game] The Spy has evaded capture! Spy wins."
+ }
+ ]
+}
diff --git a/examples/experimental/games/spyfall/main.py b/examples/experimental/games/spyfall/main.py
new file mode 100644
index 000000000..b944dc0ed
--- /dev/null
+++ b/examples/experimental/games/spyfall/main.py
@@ -0,0 +1,493 @@
+"""Launcher for the Spyfall social game scenario."""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+from pathlib import Path
+from typing import Any, Dict, List
+
+import redis
+from rich.logging import RichHandler
+
+from sotopia.agents import LLMAgent
+from sotopia.agents.llm_agent import Agents
+from sotopia.database.persistent_profile import (
+ AgentProfile,
+ EnvironmentProfile,
+ RelationshipType,
+)
+from sotopia.envs import SocialDeductionGame
+from sotopia.envs.evaluators import SocialGameEndEvaluator
+from sotopia.envs.social_game import (
+ SOCIAL_GAME_PROMPT_TEMPLATE,
+ ActionHandler,
+ load_config,
+)
+from sotopia.messages import AgentAction, Message, Observation, SimpleMessage
+from sotopia.server import arun_one_episode
+
+BASE_DIR = Path(__file__).resolve().parent
+CONFIG_PATH = BASE_DIR / "config.json"
+
+os.environ.setdefault("REDIS_OM_URL", "redis://:@localhost:6379")
+redis.Redis(host="localhost", port=6379)
+
+# Loggers (configured in main if running standalone)
+_gen_logger = logging.getLogger("sotopia.generation")
+_env_logger = logging.getLogger("sotopia.envs.social_game")
+
+
+# ============================================================================
+# Spyfall game-end evaluator
+# ============================================================================
+
+
+class SpyfallGameEndEvaluator(SocialGameEndEvaluator):
+ """Evaluator that checks spyfall win conditions."""
+
+ def _check_win_conditions( # type: ignore[override]
+ self, env: Any, turn_number: int, messages: List[tuple[str, Message]]
+ ) -> tuple[bool, str, Dict[str, float]]:
+ """Check if game has ended based on spyfall win conditions."""
+
+ # Check if Spy is eliminated
+ # spy_eliminated = False
+ non_spies_count = 0
+ spy_count = 0
+
+ for agent_name, alive in env.agent_alive.items():
+ role = env.agent_to_role.get(agent_name, "")
+ if role == "Spy":
+ if not alive:
+ # Spy is eliminated
+ pass
+ else:
+ spy_count += 1
+ else:
+ if alive:
+ non_spies_count += 1
+
+ # Check end conditions from config
+ end_conditions = env._config.get("end_conditions", [])
+ for condition in end_conditions:
+ cond_type = condition.get("type")
+
+ if cond_type == "team_eliminated":
+ team = condition.get("team", "") # e.g. "Spy"
+ # If target team is eliminated (Spy count is 0)
+ if team == "Spy" and spy_count == 0:
+ winner = condition.get("winner", "")
+ msg = condition.get("message", f"{winner} wins!")
+ env.recv_message("Environment", SimpleMessage(message=msg))
+
+ # Calculate rewards
+ rewards = {}
+ for agent_name in env.agents:
+ role = env.agent_to_role.get(agent_name, "")
+ team_name = env.role_to_team.get(role, "")
+ if team_name == winner: # Compare with team name "Non-Spies"
+ rewards[agent_name] = 1.0
+ else:
+ rewards[agent_name] = -1.0
+
+ return True, msg, rewards
+
+ elif cond_type == "parity":
+ # Spy wins if spy count >= non-spy count
+ team1 = condition.get("team", "")
+ team2 = condition.get("other", "")
+ # Count current alive
+ team1_count = 0
+ team2_count = 0
+ for agent_name, alive in env.agent_alive.items():
+ if alive:
+ role = env.agent_to_role.get(agent_name, "")
+ team = env.role_to_team.get(role, "")
+ if team == team1:
+ team1_count += 1
+ elif team == team2:
+ team2_count += 1
+
+ if team1_count >= team2_count and team1_count > 0:
+ winner = condition.get("winner", "")
+ msg = condition.get("message", f"{winner} wins!")
+ env.recv_message("Environment", SimpleMessage(message=msg))
+
+ # Calculate rewards
+ rewards = {}
+ for agent_name in env.agents:
+ role = env.agent_to_role.get(agent_name, "")
+ team_name = env.role_to_team.get(role, "")
+ if team_name == winner:
+ rewards[agent_name] = 1.0
+ else:
+ rewards[agent_name] = -1.0
+ return True, msg, rewards
+
+ return False, "", {}
+
+ def __call__(
+ self, turn_number: int, messages: List[tuple[str, Message]], **kwargs: Any
+ ) -> List[tuple[str, tuple[tuple[str, int | float | bool], str]]]:
+ # Check turn limit
+ if turn_number >= self.max_turn_number:
+ return [("environment", (("terminated", True), "Max turns reached"))]
+
+ # Extract environment from kwargs
+ env = kwargs.get("env")
+ if not env:
+ return [("environment", (("terminated", False), ""))]
+
+ # Check game-specific win conditions
+ terminated, reason, rewards = self._check_win_conditions(
+ env, turn_number, messages
+ )
+
+ response: List[tuple[str, tuple[tuple[str, int | float | bool], str]]] = [
+ ("environment", (("terminated", terminated), reason))
+ ]
+
+ if terminated and rewards:
+ agent_names = list(env.agents)
+ for agent_name, reward in rewards.items():
+ try:
+ idx = agent_names.index(agent_name)
+ generic_key = f"agent_{idx+1}"
+ response.append((generic_key, (("complete_rating", reward), "")))
+ except ValueError:
+ continue
+
+ return response
+
+
+# ============================================================================
+# Spyfall-specific game logic
+# ============================================================================
+
+
+class SpyfallActionHandler(ActionHandler):
+ """Handles actions for the Spyfall game."""
+
+ def handle_action(
+ self, env: SocialDeductionGame, agent_name: str, action: AgentAction
+ ) -> None:
+ """Handle a single action from an agent based on current state."""
+
+ if env.current_state == "Round_vote":
+ # Collect votes for elimination
+ if "votes" not in env.internal_state:
+ env.internal_state["votes"] = {}
+
+ if action.action_type == "action" and "vote" in action.argument.lower():
+ # Parse target from "vote Aurora" or "I vote for Aurora"
+ words = action.argument.split()
+ # Try to find a name (capitalized word)
+ target = next(
+ (w for w in words if w[0].isupper() and w in env.agents), None
+ )
+ if target:
+ env.internal_state["votes"][agent_name] = target
+
+ def get_action_instruction(self, env: SocialDeductionGame, agent_name: str) -> str:
+ """Get specific action instructions for an agent based on current state."""
+
+ if env.current_state == "Round_vote":
+ return "It is voting time. You MUST use the command 'vote NAME' to vote for the player you suspect is the Spy. e.g. 'vote Alice'"
+
+ return ""
+
+
+class SpyfallEnv(SocialDeductionGame):
+ """Spyfall game with questioning and voting."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(action_handler=SpyfallActionHandler(), **kwargs)
+
+ def reset(
+ self,
+ seed: int | None = None,
+ options: Dict[str, str] | None = None,
+ agents: Agents | None = None,
+ omniscient: bool = False,
+ lite: bool = False,
+ include_background_observations: bool = True,
+ ) -> Dict[str, Observation]:
+ return super().reset(
+ seed=seed,
+ options=options,
+ agents=agents,
+ omniscient=omniscient,
+ lite=lite,
+ include_background_observations=include_background_observations,
+ )
+
+ def _check_eliminations(self) -> None:
+ """Apply eliminations based on collected actions."""
+ # Only apply eliminations if we are about to transition state
+ if not self._should_transition_state():
+ return
+
+ if self.current_state == "Round_vote":
+ # Tally votes and eliminate most voted player
+ votes = self.internal_state.get("votes", {})
+ if votes:
+ vote_counts: Dict[str, int] = {}
+ for target in votes.values():
+ vote_counts[target] = vote_counts.get(target, 0) + 1
+
+ if vote_counts:
+ # Find player with most votes
+ eliminated = max(vote_counts, key=vote_counts.get) # type: ignore
+ self.agent_alive[eliminated] = False
+ self.recv_message(
+ "Environment",
+ SimpleMessage(
+ message=f"[Game] {eliminated} was voted out! They were a {self.agent_to_role[eliminated]}."
+ ),
+ )
+ _gen_logger.info(
+ f"{eliminated} was voted out! They were a {self.agent_to_role[eliminated]}."
+ )
+ # Clear votes
+ self.internal_state["votes"] = {}
+
+
+# ============================================================================
+# Setup helpers
+# ============================================================================
+
+
+def ensure_agent_profile(config: Dict[str, Any]) -> AgentProfile:
+ """Create or retrieve agent profile."""
+ name = config.get("name", "")
+ role = config.get("role", "")
+
+ first_name, _, last_name = name.partition(" ")
+ if not last_name:
+ last_name = ""
+
+ # Try to find existing
+ try:
+ existing = AgentProfile.find(
+ (AgentProfile.first_name == first_name)
+ & (AgentProfile.last_name == last_name)
+ ).all()
+ if existing:
+ return AgentProfile.get(existing[0].pk)
+ except Exception:
+ pass
+
+ # Create new
+ role_secret = config.get("role_secrets", {}).get(role, "")
+ profile = AgentProfile(
+ first_name=first_name,
+ last_name=last_name,
+ secret=role_secret,
+ )
+ profile.save()
+ return profile
+
+
+def create_environment(
+ env_profile: EnvironmentProfile, model_name: str, config: Dict[str, Any]
+) -> SpyfallEnv:
+ """Create spyfall game environment."""
+ return SpyfallEnv(
+ env_profile=env_profile,
+ config=config,
+ model_name=model_name,
+ action_order="round-robin",
+ evaluators=[SpyfallGameEndEvaluator(max_turn_number=20)],
+ terminal_evaluators=[],
+ hide_unknown=True,
+ )
+
+
+def create_agents(
+ agent_profiles: List[AgentProfile],
+ env_profile: EnvironmentProfile,
+ model_name: str | Dict[str, str] | List[str],
+ config: Dict[str, Any],
+) -> List[LLMAgent]:
+ """Create LLM agents."""
+ agents = []
+ for idx, profile in enumerate(agent_profiles):
+ # Calculate secrets
+ agent_name = f"{profile.first_name}{' ' + profile.last_name if profile.last_name else ''}"
+ role_goal = env_profile.agent_goals[idx]
+
+ # Get secret based on role
+ role = config.get("agents", [])[idx].get("role", "")
+ # role_secrets is dictionary in config
+ secrets = config.get("role_secrets", {}).get(role, "")
+
+ # Fill template
+ filled_template = (
+ SOCIAL_GAME_PROMPT_TEMPLATE.replace("{description}", env_profile.scenario)
+ .replace("{secret}", f"Your secret info: {secrets}")
+ .replace(
+ "{goal}",
+ role_goal,
+ )
+ )
+
+ # Determine model
+ if isinstance(model_name, dict):
+ this_agent_model = model_name.get(
+ agent_name, model_name.get("default", "gpt-4")
+ )
+ elif isinstance(model_name, list):
+ this_agent_model = model_name[idx]
+ else:
+ this_agent_model = model_name
+
+ agent = LLMAgent(
+ agent_name=agent_name,
+ agent_profile=profile,
+ model_name=this_agent_model,
+ strict_action_constraint=True,
+ custom_template=filled_template,
+ )
+ agent.goal = env_profile.agent_goals[idx]
+ agents.append(agent)
+ return agents
+
+
+def prepare_scenario(
+ env_model_name: str,
+ agent_model_name: str | Dict[str, str] | List[str],
+ config: Dict[str, Any] | None = None,
+) -> tuple[SocialDeductionGame, List[LLMAgent]]:
+ """Load config and create profiles."""
+ if config is None:
+ config = load_config(CONFIG_PATH)
+
+ # Create agent profiles
+ agent_profiles = []
+ agent_goals = []
+ for entry in config.get("agents", []):
+ profile = ensure_agent_profile(entry)
+ agent_profiles.append(profile)
+
+ role_goal = config.get("role_goals", {}).get(entry.get("role", ""), "")
+ agent_goals.append(role_goal)
+
+ # Create environment profile
+ agent_names = [entry.get("name", "") for entry in config.get("agents", [])]
+ scenario = config.get("description", "Spyfall game").format(
+ agent_names=", ".join(agent_names)
+ )
+ env_profile = EnvironmentProfile(
+ scenario=scenario,
+ relationship=RelationshipType.acquaintance,
+ agent_goals=agent_goals,
+ tag="spyfall",
+ )
+ env_profile.save()
+
+ env = create_environment(env_profile, env_model_name, config)
+ agents = create_agents(agent_profiles, env_profile, agent_model_name, config)
+ return env, agents
+
+
+def print_roster(config: Dict[str, Any]) -> None:
+ """Print game roster."""
+ print("Participants & roles:")
+ for entry in config.get("agents", []):
+ name = entry.get("name", "Unknown")
+ role = entry.get("role", "Unknown")
+ print(f" - {name}: {role}")
+
+
+# ============================================================================
+# Main
+# ============================================================================
+
+
+def get_model_names(config: Dict[str, Any]) -> Dict[str, str]:
+ """Extract model names from config. Enforces strict requirement."""
+ model_map = {}
+ for entry in config.get("agents", []):
+ name = entry.get("name")
+ model = entry.get("agent_model")
+ if not name:
+ continue
+ if not model:
+ raise ValueError(
+ f"Agent '{name}' missing 'agent_model' in config configuration."
+ )
+ model_map[name] = model
+ return model_map
+
+
+async def main() -> None:
+ """Run spyfall game."""
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Run Spyfall game.")
+ parser.add_argument(
+ "--config",
+ type=str,
+ default=str(CONFIG_PATH),
+ help="Path to configuration file",
+ )
+ parser.add_argument(
+ "--roster",
+ type=str,
+ default=str(BASE_DIR / "roster.json"),
+ help="Path to roster file",
+ )
+ args = parser.parse_args()
+
+ config_path = args.config
+ roster_path = args.roster
+
+ config = load_config(config_path)
+ roster = load_config(roster_path)
+
+ # Merge roster into config
+ config["agents"] = roster.get("agents", [])
+
+ env_model_name = "gpt-4o"
+ agent_model_name = get_model_names(config)
+
+ # Setup
+ env, agents = prepare_scenario(env_model_name, agent_model_name, config)
+
+ # Display roster
+ # Config already loaded above
+ print("🕵️ Spyfall")
+ print("=" * 60)
+ print_roster(config)
+ print("=" * 60)
+
+ # Run game
+ await arun_one_episode(
+ env=env,
+ agent_list=agents,
+ omniscient=False,
+ script_like=False,
+ json_in_script=False,
+ tag="test_spyfall",
+ push_to_db=True,
+ )
+
+
+if __name__ == "__main__":
+ # Configure logging for standalone execution
+ LOG_FILE = BASE_DIR / "spyfall_game_debug.log"
+ _fh = logging.FileHandler(LOG_FILE, mode="w", encoding="utf-8")
+ _fh.setFormatter(
+ logging.Formatter("%(asctime)s %(levelname)-7s %(name)s - %(message)s")
+ )
+
+ _gen_logger.setLevel(logging.DEBUG)
+ _gen_logger.addHandler(_fh)
+
+ _env_logger.setLevel(logging.INFO)
+ _env_logger.addHandler(_fh)
+ _env_logger.addHandler(RichHandler())
+
+ asyncio.run(main())
diff --git a/examples/experimental/games/spyfall/roster.json b/examples/experimental/games/spyfall/roster.json
new file mode 100644
index 000000000..47385fef1
--- /dev/null
+++ b/examples/experimental/games/spyfall/roster.json
@@ -0,0 +1,28 @@
+{
+ "agents": [
+ {
+ "name": "Alice",
+ "role": "Non-Spy",
+ "team": "Non-Spies",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "Bob",
+ "role": "Spy",
+ "team": "Spy",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "Charlie",
+ "role": "Non-Spy",
+ "team": "Non-Spies",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "David",
+ "role": "Non-Spy",
+ "team": "Non-Spies",
+ "agent_model": "gpt-4o"
+ }
+ ]
+}
diff --git a/examples/experimental/games/undercover/__init__.py b/examples/experimental/games/undercover/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/experimental/games/undercover/config.json b/examples/experimental/games/undercover/config.json
new file mode 100644
index 000000000..0bc2054b8
--- /dev/null
+++ b/examples/experimental/games/undercover/config.json
@@ -0,0 +1,47 @@
+{
+ "scenario": "Undercover",
+ "description": "A social deduction game where players describe a secret word. Most players (Civilians) have the same word, but two players (Undercovers) have a different but related word. GAME STRUCTURE: two Undercovers, four Civilians. WIN CONDITIONS: Civilians win by eliminating Undercovers. The Undercovers win if they survive until only two players remain (or parity). PHASES: (1) Description: Players take turns describing their word using one sentence, without revealing the word itself. (2) Voting: Players vote to eliminate suspected Undercovers. STRATEGY: Civilians must be vague enough to not help the Undercover but specific enough to prove their identity. The Undercover must infer the Civilians' word and blend in. \n\n The players in this game are: {agent_names}.",
+ "role_goals": {
+ "Civilian": "Describe your word subtly and identify the player with the different word.",
+ "Undercover": "Infer the Civilians' word and blend in without being caught."
+ },
+ "role_secrets": {
+ "Civilian": "Your secret word is: Cat.",
+ "Undercover": "Your secret word is: Dog."
+ },
+ "initial_state": "Round_description",
+ "state_transition": {
+ "Round_description": "Round_vote",
+ "Round_vote": "Round_description"
+ },
+ "state_properties": {
+ "Round_description": {
+ "actions": [
+ "speak"
+ ],
+ "visibility": "public"
+ },
+ "Round_vote": {
+ "actions": [
+ "action"
+ ],
+ "action_order": "simultaneous",
+ "visibility": "public"
+ }
+ },
+ "end_conditions": [
+ {
+ "type": "team_eliminated",
+ "team": "Undercover",
+ "winner": "Civilians",
+ "message": "[Game] The Undercover has been eliminated! Civilians win."
+ },
+ {
+ "type": "parity",
+ "team": "Undercover",
+ "other": "Civilians",
+ "winner": "Undercover",
+ "message": "[Game] The Undercover has survived! Undercover wins."
+ }
+ ]
+}
diff --git a/examples/experimental/games/undercover/main.py b/examples/experimental/games/undercover/main.py
new file mode 100644
index 000000000..0cfbedd97
--- /dev/null
+++ b/examples/experimental/games/undercover/main.py
@@ -0,0 +1,478 @@
+"""Launcher for the Undercover social game scenario."""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+from pathlib import Path
+from typing import Any, Dict, List
+
+import redis
+from rich.logging import RichHandler
+
+from sotopia.agents import LLMAgent
+from sotopia.agents.llm_agent import Agents
+from sotopia.database.persistent_profile import (
+ AgentProfile,
+ EnvironmentProfile,
+ RelationshipType,
+)
+from sotopia.envs import SocialDeductionGame
+from sotopia.envs.evaluators import SocialGameEndEvaluator
+from sotopia.envs.social_game import (
+ SOCIAL_GAME_PROMPT_TEMPLATE,
+ ActionHandler,
+ load_config,
+)
+from sotopia.messages import AgentAction, Message, Observation, SimpleMessage
+from sotopia.server import arun_one_episode
+
+BASE_DIR = Path(__file__).resolve().parent
+CONFIG_PATH = BASE_DIR / "config.json"
+
+os.environ.setdefault("REDIS_OM_URL", "redis://:@localhost:6379")
+redis.Redis(host="localhost", port=6379)
+
+# Loggers (configured in main if running standalone)
+_gen_logger = logging.getLogger("sotopia.generation")
+_env_logger = logging.getLogger("sotopia.envs.social_game")
+
+
+# ============================================================================
+# Undercover game-end evaluator
+# ============================================================================
+
+
+class UndercoverGameEndEvaluator(SocialGameEndEvaluator):
+ """Evaluator that checks undercover win conditions."""
+
+ def _check_win_conditions( # type: ignore[override]
+ self, env: Any, turn_number: int, messages: List[tuple[str, Message]]
+ ) -> tuple[bool, str, Dict[str, float]]:
+ """Check if game has ended based on undercover win conditions."""
+
+ # Count current alive
+ civilians_count = 0
+ undercover_count = 0
+
+ for agent_name, alive in env.agent_alive.items():
+ if alive:
+ role = env.agent_to_role.get(agent_name, "")
+ if role == "Undercover":
+ undercover_count += 1
+ else:
+ civilians_count += 1
+
+ # 1. Civilians win if all Undercovers are eliminated
+ if undercover_count == 0:
+ msg = "[Game] All Undercovers have been eliminated! Civilians win."
+ env.recv_message("Environment", SimpleMessage(message=msg))
+
+ rewards = {}
+ for agent_name in env.agents:
+ role = env.agent_to_role.get(agent_name, "")
+ team_name = env.role_to_team.get(role, "")
+ if team_name == "Civilians":
+ rewards[agent_name] = 1.0
+ else:
+ rewards[agent_name] = -1.0
+ return True, msg, rewards
+
+ # 2. Undercovers win if all Civilians are eliminated
+ if civilians_count == 0:
+ msg = "[Game] All Civilians have been eliminated! Undercovers win."
+ env.recv_message("Environment", SimpleMessage(message=msg))
+
+ rewards = {}
+ for agent_name in env.agents:
+ role = env.agent_to_role.get(agent_name, "")
+ team_name = env.role_to_team.get(role, "")
+ if team_name == "Undercover":
+ rewards[agent_name] = 1.0
+ else:
+ rewards[agent_name] = -1.0
+ return True, msg, rewards
+
+ # 3. Undercovers win if exactly 1 Undercover and 1 Civilian remain
+ if undercover_count == 1 and civilians_count == 1:
+ msg = "[Game] Only 1 Civilian and 1 Undercover remain! Undercover wins."
+ env.recv_message("Environment", SimpleMessage(message=msg))
+
+ rewards = {}
+ for agent_name in env.agents:
+ role = env.agent_to_role.get(agent_name, "")
+ team_name = env.role_to_team.get(role, "")
+ if team_name == "Undercover":
+ rewards[agent_name] = 1.0
+ else:
+ rewards[agent_name] = -1.0
+ return True, msg, rewards
+
+ # Other cases (e.g. 2v2, 2v1, 1v2): Game continues
+ # Note: 2v2 continues because undercovers don't know each other.
+
+ return False, "", {}
+
+ def __call__(
+ self, turn_number: int, messages: List[tuple[str, Message]], **kwargs: Any
+ ) -> List[tuple[str, tuple[tuple[str, int | float | bool], str]]]:
+ # Check turn limit
+ if turn_number >= self.max_turn_number:
+ return [("environment", (("terminated", True), "Max turns reached"))]
+
+ env = kwargs.get("env")
+ if not env:
+ return [("environment", (("terminated", False), ""))]
+
+ terminated, reason, rewards = self._check_win_conditions(
+ env, turn_number, messages
+ )
+
+ response: List[tuple[str, tuple[tuple[str, int | float | bool], str]]] = [
+ ("environment", (("terminated", terminated), reason))
+ ]
+
+ if terminated and rewards:
+ agent_names = list(env.agents)
+ for agent_name, reward in rewards.items():
+ try:
+ idx = agent_names.index(agent_name)
+ generic_key = f"agent_{idx+1}"
+ response.append((generic_key, (("complete_rating", reward), "")))
+ except ValueError:
+ continue
+
+ return response
+
+
+# ============================================================================
+# Undercover-specific game logic
+# ============================================================================
+
+
+class UndercoverActionHandler(ActionHandler):
+ """Handles actions for the Undercover game."""
+
+ def handle_action(
+ self, env: SocialDeductionGame, agent_name: str, action: AgentAction
+ ) -> None:
+ """Handle a single action from an agent based on current state."""
+
+ if env.current_state == "Round_vote":
+ # Collect votes for elimination
+ if "votes" not in env.internal_state:
+ env.internal_state["votes"] = {}
+
+ if action.action_type == "action" and "vote" in action.argument.lower():
+ # Parse target from "vote Aurora" or "I vote for Aurora"
+ words = action.argument.split()
+ # Try to find a name (capitalized word)
+ target = next(
+ (w for w in words if w[0].isupper() and w in env.agents), None
+ )
+ if target:
+ env.internal_state["votes"][agent_name] = target
+
+ def get_action_instruction(self, env: SocialDeductionGame, agent_name: str) -> str:
+ """Get specific action instructions for an agent based on current state."""
+
+ if env.current_state == "Round_vote":
+ return "It is voting time. You MUST use the command 'vote NAME' to vote for the player you suspect is the Undercover. e.g. 'vote Alice'"
+
+ return ""
+
+
+class UndercoverEnv(SocialDeductionGame):
+ """Undercover game with description and voting."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(action_handler=UndercoverActionHandler(), **kwargs)
+
+ def reset(
+ self,
+ seed: int | None = None,
+ options: Dict[str, str] | None = None,
+ agents: Agents | None = None,
+ omniscient: bool = False,
+ lite: bool = False,
+ include_background_observations: bool = True,
+ ) -> Dict[str, Observation]:
+ return super().reset(
+ seed=seed,
+ options=options,
+ agents=agents,
+ omniscient=omniscient,
+ lite=lite,
+ include_background_observations=include_background_observations,
+ )
+
+ def _check_eliminations(self) -> None:
+ """Apply eliminations based on collected actions."""
+ # Only apply eliminations if we are about to transition state
+ if not self._should_transition_state():
+ return
+
+ if self.current_state == "Round_vote":
+ # Tally votes and eliminate most voted player
+ votes = self.internal_state.get("votes", {})
+ if votes:
+ vote_counts: Dict[str, int] = {}
+ for target in votes.values():
+ vote_counts[target] = vote_counts.get(target, 0) + 1
+
+ if vote_counts:
+ # Find player with most votes
+ eliminated = max(vote_counts, key=vote_counts.get) # type: ignore
+ self.agent_alive[eliminated] = False
+ self.recv_message(
+ "Environment",
+ SimpleMessage(
+ message=f"[Game] {eliminated} was voted out! They were a {self.agent_to_role[eliminated]}."
+ ),
+ )
+ _gen_logger.info(
+ f"{eliminated} was voted out! They were a {self.agent_to_role[eliminated]}."
+ )
+ # Clear votes
+ self.internal_state["votes"] = {}
+
+
+# ============================================================================
+# Setup helpers
+# ============================================================================
+
+
+def ensure_agent_profile(config: Dict[str, Any]) -> AgentProfile:
+ """Create or retrieve agent profile."""
+ name = config.get("name", "")
+ role = config.get("role", "")
+
+ first_name, _, last_name = name.partition(" ")
+ if not last_name:
+ last_name = ""
+
+ # Try to find existing
+ try:
+ existing = AgentProfile.find(
+ (AgentProfile.first_name == first_name)
+ & (AgentProfile.last_name == last_name)
+ ).all()
+ if existing:
+ return AgentProfile.get(existing[0].pk)
+ except Exception:
+ pass
+
+ # Create new
+ role_secret = config.get("role_secrets", {}).get(role, "")
+ profile = AgentProfile(
+ first_name=first_name,
+ last_name=last_name,
+ secret=role_secret,
+ )
+ profile.save()
+ return profile
+
+
+def create_environment(
+ env_profile: EnvironmentProfile, model_name: str, config: Dict[str, Any]
+) -> UndercoverEnv:
+ """Create undercover game environment."""
+ return UndercoverEnv(
+ env_profile=env_profile,
+ config=config,
+ model_name=model_name,
+ action_order="round-robin",
+ evaluators=[UndercoverGameEndEvaluator(max_turn_number=100)],
+ terminal_evaluators=[],
+ hide_unknown=True,
+ )
+
+
+def create_agents(
+ agent_profiles: List[AgentProfile],
+ env_profile: EnvironmentProfile,
+ model_name: str | Dict[str, str] | List[str],
+ config: Dict[str, Any],
+) -> List[LLMAgent]:
+ """Create LLM agents."""
+ agents = []
+ for idx, profile in enumerate(agent_profiles):
+ # Calculate secrets
+ agent_name = f"{profile.first_name}{' ' + profile.last_name if profile.last_name else ''}"
+ role_goal = env_profile.agent_goals[idx]
+
+ # Get secret based on role
+ role = config.get("agents", [])[idx].get("role", "")
+ # role_secrets is dictionary in config
+ secrets = config.get("role_secrets", {}).get(role, "")
+
+ # Fill template
+ # SOCIAL_GAME_PROMPT_TEMPLATE comes with {secret} placeholder.
+ filled_template = (
+ SOCIAL_GAME_PROMPT_TEMPLATE.replace("{description}", env_profile.scenario)
+ .replace("{secret}", f"Your secret info: {secrets}")
+ .replace(
+ "{goal}",
+ role_goal,
+ )
+ )
+
+ # Determine model
+ if isinstance(model_name, dict):
+ this_agent_model = model_name.get(
+ agent_name, model_name.get("default", "gpt-4")
+ )
+ elif isinstance(model_name, list):
+ this_agent_model = model_name[idx]
+ else:
+ this_agent_model = model_name
+
+ agent = LLMAgent(
+ agent_name=agent_name,
+ agent_profile=profile,
+ model_name=this_agent_model,
+ strict_action_constraint=True,
+ custom_template=filled_template,
+ )
+ agent.goal = env_profile.agent_goals[idx]
+ agents.append(agent)
+ return agents
+
+
+def prepare_scenario(
+ env_model_name: str,
+ agent_model_name: str | Dict[str, str] | List[str],
+ config: Dict[str, Any] | None = None,
+) -> tuple[SocialDeductionGame, List[LLMAgent]]:
+ """Load config and create profiles."""
+ if config is None:
+ config = load_config(CONFIG_PATH)
+
+ # Create agent profiles
+ agent_profiles = []
+ agent_goals = []
+ for entry in config.get("agents", []):
+ profile = ensure_agent_profile(entry)
+ agent_profiles.append(profile)
+
+ role_goal = config.get("role_goals", {}).get(entry.get("role", ""), "")
+ agent_goals.append(role_goal)
+
+ # Create environment profile
+ agent_names = [entry.get("name", "") for entry in config.get("agents", [])]
+ scenario = config.get("description", "Undercover game").format(
+ agent_names=", ".join(agent_names)
+ )
+ env_profile = EnvironmentProfile(
+ scenario=scenario,
+ relationship=RelationshipType.acquaintance,
+ agent_goals=agent_goals,
+ tag="undercover",
+ )
+ env_profile.save()
+
+ env = create_environment(env_profile, env_model_name, config)
+ agents = create_agents(agent_profiles, env_profile, agent_model_name, config)
+ return env, agents
+
+
+def print_roster(config: Dict[str, Any]) -> None:
+ """Print game roster."""
+ print("Participants & roles:")
+ for entry in config.get("agents", []):
+ name = entry.get("name", "Unknown")
+ role = entry.get("role", "Unknown")
+ print(f" - {name}: {role}")
+
+
+# ============================================================================
+# Main
+# ============================================================================
+
+
+def get_model_names(config: Dict[str, Any]) -> Dict[str, str]:
+ """Extract model names from config. Enforces strict requirement."""
+ model_map = {}
+ for entry in config.get("agents", []):
+ name = entry.get("name")
+ model = entry.get("agent_model")
+ if not name:
+ continue
+ if not model:
+ raise ValueError(
+ f"Agent '{name}' missing 'agent_model' in config configuration."
+ )
+ model_map[name] = model
+ return model_map
+
+
+async def main() -> None:
+ """Run undercover game."""
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Run Undercover game.")
+ parser.add_argument(
+ "--config",
+ type=str,
+ default=str(CONFIG_PATH),
+ help="Path to configuration file",
+ )
+ parser.add_argument(
+ "--roster",
+ type=str,
+ default=str(BASE_DIR / "roster.json"),
+ help="Path to roster file",
+ )
+ args = parser.parse_args()
+
+ config_path = args.config
+ roster_path = args.roster
+
+ config = load_config(config_path)
+ roster = load_config(roster_path)
+
+ # Merge roster into config
+ config["agents"] = roster.get("agents", [])
+
+ env_model_name = "gpt-4o"
+ agent_model_name = get_model_names(config)
+
+ # Setup
+ env, agents = prepare_scenario(env_model_name, agent_model_name, config)
+
+ # Display roster
+ # Config already loaded above
+ print("🕵️ Undercover")
+ print("=" * 60)
+ print_roster(config)
+ print("=" * 60)
+
+ # Run game
+ await arun_one_episode(
+ env=env,
+ agent_list=agents,
+ omniscient=False,
+ script_like=False,
+ json_in_script=False,
+ tag="test_undercover",
+ push_to_db=True,
+ )
+
+
+if __name__ == "__main__":
+ # Configure logging for standalone execution
+ LOG_FILE = BASE_DIR / "undercover_game_debug.log"
+ _fh = logging.FileHandler(LOG_FILE, mode="w", encoding="utf-8")
+ _fh.setFormatter(
+ logging.Formatter("%(asctime)s %(levelname)-7s %(name)s - %(message)s")
+ )
+
+ _gen_logger.setLevel(logging.DEBUG)
+ _gen_logger.addHandler(_fh)
+
+ _env_logger.setLevel(logging.INFO)
+ _env_logger.addHandler(_fh)
+ _env_logger.addHandler(RichHandler())
+
+ asyncio.run(main())
diff --git a/examples/experimental/games/undercover/roster.json b/examples/experimental/games/undercover/roster.json
new file mode 100644
index 000000000..e0a92cd6c
--- /dev/null
+++ b/examples/experimental/games/undercover/roster.json
@@ -0,0 +1,40 @@
+{
+ "agents": [
+ {
+ "name": "Alice",
+ "role": "Civilian",
+ "team": "Civilians",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "Bob",
+ "role": "Undercover",
+ "team": "Undercover",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "Charlie",
+ "role": "Civilian",
+ "team": "Civilians",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "David",
+ "role": "Civilian",
+ "team": "Civilians",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "Eve",
+ "role": "Undercover",
+ "team": "Undercover",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "Frank",
+ "role": "Civilian",
+ "team": "Civilians",
+ "agent_model": "gpt-4o"
+ }
+ ]
+}
diff --git a/examples/experimental/games/werewolves/README.md b/examples/experimental/games/werewolves/README.md
new file mode 100644
index 000000000..7adf42c66
--- /dev/null
+++ b/examples/experimental/games/werewolves/README.md
@@ -0,0 +1,43 @@
+# Duskmire Werewolves
+
+A text-based social deduction game built on top of `sotopia`. This experimental example demonstrates how to implement complex game phases (Day/Night), roles, and turn-based interactions using the Sotopia framework.
+
+## Overview
+
+In this 6-player game, players are assigned roles (Villager, Werewolf, Seer, Witch) and compete to eliminate the opposing team.
+
+- **Villagers**: Must identify and vote out Werewolves.
+- **Werewolves**: Must deceive Villagers and eliminate them at night.
+- **Seer**: Can inspect one player's role each night.
+- **Witch**: Has one potion to save a victim and one to poison a suspect.
+
+## Features
+
+- **Sequential Discussion**: Utilizes `round-robin` action order during the day, ensuring agents speak one after another and can reference previous arguments.
+- **Simultaneous Action**: Night phases and voting are simultaneous to preserve secrecy/fairness.
+- **Global Event Notifications**: Players receive system messages about state transitions (e.g., "Entering Night Phase") regardless of their role visibility settings.
+- **Safe Elimination**: Role information is hidden from players upon elimination to simulate realistic uncertainty (roles are only revealed in admin logs).
+
+## Running the Game
+
+1. Ensure you have the `sotopia` environment set up.
+2. Run the main script:
+ ```bash
+ python examples/experimental/werewolves/main.py
+ ```
+ *Note: Ensure your Redis server is running.*
+
+## Configuration
+
+The game is configured via `config.json`. Key settings include:
+
+- **`state_properties`**: Defines the phases (Day/Night).
+ - `action_order`: Set to `"round-robin"` for sequential phases (e.g., `Day_discussion`), or `"simultaneous"` for others (e.g., `Day_vote`).
+ - `visibility`: Controls who sees messages (`"public"`, `"team"`, `"private"`).
+- **`agents`**: Defines the roster and roles.
+
+## Extending
+
+To modify the game logic, check:
+- `main.py`: Handles game initialization and elimination logic (`_check_eliminations`).
+- `config.json` and `sotopia/envs/social_game.py`: Adjusts game balance, roles, and state transitions.
diff --git a/examples/experimental/games/werewolves/__init__.py b/examples/experimental/games/werewolves/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/experimental/games/werewolves/config.json b/examples/experimental/games/werewolves/config.json
new file mode 100644
index 000000000..6665e40c3
--- /dev/null
+++ b/examples/experimental/games/werewolves/config.json
@@ -0,0 +1,83 @@
+{
+ "scenario": "Werewolves Game",
+ "description": "This game is called Werewolves (also known as Mafia), which is a social deduction game. GAME STRUCTURE: This game contains six players: two villagers, two werewolves, one seer, and one witch. Each cycle consists of Night phases followed by Day phases. WIN CONDITIONS: Villagers win by eliminating all Werewolves. Werewolves win when they equal or outnumber the Villagers. NIGHT PHASES: (1) Werewolves wake and privately choose one Villager to kill using 'kill NAME'. (2) The Seer wakes and inspects one player using 'inspect NAME' to learn if they are a Werewolf or Villager. (3) The Witch wakes and may use their one-time save potion with 'save NAME' (if someone died tonight) or their one-time poison potion with 'poison NAME' to kill someone. DAY PHASES: (1) Dawn: All players learn who died during the night (if any). (2) Discussion: All living players go *one* round of public discussion to discuss and debate who might be a Werewolf. Use 'speak' to share your thoughts and suspicions. (3) Voting: Each player votes to eliminate one suspected Werewolf using 'vote NAME'. The player with the most votes is executed. (4) Twilight: The execution result is announced and night returns. COMMUNICATION RULES: All day discussions are public. Dead players cannot speak or act. If 'action' is available, use commands like 'kill NAME', 'inspect NAME', 'save NAME', 'poison NAME', or 'vote NAME'. Note: Player actions in the nights are not seen by players of other roles (e.g. only wolves can see other wolves' actions). Day discussion is public. Voting requires an 'action' beginning with 'vote'.\n\n The players in this game are: {agent_names}, and during the day, they speak in the same order as they are listed.",
+ "role_goals": {
+ "Villager": "Act openly and collaboratively to identify werewolves.",
+ "Werewolf": "Deceive others, avoid detection, and eliminate villagers.",
+ "Seer": "Discreetly inspect players at night to uncover werewolves.",
+ "Witch": "Use save/poison potions strategically to aid villagers."
+ },
+ "role_secrets": {
+ "Werewolf": "I am a werewolf."
+ },
+ "initial_state": "Night_werewolf",
+ "state_transition": {
+ "Night_werewolf": "Night_seer",
+ "Night_seer": "Night_witch",
+ "Night_witch": "Day_discussion",
+ "Day_discussion": "Day_vote",
+ "Day_vote": "Night_werewolf"
+ },
+ "state_properties": {
+ "Night_werewolf": {
+ "acting_roles": [
+ "Werewolf"
+ ],
+ "actions": [
+ "action"
+ ],
+ "visibility": "team"
+ },
+ "Night_seer": {
+ "acting_roles": [
+ "Seer"
+ ],
+ "actions": [
+ "action"
+ ],
+ "visibility": "private"
+ },
+ "Night_witch": {
+ "acting_roles": [
+ "Witch"
+ ],
+ "actions": [
+ "action"
+ ],
+ "visibility": "private",
+ "internal_state": {
+ "save_available": true,
+ "poison_available": true
+ }
+ },
+ "Day_discussion": {
+ "actions": [
+ "speak"
+ ],
+ "action_order": "round-robin",
+ "visibility": "public"
+ },
+ "Day_vote": {
+ "actions": [
+ "action"
+ ],
+ "action_order": "simultaneous",
+ "visibility": "public"
+ }
+ },
+ "end_conditions": [
+ {
+ "type": "team_eliminated",
+ "team": "Werewolves",
+ "winner": "Villagers",
+ "message": "Villagers win; no werewolves remain."
+ },
+ {
+ "type": "parity",
+ "team": "Werewolves",
+ "other": "Villagers",
+ "winner": "Werewolves",
+ "message": "Werewolves win; they now match the village."
+ }
+ ]
+}
diff --git a/examples/experimental/games/werewolves/main.py b/examples/experimental/games/werewolves/main.py
new file mode 100644
index 000000000..b10061488
--- /dev/null
+++ b/examples/experimental/games/werewolves/main.py
@@ -0,0 +1,649 @@
+"""Launcher for the Duskmire Werewolves social game scenario."""
+
+from __future__ import annotations
+
+import asyncio
+import os
+from pathlib import Path
+import logging
+from typing import Any, Dict, List
+import random
+from collections import Counter
+
+from rich.logging import RichHandler
+import redis
+
+from sotopia.agents import LLMAgent
+from sotopia.agents.llm_agent import Agents
+from sotopia.database.persistent_profile import (
+ AgentProfile,
+ EnvironmentProfile,
+ RelationshipType,
+)
+from sotopia.envs import SocialDeductionGame
+from sotopia.envs.social_game import (
+ ActionHandler,
+ load_config,
+ SOCIAL_GAME_PROMPT_TEMPLATE,
+)
+from sotopia.envs.evaluators import SocialGameEndEvaluator
+from sotopia.server import arun_one_episode
+from sotopia.messages import AgentAction, SimpleMessage, Message, Observation
+
+BASE_DIR = Path(__file__).resolve().parent
+CONFIG_PATH = BASE_DIR / "config.json"
+
+os.environ.setdefault("REDIS_OM_URL", "redis://:@localhost:6379")
+redis.Redis(host="localhost", port=6379)
+
+# Loggers (configured in main if running standalone)
+_gen_logger = logging.getLogger("sotopia.generation")
+_env_logger = logging.getLogger("sotopia.envs.social_game")
+
+
+# ============================================================================
+# Werewolf game-end evaluator
+# ============================================================================
+
+
+class WerewolfGameEndEvaluator(SocialGameEndEvaluator):
+ """Evaluator that checks werewolf win conditions."""
+
+ def _check_win_conditions( # type: ignore[override]
+ self, env: Any, turn_number: int, messages: List[tuple[str, Message]]
+ ) -> tuple[bool, str, Dict[str, float]]:
+ """Check if game has ended based on werewolf win conditions."""
+ # Count alive players by team
+ team_counts: Dict[str, int] = {}
+ for agent_name, alive in env.agent_alive.items():
+ if alive:
+ role = env.agent_to_role.get(agent_name, "")
+ team = env.role_to_team.get(role, "")
+ team_counts[team] = team_counts.get(team, 0) + 1
+
+ # Check end conditions from config
+ end_conditions = env._config.get("end_conditions", [])
+ for condition in end_conditions:
+ cond_type = condition.get("type")
+
+ if cond_type == "team_eliminated":
+ team = condition.get("team", "")
+ if team_counts.get(team, 0) == 0:
+ winner = condition.get("winner", "")
+ msg = condition.get("message", f"{winner} wins!")
+ env.recv_message("Environment", SimpleMessage(message=msg))
+
+ # Calculate rewards
+ rewards = {}
+ for agent_name in env.agents:
+ role = env.agent_to_role.get(agent_name, "")
+ team = env.role_to_team.get(role, "")
+ if team == winner:
+ rewards[agent_name] = 1.0
+ else:
+ rewards[agent_name] = -1.0
+
+ return True, msg, rewards
+
+ elif cond_type == "parity":
+ team1 = condition.get("team", "")
+ team2 = condition.get("other", "")
+ if team_counts.get(team1, 0) >= team_counts.get(team2, 0):
+ winner = condition.get("winner", "")
+ msg = condition.get("message", f"{winner} wins!")
+ env.recv_message("Environment", SimpleMessage(message=msg))
+
+ # Calculate rewards
+ rewards = {}
+ for agent_name in env.agents:
+ role = env.agent_to_role.get(agent_name, "")
+ team = env.role_to_team.get(role, "")
+ if team == winner:
+ rewards[agent_name] = 1.0
+ else:
+ rewards[agent_name] = -1.0
+
+ return True, msg, rewards
+
+ return False, "", {}
+
+ def __call__(
+ self, turn_number: int, messages: List[tuple[str, Message]], **kwargs: Any
+ ) -> List[tuple[str, tuple[tuple[str, int | float | bool], str]]]:
+ # Check turn limit
+ if turn_number >= self.max_turn_number:
+ return [("environment", (("terminated", True), "Max turns reached"))]
+
+ # Extract environment from kwargs
+ env = kwargs.get("env")
+ if not env:
+ return [("environment", (("terminated", False), ""))]
+
+ # Check game-specific win conditions
+ terminated, reason, rewards = self._check_win_conditions(
+ env, turn_number, messages
+ )
+
+ response: List[tuple[str, tuple[tuple[str, int | float | bool], str]]] = [
+ ("environment", (("terminated", terminated), reason))
+ ]
+
+ if terminated and rewards:
+ agent_names = list(env.agents) # Ensure order matches env.agents
+ for agent_name, reward in rewards.items():
+ # Find index of agent in env.agents
+ try:
+ idx = agent_names.index(agent_name)
+ generic_key = f"agent_{idx+1}" # e.g. agent_1, agent_2
+ response.append((generic_key, (("complete_rating", reward), "")))
+ except ValueError:
+ # Should not happen if agent_name is in env.agents
+ continue
+
+ return response
+
+
+# ============================================================================
+# Werewolf-specific game logic
+# ============================================================================
+
+
+class WerewolfActionHandler(ActionHandler):
+ """Handles actions for the Werewolf game."""
+
+ def handle_action(
+ self, env: SocialDeductionGame, agent_name: str, action: AgentAction
+ ) -> None:
+ """Handle a single action from an agent based on current state."""
+
+ if env.current_state == "Day_vote":
+ # Collect votes for elimination
+ if "votes" not in env.internal_state:
+ env.internal_state["votes"] = {}
+
+ if action.action_type == "action" and "vote" in action.argument.lower():
+ # Parse target from "vote Aurora" or "I vote for Aurora"
+ words = action.argument.split()
+ # Try to find a name (capitalized word)
+ target = next(
+ (w for w in words if w[0].isupper() and w in env.agents), None
+ )
+ if target:
+ env.internal_state["votes"][agent_name] = target
+
+ elif env.current_state == "Night_werewolf":
+ # Werewolves choose kill target
+ role = env.agent_to_role.get(agent_name, "")
+ if role == "Werewolf" and action.action_type == "action":
+ if "kill" in action.argument.lower():
+ words = action.argument.split()
+ target = next(
+ (w for w in words if w[0].isupper() and w in env.agents),
+ None,
+ )
+ if target:
+ if "kill_target_proposals" not in env.internal_state:
+ env.internal_state["kill_target_proposals"] = {}
+ env.internal_state["kill_target_proposals"][agent_name] = target
+ # Update the werewolf kill result
+ kill_votes = env.internal_state.get("kill_target_proposals", {})
+ if kill_votes:
+ # Count votes
+ vote_counts = Counter(kill_votes.values())
+ if vote_counts:
+ # Find max votes
+ max_votes = max(vote_counts.values())
+ # Get all targets with max votes
+ candidates = [
+ t for t, c in vote_counts.items() if c == max_votes
+ ]
+ # Break tie randomly
+ env.internal_state["kill_target"] = random.choice(
+ candidates
+ )
+
+ elif env.current_state == "Night_seer":
+ # Seer inspects someone
+ role = env.agent_to_role.get(agent_name, "")
+ if role == "Seer" and action.action_type == "action":
+ if "inspect" in action.argument.lower():
+ words = action.argument.split()
+ target = next(
+ (w for w in words if w[0].isupper() and w in env.agents),
+ None,
+ )
+ if target:
+ # Reveal target's role to seer
+ target_role = env.agent_to_role.get(target, "Unknown")
+ target_team = env.role_to_team.get(target_role, "Unknown")
+ env.recv_message(
+ "Environment",
+ SimpleMessage(
+ message=f"[Private to {agent_name}] {target} is on team: {target_team}"
+ ),
+ receivers=[agent_name],
+ )
+
+ elif env.current_state == "Night_witch":
+ # Witch uses potions
+ role = env.agent_to_role.get(agent_name, "")
+ if role == "Witch" and action.action_type == "action":
+ if "save" in action.argument.lower():
+ env.internal_state["witch_have_save"] = False
+ words = action.argument.split()
+ target = next(
+ (w for w in words if w[0].isupper() and w in env.agents),
+ None,
+ )
+ if target:
+ env.internal_state["saved_target"] = target
+ elif "poison" in action.argument.lower():
+ env.internal_state["witch_have_posion"] = False
+ words = action.argument.split()
+ target = next(
+ (w for w in words if w[0].isupper() and w in env.agents),
+ None,
+ )
+ if target:
+ env.internal_state["poison_target"] = target
+
+ def get_action_instruction(self, env: SocialDeductionGame, agent_name: str) -> str:
+ """Get specific action instructions for an agent based on current state."""
+ role = env.agent_to_role.get(agent_name, "")
+
+ if env.current_state == "Day_vote":
+ return "It is voting time. You MUST use the command 'vote NAME' to vote for a player to eliminate. e.g. 'vote Alice'"
+
+ elif env.current_state == "Night_werewolf":
+ if role == "Werewolf":
+ return "It is Night. You are a Werewolf. You MUST use the command 'kill NAME' to choose a target. e.g. 'kill Bob'"
+ else:
+ return "It is Night. You are sleeping."
+
+ elif env.current_state == "Night_seer":
+ if role == "Seer":
+ return "It is Night. You are the Seer. You MUST use the command 'inspect NAME' to check a player's team. e.g. 'inspect Charlie'"
+ else:
+ return "It is Night. You are sleeping."
+
+ elif env.current_state == "Night_witch":
+ if role == "Witch":
+ if env.internal_state.get(
+ "witch_have_posion", True
+ ) and env.internal_state.get("witch_have_save", True):
+ use_potion_guide = "You can use 'save NAME' or 'poison NAME'. If you don't want to use potions, you can put 'skip' in the argument of action."
+ elif env.internal_state.get("witch_have_posion", True):
+ use_potion_guide = "You can use 'poison NAME'. If you don't want to use potions, you can put 'skip' in the argument of action."
+ elif env.internal_state.get("witch_have_save", True):
+ use_potion_guide = "You can use 'save NAME'. If you don't want to use potions, you can put 'skip' in the argument of action."
+ else:
+ use_potion_guide = (
+ "You can't use any potions as you don't have any left."
+ )
+ killed_message = ""
+ if kill_target := env.internal_state.get("kill_target", None):
+ killed_message = f"{kill_target} is killed by werewolves."
+ return f"It is Night. You are the Witch. {use_potion_guide} {killed_message}"
+ else:
+ return "It is Night. You are sleeping."
+
+ return ""
+
+
+class WerewolfEnv(SocialDeductionGame):
+ """Werewolf game with voting, kills, and special roles."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(action_handler=WerewolfActionHandler(), **kwargs)
+
+ def reset(
+ self,
+ seed: int | None = None,
+ options: dict[str, str] | None = None,
+ agents: Agents | None = None,
+ omniscient: bool = False,
+ lite: bool = False,
+ include_background_observations: bool = True,
+ ) -> Dict[str, Observation]:
+ obs = super().reset(
+ seed=seed,
+ options=options,
+ agents=agents,
+ omniscient=omniscient,
+ lite=lite,
+ include_background_observations=include_background_observations,
+ )
+ # Witch has potions
+ self.internal_state["witch_have_posion"] = True
+ self.internal_state["witch_have_save"] = True
+ # Werewolves have kill targets
+ self.internal_state["kill_target_proposals"] = {}
+ return obs
+
+ def _check_eliminations(self) -> None:
+ """Apply eliminations based on collected actions."""
+ # Only apply eliminations if we are about to transition state
+ if not self._should_transition_state():
+ return
+
+ if self.current_state == "Day_vote":
+ # Tally votes and eliminate most voted player
+ votes = self.internal_state.get("votes", {})
+ if votes:
+ vote_counts: Dict[str, int] = {}
+ for target in votes.values():
+ vote_counts[target] = vote_counts.get(target, 0) + 1
+
+ if vote_counts:
+ # Find player with most votes
+ eliminated = max(vote_counts, key=vote_counts.get) # type: ignore
+ self.agent_alive[eliminated] = False
+ self.recv_message(
+ "Environment",
+ SimpleMessage(message=f"[Game] {eliminated} was voted out!"),
+ )
+ # Clear votes
+ self.internal_state["votes"] = {}
+ # log elimination
+ _gen_logger.info(
+ f"{eliminated} was voted out! They were a {self.agent_to_role[eliminated]}."
+ )
+ _gen_logger.info(f"Remaining players: {self.agent_alive}")
+
+ elif self.current_state == "Night_witch":
+ # Resolve Night actions (Werewolf kill + Witch save/poison)
+
+ kill_target = self.internal_state.get("kill_target")
+ saved_target = self.internal_state.get("saved_target")
+ poison_target = self.internal_state.get("poison_target")
+
+ # Check kill
+ if kill_target and self.agent_alive.get(kill_target, False):
+ if kill_target != saved_target:
+ self.agent_alive[kill_target] = False
+ self.recv_message(
+ "Environment",
+ SimpleMessage(
+ message=f"[Game] {kill_target} was killed by werewolves!"
+ ),
+ )
+ _gen_logger.info(f"{kill_target} was killed by werewolves!")
+ _gen_logger.info(f"Remaining players: {self.agent_alive}")
+ else:
+ self.recv_message(
+ "Environment",
+ SimpleMessage(message="[Game] An attack was prevented!"),
+ )
+ _gen_logger.info(f"An attack to {kill_target} was prevented!")
+ _gen_logger.info(f"Remaining players: {self.agent_alive}")
+
+ # 2. Witch Poison
+ if poison_target and self.agent_alive.get(poison_target, False):
+ self.agent_alive[poison_target] = False
+ self.recv_message(
+ "Environment",
+ SimpleMessage(
+ message=f"[Game] {poison_target} died by witch's poison!"
+ ),
+ )
+ _gen_logger.info(f"{poison_target} died by witch's poison!")
+ _gen_logger.info(f"Remaining players: {self.agent_alive}")
+
+ # Clear night states
+ self.internal_state.pop("kill_target_proposals", None)
+ self.internal_state.pop("kill_target", None)
+ self.internal_state.pop("saved_target", None)
+ self.internal_state.pop("poison_target", None)
+
+
+# ============================================================================
+# Setup helpers
+# ============================================================================
+
+
+def ensure_agent_profile(config: Dict[str, Any]) -> AgentProfile:
+ """Create or retrieve agent profile."""
+ name = config.get("name", "")
+ role = config.get("role", "")
+
+ first_name, _, last_name = name.partition(" ")
+ if not last_name:
+ last_name = ""
+
+ # Try to find existing
+ try:
+ existing = AgentProfile.find(
+ (AgentProfile.first_name == first_name)
+ & (AgentProfile.last_name == last_name)
+ ).all()
+ if existing:
+ return AgentProfile.get(existing[0].pk)
+ except Exception:
+ pass
+
+ # Create new
+ role_secret = config.get("role_secrets", {}).get(role, "")
+ profile = AgentProfile(
+ first_name=first_name,
+ last_name=last_name,
+ secret=role_secret,
+ )
+ profile.save()
+ return profile
+
+
+def create_environment(
+ env_profile: EnvironmentProfile, model_name: str, config: Dict[str, Any]
+) -> WerewolfEnv:
+ """Create werewolf game environment."""
+ return WerewolfEnv(
+ env_profile=env_profile,
+ config=config,
+ model_name=model_name,
+ action_order="round-robin",
+ evaluators=[WerewolfGameEndEvaluator(max_turn_number=40)],
+ terminal_evaluators=[],
+ hide_unknown=True,
+ )
+
+
+def create_agents(
+ agent_profiles: List[AgentProfile],
+ env_profile: EnvironmentProfile,
+ model_name: str | Dict[str, str] | List[str],
+ config: Dict[str, Any], # Added config here
+) -> List[LLMAgent]:
+ # Identify werewolves for partner info
+ werewolf_goal_str = config.get("role_goals", {}).get("Werewolf", "")
+ werewolves = [
+ p.first_name + (" " + p.last_name if p.last_name else "")
+ for p in agent_profiles
+ if env_profile.agent_goals[agent_profiles.index(p)] == werewolf_goal_str
+ ]
+
+ agents = []
+ for idx, profile in enumerate(agent_profiles):
+ # Calculate secrets
+ agent_name = f"{profile.first_name}{' ' + profile.last_name if profile.last_name else ''}"
+ role_goal = env_profile.agent_goals[idx]
+ secrets = ""
+
+ # Check if agent is a werewolf
+ is_werewolf = env_profile.agent_goals[idx] == "Werewolf"
+
+ if is_werewolf:
+ partners = [w for w in werewolves if w != agent_name]
+ if partners:
+ secrets = f"Your secret: You are a werewolf. Your partner(s) are: {', '.join(partners)}."
+ else:
+ secrets = "Your secret: You are a werewolf. You have no partners."
+ filled_template = (
+ SOCIAL_GAME_PROMPT_TEMPLATE.replace("{description}", env_profile.scenario)
+ .replace("{secret}", secrets)
+ .replace(
+ "{goal}",
+ role_goal, # Also replace the goal here
+ )
+ )
+ # Determine model for this agent
+ if isinstance(model_name, dict):
+ # model_dict mapping "agentX" keys or names?
+ # Tournament runner will likely pass a list or dict keyed by index/name.
+ # Let's support a list or dict lookup.
+ # Assuming 'model_name' might be a dict like {"Aurora": "gpt-4", ...}
+ this_agent_model = model_name.get(
+ agent_name, model_name.get("default", "gpt-4")
+ )
+ elif isinstance(model_name, list):
+ this_agent_model = model_name[idx]
+ else:
+ this_agent_model = model_name
+
+ agent = LLMAgent(
+ agent_name=f"{profile.first_name}{' ' + profile.last_name if profile.last_name else ''}",
+ agent_profile=profile,
+ model_name=this_agent_model,
+ strict_action_constraint=True,
+ custom_template=filled_template,
+ )
+ agent.goal = env_profile.agent_goals[idx]
+ agents.append(agent)
+ return agents
+
+
+def prepare_scenario(
+ env_model_name: str,
+ agent_model_name: str | Dict[str, str] | List[str],
+ config: Dict[str, Any] | None = None,
+) -> tuple[SocialDeductionGame, List[LLMAgent]]:
+ """Load config and create profiles."""
+ if config is None:
+ config = load_config(CONFIG_PATH)
+
+ # Create agent profiles
+ agent_profiles = []
+ agent_goals = []
+ for entry in config.get("agents", []):
+ profile = ensure_agent_profile(entry)
+ agent_profiles.append(profile)
+
+ role_goal = config.get("role_goals", {}).get(entry.get("role", ""), "")
+ agent_goals.append(role_goal)
+
+ # Create environment profile
+ agent_names = [entry.get("name", "") for entry in config.get("agents", [])]
+ scenario = config.get("description", "Werewolves game").format(
+ agent_names=", ".join(agent_names)
+ )
+ env_profile = EnvironmentProfile(
+ scenario=scenario,
+ relationship=RelationshipType.acquaintance,
+ agent_goals=agent_goals,
+ tag="werewolves",
+ )
+ env_profile.save()
+
+ env = create_environment(env_profile, env_model_name, config)
+ agents = create_agents(agent_profiles, env_profile, agent_model_name, config)
+ return env, agents
+
+
+def print_roster(config: Dict[str, Any]) -> None:
+ """Print game roster."""
+ print("Participants & roles:")
+ for entry in config.get("agents", []):
+ name = entry.get("name", "Unknown")
+ role = entry.get("role", "Unknown")
+ print(f" - {name}: {role}")
+
+
+# ============================================================================
+# Main
+# ============================================================================
+
+
+def get_model_names(config: Dict[str, Any]) -> Dict[str, str]:
+ """Extract model names from config. Enforces strict requirement."""
+ model_map = {}
+ for entry in config.get("agents", []):
+ name = entry.get("name")
+ model = entry.get("agent_model")
+ if not name:
+ continue
+ if not model:
+ raise ValueError(
+ f"Agent '{name}' missing 'agent_model' in config configuration."
+ )
+ model_map[name] = model
+ return model_map
+
+
+async def main() -> None:
+ """Run werewolf game."""
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Run Werewolf game.")
+ parser.add_argument(
+ "--config",
+ type=str,
+ default=str(CONFIG_PATH),
+ help="Path to configuration file",
+ )
+ parser.add_argument(
+ "--roster",
+ type=str,
+ default=str(BASE_DIR / "roster.json"),
+ help="Path to roster file",
+ )
+ args = parser.parse_args()
+
+ config_path = args.config
+ roster_path = args.roster
+
+ config = load_config(config_path)
+ roster = load_config(roster_path)
+
+ # Merge roster into config
+ config["agents"] = roster.get("agents", [])
+
+ env_model_name = "gpt-4o"
+ agent_model_name = get_model_names(config)
+
+ # Setup
+ env, agents = prepare_scenario(env_model_name, agent_model_name, config)
+
+ # Display roster
+ # Config already loaded above
+ print("🌕 Duskmire Werewolves")
+ print("=" * 60)
+ print_roster(config)
+ print("=" * 60)
+
+ # Run game
+ await arun_one_episode(
+ env=env,
+ agent_list=agents,
+ omniscient=False,
+ script_like=False,
+ json_in_script=False,
+ tag="test_werewolves",
+ push_to_db=True,
+ )
+
+
+if __name__ == "__main__":
+ # Configure logging for standalone execution
+ LOG_FILE = BASE_DIR / "werewolves_game_debug.log"
+ _fh = logging.FileHandler(LOG_FILE, mode="w", encoding="utf-8")
+ _fh.setFormatter(
+ logging.Formatter("%(asctime)s %(levelname)-7s %(name)s - %(message)s")
+ )
+
+ _gen_logger.setLevel(logging.DEBUG)
+ _gen_logger.addHandler(_fh)
+
+ _env_logger.setLevel(logging.INFO)
+ _env_logger.addHandler(_fh)
+ _env_logger.addHandler(RichHandler())
+
+ asyncio.run(main())
diff --git a/examples/experimental/games/werewolves/roster.json b/examples/experimental/games/werewolves/roster.json
new file mode 100644
index 000000000..b3ad22b95
--- /dev/null
+++ b/examples/experimental/games/werewolves/roster.json
@@ -0,0 +1,40 @@
+{
+ "agents": [
+ {
+ "name": "Aurora",
+ "role": "Villager",
+ "team": "Villagers",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "Bram",
+ "role": "Werewolf",
+ "team": "Werewolves",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "Celeste",
+ "role": "Seer",
+ "team": "Villagers",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "Dorian",
+ "role": "Werewolf",
+ "team": "Werewolves",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "Elise",
+ "role": "Witch",
+ "team": "Villagers",
+ "agent_model": "gpt-4o"
+ },
+ {
+ "name": "Finn",
+ "role": "Villager",
+ "team": "Villagers",
+ "agent_model": "gpt-4o"
+ }
+ ]
+}
diff --git a/experiments/__init__.py b/experiments/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/experiments/calculate_elo.py b/experiments/calculate_elo.py
new file mode 100644
index 000000000..6173d1327
--- /dev/null
+++ b/experiments/calculate_elo.py
@@ -0,0 +1,571 @@
+import json
+import glob
+import os
+import csv
+from collections import defaultdict
+from typing import Any, Tuple, Optional, Dict
+
+# Simple ELO implementation
+K_FACTOR = 32
+STARTING_ELO = 1200
+
+
+def expected_score(rating_a: float, rating_b: float) -> float:
+ return 1 / (1 + 10 ** ((rating_b - rating_a) / 400))
+
+
+def generate_single_table_html(
+ title: str, stats: list[dict[str, Any]], show_split_elo: bool = True
+) -> str:
+ """Generates the HTML for a single leaderboard table."""
+ rows_html = ""
+ for item in stats:
+ rank = item["rank"]
+ rank_display = f"#{rank}"
+ if rank == 1:
+ rank_display = "🥇"
+ if rank == 2:
+ rank_display = "🥈"
+ if rank == 3:
+ rank_display = "🥉"
+
+ name = item["model"]
+ provider = "Unknown"
+ # Heuristic for provider
+ lower_name = name.lower()
+ if "gpt" in lower_name:
+ provider = "OpenAI"
+ elif "qwen" in lower_name:
+ provider = "Alibaba"
+ elif "gemini" in lower_name or "google" in lower_name:
+ provider = "Google"
+ elif "claude" in lower_name:
+ provider = "Anthropic"
+ elif "llama" in lower_name:
+ provider = "Meta"
+ elif "mistral" in lower_name:
+ provider = "Mistral"
+
+ wr_val = item["win_rate"]
+ wr_color = "#e55" if wr_val < 50 else "#2a9d8f"
+
+ split_elo_cells = ""
+ if show_split_elo:
+ split_elo_cells = f"""
+
{int(item['elo_w'])} |
+ {int(item['elo_v'])} |
+ """
+ else:
+ split_elo_cells = """
+ - |
+ - |
+ """
+
+ row = f"""
+
+ | {rank_display} |
+
+
+ {name}
+ {provider}
+
+ |
+ {int(item['elo'])} |
+ {split_elo_cells}
+ {item['win_rate']:.1f}% |
+ {item['matches']} |
+
+ """
+ rows_html += row
+
+ split_headers = ""
+ if show_split_elo:
+ split_headers = """
+ ELO-Alt (Wolf/Spy) |
+ ELO-Main (Vil/Civ) |
+ """
+ else:
+ split_headers = """
+ |
+ |
+ """
+
+ table_html = f"""
+
+
{title}
+
+
+
+ | Rank |
+ Model |
+ ELO |
+ {split_headers}
+ Win Rate |
+ Matches |
+
+
+
+ {rows_html}
+
+
+
+ """
+ return table_html
+
+
+def generate_html_report(tables_data: Dict[str, list[dict[str, Any]]]) -> str:
+ """
+ Generates the full HTML report with multiple tables.
+ tables_data: { "Title": stats_list, ... }
+ """
+
+ # Generate HTML for all tables
+ all_tables_html = ""
+
+ # Ensure "Overall" comes first if present
+ if "Overall" in tables_data:
+ all_tables_html += generate_single_table_html(
+ "Overall Leaderboard", tables_data["Overall"], show_split_elo=True
+ )
+
+ sorted_titles = sorted([t for t in tables_data.keys() if t != "Overall"])
+ for title in sorted_titles:
+ # Determine if we should show split ELO
+ # Symmetric games: RPS, Prisoners Dilemma -> No split
+ lower_title = title.lower()
+ is_symmetric = "rock" in lower_title or "prisoner" in lower_title
+ show_split = not is_symmetric
+
+ all_tables_html += generate_single_table_html(
+ f"{title} Leaderboard", tables_data[title], show_split_elo=show_split
+ )
+
+ html_template = f"""
+
+
+
+
+
+ Elo Leaderboard
+
+
+
+ Social Games Tournament Results
+
+ {all_tables_html}
+
+
+
+
+ """
+ return html_template
+
+
+def get_match_result(
+ model_mapping: dict[str, str],
+ agent_rewards: dict[str, float],
+ alt_model: str,
+ main_model: str,
+) -> Optional[Tuple[bool, float, float]]:
+ """
+ Determines if the Alt model won against the Main model.
+ Returns (alt_won: bool, alt_reward, main_reward) or None if inconclusive.
+ """
+
+ # We need to find ONE representative agent for Alt model and ONE for Main model
+ # to compare their rewards.
+ # Why? Because in these games, team members usually get the same reward.
+
+ alt_agent = None
+ main_agent = None
+
+ for agent, model in model_mapping.items():
+ if model == alt_model and alt_agent is None:
+ alt_agent = agent
+ elif model == main_model and main_agent is None:
+ main_agent = agent
+
+ if alt_agent and main_agent:
+ break
+
+ if not alt_agent or not main_agent:
+ return None
+
+ r_alt = agent_rewards.get(alt_agent, 0.0)
+ r_main = agent_rewards.get(main_agent, 0.0)
+
+ return (r_alt > r_main), r_alt, r_main
+
+
+def process_logs(log_files: list[str]) -> list[dict[str, Any]]:
+ """
+ Process a list of log files and return stats for the leaderboard.
+ """
+ elo_overall: dict[str, float] = defaultdict(lambda: STARTING_ELO)
+ elo_wolf: dict[str, float] = defaultdict(lambda: STARTING_ELO) # Alt role
+ elo_villager: dict[str, float] = defaultdict(lambda: STARTING_ELO) # Main role
+
+ wins: dict[str, int] = defaultdict(int)
+ total_games: dict[str, int] = defaultdict(int)
+
+ count = 0
+ for filepath in log_files:
+ try:
+ with open(filepath, "r") as f:
+ data = json.load(f)
+
+ model_mapping = data.get("model_mapping", {})
+ rewards = data.get("rewards", [])
+ metadata = data.get("metadata", {})
+
+ if not model_mapping or not rewards:
+ continue
+
+ parsed_rewards = []
+ for r in rewards:
+ if isinstance(r, (list, tuple)):
+ parsed_rewards.append(float(r[0]))
+ else:
+ parsed_rewards.append(float(r))
+
+ if len(model_mapping) != len(parsed_rewards):
+ continue
+
+ # Map Agent Name -> Reward
+ # Relies on implicit ordering of keys vs list.
+ # Sotopia seems to maintain this consistency.
+ agent_rewards = {}
+ for i, agent_name in enumerate(model_mapping.keys()):
+ agent_rewards[agent_name] = parsed_rewards[i]
+
+ # --- Dispatch Logic based on Metadata Keys ---
+
+ check_processed = False
+
+ # 1. Werewolves
+ if "Werewolves_model" in metadata and "Villagers_model" in metadata:
+ m_alt = metadata["Werewolves_model"]
+ m_main = metadata["Villagers_model"]
+
+ res = get_match_result(model_mapping, agent_rewards, m_alt, m_main)
+ if res:
+ alt_won, r_alt, r_main = res
+ score_alt = 1.0 if alt_won else 0.0
+ score_main = 1.0 - score_alt
+
+ # Updates
+ elo_overall[m_alt] += K_FACTOR * (
+ score_alt
+ - expected_score(elo_overall[m_alt], elo_overall[m_main])
+ )
+ elo_overall[m_main] += K_FACTOR * (
+ score_main
+ - expected_score(elo_overall[m_main], elo_overall[m_alt])
+ )
+
+ elo_wolf[m_alt] += K_FACTOR * (
+ score_alt
+ - expected_score(elo_wolf[m_alt], elo_villager[m_main])
+ )
+ elo_villager[m_main] += K_FACTOR * (
+ score_main
+ - expected_score(elo_villager[m_main], elo_wolf[m_alt])
+ )
+
+ winner = m_alt if alt_won else m_main
+ wins[winner] += 1
+ total_games[m_alt] += 1
+ if m_alt != m_main:
+ total_games[m_main] += 1
+ check_processed = True
+
+ # 2. Spyfall
+ elif "Spy_model" in metadata and "Non-Spies_model" in metadata:
+ m_alt = metadata["Spy_model"]
+ m_main = metadata["Non-Spies_model"]
+
+ res = get_match_result(model_mapping, agent_rewards, m_alt, m_main)
+ if res:
+ alt_won, r_alt, r_main = res
+ score_alt = 1.0 if alt_won else 0.0
+ score_main = 1.0 - score_alt
+
+ elo_overall[m_alt] += K_FACTOR * (
+ score_alt
+ - expected_score(elo_overall[m_alt], elo_overall[m_main])
+ )
+ elo_overall[m_main] += K_FACTOR * (
+ score_main
+ - expected_score(elo_overall[m_main], elo_overall[m_alt])
+ )
+
+ elo_wolf[m_alt] += K_FACTOR * (
+ score_alt
+ - expected_score(elo_wolf[m_alt], elo_villager[m_main])
+ )
+ elo_villager[m_main] += K_FACTOR * (
+ score_main
+ - expected_score(elo_villager[m_main], elo_wolf[m_alt])
+ )
+
+ winner = m_alt if alt_won else m_main
+ wins[winner] += 1
+ total_games[m_alt] += 1
+ if m_alt != m_main:
+ total_games[m_main] += 1
+ check_processed = True
+
+ # 3. Undercover
+ elif "Undercover_model" in metadata and "Civilians_model" in metadata:
+ m_alt = metadata["Undercover_model"]
+ m_main = metadata["Civilians_model"]
+
+ res = get_match_result(model_mapping, agent_rewards, m_alt, m_main)
+ if res:
+ alt_won, r_alt, r_main = res
+ score_alt = 1.0 if alt_won else 0.0
+ score_main = 1.0 - score_alt
+
+ elo_overall[m_alt] += K_FACTOR * (
+ score_alt
+ - expected_score(elo_overall[m_alt], elo_overall[m_main])
+ )
+ elo_overall[m_main] += K_FACTOR * (
+ score_main
+ - expected_score(elo_overall[m_main], elo_overall[m_alt])
+ )
+
+ elo_wolf[m_alt] += K_FACTOR * (
+ score_alt
+ - expected_score(elo_wolf[m_alt], elo_villager[m_main])
+ )
+ elo_villager[m_main] += K_FACTOR * (
+ score_main
+ - expected_score(elo_villager[m_main], elo_wolf[m_alt])
+ )
+
+ winner = m_alt if alt_won else m_main
+ wins[winner] += 1
+ total_games[m_alt] += 1
+ if m_alt != m_main:
+ total_games[m_main] += 1
+ check_processed = True
+
+ # 4. Symmetric Fallback
+ else:
+ # If no role keys, assume symmetric (RPS, PD, etc)
+ # metadata should have model_a, model_b
+ # OR we just pick 2 from model_mapping
+
+ agents = list(model_mapping.keys())
+ if len(agents) >= 2:
+ # Prefer metadata definition if available?
+ # Actually standardizing on model_mapping logic is safer
+
+ a1, a2 = agents[0], agents[1]
+ m1, m2 = model_mapping[a1], model_mapping[a2]
+ r1, r2 = agent_rewards[a1], agent_rewards[a2]
+
+ if r1 > r2:
+ s1, s2 = 1.0, 0.0
+ elif r2 > r1:
+ s1, s2 = 0.0, 1.0
+ else:
+ s1, s2 = 0.5, 0.5
+
+ elo_overall[m1] += K_FACTOR * (
+ s1 - expected_score(elo_overall[m1], elo_overall[m2])
+ )
+ elo_overall[m2] += K_FACTOR * (
+ s2 - expected_score(elo_overall[m2], elo_overall[m1])
+ )
+
+ if s1 > s2:
+ wins[m1] += 1
+ elif s2 > s1:
+ wins[m2] += 1
+
+ total_games[m1] += 1
+ if m1 != m2:
+ total_games[m2] += 1
+ check_processed = True
+
+ if check_processed:
+ count += 1
+
+ except Exception:
+ # print(f"Error processing {filepath}: {e}")
+ continue
+
+ # Generate Stats List
+ sorted_models = sorted(
+ elo_overall.keys(), key=lambda m: elo_overall[m], reverse=True
+ )
+ stats_list = []
+
+ for rank, model in enumerate(sorted_models, 1):
+ n_games = total_games[model]
+ win_rate = (wins[model] / n_games * 100) if n_games > 0 else 0.0
+ display_name = model.split("@")[0].replace("custom/", "")
+
+ stats_list.append(
+ {
+ "rank": rank,
+ "model": display_name,
+ "elo": elo_overall[model],
+ "elo_w": elo_wolf[model],
+ "elo_v": elo_villager[model],
+ "win_rate": win_rate,
+ "matches": n_games,
+ }
+ )
+
+ return stats_list
+
+
+def save_to_csv(title: str, stats: list[dict[str, Any]]) -> None:
+ """Saves the stats list to a CSV file."""
+ # Sanitize title to filename
+ safe_title = title.lower().replace(" ", "_").replace("/", "_")
+ filename = os.path.join("experiments", f"elo_results_{safe_title}.csv")
+
+ headers = [
+ "Rank",
+ "Model",
+ "ELO",
+ "ELO-Alt (Wolf/Spy)",
+ "ELO-Main (Vil/Civ)",
+ "Win Rate",
+ "Matches",
+ ]
+
+ with open(filename, "w", newline="") as f:
+ writer = csv.writer(f)
+ writer.writerow(headers)
+
+ for item in stats:
+ writer.writerow(
+ [
+ item["rank"],
+ item["model"],
+ int(item["elo"]),
+ int(item["elo_w"]),
+ int(item["elo_v"]),
+ f"{item['win_rate']:.1f}%",
+ item["matches"],
+ ]
+ )
+ print(f"Generated CSV: {filename}")
+
+
+def calculate_elo(log_dir: str = "logs") -> None:
+ print(f"Calculating ELO from logs in: {log_dir}")
+
+ # Gather logs
+ log_files = glob.glob(os.path.join(log_dir, "*.json"))
+
+ # User requested to process ALL logs, no filtering.
+ filtered_Logs = log_files
+
+ print(f"Found {len(filtered_Logs)} items")
+
+ # 1. Group logs by Game
+ logs_by_game: Dict[str, list[str]] = defaultdict(list)
+
+ for filepath in filtered_Logs:
+ try:
+ with open(filepath, "r") as f:
+ header = json.load(f)
+ metadata = header.get("metadata", {})
+ game_name = metadata.get("game_name", "Unknown")
+
+ # Robust detection if metadata missing but keys present
+ if game_name == "Unknown":
+ # Fallback check based on keys
+ if "Werewolves_model" in metadata:
+ game_name = "Werewolves"
+ elif "Spy_model" in metadata:
+ game_name = "Spyfall"
+ elif "Undercover_model" in metadata:
+ game_name = "Undercover"
+ else:
+ # Fallback to env string
+ print("Unknown game name, falling back to env string")
+ env = header.get("environment", "")
+ if "werewolf" in env.lower():
+ game_name = "werewolves"
+ elif "spyfall" in env.lower():
+ game_name = "spyfall"
+ elif "prison" in env.lower():
+ game_name = "prisoners_dilemma"
+ elif "rock" in env.lower():
+ game_name = "rock_paper_scissors"
+
+ logs_by_game[game_name].append(filepath)
+ except Exception:
+ continue
+
+ # 2. Calculate Stats
+ all_tables_data = {}
+
+ # Overall
+ print("Processing Overall...")
+ overall_stats = process_logs(filtered_Logs)
+ all_tables_data["Overall"] = overall_stats
+ save_to_csv("Overall", overall_stats)
+
+ # Per Game
+ for game_name, game_logs in logs_by_game.items():
+ if not game_name or game_name == "Unknown":
+ continue
+ print(f"Processing {game_name} ({len(game_logs)} games)...")
+ # Format title case
+ title = game_name.replace("_", " ").title()
+ stats = process_logs(game_logs)
+ all_tables_data[title] = stats
+ save_to_csv(title, stats)
+
+ # 3. Generate HTML
+ html_content = generate_html_report(all_tables_data)
+ output_html = os.path.join("experiments", "elo_leaderboard.html")
+ with open(output_html, "w") as f:
+ f.write(html_content)
+
+ print(f"\nSuccessfully generated {output_html}")
+ print("Tables generated for:", ", ".join(all_tables_data.keys()))
+
+
+if __name__ == "__main__":
+ calculate_elo()
diff --git a/experiments/generate_sim_rosters.py b/experiments/generate_sim_rosters.py
new file mode 100644
index 000000000..3c7624d69
--- /dev/null
+++ b/experiments/generate_sim_rosters.py
@@ -0,0 +1,283 @@
+import random
+import os
+import sys
+import itertools
+import json
+import argparse
+import glob
+
+# Add project root to path to allow imports
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+
+from experiments.utils import load_roster_template
+
+NAME_POOL = [
+ "James",
+ "Mary",
+ "Robert",
+ "Patricia",
+ "John",
+ "Jennifer",
+ "Michael",
+ "Linda",
+ "David",
+ "Elizabeth",
+ "William",
+ "Barbara",
+ "Richard",
+ "Susan",
+ "Joseph",
+ "Jessica",
+ "Thomas",
+ "Sarah",
+ "Charles",
+ "Karen",
+ "Christopher",
+ "Nancy",
+ "Daniel",
+ "Lisa",
+ "Matthew",
+ "Betty",
+ "Anthony",
+ "Margaret",
+ "Mark",
+ "Sandra",
+ "Donald",
+ "Ashley",
+ "Steven",
+ "Kimberly",
+ "Paul",
+ "Emily",
+ "Andrew",
+ "Donna",
+ "Joshua",
+ "Michelle",
+ "Kenneth",
+ "Dorothy",
+ "Kevin",
+ "Carol",
+ "Brian",
+ "Amanda",
+ "George",
+ "Melissa",
+ "Edward",
+ "Deborah",
+ "Ronald",
+ "Stephanie",
+ "Timothy",
+ "Rebecca",
+ "Jason",
+ "Sharon",
+ "Jeffrey",
+ "Laura",
+ "Ryan",
+ "Cynthia",
+ "Jacob",
+ "Kathleen",
+ "Gary",
+ "Amy",
+ "Nicholas",
+ "Shirley",
+ "Eric",
+ "Angela",
+ "Jonathan",
+ "Helen",
+ "Stephen",
+ "Anna",
+ "Larry",
+ "Brenda",
+ "Justin",
+ "Pamela",
+ "Scott",
+ "Nicole",
+ "Brandon",
+ "Emma",
+]
+
+
+def generate_rosters(
+ game_names: list[str],
+ models: list[str],
+ n_episodes: int = 6,
+ overwrite: bool = False,
+ challenger: str | None = None,
+) -> None:
+ """
+ Generate randomized roster files for ELO tournament.
+ """
+ if not isinstance(game_names, list):
+ game_names = [game_names]
+
+ print(f"Generating rosters for games: {game_names}")
+ print(f"Competitors: {models}")
+ print(f"Episodes per matchup: {n_episodes}")
+ print("\n" + "=" * 50)
+
+ for game_name in game_names:
+ print(f"Processing game: {game_name}")
+
+ # Verify template exists
+ try:
+ load_roster_template(game_name)
+ except Exception as e:
+ print(f" Error loading template for '{game_name}': {e}")
+ continue
+
+ if len(models) < 2:
+ print(f" Skipping {game_name}: Need at least 2 models.")
+ continue
+
+ # Output directory
+ roster_output_dir = os.path.abspath(
+ os.path.join(os.path.dirname(__file__), "rosters", game_name)
+ )
+ os.makedirs(roster_output_dir, exist_ok=True)
+ print(f" Output directory: {roster_output_dir}")
+
+ # Removed directory-level safety check to allow incremental addition
+
+ count = 0
+ for pair_idx, (model_a, model_b) in enumerate(
+ itertools.permutations(models, 2)
+ ):
+ # Challenger Mode Filter
+ if challenger:
+ if challenger not in (model_a, model_b):
+ continue
+ for i in range(n_episodes):
+ # ... (Logic setup)
+ m1 = model_a
+ m2 = model_b
+
+ # Parse model name: take what's before '@', then take what's after the last '/'
+ sanitized_m1 = model_a.split("@")[0].split("/")[-1]
+ sanitized_m2 = model_b.split("@")[0].split("/")[-1]
+
+ filename = f"roster_{game_name}_match{pair_idx}_ep{i}_{sanitized_m1}_vs_{sanitized_m2}.json"
+ file_path = os.path.join(roster_output_dir, filename)
+
+ # Skip if exists and not overwrite
+ # IMPROVED: Check semantically (ignoring match index) to prevent duplicates if indices shift
+ # New filename pattern suffix: ep{i}_{m1}_vs_{m2}.json
+ semantic_suffix = f"ep{i}_{sanitized_m1}_vs_{sanitized_m2}.json"
+ existing_match = glob.glob(
+ os.path.join(roster_output_dir, f"*{semantic_suffix}")
+ )
+
+ if existing_match and not overwrite:
+ # found existing file for this pair+episode
+ continue
+
+ # Double check specific path (though glob should cover it)
+ if os.path.exists(file_path) and not overwrite:
+ continue
+
+ # 2. Assign Models
+ try:
+ current_config = load_roster_template(game_name)
+ except Exception as e:
+ print(f"Error loading template for '{game_name}': {e}")
+ continue
+
+ agents = current_config["agents"]
+
+ # Check teams
+ teams = {a.get("team") for a in agents if a.get("team")}
+ unique_teams = sorted([t for t in teams if t])
+
+ if len(unique_teams) == 2:
+ # Asymmetric
+ t1, t2 = unique_teams[0], unique_teams[1]
+ for agent in agents:
+ if agent.get("team") == t1:
+ agent["agent_model"] = m1
+ elif agent.get("team") == t2:
+ agent["agent_model"] = m2
+ else:
+ agent["agent_model"] = m1
+ else:
+ # Symmetric or single-team
+ for idx, agent in enumerate(agents):
+ agent["agent_model"] = m1 if idx % 2 == 0 else m2
+
+ # 3. Randomization / Permutation
+ if len(agents) == 2:
+ # Deterministic toggle for 2-player games to strictly balance speaking order
+ # i=0: [A, B], i=1: [B, A]
+ if i % 2 != 0:
+ agents.reverse()
+ else:
+ # Random shuffle for multi-player games (>2 agents)
+ random.shuffle(agents)
+ # Assign Names
+ name_pool = NAME_POOL.copy()
+ random.shuffle(name_pool)
+ for idx, agent in enumerate(agents):
+ agent["name"] = name_pool[idx]
+
+ # 4. Save
+ with open(file_path, "w") as f:
+ json.dump(current_config, f, indent=4)
+
+ count += 1
+
+ print(f" Generated {count} NEW rosters for {game_name}")
+
+ print("\nGeneration Complete.")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Generate Rosters for ELO Tournament")
+ parser.add_argument(
+ "--game",
+ nargs="+",
+ default=[
+ "werewolves",
+ "spyfall",
+ "prisoners_dilemma",
+ "rock_paper_scissors",
+ "undercover",
+ ],
+ help="List of games",
+ )
+ parser.add_argument(
+ "--models",
+ nargs="+",
+ default=[
+ "gpt-4o-mini",
+ "gpt-4o",
+ "gpt-5",
+ "custom/google/gemma-3-1b@http://127.0.0.1:1234/v1",
+ "custom/qwen/qwen3-next-80b@http://127.0.0.1:1234/v1",
+ "custom/qwen/qwen3-4b-2507@http://127.0.0.1:1234/v1",
+ "custom/Qwen/Qwen3-8B@http://127.0.0.1:1235/v1",
+ ],
+ help="List of models to compete",
+ )
+ parser.add_argument(
+ "--episodes",
+ type=int,
+ default=6,
+ help="Number of episodes per matchup (default: 6)",
+ )
+ parser.add_argument(
+ "--overwrite",
+ action="store_true",
+ help="Allow generation even if output directory is not empty",
+ )
+ parser.add_argument(
+ "--challenger",
+ type=str,
+ default=None,
+ help="If set, only generate rosters involving this model",
+ )
+
+ args = parser.parse_args()
+
+ generate_rosters(
+ game_names=args.game,
+ models=args.models,
+ n_episodes=args.episodes,
+ overwrite=args.overwrite,
+ challenger=args.challenger,
+ )
diff --git a/experiments/run_elo_tournament.py b/experiments/run_elo_tournament.py
new file mode 100644
index 000000000..0c0795f8a
--- /dev/null
+++ b/experiments/run_elo_tournament.py
@@ -0,0 +1,239 @@
+import asyncio
+import logging
+import os
+import sys
+import json
+import glob
+from datetime import datetime
+from tqdm.asyncio import tqdm
+
+# Add project root to path to allow imports
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+
+from sotopia.server import arun_one_episode
+from experiments.utils import get_game_module, load_game_config
+
+# Logger for this script
+logger = logging.getLogger(__name__)
+
+
+async def run_elo_tournament(
+ game_names: list[str],
+ tag: str = "elo_exp_v1",
+ push_to_db: bool = True,
+ concurrency_limit: int = 10,
+) -> None:
+ """
+ Run ELO tournament by executing pre-generated rosters found in experiments/rosters/.
+ """
+ if not isinstance(game_names, list):
+ game_names = [game_names]
+
+ print("Starting ELO Tournament Execution")
+ print(f"Target Games: {game_names}")
+ print(f"Concurrency: {concurrency_limit}")
+ print("\n" + "=" * 50)
+
+ semaphore = asyncio.Semaphore(concurrency_limit) # Limit concurrent API calls
+
+ # Refactor: Define worker properly to avoid closure issues if defined once
+ # Better to define it inside the loop or pass args.
+
+ for game_name in game_names:
+ print(f"Scanning rosters for game: {game_name}")
+
+ # 1. Scan EXISTING LOGS to identify played rosters
+ # We look for metadata["roster_file"] in the logs
+ executed_rosters = set()
+ log_pattern = os.path.join("logs", "*.json")
+ existing_logs = glob.glob(log_pattern)
+
+ # Optimization: Only scan logs relevant to this game?
+ # But filename parsing is fast enough for <2000 files.
+ for log_file in existing_logs:
+ try:
+ with open(log_file, "r") as f:
+ # Partial read might be faster but JSON load is safe
+ data = json.load(f)
+ meta = data.get("metadata", {})
+ r_file = meta.get("roster_file")
+ if r_file:
+ executed_rosters.add(r_file)
+ except Exception as e:
+ print(f"Error reading log file {log_file}: {e}")
+ continue
+
+ print(
+ f" Found {len(executed_rosters)} already executed rosters (across all games)."
+ )
+
+ roster_dir = os.path.abspath(
+ os.path.join(os.path.dirname(__file__), "rosters", game_name)
+ )
+ roster_files = sorted(glob.glob(os.path.join(roster_dir, "*.json")))
+
+ if not roster_files:
+ print(f" No rosters found in {roster_dir}")
+ continue
+
+ print(f" Found {len(roster_files)} total rosters.")
+
+ # Filter out executed rosters
+ rosters_to_run = []
+ for r_path in roster_files:
+ r_filename = os.path.basename(r_path)
+ if r_filename in executed_rosters:
+ continue
+ rosters_to_run.append(r_path)
+
+ print(f" {len(rosters_to_run)} rosters remaining to execute.")
+
+ if not rosters_to_run:
+ continue
+
+ # Load Game Module
+ try:
+ game_module = get_game_module(game_name)
+ prepare_scenario = game_module.prepare_scenario
+ except Exception as e:
+ print(f" Error loading game module: {e}")
+ continue
+
+ async def _worker(roster_path: str) -> None:
+ async with semaphore:
+ filename = os.path.basename(roster_path)
+ # Double check to prevent race condition if logs updated mid-run?
+ # Not strictly necessary for this scale.
+
+ try:
+ with open(roster_path, "r") as f:
+ roster_config = json.load(f)
+
+ base_config = load_game_config(game_name)
+ episode_config = base_config.copy()
+ episode_config.update(roster_config)
+
+ agent_model_list = [
+ a["agent_model"] for a in episode_config["agents"]
+ ]
+ parts = filename.replace(".json", "").split("_")
+ match_str = "unknown"
+ for p in parts:
+ if p.startswith("match"):
+ match_str = p
+
+ agents_conf = episode_config["agents"]
+ teams = set(a.get("team") for a in agents_conf if a.get("team"))
+
+ unique_models = sorted(list(set(agent_model_list)))
+ model_a_log = (
+ unique_models[0] if len(unique_models) > 0 else "unknown"
+ )
+ model_b_log = (
+ unique_models[1] if len(unique_models) > 1 else "unknown"
+ )
+
+ metadata = {
+ "game_name": game_name,
+ "model_a": model_a_log,
+ "model_b": model_b_log,
+ "pair_idx": match_str,
+ "roster_file": filename,
+ }
+
+ # Add Team-specific model info (e.g. Civilians_model: gpt-4o)
+ if len(teams) > 1:
+ for team_name in teams:
+ # Find model(s) for this team
+ team_models = set(
+ a["agent_model"]
+ for a in agents_conf
+ if a.get("team") == team_name
+ )
+ if len(team_models) == 1:
+ metadata[f"{team_name}_model"] = list(team_models)[0]
+ else:
+ metadata[f"{team_name}_model"] = (
+ "mixed" # Should not happen
+ )
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+ log_filename = (
+ f"episode_{tag}_{game_name}_{match_str}_{timestamp}.json"
+ )
+ log_path = os.path.join("logs", log_filename)
+
+ env, agents = prepare_scenario(
+ env_model_name="gpt-4o",
+ agent_model_name=agent_model_list,
+ config=episode_config,
+ )
+
+ os.makedirs("logs", exist_ok=True)
+
+ await arun_one_episode(
+ env=env,
+ agent_list=agents,
+ tag=tag,
+ push_to_db=push_to_db,
+ output_path=log_path,
+ metadata=metadata,
+ )
+ except Exception:
+ # TQDM will swallow prints usually, so we might want to log errors manually if crucial
+ # logging.error(f"Error in {filename}: {e}")
+ pass
+
+ # Create tasks
+ tasks = [asyncio.create_task(_worker(p)) for p in rosters_to_run]
+
+ # Use TQDM with asyncio
+ print(f" Queuing {len(tasks)} tasks...")
+ for f in tqdm(
+ asyncio.as_completed(tasks), total=len(tasks), desc=f"Playing {game_name}"
+ ):
+ await f
+
+ print("\nAll Scheduled Rosters Executed.")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ # Reconfigure logging to suppress sotopia's verbose output
+ # 1. Root Logger
+ logging.basicConfig(level=logging.ERROR)
+
+ # 2. Silence sotopia experimental server and generation
+ server_logger = logging.getLogger("sotopia.experimental.server")
+ server_logger.setLevel(logging.ERROR)
+ generation_logger = logging.getLogger("sotopia.generation")
+ generation_logger.setLevel(logging.ERROR)
+
+ os.environ.setdefault("REDIS_OM_URL", "redis://:@localhost:6379")
+
+ parser = argparse.ArgumentParser(description="Run ELO Tournament Execution Phase")
+ parser.add_argument(
+ "--game",
+ nargs="+",
+ default=[
+ "werewolves",
+ "spyfall",
+ "prisoners_dilemma",
+ "rock_paper_scissors",
+ "undercover",
+ ],
+ help="List of games to execute rosters for",
+ )
+ parser.add_argument("--tag", type=str, default="elo_exp_v1", help="Experiment tag")
+ parser.add_argument(
+ "--concurrency", type=int, default=10, help="Max concurrent episodes"
+ )
+
+ args = parser.parse_args()
+
+ asyncio.run(
+ run_elo_tournament(
+ game_names=args.game, tag=args.tag, concurrency_limit=args.concurrency
+ )
+ )
diff --git a/experiments/utils.py b/experiments/utils.py
new file mode 100644
index 000000000..d72be66d2
--- /dev/null
+++ b/experiments/utils.py
@@ -0,0 +1,140 @@
+import json
+import os
+import importlib.util
+import sys
+from typing import Any, cast
+
+
+def get_game_module(game_name: str) -> Any:
+ """
+ Dynamically load the main module for a given game.
+ Assumes standard path: examples/experimental/games/{game_name}/main.py
+ """
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+ module_path = os.path.join(
+ project_root, f"examples/experimental/games/{game_name}/main.py"
+ )
+
+ if not os.path.exists(module_path):
+ raise ValueError(f"Game module not found at {module_path}")
+
+ spec = importlib.util.spec_from_file_location(f"{game_name}_main", module_path)
+ if spec and spec.loader:
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[f"{game_name}_main"] = module
+ spec.loader.exec_module(module)
+ return module
+ else:
+ raise ImportError(f"Could not load module from {module_path}")
+
+
+def load_roster_template(game_name: str) -> dict[str, Any]:
+ """Load the base roster.json for a game."""
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+ roster_path = os.path.join(
+ project_root, f"examples/experimental/games/{game_name}/roster.json"
+ )
+
+ if not os.path.exists(roster_path):
+ raise FileNotFoundError(f"Roster file not found at {roster_path}")
+
+ with open(roster_path, "r") as f:
+ return cast(dict[str, Any], json.load(f))
+
+
+def load_game_config(game_name: str) -> dict[str, Any]:
+ """Load the base config.json for a game."""
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+ config_path = os.path.join(
+ project_root, f"examples/experimental/games/{game_name}/config.json"
+ )
+
+ if os.path.exists(config_path):
+ with open(config_path, "r") as f:
+ return cast(dict[str, Any], json.load(f))
+ return {}
+
+
+def generate_roster(
+ game_name: str,
+ model_a: str,
+ model_b: str,
+ output_dir: str,
+ pair_id: str,
+ swap: bool = False,
+) -> str:
+ """
+ Generate a roster.json for a specific matchup.
+
+ Args:
+ game_name: Name of the game (e.g., 'werewolves')
+ model_a: Name of Model A
+ model_b: Name of Model B
+ output_dir: Directory to save the generated roster
+ pair_id: Identifier for the pair (e.g., 'match0')
+ swap: If True, swap the assignment (Model B takes Team 1/Slot 1, Model A takes Team 2/Slot 2)
+
+ Returns:
+ Path to the generated roster file.
+ """
+ # Load base config (rules, states, etc.)
+ base_config = load_game_config(game_name)
+
+ # Load roster template (agents list)
+ base_roster = load_roster_template(game_name)
+ agents = base_roster.get("agents", [])
+
+ if not agents:
+ raise ValueError("Roster must contain 'agents' list")
+
+ # Determine assignment logic
+ # Check if teams are present
+ teams = {a.get("team") for a in agents if a.get("team")}
+ # Filter out None if present
+ unique_teams = sorted([t for t in teams if t])
+
+ m1 = model_b if swap else model_a
+ m2 = model_a if swap else model_b
+
+ if len(unique_teams) == 2:
+ # Asymmetric / Team-based game (Werewolf, Spyfall)
+ team_1 = unique_teams[0] # e.g. "Civilians" or "Non-Spy" (alphabetical)
+ team_2 = unique_teams[1] # e.g. "Undercover" or "Spy"
+
+ for agent in agents:
+ team = agent.get("team")
+ if team == team_1:
+ agent["agent_model"] = m1
+ elif team == team_2:
+ agent["agent_model"] = m2
+ else:
+ # Fallback or neutral role?
+ agent["agent_model"] = m1
+
+ else:
+ # Symmetric / No-team game (PD, RPS)
+ for i, agent in enumerate(agents):
+ if i % 2 == 0:
+ agent["agent_model"] = m1
+ else:
+ agent["agent_model"] = m2
+
+ # Merge: Update agents in base_config
+ base_config["agents"] = agents
+
+ # If base_config was empty (no config.json), fall back to just roster structure
+ final_output = base_config if base_config else base_roster
+
+ # Save
+ os.makedirs(output_dir, exist_ok=True)
+ # Sanitize model names for filename
+ sanitized_m1 = m1.replace("/", "_").replace("@", "_").split("v1")[0][-10:]
+ sanitized_m2 = m2.replace("/", "_").replace("@", "_").split("v1")[0][-10:]
+
+ filename = f"roster_{game_name}_{pair_id}_{'swapped' if swap else 'normal'}_{sanitized_m1}_vs_{sanitized_m2}.json"
+ output_path = os.path.join(output_dir, filename)
+
+ with open(output_path, "w") as f:
+ json.dump(final_output, f, indent=4)
+
+ return output_path
diff --git a/sotopia/agents/base_agent.py b/sotopia/agents/base_agent.py
index 1454104e5..d80897251 100644
--- a/sotopia/agents/base_agent.py
+++ b/sotopia/agents/base_agent.py
@@ -19,14 +19,19 @@ def __init__(
MessengerMixin.__init__(self)
if agent_profile is not None:
self.profile = agent_profile
- self.agent_name = self.profile.first_name + " " + self.profile.last_name
+ self.agent_name = (
+ self.profile.first_name + " " + self.profile.last_name
+ ).strip()
+
elif uuid_str is not None:
# try retrieving profile from database
try:
self.profile = AgentProfile.get(pk=uuid_str)
except NotFoundError:
raise ValueError(f"Agent with uuid {uuid_str} not found in database")
- self.agent_name = self.profile.first_name + " " + self.profile.last_name
+ self.agent_name = (
+ self.profile.first_name + " " + self.profile.last_name
+ ).strip()
else:
assert (
agent_name is not None
diff --git a/sotopia/agents/llm_agent.py b/sotopia/agents/llm_agent.py
index 497954d7b..873f52d24 100644
--- a/sotopia/agents/llm_agent.py
+++ b/sotopia/agents/llm_agent.py
@@ -1,6 +1,6 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
-from typing import cast
+from typing import Any, cast
from sotopia.agents import BaseAgent
from sotopia.database import AgentProfile
@@ -8,6 +8,7 @@
agenerate_action,
agenerate_goal,
agenerate_script,
+ fill_template,
)
from sotopia.messages import AgentAction, Observation
from sotopia.messages.message_classes import ScriptBackground
@@ -28,6 +29,8 @@ def __init__(
agent_profile: AgentProfile | None = None,
model_name: str = "gpt-4o-mini",
script_like: bool = False,
+ strict_action_constraint: bool = False,
+ custom_template: str | None = None,
) -> None:
super().__init__(
agent_name=agent_name,
@@ -36,6 +39,9 @@ def __init__(
)
self.model_name = model_name
self.script_like = script_like
+ self.strict_action_constraint = strict_action_constraint
+ self.custom_template = custom_template
+ self.generation_history: list[dict[str, Any]] = []
@property
def goal(self) -> str:
@@ -68,7 +74,13 @@ async def aact(self, obs: Observation) -> AgentAction:
if len(obs.available_actions) == 1 and "none" in obs.available_actions:
return AgentAction(action_type="none", argument="")
else:
- action = await agenerate_action(
+ custom_template = self.custom_template
+ if custom_template:
+ custom_template = fill_template(
+ custom_template, action_instructions=obs.action_instruction
+ )
+
+ action, prompt, raw_response = await agenerate_action(
self.model_name,
history="\n".join(f"{y.to_natural_language()}" for x, y in self.inbox),
turn_number=obs.turn_number,
@@ -76,6 +88,18 @@ async def aact(self, obs: Observation) -> AgentAction:
agent=self.agent_name,
goal=self.goal,
script_like=self.script_like,
+ strict_action_constraint=self.strict_action_constraint,
+ custom_template=custom_template,
+ return_prompt_and_response=True,
+ )
+ self.generation_history.append(
+ {
+ "turn_number": obs.turn_number,
+ "prompt": prompt,
+ "response": raw_response,
+ "parsed_action": action.model_dump(),
+ "agent_name": self.agent_name,
+ }
)
# Temporary fix for mixtral-moe model for incorrect generation format
if "Mixtral-8x7B-Instruct-v0.1" in self.model_name:
diff --git a/sotopia/database/logs.py b/sotopia/database/logs.py
index cbbf13c4b..fdeea5591 100644
--- a/sotopia/database/logs.py
+++ b/sotopia/database/logs.py
@@ -8,7 +8,7 @@
from pydantic import model_validator, BaseModel
from redis_om import JsonModel
from redis_om.model.model import Field
-from typing import Literal
+from typing import Literal, Any
from sotopia.database.persistent_profile import AgentProfile
@@ -30,6 +30,7 @@ class BaseEpisodeLog(BaseModel):
reasoning: str = Field(default="")
rewards: list[tuple[float, dict[str, float]] | float] # Rewards arranged by turn
rewards_prompt: str = Field(default="")
+ metadata: dict[str, Any] = Field(default_factory=dict)
@model_validator(mode="after")
def agent_number_message_number_reward_number_turn_number_match(self) -> Self:
diff --git a/sotopia/database/persistent_profile.py b/sotopia/database/persistent_profile.py
index ab2f78fcb..23e1871e9 100644
--- a/sotopia/database/persistent_profile.py
+++ b/sotopia/database/persistent_profile.py
@@ -88,6 +88,10 @@ class BaseEnvironmentProfile(BaseModel):
agent_constraint: list[list[str]] | None = Field(
default_factory=lambda: None,
)
+ game_metadata: dict[str, Any] | None = Field(
+ default_factory=lambda: None,
+ description="Optional metadata for structured social games (rulebooks, config paths, etc.).",
+ )
tag: str = Field(
index=True,
default_factory=lambda: "",
diff --git a/sotopia/envs/__init__.py b/sotopia/envs/__init__.py
index fa56ad757..4cf6c0840 100644
--- a/sotopia/envs/__init__.py
+++ b/sotopia/envs/__init__.py
@@ -1,3 +1,4 @@
from .parallel import ParallelSotopiaEnv
+from .social_game import SocialDeductionGame, SocialGame
-__all__ = ["ParallelSotopiaEnv"]
+__all__ = ["ParallelSotopiaEnv", "SocialDeductionGame", "SocialGame"]
diff --git a/sotopia/envs/evaluators.py b/sotopia/envs/evaluators.py
index 4aeda1199..731f60670 100644
--- a/sotopia/envs/evaluators.py
+++ b/sotopia/envs/evaluators.py
@@ -1,7 +1,7 @@
import abc
import logging
from collections import defaultdict
-from typing import Generic, TypeVar
+from typing import Any, Generic, TypeVar
import gin
from pydantic import BaseModel, validate_call
@@ -33,17 +33,55 @@ def __init__(self) -> None:
@abc.abstractmethod
def __call__(
- self, turn_number: int, messages: list[tuple[str, Message]]
+ self, turn_number: int, messages: list[tuple[str, Message]], **kwargs: Any
) -> list[tuple[str, tuple[tuple[str, int | float | bool], str]]]:
raise NotImplementedError
@abc.abstractmethod
async def __acall__(
- self, turn_number: int, messages: list[tuple[str, Message]]
+ self, turn_number: int, messages: list[tuple[str, Message]], **kwargs: Any
) -> list[tuple[str, tuple[tuple[str, int | float | bool], str]]]:
raise NotImplementedError
+class SocialGameEndEvaluator(Evaluator):
+ """Base evaluator for social game win conditions.
+
+ Subclasses should implement _check_win_conditions() to check
+ game-specific win conditions using the environment state.
+ """
+
+ def __init__(self, max_turn_number: int = 100) -> None:
+ self.max_turn_number = max_turn_number
+
+ def __call__(
+ self, turn_number: int, messages: list[tuple[str, Message]], **kwargs: Any
+ ) -> list[tuple[str, tuple[tuple[str, int | float | bool], str]]]:
+ # Check turn limit
+ if turn_number >= self.max_turn_number:
+ return [("environment", (("terminated", True), "Max turns reached"))]
+
+ # Extract environment from kwargs
+ env = kwargs.get("env")
+ if not env:
+ return [("environment", (("terminated", False), ""))]
+
+ # Check game-specific win conditions
+ terminated, reason = self._check_win_conditions(env, turn_number, messages)
+ return [("environment", (("terminated", terminated), reason))]
+
+ async def __acall__(
+ self, turn_number: int, messages: list[tuple[str, Message]], **kwargs: Any
+ ) -> list[tuple[str, tuple[tuple[str, int | float | bool], str]]]:
+ return self.__call__(turn_number, messages, **kwargs)
+
+ def _check_win_conditions(
+ self, env: Any, turn_number: int, messages: list[tuple[str, Message]]
+ ) -> tuple[bool, str]:
+ """Check game-specific win conditions. Override in subclasses."""
+ return False, ""
+
+
class RuleBasedTerminatedEvaluator(Evaluator):
def __init__(self, max_turn_number: int = 20, max_stale_turn: int = 2) -> None:
self.max_turn_number = max_turn_number
@@ -51,7 +89,7 @@ def __init__(self, max_turn_number: int = 20, max_stale_turn: int = 2) -> None:
@validate_call
def __call__(
- self, turn_number: int, messages: list[tuple[str, Message]]
+ self, turn_number: int, messages: list[tuple[str, Message]], **kwargs: Any
) -> list[tuple[str, tuple[tuple[str, int | float | bool], str]]]:
# Rule 1: If the conversation is too long, terminate the conversation
conversation_too_long = turn_number >= self.max_turn_number
@@ -72,10 +110,16 @@ def __call__(
latest_action_by_agent[speaker] = msg.action_type
# If we haven't observed any agent messages yet, do not terminate early
- if observed_agents:
+ env = kwargs.get("env")
+ if env:
+ all_agents = set(env.agents)
+ else:
+ all_agents = observed_agents
+
+ if all_agents:
num_active_agents = sum(
1
- for agent in observed_agents
+ for agent in all_agents
if latest_action_by_agent.get(agent, "speak") != "leave"
)
else:
@@ -109,9 +153,9 @@ def __call__(
]
async def __acall__(
- self, turn_number: int, messages: list[tuple[str, Message]]
+ self, turn_number: int, messages: list[tuple[str, Message]], **kwargs: Any
) -> list[tuple[str, tuple[tuple[str, int | float | bool], str]]]:
- return self(turn_number, messages)
+ return self(turn_number, messages, **kwargs)
class EpisodeLLMEvaluator(Evaluator, Generic[T_eval_dim]):
@@ -125,7 +169,7 @@ def __init__(
self.response_format_class = response_format_class
def __call__(
- self, turn_number: int, messages: list[tuple[str, Message]]
+ self, turn_number: int, messages: list[tuple[str, Message]], **kwargs: Any
) -> list[tuple[str, tuple[tuple[str, int | float | bool], str]]]:
raise NotImplementedError(
"ReachGoalLLMEvaluator is not implemented for synchronous evaluation"
@@ -139,6 +183,7 @@ async def __acall__(
messages: list[tuple[str, Message]] | None,
history: str = "",
temperature: float | None = 0.0,
+ **kwargs: Any,
) -> list[tuple[str, tuple[tuple[str, int | float | bool], str]]]:
# filter did nothing
if not history and messages:
@@ -249,7 +294,7 @@ def _reduce(
if len(scores) and "overall_score" not in responses_dict:
scores = [x for x in scores if x is not None]
reduced_dict["overall_score"] = sum(scores) / len(scores)
- comments = "\n".join([f"{k}: {v}" for k, v in comments_dict.items()])
+ comments = "\n".join([f"{k}: {v}" for k, v in comments_dict.items() if v])
return reduced_dict, comments
@@ -268,7 +313,8 @@ def unweighted_aggregate_evaluate(
defaultdict(list)
)
for response in responses:
- assert response[0] == "environment" or response[0].startswith("agent")
+ # Relaxed assertion: allow any key for agents, not just "agent_X"
+ # assert response[0] == "environment" or response[0].startswith("agent")
responses_dict[response[0]].append(response[1])
environment_responses: tuple[dict[str, float | int | bool], str] = ({}, "")
@@ -323,4 +369,8 @@ def unweighted_aggregate_evaluate(
if agent_2_responses != ({}, "")
else None,
comments=comments,
+ rewards={
+ k: (v[0]["overall_score"] if "overall_score" in v[0] else 0, v[0])
+ for k, v in agent_responses.items()
+ },
)
diff --git a/sotopia/envs/parallel.py b/sotopia/envs/parallel.py
index a03f6945c..2970662bd 100644
--- a/sotopia/envs/parallel.py
+++ b/sotopia/envs/parallel.py
@@ -136,6 +136,8 @@ def __init__(
uuid_str: str | None = None,
env_profile: EnvironmentProfile | None = None,
background_class: Optional[Type[TBackground]] = None,
+ hide_unknown: bool = False,
+ include_turn_marker: bool = True,
) -> None:
"""A sotopia environment for parallel agents.
@@ -149,6 +151,8 @@ def __init__(
self.background_class = ScriptBackground
else:
self.background_class = background_class
+ self.hide_unknown = hide_unknown
+ self.include_turn_marker = include_turn_marker
self.background = self.background_class(
scenario="",
agent_names=[],
@@ -186,6 +190,7 @@ def reset(
agents: Agents | None = None,
omniscient: bool = False,
lite: bool = False,
+ include_background_observations: bool = True,
) -> dict[str, Observation]:
"""Starting a new episode. Must be called before step().
@@ -195,6 +200,7 @@ def reset(
"partial_background_file" (str): Path to a json file which need to contain a ScriptBackground object. The backgound can be incompleted ("unknown" for missing parts), and the missing parts will be filled in by the environment.
"full_background_file" (str): Path to a json file which need to contain a ScriptBackground object. The backgound must be completed (no "unknown" for missing parts).
omniscient (bool, optional): Whether the agents know the other agent's goal. Defaults to False.
+ include_background_observations (bool, optional): Whether to include the background (Environment's message) in the observation. Defaults to True.
"""
super().__init__()
MessengerMixin.reset_inbox(self)
@@ -246,7 +252,7 @@ def reset(
# Lite mode - clear backgrounds
raw_background.agent_backgrounds = [""] * num_agents
- # Create final rendered background (works for 2+ agents)
+ # Create final rendered background
self.background = self.background_class(
scenario=render_text_for_environment(raw_background.scenario),
agent_names=raw_background.agent_names,
@@ -288,6 +294,7 @@ def reset(
agent_goals=[
render_text_for_agent(goal, i) for goal in hidden_goals
],
+ hide_unknown=self.hide_unknown,
)
agent_backgrounds.append(agent_background)
@@ -309,36 +316,36 @@ def reset(
else:
self.action_mask = [True for _ in self.agents]
- self.recv_message("Environment", self.background)
-
# Create observations for each agent
observations = {}
- for i, agent_name in enumerate(self.agents):
- agent_bg = agent_backgrounds[i]
- observations[agent_name] = Observation(
- last_turn=agent_bg.to_natural_language(),
- turn_number=0,
- available_actions=list(self.available_action_types)
- if self.action_mask[i]
- else ["none"],
- )
+ if include_background_observations:
+ self.recv_message("Environment", self.background)
+ for i, agent_name in enumerate(self.agents):
+ agent_bg = agent_backgrounds[i]
+ observations[agent_name] = Observation(
+ last_turn=agent_bg.to_natural_language(),
+ turn_number=0,
+ available_actions=list(self.available_action_types)
+ if self.action_mask[i]
+ else ["none"],
+ )
+ else:
+ for i, agent_name in enumerate(self.agents):
+ observations[agent_name] = Observation(
+ last_turn="",
+ turn_number=0,
+ available_actions=list(self.available_action_types)
+ if self.action_mask[i]
+ else ["none"],
+ )
return observations
- @validate_call
- def step(
+ def _process_incoming_actions(
self, actions: dict[str, AgentAction] | dict[str, dict[str, int | str]]
- ) -> tuple[
- dict[str, Observation],
- dict[str, float],
- dict[str, bool],
- dict[str, bool],
- dict[str, dict[Any, Any]],
- ]:
- # Time step ++
- self.turn_number += 1
-
- # For action sampled from action space, it needs to be converted into AgentAction
+ ) -> dict[str, AgentAction]:
+ """Normalize actions, apply mask, and record to history."""
+ # Normalize actions to AgentAction objects
complied_actions: dict[str, AgentAction] = {}
for key in actions.keys():
action = actions[key]
@@ -355,12 +362,53 @@ def step(
if not self.action_mask[idx]:
complied_actions[agent] = AgentAction(action_type="none", argument="")
- self.recv_message(
- "Environment", SimpleMessage(message=f"Turn #{self.turn_number}")
- )
+ if self.include_turn_marker:
+ self.recv_message(
+ "Environment", SimpleMessage(message=f"Turn #{self.turn_number}")
+ )
for agent, action in complied_actions.items():
- self.recv_message(agent, action)
+ # Only record actions from agents that are in turn
+ idx = self.agents.index(agent)
+ if self.action_mask[idx]:
+ self.recv_message(agent, action)
+
+ return complied_actions
+
+ async def _run_evaluators(self, evaluators: list[Evaluator]) -> Any:
+ """Run evaluators and aggregate results."""
+ return unweighted_aggregate_evaluate(
+ list(
+ itertools.chain(
+ *await asyncio.gather(
+ *[
+ evaluator.__acall__(
+ turn_number=self.turn_number,
+ messages=self.inbox,
+ env=self,
+ )
+ for evaluator in evaluators
+ ]
+ )
+ )
+ )
+ )
+
+ @validate_call
+ def step(
+ self, actions: dict[str, AgentAction] | dict[str, dict[str, int | str]]
+ ) -> tuple[
+ dict[str, Observation],
+ dict[str, float],
+ dict[str, bool],
+ dict[str, bool],
+ dict[str, dict[Any, Any]],
+ ]:
+ # Time step ++
+ self.turn_number += 1
+
+ complied_actions = self._process_incoming_actions(actions)
+ # Sync evaluation (not refactored to helper as it's sync vs async)
response = unweighted_aggregate_evaluate(
list(
itertools.chain(
@@ -404,9 +452,26 @@ def step(
{
agent_name: {
"comments": response.comments or "",
- "complete_rating": 0,
+ "complete_rating": (
+ response.rewards.get(f"agent_{i+1}", (0, {}))[0] # type: ignore[index]
+ if response.rewards
+ else (
+ (response.p1_rate if i == 0 else response.p2_rate)
+ if isinstance(response.p1_rate, (int, float))
+ and isinstance(response.p2_rate, (int, float))
+ else (
+ response.p1_rate[0]
+ if i == 0 and isinstance(response.p1_rate, tuple)
+ else (
+ response.p2_rate[0]
+ if i == 1 and isinstance(response.p2_rate, tuple)
+ else 0
+ )
+ )
+ )
+ ),
}
- for agent_name in self.agents
+ for i, agent_name in enumerate(self.agents)
},
)
@@ -422,61 +487,12 @@ async def astep(
# Time step ++
self.turn_number += 1
- # For action sampled from action space, it needs to be converted into AgentAction
- complied_actions: dict[str, AgentAction] = {}
- for key in actions.keys():
- action = actions[key]
- if isinstance(action, AgentAction):
- complied_actions[key] = action
- else:
- action["action_type"] = self.available_action_types[
- int(action["action_type"])
- ]
- complied_actions[key] = AgentAction.parse_obj(action)
-
- # Masking actions from agent that are in turn
- for idx, agent in enumerate(self.agents):
- if not self.action_mask[idx]:
- complied_actions[agent] = AgentAction(action_type="none", argument="")
+ complied_actions = self._process_incoming_actions(actions)
- self.recv_message(
- "Environment", SimpleMessage(message=f"Turn #{self.turn_number}")
- )
- for agent, action in complied_actions.items():
- self.recv_message(agent, action)
-
- response = unweighted_aggregate_evaluate(
- list(
- itertools.chain(
- *await asyncio.gather(
- *[
- evaluator.__acall__(
- turn_number=self.turn_number,
- messages=self.inbox,
- )
- for evaluator in self.evaluators
- ]
- )
- )
- )
- )
+ response = await self._run_evaluators(self.evaluators)
if response.terminated:
- terminal_response = unweighted_aggregate_evaluate(
- list(
- itertools.chain(
- *await asyncio.gather(
- *[
- evaluator.__acall__(
- turn_number=self.turn_number,
- messages=self.inbox,
- )
- for evaluator in self.terminal_evaluators
- ]
- )
- )
- )
- )
+ terminal_response = await self._run_evaluators(self.terminal_evaluators)
# incorporate terminal response into response
response.p1_rate = response.p1_rate or terminal_response.p1_rate
response.p2_rate = response.p2_rate or terminal_response.p2_rate
@@ -494,12 +510,32 @@ async def astep(
self.action_mask = [True for _ in self.agents]
obs = _actions_to_natural_language(complied_actions)
# Create info dictionary for all agents
+ if response.terminated:
+ pass
+
info = {
agent_name: {
"comments": response.comments or "",
- "complete_rating": 0,
+ "complete_rating": (
+ response.rewards.get(f"agent_{i+1}", (0, {}))[0]
+ if response.rewards
+ else (
+ (response.p1_rate if i == 0 else response.p2_rate)
+ if isinstance(response.p1_rate, (int, float))
+ and isinstance(response.p2_rate, (int, float))
+ else (
+ response.p1_rate[0]
+ if i == 0 and isinstance(response.p1_rate, tuple)
+ else (
+ response.p2_rate[0]
+ if i == 1 and isinstance(response.p2_rate, tuple)
+ else 0
+ )
+ )
+ )
+ ),
}
- for agent_name in self.agents
+ for i, agent_name in enumerate(self.agents)
}
if response.terminated:
info["rewards_prompt"] = {
diff --git a/sotopia/envs/social_game.py b/sotopia/envs/social_game.py
new file mode 100644
index 000000000..af8da54ba
--- /dev/null
+++ b/sotopia/envs/social_game.py
@@ -0,0 +1,625 @@
+from __future__ import annotations
+
+import asyncio
+import itertools
+import json
+import logging
+import random
+from collections import defaultdict
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Any, Dict, List, Literal, cast
+
+from sotopia.envs.parallel import ParallelSotopiaEnv
+from sotopia.envs.evaluators import Evaluator, unweighted_aggregate_evaluate
+from sotopia.agents.llm_agent import Agents
+from sotopia.database import EnvironmentProfile
+from sotopia.messages import AgentAction, Observation, SimpleMessage, Message
+
+logger = logging.getLogger(__name__)
+
+SOCIAL_GAME_PROMPT_TEMPLATE = """
+Imagine you are playing the game as {agent}.
+
+Here is the description of the game: {description}
+
+Your ({agent}'s) goal: {goal}
+{secret}
+
+Here is the context of the interaction:
+{history}
+
+Your available action type(s): [{action_list}].
+{action_instructions}
+
+Please only generate a JSON string including the action type and the argument.
+Your action should follow the given format:
+{format_instructions}
+"""
+
+
+class SocialGame(ParallelSotopiaEnv, ABC):
+ """Abstract base class for social games.
+
+ Defines the interface for building state, handling transitions, and building observations.
+ """
+
+ def __init__(
+ self,
+ env_profile: EnvironmentProfile,
+ action_handler: ActionHandler | None = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(env_profile=env_profile, **kwargs)
+ self.action_handler = action_handler
+
+ @abstractmethod
+ def build_state(self, actions: Dict[str, AgentAction]) -> None:
+ """Update game state based on agent actions."""
+ pass
+
+ @abstractmethod
+ def state_transition(self) -> None:
+ """Handle state transitions (e.g., FSM updates)."""
+ pass
+
+ @abstractmethod
+ def build_observation(self) -> Dict[str, Observation]:
+ """Generate observations for each agent."""
+ pass
+
+ async def astep(
+ self, actions: Dict[str, AgentAction] | Dict[str, Dict[str, int | str]]
+ ) -> tuple[
+ Dict[str, Observation],
+ Dict[str, float],
+ Dict[str, bool],
+ Dict[str, bool],
+ Dict[str, Dict[Any, Any]],
+ ]:
+ """Process one step: record actions, update state, build observations."""
+ self.turn_number += 1
+
+ # 1. Normalize actions and record to history
+ normalized_actions = self._process_incoming_actions(actions)
+
+ # 2. Build State (Update internal state, process actions, check eliminations, update masks)
+ self.build_state(normalized_actions)
+
+ # 3. State Transition (Check conditions, move FSM)
+ self.state_transition()
+
+ # 4. Run evaluators (Moved to check post-transition state)
+ evaluator_response = await self._run_evaluators(self.evaluators)
+
+ # 5. Build Observation (Generate what agents see)
+ observations = self.build_observation()
+
+ # 6. Set termination
+ terminated = {agent: evaluator_response.terminated for agent in self.agents}
+
+ # 7. Terminal evaluators
+ if evaluator_response.terminated and self.terminal_evaluators:
+ terminal_response = await self._run_evaluators(self.terminal_evaluators)
+ if evaluator_response.comments and terminal_response.comments:
+ evaluator_response.comments += terminal_response.comments
+ elif terminal_response.comments:
+ evaluator_response.comments = terminal_response.comments
+
+ rewards = {agent: 0.0 for agent in self.agents}
+ truncations = {agent: False for agent in self.agents}
+ info = {
+ agent: {
+ "comments": evaluator_response.comments or "",
+ "complete_rating": (
+ evaluator_response.rewards.get(f"agent_{i+1}", (0, {}))[0]
+ if evaluator_response.rewards
+ else 0
+ ),
+ }
+ for i, agent in enumerate(self.agents)
+ }
+
+ return observations, rewards, terminated, truncations, info
+
+ async def _run_evaluators(self, evaluators: list[Evaluator]) -> Any:
+ """Run evaluators and aggregate results"""
+ return unweighted_aggregate_evaluate(
+ list(
+ itertools.chain(
+ *await asyncio.gather(
+ *[
+ evaluator.__acall__(
+ turn_number=self.turn_number,
+ messages=self.inbox,
+ env=self,
+ )
+ for evaluator in evaluators
+ ]
+ )
+ )
+ )
+ )
+
+
+class ActionHandler(ABC):
+ """Abstract base class for handling game-specific actions."""
+
+ @abstractmethod
+ def handle_action(
+ self, env: SocialDeductionGame, agent_name: str, action: AgentAction
+ ) -> None:
+ """Handle a single action from an agent based on current state.
+
+ Args:
+ env: The game environment instance.
+ agent_name: The name of the agent performing the action.
+ action: The action object.
+ """
+ pass
+
+ def get_action_instruction(self, env: SocialDeductionGame, agent_name: str) -> str:
+ """Get specific action instructions for an agent based on current state.
+
+ Args:
+ env: The game environment instance.
+ agent_name: The name of the agent.
+
+ Returns:
+ A string containing instructions, or empty string.
+ """
+ return ""
+
+
+class SocialDeductionGame(SocialGame):
+ """Environment for social deduction games with states, roles, and private information.
+
+ Adds to SocialGame:
+ - FSM states (Night, Day, etc.)
+ - Role/team system
+ - Alive/dead status
+ - Private information visibility
+ - State transitions
+ - Turn management (round-robin vs simultaneous)
+ - Global Environment notifications (bypassing visibility filters)
+ """
+
+ def __init__(
+ self,
+ env_profile: EnvironmentProfile,
+ *,
+ config_path: str | None = None,
+ config: Dict[str, Any] | None = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(env_profile=env_profile, include_turn_marker=False, **kwargs)
+
+ # Load game configuration
+ self._config_path = Path(config_path) if config_path else None
+ self._config: Dict[str, Any] = config if config else {}
+
+ # Agent message buffer
+ self.agent_message_buffer: Dict[str, List[str]] = defaultdict(list)
+
+ # Game state
+ self.current_state: str = ""
+ self.agent_to_role: Dict[str, str] = {} # Aurora -> Villager
+ self.role_to_team: Dict[
+ str, str
+ ] = {} # Seer -> Villagers, Werewolf -> Werewolves
+ self.agent_alive: Dict[str, bool] = {} # Aurora -> True
+ self.internal_state: Dict[str, Any] = {} # votes, targets, etc.
+
+ def _load_config(self) -> None:
+ """Load game configuration from JSON file if not already loaded."""
+ if self._config:
+ pass
+ elif self._config_path:
+ if not self._config_path.exists():
+ raise FileNotFoundError(f"Config not found: {self._config_path}")
+ self._config = json.loads(self._config_path.read_text())
+ else:
+ raise ValueError("Neither config nor config_path provided")
+
+ # Build role -> team mapping
+ self.role_to_team = {}
+ for agent_entry in self._config.get("agents", []):
+ role = agent_entry.get("role")
+ team = agent_entry.get("team")
+ if role and team:
+ self.role_to_team.setdefault(role, team)
+
+ async def astep(
+ self, actions: Dict[str, AgentAction] | Dict[str, Dict[str, int | str]]
+ ) -> tuple[
+ Dict[str, Observation],
+ Dict[str, float],
+ Dict[str, bool],
+ Dict[str, bool],
+ Dict[str, Dict[Any, Any]],
+ ]:
+ """Process one step: update state counters and delegate."""
+ # Update state turn counter
+ if not hasattr(self, "_state_turn_count"):
+ self._state_turn_count: Dict[str, int] = {}
+
+ if self.current_state not in self._state_turn_count:
+ self._state_turn_count[self.current_state] = 0
+
+ self._state_turn_count[self.current_state] += 1
+
+ # Call super().astep to get results
+ (
+ observations,
+ rewards,
+ terminated,
+ truncations,
+ info,
+ ) = await super().astep(actions)
+
+ # Log termination
+ if all(terminated.values()):
+ # Extract comments/reasons from info
+ first_agent = list(self.agents)[0]
+ reason = info.get(first_agent, {}).get("comments", "Unknown reason")
+
+ log_msg = f"Game Ends:\n{reason}\n"
+ for agent_name in self.agents:
+ agent_rating = info.get(agent_name, {}).get("complete_rating", 0)
+ log_msg += f"{agent_name}: {agent_rating}\n"
+ logger.info(log_msg)
+
+ return observations, rewards, terminated, truncations, info
+
+ def reset(
+ self,
+ seed: int | None = None,
+ options: Dict[str, str] | None = None,
+ agents: Agents | None = None,
+ omniscient: bool = False,
+ lite: bool = False,
+ include_background_observations: bool = True,
+ ) -> dict[str, Observation]:
+ """Reset environment and initialize game state."""
+ # Call parent reset
+ base_obs = super().reset(
+ seed=seed,
+ options=options,
+ agents=agents,
+ omniscient=omniscient,
+ include_background_observations=False,
+ )
+
+ # Load config
+ self._load_config()
+
+ # Map agent names to roles from config
+ self.agent_to_role = {}
+ for name in self.agents:
+ role = next(
+ (
+ a.get("role", "Unknown")
+ for a in self._config.get("agents", [])
+ if a.get("name") == name.strip()
+ ),
+ "Unknown",
+ )
+ self.agent_to_role[name] = role
+
+ # Initialize alive status and state
+ self.agent_alive = {name: True for name in self.agents}
+ self.current_state = self._config.get("initial_state", "Unknown")
+ self.internal_state = {}
+ self._state_turn_count = {self.current_state: 0}
+
+ # Send initial system message
+ initial_msg_content = f"[Game] State: {self.current_state}. The game begins!"
+ logger.info(initial_msg_content)
+ self.recv_message(
+ "Environment",
+ SimpleMessage(message=initial_msg_content),
+ )
+
+ # Initialize action mask for first turn based on state
+ self._update_action_mask()
+
+ # Update available actions based on game state
+ for agent_name in self.agents:
+ base_obs[agent_name].available_actions = self._get_available_actions(
+ agent_name
+ )
+ # Inject initial action instruction if handler is present
+ if self.action_handler:
+ instruction = self.action_handler.get_action_instruction(
+ self, agent_name
+ )
+ if instruction:
+ base_obs[agent_name].action_instruction = instruction
+
+ return base_obs
+
+ def build_state(self, actions: Dict[str, AgentAction]) -> None:
+ """Update game state based on agent actions."""
+ # 1. Process game-specific logic
+ self._process_actions(actions)
+
+ # 2. Check for eliminations
+ self._check_eliminations()
+
+ def state_transition(self) -> None:
+ """Handle state transitions."""
+ should_transition = self._should_transition_state()
+ logger.debug(
+ f"Turn {self.turn_number}: state={self.current_state}, should_transition={should_transition}, state_turn_count={getattr(self, '_state_turn_count', {})}"
+ )
+ if should_transition:
+ self._perform_transition_state()
+ logger.debug(f"Transitioned to {self.current_state}")
+
+ # Update action mask for next turn based on (potentially new) state
+ state_props = self._config.get("state_properties", {}).get(
+ self.current_state, {}
+ )
+ action_order = state_props.get("action_order", self.action_order)
+ logger.debug(
+ f"About to update mask - state={self.current_state}, action_order={action_order}"
+ )
+ self._update_action_mask()
+ logger.debug(f"After update_action_mask - mask={self.action_mask}")
+
+ def build_observation(self) -> Dict[str, Observation]:
+ """Generate observations for each agent."""
+ return self._build_observations()
+
+ def _process_actions(self, actions: Dict[str, AgentAction]) -> None:
+ """Process actions by delegating to action_handler."""
+ if self.action_handler:
+ for agent_name, action in actions.items():
+ self.action_handler.handle_action(self, agent_name, action)
+
+ def _check_eliminations(self) -> None:
+ """Check if anyone should be eliminated (voted out, killed, etc.)."""
+ # Example: tally votes and eliminate player with most votes
+ pass
+
+ def _update_action_mask(self) -> None:
+ """Update action mask for next turn based on state configuration."""
+ # Get action_order for this state from config, or use environment default
+ state_props = self._config.get("state_properties", {}).get(
+ self.current_state, {}
+ )
+ action_order = state_props.get("action_order", self.action_order)
+ acting_roles = state_props.get("acting_roles", [])
+
+ # Determine which agents are eligible to act in this state
+ if acting_roles:
+ # Only agents with specific roles can act
+ eligible_indices = [
+ idx
+ for idx, agent_name in enumerate(self.agents)
+ if self.agent_alive.get(agent_name, False)
+ and self.agent_to_role.get(agent_name, "") in acting_roles
+ ]
+ else:
+ # All alive agents can act
+ eligible_indices = [
+ idx
+ for idx, agent_name in enumerate(self.agents)
+ if self.agent_alive.get(agent_name, False)
+ ]
+
+ # Update action mask based on action order
+ self.action_mask = [False for _ in self.agents]
+
+ if not eligible_indices:
+ # No eligible agents - keep all masks False
+ return
+
+ if action_order == "round-robin":
+ # Cycle through eligible agents only
+ if not hasattr(self, "_round_robin_idx"):
+ self._round_robin_idx = 0
+ # Get next eligible agent
+ acting_idx = eligible_indices[self._round_robin_idx % len(eligible_indices)]
+ self.action_mask[acting_idx] = True
+ self._round_robin_idx += 1
+ elif action_order == "random":
+ # Pick random eligible agent
+ acting_idx = random.choice(eligible_indices)
+ self.action_mask[acting_idx] = True
+ else:
+ # Simultaneous: all eligible agents can act
+ for idx in eligible_indices:
+ self.action_mask[idx] = True
+
+ def _should_transition_state(self) -> bool:
+ """Check if we should move to next state based on how many agents have acted."""
+ state_props = self._config.get("state_properties", {}).get(
+ self.current_state, {}
+ )
+ acting_roles = state_props.get("acting_roles", [])
+ action_order = state_props.get("action_order", self.action_order)
+
+ turns_in_state = self._state_turn_count.get(self.current_state, 0)
+
+ # Determine how many agents should act in this state
+ if acting_roles:
+ # Only specific roles act - count them
+ num_acting_agents = sum(
+ 1
+ for agent in self.agents
+ if self.agent_alive.get(agent, False)
+ and self.agent_to_role.get(agent, "") in acting_roles
+ )
+ else:
+ # All alive agents act
+ num_acting_agents = sum(1 for alive in self.agent_alive.values() if alive)
+
+ # Transition logic based on action order
+ if action_order == "simultaneous":
+ # All agents act at once - transition after 1 turn
+ return turns_in_state >= 1
+ elif action_order in ["round-robin", "random"]:
+ # Each agent acts once - transition after N turns
+ return turns_in_state >= num_acting_agents
+
+ return False
+
+ def _perform_transition_state(self) -> None:
+ """Transition to next state based on FSM."""
+ state_transition = self._config.get("state_transition", {})
+ next_state = state_transition.get(self.current_state)
+
+ if next_state:
+ self.current_state = next_state
+ # Reset turn counter for the new state
+ if hasattr(self, "_state_turn_count"):
+ self._state_turn_count[self.current_state] = 0
+ # Reset round-robin counter for the new state
+ if hasattr(self, "_round_robin_idx"):
+ self._round_robin_idx = 0
+ self.recv_message(
+ "Environment",
+ SimpleMessage(message=f"[Game] Entering state: {self.current_state}"),
+ )
+ logger.info(f"{'-'* 50}\nTurn to {self.current_state}\n{'-'* 50}")
+
+ def _build_observations(self) -> Dict[str, Observation]:
+ """Build observations for each agent based on visibility rules."""
+ observations = {}
+
+ for agent_name in self.agents:
+ observations[agent_name] = self._get_observation(agent_name)
+
+ return observations
+
+ def recv_message(
+ self, sender: str, message: Message, receivers: List[str] | None = None
+ ) -> None:
+ """Receive a message and distribute it to agents based on visibility."""
+ super().recv_message(sender, message)
+
+ # Determine visibility for each agent
+ state_props = self._config.get("state_properties", {}).get(
+ self.current_state, {}
+ )
+ visibility = state_props.get("visibility", "public")
+
+ for agent_name in self.agents:
+ should_see = False
+
+ # Check for explicit receivers
+ if receivers is not None:
+ if agent_name in receivers:
+ should_see = True
+ elif visibility == "public":
+ should_see = True
+ elif visibility == "team":
+ sender_team = self.role_to_team.get(
+ self.agent_to_role.get(sender, ""), ""
+ )
+ viewer_team = self.role_to_team.get(
+ self.agent_to_role.get(agent_name, ""), ""
+ )
+ should_see = sender_team == viewer_team
+ elif visibility == "private":
+ should_see = sender == agent_name
+
+ # Environment messages should be public unless explicitly targeted
+ if sender == "Environment" and receivers is None:
+ should_see = True
+
+ if should_see:
+ if sender == "Environment":
+ self.agent_message_buffer[agent_name].append(
+ message.to_natural_language()
+ )
+ else:
+ self.agent_message_buffer[agent_name].append(
+ f"{sender}: {message.to_natural_language()}"
+ )
+
+ def _get_observation(self, agent_name: str) -> Observation:
+ """Get observation for a specific agent."""
+ # Get visible history from buffer
+ visible_history = "\n".join(self.agent_message_buffer[agent_name])
+
+ # Clear buffer after reading: Observation usually only sends new content; agent's memory handles accumulation.
+ self.agent_message_buffer[agent_name].clear()
+
+ # Get available actions
+ available_actions = self._get_available_actions(agent_name)
+
+ # Add specific action instructions if handler is present
+ action_instruction = ""
+ if self.action_handler:
+ instruction = self.action_handler.get_action_instruction(self, agent_name)
+ if instruction:
+ action_instruction = instruction
+
+ return Observation(
+ last_turn=visible_history if visible_history else "[No recent activity]",
+ turn_number=self.turn_number,
+ available_actions=available_actions,
+ action_instruction=action_instruction,
+ )
+
+ def _get_available_actions(
+ self, agent_name: str
+ ) -> List[Literal["none", "speak", "non-verbal communication", "action", "leave"]]:
+ """Get available actions for this agent based on state and role, restricted to allowed literals."""
+ if not self.agent_alive.get(agent_name, False):
+ return ["none"]
+
+ state_props = self._config.get("state_properties", {}).get(
+ self.current_state, {}
+ )
+ acting_roles = state_props.get("acting_roles", [])
+ actions = state_props.get("actions", ["speak"])
+
+ # If state restricts by role, check if this agent can act
+ if acting_roles:
+ agent_role = self.agent_to_role.get(agent_name, "")
+ if agent_role not in acting_roles:
+ return ["none"]
+
+ # Check action mask (for round-robin/random ordering)
+ if self.action_mask:
+ try:
+ agent_idx = self.agents.index(agent_name)
+ if not self.action_mask[agent_idx]:
+ return ["none"]
+ except ValueError:
+ pass # Should not happen if agent_name is valid
+
+ allowed = {
+ "none",
+ "speak",
+ "non-verbal communication",
+ "action",
+ "leave",
+ }
+ filtered = [a for a in actions if a in allowed] or ["none"]
+ return cast(
+ List[
+ Literal["none", "speak", "non-verbal communication", "action", "leave"]
+ ],
+ filtered,
+ )
+
+ def get_agent_role(self, agent_name: str) -> str:
+ """Get the role of an agent."""
+ return self.agent_to_role.get(agent_name, "Unknown")
+
+ def get_agent_team(self, agent_name: str) -> str:
+ """Get the team of an agent."""
+ role = self.get_agent_role(agent_name)
+ return self.role_to_team.get(role, "Unknown")
+
+
+def load_config(config_path: str | Path) -> Dict[str, Any]:
+ """Load game configuration from JSON file."""
+ path = Path(config_path)
+ if not path.exists():
+ raise FileNotFoundError(f"Config not found: {path}")
+ return cast(Dict[str, Any], json.loads(path.read_text()))
diff --git a/sotopia/generation_utils/generate.py b/sotopia/generation_utils/generate.py
index c46d6496f..06dc08c0c 100644
--- a/sotopia/generation_utils/generate.py
+++ b/sotopia/generation_utils/generate.py
@@ -1,5 +1,6 @@
import logging
import os
+import re
import json
from dataclasses import dataclass
from litellm import acompletion
@@ -8,7 +9,7 @@
from litellm.litellm_core_utils.get_supported_openai_params import (
get_supported_openai_params,
)
-from typing import Any, cast
+from typing import Any, cast, Literal, overload
import gin
@@ -48,6 +49,14 @@
)
console_handler.setFormatter(formatter)
+
+def fill_template(template: str, **kwargs: str) -> str:
+ """Fill template with kwargs, ignoring missing keys."""
+ for k, v in kwargs.items():
+ template = template.replace(f"{{{k}}}", v)
+ return template
+
+
# Add handler to logger
log.addHandler(console_handler)
@@ -106,13 +115,43 @@ async def format_bad_output(
# Parse format_instructions to get the schema
try:
schema = json.loads(format_instructions)
+
+ def _fix_schema(s: dict[str, Any]) -> None:
+ if s.get("type") == "array":
+ if "prefixItems" in s:
+ # OpenAI doesn't support prefixItems (tuple validation).
+ # Convert to items: {anyOf: [...]} to satisfy "items must be a schema object"
+ # This allows valid tuple elements but loses positional validation, which is acceptable for strict=False.
+ prefix_items = s.pop("prefixItems")
+ s["items"] = {"anyOf": prefix_items}
+
+ if "items" in s and isinstance(s["items"], dict):
+ _fix_schema(s["items"])
+ elif "items" in s and isinstance(s["items"], list):
+ # Should not happen after the fix above, but handle legacy cases if any
+ for item in s["items"]:
+ _fix_schema(item)
+ elif s.get("type") == "object":
+ if "properties" in s:
+ for prop in s["properties"].values():
+ _fix_schema(prop)
+ if "additionalProperties" in s and isinstance(
+ s["additionalProperties"], dict
+ ):
+ _fix_schema(s["additionalProperties"])
+ if "$defs" in s:
+ for def_schema in s["$defs"].values():
+ _fix_schema(def_schema)
+
+ _fix_schema(schema)
+
# Build proper json_schema response_format
completion_kwargs["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "reformatted_output",
"schema": schema,
- "strict": True,
+ "strict": False,
},
}
except json.JSONDecodeError:
@@ -126,10 +165,40 @@ async def format_bad_output(
response = await acompletion(**completion_kwargs)
reformatted_output = response.choices[0].message.content
assert isinstance(reformatted_output, str)
+ log.debug(f"Model: {model_name}")
+ log.debug(f"Prompt: {content}")
log.info(f"Reformated output: {reformatted_output}")
return reformatted_output
+@overload
+async def agenerate(
+ model_name: str,
+ template: str,
+ input_values: dict[str, str],
+ output_parser: OutputParser[OutputType],
+ temperature: TemperatureSetting | float | None | object = _TEMPERATURE_SENTINEL,
+ structured_output: bool = False,
+ bad_output_process_model: str | None = None,
+ use_fixed_model_version: bool = True,
+ return_prompt_and_response: Literal[False] = False,
+) -> OutputType: ...
+
+
+@overload
+async def agenerate(
+ model_name: str,
+ template: str,
+ input_values: dict[str, str],
+ output_parser: OutputParser[OutputType],
+ temperature: TemperatureSetting | float | None | object = _TEMPERATURE_SENTINEL,
+ structured_output: bool = False,
+ bad_output_process_model: str | None = None,
+ use_fixed_model_version: bool = True,
+ return_prompt_and_response: Literal[True] = ...,
+) -> tuple[OutputType, list[dict[str, str]], str]: ...
+
+
@gin.configurable
@validate_call
async def agenerate(
@@ -141,7 +210,8 @@ async def agenerate(
structured_output: bool = False,
bad_output_process_model: str | None = None,
use_fixed_model_version: bool = True,
-) -> OutputType:
+ return_prompt_and_response: bool = False,
+) -> OutputType | tuple[OutputType, list[dict[str, str]], str]:
"""Generate text using LiteLLM instead of Langchain."""
# Format template with input values
if "format_instructions" not in input_values:
@@ -265,9 +335,19 @@ async def _call_with_retry(completion_kwargs: dict[str, Any]) -> Any:
# Include agent name in logs if available
agent_name = input_values.get("agent", "")
log_prefix = f" [{agent_name}]" if agent_name else ""
- log.info(f"Generated result{log_prefix}: {result}")
+ log.debug(f"Model: {model_name}")
+ log.debug(f"Prompt: {messages}")
+ try:
+ clean_result = json.dumps(json.loads(result.strip()), ensure_ascii=False)
+ except Exception:
+ clean_result = result.replace("\n", "").strip()
+ clean_result = re.sub(r"\s+", " ", clean_result)
+ log.info(f"Generated result{log_prefix}: {clean_result}")
assert isinstance(result, str)
- return cast(OutputType, output_parser.parse(result))
+ parsed = cast(OutputType, output_parser.parse(result))
+ if return_prompt_and_response:
+ return parsed, messages, result
+ return parsed
messages = [{"role": "user", "content": template}]
@@ -303,10 +383,43 @@ async def _call_with_retry(completion_kwargs: dict[str, Any]) -> Any:
# Include agent name in logs if available
agent_name = input_values.get("agent", "")
log_prefix = f" [{agent_name}]" if agent_name else ""
- log.info(f"Generated result{log_prefix}: {parsed_result}")
+ log.debug(f"Model: {model_name}")
+ log.debug(f"Prompt: {messages}")
+ try:
+ clean_result = json.dumps(json.loads(result.strip()), ensure_ascii=False)
+ except Exception:
+ clean_result = result.replace("\n", "").strip()
+ clean_result = re.sub(r"\s+", " ", clean_result)
+ log.info(f"Generated result{log_prefix}: {clean_result}")
+ if return_prompt_and_response:
+ return parsed_result, messages, result
return parsed_result
+@overload
+async def agenerate_env_profile(
+ model_name: str,
+ inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex",
+ examples: str = "",
+ temperature: TemperatureSetting | float | None | object = _TEMPERATURE_SENTINEL,
+ bad_output_process_model: str | None = None,
+ use_fixed_model_version: bool = True,
+ return_prompt_and_response: Literal[False] = False,
+) -> EnvironmentProfile: ...
+
+
+@overload
+async def agenerate_env_profile(
+ model_name: str,
+ inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex",
+ examples: str = "",
+ temperature: TemperatureSetting | float | None | object = _TEMPERATURE_SENTINEL,
+ bad_output_process_model: str | None = None,
+ use_fixed_model_version: bool = True,
+ return_prompt_and_response: Literal[True] = ...,
+) -> tuple[EnvironmentProfile, list[dict[str, str]], str]: ...
+
+
@gin.configurable
@validate_call
async def agenerate_env_profile(
@@ -316,28 +429,71 @@ async def agenerate_env_profile(
temperature: TemperatureSetting | float | None | object = _TEMPERATURE_SENTINEL,
bad_output_process_model: str | None = None,
use_fixed_model_version: bool = True,
-) -> EnvironmentProfile:
+ return_prompt_and_response: bool = False,
+) -> EnvironmentProfile | tuple[EnvironmentProfile, list[dict[str, str]], str]:
"""
Using langchain to generate the background
"""
- return await agenerate(
- model_name=model_name,
- template="""Please generate scenarios and goals based on the examples below as well as the inspirational prompt, when creating the goals, try to find one point that both sides may not agree upon initially and need to collaboratively resolve it.
- Examples:
- {examples}
- Inspirational prompt: {inspiration_prompt}
- Please use the following format:
- {format_instructions}
- """,
- input_values=dict(
- inspiration_prompt=inspiration_prompt,
- examples=examples,
- ),
- output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile),
- temperature=temperature,
- bad_output_process_model=bad_output_process_model,
- use_fixed_model_version=use_fixed_model_version,
- )
+ if return_prompt_and_response:
+ return await agenerate(
+ model_name=model_name,
+ template="""Please generate scenarios and goals based on the examples below as well as the inspirational prompt, when creating the goals, try to find one point that both sides may not agree upon initially and need to collaboratively resolve it.
+ Examples:
+ {examples}
+ Inspirational prompt: {inspiration_prompt}
+ Please use the following format:
+ {format_instructions}
+ """,
+ input_values=dict(
+ inspiration_prompt=inspiration_prompt,
+ examples=examples,
+ ),
+ output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile),
+ temperature=temperature,
+ bad_output_process_model=bad_output_process_model,
+ use_fixed_model_version=use_fixed_model_version,
+ return_prompt_and_response=True,
+ )
+ else:
+ return await agenerate(
+ model_name=model_name,
+ template="""Please generate scenarios and goals based on the examples below as well as the inspirational prompt, when creating the goals, try to find one point that both sides may not agree upon initially and need to collaboratively resolve it.
+ Examples:
+ {examples}
+ Inspirational prompt: {inspiration_prompt}
+ Please use the following format:
+ {format_instructions}
+ """,
+ input_values=dict(
+ inspiration_prompt=inspiration_prompt,
+ examples=examples,
+ ),
+ output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile),
+ temperature=temperature,
+ bad_output_process_model=bad_output_process_model,
+ use_fixed_model_version=use_fixed_model_version,
+ return_prompt_and_response=False,
+ )
+
+
+@overload
+async def agenerate_relationship_profile(
+ model_name: str,
+ agents_profiles: list[str],
+ bad_output_process_model: str | None = None,
+ use_fixed_model_version: bool = True,
+ return_prompt_and_response: Literal[False] = False,
+) -> RelationshipProfile: ...
+
+
+@overload
+async def agenerate_relationship_profile(
+ model_name: str,
+ agents_profiles: list[str],
+ bad_output_process_model: str | None = None,
+ use_fixed_model_version: bool = True,
+ return_prompt_and_response: Literal[True] = ...,
+) -> tuple[RelationshipProfile, list[dict[str, str]], str]: ...
@validate_call
@@ -346,25 +502,80 @@ async def agenerate_relationship_profile(
agents_profiles: list[str],
bad_output_process_model: str | None = None,
use_fixed_model_version: bool = True,
-) -> tuple[RelationshipProfile, str]:
+ return_prompt_and_response: bool = False,
+) -> RelationshipProfile | tuple[RelationshipProfile, list[dict[str, str]], str]:
"""
Using langchain to generate the background
"""
agent_profile = "\n".join(agents_profiles)
- return await agenerate(
- model_name=model_name,
- template="""Please generate relationship between two agents based on the agents' profiles below. Note that you generate
- {agent_profile}
- Please use the following format:
- {format_instructions}
- """,
- input_values=dict(
- agent_profile=agent_profile,
- ),
- output_parser=PydanticOutputParser(pydantic_object=RelationshipProfile),
- bad_output_process_model=bad_output_process_model,
- use_fixed_model_version=use_fixed_model_version,
- )
+ if return_prompt_and_response:
+ return await agenerate(
+ model_name=model_name,
+ template="""Please generate relationship between two agents based on the agents' profiles below. Note that you generate
+ {agent_profile}
+ Please use the following format:
+ {format_instructions}
+ """,
+ input_values=dict(
+ agent_profile=agent_profile,
+ ),
+ output_parser=PydanticOutputParser(pydantic_object=RelationshipProfile),
+ bad_output_process_model=bad_output_process_model,
+ use_fixed_model_version=use_fixed_model_version,
+ return_prompt_and_response=True,
+ )
+ else:
+ return await agenerate(
+ model_name=model_name,
+ template="""Please generate relationship between two agents based on the agents' profiles below. Note that you generate
+ {agent_profile}
+ Please use the following format:
+ {format_instructions}
+ """,
+ input_values=dict(
+ agent_profile=agent_profile,
+ ),
+ output_parser=PydanticOutputParser(pydantic_object=RelationshipProfile),
+ bad_output_process_model=bad_output_process_model,
+ use_fixed_model_version=use_fixed_model_version,
+ return_prompt_and_response=False,
+ )
+
+
+@overload
+async def agenerate_action(
+ model_name: str,
+ history: str,
+ turn_number: int,
+ action_types: list[ActionType],
+ agent: str,
+ goal: str,
+ temperature: TemperatureSetting | float | None | object = _TEMPERATURE_SENTINEL,
+ script_like: bool = False,
+ bad_output_process_model: str | None = None,
+ use_fixed_model_version: bool = True,
+ strict_action_constraint: bool = False,
+ custom_template: str | None = None,
+ return_prompt_and_response: Literal[False] = False,
+) -> AgentAction: ...
+
+
+@overload
+async def agenerate_action(
+ model_name: str,
+ history: str,
+ turn_number: int,
+ action_types: list[ActionType],
+ agent: str,
+ goal: str,
+ temperature: TemperatureSetting | float | None | object = _TEMPERATURE_SENTINEL,
+ script_like: bool = False,
+ bad_output_process_model: str | None = None,
+ use_fixed_model_version: bool = True,
+ strict_action_constraint: bool = False,
+ custom_template: str | None = None,
+ return_prompt_and_response: Literal[True] = ...,
+) -> tuple[AgentAction, list[dict[str, str]], str]: ...
@gin.configurable
@@ -380,12 +591,21 @@ async def agenerate_action(
script_like: bool = False,
bad_output_process_model: str | None = None,
use_fixed_model_version: bool = True,
-) -> AgentAction:
+ strict_action_constraint: bool = False,
+ custom_template: str | None = None,
+ return_prompt_and_response: bool = False,
+) -> AgentAction | tuple[AgentAction, list[dict[str, str]], str]:
"""
Using langchain to generate an example episode
"""
try:
- if script_like:
+ if custom_template:
+ if script_like:
+ raise ValueError(
+ "script_like and custom_template are mutually exclusive"
+ )
+ template = custom_template
+ elif script_like:
# model as playwright
template = """
Now you are a famous playwright, your task is to continue writing one turn for agent {agent} under a given background and history to help {agent} reach social goal. Please continue the script based on the previous turns. You can only generate one turn at a time.
@@ -418,24 +638,83 @@ async def agenerate_action(
Your action should follow the given format:
{format_instructions}
"""
- return await agenerate(
- model_name=model_name,
- template=template,
- input_values=dict(
- agent=agent,
- turn_number=str(turn_number),
- history=history,
- action_list=" ".join(action_types),
- ),
- output_parser=PydanticOutputParser(pydantic_object=AgentAction),
- temperature=temperature,
- structured_output=True,
- bad_output_process_model=bad_output_process_model,
- use_fixed_model_version=use_fixed_model_version,
- )
+
+ # Create dynamic AgentAction model with restricted ActionType
+ if strict_action_constraint and action_types:
+ # Create a dynamic Literal for the allowed action types
+ # Use __getitem__ to dynamically create Literal from list of strings
+ DynamicActionType = Literal.__getitem__(tuple(action_types))
+
+ # Create a dynamic Pydantic model
+ from pydantic import create_model, Field
+
+ DynamicAgentAction = create_model(
+ "AgentAction",
+ action_type=(
+ DynamicActionType,
+ Field(
+ ...,
+ description="whether to speak at this turn or choose to not do anything",
+ ),
+ ),
+ argument=(
+ str,
+ Field(
+ ...,
+ description="the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action",
+ ),
+ ),
+ __base__=AgentAction,
+ )
+
+ output_parser_obj: PydanticOutputParser[Any] = PydanticOutputParser(
+ pydantic_object=DynamicAgentAction
+ )
+ else:
+ output_parser_obj = PydanticOutputParser(pydantic_object=AgentAction)
+
+ if return_prompt_and_response:
+ return await agenerate(
+ model_name=model_name,
+ template=template,
+ input_values=dict(
+ agent=agent,
+ turn_number=str(turn_number),
+ history=history,
+ action_list=" ".join(action_types),
+ goal=goal,
+ ),
+ output_parser=output_parser_obj,
+ temperature=temperature,
+ structured_output=True,
+ bad_output_process_model=bad_output_process_model,
+ use_fixed_model_version=use_fixed_model_version,
+ return_prompt_and_response=True,
+ )
+ else:
+ return await agenerate(
+ model_name=model_name,
+ template=template,
+ input_values=dict(
+ agent=agent,
+ turn_number=str(turn_number),
+ history=history,
+ action_list=" ".join(action_types),
+ goal=goal,
+ ),
+ output_parser=output_parser_obj,
+ temperature=temperature,
+ structured_output=True,
+ bad_output_process_model=bad_output_process_model,
+ use_fixed_model_version=use_fixed_model_version,
+ return_prompt_and_response=False,
+ )
except Exception as e:
log.warning(f"Failed to generate action due to {e}")
- return AgentAction(action_type="none", argument="")
+ action = AgentAction(action_type="none", argument="")
+ if return_prompt_and_response:
+ return action, [], str(e)
+ return action
@gin.configurable
@@ -459,7 +738,7 @@ async def agenerate_script(
"""
try:
if single_step:
- return await agenerate(
+ result = await agenerate(
model_name=model_name,
template="""Now you are a famous playwright, your task is to continue writing one turn for agent {agent} under a given background and history to help {agent} reach social goal. Please continue the script based on the previous turns. You can only generate one turn at a time.
@@ -477,7 +756,7 @@ async def agenerate_script(
history=history,
agent=agent_name,
),
- output_parser=ScriptOutputParser( # type: ignore[arg-type]
+ output_parser=ScriptOutputParser(
agent_names=agent_names,
background=background.to_natural_language(),
single_turn=True,
@@ -486,9 +765,10 @@ async def agenerate_script(
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
+ return cast(tuple[ScriptInteractionReturnType, str], result)
else:
- return await agenerate(
+ result = await agenerate(
model_name=model_name,
template="""
Please write the script between two characters based on their social goals with a maximum of 20 turns.
@@ -501,7 +781,7 @@ async def agenerate_script(
input_values=dict(
background=background.to_natural_language(),
),
- output_parser=ScriptOutputParser( # type: ignore[arg-type]
+ output_parser=ScriptOutputParser(
agent_names=agent_names,
background=background.to_natural_language(),
single_turn=False,
@@ -510,6 +790,7 @@ async def agenerate_script(
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
+ return cast(tuple[ScriptInteractionReturnType, str], result)
except Exception as e:
# TODO raise(e) # Maybe we do not want to return anything?
print(f"Exception in agenerate {e}")
diff --git a/sotopia/messages/message_classes.py b/sotopia/messages/message_classes.py
index bb8fbc435..b6e33da26 100644
--- a/sotopia/messages/message_classes.py
+++ b/sotopia/messages/message_classes.py
@@ -33,6 +33,9 @@ class Observation(Message):
last_turn: str = Field(description="the last turn of the conversation")
turn_number: int = Field(description="the turn number of the conversation")
available_actions: list[ActionType] = Field(description="the available actions")
+ action_instruction: str = Field(
+ default="", description="instruction for the action"
+ )
def to_natural_language(self) -> str:
if self.turn_number == 0:
@@ -46,6 +49,9 @@ class ScriptBackground(Message):
agent_names: list[str] = Field(description="names of all participants")
agent_backgrounds: list[str] = Field(description="backgrounds of all participants")
agent_goals: list[str] = Field(description="goals of all participants")
+ hide_unknown: bool = Field(
+ default=False, description="whether to hide unknown background/goals"
+ )
def to_natural_language(self) -> str:
# Format participant names naturally with "and" before the last name
@@ -62,12 +68,20 @@ def to_natural_language(self) -> str:
if any(self.agent_backgrounds):
backgrounds_text = ""
for name, background in zip(self.agent_names, self.agent_backgrounds):
- bg_text = background if background else "Unknown"
- backgrounds_text += f"{name}'s background: {bg_text}\n"
+ if self.hide_unknown:
+ if background and background != "Unknown":
+ backgrounds_text += f"{name}'s background: {background}\n"
+ else:
+ bg_text = background if background else "Unknown"
+ backgrounds_text += f"{name}'s background: {bg_text}\n"
goals_text = ""
for name, goal in zip(self.agent_names, self.agent_goals):
- goals_text += f"{name}'s goal: {goal}\n"
+ if self.hide_unknown:
+ if goal and goal != "Unknown":
+ goals_text += f"{name}'s goal: {goal}\n"
+ else:
+ goals_text += f"{name}'s goal: {goal}\n"
return format_docstring(
f"""Here is the context of this interaction:
@@ -105,6 +119,9 @@ class ScriptEnvironmentResponse(Message):
comments: str | None = Field(
description="All of the comments supporting the termination and rating"
)
+ rewards: dict[str, float | tuple[float, dict[str, float]]] | None = Field(
+ description="dictionary of rewards for all agents", default=None
+ )
def to_natural_language(self) -> str:
reason_to_stop = format_docstring(
diff --git a/sotopia/samplers/uniform_sampler.py b/sotopia/samplers/uniform_sampler.py
index d519eee0d..bf7a9e968 100644
--- a/sotopia/samplers/uniform_sampler.py
+++ b/sotopia/samplers/uniform_sampler.py
@@ -1,5 +1,5 @@
-import logging
import random
+import logging
from typing import Any, Generator, Type, TypeVar
from sotopia.agents.base_agent import BaseAgent
diff --git a/sotopia/server.py b/sotopia/server.py
index d3710f280..d34ec4f69 100644
--- a/sotopia/server.py
+++ b/sotopia/server.py
@@ -1,8 +1,9 @@
import asyncio
+import json
import itertools
import logging
import re
-from typing import Literal, Sequence, Type, AsyncGenerator, Union
+from typing import Literal, Sequence, Type, AsyncGenerator, Union, Any
import gin
from pydantic import validate_call
@@ -63,7 +64,6 @@ def run_sync_server(
else:
environment_messages = env.reset()
agents = Agents()
- # agents_model_names = [model_name_dict["agent1"], model_name_dict["agent2"]]
# derive agent keys like agent1, agent2, … agentN
agent_keys = sorted(k for k in model_name_dict if re.fullmatch(r"agent\d+", k))
agents_model_names = [model_name_dict[k] for k in agent_keys]
@@ -128,6 +128,8 @@ async def arun_one_episode(
episode_pk: str | None = None,
streaming: bool = False,
simulation_status: NonStreamingSimulationStatus | None = None,
+ output_path: str | None = None,
+ metadata: dict[str, Any] | None = None,
) -> Union[
list[tuple[str, str, Message]],
AsyncGenerator[list[list[tuple[str, str, Message]]], None],
@@ -159,6 +161,7 @@ async def generate_messages() -> (
while not done:
# gather agent messages
agent_messages: dict[str, AgentAction] = dict()
+
actions = await asyncio.gather(
*[
agents[agent_name].aact(environment_messages[agent_name])
@@ -206,7 +209,6 @@ async def generate_messages() -> (
environment=env.profile.pk,
agents=[agent.profile.pk for agent in agent_list],
tag=tag,
- # models=[env.model_name, agent_list[0].model_name, agent_list[1].model_name],
models=[env.model_name] + [agent.model_name for agent in agent_list],
messages=[
[(m[0], m[1], m[2].to_natural_language()) for m in messages_in_turn]
@@ -214,6 +216,7 @@ async def generate_messages() -> (
],
reasoning=info[env.agents[0]]["comments"],
rewards=[info[agent_name]["complete_rating"] for agent_name in env.agents],
+ metadata=metadata or {},
)
if streaming:
@@ -239,6 +242,53 @@ async def generate_messages() -> (
except Exception as e:
logging.error(f"Failed to save episode log: {e}")
+ if output_path:
+ try:
+ # Construct simplified log
+ model_mapping = {}
+ for agent_name, agent in agents.items():
+ model_mapping[agent_name] = agent.model_name
+
+ turns_list = []
+ for agent_name, agent in agents.items():
+ if hasattr(agent, "generation_history"):
+ for entry in agent.generation_history:
+ # Flatten and clean up entry
+ clean_entry = {
+ "turn_number": entry.get("turn_number"),
+ "agent_name": agent_name,
+ "model_name": model_mapping.get(agent_name, "Unknown"),
+ "prompt": entry.get("prompt"),
+ "response": entry.get("response"),
+ }
+ turns_list.append(clean_entry)
+
+ # Sort by turn number
+ turns_list.sort(
+ key=lambda x: int(x["turn_number"])
+ if x["turn_number"] is not None
+ else -1
+ )
+
+ custom_log = {
+ "pk": epilog.pk,
+ "tag": tag,
+ "metadata": metadata or {},
+ "model_mapping": model_mapping,
+ "rewards": epilog.rewards,
+ "turns": turns_list,
+ }
+
+ with open(output_path, "w") as f:
+ # Use json.dumps with indent for readability
+ f.write(json.dumps(custom_log, indent=4))
+ except Exception as e:
+ import traceback
+
+ logging.error(
+ f"Failed to save episode log to file: {e}\n{traceback.format_exc()}"
+ )
+
if streaming:
return generate_messages()
else:
@@ -311,11 +361,6 @@ def get_agent_class(
),
],
}
- # agents_model_dict = {
- # agent_name: model_name
- # for agent_name, model_name in model_dict.items()
- # if agent_name.startswith("agent")
- # }
agent_keys = sorted(k for k in model_dict if re.fullmatch(r"agent\d+", k))
agent_models = [model_dict[k] for k in agent_keys]
@@ -491,25 +536,44 @@ async def aevaluate_one_episode(
)
)
)
- info: dict[str, dict[str, str | ScriptEnvironmentResponse | float | None]] = {
- episode.agents[0]: {
+ info: dict[str, dict[str, str | ScriptEnvironmentResponse | float | None]] = {}
+ for i, agent_name in enumerate(episode.agents):
+ # Try to find reward in various possible locations
+ rating: float | tuple[float, dict[str, float]] | None = 0
+ if response.rewards:
+ # Check for direct name match or numeric index match (agent_1, agent_2...)
+ if agent_name in response.rewards:
+ rating = response.rewards[agent_name]
+ elif f"agent_{i+1}" in response.rewards:
+ rating = response.rewards[f"agent_{i+1}"]
+
+ # Fallback to legacy p1_rate/p2_rate for 2-agent cases if rewards missing
+ if rating == 0 or rating is None:
+ if i == 0 and response.p1_rate is not None:
+ rating = response.p1_rate
+ elif i == 1 and response.p2_rate is not None:
+ rating = response.p2_rate
+
+ # Unpack tuple if necessary (value, metadata)
+ if isinstance(rating, tuple):
+ rating = rating[0]
+
+ info[agent_name] = {
"comments": response.comments or "",
- "complete_rating": response.p1_rate or 0, # type: ignore
- },
- episode.agents[1]: {
- "comments": response.comments or "",
- "complete_rating": response.p2_rate or 0, # type: ignore
- },
- }
+ "complete_rating": rating or 0,
+ }
+
assert isinstance(episode.models, list)
+ # Generic model list: [env_model, agent1_model, agent2_model...]
+ log_models = [model] + episode.models[1:] if episode.models else [model]
+
epilog = EpisodeLog(
environment=episode.environment,
agents=episode.agents,
tag=tag,
- models=[model, episode.models[1], episode.models[2]],
+ models=log_models,
messages=episode.messages,
- reasoning=str(info[episode.agents[0]]["comments"])
- + str(info[episode.agents[1]]["comments"]),
+ reasoning="\n".join([str(info[agent]["comments"]) for agent in episode.agents]),
rewards=[info[agent_name]["complete_rating"] for agent_name in episode.agents],
rewards_prompt="TBD",
)