diff --git a/agent/actor/__init__.py b/agent/actor/__init__.py new file mode 100644 index 0000000..47a374c --- /dev/null +++ b/agent/actor/__init__.py @@ -0,0 +1,5 @@ +"""Actor Agent components for action execution and validation.""" + +from .action_executor import ActionExecutor + +__all__ = ["ActionExecutor"] \ No newline at end of file diff --git a/agent/actor/action_executor.py b/agent/actor/action_executor.py new file mode 100644 index 0000000..baf751b --- /dev/null +++ b/agent/actor/action_executor.py @@ -0,0 +1,117 @@ +"""Action execution and validation for Actor Agent.""" + +from typing import Any, Dict, List + +from browser_env import Action + + +class ActionExecutor: + """Executes and validates actions for the Actor Agent.""" + + def __init__(self, action_set_tag: str) -> None: + self.action_set_tag = action_set_tag + self.execution_history: List[Dict[str, Any]] = [] + + def validate_action(self, action: Action) -> Dict[str, Any]: + """Validate an action format and content before execution. + + Args: + action: The action to validate + + Returns: + Dictionary containing validation results + """ + try: + # Validate action format and content + validation_result = self._validate_action_format(action) + + # Store validation history + validation_record = { + "action": action, + "validation_passed": validation_result["valid"], + "validation_details": validation_result, + } + self.execution_history.append(validation_record) + + return { + "valid": validation_result["valid"], + "action": action, + "validation_details": validation_result, + } + + except Exception as e: + # Record failed validation + validation_record = { + "action": action, + "validation_passed": False, + "error": str(e), + } + self.execution_history.append(validation_record) + + return { + "valid": False, + "error": str(e), + "action": action, + "validation_details": {"valid": False, "error": str(e)}, + } + + def _validate_action_format(self, action: Action) -> Dict[str, Any]: + """Validate the format and content of an action. + + Args: + action: The action to validate + + Returns: + Dictionary containing validation results + """ + required_fields = ["action_type"] + validation_result = { + "valid": True, + "missing_fields": [], + "invalid_fields": [], + "warnings": [], + } + + # Check required fields + for field in required_fields: + if field not in action: + validation_result["valid"] = False + validation_result["missing_fields"].append(field) + + # Validate action type + if "action_type" in action: + action_type = action["action_type"] + valid_types = [ + "CLICK", "TYPE", "SCROLL", "KEY_PRESS", "GOTO_URL", + "NEW_TAB", "PAGE_CLOSE", "GO_BACK", "GO_FORWARD", + "PAGE_FOCUS", "CLEAR", "UPLOAD", "STOP", "NONE", "HOVER" + ] + + if action_type not in valid_types: + validation_result["valid"] = False + validation_result["invalid_fields"].append(f"Invalid action_type: {action_type}") + + # Type-specific validations + if action_type == "TYPE" and "element_id" not in action: + validation_result["valid"] = False + validation_result["missing_fields"].append("element_id for TYPE action") + + if action_type == "CLICK" and "element_id" not in action: + validation_result["valid"] = False + validation_result["missing_fields"].append("element_id for CLICK action") + + if action_type == "SCROLL" and "direction" not in action: + validation_result["valid"] = False + validation_result["missing_fields"].append("direction for SCROLL action") + + # Check for potential issues (warnings) + if "element_id" in action: + element_id = action["element_id"] + if isinstance(element_id, str) and not element_id.strip(): + validation_result["warnings"].append("Empty element_id detected") + + return validation_result + + def reset_execution_history(self) -> None: + """Reset execution history for a new task.""" + self.execution_history.clear() diff --git a/agent/actor_agent.py b/agent/actor_agent.py new file mode 100644 index 0000000..81d7ed5 --- /dev/null +++ b/agent/actor_agent.py @@ -0,0 +1,156 @@ +"""Actor Agent for executing high-level intentions with specific browser actions.""" + +from typing import Any, Dict, List, Optional + +from PIL import Image + +from browser_env import Trajectory +from browser_env.utils import Observation +from llms import lm_config + +from agent import PromptAgent # Import existing PromptAgent +from .actor.action_executor import ActionExecutor + + +class ActorAgent(PromptAgent): + """Executes high-level intentions using specific browser actions. + + Extends the existing PromptAgent to work with high-level intentions from + the Planner Agent while maintaining compatibility with the existing codebase. + """ + + def __init__( + self, + action_set_tag: str, + lm_config: lm_config.LMConfig, + prompt_constructor, + captioning_fn=None, + ) -> None: + """Initialize Actor Agent with enhanced capabilities.""" + # Initialize parent PromptAgent with existing parameters + super().__init__( + action_set_tag=action_set_tag, + lm_config=lm_config, + prompt_constructor=prompt_constructor, + captioning_fn=captioning_fn, + ) + + # Initialize action executor for validation and tracking + self.action_executor = ActionExecutor(action_set_tag) + + # Track intention execution history + self.intention_history: List[Dict[str, Any]] = [] + + def execute_intention( + self, + intention: str, + current_observation: Observation, + trajectory: Trajectory, + meta_data: Optional[Dict[str, Any]] = None, + images: Optional[List[Image.Image]] = None, + ) -> Dict[str, Any]: + """Execute a high-level intention and generate specific actions. + + Args: + intention: High-level intention from Planner Agent + current_observation: Current page observation + trajectory: Current execution trajectory + meta_data: Additional metadata for execution + images: Optional input images + + Returns: + Dictionary containing execution results + """ + # Record intention execution attempt + execution_record = { + "intention": intention, + "timestamp": None, # Would be set in actual implementation + "observation_before": current_observation, + } + + try: + # Create a simple intention message that works with the existing prompt system + intention_message = f"Execute browser actions to fulfill this intention: {intention}" + + # Use existing PromptAgent's next_action method with the intention message + try: + action = self.next_action( + trajectory=trajectory, + intent=intention_message, + meta_data=meta_data or {}, + images=images, + output_response=False, + ) + except Exception as next_action_error: + print(f"🎬 Actor Error: {str(next_action_error)[:200]}") + print(f"🎬 Error Type: {type(next_action_error).__name__}") + raise next_action_error + + # Extract LLM raw response from action + llm_response = action.get("raw_prediction", "No LLM response available") + + # Validate the generated action (execution will be handled externally) + validation_result = self.action_executor.validate_action(action) + + # Record validation results + execution_record.update({ + "generated_action": action, + "validation_result": validation_result, + "llm_response": llm_response, + # intention_fulfilled will be determined after actual execution + }) + + # Store in intention history + self.intention_history.append(execution_record) + + return { + "action": action, + "validation_result": validation_result, + "intention": intention, + # intention_fulfilled will be determined by actual browser execution + "intention_fulfilled": False, # Default to False, will be updated after execution + "execution_history_length": len(self.intention_history), + "llm_response": llm_response, + "response": f"LLM Response: {llm_response[:200]}{'...' if len(llm_response) > 200 else ''}", + } + + except Exception as e: + # Provide more detailed error information + error_details = str(e) + if "prompt_constructor" in error_details.lower(): + error_details += " (Prompt constructor issue)" + elif "next_action" in error_details.lower(): + error_details += " (next_action method failure)" + elif "traject" in error_details.lower(): + error_details += " (Trajectory processing issue)" + + # Record failed execution + execution_record.update({ + "error": error_details, + "exception_type": type(e).__name__, + }) + self.intention_history.append(execution_record) + + return { + "error": error_details, + "intention": intention, + "intention_fulfilled": False, + "exception_type": type(e).__name__, + "response": f"Execution failed: {error_details}", + } + + def reset_intention_history(self) -> None: + """Reset intention execution history for a new task.""" + self.intention_history.clear() + self.action_executor.reset_execution_history() + + def get_recent_intentions(self, count: int = 5) -> List[Dict[str, Any]]: + """Get the most recent intention executions. + + Args: + count: Number of recent intentions to return + + Returns: + List of recent intention execution records + """ + return self.intention_history[-count:] if self.intention_history else [] \ No newline at end of file diff --git a/agent/agent.py b/agent/agent.py index 5fbcba9..38d2235 100644 --- a/agent/agent.py +++ b/agent/agent.py @@ -116,7 +116,7 @@ def __init__( self.captioning_fn = captioning_fn # Check if the model is multimodal. - if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model) and type(prompt_constructor) == MultimodalCoTPromptConstructor: + if type(prompt_constructor) == MultimodalCoTPromptConstructor: self.multimodal_inputs = True else: self.multimodal_inputs = False @@ -132,9 +132,14 @@ def next_action( # Create page screenshot image for multimodal models. if self.multimodal_inputs: page_screenshot_arr = trajectory[-1]["observation"]["image"] - page_screenshot_img = Image.fromarray( - page_screenshot_arr - ) # size = (viewport_width, viewport_width) + if page_screenshot_arr is not None: + page_screenshot_img = Image.fromarray( + page_screenshot_arr + ) # size = (viewport_width, viewport_width) + else: + # Fallback: create empty image if image is None + print("WARNING: No page screenshot image found, creating empty image.") + page_screenshot_img = Image.new('RGB', (1280, 720), color='white') # Caption the input image, if provided. if images is not None and len(images) > 0: diff --git a/agent/context/__init__.py b/agent/context/__init__.py new file mode 100644 index 0000000..bd0cc91 --- /dev/null +++ b/agent/context/__init__.py @@ -0,0 +1,86 @@ +"""Context manager for agent execution history and progress tracking.""" + +from typing import Any, Dict, List + +from browser_env import Action +from browser_env.utils import Observation + + +class StateManager: + """Manages execution history for context awareness.""" + + def __init__(self) -> None: + self.observations: List[Observation] = [] + self.actions: List[Action] = [] + self.reflections: List[Dict[str, Any]] = [] + self.intentions: List[str] = [] + self.user_goal: str = "" + + def add_observation(self, observation: Observation) -> None: + """Add a new observation to the history.""" + self.observations.append(observation) + + def add_action(self, action: Action) -> None: + """Add a new action to the history.""" + self.actions.append(action) + + def add_reflection(self, reflection: Dict[str, Any]) -> None: + """Add a new reflection to the history.""" + self.reflections.append(reflection) + + def add_intention(self, intention: str) -> None: + """Add a new intention to the history.""" + self.intentions.append(intention) + + def get_all_observations(self) -> List[Observation]: + """Get all observations.""" + return self.observations + + def get_all_actions(self) -> List[Action]: + """Get all actions.""" + return self.actions + + def get_all_reflections(self) -> List[Dict[str, Any]]: + """Get all reflections.""" + return self.reflections + + def get_all_intentions(self) -> List[str]: + """Get all intentions.""" + return self.intentions + + def get_latest_observation(self) -> Observation: + """Get the most recent observation.""" + return self.observations[-1] if self.observations else None + + def get_latest_action(self) -> Action: + """Get the most recent action.""" + return self.actions[-1] if self.actions else None + + def get_history(self) -> Dict[str, Any]: + """Get complete execution history.""" + return { + "observations": self.observations, + "actions": self.actions, + "reflections": self.reflections, + "intentions": self.intentions, + "total_steps": len(self.actions), + "total_observations": len(self.observations), + "total_reflections": len(self.reflections), + "total_intentions": len(self.intentions), + } + + def set_user_goal(self, user_goal: str) -> None: + """Set the user goal for this task.""" + self.user_goal = user_goal + + def get_user_goal(self) -> str: + """Get the user goal for this task.""" + return self.user_goal + + def clear(self) -> None: + """Clear all history.""" + self.observations.clear() + self.actions.clear() + self.reflections.clear() + self.intentions.clear() + self.user_goal = "" \ No newline at end of file diff --git a/agent/context/summary_generator.py b/agent/context/summary_generator.py new file mode 100644 index 0000000..0a27465 --- /dev/null +++ b/agent/context/summary_generator.py @@ -0,0 +1,189 @@ +"""Context summary generation for maintaining agent awareness.""" + +from typing import Any, Dict, List + +from browser_env import Action +from browser_env.utils import Observation +from llms import lm_config, call_llm +from ..prompts.prompt_loader import load_prompt_template + + +class SummaryGenerator: + """Generates context summaries for agent coordination.""" + + def __init__(self, lm_config: lm_config.LMConfig) -> None: + self.lm_config = lm_config + + def generate_summary( + self, + user_goal: str, + observations: List[Observation], + actions: List[Action], + reflections: List[Dict[str, Any]], + ) -> str: + """Generate a comprehensive context summary. + + Args: + user_goal: Original user goal + observations: List of all observations + actions: List of all actions taken + reflections: List of all reflections + + Returns: + Generated context summary string + """ + # Extract key information + total_steps = len(actions) + + # Get recent context + recent_observations = observations[-2:] if len(observations) >= 2 else observations + recent_actions = actions[-3:] if len(actions) >= 3 else actions + recent_reflections = reflections[-2:] if len(reflections) >= 2 else reflections + + # Build context string + observation_summary = self._summarize_observations(recent_observations) + action_summary = self._summarize_actions(recent_actions) + reflection_summary = self._summarize_reflections(recent_reflections) + + # Generate summary using prompt template + prompt = load_prompt_template( + "context_agent", + "summary_generation", + user_goal=user_goal, + total_steps=total_steps, + observation_summary=observation_summary, + action_summary=action_summary, + reflection_summary=reflection_summary + ) + + try: + summary = call_llm( + self.lm_config, [{"role": "user", "content": prompt}] + ).strip() + except Exception as e: + # Fallback summary + print(f"caught exception {e}") + print("Context Agent call_llm failed, use fallback summary") + summary = f"""Task execution after {total_steps} steps. +Recent observations: {len(recent_observations)} pages viewed. +Recent actions: {len(recent_actions)} actions taken. +Recent reflections: {len(recent_reflections)} performance analyses.""" + + return summary, observation_summary, action_summary, reflection_summary + + def _summarize_observations(self, observations: List[Observation]) -> str: + """Summarize recent observations using LLM.""" + if not observations: + return "No page observations available." + + # Extract text content from recent observations + observation_texts = [] + for i, obs in enumerate(observations[-3:], 1): # Last 3 observations + text = obs.get("text", "") + if text: + # Truncate for brevity but keep meaningful content + truncated_text = text[:800] if len(text) > 800 else text + observation_texts.append(f"Page {i}: {truncated_text}") + + if not observation_texts: + return "No meaningful page content." + + # Use LLM to generate intelligent summary of observations + try: + prompt = load_prompt_template( + "context_agent", + "observation_summarization", + observations_text="\n\n".join(observation_texts) + ) + + summary = call_llm( + self.lm_config, [{"role": "user", "content": prompt}] + ).strip() + return summary + except Exception as e: + # Fallback to simple concatenation if LLM fails + print(f"Observation summarization LLM call failed: {e}") + return " | ".join(observation_texts) + + def _summarize_actions(self, actions: List[Action]) -> str: + """Summarize recent actions.""" + if not actions: + return "No actions taken yet." + + action_summaries = [] + for i, action in enumerate(actions[-5:], 1): # Last 5 actions + action_type = action.get("action_type", "unknown") + element_id = action.get("element_id", "N/A") + + if action_type == "TYPE": + text = action.get("text", [])[:3] # First 3 characters + action_summaries.append(f"{i}. Type[{element_id}]: {''.join(text)}...") + elif action_type == "CLICK": + action_summaries.append(f"{i}. Click[{element_id}]") + elif action_type == "SCROLL": + direction = action.get("direction", "unknown") + action_summaries.append(f"{i}. Scroll[{direction}]") + else: + action_summaries.append(f"{i}. {action_type}") + + return " | ".join(action_summaries) if action_summaries else "No recent actions." + + def _summarize_reflections(self, reflections: List[Dict[str, Any]]) -> str: + """Summarize recent reflections using LLM.""" + if not reflections: + return "No performance reflections available." + + # Extract text content from recent reflections + reflection_texts = [] + for i, reflection in enumerate(reflections[-3:], 1): # Last 3 reflections + text_parts = [] + + # Add reflection number and current intention + text_parts.append(f"Reflection {reflection.get('reflection_number', i)}:") + text_parts.append(f"Intention: {reflection.get('current_intention', 'N/A')}") + + # Add effectiveness analysis + effectiveness = reflection.get("effectiveness_analyzer", "") + if effectiveness: + text_parts.append(f"Effectiveness: {effectiveness}") + + # Add triple summary + triple_summary = reflection.get("triple_summary", "") + if triple_summary: + text_parts.append(f"State Transition: {triple_summary}") + + # Add pattern detection summary + pattern_detector = reflection.get("pattern_detector", {}) + if pattern_detector.get("patterns_detected", False): + detection_summary = pattern_detector.get("detection_summary", "") + if detection_summary: + text_parts.append(f"Patterns: {detection_summary}") + + # Add latest action summary + latest_action = reflection.get("latest_action", {}) + if latest_action: + action_type = latest_action.get("action_type", "UNKNOWN") + element_id = latest_action.get("element_id", "N/A") + text_parts.append(f"Action: {action_type} on {element_id}") + + reflection_texts.append("\n".join(text_parts)) + + if not reflection_texts: + return "No meaningful reflection content available." + + # Use LLM to generate intelligent summary of reflections + try: + prompt = load_prompt_template( + "context_agent", + "reflection_summarization", + reflections_text="\n\n".join(reflection_texts) + ) + + summary = call_llm( + self.lm_config, [{"role": "user", "content": prompt}] + ).strip() + return summary + except Exception as e: + # Fallback to simple concatenation if LLM fails + print(f"Reflection summarization LLM call failed: {e}") + return " | ".join(reflection_texts[:2]) # Show first 2 reflections \ No newline at end of file diff --git a/agent/context_agent.py b/agent/context_agent.py new file mode 100644 index 0000000..b5ad8a6 --- /dev/null +++ b/agent/context_agent.py @@ -0,0 +1,252 @@ +"""Context Agent for managing global state and progress tracking.""" + +from typing import Any, Dict, List, Optional + +import torch + +from browser_env import Action, Trajectory +from browser_env.utils import Observation +from llms import lm_config + +from .context.summary_generator import SummaryGenerator +from .context import StateManager +from .prompts.prompt_loader import generate_llm_prompt_from_template +from .memory import MemoryBank, MemoryGenerator + + +class ContextAgent: + """Manages global state and context summarization. + + Responsible for maintaining task execution history and generating + comprehensive context summaries for other agents. + """ + + def __init__(self, lm_config: lm_config.LMConfig, memory_config: Dict[str, Any]) -> None: + self.lm_config = lm_config + self.state_manager = StateManager() + self.summary_generator = SummaryGenerator(lm_config) + + + # Initialize memory system + self.enable_memory = memory_config.get("enable_memory", False) + self.enable_memory_store = memory_config.get("enable_memory_store", False) + self.memory_content = "" + if self.enable_memory or self.enable_memory_store: + device = torch.device("cuda") if torch.cuda.is_available() else "cpu" + self.memory_bank = MemoryBank( + memory_dir=memory_config.get('memory_dir', 'agent_memories'), + embedding_model=memory_config.get('embedding_model', 'sentence-transformers/all-MiniLM-L6-v2'), + top_k=memory_config.get('top_k', 3), + device=device + ) + self.memory_generator = MemoryGenerator(lm_config) + self.window_size = memory_config.get('window_size', 3) + + + def update_state(self, + current_observation: Optional[Observation] = None, + latest_intention: Optional[str] = None, + latest_action: Optional[Action] = None, + latest_reflection: Optional[Dict[str, Any]] = None, + ) -> None: + """Update context state.""" + if current_observation: + self.state_manager.add_observation(current_observation) + + if latest_intention: + self.state_manager.add_intention(latest_intention) + + if latest_action: + self.state_manager.add_action(latest_action) + + if latest_reflection: + self.state_manager.add_reflection(latest_reflection) + + def update_context( + self, + trajectory: Trajectory, + user_goal: str, + current_observation: Optional[Observation] = None, + latest_intention: Optional[str] = None, + latest_action: Optional[Action] = None, + latest_reflection: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Update context state and generate comprehensive summary. + + Args: + trajectory: Current execution trajectory + user_goal: Original user goal/task + current_observation: Latest page observation + latest_action: Most recent action taken + latest_reflection: Most recent reflection from Reflector Agent + + Returns: + Dictionary containing updated context information + """ + # Update state manager with new information (avoid duplicates) + + self.update_state( + current_observation=current_observation, + latest_intention=latest_intention, + latest_action=latest_action, + latest_reflection=latest_reflection, + ) + + # Get complete execution history + history = self.state_manager.get_history() + + # Generate context summary without progress metrics + summary, observation_summary, action_summary, reflection_summary = self.summary_generator.generate_summary( + user_goal=user_goal, + observations=self.state_manager.get_all_observations(), + actions=self.state_manager.get_all_actions(), + reflections=self.state_manager.get_all_reflections(), + ) + + # Return comprehensive context information + return { + "summary": summary, + "memory_content": self.memory_content, + "observation_summary": observation_summary, + "action_summary": action_summary, + "reflection_summary": reflection_summary, + "state_history": history, + "latest_observation": self.state_manager.get_latest_observation(), + "latest_action": self.state_manager.get_latest_action(), + } + + def reset(self) -> None: + """Reset all context state for a new task.""" + self.state_manager.clear() + + def get_current_state(self) -> Dict[str, Any]: + """Get current context state without updating.""" + history = self.state_manager.get_history() + return { + "state_history": history, + "latest_observation": self.state_manager.get_latest_observation(), + "latest_action": self.state_manager.get_latest_action(), + "total_steps": history.get("total_steps", 0), + } + + def check_task_completion(self, user_goal: str) -> bool: + """Check if the task is considered complete based on current state. + + This is a simplified check - task completion is primarily determined + by reaching max_steps or explicit STOP action in the coordinator. + + Args: + user_goal: Original user goal + + Returns: + True if task appears complete, False otherwise + """ + # Task completion is determined by the coordinator based on: + # 1. Max steps reached + # 2. Actor generating a STOP action + # This method provides a basic fallback check + history = self.state_manager.get_history() + actions = history.get("actions", []) + + # Check if the last action was a STOP action + if actions and actions[-1].get("action_type") == "STOP": + return True + + return False + + def initialize_task_memory(self, user_goal: str) -> Dict[str, Any]: + """Initialize memory tracking for a new task. + + Args: + user_goal: The user's goal/task description + + Returns: + Relevant memories from past similar tasks + """ + if not self.enable_memory: + self.memory_content = "" + return {"memory_content": "", "relevant_memories": []} + + # Store user goal in state manager for memory generation + self.state_manager.set_user_goal(user_goal) + + # Retrieve relevant memories for this task + try: + relevant_memories = self.memory_bank.search_memories(query=user_goal) + + self.memory_content = self._get_mem_str(relevant_memories) + return { + "memory_content": self.memory_content, + "relevant_memories": relevant_memories, + } + except Exception as e: + print(f"Error retrieving memories: {e}") + self.memory_content = "" + return { + "memory_content": "", + "relevant_memories": [], + } + + def _get_mem_str(self, memories: List[Dict[str, Any]]) -> str: + """Convert list of memories to a formatted string. + + Args: + memories: List of memory dictionaries from memory bank + + Returns: + Formatted string of memories + """ + if not memories: + return "" + + mem_str = "" + for i, mem in enumerate(memories): + mem_str += f"Memory {i+1}:\n" + mem_str += f"Title: {mem['title']}\n" + mem_str += f"Description: {mem['description']}\n" + mem_str += f"Content: {mem['content']}\n\n" + + return mem_str.strip() + + def generate_and_store_memory(self, task_completed: bool) -> Optional[str]: + """Generate and store memory from completed task. + + Args: + task_completed: Whether the task was completed successfully + + Returns: + Memory ID if successful, None otherwise + """ + if not self.enable_memory_store: + return None + + try: + # Get user goal from state manager + user_goal = self.state_manager.get_user_goal() + + # Generate memory from trajectory using state manager data + memory_data = self.memory_generator.generate_memory_from_trajectory( + user_goal=user_goal, + observations=self.state_manager.get_all_observations(), + intentions=self.state_manager.get_all_intentions(), + actions=self.state_manager.get_all_actions(), + reflections=self.state_manager.get_all_reflections(), + task_completed=task_completed, + window_size=self.window_size, + ) + + # Store memory in bank + memory_id = self.memory_bank.add_memory( + title=memory_data["title"], + description=memory_data["description"], + content=memory_data["content"], + success_rate=memory_data.get("success_rate", 0.0), + ) + + print(f"💾 Memory stored: {memory_data['title'][:50]}... (ID: {memory_id})") + return memory_id + + except Exception as e: + print(f"Error generating memory: {e}") + return None + diff --git a/agent/coordinator/__init__.py b/agent/coordinator/__init__.py new file mode 100644 index 0000000..d352ad7 --- /dev/null +++ b/agent/coordinator/__init__.py @@ -0,0 +1,6 @@ +"""Multi-agent coordination components.""" + +from .workflow_manager import WorkflowManager +from .communication_hub import CommunicationHub + +__all__ = ["WorkflowManager", "CommunicationHub"] \ No newline at end of file diff --git a/agent/coordinator/communication_hub.py b/agent/coordinator/communication_hub.py new file mode 100644 index 0000000..6ab2365 --- /dev/null +++ b/agent/coordinator/communication_hub.py @@ -0,0 +1,70 @@ +"""Communication hub for multi-agent coordination.""" + +from typing import Any, Dict, Optional + + +class CommunicationHub: + """Manages shared context between different agents.""" + + def __init__(self) -> None: + self.agent_states: Dict[str, Any] = {} + self.shared_context: Dict[str, Any] = {} + + def register_agent(self, agent_name: str, initial_state: Dict[str, Any]) -> None: + """Register an agent with the communication hub. + + Args: + agent_name: Name of the agent + initial_state: Initial state of the agent + """ + self.agent_states[agent_name] = { + "state": initial_state, + } + + def update_agent_state(self, agent_name: str, state_update: Dict[str, Any]) -> None: + """Update the state of a registered agent. + + Args: + agent_name: Name of the agent + state_update: State update information + """ + if agent_name in self.agent_states: + self.agent_states[agent_name]["state"].update(state_update) + + def update_shared_context(self, context_key: str, context_value: Any) -> None: + """Update shared context accessible to all agents. + + Args: + context_key: Key for the shared context + context_value: Value for the shared context + """ + self.shared_context[context_key] = context_value + + def get_shared_context(self, context_key: Optional[str] = None) -> Any: + """Get shared context information. + + Args: + context_key: Optional specific key to retrieve + + Returns: + Shared context value or entire context dictionary + """ + if context_key: + return self.shared_context.get(context_key) + return self.shared_context.copy() + + def get_agent_state(self, agent_name: str) -> Optional[Dict[str, Any]]: + """Get the current state of a registered agent. + + Args: + agent_name: Name of the agent + + Returns: + Current state of the agent or None if not found + """ + return self.agent_states.get(agent_name, {}).get("state") + + def reset(self) -> None: + """Reset the communication hub for a new task.""" + self.agent_states.clear() + self.shared_context.clear() diff --git a/agent/coordinator/workflow_manager.py b/agent/coordinator/workflow_manager.py new file mode 100644 index 0000000..e997269 --- /dev/null +++ b/agent/coordinator/workflow_manager.py @@ -0,0 +1,161 @@ +"""Workflow management for multi-agent coordination.""" + +from typing import Any, Dict, List, Optional + +from browser_env import Action +from browser_env.utils import Observation + + +class WorkflowManager: + """Manages the workflow and execution flow for multiple agents.""" + + def __init__(self) -> None: + self.current_step = 0 + self.max_steps = 30 + self.workflow_state = "running" + self.execution_history: List[Dict[str, Any]] = [] + + def initialize_workflow(self, max_steps: int = 30) -> Dict[str, Any]: + """Initialize the workflow for a new task. + + Args: + max_steps: Maximum number of execution steps + + Returns: + Dictionary containing workflow initialization info + """ + self.current_step = 0 + self.max_steps = max_steps + self.workflow_state = "running" + self.execution_history.clear() + + return { + "max_steps": max_steps, + "current_step": 0, + "workflow_state": "initialized", + } + + def should_continue_execution( + self, context_summary: Dict[str, Any] + ) -> Dict[str, Any]: + """Determine if execution should continue. + + Args: + context_summary: Current context from Context Agent (currently unused) + + Returns: + Dictionary containing continuation decision and reasoning + """ + # Check step limit - this is the primary stopping condition + if self.current_step >= self.max_steps: + return { + "should_continue": False, + "reason": "Maximum steps reached", + "stop_type": "step_limit", + } + + # Default: continue execution + return { + "should_continue": True, + "reason": "Execution should continue", + "stop_type": "none", + } + + def record_execution_step( + self, + step_number: int, + intention: str, + action: Action, + observation: Observation, + reflection: Dict[str, Any], + execution_time: Optional[float] = None, + ) -> None: + """Record a complete execution step. + + Args: + step_number: Current step number + intention: The intention that was being fulfilled + action: The action that was executed + observation: The result observation + reflection: The reflection on the execution + execution_time: Time taken for this step + """ + step_record = { + "step_number": step_number, + "intention": intention, + "action": action, + "observation": observation, + "reflection": reflection, + "execution_time": execution_time, + } + + self.execution_history.append(step_record) + self.current_step = step_number + + def get_workflow_statistics(self) -> Dict[str, Any]: + """Get statistics about the current workflow execution. + + Returns: + Dictionary containing workflow statistics + """ + if not self.execution_history: + return { + "total_steps": 0, + "workflow_state": self.workflow_state, + "message": "No execution history available", + } + + total_steps = len(self.execution_history) + + # Count action types + action_types = {} + for record in self.execution_history: + action = record.get("action", {}) + action_type = action.get("action_type", "UNKNOWN") + action_types[action_type] = action_types.get(action_type, 0) + 1 + + return { + "total_steps": total_steps, + "current_step": self.current_step, + "max_steps": self.max_steps, + "workflow_state": self.workflow_state, + "progress_percentage": (self.current_step / self.max_steps) * 100, + "action_type_distribution": action_types, + } + + def finalize_workflow(self, final_state: str, completion_reason: str) -> Dict[str, Any]: + """Finalize the workflow execution. + + Args: + final_state: Final state of the workflow + completion_reason: Reason for workflow completion + + Returns: + Dictionary containing workflow finalization info + """ + self.workflow_state = final_state + + finalization_record = { + "final_state": final_state, + "completion_reason": completion_reason, + "total_steps": self.current_step, + "max_steps": self.max_steps, + "statistics": self.get_workflow_statistics(), + "execution_summary": self._generate_execution_summary(), + } + + return finalization_record + + def _generate_execution_summary(self) -> str: + """Generate a summary of the workflow execution.""" + if not self.execution_history: + return "No execution steps recorded" + + stats = self.get_workflow_statistics() + return f"Workflow completed with {stats['total_steps']} steps out of {stats['max_steps']} maximum." + + def reset_workflow(self) -> None: + """Reset the workflow for a new task.""" + self.current_step = 0 + self.workflow_state = "ready" + self.execution_history.clear() diff --git a/agent/memory/__init__.py b/agent/memory/__init__.py new file mode 100644 index 0000000..2696bf8 --- /dev/null +++ b/agent/memory/__init__.py @@ -0,0 +1,9 @@ +"""Memory module for multi-agent experience storage and retrieval.""" + +from .memory_bank import MemoryBank +from .memory_generator import MemoryGenerator + +__all__ = [ + 'MemoryBank', + 'MemoryGenerator' +] \ No newline at end of file diff --git a/agent/memory/memory_bank.py b/agent/memory/memory_bank.py new file mode 100644 index 0000000..2e0c96e --- /dev/null +++ b/agent/memory/memory_bank.py @@ -0,0 +1,376 @@ +"""Memory Bank for storing and managing agent experiences.""" + +import json +import os +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple +import pickle + +import numpy as np +import torch + + +class MemoryBank: + """Memory Bank for storing and managing agent experiences. + + Supports memory storage, retrieval, and management with optimized numpy-based embeddings. + Uses an internal ID mapping system to handle deletions without breaking embeddings. + """ + + def __init__(self, memory_dir: str = "agent_memories", + embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", + top_k: int = 3, + device: str = "cpu") -> None: + self.memory_dir = memory_dir + self.embedding_model = embedding_model + self.device = device + self.top_k = top_k + + # Create memory directory if it doesn't exist + os.makedirs(memory_dir, exist_ok=True) + + # Memory storage files + self.memory_file = os.path.join(memory_dir, "memories.json") + self.embedding_file = os.path.join(memory_dir, "embeddings.npy") + self.metadata_file = os.path.join(memory_dir, "metadata.json") + self.id_mapping_file = os.path.join(memory_dir, "id_mapping.json") + self.embeddings = None + + # Load existing memories + self.memories = self._load_memories() + self.embeddings = self._load_embeddings() + self.metadata = self._load_metadata() + self.id_mapping = self._load_id_mapping() + + # Initialize embedding manager (lazy loading) + self._embedding_manager = None + + # Cache reverse mapping for search optimization + self._emb_id_to_mem_id = None + self._mapping_dirty = True # Flag to indicate if mapping needs update + + def _get_embedding_manager(self): + """Lazy load embedding manager.""" + if self._embedding_manager is None: + try: + from sentence_transformers import SentenceTransformer + self._embedding_manager = SentenceTransformer(self.embedding_model, device=self.device) + except ImportError: + print("Warning: sentence-transformers not installed. Using simple text similarity.") + self._embedding_manager = None + return self._embedding_manager + + def _load_memories(self) -> Dict[str, Any]: + """Load memories from JSON file.""" + if os.path.exists(self.memory_file): + try: + with open(self.memory_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"Error loading memories: {e}") + return {} + return {} + + def _load_embeddings(self) -> Optional[np.ndarray]: + """Load embeddings from numpy file.""" + if os.path.exists(self.embedding_file): + try: + embeddings = np.load(self.embedding_file, allow_pickle=True) + # Convert to proper numpy array if loaded as object array + if embeddings.dtype == object: + embeddings = np.array([np.array(e) for e in embeddings]) + return embeddings + except Exception as e: + print(f"Error loading embeddings: {e}") + return None + return None + + def _load_metadata(self) -> Dict[str, Any]: + """Load metadata from JSON file.""" + if os.path.exists(self.metadata_file): + try: + with open(self.metadata_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"Error loading metadata: {e}") + return {"total_memories": 0, "last_updated": None, "next_mem_id": 0} + return {"total_memories": 0, "last_updated": None, "next_mem_id": 0} + + def _load_id_mapping(self) -> Dict[int, int]: + """Load ID mapping from JSON file.""" + if os.path.exists(self.id_mapping_file): + try: + with open(self.id_mapping_file, 'r', encoding='utf-8') as f: + mapping = json.load(f) + # Convert string keys back to int + return {int(k): v for k, v in mapping.items()} + except Exception as e: + print(f"Error loading ID mapping: {e}") + + # If no mapping exists, create one-to-one mapping for existing memories + mapping = {} + mem_ids = [int(id_str) for id_str in self.memories.keys()] + mem_ids.sort() + for i, mem_id in enumerate(mem_ids): + mapping[mem_id] = i + return mapping + + def _save_id_mapping(self) -> None: + """Save ID mapping to JSON file.""" + try: + with open(self.id_mapping_file, 'w', encoding='utf-8') as f: + json.dump(self.id_mapping, f, indent=2) + except Exception as e: + print(f"Error saving ID mapping: {e}") + + def _save_memories(self) -> None: + """Save memories to JSON file.""" + try: + with open(self.memory_file, 'w', encoding='utf-8') as f: + json.dump(self.memories, f, indent=2, ensure_ascii=False) + except Exception as e: + print(f"Error saving memories: {e}") + + def _save_embeddings(self) -> None: + """Save embeddings to numpy file.""" + if self.embeddings is not None: + try: + np.save(self.embedding_file, self.embeddings) + except Exception as e: + print(f"Error saving embeddings: {e}") + + def _save_metadata(self) -> None: + """Save metadata to JSON file.""" + try: + self.metadata["total_memories"] = len(self.memories) + self.metadata["last_updated"] = datetime.now().isoformat() + with open(self.metadata_file, 'w', encoding='utf-8') as f: + json.dump(self.metadata, f, indent=2, ensure_ascii=False) + except Exception as e: + print(f"Error saving metadata: {e}") + + def _create_embedding(self, text: str) -> Optional[np.ndarray]: + """Create embedding for text.""" + manager = self._get_embedding_manager() + if manager is None: + return None + + try: + embedding = manager.encode(text, convert_to_numpy=True) + return embedding + except Exception as e: + print(f"Error creating embedding: {e}") + return None + + def _update_reverse_mapping(self) -> None: + """Update the cached reverse mapping from embedding index to memory ID.""" + self._emb_id_to_mem_id = {v: k for k, v in self.id_mapping.items()} + self._mapping_dirty = False + + def add_memory(self, + title: str, + description: str, + content: str, + success_rate: float = 0.0,) -> str: + """Add a new memory to the bank. + + Args: + title: Short title for the memory + description: Brief description of the memory + content: Detailed content/experience + success_rate: Success rate for this type of task (0.0-1.0) + + Returns: + Memory ID (integer starting from 0) + """ + # Get next memory ID + mem_id = self.metadata.get("next_mem_id", 0) + + memory = { + "title": title, + "description": description, + "content": content, + "success_rate": success_rate, + "created_at": datetime.now().isoformat(), + "access_count": 0, + "last_accessed": None + } + + # Create embedding for search + search_text = f"{title} {description} {content}" + embedding = self._create_embedding(search_text) + + # Add embedding to numpy array + if embedding is not None: + if self.embeddings is None: + self.embeddings = np.array([embedding]) + embedding_index = 0 + else: + embedding_index = len(self.embeddings) + self.embeddings = np.vstack([self.embeddings, embedding]) + + # Update ID mapping + self.id_mapping[mem_id] = embedding_index + # Mark mapping as dirty since it changed + self._mapping_dirty = True + + # Add to memories + self.memories[str(mem_id)] = memory + + # Update next memory ID + self.metadata["next_mem_id"] = mem_id + 1 + + # Save changes + self._save_memories() + self._save_embeddings() + self._save_metadata() + self._save_id_mapping() + + return str(mem_id) + + def get_memory(self, memory_id: int) -> Optional[Dict[str, Any]]: + """Get memory by ID.""" + memory = self.memories.get(str(memory_id)) + if memory is None: + return None + + # Update access statistics + memory["access_count"] += 1 + memory["last_accessed"] = datetime.now().isoformat() + self._save_memories() + + return memory + + def get_all_memories(self) -> List[Dict[str, Any]]: + """Get all memories.""" + return list(self.memories.values()) + + def update_memory(self, memory_id: int, updates: Dict[str, Any]) -> bool: + """Update memory with new data.""" + memory = self.memories.get(str(memory_id)) + if memory is None: + return False + + # Update fields + for key, value in updates.items(): + if key in memory: + memory[key] = value + + # Update embedding if content changed + if "title" in updates or "description" in updates or "content" in updates: + search_text = f"{memory['title']} {memory['description']} {memory['content']}" + embedding = self._create_embedding(search_text) + if embedding is not None and self.embeddings is not None and memory_id in self.id_mapping: + embedding_idx = self.id_mapping[memory_id] + if 0 <= embedding_idx < len(self.embeddings): + self.embeddings[embedding_idx] = embedding + + self._save_memories() + self._save_embeddings() + return True + + def search_memories(self, query: str) -> List[Dict[str, Any]]: + """Search memories by text using PyTorch-optimized vector operations.""" + if not self.memories or self.embeddings is None: + return [] + + # Update reverse mapping cache if needed + if self._mapping_dirty or self._emb_id_to_mem_id is None: + self._update_reverse_mapping() + + # Create query embedding + query_embedding = self._create_embedding(query) + if query_embedding is None: + return [] + + # Convert to PyTorch tensors + query_tensor = torch.from_numpy(query_embedding).float().to(self.device) + embeddings_tensor = torch.from_numpy(self.embeddings).float().to(self.device) + + # Ensure proper shape + if query_tensor.dim() == 1: + query_tensor = query_tensor.unsqueeze(0) # [1, embedding_dim] + + # Normalize using PyTorch operations for GPU acceleration + query_norm = torch.nn.functional.normalize(query_tensor, p=2, dim=1) + embeddings_norm = torch.nn.functional.normalize(embeddings_tensor, p=2, dim=1) + + # Calculate cosine similarities using matrix multiplication + similarities = torch.mm(query_norm, embeddings_norm.T).squeeze(0) + + # Get top-k values and indices + top_k = min(self.top_k, len(similarities)) + top_values, top_indices = torch.topk(similarities, top_k) + + # Convert to CPU for processing + top_indices_cpu = top_indices.cpu().numpy() + similarities_cpu = top_values.cpu().numpy() + + # Return corresponding memories + result_memories = [] + for idx, embedding_idx in enumerate(top_indices_cpu): + if similarities_cpu[idx] > 0: # Only return memories with some similarity + mem_id = self._emb_id_to_mem_id.get(int(embedding_idx)) + if mem_id is not None: + memory = self.get_memory(mem_id) + if memory: + # Add similarity score to memory + memory_copy = memory.copy() + memory_copy['similarity_score'] = float(similarities_cpu[idx]) + result_memories.append(memory_copy) + + return result_memories + + + def delete_memory(self, memory_id: int) -> bool: + """Delete a memory from the bank. + + Args: + memory_id: ID of the memory to delete + + Returns: + True if memory was deleted, False if memory was not found + """ + memory_str_id = str(memory_id) + if memory_str_id not in self.memories: + return False + + # Remove from memories + del self.memories[memory_str_id] + + # Update embeddings and ID mapping + if memory_id in self.id_mapping and self.embeddings is not None: + embedding_idx = self.id_mapping[memory_id] + + # Remove embedding from the array + self.embeddings = np.delete(self.embeddings, embedding_idx, axis=0) + + # Remove from ID mapping + del self.id_mapping[memory_id] + + # Update all ID mappings that were after the deleted embedding + for mem_id_key, emb_idx in list(self.id_mapping.items()): + if emb_idx > embedding_idx: + self.id_mapping[mem_id_key] = emb_idx - 1 + + # Mark mapping as dirty since it changed + self._mapping_dirty = True + + # Save all changes + self._save_memories() + self._save_embeddings() + self._save_metadata() + self._save_id_mapping() + + return True + + def get_statistics(self) -> Dict[str, Any]: + """Get memory bank statistics.""" + return { + "total_memories": len(self.memories), + "total_embeddings": self.embeddings.shape[0] if self.embeddings is not None else 0, + "embedding_dim": self.embeddings.shape[1] if self.embeddings is not None else 0, + "last_updated": self.metadata.get("last_updated"), + "next_mem_id": self.metadata.get("next_mem_id", 0) + } + diff --git a/agent/memory/memory_generator.py b/agent/memory/memory_generator.py new file mode 100644 index 0000000..1d05533 --- /dev/null +++ b/agent/memory/memory_generator.py @@ -0,0 +1,383 @@ +"""Memory Generator for creating experiences from trajectories.""" + +from typing import Any, Dict, List, Optional +from datetime import datetime +import json + +from browser_env.utils import Observation + +from llms import lm_config, call_llm +from ..prompts.prompt_loader import load_prompt_template +from browser_env import ( + Action, + action2str, +) + + +class MemoryGenerator: + """Memory Generator for creating experiences from trajectories. + + Implements iterative memory generation with two-phase approach: + 1. Initial memory generation from early trajectory + 2. Iterative refinement with full trajectory + """ + + def __init__(self, lm_config: lm_config.LMConfig) -> None: + self.lm_config = lm_config + + def generate_memory_from_trajectory(self, + user_goal: str, + observations: List[Observation], + intentions: List[str], + actions: List[Action], + reflections: List[Dict[str, Any]], + task_completed: bool, + window_size: int = 3 + ) -> Dict[str, Any]: + """Generate memory from complete trajectory. + + Args: + user_goal: Original user goal + intentions: List of intentions generated + actions: List of actions taken + reflections: List of reflections + task_completed: Whether the task was completed successfully + + Returns: + Generated memory dictionary + """ + if len(actions) == 0: + return {} + + + assert(len(intentions) == len(actions)) + assert(len(reflections) == len(actions)) + assert( (len(observations)-1) == len(actions)) + + + # Phase 1: Initial memory generation from first part of trajectory + memory = self._generate_initial_memory( + user_goal, + observations[:window_size+1], # 多包涵一个初始页 + intentions[:window_size], + actions[:window_size], + reflections[:window_size], + task_completed + ) + + memory.update({ + "success_rate": 1.0 if task_completed else 0.0, + "intentions_count": len(intentions), + "actions_count": len(actions), + "reflections_count": len(reflections) + }) + + # Phase 2: Iterative refinement with full trajectory + start_idx = window_size + while start_idx < len(actions): + end_idx = min(start_idx + window_size, len(actions)) + + partial_observations = observations[start_idx:end_idx+1] # 多包涵一个结束页 + partial_intentions = intentions[start_idx:end_idx] + partial_actions = actions[start_idx:end_idx] + partial_reflections = reflections[start_idx:end_idx] + history_intentions = intentions[:start_idx] + history_actions = actions[:start_idx] + + memory = self._refine_memory_with_trajectory( + memory, + user_goal, + partial_observations, + partial_intentions, + partial_actions, + partial_reflections, + history_intentions, + history_actions, + start_idx, + task_completed + ) + start_idx = end_idx + + return memory + + def _generate_initial_memory(self, + user_goal: str, + partial_observations: List[Observation], + partial_intentions: List[str], + partial_actions: List[Action], + partial_reflections: List[Dict[str, Any]], + task_completed: bool + ) -> Dict[str, Any]: + """Generate initial memory from partial trajectory.""" + try: + # Create summarized trajectory data + trajectory_summary = self._summarize_trajectory( + partial_observations, + partial_intentions, + partial_actions, + partial_reflections, + start_idx = 0 + ) + + # Generate initial memory using LLM + prompt = load_prompt_template( + "memory_generator", + "memory_initial_generation", + user_goal=user_goal, + trajectory_summary=trajectory_summary, + task_completed=task_completed + ) + + response = call_llm( + self.lm_config, [{"role": "user", "content": prompt}] + ).replace("```json", "").replace("```", "").strip() + + # Parse response + try: + memory_data = json.loads(response) + except json.JSONDecodeError as e: + print(f"JSON 解析错误: {e}") + # Fallback parsing if response is not valid JSON + memory_data = self._parse_memory_response(response) + + # Ensure required fields + return { + "title": memory_data.get("title", f"Task: {user_goal[:50]}..."), + "description": memory_data.get("description", f"Experience with task: {user_goal}"), + "content": memory_data.get("content", trajectory_summary[:800]), + "phase": "initial" + } + + except Exception as e: + print(f"Error in initial memory generation: {e}") + # Fallback memory + return { + "title": f"Task Experience: {user_goal[:50]}...", + "description": f"Automatically generated memory for: {user_goal}", + "content": f"Experience with task: {user_goal}", + "phase": "initial_fallback" + } + + def _refine_memory_with_trajectory(self, + initial_memory: Dict[str, Any], + user_goal: str, + partial_observations: List[Observation], + partial_intentions: List[str], + partial_actions: List[Action], + partial_reflections: List[Dict[str, Any]], + history_intentions: List[str], + history_actions: List[Action], + start_idx: int, + task_completed: bool) -> Dict[str, Any]: + """Refine initial memory with full trajectory data.""" + + + try: + # Create comprehensive summaries + trajectory_summary = self._summarize_trajectory( + partial_observations, + partial_intentions, + partial_actions, + partial_reflections, + start_idx = start_idx + ) + + history_intention_text = self._get_history_intention_text(history_intentions) + history_action_text = self._get_history_action_text(history_actions) + + # Generate refined memory using LLM + prompt = load_prompt_template( + "memory_generator", + "memory_refinement", + initial_memory=initial_memory, + user_goal=user_goal, + trajectory_summary=trajectory_summary, + history_intentions=history_intention_text, + history_actions=history_action_text, + task_completed=task_completed + ) + + response = call_llm( + self.lm_config, [{"role": "user", "content": prompt}] + ).replace("```json", "").replace("```", "").strip() + + # Parse refined memory data + try: + refined_data = json.loads(response) + except json.JSONDecodeError as e: + print(f"JSON 解析错误: {e}") + refined_data = self._parse_memory_response(response) + + # Merge with initial memory + refined_memory = initial_memory.copy() + refined_memory.update({ + "title": refined_data.get("title", initial_memory["title"]), + "description": refined_data.get("description", initial_memory["description"]), + "content": refined_data.get("content", initial_memory["content"]), + "phase": "refined" + }) + + return refined_memory + + except Exception as e: + print(f"Error in memory refinement: {e}") + # Return initial memory with minimal refinement + refined_memory = initial_memory.copy() + refined_memory.update({ + "phase": "refined_fallback", + }) + return refined_memory + + def _get_obs_text(self, obs: Observation, idx: int) -> str: + """Extract text from observation, truncating if necessary.""" + text = obs.get("text", "") + if text: + # Truncate for brevity but keep meaningful content + truncated_text = text[:800] + "..." if len(text) > 800 else text + return f"Page {idx}: \n{truncated_text}" + return f"Page {idx}: None" + + def _get_intention_text(self, intention: str, idx: int) -> str: + """Extract text from intention, truncating if necessary.""" + if intention: + # Truncate for brevity but keep meaningful content + truncated_intention = intention[:200] + "..." if len(intention) > 200 else intention + return f"Intention {idx}: {truncated_intention}" + return f"Intention {idx}: None" + + def _get_action_text(self, action: Action, idx: int) -> str: + """Extract text from action, truncating if necessary.""" + + action_str = action2str(action, "som", "") + action_text = f"Action {idx}: {action_str}" + + return action_text + + def _get_reflection_text(self, reflection: Dict[str, Any], idx: int) -> str: + text_parts = [] + + # Add effectiveness analysis + effectiveness = reflection.get("effectiveness_analyzer", "") + + if effectiveness: + effectiveness = effectiveness[:200] + "..." if len(effectiveness) > 200 else effectiveness + text_parts.append(f"Effectiveness: {effectiveness}") + + # Add triple summary + triple_summary = reflection.get("triple_summary", "") + if triple_summary: + triple_summary = triple_summary[:200] + "..." if len(triple_summary) > 200 else triple_summary + text_parts.append(f"State Transition: {triple_summary}") + + # Add pattern detection summary + pattern_detector = reflection.get("pattern_detector", {}) + if pattern_detector.get("patterns_detected", False): + detection_summary = pattern_detector.get("detection_summary", "") + if detection_summary: + detection_summary = detection_summary[:200] + "..." if len(detection_summary) > 200 else detection_summary + text_parts.append(f"Patterns: {detection_summary}") + + reflection_text = f"Reflection {idx}: \n" + "\n".join(text_parts) + + return reflection_text + + def _get_history_intention_text(self, intentions: List[str]) -> str: + """Extract text from intentions, truncating if necessary.""" + intention_texts = [self._get_intention_text(intention, i+1) for i, intention in enumerate(intentions)] + return "\n".join(intention_texts) + + def _get_history_action_text(self, actions: List[Action]) -> str: + """Extract text from actions, truncating if necessary.""" + action_texts = [self._get_action_text(action, i+1) for i, action in enumerate(actions)] + return "\n".join(action_texts) + + def _summarize_trajectory(self, + observations: List[Observation], + intentions: List[str], + actions: List[Action], + reflections: List[Dict[str, Any]], + start_idx: int = 0) -> str: + + """Create a summary of the trajectory.""" + if not observations: + return "Empty trajectory" + + trajectory_texts = [] + trajectory_texts.append(self._get_obs_text(observations[0], start_idx)) + observations = observations[1:] + + for i, (obs, intention, action, reflection) in enumerate( + zip(observations, intentions, actions, reflections), + 1): + idx = start_idx + i + trajectory_texts.append(self._get_intention_text(intention, idx)) + trajectory_texts.append(self._get_action_text(action, idx)) + trajectory_texts.append(self._get_obs_text(obs, idx)) + trajectory_texts.append(self._get_reflection_text(reflection, idx)) + + return "\n\n".join(trajectory_texts) + + def _parse_memory_response(self, response: str) -> Dict[str, Any]: + """Parse memory response when JSON parsing fails with a simplified approach. + + Args: + response: Raw text response from LLM + + Returns: + Dictionary with parsed memory components + """ + import re + + # Initialize result with default values + result = { + 'title': '', + 'description': '', + 'content': '', + } + + if not response: + return result + + # Normalize response + response = response.strip() + lines = response.split('\n') + + # Check for possible JSON fragment first + json_match = re.search(r'(\{[^}]*\})', response, re.DOTALL) + if json_match: + try: + json_data = json.loads(json_match.group(1)) + for key in result.keys(): + if key in json_data: + result[key] = json_data[key] + except json.JSONDecodeError: + pass + + # Process each line for key-value pairs or section headers + for line in lines: + line = line.strip() + if not line: + continue + + # Check for key:value format + if ':' in line and not line.startswith('-') and not line.startswith('*'): + parts = line.split(':', 1) + key = parts[0].strip().lower() + value = parts[1].strip() if len(parts) > 1 else '' + + if 'title' in key and not result['title']: + result['title'] = value + elif 'description' in key and not result['description']: + result['description'] = value + elif 'content' in key and not result['content']: + result['content'] = value + + + # Set defaults if still empty + if not result['title'] and lines: + result['title'] = lines[0][:200].strip() + + if not result['content']: + result['content'] = response[:500].strip() + ('...' if len(response) > 500 else '') + + return result diff --git a/agent/multi_agent_coordinator.py b/agent/multi_agent_coordinator.py new file mode 100644 index 0000000..50052fb --- /dev/null +++ b/agent/multi_agent_coordinator.py @@ -0,0 +1,932 @@ +"""Multi-Agent Coordinator for managing collaborative agent execution.""" + +from typing import Any, Dict, List, Optional, Union +import json +import os +import traceback +from datetime import datetime + +from PIL import Image + +from browser_env import Action, Trajectory +from browser_env.helper_functions import get_action_description +from llms import lm_config + +# Try importing specific Observation types +try: + from browser_env.utils import Observation as BrowserObservation + ObservationType = Union[BrowserObservation, dict] +except ImportError: + # Fallback types for when browser_env is not available + BrowserObservation = None + ObservationType = Union[dict, Dict[str, Any]] + +# Type aliases for better type hints +ObservationTypeAlias = ObservationType + +import copy + +from .context_agent import ContextAgent +from .planner_agent import PlannerAgent +from .actor_agent import ActorAgent +from .reflector_agent import ReflectorAgent +from .coordinator.workflow_manager import WorkflowManager +from .coordinator.communication_hub import CommunicationHub + + +class MultiAgentCoordinator: + """Coordinates multiple agents for collaborative task execution. + + Manages the interaction between Context, Planner, Actor, and Reflector agents + to achieve complex web automation tasks through coordinated execution. + """ + + def __init__(self, lm_config: lm_config.LMConfig, + existing_prompt_agent, + browser_env=None, + result_dir: str = "results", + memory_config: Dict[str, Any]= {}, + webjudge_result_root: Optional[str] = None) -> None: + self.lm_config = lm_config + + # Get action set tag from existing agent or use default + action_set_tag = getattr(existing_prompt_agent, 'action_set_tag', 'som') + + self.enable_memory = memory_config.get("enable_memory", False) + self.enable_memory_store = memory_config.get("enable_memory_store", False) + + # Initialize individual agents with memory enabled if specified + self.context_agent = ContextAgent(lm_config, memory_config) + self.planner_agent = PlannerAgent(lm_config) + self.actor_agent = ActorAgent( + action_set_tag=action_set_tag, + lm_config=lm_config, + prompt_constructor=getattr(existing_prompt_agent, 'prompt_constructor', None), + captioning_fn=getattr(existing_prompt_agent, 'captioning_fn', None) + ) + self.reflector_agent = ReflectorAgent(lm_config) + + # Initialize coordination components + self.workflow_manager = WorkflowManager() + self.communication_hub = CommunicationHub() + + # Browser environment for action execution + self.browser_env = browser_env + + # Result directory and logging setup + self.result_dir = result_dir + # WebJudge 结果根目录(与现有结果分开保存) + self.webjudge_result_root = webjudge_result_root or os.path.join(self.result_dir, "webjudge_results") # WebJudge 输出根目录 + self.log_file_path = os.path.join(result_dir, "agent_responses.log") + self.observation_log_path = os.path.join(result_dir, "observations.json") + self.images_dir = os.path.join(result_dir, "images") + self._setup_logging() + + # Execution state + self.trajectory: Trajectory = [] + self.intentions: List[str] = [] + self.actions: List[Action] = [] + self.reflections: List[Dict[str, Any]] = [] + self.webjudge_action_history: List[str] = [] # 存动作文本历史 + self.webjudge_thoughts: List[str] = [] # 存每步意图/思考 + self.task_metadata: Dict[str, Any] = {} # 任务元信息缓存 + self.task_output_dir: Optional[str] = None # 当前任务输出目录 + self.trajectory_dir: Optional[str] = None # 截图存放目录 + self.result_json_path: Optional[str] = None # result.json 路径 + self.current_task_id: str = "" # 当前任务 ID + + # Meta data for action history tracking (required by DirectPromptConstructor) + # Initialize with "None" as the first action, matching run.py implementation + self.meta_data: Dict[str, Any] = {"action_history": ["None"]} + + # Task configuration + self.user_goal: str = "" + self.max_steps: int = 30 + self.current_observation: Optional["ObservationTypeAlias"] = None + + def _setup_logging(self) -> None: + """Setup logging for agent responses.""" + try: + # Ensure result directory exists + os.makedirs(self.result_dir, exist_ok=True) + + # Create images directory + os.makedirs(self.images_dir, exist_ok=True) + + # Create or clear the log file + with open(self.log_file_path, 'w', encoding='utf-8') as f: + f.write(f"Multi-Agent Execution Log - Started at {datetime.now().isoformat()}\n") + f.write("=" * 80 + "\n\n") + + # Initialize observation log file with empty list + with open(self.observation_log_path, 'w', encoding='utf-8') as f: + json.dump([], f, indent=2) + + except Exception as e: + print(f"Warning: Failed to setup logging: {e}") + + def log_agent_response(self, agent_name: str, step_number: int, response_data: Dict[str, Any]) -> None: + """Log agent response summary to file.""" + try: + timestamp = datetime.now().isoformat() + + with open(self.log_file_path, 'a', encoding='utf-8') as f: + f.write(f"[{timestamp}] Step {step_number} - {agent_name.upper()} Agent Response\n") + f.write("-" * 60 + "\n") + f.write(json.dumps(response_data, indent=2, ensure_ascii=False)) + f.write("\n\n") + except Exception as e: + print(f"Warning: Failed to log {agent_name} response: {e}") + + def log_observation(self, step_number: int, observation: Dict[str, Any]) -> None: + """Log observation text and save screenshot image.""" + try: + # Load existing observations + with open(self.observation_log_path, 'r', encoding='utf-8') as f: + observations = json.load(f) + + # Add new observation + obs_entry = { + "step": step_number, + "timestamp": datetime.now().isoformat(), + "text": observation.get("text", ""), + "has_image": observation.get("image") is not None + } + observations.append(obs_entry) + + # Save updated observations + with open(self.observation_log_path, 'w', encoding='utf-8') as f: + json.dump(observations, f, indent=2, ensure_ascii=False) + + # Save screenshot image if available + if observation.get("image") is not None: + image_path = os.path.join(self.images_dir, f"step_{step_number:03d}.png") + # Convert numpy array to PIL Image and save + from PIL import Image + import numpy as np + + if isinstance(observation["image"], np.ndarray): + img = Image.fromarray(observation["image"]) + img.save(image_path) + + except Exception as e: + print(f"Warning: Failed to log observation for step {step_number}: {e}") + + + def _prepare_webjudge_output(self, task_metadata: Optional[Dict[str, Any]], webjudge_root: Optional[str]) -> None: + """Create per-task directories for WebJudge-compatible outputs.""" + self.task_metadata = task_metadata or {} # 记录任务元信息 + self.current_task_id = self.task_metadata.get("task_id") or datetime.now().strftime("%Y%m%d_%H%M%S") # 若无 task_id 用时间戳代替 + base_root = webjudge_root or self.webjudge_result_root # 选择输出根目录 + self.task_output_dir = os.path.join(base_root, self.current_task_id) # 当前任务输出目录 + self.trajectory_dir = os.path.join(self.task_output_dir, "trajectory") # 截图子目录 + os.makedirs(self.trajectory_dir, exist_ok=True) + self.result_json_path = os.path.join(self.task_output_dir, "result.json") # 结果文件路径 + self.webjudge_action_history = [] # 重置动作记录 + self.webjudge_thoughts = [] # 重置思考记录 + + def _capture_and_save_screenshot(self, step_number: int) -> Optional[str]: + """Capture full-page screenshot directly from Playwright page (no SOM).""" + if self.browser_env is None or not hasattr(self.browser_env, "page"): + return None # 无浏览器实例时跳过 + if not self.trajectory_dir: + return None # 未初始化目录时跳过 + screenshot_path = os.path.join(self.trajectory_dir, f"step_{step_number:03d}.png") + try: + self.browser_env.page.screenshot(path=screenshot_path, full_page=True) # 直接截全页 + return screenshot_path + except Exception as e: + print(f"Warning: Failed to capture screenshot for step {step_number}: {e}") + return None + + def _format_action_for_webjudge(self, executed_action: Action, info: Optional[Dict[str, Any]]) -> str: + """Format action string for WebJudge action_history.""" + action_type = executed_action.get("action_type", "UNKNOWN") + description = executed_action.get("raw_prediction") or action_type # 默认用原始预测 + observation_metadata = None + if info and isinstance(info, dict): + observation_metadata = info.get("observation_metadata") + + # 将每步的 observation_metadata 单独存文件,便于按步排查 + # step_meta_path = os.path.join( + # self.trajectory_dir, + # f"observation_metadata_step_{self.workflow_manager.current_step:03d}.json" + # ) + # step_meta_path_action = os.path.join( + # self.trajectory_dir, + # f"observation_metadata_step_{self.workflow_manager.current_step:03d}_action.txt" + # ) + # try: + # with open(step_meta_path, "w", encoding="utf-8") as f: + # json.dump(observation_metadata, f, ensure_ascii=False, indent=2) + # with open(step_meta_path_action, "w", encoding="utf-8") as f: + # f.write(str(executed_action.get("element_id"))) + # except Exception as e: + # print(f"Warning: Failed to save observation metadata for step {self.workflow_manager.current_step}: {e}") + + if observation_metadata: + try: + description = get_action_description( + executed_action, + observation_metadata, + getattr(self.actor_agent, "action_set_tag", "id_accessibility_tree"), + getattr(self.actor_agent, "prompt_constructor", None), + ) + except Exception: + pass + return f"{description} -> {action_type}" + + def _format_action_for_webjudge_html(self, executed_action: Action, info: Optional[Dict[str, Any]]) -> str: + """备用:尽量还原原始 HTML 的动作描述(如 文本 -> CLICK)。 + 如需启用,将调用处的 _format_action_for_webjudge 替换为本函数即可。 + """ + action_type = executed_action.get("action_type", "UNKNOWN") + observation_metadata = None + if info and isinstance(info, dict): + observation_metadata = info.get("observation_metadata") + + node = None + if observation_metadata and "text" in observation_metadata: + text_meta = observation_metadata["text"] + node = text_meta.get("obs_nodes_info", {}).get(executed_action.get("element_id")) + + if node: + tag = node.get("tag") or node.get("nodeName") or "div" + attrs = [] + for k in ["id", "class", "href", "url", "name", "aria-label", "placeholder", "type", "value"]: + v = node.get(k) or node.get(k.replace("-", "_")) + if v: + attrs.append(f'{k}="{v}"') + role = node.get("role") + if role: + attrs.append(f'role="{role}"') + attr_str = (" " + " ".join(attrs)) if attrs else "" + inner = node.get("text") or node.get("alt") or node.get("name") or node.get("value") or tag + return f"<{tag}{attr_str}>{inner} -> {action_type}" + + # SOM 或缺元数据时退回简单描述 + elem_id = executed_action.get("element_id", "unknown") + return f" -> {action_type}" + + def _save_webjudge_result(self, final_result_response: str) -> None: + """Persist WebJudge style result.json.""" + if not self.result_json_path: + return + payload = { + "task_id": self.task_metadata.get("task_id", self.current_task_id), + "task": self.task_metadata.get("task") or self.user_goal, + "final_result_response": final_result_response, # 最终回答 + "action_history": self.webjudge_action_history, # 动作列表 + "thoughts": self.webjudge_thoughts, # 思考列表 + } + for key in ["website", "reference_length", "level", "confirmed_task"]: + if key in self.task_metadata: + payload[key] = self.task_metadata[key] # 可选元数据透传 + + os.makedirs(os.path.dirname(self.result_json_path), exist_ok=True) + with open(self.result_json_path, 'w', encoding='utf-8') as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + + + def execute_task( + self, + user_goal: str, + start_observation: Optional["ObservationTypeAlias"] = None, + max_steps: int = 30, + images: Optional[List[Image.Image]] = None, + task_metadata: Optional[Dict[str, Any]] = None, + webjudge_root: Optional[str] = None, + ) -> Dict[str, Any]: + """Execute a complete task using coordinated multi-agent approach. + + Args: + user_goal: The user's goal/task description + start_observation: Initial page observation + max_steps: Maximum number of execution steps + images: Optional input images for the task + + Returns: + Dictionary containing complete execution results + """ + # Initialize task + self.user_goal = user_goal + self.max_steps = max_steps + self.current_observation = start_observation + + # Prepare WebJudge output directories + self._prepare_webjudge_output(task_metadata, webjudge_root) # 为本任务创建输出目录 + + # Reset meta_data for new task execution (required by DirectPromptConstructor) + # Initialize with "None" as the first action, matching run.py implementation + self.meta_data = {"action_history": ["None"]} + + # Initialize workflow and monitoring + workflow_init = self.workflow_manager.initialize_workflow(max_steps) + + # Register agents with communication hub + self._register_agents() + + # Update shared context + self.communication_hub.update_shared_context("user_goal", user_goal) + self.communication_hub.update_shared_context("max_steps", max_steps) + + if self.enable_memory: + # Initialize memory system for this task + print("🧠 Initializing memory system...") + memory_initialization = self.context_agent.initialize_task_memory(user_goal) + if memory_initialization.get("relevant_memories"): + print(f"📚 Found {len(memory_initialization['relevant_memories'])} relevant memories") + if memory_initialization.get("memory_content"): + print(f"📝 Memory content: {memory_initialization['memory_content'][:100]}...") + else: + print("📚 No relevant memories found for this task") + + # Initialize trajectory with initial state + # This is required because DirectPromptConstructor expects trajectory[-1] to exist + if not self.trajectory: + # Handle different formats of start_observation + if start_observation is None: + # No observation provided, create empty state + initial_observation = {"text": "", "image": None} + initial_info = {"page": type('SimplePage', (), {'url': ''})(), "observation_metadata": {}} + elif isinstance(start_observation, dict) and "observation" in start_observation and "info" in start_observation: + # start_observation is already in StateInfo format: {"observation": obs, "info": info} + initial_observation = start_observation["observation"] + # Deep copy observation_metadata to avoid reference sharing issues + source_info = start_observation["info"] + initial_info = { + "page": source_info.get("page"), + "fail_error": source_info.get("fail_error", ""), + "observation_metadata": copy.deepcopy(source_info.get("observation_metadata", {})) + } + # Ensure observation has both text and image fields + if isinstance(initial_observation, dict): + if "text" not in initial_observation: + initial_observation["text"] = "" + if "image" not in initial_observation: + initial_observation["image"] = None + else: + # start_observation is the observation itself + initial_observation = start_observation + # Ensure observation is a dict with "text" and "image" keys + if isinstance(initial_observation, dict): + if "text" not in initial_observation: + initial_observation["text"] = "" + if "image" not in initial_observation: + initial_observation["image"] = None + else: + initial_observation = {"text": str(initial_observation) if initial_observation else "", "image": None} + + # Create a simple page-like object with url attribute + class SimplePage: + def __init__(self, url: str = ""): + self.url = url + initial_info = { + "page": SimplePage(url=""), # Will be updated when browser is initialized + "observation_metadata": {} + } + + initial_state_info = { + "observation": copy.deepcopy(initial_observation), + "info": copy.deepcopy(initial_info) + } + self.trajectory.append(initial_state_info) + self.current_observation = initial_observation + + # Save initial screenshot as step_000.png (playwright full page, no SOM) + self._capture_and_save_screenshot(0) # 初始页截图留档 + if initial_observation.get("image") is not None: + initial_image_path = os.path.join(self.images_dir, "step_000.png") + # Convert numpy array to PIL Image and save + from PIL import Image + import numpy as np + + if isinstance(initial_observation["image"], np.ndarray): + img = Image.fromarray(initial_observation["image"]) + img.save(initial_image_path) + print(f"📸 Saved initial screenshot as {initial_image_path}") + + step_result = None + # Main execution loop + while True: + try: + # Check if we should continue + context_summary = self._get_current_context_summary() + continuation_decision = self.workflow_manager.should_continue_execution(context_summary) + + if not continuation_decision["should_continue"]: + self.context_agent.update_state( + current_observation=self.current_observation, + latest_intention=self.intentions[-1] if self.intentions else None, + latest_action=self.actions[-1] if self.actions else None, + latest_reflection=self.reflections[-1] if self.reflections else None, + ) + break + + # Handle intervention requirements + if continuation_decision.get("requires_intervention"): + pass + + # Execute one coordination cycle + step_result = self._execute_coordination_cycle(images) + + # Check for early termination + if step_result.get("should_terminate", False): + break + + # Update current observation + if step_result.get("new_observation"): + self.current_observation = step_result["new_observation"] + + + except Exception as e: + # Record error and continue + print(f"Error during agent execution: {e}") + traceback.print_exc() + break + + # Finalize execution + final_context_summary = self._get_current_context_summary() + workflow_final = self.workflow_manager.finalize_workflow( + final_state="completed", + completion_reason=continuation_decision.get("reason", "Execution completed") + ) + + # Log final execution summary - use context agent completion check + task_completed = self.context_agent.check_task_completion(self.user_goal) + + if self.enable_memory_store: + self.context_agent.generate_and_store_memory(task_completed) + + final_summary = { + "total_steps_executed": len(self.actions), + "task_completed": task_completed, + "completion_percentage": "Completed" if task_completed else "In Progress", + "total_intentions": len(self.intentions), + "total_actions": len(self.actions), + "total_reflections": len(self.reflections) + } + self.log_agent_response("execution_summary", len(self.actions), final_summary) + + # Save WebJudge style result + final_result_response = step_result.get("execution_result", {}).get("action", {}).get("answer",{}) or ("Task completed" if task_completed else "Task incomplete") # 生成最终回答 + self._save_webjudge_result(final_result_response) # 写出 WebJudge 结果 + + # Return comprehensive execution result + return { + "task_info": { + "user_goal": self.user_goal, + "max_steps": self.max_steps, + "total_steps_executed": len(self.actions), + "task_completed": task_completed, + "completion_percentage": "Completed" if task_completed else "In Progress", + }, + "execution_trajectory": { + "intentions": self.intentions, + "actions": self.actions, + "reflections": self.reflections, + "trajectory_length": len(self.trajectory), + }, + "workflow_results": workflow_final, + "final_context": final_context_summary, + } + + def _register_agents(self) -> None: + """Register all agents with the communication hub.""" + self.communication_hub.register_agent( + "context_agent", {"state": "ready", "capabilities": ["context_management", "summary_generation"]} + ) + self.communication_hub.register_agent( + "planner_agent", {"state": "ready", "capabilities": ["task_planning", "intention_generation"]} + ) + self.communication_hub.register_agent( + "actor_agent", {"state": "ready", "capabilities": ["action_execution", "browser_interaction"]} + ) + self.communication_hub.register_agent( + "reflector_agent", {"state": "ready", "capabilities": ["reflection", "validation", "recovery"]} + ) + + def _execute_coordination_cycle(self, images: Optional[List[Image.Image]] = None) -> Dict[str, Any]: + """Execute one complete coordination cycle. + + Args: + images: Optional input images + + Returns: + Dictionary containing cycle results + """ + step_number = self.workflow_manager.current_step + 1 + print(f"🔄 Executing step {step_number}") + + + # 1. Context Agent updates context + print("🧠 Context Agent: Updating context...") + try: + context_result = self.context_agent.update_context( + trajectory=self.trajectory, + user_goal=self.user_goal, + current_observation=self.current_observation, + latest_intention=self.intentions[-1] if self.intentions else None, + latest_action=self.actions[-1] if self.actions else None, + latest_reflection=self.reflections[-1] if self.reflections else None, + ) + # Show only key context information + summary = context_result.get("summary", "No summary") + print(f"🧠 Context: {summary[:100]}{'...' if len(summary) > 100 else ''}") + + # Store response information for execution summary + + # Log context agent response summary with detailed breakdown + context_response = { + "summary": summary, + "observation_summary": context_result.get("observation_summary", ""), + "action_summary": context_result.get("action_summary", ""), + "reflection_summary": context_result.get("reflection_summary", "") + } + self.log_agent_response("context_agent", step_number, context_response) + + except Exception as e: + print(f"🧠 Context Error: {str(e)[:100]}{'...' if len(str(e)) > 100 else ''}") + context_result = { + "summary": "Error generating context", + "observation_summary": "", + "action_summary": "", + "reflection_summary": "" + } + + # Log context agent error summary + error_response = { + "error": str(e), + "summary": "Error generating context", + "observation_summary": "", + "action_summary": "", + "reflection_summary": "" + } + self.log_agent_response("context_agent", step_number, error_response) + + + # Check for task completion using context agent's completion check + if self.context_agent.check_task_completion(self.user_goal): + return { + "should_terminate": True, + "termination_reason": "Task completed", + "context_result": context_result, + } + + # 2. Planner Agent generates intention + print("🎯 Planner Agent: Generating intention...") + try: + planning_result = self.planner_agent.generate_intention( + user_goal=self.user_goal, + context_summary=context_result, + current_observation=self.current_observation, + previous_intentions=self.intentions, + ) + + current_intention = planning_result["intention"] + self.intentions.append(current_intention) + + # Show key planner information + current_subtask = planning_result.get("current_subtask", "") + next_atomic_action = planning_result.get("next_atomic_action", "") + reasoning = planning_result.get("reasoning", "") + all_subtasks = planning_result.get("all_subtasks", []) + current_step_index = planning_result.get("current_step_index", 0) + total_subtasks = planning_result.get("total_subtasks", 0) + + print(f"🎯 Current Subtask: {current_subtask[:100]}{'...' if len(current_subtask) > 100 else ''}") + if next_atomic_action != current_subtask: + print(f"🎯 Next Atomic Action: {next_atomic_action[:100]}{'...' if len(next_atomic_action) > 100 else ''}") + # print(f"🎯 Progress: Step {current_step_index + 1}/{total_subtasks}") + + # Show all subtasks overview + if all_subtasks and total_subtasks > 0: + print(f"🎯 Task Overview ({total_subtasks} subtasks):") + for i, subtask in enumerate(all_subtasks, 1): + status = "✅" if i <= current_step_index + 1 else "⏳" + print(f" {status} {i}. {subtask[:80]}{'...' if len(subtask) > 80 else ''}") + + # Show selected intention + print(f"✅ Selected Intention: {current_intention[:100]}{'...' if len(current_intention) > 100 else ''}") + + + # Log planner agent response summary + planner_response = { + "intention": current_intention, + "current_subtask": current_subtask, + "next_atomic_action": next_atomic_action, + "reasoning": reasoning, + "all_subtasks": all_subtasks, + "current_step_index": current_step_index, + "total_subtasks": total_subtasks, + "task_decomposed": planning_result.get("task_decomposed", False), + "response": planning_result.get("response", "") + } + self.log_agent_response("planner_agent", step_number, planner_response) + except Exception as e: + print(f"🎯 Planner Error: {str(e)[:100]}{'...' if len(str(e)) > 100 else ''}") + # Create fallback planning result for error case + planning_result = { + "intention": f"Continue working on: {self.user_goal}", + "current_subtask": f"Continue working on: {self.user_goal}", + "next_atomic_action": f"Continue working on: {self.user_goal}", + "reasoning": f"Fallback intention due to error: {str(e)}", + "all_subtasks": [f"Complete the task: {self.user_goal}"], + "current_step_index": 0, + "total_subtasks": 1, + "task_decomposed": False + } + + # Log planner agent error summary + error_response = { + "error": str(e), + "intention": planning_result["intention"], + "current_subtask": planning_result["current_subtask"], + "next_atomic_action": planning_result["next_atomic_action"], + "reasoning": planning_result["reasoning"], + "task_decomposed": planning_result["task_decomposed"] + } + self.log_agent_response("planner_agent", step_number, error_response) + + # 3. Actor Agent executes intention + print("🎬 Actor Agent: Executing intention...") + try: + # Merge step_number into meta_data while preserving action_history + # This matches the pattern used in run.py + meta_data_for_action = self.meta_data.copy() + meta_data_for_action["step_number"] = step_number + + # Ensure current_observation has both text and image fields + current_obs = self.current_observation or {"text": "", "image": None} + if isinstance(current_obs, dict): + if "text" not in current_obs: + current_obs["text"] = "" + if "image" not in current_obs: + current_obs["image"] = None + + execution_result = self.actor_agent.execute_intention( + intention=current_intention, + current_observation=current_obs, + trajectory=self.trajectory, + meta_data=meta_data_for_action, + images=images, + ) + + # Check if execution was successful and contains action + if "action" in execution_result: + executed_action = execution_result["action"] + self.actions.append(copy.deepcopy(executed_action)) + + # Execute action in browser environment if available + if self.browser_env is not None: + try: + print(f"🔍 Executing action in browser: {executed_action.get('action_type', 'UNKNOWN')}") + obs, reward, terminated, truncated, info = self.browser_env.step(executed_action) + + # Ensure observation has both text and image fields + if isinstance(obs, dict): + if "text" not in obs: + obs["text"] = "" + if "image" not in obs: + obs["image"] = None + else: + obs = {"text": str(obs) if obs else "", "image": None} + + # Update current observation with new browser state + self.current_observation = obs # Keep as observation format + new_observation = obs # Keep as observation format + + # Log observation (text and image) + self.log_observation(step_number, obs) + # Capture playwright full-page screenshot for WebJudge + self._capture_and_save_screenshot(step_number) # 保存当前步截图 + + # Determine if intention is fulfilled based on execution success + intention_fulfilled = reward == 1.0 # reward is 1.0 for success, 0.0 for failure + + # # Update trajectory with new state + # state_info = {"observation": obs, "info": info} + # self.trajectory.append(state_info) + + print(f"✅ Browser execution successful - URL: {info.get('page', {}).url if 'page' in info else 'Unknown'}") + except Exception as e: + print(f"❌ Browser execution failed: {str(e)}") + intention_fulfilled = False + info = None + new_observation = self.current_observation + else: + # No browser environment - simulate success for compatibility + intention_fulfilled = False # Will be determined by reflection + info = None + new_observation = self.current_observation + + # Update execution result with browser execution results + execution_result["intention_fulfilled"] = intention_fulfilled + + else: + # Create a fallback action for failed executions + executed_action = { + "action_type": "NONE", + "error": execution_result.get("error", "Unknown execution error"), + "intention": current_intention + } + self.actions.append(executed_action) + info = None + new_observation = self.current_observation + + # Show key actor execution information + action_type = executed_action.get("action_type", "UNKNOWN") + intention_fulfilled = execution_result.get("intention_fulfilled", False) + print(f"🎬 Actor: {action_type} - Fulfilled: {intention_fulfilled}") + + # Show LLM response if available (truncated) + llm_response = execution_result.get("llm_response", "") + if llm_response: + print(f" LLM Response: {llm_response[:500]}{'...' if len(llm_response) > 500 else ''}") + + + # Log actor agent response summary + actor_response = { + 'response': llm_response, + "action_type": action_type, + "fulfilled": intention_fulfilled + } + self.log_agent_response("actor_agent", step_number, actor_response) + + except Exception as e: + executed_action = { + "action_type": "NONE", + "error": str(e), + "intention": current_intention + } + self.actions.append(executed_action) + info = None + new_observation = self.current_observation + + print(f"🎬 Actor Error: {str(e)[:100]}{'...' if len(str(e)) > 100 else ''}") + + + # Store error response for execution summary, try to extract LLM response + error_details = str(e) + llm_response = "No LLM response available due to exception" + + # Try to get LLM response from execution_result if available + if hasattr(execution_result, 'get') and execution_result.get("llm_response"): + llm_response = execution_result.get("llm_response") + elif hasattr(execution_result, 'get') and execution_result.get("response"): + llm_response = execution_result.get("response") + + if hasattr(execution_result, 'get') and execution_result.get("error"): + error_details = execution_result.get("error") + elif hasattr(execution_result, 'get') and execution_result.get("exception_type"): + error_details = f"{execution_result.get('exception_type', 'Exception')}: {error_details}" + + + # Log actor agent error summary + self.log_agent_response("actor_agent", step_number, {"error": error_details}) + + # 4. Action execution is now handled above with browser environment + + # Add to trajectory (following run.py pattern) + # trajectory should contain: StateInfo, Action, StateInfo, Action, StateInfo... + # trajectory[-1] should already be the current StateInfo, so we just add action and new state + prev_state_info = None + + if self.trajectory and isinstance(self.trajectory[-1], dict): + prev_state_info = self.trajectory[-1] + + else: + prev_state_info = None + + + # observation_metadata = self.trajectory[-1].get("info", {}).get("observation_metadata", {}) + + # step_meta_path = os.path.join( + # self.trajectory_dir, + # f"observation_metadata_step_{self.workflow_manager.current_step:03d}_id_01.json" + # ) + # try: + # with open(step_meta_path, "w", encoding="utf-8") as f: + # json.dump(observation_metadata, f, ensure_ascii=False, indent=2) + # except Exception as e: + # print(f"Warning: Failed to save observation metadata for step {self.workflow_manager.current_step}: {e}") + + + # 在trajectory更新前,保存所有偶数位置的observation_metadata + # for pos in range(0, len(self.trajectory), 2): # 偶数位置是StateInfo + # if pos < len(self.trajectory): + # state_info = self.trajectory[pos] + # if isinstance(state_info, dict): + # observation_metadata = state_info.get("info", {}).get("observation_metadata", {}) + + # step_meta_path = os.path.join( + # self.trajectory_dir, + # f"observation_metadata_step_{self.workflow_manager.current_step:03d}_trajectory_pos_{pos:02d}.json" + # ) + # try: + # with open(step_meta_path, "w", encoding="utf-8") as f: + # json.dump(observation_metadata, f, ensure_ascii=False, indent=2) + # except Exception as e: + # print(f"Warning: Failed to save trajectory observation metadata for step {self.workflow_manager.current_step}, pos {pos}: {e}") + + prev_info_for_desc = prev_state_info.get("info") if isinstance(prev_state_info, dict) else None + + self.trajectory.append(executed_action) + if info is None: + info = { + "page": type('Page', (), {'url': ''})(), + "observation_metadata": {} + } + + # Create new state_info from new_observation + # 深拷贝info以避免trajectory中所有StateInfo引用同一个对象 + new_state_info = { + "observation": copy.deepcopy(new_observation), # new_observation is now observation format + "info": copy.deepcopy(info) + } + self.trajectory.append(new_state_info) + + # WebJudge logging + self.webjudge_thoughts.append(current_intention) # 记录本步意图 + use_html_format = False # 如需 HTML 标签样式,改为 True + formatter = self._format_action_for_webjudge_html if use_html_format else self._format_action_for_webjudge + self.webjudge_action_history.append(formatter(executed_action, prev_info_for_desc)) # 记录本步动作(使用执行前的 info 避免 DOM 变动导致缺元素) + self._capture_and_save_screenshot(step_number) # 记录本步截图 + + # Update current observation to stay in sync with trajectory + self.current_observation = new_observation + + # Update action_history in meta_data (required by DirectPromptConstructor) + # For multi-agent simulation, use simplified action descriptions to avoid dependency on complex observation_metadata + action_type = executed_action.get("action_type", "UNKNOWN") + + if action_type == "NONE": + action_str = f"Failed action: {executed_action.get('error', 'Unknown error')}" + elif action_type == "GOTO_URL": + url = executed_action.get("url", "unknown") + action_str = f"Navigate to {url}" + elif action_type == "CLICK": + element_id = executed_action.get("element_id", "unknown") + action_str = f"Click on element {element_id}" + elif action_type == "TYPE": + element_id = executed_action.get("element_id", "unknown") + text = executed_action.get("text", [""])[0] if executed_action.get("text") else "" + action_str = f"Type '{text}' into element {element_id}" + elif action_type == "SCROLL": + direction = executed_action.get("direction", "unknown") + action_str = f"Scroll {direction}" + elif action_type == "HOVER": + element_id = executed_action.get("element_id", "unknown") + action_str = f"Hover over element {element_id}" + else: + action_str = f"Action: {action_type}" + + self.meta_data["action_history"].append(action_str) + + # 5. Reflector Agent reflects on execution + reflection_result = self.reflector_agent.reflect_execution( + trajectory=self.trajectory, + intentions=self.intentions, + actions=self.actions, + current_intention=current_intention, + latest_action=executed_action, + current_observation=new_observation, + context_summary=context_result, + ) + + self.reflections.append(reflection_result) + + # Log reflector agent response summary + reflector_response = { + "effectiveness_result": reflection_result.get("effectiveness_analyzer", ""), + "pattern_result": reflection_result.get("pattern_detector", ""), + "triple_summary": reflection_result.get("triple_summary", "") + } + self.log_agent_response("reflector_agent", step_number, reflector_response) + + # 6. Record workflow step + self.workflow_manager.record_execution_step( + step_number=step_number, + intention=current_intention, + action=executed_action, + observation=new_observation, + reflection=reflection_result + ) + + return { + "should_terminate": False, + "step_number": step_number, + "context_result": context_result, + "planning_result": planning_result, + "execution_result": execution_result, + "reflection_result": reflection_result, + "new_observation": new_observation, + } + + + def _get_current_context_summary(self) -> Dict[str, Any]: + """Get current context summary without full update.""" + return self.context_agent.get_current_state() + + + \ No newline at end of file diff --git a/agent/planner/__init__.py b/agent/planner/__init__.py new file mode 100644 index 0000000..93e2c6a --- /dev/null +++ b/agent/planner/__init__.py @@ -0,0 +1,6 @@ +"""Planner Agent components for task decomposition and state analysis.""" + +from .task_decomposer import TaskDecomposer +from .current_state_analyzer import CurrentStateAnalyzer + +__all__ = ["TaskDecomposer", "CurrentStateAnalyzer"] \ No newline at end of file diff --git a/agent/planner/current_state_analyzer.py b/agent/planner/current_state_analyzer.py new file mode 100644 index 0000000..7632771 --- /dev/null +++ b/agent/planner/current_state_analyzer.py @@ -0,0 +1,161 @@ +"""Current state analysis for planning Agent.""" + +from typing import Any, Dict, List + +from browser_env.utils import Observation +from llms import lm_config, call_llm +from ..prompts.prompt_loader import load_prompt_template + + +class CurrentStateAnalyzer: + """Analyzes current state to determine current subtask and next atomic action.""" + + def __init__(self, lm_config: lm_config.LMConfig) -> None: + self.lm_config = lm_config + + def analyze_current_state( + self, + user_goal: str, + subtasks: List[str], + current_observation: Observation, + context_summary: Dict[str, Any], + ) -> Dict[str, Any]: + """Analyze current state to determine current subtask and next atomic action. + + Args: + user_goal: Original user goal + subtasks: List of decomposed subtasks + current_observation: Current page observation + context_summary: Current context from Context Agent + + Returns: + Dictionary containing state analysis results + """ + # Analyze current page content + # Handle both StateInfo format and direct observation format + if isinstance(current_observation, dict) and "observation" in current_observation: + # StateInfo format: {"observation": obs, "info": info} + obs_data = current_observation["observation"] + else: + # Direct observation format + obs_data = current_observation + + current_page_text = obs_data.get("text", "") if isinstance(obs_data, dict) else str(obs_data) + + # Get context information + observation_summary = context_summary.get("observation_summary", "") + action_summary = context_summary.get("action_summary", "") + reflection_summary = context_summary.get("reflection_summary", "") + + # Build state analysis prompt using template + subtasks_str = "\n".join([f"{i+1}. {subtask}" for i, subtask in enumerate(subtasks)]) + + memory = context_summary.get("memory_content", "") + # Build decomposition prompt using template + if memory != "": + prompt = load_prompt_template( + "planner_agent", + "current_state_analysis_w_mem", + memory=memory, + user_goal=user_goal, + subtasks=subtasks_str, + current_page_text=current_page_text, + observation_summary=observation_summary, + action_summary=action_summary, + reflection_summary=reflection_summary + ) + else: + prompt = load_prompt_template( + "planner_agent", + "current_state_analysis", + user_goal=user_goal, + subtasks=subtasks_str, + current_page_text=current_page_text, + observation_summary=observation_summary, + action_summary=action_summary, + reflection_summary=reflection_summary + ) + + try: + response = call_llm( + self.lm_config, [{"role": "user", "content": prompt}] + ).strip() + # Parse the LLM response into structured format + analysis = self._parse_analysis_response(response, subtasks) + + except Exception as e: + # Fallback analysis + print(f"🎯 Current State Analyzer: Error in analysis: {str(e)}") + analysis = self._generate_fallback_analysis(user_goal, subtasks, str(e)) + + return analysis + + def _parse_analysis_response(self, response: str, subtasks: List[str]) -> Dict[str, Any]: + """Parse LLM state analysis response into structured format using XML tags.""" + import re + + # Default analysis structure + analysis = { + "current_subtask": "", + "next_atomic_action": "", + "reasoning": "", + "response": response + } + + try: + # Extract content from XML tags using regex + think_match = re.search(r'(.*?)', response, re.DOTALL | re.IGNORECASE) + if think_match: + analysis["reasoning"] = think_match.group(1).strip() + + current_subtask_match = re.search(r'(.*?)', response, re.DOTALL | re.IGNORECASE) + if current_subtask_match: + analysis["current_subtask"] = current_subtask_match.group(1).strip() + + next_action_match = re.search(r'(.*?)', response, re.DOTALL | re.IGNORECASE) + if next_action_match: + analysis["next_atomic_action"] = next_action_match.group(1).strip() + + except Exception as e: + # If parsing fails, fall back to using the entire response + analysis["reasoning"] = response + analysis["next_atomic_action"] = response + + # If no current subtask extracted, try to match from subtasks list + if not analysis["current_subtask"] and subtasks: + # Simple keyword matching with the reasoning + reasoning_lower = analysis["reasoning"].lower() + for subtask in subtasks: + subtask_lower = subtask.lower() + # Match if any significant words from subtask appear in reasoning + subtask_words = [word for word in subtask_lower.split() if len(word) > 3] + if any(word in reasoning_lower for word in subtask_words): + analysis["current_subtask"] = subtask + break + else: + # Default to first subtask if no match found + analysis["current_subtask"] = subtasks[0] + + # If no next action extracted, use a simple fallback + if not analysis["next_atomic_action"]: + print(f"🎯 Current State Analyzer: No next atomic action extracted, using current subtask: {analysis['current_subtask']}") + analysis["next_atomic_action"] = f"Continue working on: {analysis['current_subtask']}" + + return analysis + + def _generate_fallback_analysis(self, user_goal: str, subtasks: List[str], error: str) -> Dict[str, Any]: + """Generate fallback state analysis when LLM fails.""" + # Simple fallback logic + if subtasks: + current_subtask = subtasks[0] # Default to first subtask + else: + current_subtask = f"Work on: {user_goal}" + + # Generate simple fallback action + next_action = f"Continue working on: {current_subtask}" + + return { + "current_subtask": current_subtask, + "next_atomic_action": next_action, + "reasoning": f"Fallback analysis due to error: {error}" + } diff --git a/agent/planner/task_decomposer.py b/agent/planner/task_decomposer.py new file mode 100644 index 0000000..2d23994 --- /dev/null +++ b/agent/planner/task_decomposer.py @@ -0,0 +1,154 @@ +"""Task decomposition for planning Agent.""" + +from typing import Any, Dict, List + +from browser_env.utils import Observation +from llms import lm_config, call_llm +from ..prompts.prompt_loader import load_prompt_template + + +class TaskDecomposer: + """Decomposes complex tasks into manageable subtasks.""" + + def __init__(self, lm_config: lm_config.LMConfig) -> None: + self.lm_config = lm_config + + def decompose_task( + self, + user_goal: str, + current_observation: Observation, + context_summary: Dict[str, Any], + ) -> Dict[str, Any]: + """Decompose user goal into 3-5 manageable subtasks. + + Args: + user_goal: Original user goal + current_observation: Current page observation + + Returns: + Dictionary containing task decomposition results + """ + # Analyze current page content + # Handle both StateInfo format and direct observation format + if isinstance(current_observation, dict) and "observation" in current_observation: + # StateInfo format: {"observation": obs, "info": info} + obs_data = current_observation["observation"] + else: + # Direct observation format + obs_data = current_observation + + current_page_text = obs_data.get("text", "") if isinstance(obs_data, dict) else str(obs_data) + + memory = context_summary.get("memory_content", "") + # Build decomposition prompt using template + if memory != "": + prompt = load_prompt_template( + "planner_agent", + "task_decomposition_w_mem", + memory=memory, + user_goal=user_goal, + current_page_text=current_page_text + ) + else: + prompt = load_prompt_template( + "planner_agent", + "task_decomposition", + user_goal=user_goal, + current_page_text=current_page_text + ) + + try: + response = call_llm( + self.lm_config, [{"role": "user", "content": prompt}] + ).strip() + + # Parse the LLM response into structured format + decomposition = self._parse_decomposition_response(response) + + except Exception as e: + # Fallback decomposition + decomposition = self._generate_fallback_decomposition(user_goal, str(e)) + + return decomposition + + def _parse_decomposition_response(self, response: str) -> Dict[str, Any]: + """Parse LLM task decomposition response into structured format.""" + # Simple parsing for subtasks + decomposition = { + "subtasks": [], + "reasoning": response, + } + + # Try to extract subtasks + lines = response.split('\n') + for line in lines: + line = line.strip() + + # Skip empty lines and headers + if not line or line.lower().startswith(('here are', 'subtasks:', 'steps:', 'breakdown:')): + continue + + # Remove numbering and bullet points + cleaned = line.lstrip('0123456789.-* ') + cleaned = cleaned.lstrip('- ') + cleaned = cleaned.lstrip('• ') + + # Clean up common prefixes + for prefix in ['Subtask:', 'Step:', 'Task:']: + if cleaned.startswith(prefix): + cleaned = cleaned[len(prefix):].strip() + break + + if cleaned and len(cleaned) > 10: # Reasonable minimum length + decomposition["subtasks"].append(cleaned) + + # If no subtasks were parsed, use the full response + if not decomposition["subtasks"] and len(response.strip()) > 10: + decomposition["subtasks"] = [response.strip()] + + # Limit to 3-5 subtasks + decomposition["subtasks"] = decomposition["subtasks"][:5] + + return decomposition + + def _generate_fallback_decomposition(self, user_goal: str, error: str) -> Dict[str, Any]: + """Generate fallback task decomposition when LLM fails.""" + # Generate generic subtasks based on common web task patterns + subtasks = [] + + # Common web task patterns + if any(keyword in user_goal.lower() for keyword in ['search', 'find', 'look for']): + subtasks.extend([ + "Navigate to search functionality", + "Enter search query", + "Review search results", + "Select relevant option" + ]) + + elif any(keyword in user_goal.lower() for keyword in ['buy', 'purchase', 'order', 'cart']): + subtasks.extend([ + "Locate product or service to purchase", + "Add item to shopping cart", + "Proceed to checkout process", + "Complete purchase" + ]) + + elif any(keyword in user_goal.lower() for keyword in ['information', 'details', 'about']): + subtasks.extend([ + "Look for information sections or links", + "Navigate to relevant information pages", + "Extract and review requested information" + ]) + + else: + # Generic subtasks + subtasks = [ + f"Get started with the task: {user_goal}", + f"Make progress on: {user_goal}", + f"Complete the task: {user_goal}" + ] + + return { + "subtasks": subtasks[:5], # Limit to 5 subtasks + "reasoning": f"Fallback decomposition due to error: {error}", + } \ No newline at end of file diff --git a/agent/planner_agent.py b/agent/planner_agent.py new file mode 100644 index 0000000..c1efaf7 --- /dev/null +++ b/agent/planner_agent.py @@ -0,0 +1,161 @@ +"""Planner Agent for task decomposition and current state analysis.""" + +from typing import Any, Dict, List, Optional + +from browser_env.utils import Observation +from llms import lm_config + +from .planner.task_decomposer import TaskDecomposer +from .planner.current_state_analyzer import CurrentStateAnalyzer + + +class PlannerAgent: + """Decomposes complex tasks and analyzes current state for execution planning. + + Responsible for: + 1. Initial task decomposition into manageable subtasks + 2. Current state analysis to determine progress and next atomic action + """ + + def __init__(self, lm_config: lm_config.LMConfig) -> None: + self.lm_config = lm_config + self.task_decomposer = TaskDecomposer(lm_config) + self.state_analyzer = CurrentStateAnalyzer(lm_config) + + # Planning state + self.subtasks: List[str] = [] + self.current_step_index: int = 0 + self.task_decomposed: bool = False + + def generate_intention( + self, + user_goal: str, + context_summary: Dict[str, Any], + current_observation: Optional[Observation] = None, + previous_intentions: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """Generate the next execution intention based on current state. + + Args: + user_goal: Original user goal/task description + context_summary: Current context from Context Agent + current_observation: Current page observation + previous_intentions: List of intentions already completed (not used in new approach) + + Returns: + Dictionary containing selected intention and metadata + """ + + # Step 1: Perform task decomposition only on the first step + if not self.task_decomposed: + print("🎯 Planner Agent: Decomposing task...") + decomposition_result = self.task_decomposer.decompose_task( + user_goal=user_goal, + current_observation=current_observation or {"text": ""}, + context_summary=context_summary, + ) + self.subtasks = decomposition_result.get("subtasks", []) + self.task_decomposed = True + + print(f"🎯 Task decomposed into {len(self.subtasks)} subtasks:") + for i, subtask in enumerate(self.subtasks, 1): + print(f" {i}. {subtask}") + else: + print("🎯 Planner Agent: Analyzing current state...") + + # Step 2: Analyze current state to determine current subtask and next action + state_analysis = self.state_analyzer.analyze_current_state( + user_goal=user_goal, + subtasks=self.subtasks, + current_observation=current_observation or {"text": ""}, + context_summary=context_summary, + ) + + # Extract key information from state analysis + current_subtask = state_analysis.get("current_subtask", "") + next_atomic_action = state_analysis.get("next_atomic_action", "") + reasoning = state_analysis.get("reasoning", "") + response = state_analysis.get("response", "") + + # Create the intention for the Actor Agent + # Use the atomic action as the main intention for more precise execution + if next_atomic_action: + selected_intention = next_atomic_action + elif current_subtask: + selected_intention = current_subtask + # Update current_step_index to match the LLM-determined current subtask + # Try to find the matching subtask, handling cases where LLM adds numbering prefixes + matched_index = self._find_subtask_index(current_subtask) + if matched_index is not None: + self.current_step_index = matched_index + else: + # Fallback intention + if self.current_step_index < len(self.subtasks): + selected_intention = self.subtasks[self.current_step_index] + else: + print(f"🎯 Planner Agent: No subtasks available, continuing with user goal: {user_goal}") + selected_intention = f"Continue working on: {user_goal}" + + # Build planning result with comprehensive information + planning_result = { + "intention": selected_intention, + "current_subtask": current_subtask, + "next_atomic_action": next_atomic_action, + "reasoning": reasoning, + "all_subtasks": self.subtasks, + "current_step_index": self.current_step_index, + "total_subtasks": len(self.subtasks), + "task_decomposed": self.task_decomposed, + "state_analysis": state_analysis, + "user_goal": user_goal, + "response": response + } + + return planning_result + + def reset_planning_state(self) -> None: + """Reset planning state for a new task.""" + self.subtasks.clear() + self.current_step_index = 0 + self.task_decomposed = False + + + def _find_subtask_index(self, current_subtask: str) -> Optional[int]: + """Find the index of a subtask, handling cases where LLM adds numbering prefixes. + + Args: + current_subtask: The subtask text from LLM analysis (may include numbering) + + Returns: + Index of the matching subtask, or None if not found + """ + # First try exact match + if current_subtask in self.subtasks: + return self.subtasks.index(current_subtask) + + # Try removing common numbering patterns (e.g., "1. ", "2. ", "(1) ", etc.) + import re + + # Pattern to match numbering prefixes like "1. ", "2. ", "(1) ", "1) ", etc. + cleaned_subtask = re.sub(r'^\s*\d+\.?\s*', '', current_subtask).strip() + cleaned_subtask = re.sub(r'^\s*\(\d+\)\s*', '', cleaned_subtask).strip() + + # Try exact match with cleaned version + if cleaned_subtask in self.subtasks: + return self.subtasks.index(cleaned_subtask) + + # Try partial match (first N characters) for robustness + for i, subtask in enumerate(self.subtasks): + # Remove numbering from stored subtask too + cleaned_stored = re.sub(r'^\s*\d+\.?\s*', '', subtask).strip() + cleaned_stored = re.sub(r'^\s*\(\d+\)\s*', '', cleaned_stored).strip() + + # Check if they match (case insensitive, ignore extra whitespace) + if cleaned_subtask.lower().strip() == cleaned_stored.lower().strip(): + return i + + # Fallback: check if the cleaned subtask is contained in the stored subtask + if len(cleaned_subtask) > 10 and cleaned_subtask.lower() in cleaned_stored.lower(): + return i + + return None \ No newline at end of file diff --git a/agent/prompts/multi_agent_prompts_fixed.json b/agent/prompts/multi_agent_prompts_fixed.json new file mode 100644 index 0000000..92dc457 --- /dev/null +++ b/agent/prompts/multi_agent_prompts_fixed.json @@ -0,0 +1,125 @@ +{ + "description": "Modular prompt templates for multi-agent system", + "version": "1.0", + "last_updated": "2025-01-29", + "agents": { + "memory_generator": { + "memory_initial_generation": { + "name": "Memory Initial Generation", + "task": "Generate initial memory", + "template": "You are an AI assistant that generates structured memories from web automation task execution.\n\nBased on the following execution data, generate a concise but informative memory entry.\n\nTASK GOAL:\n{user_goal}\n\nTRAJECTORY:\n{trajectory_summary}\n\nTASK COMPLETED: {task_completed}\n\nPlease generate ONLY the structured memory entry in the following JSON format. Ensure your response is valid JSON that can be directly parsed:\n{{\n \"title\": \"Brief, descriptive title for this experience (10-15 words)\",\n \"description\": \"One-sentence description of what was learned\",\n \"content\": \"Detailed description of the experience, including what worked, what didn\\'t, and key insights that could help with similar tasks\"\n}}\n\nImportant notes:\n1. Respond with ONLY the JSON object, no additional text or explanations\n2. Ensure all strings are properly escaped\n3. Focus on extracting meaningful patterns from the trajectory data\n4. Provide concrete insights that could help with similar tasks\n5. Keep the content informative but concise (2-3 paragraphs maximum)\n\nThis is an initial memory that will be refined later with more trajectory data.", + "variables": ["user_goal", "trajectory_summary", "task_completed"], + "temperature": 0.6, + "max_tokens": 2048 + }, + "memory_refinement": { + "name": "Memory Refinement", + "task": "Refine memory based on recent observations and actions", + "template": "You are an AI assistant that generates structured memories from web automation task execution.\n\nBased on the latest execution trajectory, refine and improve the memory entry by incorporating historical context and new information.\n\nTASK GOAL:\n{user_goal}\n\nINITIAL MEMORY:\n{initial_memory}\n\nHISTORICAL INTENTIONS:\n{history_intentions}\n\nHISTORICAL ACTIONS:\n{history_actions}\n\nLATEST EXECUTION TRAJECTORY:\n{trajectory_summary}\n\nTASK COMPLETION STATUS:\n{task_completed}\n\nPlease refine the initial memory based on the complete execution data, including historical context and latest trajectory. Please generate ONLY the structured memory entry in the following JSON format. Ensure your response is valid JSON that can be directly parsed:\n\n{{\n \"title\": \"Brief, descriptive title for this experience\",\n \"description\": \"Updated description incorporating results and learnings\",\n \"content\": \"Comprehensive description of the experience, including what worked, what didn\\'t, and key insights\"\n}}\n\nFocus on these key areas:\n1. How initial patterns evolved throughout the entire execution\n2. The complete journey from start to now, connecting historical actions with latest developments\n3. What ultimately led to success or failure, with specific references to critical actions\n4. Key lessons that would help with similar tasks in the future\n5. Unexpected discoveries, obstacles encountered, and how they were overcome\n6. Most effective strategies identified during the entire execution\n7. Integration of observations, actions, intentions and reflections from both historical and latest trajectory\n\nWhen refining:\n- Preserve valuable insights from the initial memory\n- Expand with new information from the latest trajectory\n- Consider the full context by connecting historical and latest events\n- Ensure the refined memory tells a coherent story of the entire task execution\n- Make the content actionable and useful for future similar tasks\n- Highlight patterns that emerged throughout the execution\n\nThe refined memory should be comprehensive, accurate, and provide valuable guidance for future task executions of similar nature.", + "variables": [ + "initial_memory", "user_goal", "history_intentions", "history_actions", "trajectory_summary", "task_completed" + ], + "temperature": 0.6, + "max_tokens": 2048 + } + }, + "context_agent": { + "observation_summarization": { + "name": "Observation Summarization", + "task": "Generate intelligent summary of web page observations", + "template": "Generate a concise but comprehensive summary of these web page observations for context awareness.\n\nPAGE OBSERVATIONS:\n{observations_text}\n\nFocus on:\n1. Key page content and structure\n2. Important elements, forms, or interactive components\n3. Navigation paths or menus\n4. Search functionality or filters\n5. Product/service information if applicable\n6. Any error messages or unusual elements\n\nProvide a 3-4 sentence summary that captures the essential information from these observations that would be most useful for planning next actions in a web automation task:", + "variables": ["observations_text"], + "temperature": 0.6, + "max_tokens": 2048 + }, + "summary_generation": { + "name": "Context Summary Generation", + "task": "Generate concise context summaries for agent coordination", + "template": "Generate a concise context summary for a web automation task.\n\nTASK INFORMATION:\nUser Goal: {user_goal}\nTotal Steps Taken: {total_steps}\n\nRECENT CONTEXT:\nRecent Page States: {observation_summary}\nRecent Actions Taken: {action_summary}\nRecent Performance: {reflection_summary}\n\nGenerate a brief summary (2-5 sentences) that captures:\n1. Current task progress based on observations and actions taken\n2. Key insights from recent observations and performance patterns\n", + "variables": [ + "user_goal", "total_steps", "observation_summary", "action_summary", "reflection_summary" + ], + "temperature": 0.7, + "max_tokens": 2048 + }, + "reflection_summarization": { + "name": "Reflection Summarization", + "task": "Generate intelligent summary of performance reflections", + "template": "Generate a concise but comprehensive summary of these performance reflections for context awareness.\n\nREFLECTION DATA:\n{reflections_text}\n\nFocus on:\n1. Overall effectiveness of recent actions and decision-making\n2. Detected execution patterns (repetitive actions, failures, stuck behaviors)\n3. Progress toward task completion and areas of success\n4. Potential issues or concerns that need attention\n5. Key insights that could improve future actions\n\nProvide a 3-4 sentence summary that captures the essential performance insights from these reflections that would be most useful for planning next actions in a web automation task:", + "variables": ["reflections_text"], + "temperature": 0.6, + "max_tokens": 2048 + } + }, + "planner_agent": { + "task_decomposition": { + "name": "Task Decomposition", + "task": "Decompose complex user goal into manageable subtasks", + "template": "Decompose the following web automation task into 3-5 manageable subtasks.\n\nUSER GOAL: {user_goal}\n\nCURRENT PAGE STATE:\n{current_page_text}\n\nFocus on breaking down the goal into logical, sequential steps that can be executed through web browser interactions. Each subtask should:\n1. Be specific and actionable\n2. Represent a meaningful step toward the overall goal\n3. Be executable through web browser interactions\n4. Move the task forward in a logical sequence\n\nExamples of good subtasks:\n- \"Navigate to search functionality\"\n- \"Enter search query and review results\"\n- \"Locate and select relevant items\"\n- \"Complete purchase process\"\n- \"Extract and verify required information\"\n\nGenerate 3-5 subtasks, each on a new line:", + "variables": ["user_goal", "current_page_text"], + "temperature": 0.7, + "max_tokens": 2048 + }, + "task_decomposition_w_mem": { + "name": "Task Decomposition with Memory", + "task": "Decompose complex user goal into manageable subtasks", + "template": "Decompose the following web automation task into 3-5 manageable subtasks.\n\nHere are some experiences and memories from previous related tasks for reference:\n{memory}\n\nCURRENT USER GOAL: {user_goal}\n\nCURRENT PAGE STATE:\n{current_page_text}\n\nFocus on breaking down the goal into logical, sequential steps that can be executed through web browser interactions. Each subtask should:\n1. Be specific and actionable\n2. Represent a meaningful step toward the overall goal\n3. Be executable through web browser interactions\n4. Move the task forward in a logical sequence\n\nExamples of good subtasks:\n- \"Navigate to search functionality\"\n- \"Enter search query and review results\"\n- \"Locate and select relevant items\"\n- \"Complete purchase process\"\n- \"Extract and verify required information\"\n\nGenerate 3-5 subtasks, each on a new line:", + "variables": ["memory","user_goal", "current_page_text"], + "temperature": 0.7, + "max_tokens": 2048 + }, + "current_state_analysis": { + "name": "Current State Analysis", + "task": "Analyze current state and determine next atomic action", + "template": "Analyze the current execution state briefly to determine which subtask is currently being worked on and what the next atomic action should be.\n\nUSER GOAL: {user_goal}\n\nTASK SUBTASKS:\n{subtasks}\n\nCURRENT PAGE STATE:\n{current_page_text}\n\nRECENT CONTEXT:\nObservations Summary: {observation_summary}\nActions Summary: {action_summary}\nPerformance Summary: {reflection_summary}\n\nINSTRUCTIONS:\nYou must respond with a structured analysis using the following exact XML format:\n\n\nYour step-by-step reasoning about the current state and next action\n\n\nThe exact subtask currently being worked on (must match one from the TASK SUBTASKS list)\n\n\nSpecific atomic action to execute next (e.g., \"click the search button\", \"type 'hello' into the input field\")\n\n\nEXAMPLE:\n\nThe user wants to search for products. We're currently on the homepage and need to start the search process. The first subtask is to navigate to search functionality, which we haven't completed yet.\n\n\nNavigate to search functionality\n\n\nClick on the search box to activate it\n\n\nANOTHER EXAMPLE:\n\nWe've successfully navigated to the search page and entered the query. Now we need to review the search results and select relevant items. This matches subtask 2.\n\n\nEnter search query and review results\n\n\nScroll down to view all search results\n", + "variables": ["user_goal", "subtasks", "current_page_text", "observation_summary", "action_summary", "reflection_summary"], + "temperature": 0.6, + "max_tokens": 2048 + }, + "current_state_analysis_w_mem": { + "name": "Current State Analysis with Memory", + "task": "Analyze current state and determine next atomic action", + "template": "Analyze the current execution state briefly to determine which subtask is currently being worked on and what the next atomic action should be.\n\nHere are some experiences and memories from previous related tasks for reference:\n{memory}\n\nCURRENT USER GOAL: {user_goal}\n\nTASK SUBTASKS:\n{subtasks}\n\nCURRENT PAGE STATE:\n{current_page_text}\n\nRECENT CONTEXT:\nObservations Summary: {observation_summary}\nActions Summary: {action_summary}\nPerformance Summary: {reflection_summary}\n\nINSTRUCTIONS:\nYou must respond with a structured analysis using the following exact XML format:\n\n\nYour step-by-step reasoning about the current state and next action\n\n\nThe exact subtask currently being worked on (must match one from the TASK SUBTASKS list)\n\n\nSpecific atomic action to execute next (e.g., \"click the search button\", \"type 'hello' into the input field\")\n\n\nEXAMPLE:\n\nThe user wants to search for products. We're currently on the homepage and need to start the search process. The first subtask is to navigate to search functionality, which we haven't completed yet.\n\n\nNavigate to search functionality\n\n\nClick on the search box to activate it\n\n\nANOTHER EXAMPLE:\n\nWe've successfully navigated to the search page and entered the query. Now we need to review the search results and select relevant items. This matches subtask 2.\n\n\nEnter search query and review results\n\n\nScroll down to view all search results\n", + "variables": ["memory","user_goal", "subtasks", "current_page_text", "observation_summary", "action_summary", "reflection_summary"], + "temperature": 0.6, + "max_tokens": 2048 + } + }, + "actor_agent": { + "action_execution": { + "name": "Browser Action Execution", + "task": "Generate specific browser actions to fulfill high-level intentions", + "template": "HIGH-LEVEL TASK INTENTION: {intention}\n\nCURRENT PAGE CONTEXT:\n{page_context}\n\nYour task is to execute specific browser actions to fulfill the stated intention.\n\nRequirements:\n1. Generate specific, actionable browser operations\n2. Use appropriate action types (CLICK, TYPE, GOTO, SCROLL, etc.)\n3. Include specific element identifiers when available\n4. Focus on efficiency and directness\n5. Handle any errors or page issues appropriately\n\n2. Generate specific, actionable browser operations\n3. Focus on completing the stated intention directly and efficiently\n\nExecute action: {intention}", + "variables": ["intention", "page_context"], + "temperature": 0.7, + "max_tokens": 2048 + } + }, + "reflector_agent": { + "action_validation": { + "name": "Action Execution Validation", + "task": "Validate whether web actions successfully completed their intended purpose", + "template": "Validate whether this web action successfully completed its intended purpose.\n\nINTENDED ACTION: {intended_action}\n\nEXECUTED ACTION:\n- Type: {action_type}\n- Element ID: {element_id}\n- Text: {action_text}\n\nBEFORE EXECUTION: {obs_before}...\nAFTER EXECUTION: {obs_after}...\n\nEvaluate:\n1. Did the action type match what was intended?\n2. Was the action executed successfully?\n3. Did the result show progress toward the intended action?\n4. Are there any error indicators in the result?\n\nProvide a validation result with:\n- success: true/false (Did it work as intended?)\n- confidence: 0.0-1.0 (How confident are you?)\n- reasoning: Brief explanation\n- issues: Any problems detected (list)", + "variables": ["intended_action", "action_type", "element_id", "action_text", "obs_before", "obs_after"], + "temperature": 0.3, + "max_tokens": 2048 + }, + "effectiveness_analysis": { + "name": "Action Effectiveness Analysis", + "task": "Analyze whether actions are making progress toward task completion", + "template": "Analyze the effectiveness of this web automation action in natural language.\n\nCURRENT INTENTION: {current_intention}\n\nEXECUTED ACTION:\n{latest_action}\n\nCURRENT PROGRESS METRICS:\n{context_summary}\n\nRECENT EXECUTION CONTEXT:\n{trajectory_summary}\n\nProvide a natural language analysis of the action's effectiveness. Focus on:\n1. Did it meaningfully progress the task?\n2. Was it the right action for the current situation?\n3. Did it avoid creating new problems?\n4. How well did it align with the intention?\n\nReturn a concise natural language assessment (2-4 sentences) describing whether the action was effective and why.", + "variables": ["current_intention", "latest_action", "context_summary", "trajectory_summary"], + "temperature": 0.5, + "max_tokens": 2048 + }, + "triple_summary": { + "name": "State Transition Triple Summary", + "task": "Generate enhanced summary of state transition (O_{t-1}, I_t, A_t, O_t)", + "template": "Generate an enhanced summary of this state transition in the web automation task.\n\nPREVIOUS OBSERVATION (O_[t-1]):\n{t-1}\n\nCURRENT INTENTION (I_t):\n{current_intention}\n\nEXECUTED ACTION (A_t):\n- Type: {action_type}\n- Element: {element_id}\n- Text: {action_text}\n\nCURRENT OBSERVATION (O_t):\n{current_observation}\n\nProvide a concise but informative summary that describes:\n1. What was the state before the action?\n2. What was the intention and action taken?\n3. What was the resulting state after the action?\n4. Did this represent meaningful progress toward the goal?\n\nGenerate a natural language summary (2-3 sentences) that captures the essence of this state transition.", + "variables": ["t-1", "current_intention", "action_type", "element_id", "action_text", "current_observation"], + "temperature": 0.6, + "max_tokens": 2048 + } + } + } +} \ No newline at end of file diff --git a/agent/prompts/prompt_constructor.py b/agent/prompts/prompt_constructor.py index 5c50c9d..ebd8dc7 100644 --- a/agent/prompts/prompt_constructor.py +++ b/agent/prompts/prompt_constructor.py @@ -380,19 +380,20 @@ def get_lm_api_input( "image_url": {"url": pil_to_b64(page_screenshot_img)}, }, ] - for image_i, image in enumerate(images): - content.extend( - [ - { - "type": "text", - "text": f"({image_i+2}) input image {image_i+1}", - }, - { - "type": "image_url", - "image_url": {"url": pil_to_b64(image)}, - }, - ] - ) + if images is not None: + for image_i, image in enumerate(images): + content.extend( + [ + { + "type": "text", + "text": f"({image_i+2}) input image {image_i+1}", + }, + { + "type": "image_url", + "image_url": {"url": pil_to_b64(image)}, + }, + ] + ) content = [{"type": "text", "text": current_prompt}] + content message.append({"role": "user", "content": content}) diff --git a/agent/prompts/prompt_loader.py b/agent/prompts/prompt_loader.py new file mode 100644 index 0000000..66321b0 --- /dev/null +++ b/agent/prompts/prompt_loader.py @@ -0,0 +1,317 @@ +"""Modular prompt loader for multi-agent system. + +This module provides a centralized way to load and manage prompt templates +for all agents in the multi-agent coordination system. +""" + +import json +import os +from pathlib import Path +from typing import Dict, Any, Optional, List +from llms import lm_config, call_llm + + +class PromptLoader: + """Centralized prompt loading and management system.""" + + def __init__(self, prompts_file: Optional[str] = None): + """Initialize the prompt loader. + + Args: + prompts_file: Path to the prompts JSON configuration file + """ + self.prompts_file = prompts_file or self._get_default_prompts_file() + self.prompts_data = self._load_prompts() + + def _get_default_prompts_file(self) -> str: + """Get the default prompts file path.""" + return str(Path(__file__).parent / "multi_agent_prompts_fixed.json") + + def _load_prompts(self) -> Dict[str, Any]: + """Load prompts from the JSON configuration file.""" + try: + with open(self.prompts_file, 'r', encoding='utf-8') as f: + return json.load(f) + except FileNotFoundError: + print(f"⚠️ Prompts file not found: {self.prompts_file}") + print("⚠️ Using fallback prompts") + return self._get_fallback_prompts() + except json.JSONDecodeError as e: + print(f"⚠️ Invalid JSON in prompts file: {e}") + print("⚠️ Using fallback prompts") + return self._get_fallback_prompts() + + def _get_fallback_prompts(self) -> Dict[str, Any]: + """Get minimal fallback prompts when file loading fails.""" + return { + "context_agent": { + "summary_generation": { + "name": "Context Summary Generation", + "task": "Generate a brief context summary for web automation task.", + "template": "Task progress: Step {current_step}/{total_steps}, {completion_pct:.0%} complete." + } + }, + "planner_agent": { + "intention_generation": { + "name": "Intention Generation", + "task": "Generate next action for web automation", + "template": "Generate a specific web action to complete: {user_goal}" + } + }, + "actor_agent": { + "action_execution": { + "name": "Action Execution", + "task": "Execute browser actions", + "template": "Execute browser action to: {intention}" + } + }, + "reflector_agent": { + "action_validation": { + "name": "Action Validation", + "task": "Validate web actions", + "template": "Did this action complete the intended task?" + } + } + } + + def get_prompt_template(self, agent_type: str, prompt_name: str) -> Optional[Dict[str, Any]]: + """Get a specific prompt template. + + Args: + agent_type: Type of agent (context_agent, planner_agent, etc.) + prompt_name: Name of the prompt template + + Returns: + Prompt template dictionary or None if not found + """ + try: + return self.prompts_data["agents"][agent_type][prompt_name] + except KeyError: + print(f"⚠️ Prompt not found: {agent_type}.{prompt_name}") + return None + + def format_prompt(self, agent_type: str, prompt_name: str, **kwargs) -> str: + """Format a prompt template with provided variables. + + Args: + agent_type: Type of agent + prompt_name: Name of the prompt template + **kwargs: Variables to substitute in the template + + Returns: + Formatted prompt string + """ + template_data = self.get_prompt_template(agent_type, prompt_name) + if not template_data: + return f"Prompt template not found: {agent_type}.{prompt_name}" + + template = template_data.get("template", "") + if not template: + return f"No template found for: {agent_type}.{prompt_name}" + + try: + return template.format(**kwargs) + except KeyError as e: + print(f"⚠️ Missing variable in prompt template: {e}") + return f"Template error: missing variable {e} in {agent_type}.{prompt_name}" + + def generate_llm_prompt( + self, + agent_type: str, + prompt_name: str, + lm_config: lm_config.LMConfig, + context_vars: Optional[Dict[str, Any]] = None + ) -> str: + """Generate an LLM prompt using the template system. + + Args: + agent_type: Type of agent + prompt_name: Name of the prompt template + lm_config: Language model configuration + context_vars: Variables to substitute in template + + Returns: + Generated prompt string + """ + template_data = self.get_prompt_template(agent_type, prompt_name) + if not template_data: + return f"Prompt template not found: {agent_type}.{prompt_name}" + + template = template_data.get("template", "") + if not template: + return f"No template found for: {agent_type}.{prompt_name}" + + # Add context variables to template if provided + if context_vars: + try: + template = template.format(**context_vars) + except KeyError as e: + print(f"⚠️ Missing variable in prompt template: {e}") + template = f"Template error: missing variable {e}" + + # Get LLM parameters from template + temperature = template_data.get("temperature", 0.7) + max_tokens = template_data.get("max_tokens", 500) + + # Generate the actual prompt using LLM + full_prompt = f"""Task: {template_data.get('task', '')} + +Context: +{template} + +Requirements: +- Be specific and actionable +- Focus on completing the stated intention +- Use appropriate web automation actions +- Handle any errors or issues appropriately""" + + if lm_config.mode == "chat": + messages = [ + {"role": "system", "content": f"You are a web automation assistant. {template_data.get('task', '')}"}, + {"role": "user", "content": full_prompt} + ] + return messages + else: + return full_prompt + + def get_all_prompts_for_agent(self, agent_type: str) -> Dict[str, Any]: + """Get all prompt templates for a specific agent type. + + Args: + agent_type: Type of agent + + Returns: + Dictionary of all prompt templates for the agent + """ + try: + return self.prompts_data["agents"][agent_type] + except KeyError: + print(f"⚠️ Agent type not found in prompts: {agent_type}") + return {} + + def reload_prompts(self) -> None: + """Reload the prompts file.""" + self.prompts_data = self._load_prompts() + print(f"🔄 Reloaded prompts from: {self.prompts_file}") + + def list_available_prompts(self) -> None: + """List all available prompt templates.""" + print("📋 Available Prompt Templates:") + print("=" * 50) + + agents = self.prompts_data.get("agents", {}) + for agent_type, agent_prompts in agents.items(): + print(f"\n🤖 {agent_type}:") + for prompt_name, prompt_data in agent_prompts.items(): + task = prompt_data.get("task", "No task description") + print(f" 📝 {prompt_name}: {task}") + + def validate_prompt_template(self, template_str: str, variables: List[str]) -> Dict[str, Any]: + """Validate that a template contains all required variables. + + Args: + template_str: Template string to validate + variables: List of required variables + + Returns: + Validation result with missing variables + """ + missing_vars = [] + for var in variables: + if f"{{{var}}}" not in template_str: + missing_vars.append(var) + + return { + "valid": len(missing_vars) == 0, + "missing_variables": missing_vars, + "template": template_str + } + + def create_custom_prompt( + self, + agent_type: str, + prompt_name: str, + template: str, + task_description: str, + temperature: float = 0.7, + max_tokens: int = 500 + ) -> Dict[str, Any]: + """Create a custom prompt template and add it to the system. + + Args: + agent_type: Type of agent + prompt_name: Name for the new prompt + template: Template string with variable placeholders + task_description: Description of what the prompt does + temperature: LLM temperature for this prompt + max_tokens: Maximum tokens for this prompt + + Returns: + Created prompt template + """ + custom_prompt = { + "name": prompt_name, + "task": task_description, + "template": template, + "temperature": temperature, + "max_tokens": max_tokens + } + + # Add to prompts data + if "agents" not in self.prompts_data: + self.prompts_data["agents"] = {} + if agent_type not in self.prompts_data["agents"]: + self.prompts_data["agents"][agent_type] = {} + + self.prompts_data["agents"][agent_type][prompt_name] = custom_prompt + + print(f"✅ Created custom prompt: {agent_type}.{prompt_name}") + return custom_prompt + + +# Global prompt loader instance +_global_prompt_loader = None + + +def get_prompt_loader(prompts_file: Optional[str] = None) -> PromptLoader: + """Get the global prompt loader instance.""" + global _global_prompt_loader + if _global_prompt_loader is None: + _global_prompt_loader = PromptLoader(prompts_file) + return _global_prompt_loader + + +def load_prompt_template(agent_type: str, prompt_name: str, **kwargs) -> str: + """Quick function to load and format a prompt template. + + Args: + agent_type: Type of agent + prompt_name: Name of the prompt template + **kwargs: Variables to substitute in template + + Returns: + Formatted prompt string + """ + loader = get_prompt_loader() + return loader.format_prompt(agent_type, prompt_name, **kwargs) + + +def generate_llm_prompt_from_template( + agent_type: str, + prompt_name: str, + lm_config: lm_config.LMConfig, + context_vars: Optional[Dict[str, Any]] = None +) -> str: + """Quick function to generate LLM prompt from template. + + Args: + agent_type: Type of agent + prompt_name: Name of the prompt template + lm_config: Language model configuration + context_vars: Variables to substitute in template + + Returns: + Generated prompt for LLM + """ + loader = get_prompt_loader() + return loader.generate_llm_prompt(agent_type, prompt_name, lm_config, context_vars) \ No newline at end of file diff --git a/agent/reflector/__init__.py b/agent/reflector/__init__.py new file mode 100644 index 0000000..6f02fa4 --- /dev/null +++ b/agent/reflector/__init__.py @@ -0,0 +1,6 @@ +"""Reflector Agent components for execution validation and reflection.""" + +from .effectiveness_analyzer import EffectivenessAnalyzer +from .pattern_detector import PatternDetector + +__all__ = ["EffectivenessAnalyzer", "PatternDetector"] \ No newline at end of file diff --git a/agent/reflector/effectiveness_analyzer.py b/agent/reflector/effectiveness_analyzer.py new file mode 100644 index 0000000..a76ca24 --- /dev/null +++ b/agent/reflector/effectiveness_analyzer.py @@ -0,0 +1,98 @@ +"""Simplified Effectiveness analysis for Reflector Agent.""" + +from typing import Any, Dict, List + +from browser_env import Action, Trajectory +from llms import lm_config, call_llm +from ..prompts.prompt_loader import load_prompt_template + + +class EffectivenessAnalyzer: + """Analyzes the effectiveness of actions in progressing toward the goal.""" + + def __init__(self, lm_config: lm_config.LMConfig) -> None: + self.lm_config = lm_config + + def analyze( + self, + trajectory: Trajectory, + current_intention: str, + latest_action: Action, + context_summary: Dict[str, Any], + ) -> str: + """Analyze the effectiveness of the latest action. + + Args: + trajectory: Current execution trajectory + current_intention: The intention that was being fulfilled + latest_action: The most recently executed action + context_summary: Current context from Context Agent + + Returns: + Natural language response about action effectiveness + """ + # Get recent trajectory context + recent_context = self._extract_recent_context(trajectory, latest_action) + + # Convert text IDs back to readable string + text_ids = latest_action.get("text", []) + if isinstance(text_ids, list) and text_ids: + try: + # Import the ID to key mapping from browser_env + from browser_env.actions import _id2key + action_text = ''.join(_id2key[id_num] if 0 <= id_num < len(_id2key) else '?' for id_num in text_ids) + except (ImportError, IndexError): + # Fallback: try to convert IDs to characters directly + action_text = ''.join(chr(id_num) if 32 <= id_num <= 126 else '?' for id_num in text_ids) + else: + action_text = "N/A" + + # Get context summary text + summary_text = context_summary.get("summary", "No context summary available") + + # Build analysis prompt using template + prompt = load_prompt_template( + "reflector_agent", + "effectiveness_analysis", + user_goal="Web automation task", + trajectory_summary=recent_context, + current_intention=current_intention, + latest_action=f"Type: {latest_action.get('action_type', 'UNKNOWN')}, Element: {latest_action.get('element_id', 'N/A')}, Details: {action_text}", + context_summary=summary_text[:500] if len(summary_text) > 500 else summary_text + ) + + try: + response = call_llm( + self.lm_config, [{"role": "user", "content": prompt}] + ).strip() + return response + except Exception as e: + # Fallback effectiveness analysis + action_type = latest_action.get("action_type", "UNKNOWN") + return f"Fallback analysis: Action '{action_type}' executed for intention: {current_intention[:100]}." + + def _extract_recent_context(self, trajectory: Trajectory, latest_action: Action) -> str: + """Extract relevant context from the recent trajectory.""" + context_parts = [] + + # Get last 3 action-observation pairs + recent_steps = trajectory[-6:] # Last 3 pairs (obs-action-obs-action-obs-action-obs) + + page_count = 1 + action_count = 1 + + for step in recent_steps: + if isinstance(step, dict): + # Check if it's an observation (StateInfo) by looking for 'observation' key + if 'observation' in step: + obs_text = step.get("observation", {}).get("text", "")[:200] + context_parts.append(f"Page {page_count}: {obs_text}...") + page_count += 1 + # Check if it's an action by looking for 'action_type' key + elif 'action_type' in step: + action_type = step.get("action_type", "UNKNOWN") + element_id = step.get("element_id", "N/A") + context_parts.append(f"Action {action_count}: {action_type} on {element_id}") + action_count += 1 + + return " | ".join(context_parts) \ No newline at end of file diff --git a/agent/reflector/pattern_detector.py b/agent/reflector/pattern_detector.py new file mode 100644 index 0000000..a72079e --- /dev/null +++ b/agent/reflector/pattern_detector.py @@ -0,0 +1,330 @@ +"""Pattern detection for Reflector Agent.""" + +from typing import Any, Dict, List + +from browser_env import Action + + +class PatternDetector: + """Detects patterns in action execution and intention sequences.""" + + def __init__(self) -> None: + self.detection_history: List[Dict[str, Any]] = [] + + def detect_patterns( + self, actions: List[Action], intentions: List[str] + ) -> Dict[str, Any]: + """Detect patterns in the execution history. + + Args: + actions: List of executed actions + intentions: List of intentions that were being fulfilled + + Returns: + Dictionary containing detected patterns + """ + if len(actions) < 3: + return { + "patterns_detected": False, + "message": "Insufficient data for pattern detection", + "total_actions": len(actions), + "total_intentions": len(intentions), + } + + patterns = {} + + # Detect repetitive action patterns + patterns["repetitive_action_patterns"] = self._detect_repetitive_actions(actions) + + # Detect repetitive intention patterns + patterns["repetitive_intention_patterns"] = self._detect_repetitive_intentions(intentions) + + # Detect action sequence patterns + patterns["action_sequence_patterns"] = self._detect_action_sequences(actions) + + # Detect failure patterns + patterns["failure_patterns"] = self._detect_failure_patterns(actions) + + # Detect stuck patterns (no progress) + patterns["stuck_patterns"] = self._detect_stuck_patterns(actions, intentions) + + # Calculate overall pattern metrics + patterns["pattern_metrics"] = self._calculate_pattern_metrics(patterns) + + # Store detection results + detection_record = { + "timestamp": None, # Would be set in actual implementation + "total_actions": len(actions), + "total_intentions": len(intentions), + "detected_patterns": patterns, + } + self.detection_history.append(detection_record) + + return { + "patterns_detected": True, + "detected_patterns": patterns, + "total_actions": len(actions), + "total_intentions": len(intentions), + "detection_summary": self._generate_summary(patterns), + } + + def _detect_repetitive_actions(self, actions: List[Action]) -> List[Dict[str, Any]]: + """Detect patterns of repetitive actions.""" + repetitive_patterns = [] + + # Group actions by type + action_groups = {} + for i, action in enumerate(actions): + action_type = action.get("action_type", "UNKNOWN") + if action_type not in action_groups: + action_groups[action_type] = [] + action_groups[action_type].append((i, action)) + + # Look for consecutive repetitions + for action_type, action_list in action_groups.items(): + if len(action_list) < 3: # Need at least 3 to detect pattern + continue + + # Check for consecutive repetitions + consecutive_count = 1 + max_consecutive = 1 + for i in range(1, len(action_list)): + if action_list[i][0] == action_list[i-1][0] + 1: # Consecutive indices + consecutive_count += 1 + max_consecutive = max(max_consecutive, consecutive_count) + else: + consecutive_count = 1 + + if max_consecutive >= 3: # 3 or more consecutive same actions + repetitive_patterns.append({ + "action_type": action_type, + "max_consecutive": max_consecutive, + "total_count": len(action_list), + "pattern_type": "consecutive_repetition", + }) + + # Check for cyclic patterns + if len(actions) >= 6: + last_6_actions = [a.get("action_type", "UNKNOWN") for a in actions[-6:]] + if len(set(last_6_actions)) <= 2: # Only 2 or fewer unique action types + repetitive_patterns.append({ + "action_types": list(set(last_6_actions)), + "sequence": last_6_actions, + "pattern_type": "cyclic_pattern", + }) + + return repetitive_patterns + + def _detect_repetitive_intentions(self, intentions: List[str]) -> List[Dict[str, Any]]: + """Detect patterns of repetitive intentions.""" + repetitive_patterns = [] + + if len(intentions) < 3: + return repetitive_patterns + + # Check for duplicate intentions + intention_counts = {} + for i, intention in enumerate(intentions): + # Normalize intention text for comparison + normalized = intention.lower().strip() + if normalized not in intention_counts: + intention_counts[normalized] = [] + intention_counts[normalized].append((i, intention)) + + # Find repeated intentions + for normalized, occurrences in intention_counts.items(): + if len(occurrences) > 1: + repetitive_patterns.append({ + "intention": occurrences[0][1], # Original text + "count": len(occurrences), + "positions": [occ[0] for occ in occurrences], + "pattern_type": "repeated_intention", + }) + + # Check for very similar intentions + for i in range(len(intentions) - 2): + current = intentions[i].lower() + next_one = intentions[i + 1].lower() + next_two = intentions[i + 2].lower() + + # Calculate similarity (simple word overlap) + similarity_score = self._calculate_similarity(current, next_two) + if similarity_score > 0.8: # High similarity + repetitive_patterns.append({ + "intentions": [intentions[i], intentions[i + 2]], + "similarity_score": similarity_score, + "pattern_type": "similar_intentions", + }) + + return repetitive_patterns + + def _detect_action_sequences(self, actions: List[Action]) -> List[Dict[str, Any]]: + """Detect common action sequences.""" + sequence_patterns = [] + + if len(actions) < 4: + return sequence_patterns + + # Look for 3-action sequences + action_types = [a.get("action_type", "UNKNOWN") for a in actions] + + # Count 3-action sequences + sequence_counts = {} + for i in range(len(action_types) - 2): + sequence = tuple(action_types[i:i+3]) + if sequence not in sequence_counts: + sequence_counts[sequence] = 0 + sequence_counts[sequence] += 1 + + # Find repeated sequences + for sequence, count in sequence_counts.items(): + if count >= 2: # Sequence appeared at least twice + sequence_patterns.append({ + "sequence": list(sequence), + "count": count, + "pattern_type": "repeated_sequence", + }) + + return sequence_patterns + + def _detect_failure_patterns(self, actions: List[Action]) -> List[Dict[str, Any]]: + """Detect patterns related to action failures.""" + failure_patterns = [] + + # Look for NONE actions (indicating failures) + none_actions = [(i, a) for i, a in enumerate(actions) if a.get("action_type") == "NONE"] + + if len(none_actions) > 0: + failure_patterns.append({ + "pattern_type": "none_actions", + "count": len(none_actions), + "positions": [pos for pos, _ in none_actions], + "failure_rate": len(none_actions) / len(actions), + }) + + # Look for repeated failed action types + if len(none_actions) >= 2: + # Check if NONE actions are clustered + for i in range(len(none_actions) - 1): + current_pos = none_actions[i][0] + next_pos = none_actions[i + 1][0] + + if next_pos - current_pos <= 3: # Clustered failures + failure_patterns.append({ + "pattern_type": "clustered_failures", + "start_position": current_pos, + "end_position": next_pos, + "span": next_pos - current_pos, + }) + + return failure_patterns + + def _detect_stuck_patterns(self, actions: List[Action], intentions: List[str]) -> List[Dict[str, Any]]: + """Detect patterns indicating agent is stuck.""" + stuck_patterns = [] + + # Check for high repetition of simple actions + simple_actions = ["CLICK", "SCROLL", "KEY_PRESS"] + simple_action_count = sum(1 for a in actions if a.get("action_type") in simple_actions) + + if simple_action_count / len(actions) > 0.8: # 80% or more are simple actions + stuck_patterns.append({ + "pattern_type": "excessive_simple_actions", + "simple_action_ratio": simple_action_count / len(actions), + "threshold": 0.8, + }) + + # Check for lack of diverse action types + unique_action_types = len(set(a.get("action_type", "UNKNOWN") for a in actions)) + if unique_action_types <= 2 and len(actions) >= 5: + stuck_patterns.append({ + "pattern_type": "low_action_diversity", + "unique_types": unique_action_types, + "total_actions": len(actions), + }) + + # Check for repeating similar intentions + if len(intentions) >= 4: + recent_intentions = intentions[-4:] + normalized = [i.lower().strip() for i in recent_intentions] + unique_normalized = set(normalized) + + if len(unique_normalized) <= 2: # Only 2 or fewer unique intentions + stuck_patterns.append({ + "pattern_type": "repetitive_intentions", + "unique_count": len(unique_normalized), + "total_count": len(recent_intentions), + }) + + return stuck_patterns + + def _calculate_pattern_metrics(self, patterns: Dict[str, Any]) -> Dict[str, Any]: + """Calculate overall metrics about detected patterns.""" + metrics = { + "total_pattern_types": 0, + "high_severity_patterns": 0, + "pattern_density": 0.0, + } + + # Count different pattern types + for pattern_type, pattern_list in patterns.items(): + if pattern_list: # Non-empty pattern list + metrics["total_pattern_types"] += 1 + + # Identify high severity patterns + high_severity_indicators = [ + "clustered_failures", "excessive_simple_actions", "low_action_diversity" + ] + + for pattern_list in patterns.values(): + for pattern in pattern_list: + if pattern.get("pattern_type") in high_severity_indicators: + metrics["high_severity_patterns"] += 1 + + # Calculate pattern density (patterns per action) + total_patterns = sum(len(pattern_list) for pattern_list in patterns.values()) + total_actions = patterns.get("pattern_metrics", {}).get("total_actions", 1) + metrics["pattern_density"] = total_patterns / total_actions if total_actions > 0 else 0.0 + + return metrics + + def _calculate_similarity(self, text1: str, text2: str) -> float: + """Calculate similarity between two text strings.""" + words1 = set(text1.split()) + words2 = set(text2.split()) + + if not words1 and not words2: + return 1.0 + if not words1 or not words2: + return 0.0 + + intersection = words1 & words2 + union = words1 | words2 + + return len(intersection) / len(union) + + def _generate_summary(self, patterns: Dict[str, Any]) -> str: + """Generate a human-readable summary of detected patterns.""" + summary_parts = [] + + for pattern_type, pattern_list in patterns.items(): + if pattern_list: + if pattern_type == "repetitive_action_patterns": + summary_parts.append(f"Found {len(pattern_list)} repetitive action patterns") + elif pattern_type == "failure_patterns": + summary_parts.append(f"Found {len(pattern_list)} failure-related patterns") + elif pattern_type == "stuck_patterns": + summary_parts.append(f"Found {len(pattern_list)} patterns suggesting agent is stuck") + + return "; ".join(summary_parts) if summary_parts else "No significant patterns detected" + + def get_pattern_history(self, count: int = 5) -> List[Dict[str, Any]]: + """Get the most recent pattern detections. + + Args: + count: Number of recent detections to return + + Returns: + List of recent pattern detection records + """ + return self.detection_history[-count:] if self.detection_history else [] \ No newline at end of file diff --git a/agent/reflector_agent.py b/agent/reflector_agent.py new file mode 100644 index 0000000..59baacb --- /dev/null +++ b/agent/reflector_agent.py @@ -0,0 +1,159 @@ +"""Reflector Agent for execution validation, analysis, and recovery suggestions.""" + +from typing import Any, Dict, List, Optional + +from browser_env import Action, Trajectory +from browser_env.utils import Observation +from llms import lm_config, call_llm + +from .reflector.effectiveness_analyzer import EffectivenessAnalyzer +from .reflector.pattern_detector import PatternDetector +from .prompts.prompt_loader import load_prompt_template + + +class ReflectorAgent: + """Simplified reflector agent for execution analysis. + + Responsible for analyzing action effectiveness and detecting execution patterns + to provide insights for better decision making. + """ + + def __init__(self, lm_config: lm_config.LMConfig) -> None: + self.lm_config = lm_config + self.effectiveness_analyzer = EffectivenessAnalyzer(lm_config) + self.pattern_detector = PatternDetector() + + # Reflection history + self.reflection_history: List[Dict[str, Any]] = [] + + def reflect_execution( + self, + trajectory: Trajectory, + intentions: List[str], + actions: List[Action], + current_intention: str, + latest_action: Action, + current_observation: Observation, + context_summary: Dict[str, Any], + ) -> Dict[str, Any]: + """Reflect on the execution of the latest action. + + Args: + trajectory: Current execution trajectory + intentions: List of all intentions so far + actions: List of all actions executed so far + current_intention: The intention that was being fulfilled + latest_action: The most recently executed action + current_observation: The observation after action execution + context_summary: Current context from Context Agent + + Returns: + Dictionary containing simplified reflection results + """ + try: + # 1. Analyze action effectiveness (natural language only) + effectiveness_result = self.effectiveness_analyzer.analyze( + trajectory=trajectory, + current_intention=current_intention, + latest_action=latest_action, + context_summary=context_summary, + ) + + # 2. Detect execution patterns + pattern_result = self.pattern_detector.detect_patterns(actions, intentions) + + # 3. Generate triple summary + triple_summary = self._generate_enhanced_triple_summary( + trajectory, current_intention, latest_action, current_observation + ) + + # Create simplified reflection + reflection = { + "effectiveness_analyzer": effectiveness_result, + "pattern_detector": pattern_result, + "triple_summary": triple_summary, + "current_intention": current_intention, + "latest_action": latest_action, + "reflection_number": len(self.reflection_history) + 1, + } + + # Store in reflection history + self.reflection_history.append(reflection) + + return reflection + + except Exception as e: + # Create error reflection + error_reflection = { + "effectiveness_analyzer": f"Error during effectiveness analysis: {str(e)}", + "pattern_detector": f"Error during pattern detection: {str(e)}", + "triple_summary": f"Error during triple summary generation: {str(e)}", + "current_intention": current_intention, + "latest_action": latest_action, + "reflection_number": len(self.reflection_history) + 1, + } + + self.reflection_history.append(error_reflection) + return error_reflection + + def _generate_enhanced_triple_summary( + self, + trajectory: Trajectory, + current_intention: str, + latest_action: Action, + current_observation: Observation, + ) -> str: + """Generate an enhanced (O_{t-1}, I_t, A_t, O_t) triple summary using LLM.""" + + # Get previous observation from trajectory + obs_before = "No previous observation available" + if len(trajectory) >= 2: + obs_before = trajectory[-3].get("observation", {}).get("text", "")[:200] + if len(trajectory[-3].get("observation", {}).get("text", "")) > 200: + obs_before += "..." + + obs_after = current_observation.get("text", "")[:200] + if len(current_observation.get("text", "")) > 200: + obs_after += "..." + + action_type = latest_action.get("action_type", "UNKNOWN") + element_id = latest_action.get("element_id", "N/A") + + # Correctly convert text IDs back to string + text_ids = latest_action.get("text", []) + if isinstance(text_ids, list) and text_ids: + try: + # Import the ID to key mapping from browser_env + from browser_env.actions import _id2key + action_text = ''.join(_id2key[id_num] if 0 <= id_num < len(_id2key) else '?' for id_num in text_ids) + except (ImportError, IndexError): + # Fallback: try to convert IDs to characters directly + action_text = ''.join(chr(id_num) if 32 <= id_num <= 126 else '?' for id_num in text_ids) + else: + action_text = "N/A" + + # Build enhanced triple summary prompt + prompt = load_prompt_template( + "reflector_agent", + "triple_summary", + **{"t-1": obs_before}, + current_intention=current_intention, + action_type=action_type, + element_id=element_id, + action_text=action_text, + current_observation=obs_after + ) + + try: + response = call_llm( + self.lm_config, [{"role": "user", "content": prompt}] + ).strip() + return response + except Exception as e: + # Fallback simple triple summary + success_indicator = "Success" if action_type not in ["NONE", "STOP"] else "Failed" + return f"Intent: {current_intention[:50]}{'...' if len(current_intention) > 50 else ''} | Action: {action_type} on {element_id} | Result: {success_indicator} (Enhanced summary unavailable: {str(e)})" + + def reset_reflection_history(self) -> None: + """Reset reflection history for a new task.""" + self.reflection_history.clear() \ No newline at end of file diff --git a/browser_env/envs.py b/browser_env/envs.py index ef326bb..f7a8f3f 100644 --- a/browser_env/envs.py +++ b/browser_env/envs.py @@ -137,14 +137,17 @@ def __init__( ) @beartype - def setup(self, config_file: Path | None = None) -> None: + def setup(self, config_file: Path | None = None, instance_config: dict | None = None) -> None: self.context_manager = sync_playwright() self.playwright = self.context_manager.__enter__() self.browser = self.playwright.chromium.launch( headless=self.headless, slow_mo=self.slow_mo ) - if config_file: + if instance_config is not None: + # Use provided instance config + pass + elif config_file: with open(config_file, "r") as f: instance_config = json.load(f) else: @@ -235,19 +238,26 @@ def reset( Reset the environment. :param options: options for the environment. The current supported options are: - "storage_state": the storage state of the browser. It is a file path to a json file. + - "start_url": the initial URL to navigate to. """ super().reset(seed=seed, options=options) if self.reset_finished: self.context_manager.__exit__() - if options is not None and "config_file" in options: + if options is not None and "config_file" in options and options["config_file"] is not None: config_file = Path(options["config_file"]) if config_file.exists(): self.setup(config_file=config_file) else: raise ValueError(f"Config file {config_file} does not exist.") else: - self.setup() + # Create instance config from options + instance_config = {} + if options is not None: + # Copy all options to instance_config to maintain compatibility + # with original config file structure + instance_config.update(options) + self.setup(config_file=None, instance_config=instance_config) self.reset_finished = True self.page.wait_for_timeout(int(self.sleep_after_execution * 1000)) diff --git a/browser_env/helper_functions.py b/browser_env/helper_functions.py index 54dce12..0ced7dc 100644 --- a/browser_env/helper_functions.py +++ b/browser_env/helper_functions.py @@ -57,9 +57,11 @@ def get_render_action( case "som": text_meta_data = observation_metadata["text"] if action["element_id"] in text_meta_data["obs_nodes_info"]: - node_content = text_meta_data["obs_nodes_info"][ + node_info = text_meta_data["obs_nodes_info"][ action["element_id"] ] + # Extract the text representation from the node info + node_content = node_info.get("text", str(node_info)) else: node_content = "No match found" diff --git a/browser_env/processors.py b/browser_env/processors.py index f9eb8bc..900fa61 100644 --- a/browser_env/processors.py +++ b/browser_env/processors.py @@ -619,6 +619,10 @@ def clean_accesibility_tree(tree_str: str) -> str: return "\n".join(clean_lines) def fetch_image_related(self, page: Page, browser_info: BrowserInfo) -> str: + # Skip captioning for image_som observation type when no captioning function is available + if self.observation_type == "image_som" and self.captioning_fn is None: + return "" + # Check if the current page is an image url if page.url.endswith((".jpg", ".jpeg", ".png")): print("NOTE: We are on an image page!!!") diff --git a/config_vlm.json b/config_vlm.json new file mode 100644 index 0000000..fc2f70e --- /dev/null +++ b/config_vlm.json @@ -0,0 +1,32 @@ +{ + "task": { + "start_url": "https://www.baidu.com", + "intent": "Search for information about Yao Ming and Shaquille O'Neal, then calculate the sum of their ages", + "max_steps": 1 + }, + "model": { + "provider": "openai", + "model": "qwen3-vl-max", + "mode": "chat", + "temperature": 1.0, + "top_p": 0.9, + "max_tokens": 8192 + }, + "browser": { + "headless": true, + "slow_mo": 0, + "viewport_width": 1280, + "viewport_height": 2048 + }, + "observation": { + "observation_type": "image_som", + "action_set_tag": "som", + "current_viewport_only": true + }, + "output": { + "result_dir": "demo_multi_agent_vlm", + "save_trace": true, + "render_screenshot": true, + "verbose": true + } +} \ No newline at end of file diff --git a/llms/tokenizers.py b/llms/tokenizers.py index 53d3858..f6661af 100644 --- a/llms/tokenizers.py +++ b/llms/tokenizers.py @@ -7,7 +7,17 @@ class Tokenizer(object): def __init__(self, provider: str, model_name: str) -> None: if provider == "openai": - self.tokenizer = tiktoken.encoding_for_model(model_name) + # Handle custom models that are not recognized by tiktoken + try: + self.tokenizer = tiktoken.encoding_for_model(model_name) + except KeyError: + # For deepseek-chat and other models not recognized by tiktoken, + # use cl100k_base (GPT-4 tokenizer) as default + if "deepseek" in model_name.lower(): + self.tokenizer = tiktoken.get_encoding("cl100k_base") + else: + # Fallback to cl100k_base for other unknown models + self.tokenizer = tiktoken.get_encoding("cl100k_base") elif provider == "huggingface": self.tokenizer = LlamaTokenizer.from_pretrained(model_name) # turn off adding special tokens automatically diff --git a/llms/utils.py b/llms/utils.py index 0b94c52..2df49d3 100644 --- a/llms/utils.py +++ b/llms/utils.py @@ -25,24 +25,37 @@ def call_llm( if lm_config.provider == "openai": if lm_config.mode == "chat": assert isinstance(prompt, list) + # Safely get configuration parameters with defaults + temperature = lm_config.gen_config.get("temperature", 1.0) + top_p = lm_config.gen_config.get("top_p", 0.9) + context_length = lm_config.gen_config.get("context_length", 4096) + max_tokens = lm_config.gen_config.get("max_tokens", 384) + stop_token = lm_config.gen_config.get("stop_token", None) + response = generate_from_openai_chat_completion( messages=prompt, model=lm_config.model, - temperature=lm_config.gen_config["temperature"], - top_p=lm_config.gen_config["top_p"], - context_length=lm_config.gen_config["context_length"], - max_tokens=lm_config.gen_config["max_tokens"], - stop_token=None, + temperature=temperature, + top_p=top_p, + context_length=context_length, + max_tokens=max_tokens, + stop_token=stop_token, ) elif lm_config.mode == "completion": assert isinstance(prompt, str) + # Safely get configuration parameters with defaults + temperature = lm_config.gen_config.get("temperature", 1.0) + max_tokens = lm_config.gen_config.get("max_tokens", 384) + top_p = lm_config.gen_config.get("top_p", 0.9) + stop_token = lm_config.gen_config.get("stop_token", None) + response = generate_from_openai_completion( prompt=prompt, engine=lm_config.model, - temperature=lm_config.gen_config["temperature"], - max_tokens=lm_config.gen_config["max_tokens"], - top_p=lm_config.gen_config["top_p"], - stop_token=lm_config.gen_config["stop_token"], + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + stop_token=stop_token, ) else: raise ValueError( @@ -50,25 +63,37 @@ def call_llm( ) elif lm_config.provider == "huggingface": assert isinstance(prompt, str) + # Safely get configuration parameters with defaults + model_endpoint = lm_config.gen_config.get("model_endpoint", "") + temperature = lm_config.gen_config.get("temperature", 1.0) + top_p = lm_config.gen_config.get("top_p", 0.9) + stop_sequences = lm_config.gen_config.get("stop_sequences", []) + max_new_tokens = lm_config.gen_config.get("max_new_tokens", 384) + response = generate_from_huggingface_completion( prompt=prompt, - model_endpoint=lm_config.gen_config["model_endpoint"], - temperature=lm_config.gen_config["temperature"], - top_p=lm_config.gen_config["top_p"], - stop_sequences=lm_config.gen_config["stop_sequences"], - max_new_tokens=lm_config.gen_config["max_new_tokens"], + model_endpoint=model_endpoint, + temperature=temperature, + top_p=top_p, + stop_sequences=stop_sequences, + max_new_tokens=max_new_tokens, ) elif lm_config.provider == "google": assert isinstance(prompt, list) assert all( [isinstance(p, str) or isinstance(p, Image) for p in prompt] ) + # Safely get configuration parameters with defaults + temperature = lm_config.gen_config.get("temperature", 1.0) + max_tokens = lm_config.gen_config.get("max_tokens", 384) + top_p = lm_config.gen_config.get("top_p", 0.9) + response = generate_from_gemini_completion( prompt=prompt, engine=lm_config.model, - temperature=lm_config.gen_config["temperature"], - max_tokens=lm_config.gen_config["max_tokens"], - top_p=lm_config.gen_config["top_p"], + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, ) else: raise NotImplementedError( diff --git a/multi_agent_config_example.json b/multi_agent_config_example.json new file mode 100644 index 0000000..de0cf37 --- /dev/null +++ b/multi_agent_config_example.json @@ -0,0 +1,38 @@ +{ + "task": { + "start_url": "https://www.baidu.com", + "intent": "Search for information about Yao Ming and Shaquille O'Neal, then calculate the sum of their ages" + }, + "model": { + "provider": "openai", + "model": "qwen3-vl-plus", + "mode": "chat", + "temperature": 1.0, + "top_p": 0.9, + "max_tokens": 2048 + }, + "browser": { + "headless": true, + "slow_mo": 0, + "viewport_width": 1280, + "viewport_height": 2048 + }, + "observation": { + "observation_type": "image_som", + "action_set_tag": "som", + "current_viewport_only": true + }, + "memory": { + "memory_dir": "agent_memories", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + "top_k": 3, + "window_size": 2, + "enable_memory": false, + "enable_memory_store": false + }, + "output": { + "result_dir": "demo_multi_agent", + "render_screenshot": true, + "verbose": true + } +} \ No newline at end of file diff --git a/online_valid.json b/online_valid.json new file mode 100644 index 0000000..e35defb --- /dev/null +++ b/online_valid.json @@ -0,0 +1,422 @@ +[ + { + "task_id": "b320c68bffc1f3c7f2a8dc9d5478fb27", + "confirmed_task": "Find a walkthrough for the game \"The Legend of Zelda: Breath of the Wild\" on ign.", + "website": "https://www.ign.com/", + "reference_length": 6, + "level": "medium" + }, + { + "task_id": "aa4b5cb7114fcc138ade82b4b9716d24", + "confirmed_task": "Find an editor's choice review with a score of 10 in the boardgame category on ign.", + "website": "https://www.ign.com/", + "reference_length": 8, + "level": "medium" + }, + { + "task_id": "644a856c3897665e475e0dce50bf217d", + "confirmed_task": "Find a pair of wireless headphones on Amazon with active noise canceling for $100 or less and add them to the cart.", + "website": "https://www.amazon.com/", + "reference_length": 8, + "level": "medium" + }, + { + "task_id": "561693d6eec7bbfba3fefe9e4b26decb", + "confirmed_task": "Browse Marriott Bonvoy credit cards on Marriott.", + "website": "https://www.marriott.com/", + "reference_length": 4, + "level": "easy" + }, + { + "task_id": "34ccd15a8ea8fd3895af83f5ccf62369", + "confirmed_task": "Find out what to do when I lose an item on a bus on us.megabus.", + "website": "https://us.megabus.com/", + "reference_length": 3, + "level": "easy" + }, + { + "task_id": "d71be72aa25c3eab8eea47a0e60382e2", + "confirmed_task": "Find technical specs for the latest Macbook Air on Apple.", + "website": "https://www.apple.com/", + "reference_length": 4, + "level": "easy" + }, + { + "task_id": "bf3b311cc8dce16d3de844f4b5875dfd", + "confirmed_task": "Compare Apple watches and learn more about the ultra version on apple.", + "website": "https://www.apple.com/", + "reference_length": 4, + "level": "easy" + }, + { + "task_id": "816851ff92ff0219acf4364dcc2c4692_110325", + "confirmed_task": "Show me a list of blue baby boys' pajamas under $40, sorted by rating.", + "website": "https://www.jcpenney.com/", + "reference_length": 9, + "level": "medium" + }, + { + "task_id": "20a460a8fe1971b84411c5b1e6ac4186", + "confirmed_task": "Show theatre events for Las Vegas and select one.", + "website": "https://www.stubhub.com/", + "reference_length": 3, + "level": "easy" + }, + { + "task_id": "75146b7b67388b9244e0f21a1527c022", + "confirmed_task": "Find a male senior boxer near zip code 90028.", + "website": "https://www.adoptapet.com/", + "reference_length": 10, + "level": "medium" + }, + { + "task_id": "eb323dc584156d0eb3a2b90bb8c4b791_110325", + "confirmed_task": "Find the cheapest 2 bed and 3+ bath apartment listing for rent in New York.", + "website": "https://www.compass.com/", + "reference_length": 15, + "level": "hard" + }, + { + "task_id": "2dd41b1d0e8f389d0683f4a4627abfe6", + "confirmed_task": "Show houses for sale in Maryland with a maximum price of $60,000.", + "website": "https://www.landwatch.com/", + "reference_length": 7, + "level": "medium" + }, + { + "task_id": "9c04b71bb8db6cf8e743b2290cbc8797", + "confirmed_task": "Find a UPS drop-off point near Miami Florida.", + "website": "https://www.ups.com/", + "reference_length": 6, + "level": "medium" + }, + { + "task_id": "e7f6cca9a8875f98fee3b711ead3a444_110325", + "confirmed_task": "Find the posts liked by the user Taylor Swift on Tumblr.", + "website": "https://www.tumblr.com/", + "reference_length": 4, + "level": "easy" + }, + { + "task_id": "2207bb4f21786690cfed20b37253fb8b", + "confirmed_task": "Check the current wind speed in Calgary, Alberta.", + "website": "https://www.theweathernetwork.com/", + "reference_length": 2, + "level": "easy" + }, + { + "task_id": "7fff82864f21ddeccf4104a220892824", + "confirmed_task": "Find the lowest 27\"-32\" Samsung or LG computer monitors nearby which have 4k, IPS display.", + "website": "https://www.google.com/shopping?udm=28", + "reference_length": 10, + "level": "medium" + }, + { + "task_id": "ce616721ce9aeda69890fbccb29677a6", + "confirmed_task": "Calculate the price to ship a large flat-rate box from 77449 to 77084 at the first available date and time.", + "website": "https://www.usps.com/", + "reference_length": 8, + "level": "medium" + }, + { + "task_id": "c181f903ec1107b850032c17cad88393", + "confirmed_task": "Help me identify a pink round pill with 150 written on it.", + "website": "https://www.webmd.com/", + "reference_length": 7, + "level": "medium" + }, + { + "task_id": "11abb668c751dd56bb41f296a8bb3a13", + "confirmed_task": "Find a store near zip 30010 that provides authorized Apple services for imacs and make this one my store.", + "website": "https://www.bestbuy.com/", + "reference_length": 10, + "level": "medium" + }, + { + "task_id": "8103786e0e5976ebf961bd062d5f39cd", + "confirmed_task": "Find possible causes for the symptoms of chest pain which is sharp which is accompanied by anxiety.", + "website": "https://www.mayoclinic.org/", + "reference_length": 9, + "level": "medium" + }, + { + "task_id": "987bad7c6d4726d64232a8a1c3386888", + "confirmed_task": "Find the seller info and seller's notes about the used car model 2011 BMW 135 with a max price of $30000.", + "website": "https://www.cars.com/", + "reference_length": 11, + "level": "hard" + }, + { + "task_id": "fd787623166785d84093565bf945fd24", + "confirmed_task": "Check the interaction between Novolin N and Novolin R.", + "website": "https://www.drugs.com/", + "reference_length": 6, + "level": "medium" + }, + { + "task_id": "c3307a70bb12ebf56cc9ec926b368f15", + "confirmed_task": "Find the interactions between Eulexin and hepatic dysfunction.", + "website": "https://www.drugs.com/", + "reference_length": 5, + "level": "easy" + }, + { + "task_id": "5e4e89c9b6fdaee7a41aca5601b82e04", + "confirmed_task": "Identify a pill with a pink color and oval shape with 894 5 number on it.", + "website": "https://www.drugs.com/", + "reference_length": 8, + "level": "medium" + }, + { + "task_id": "60cbbbd58eb9d28b053aef945f464228", + "confirmed_task": "Look up if the phone number 555555555 is a scam.", + "website": "https://www.bbb.org/", + "reference_length": 6, + "level": "medium" + }, + { + "task_id": "1df24ec81137386d6476bcf343a79012", + "confirmed_task": "Search for NordicTrack with the lowest price.", + "website": "https://www.bestbuy.com/", + "reference_length": 6, + "level": "medium" + }, + { + "task_id": "b1ce968a361e1088ce8d2ade6c2c9af0", + "confirmed_task": "Find young cats in Seattle and show off the newest additions.", + "website": "https://www.petfinder.com/", + "reference_length": 6, + "level": "medium" + }, + { + "task_id": "4e3f6a538cc1f7321cfc50260db9545d", + "confirmed_task": "Look up the current temperature for zip code 10019.", + "website": "https://www.theweathernetwork.com/", + "reference_length": 2, + "level": "easy" + }, + { + "task_id": "3ae28b3c440efe87dc700480b78ac608", + "confirmed_task": "Find the closest 5-star rated dentist to zip code 98011.", + "website": "https://www.healthgrades.com/", + "reference_length": 9, + "level": "medium" + }, + { + "task_id": "690d7b4a285fdb1e9dabf973bf46ae4d", + "confirmed_task": "Browse iPhone X for sale that is in good condition, has a max price of 400, and searches in titles only.", + "website": "https://craigslist.org/", + "reference_length": 8, + "level": "medium" + }, + { + "task_id": "73d08420706ae205a9c5be28b6d4e80f", + "confirmed_task": "Show me the rules and cancellation for Alley Spring.", + "website": "https://www.recreation.gov/", + "reference_length": 3, + "level": "easy" + }, + { + "task_id": "27fa3ac20745d3d35e89fae157f63069", + "confirmed_task": "Browse the class schedule of graduate-level chemistry courses on Monday afternoons in the winter of 2023.", + "website": "https://www.stanford.edu/", + "reference_length": 11, + "level": "hard" + }, + { + "task_id": "b4aa7315e31dfcdc52baf7771be260c9", + "confirmed_task": "Find the HGX H100 driver for Ubuntu 22.04 on AMD64 CPU.", + "website": "https://www.nvidia.com/", + "reference_length": 11, + "level": "hard" + }, + { + "task_id": "9ed3827266b3b804f485859c3d00401e", + "confirmed_task": "If I'm 30, plan to retire at 65, and can save $300/month, with a 3% annual return, 13% current tax rate, and 24% retirement tax rate, show the comparison chart between Traditional and Roth IRA.", + "website": "https://www.chase.com/", + "reference_length": 12, + "level": "hard" + }, + { + "task_id": "7c09c2c7c87cf6bb1138701eb54284ea", + "confirmed_task": "Find the comments for the most popular news in the past month under the Quantum Physics topic.", + "website": "https://phys.org/", + "reference_length": 8, + "level": "medium" + }, + { + "task_id": "afcebfed28bea091d58f49ea6cb8194b", + "confirmed_task": "Find the most reviewed gluten-free multivitamins from CVS Health Brand under $15.", + "website": "https://www.cvs.com/", + "reference_length": 12, + "level": "hard" + }, + { + "task_id": "33bd2cdcea4fcc42a09a8a1e4e5841c6_110325", + "confirmed_task": "Add a Box Combo to my bag with Diet Coke as the drink, and a Kids Combo with milk as the drink. Select the store closest to ZIP 10001 for pickup tomorrow at 12:00 PM.", + "website": "https://raisingcanes.com/", + "reference_length": 20, + "level": "hard" + }, + { + "task_id": "47186fac8e7c7277af01144644eb4e0b", + "confirmed_task": "What is the ownership cost of the first car in the list \"top buys 2025\"?", + "website": "https://www.parkers.co.uk/", + "reference_length": 3, + "level": "easy" + }, + { + "task_id": "fa9adb815b85d259f943d81874a052e5", + "confirmed_task": "Browse a user homepage that reposted the top song from the Top 50 Rock chart.", + "website": "https://soundcloud.com/", + "reference_length": 6, + "level": "medium" + }, + { + "task_id": "b922508886ded315c9835457a6eb43ea", + "confirmed_task": "Browse tenured/tenure-track faculty positions in Computer Sciences & Technology in California.", + "website": "https://jobs.chronicle.com", + "reference_length": 6, + "level": "medium" + }, + { + "task_id": "01abae9608f2d8752a83e08f136f720c", + "confirmed_task": "Show me the code for the company that is the top mover in the Cboe Europe Technology Sector Index (BEPTEC) as of the latest market close.", + "website": "https://www.cboe.com/", + "reference_length": 6, + "level": "medium" + }, + { + "task_id": "da8f3823a827c7d3a492f383808e7912", + "confirmed_task": "Find and open the earliest press release.", + "website": "https://www.instructure.com/", + "reference_length": 6, + "level": "medium" + }, + { + "task_id": "3dca7cbe7d086619d837ff9f5312cebc_110325", + "confirmed_task": "Can you show me products under the category path 'zara home' -> 'rug', with an additional filter for the color beige?", + "website": "https://zara.com/us", + "reference_length": 5, + "level": "easy" + }, + { + "task_id": "c7c07ec10c668625a21ba64165d719bb", + "confirmed_task": "Find the total monthly price for four prepaid unlimited lines without autopay discounts.", + "website": "https://www.verizon.com/", + "reference_length": 8, + "level": "medium" + }, + { + "task_id": "d9d8b7d84a3f8d057e368254fe8d65e2", + "confirmed_task": "Find the first commit submitted by NielsRogge to the official repository of the SAM2 model.", + "website": "https://github.com/", + "reference_length": 8, + "level": "medium" + }, + { + "task_id": "157f4a79d55e8fa3fd55ba772ba40fbc", + "confirmed_task": "Find the most popular blue Lilo & Stitch toys.", + "website": "https://www.disney.com/", + "reference_length": 9, + "level": "medium" + }, + { + "task_id": "62c8d970b3d13891f355911e5a8f4030", + "confirmed_task": "Find the top game listed in the Steam Deck's top-played list over the past year. Then, browse reviews for that game from players who have played over 100 hours and primarily use a Steam Deck.", + "website": "https://store.steampowered.com/", + "reference_length": 9, + "level": "medium" + }, + { + "task_id": "47bfe8a7e0e4e7efc837287b407fbe90", + "confirmed_task": "Compare the first and second most popular smartphones manufactured by Xiaomi and show the comparison chart.", + "website": "https://versus.com/", + "reference_length": 10, + "level": "medium" + }, + { + "task_id": "fe33894188d20d7469f37a9fd855e7ff_110325", + "confirmed_task": "Show me the most popular open-source software on SourceForge that runs on Linux and macOS, belongs to the category of “Agentic AI”, is licensed under MIT, and has a production/stable development status.", + "website": "https://sourceforge.net/", + "reference_length": 11, + "level": "hard" + }, + { + "task_id": "71f8de1834599fba443f40dbbfab8edd", + "confirmed_task": "Search for papers related to reinforcement learning under the topics of computer science and mathematics on arxiv, with recent submission dates between September 2024 and January 2025.", + "website": "https://arxiv.org/", + "reference_length": 11, + "level": "hard" + }, + { + "task_id": "c8c1ff115879b3afd14280beb1559b13", + "confirmed_task": "Find the latest Doraemon video in MP4 format that is over 20 minutes long and has a medium file size.", + "website": "https://www.4shared.com/", + "reference_length": 12, + "level": "hard" + }, + { + "task_id": "d4fb78b7e74508cd3b33f01cf9200997", + "confirmed_task": "Show the figure comparing Occupational Fatalities Trends between Ohio and New York.", + "website": "https://www.americashealthrankings.org/", + "reference_length": 12, + "level": "hard" + }, + { + "task_id": "c3a333968fc3c43d7f2688f425a0d633", + "confirmed_task": "Find the cheapest certified pre-owned Porsche 911 with a model year of 2019 or newer, within a 200-mile radius of ZIP code 97007.", + "website": "https://www.porsche.com/", + "reference_length": 15, + "level": "hard" + }, + { + "task_id": "c6c9dc6079677cef594cec2fa6b16602", + "confirmed_task": "Add the cheapest black sofa with at least three seats, a leather finish, and at least four stars to my cart.", + "website": "https://www.ikea.com/", + "reference_length": 16, + "level": "hard" + }, + { + "task_id": "c39d6c245f8243993e707d54d2f4acec", + "confirmed_task": "Browse the final skin in the list for the champion Ahri.", + "website": "https://www.leagueoflegends.com/", + "reference_length": 18, + "level": "hard" + }, + { + "task_id": "ba01ea557b73f864c35ebba0dd6f3cb2", + "confirmed_task": "Find the top-rated hotel in Manhattan, NY, suitable for 4 guests, and identify the fastest public transportation option from the hotel to LGA airport.", + "website": "https://www.google.com/maps/", + "reference_length": 14, + "level": "hard" + }, + { + "task_id": "a96fca87a17d792644e736d1d10d3cbe", + "confirmed_task": "View the pricing plan for 'Business'. Specifically, we have 100 users. We need a 1PB storage quota and a 50 TB transfer quota.", + "website": "https://mega.io/", + "reference_length": 5, + "level": "easy" + }, + { + "task_id": "d1970c16271496cbbe166ecbecc0a1d8", + "confirmed_task": "I'm 25 and located in Texas. Shop for 2020 made dry red wine made in United States priced between 15-20 dollars and add 5 bottles to the cart.", + "website": "https://macyswineshop.com/", + "reference_length": 13, + "level": "hard" + }, + { + "task_id": "28e7574e7bd6d14f36d2988a5ef2bd23", + "confirmed_task": "Get a part-time job within 5 miles of Moscow, Idaho in the accommodation and food services industry, as a chef, and show jobs for corporate only.", + "website": "https://ohiomeansjobs.ohio.gov/", + "reference_length": 12, + "level": "hard" + }, + { + "task_id": "1223b07536a87e0170ff87cbbebd1d3c", + "confirmed_task": "Complete a multiplication quiz on https://www.coolmath4kids.com/, covering multiplication facts for 11-12. The quiz should consist of 10 questions, with unlimited time allowed for each. The goal is to achieve a perfect score of 10 out of 10.", + "website": "https://www.coolmath4kids.com/", + "reference_length": 24, + "level": "hard" + } +] \ No newline at end of file diff --git a/run_ma_simple.sh b/run_ma_simple.sh new file mode 100755 index 0000000..89c7541 --- /dev/null +++ b/run_ma_simple.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Simple Multi-Agent Runner Script +# Usage: ./run_ma_simple.sh + +export DATASET=visualwebarena +export HF_ENDPOINT=https://hf-mirror.com +export CLASSIFIEDS=":9980" +export CLASSIFIEDS_RESET_TOKEN="4b61655535e7ed388f0d40a93600254c" # Default reset token for classifieds site, change if you edited its docker-compose.yml +export SHOPPING=":7770" +export REDDIT=":9999" +export WIKIPEDIA=":8888" +export HOMEPAGE=":4399" + +export SHOPPING_ADMIN=":7780/admin" +export GITLAB=":8023" +export MAP=":3000" + +# # Environment variables (modify as needed) +# export OPENAI_API_KEY=sk-ba12564bebdb4f129f91944b55147971 +# export OPENAI_BASE_URL=https://api.deepseek.com/v1 + +export OPENAI_API_KEY="sk-5b6e216d66e1454cae552a7f463f1233" +export OPENAI_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1" + +echo "🚀 Starting Multi-Agent Web Arena" +echo "📋 Task: Search for Yao Ming's age and Shaquille O'Neal's age, then calculate the sum" +echo "🌐 URL: https://www.baidu.com" +echo "🧠 Model: qwen3-vl-plus" + +# Run the multi-agent script +# python run_multi_agent.py \ +# --start_url "https://www.baidu.com" \ +# --intent "Search yaoming's age and shaquille o'neal's age, tell me the sum of their ages" \ +# --max_steps 4 \ +# --config_file multi_agent_config_example.json +# # --task_list_file Online_Mind2Web.json +# # --start_id 0 +# # --end_id 10 + +python run_multi_agent_multiple_tasks.py \ + --start_url "https://www.baidu.com" \ + --intent "Search yaoming's age and shaquille o'neal's age, tell me the sum of their ages" \ + --config_file multi_agent_config_example.json \ + --task_list_file online_valid.json \ + --start_id 0 \ + --end_id 3 \ No newline at end of file diff --git a/run_multi_agent.py b/run_multi_agent.py new file mode 100644 index 0000000..23188ed --- /dev/null +++ b/run_multi_agent.py @@ -0,0 +1,370 @@ +"""Multi-Agent Web Arena Runner. + +This module implements a multi-agent system for web automation tasks, +using Context, Planner, Actor, and Reflector agents. +""" + +import argparse +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Optional + + + +import torch + +from browser_env import ( + ScriptBrowserEnv, + Action, + create_id_based_action, +) +from browser_env.utils import Observation +from agent import PromptAgent +from llms import lm_config, call_llm +from PIL import Image +from evaluation_harness import image_utils + +# Import the full multi-agent coordinator +from agent.multi_agent_coordinator import MultiAgentCoordinator +from agent.prompts.prompt_constructor import PromptConstructor, DirectPromptConstructor, CoTPromptConstructor, MultimodalCoTPromptConstructor +from llms import lm_config + + +def generate_execution_summary(execution_result: Dict[str, Any]) -> str: + """Generate a human-readable execution summary from execution result. + + Args: + execution_result: The execution result from multi-agent coordinator + + Returns: + Human-readable summary string + """ + lines = [] + lines.append("Multi-Agent Web Automation Execution Summary") + lines.append("=" * 50) + lines.append("") + + # Task information + if "goal" in execution_result: + goal = execution_result["goal"] + lines.append("Task Information:") + lines.append(f" Goal: {goal}") + lines.append(f" Completed: {'Yes' if execution_result.get('success_rate', 0) >= 0.8 else 'No'}") + lines.append(f" Completion: {execution_result.get('success_rate', 0) * 100:.1f}%") + lines.append(f" Steps Executed: {execution_result.get('total_steps', 0)}") + if 'execution_time_formatted' in execution_result: + lines.append(f" Execution Time: {execution_result['execution_time_formatted']}") + lines.append("") + + # Agent performance + if 'reflections' in execution_result and execution_result['reflections']: + lines.append("Agent Performance:") + + # Count successful actions + actions = execution_result.get('actions', []) + successful_actions = sum(1 for action in actions if action.get('action_type') != 'NONE') + total_actions = len(actions) + + lines.append(f" actor_agent:") + lines.append(f" total_intentions: {len(execution_result.get('intentions', []))}") + lines.append(f" successful_actions: {successful_actions}") + lines.append(f" failed_actions: {total_actions - successful_actions}") + if total_actions > 0: + fulfillment_rate = successful_actions / total_actions + lines.append(f" fulfillment_rate: {fulfillment_rate * 100:.1f}%") + + # Count reflection success + reflections = execution_result['reflections'] + successful_reflections = sum(1 for reflection in reflections if reflection.get('success', False)) + helpful_reflections = sum(1 for reflection in reflections if reflection.get('helpful', False)) + stuck_reflections = sum(1 for reflection in reflections if reflection.get('stuck', False)) + + lines.append(f" reflector_agent:") + lines.append(f" total_reflections: {len(reflections)}") + lines.append(f" successful_reflections: {successful_reflections}") + lines.append(f" helpful_reflections: {helpful_reflections}") + lines.append(f" stuck_reflections: {stuck_reflections}") + if len(reflections) > 0: + success_rate = successful_reflections / len(reflections) + helpful_rate = helpful_reflections / len(reflections) + lines.append(f" success_rate: {success_rate * 100:.1f}%") + lines.append(f" helpful_rate: {helpful_rate * 100:.1f}%") + lines.append(f" stuck_rate: {stuck_reflections * 100:.1f}%") + + lines.append("") + + # Overall assessment + lines.append("Overall Assessment:") + success_rate = execution_result.get('success_rate', 0) + lines.append(f" Success: {'Yes' if success_rate >= 0.8 else 'No'}") + lines.append(f" Success Rate: {success_rate * 100:.1f}%") + lines.append(f" Total Actions: {len(execution_result.get('actions', []))}") + lines.append(f" Total Steps: {execution_result.get('total_steps', 0)}") + lines.append("") + + return "\n".join(lines) + + +def config(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Multi-Agent Web Arena Runner") + + # Required arguments + parser.add_argument("--config_file", type=str, required=True, + help="Path to JSON configuration file") + + # Commonly overridden arguments (for convenience) + parser.add_argument("--start_url", type=str, + help="Override starting URL (overrides config file)") + parser.add_argument("--intent", type=str, + help="Override task intent (overrides config file)") + parser.add_argument("--max_steps", type=int, + help="Override maximum steps (overrides config file)") + parser.add_argument("--result_dir", type=str, + help="Override result directory (overrides config file)") + + # Debugging options + parser.add_argument("--verbose", action="store_true", + help="Enable verbose output (overrides config file)") + parser.add_argument("--dry_run", action="store_true", + help="Show configuration without executing") + + return parser.parse_args() + + +def load_config_file(config_file: str) -> Dict[str, Any]: + """Load configuration from JSON file.""" + try: + with open(config_file, 'r') as f: + return json.load(f) + except FileNotFoundError: + return {} + except json.JSONDecodeError as e: + return {} + + +def merge_config_with_args(config: Dict[str, Any], args) -> Dict[str, Any]: + """Merge loaded config with command line arguments.""" + merged = config.copy() + + # Only process a simplified set of command line arguments + for key, value in vars(args).items(): + if value is not None and key not in ["config_file", "dry_run"]: + # Handle overrides for commonly changed parameters + if key in ["start_url", "intent", "max_steps"]: + if "task" not in merged: + merged["task"] = {} + merged["task"][key] = value + elif key in ["result_dir", "verbose"]: + if "output" not in merged: + merged["output"] = {} + merged["output"][key] = value + else: + # Direct override for any other arguments + merged[key] = value + + return merged + + +def test(args, config_file): + """Run the multi-agent system.""" + + # Load and merge configuration + file_config = load_config_file(config_file) + config = merge_config_with_args(file_config, args) + + # Handle dry run + if args.dry_run: + return + + # Setup result directory + result_dir = config.get('output', {}).get('result_dir', 'results') + if not Path(result_dir).exists(): + Path(result_dir).mkdir(parents=True, exist_ok=True) + print(f"Created result directory: {result_dir}") + + # Add result_dir to config for coordinator + if 'output' not in config: + config['output'] = {} + config['output']['result_dir'] = result_dir + + # Import the full multi-agent coordinator + from agent.multi_agent_coordinator import MultiAgentCoordinator + from agent import PromptAgent + from agent.prompts.prompt_constructor import PromptConstructor + from llms import lm_config + + # Create LM config + try: + # Extract model config from the config dictionary + model_config = config.get('model', {}) + + # Create LMConfig directly from the dictionary + lm_cfg = lm_config.LMConfig( + provider=model_config.get('provider', 'openai'), + model=model_config.get('model', 'gpt-4'), + mode=model_config.get('mode', 'chat') + ) + + # Add generation config if available + if model_config: + lm_cfg.gen_config.update({ + 'temperature': model_config.get('temperature', 1.0), + 'top_p': model_config.get('top_p', 0.9), + 'max_tokens': model_config.get('max_tokens', 384), + 'context_length': model_config.get('context_length', 0), + 'stop_token': model_config.get('stop_token', None), + 'max_obs_length': model_config.get('max_obs_length', 0), + 'max_retry': model_config.get('max_retry', 3) + }) + except (KeyError, AttributeError) as e: + # Fallback to minimal config if required fields missing + lm_cfg = lm_config.LMConfig( + provider=config.get('model', {}).get('provider', 'openai'), + model=config.get('model', {}).get('model', 'gpt-4'), + mode=config.get('model', {}).get('mode', 'chat') + ) + + # Get browser environment configuration + browser_config = config.get('browser', {}) + observation_type = config.get('observation', {}).get('observation_type', 'accessibility_tree') + + # Load captioning model if needed (similar to run.py) + caption_image_fn = None + if observation_type in [ + "accessibility_tree_with_captioner", + # "image_som", + ]: + device = torch.device("cuda") if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + captioning_model = config.get('model', {}).get('captioning_model', 'Salesforce/blip2-flan-t5-xl') + caption_image_fn = image_utils.get_captioning_fn( + device, dtype, captioning_model + ) + + # Build viewport_size from config + viewport_size = { + "width": browser_config.get('viewport_width', 1280), + "height": browser_config.get('viewport_height', 720), + } + + # Create browser environment + env = ScriptBrowserEnv( + headless=browser_config.get('headless', False), # Set to False for debugging + slow_mo=browser_config.get('slow_mo', 100), + observation_type=observation_type, + current_viewport_only=browser_config.get('current_viewport_only', True), + viewport_size=viewport_size, + save_trace_enabled=browser_config.get('save_trace_enabled', True), + sleep_after_execution=browser_config.get('sleep_after_execution', 0.5), + captioning_fn=caption_image_fn, + ) + + # Determine if model is multimodal and select appropriate prompt constructor + from llms.tokenizers import Tokenizer + + model_name = lm_cfg.model.lower() + is_multimodal_model = ( + "gemini" in model_name or + ("gpt-4" in model_name and "vision" in model_name) + ) + is_image_observation = observation_type in ["image", "image_som"] + + # Get instruction path from config or use default + instruction_path = config.get('instruction_path') + if not instruction_path: + # Select default instruction path based on observation type and model + if is_multimodal_model and is_image_observation: + instruction_path = 'agent/prompts/jsons/p_multimodal_cot_id_actree_3s.json' + else: + instruction_path = 'agent/prompts/jsons/p_cot_id_actree_3s.json' + + # Load instruction to check prompt_constructor type + with open(instruction_path) as f: + instruction_data = json.load(f) + constructor_type = instruction_data.get("meta_data", {}).get("prompt_constructor", "DirectPromptConstructor") + + # Create appropriate prompt constructor + tokenizer = Tokenizer(lm_cfg.provider, lm_cfg.model) + if constructor_type == "MultimodalCoTPromptConstructor": + prompt_constructor = MultimodalCoTPromptConstructor( + instruction_path=instruction_path, + lm_config=lm_cfg, + tokenizer=tokenizer + ) + elif constructor_type == "CoTPromptConstructor": + prompt_constructor = CoTPromptConstructor( + instruction_path=instruction_path, + lm_config=lm_cfg, + tokenizer=tokenizer + ) + else: + prompt_constructor = DirectPromptConstructor( + instruction_path=instruction_path, + lm_config=lm_cfg, + tokenizer=tokenizer + ) + + # Create base prompt agent for multi-agent coordinator + # Use action_set_tag from configuration instead of hardcoding + action_set_tag = config.get('observation', {}).get('action_set_tag', 'id_accessibility_tree') + base_agent = PromptAgent( + action_set_tag=action_set_tag, + lm_config=lm_cfg, + prompt_constructor=prompt_constructor, + captioning_fn=caption_image_fn if observation_type == "accessibility_tree_with_captioner" else None, + ) + + + task_cfg = config.get('task', {}) + output_cfg = config.get('output', {}) + task_metadata_base = config.get('task_metadata') or { + "task_id": task_cfg.get('task_id'), + "task": task_cfg.get('intent'), + "confirmed_task": task_cfg.get('confirmed_task'), + "website": task_cfg.get('website') or task_cfg.get('start_url'), + "reference_length": task_cfg.get('reference_length'), + "level": task_cfg.get('level'), + } + task_metadata_base = {k: v for k, v in task_metadata_base.items() if v is not None} # 只保留有值的元数据 + webjudge_root = output_cfg.get('webjudge_root') # 可自定义 WebJudge 输出根目录 + + # Create multi-agent coordinator with browser environment + coordinator = MultiAgentCoordinator(lm_cfg, + base_agent, + browser_env=env, + result_dir=result_dir, + memory_config=config.get('memory', {}), + webjudge_result_root=webjudge_root) + + # Execute workflow with initial observation from browser + # Use start_url from config if available + start_url = config.get('task', {}).get('start_url') + reset_options = {} + if start_url: + reset_options["start_url"] = start_url + + initial_obs, initial_info = env.reset(options=reset_options if reset_options else None) + initial_observation = {"observation": initial_obs, "info": initial_info} + + + + + + result = coordinator.execute_task( + user_goal=config.get('task', {}).get('intent', 'Not specified'), + start_observation=initial_observation, + max_steps=config.get('task', {}).get('max_steps', 3), + task_metadata=task_metadata_base, + webjudge_root=webjudge_root + + ) + + # Return the execution result + return result + + +if __name__ == "__main__": + args = config() + test(args, args.config_file) \ No newline at end of file diff --git a/run_multi_agent_multiple_tasks.py b/run_multi_agent_multiple_tasks.py new file mode 100644 index 0000000..beb2920 --- /dev/null +++ b/run_multi_agent_multiple_tasks.py @@ -0,0 +1,431 @@ +"""Multi-Agent Web Arena Runner. + +This module implements a multi-agent system for web automation tasks, +using Context, Planner, Actor, and Reflector agents. +""" + +import argparse +import json +import os +import requests +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch + +import traceback + +from browser_env import ( + ScriptBrowserEnv, + Action, + create_id_based_action, +) +from browser_env.utils import Observation +from agent import PromptAgent +from llms import lm_config, call_llm +from PIL import Image +from evaluation_harness import image_utils + +# Import the full multi-agent coordinator +from agent.multi_agent_coordinator import MultiAgentCoordinator +from agent.prompts.prompt_constructor import PromptConstructor, DirectPromptConstructor, CoTPromptConstructor, MultimodalCoTPromptConstructor +from llms import lm_config + + +def generate_execution_summary(execution_result: Dict[str, Any]) -> str: + """Generate a human-readable execution summary from execution result. + + Args: + execution_result: The execution result from multi-agent coordinator + + Returns: + Human-readable summary string + """ + lines = [] + lines.append("Multi-Agent Web Automation Execution Summary") + lines.append("=" * 50) + lines.append("") + + # Task information + if "goal" in execution_result: + goal = execution_result["goal"] + lines.append("Task Information:") + lines.append(f" Goal: {goal}") + lines.append(f" Completed: {'Yes' if execution_result.get('success_rate', 0) >= 0.8 else 'No'}") + lines.append(f" Completion: {execution_result.get('success_rate', 0) * 100:.1f}%") + lines.append(f" Steps Executed: {execution_result.get('total_steps', 0)}") + if 'execution_time_formatted' in execution_result: + lines.append(f" Execution Time: {execution_result['execution_time_formatted']}") + lines.append("") + + # Agent performance + if 'reflections' in execution_result and execution_result['reflections']: + lines.append("Agent Performance:") + + # Count successful actions + actions = execution_result.get('actions', []) + successful_actions = sum(1 for action in actions if action.get('action_type') != 'NONE') + total_actions = len(actions) + + lines.append(f" actor_agent:") + lines.append(f" total_intentions: {len(execution_result.get('intentions', []))}") + lines.append(f" successful_actions: {successful_actions}") + lines.append(f" failed_actions: {total_actions - successful_actions}") + if total_actions > 0: + fulfillment_rate = successful_actions / total_actions + lines.append(f" fulfillment_rate: {fulfillment_rate * 100:.1f}%") + + # Count reflection success + reflections = execution_result['reflections'] + successful_reflections = sum(1 for reflection in reflections if reflection.get('success', False)) + helpful_reflections = sum(1 for reflection in reflections if reflection.get('helpful', False)) + stuck_reflections = sum(1 for reflection in reflections if reflection.get('stuck', False)) + + lines.append(f" reflector_agent:") + lines.append(f" total_reflections: {len(reflections)}") + lines.append(f" successful_reflections: {successful_reflections}") + lines.append(f" helpful_reflections: {helpful_reflections}") + lines.append(f" stuck_reflections: {stuck_reflections}") + if len(reflections) > 0: + success_rate = successful_reflections / len(reflections) + helpful_rate = helpful_reflections / len(reflections) + lines.append(f" success_rate: {success_rate * 100:.1f}%") + lines.append(f" helpful_rate: {helpful_rate * 100:.1f}%") + lines.append(f" stuck_rate: {stuck_reflections * 100:.1f}%") + + lines.append("") + + # Overall assessment + lines.append("Overall Assessment:") + success_rate = execution_result.get('success_rate', 0) + lines.append(f" Success: {'Yes' if success_rate >= 0.8 else 'No'}") + lines.append(f" Success Rate: {success_rate * 100:.1f}%") + lines.append(f" Total Actions: {len(execution_result.get('actions', []))}") + lines.append(f" Total Steps: {execution_result.get('total_steps', 0)}") + lines.append("") + + return "\n".join(lines) + + +def config(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Multi-Agent Web Arena Runner") + + # Required arguments + parser.add_argument("--config_file", type=str, required=True, + help="Path to JSON configuration file") + parser.add_argument("--task_list_file", type=str, + help="Path to JSON task list (e.g., Online_Mind2Web.json)") + + # Commonly overridden arguments (for convenience) + parser.add_argument("--start_url", type=str, + help="Override starting URL (overrides config file)") + parser.add_argument("--intent", type=str, + help="Override task intent (overrides config file)") + parser.add_argument("--max_steps", type=int, + help="Override maximum steps (overrides config file)") + parser.add_argument("--result_dir", type=str, + help="Override result directory (overrides config file)") + + # Debugging options + parser.add_argument("--verbose", action="store_true", + help="Enable verbose output (overrides config file)") + parser.add_argument("--dry_run", action="store_true", + help="Show configuration without executing") + parser.add_argument("--start_id", type=int, + help="the start id of the task list") + parser.add_argument("--end_id", type=int, + help="the end id of the task list") #设置起始di和终止id + + return parser.parse_args() + + +def load_config_file(config_file: str) -> Dict[str, Any]: + """Load configuration from JSON file.""" + try: + with open(config_file, 'r') as f: + return json.load(f) + except FileNotFoundError: + return {} + except json.JSONDecodeError as e: + return {} + + +def merge_config_with_args(config: Dict[str, Any], args) -> Dict[str, Any]: + """Merge loaded config with command line arguments.""" + merged = config.copy() + + # Only process a simplified set of command line arguments + for key, value in vars(args).items(): + if value is not None and key not in ["config_file", "dry_run"]: + # Handle overrides for commonly changed parameters + if key in ["start_url", "intent", "max_steps"]: + if "task" not in merged: + merged["task"] = {} + merged["task"][key] = value + elif key in ["result_dir", "verbose"]: + if "output" not in merged: + merged["output"] = {} + merged["output"][key] = value + elif key == "task_list_file": + merged["task_list_file"] = value + else: + # Direct override for any other arguments + merged[key] = value + + return merged + + +def test(args, config_file): + """Run the multi-agent system.""" + + # Load and merge configuration + file_config = load_config_file(config_file) + config = merge_config_with_args(file_config, args) + + # Handle dry run + if args.dry_run: + return + + # Setup result directory + result_dir = config.get('output', {}).get('result_dir', 'results') + if not Path(result_dir).exists(): + Path(result_dir).mkdir(parents=True, exist_ok=True) + print(f"Created result directory: {result_dir}") + + # Add result_dir to config for coordinator + if 'output' not in config: + config['output'] = {} + config['output']['result_dir'] = result_dir #因为result_dir是从命令行中读取进来的 + + # Import the full multi-agent coordinator + from agent.multi_agent_coordinator import MultiAgentCoordinator + from agent import PromptAgent + from agent.prompts.prompt_constructor import PromptConstructor + from llms import lm_config + + # Create LM config + try: + # Extract model config from the config dictionary + model_config = config.get('model', {}) + + # Create LMConfig directly from the dictionary + lm_cfg = lm_config.LMConfig( + provider=model_config.get('provider', 'openai'), + model=model_config.get('model', 'gpt-4'), + mode=model_config.get('mode', 'chat') + ) + + # Add generation config if available + if model_config: + lm_cfg.gen_config.update({ + 'temperature': model_config.get('temperature', 1.0), + 'top_p': model_config.get('top_p', 0.9), + 'max_tokens': model_config.get('max_tokens', 384), + 'context_length': model_config.get('context_length', 0), + 'stop_token': model_config.get('stop_token', None), + 'max_obs_length': model_config.get('max_obs_length', 0), + 'max_retry': model_config.get('max_retry', 3) + }) + except (KeyError, AttributeError) as e: + # Fallback to minimal config if required fields missing + lm_cfg = lm_config.LMConfig( + provider=config.get('model', {}).get('provider', 'openai'), + model=config.get('model', {}).get('model', 'gpt-4'), + mode=config.get('model', {}).get('mode', 'chat') + ) + + # Get browser environment configuration + browser_config = config.get('browser', {}) + observation_type = config.get('observation', {}).get('observation_type', 'accessibility_tree') + + # Load captioning model if needed (similar to run.py) + caption_image_fn = None + if observation_type in [ + "accessibility_tree_with_captioner", + ]: + device = torch.device("cuda") if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + captioning_model = config.get('model', {}).get('captioning_model', 'Salesforce/blip2-flan-t5-xl') + caption_image_fn = image_utils.get_captioning_fn( + device, dtype, captioning_model + ) + + # Build viewport_size from config + viewport_size = { + "width": browser_config.get('viewport_width', 1280), + "height": browser_config.get('viewport_height', 720), + } + + # Get observation config for browser parameters + observation_config = config.get('observation', {}) + output_config = config.get('output', {}) + + # Create browser environment + env = ScriptBrowserEnv( + headless=browser_config.get('headless', False), # Set to False for debugging + slow_mo=browser_config.get('slow_mo', 100), + observation_type=observation_type, + current_viewport_only=observation_config.get('current_viewport_only', True), + viewport_size=viewport_size, + save_trace_enabled=output_config.get('save_trace_enabled', True), + sleep_after_execution=browser_config.get('sleep_after_execution', 0.5), + captioning_fn=caption_image_fn, + ) + + # Determine if model is multimodal and select appropriate prompt constructor + from llms.tokenizers import Tokenizer + + model_name = lm_cfg.model.lower() + is_multimodal_model = ( + "gemini" in model_name or + ("gpt-4" in model_name and "vision" in model_name) or + ("gpt-4o" in model_name) or (True) + ) + is_image_observation = observation_type in ["image", "image_som"] + + # Get instruction path from config or use default + instruction_path = config.get('instruction_path') + if not instruction_path: + # Select default instruction path based on observation type and model + if is_multimodal_model and is_image_observation: + instruction_path = 'agent/prompts/jsons/p_multimodal_cot_id_actree_3s.json' + else: + instruction_path = 'agent/prompts/jsons/p_cot_id_actree_3s.json' + + # Load instruction to check prompt_constructor type + with open(instruction_path) as f: + instruction_data = json.load(f) + constructor_type = instruction_data.get("meta_data", {}).get("prompt_constructor", "DirectPromptConstructor") + + # Create appropriate prompt constructor + tokenizer = Tokenizer(lm_cfg.provider, lm_cfg.model) + if constructor_type == "MultimodalCoTPromptConstructor": + prompt_constructor = MultimodalCoTPromptConstructor( + instruction_path=instruction_path, + lm_config=lm_cfg, + tokenizer=tokenizer + ) + elif constructor_type == "CoTPromptConstructor": + prompt_constructor = CoTPromptConstructor( + instruction_path=instruction_path, + lm_config=lm_cfg, + tokenizer=tokenizer + ) + else: + prompt_constructor = DirectPromptConstructor( + instruction_path=instruction_path, + lm_config=lm_cfg, + tokenizer=tokenizer + ) + + # Create base prompt agent for multi-agent coordinator + # Use action_set_tag from configuration instead of hardcoding + action_set_tag = config.get('observation', {}).get('action_set_tag', 'id_accessibility_tree') + base_agent = PromptAgent( + action_set_tag=action_set_tag, + lm_config=lm_cfg, + prompt_constructor=prompt_constructor, + captioning_fn=caption_image_fn if observation_type == "accessibility_tree_with_captioner" else None, + ) + + task_cfg = config.get('task', {}) + output_cfg = config.get('output', {}) + task_metadata_base = config.get('task_metadata') or { + "task_id": task_cfg.get('task_id'), + "task": task_cfg.get('intent'), + "confirmed_task": task_cfg.get('confirmed_task'), + "website": task_cfg.get('website') or task_cfg.get('start_url'), + "reference_length": task_cfg.get('reference_length'), + "level": task_cfg.get('level'), + } + task_metadata_base = {k: v for k, v in task_metadata_base.items() if v is not None} # 只保留有值的元数据 + webjudge_root = output_cfg.get('webjudge_root') # 可自定义 WebJudge 输出根目录 + + # 任务列表模式:从 task_list_file 读取并依次执行 + task_list_path = config.get("task_list_file") + task_list = [] + if task_list_path: + with open(task_list_path, "r", encoding="utf-8") as f: + task_list = json.load(f) + print(f"Loaded {len(task_list)} tasks from {task_list_path}") + + def run_single_task(single_task_meta: Dict[str, Any]): + # 将列表里的字段映射到运行所需的 task 配置与元数据 + per_task_cfg = config.get('task', {}).copy() if isinstance(config.get('task', {}), dict) else {} + per_task_cfg['intent'] = single_task_meta.get('confirmed_task') or single_task_meta.get('task') or per_task_cfg.get('intent', 'Not specified') + per_task_cfg['start_url'] = single_task_meta.get('website') or per_task_cfg.get('start_url') + per_task_cfg['task_id'] = single_task_meta.get('task_id') + per_task_cfg['reference_length'] = single_task_meta.get('reference_length') + per_task_cfg['level'] = single_task_meta.get('level') + # 如未指定 max_steps,则尝试用 reference_length 作为上限 + if per_task_cfg.get('max_steps') is None and single_task_meta.get('reference_length'): + per_task_cfg['max_steps'] = single_task_meta['reference_length'] + + # 合成任务元数据 + tm = task_metadata_base.copy() + tm.update({k: v for k, v in single_task_meta.items() if v is not None}) + tm['task'] = per_task_cfg.get('intent') + + # 创建新的协调器以清空内部轨迹 + coordinator = MultiAgentCoordinator(lm_cfg, + base_agent, + browser_env=env, + result_dir=result_dir, + memory_config=config.get('memory', {}), + webjudge_result_root=webjudge_root) + + # 加载输入图片(若有) + image_paths = per_task_cfg.get('image') + images = [] + if image_paths is not None: + if isinstance(image_paths, str): + image_paths = [image_paths] + for image_path in image_paths: + if image_path.startswith("http"): + input_image = Image.open(requests.get(image_path, stream=True).raw) + else: + input_image = Image.open(image_path) + images.append(input_image) + + # 重置浏览器到指定起始页 + start_url = per_task_cfg.get('start_url') + reset_options = {"start_url": start_url} if start_url else None + initial_obs, initial_info = env.reset(options=reset_options) + initial_observation = {"observation": initial_obs, "info": initial_info} + + # 执行任务 + return coordinator.execute_task( + user_goal=per_task_cfg.get('intent', 'Not specified'), + start_observation=initial_observation, + max_steps=per_task_cfg.get('max_steps', 3), + images=images if images else None, + task_metadata=tm, # 传递任务元信息供落盘 + webjudge_root=webjudge_root # 指定 WebJudge 输出根目录 + ) + + # 若提供任务列表则顺序执行,否则执行单任务 + if task_list: + results = [] + for idx, t in enumerate(task_list): + if idx < args.start_id or idx >= args.end_id: + continue + print(f"Running task {idx+1}/{len(task_list)}: {t.get('task_id')}") + try: + results.append(run_single_task(t)) + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + results.append(None) + return results + else: + return run_single_task(task_metadata_base) + + +if __name__ == "__main__": + args = config() + try: + test(args, args.config_file) + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() \ No newline at end of file