From c21d9ca5b4e05d8f234e5afebac36e9aed91d0db Mon Sep 17 00:00:00 2001 From: SentienceDEV Date: Sat, 14 Mar 2026 18:45:37 -0700 Subject: [PATCH 1/3] upload trace enablement --- predicate/tracer_factory.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/predicate/tracer_factory.py b/predicate/tracer_factory.py index 39c1b2d..45b5ab1 100644 --- a/predicate/tracer_factory.py +++ b/predicate/tracer_factory.py @@ -49,7 +49,7 @@ def create_tracer( run_id: str | None = None, api_url: str | None = None, logger: SentienceLogger | None = None, - upload_trace: bool = False, + upload_trace: bool | None = None, goal: str | None = None, agent_type: str | None = None, llm_model: str | None = None, @@ -71,9 +71,10 @@ def create_tracer( run_id: Unique identifier for this agent run. If not provided, generates UUID. api_url: Sentience API base URL (default: https://api.sentienceapi.com) logger: Optional logger instance for logging file sizes and errors - upload_trace: Enable cloud trace upload (default: False). When True and api_key + upload_trace: Enable cloud trace upload. When None (default), automatically + enables cloud upload if api_key is provided. When True and api_key is provided, traces will be uploaded to cloud. When False, traces - are saved locally only. + are saved locally only regardless of api_key. goal: User's goal/objective for this trace run. This will be displayed as the trace name in the frontend. Should be descriptive and action-oriented. Example: "Add wireless headphones to cart on Amazon" @@ -133,12 +134,16 @@ def create_tracer( if api_url is None: api_url = PREDICATE_API_URL + # Default upload_trace to True when api_key is provided + # This ensures tracing is enabled automatically for Pro/Enterprise tiers + should_upload = upload_trace if upload_trace is not None else (api_key is not None) + # 0. Check for orphaned traces from previous crashes (if api_key provided and upload enabled) - if api_key and upload_trace: + if api_key and should_upload: _recover_orphaned_traces(api_key, api_url) # 1. Try to initialize Cloud Sink (Pro/Enterprise tier) if upload enabled - if api_key and upload_trace: + if api_key and should_upload: try: # Build metadata object for trace initialization # Only include non-empty fields to avoid sending empty strings From 7fadb13ebeee78fd9d743d9b6815a5c73b9c092d Mon Sep 17 00:00:00 2001 From: SentienceDEV Date: Sat, 14 Mar 2026 23:10:35 -0700 Subject: [PATCH 2/3] improve planner_executor agent with lessons learned --- .../planner-executor/custom_config_example.py | 54 +- .../intent_heuristics_example.py | 395 +++++++++++ predicate/agents/__init__.py | 10 + predicate/agents/planner_executor_agent.py | 645 ++++++++++++++++-- tests/unit/test_planner_executor_agent.py | 575 ++++++++++++++++ 5 files changed, 1639 insertions(+), 40 deletions(-) create mode 100644 examples/planner-executor/intent_heuristics_example.py create mode 100644 tests/unit/test_planner_executor_agent.py diff --git a/examples/planner-executor/custom_config_example.py b/examples/planner-executor/custom_config_example.py index af8ad37..af852b3 100644 --- a/examples/planner-executor/custom_config_example.py +++ b/examples/planner-executor/custom_config_example.py @@ -6,6 +6,8 @@ - Snapshot escalation (enable/disable, custom step sizes) - Retry configuration (timeouts, max attempts) - Vision fallback settings +- Pre-step verification (skip if predicates pass) +- Recovery navigation (track last good URL) Usage: export OPENAI_API_KEY="sk-..." @@ -22,6 +24,7 @@ from predicate.agents import ( PlannerExecutorAgent, PlannerExecutorConfig, + RecoveryNavigationConfig, RetryConfig, SnapshotEscalationConfig, ) @@ -124,9 +127,45 @@ async def example_vision_fallback() -> None: print(f" vision.max_vision_calls: {config.vision.max_vision_calls}") +async def example_pre_step_verification() -> None: + """Pre-step verification configuration.""" + print("\n--- Example 7: Pre-step Verification ---") + + # Default: enabled (steps are skipped if predicates already pass) + config_enabled = PlannerExecutorConfig( + pre_step_verification=True, # Default + ) + print(f" pre_step_verification (default): {config_enabled.pre_step_verification}") + + # Disabled: always execute steps even if predicates pass + config_disabled = PlannerExecutorConfig( + pre_step_verification=False, + ) + print(f" pre_step_verification (disabled): {config_disabled.pre_step_verification}") + print(" When disabled, all steps are executed even if already satisfied") + + +async def example_recovery_navigation() -> None: + """Recovery navigation configuration.""" + print("\n--- Example 8: Recovery Navigation ---") + + config = PlannerExecutorConfig( + recovery=RecoveryNavigationConfig( + enabled=True, + max_recovery_attempts=3, + track_successful_urls=True, + ), + ) + + print(f" recovery.enabled: {config.recovery.enabled}") + print(f" recovery.max_recovery_attempts: {config.recovery.max_recovery_attempts}") + print(f" recovery.track_successful_urls: {config.recovery.track_successful_urls}") + print(" Tracks last_known_good_url for recovery when agent gets off-track") + + async def example_full_custom() -> None: """Full custom configuration with all options.""" - print("\n--- Example 7: Full Custom Config ---") + print("\n--- Example 9: Full Custom Config ---") config = PlannerExecutorConfig( # Snapshot escalation @@ -148,6 +187,13 @@ async def example_full_custom() -> None: enabled=True, max_vision_calls=3, ), + # Recovery navigation + recovery=RecoveryNavigationConfig( + enabled=True, + max_recovery_attempts=2, + ), + # Pre-step verification + pre_step_verification=True, # Planner settings planner_max_tokens=3000, planner_temperature=0.0, @@ -164,11 +210,13 @@ async def example_full_custom() -> None: print(f" Escalation: {config.snapshot.limit_base} -> ... -> {config.snapshot.limit_max}") print(f" Max replans: {config.retry.max_replans}") print(f" Vision enabled: {config.vision.enabled}") + print(f" Pre-step verification: {config.pre_step_verification}") + print(f" Recovery enabled: {config.recovery.enabled}") async def example_run_with_config() -> None: """Run agent with custom config.""" - print("\n--- Example 8: Run Agent with Custom Config ---") + print("\n--- Example 10: Run Agent with Custom Config ---") openai_key = os.getenv("OPENAI_API_KEY") if not openai_key: @@ -229,6 +277,8 @@ async def main() -> None: await example_custom_limits() await example_retry_config() await example_vision_fallback() + await example_pre_step_verification() + await example_recovery_navigation() await example_full_custom() await example_run_with_config() diff --git a/examples/planner-executor/intent_heuristics_example.py b/examples/planner-executor/intent_heuristics_example.py new file mode 100644 index 0000000..9f36bda --- /dev/null +++ b/examples/planner-executor/intent_heuristics_example.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +""" +PlannerExecutorAgent example with IntentHeuristics and ExecutorOverride. + +This example demonstrates the pluggable heuristics system: +- IntentHeuristics: Domain-specific element selection without LLM +- ExecutorOverride: Validate or override executor element choices +- Pre-step verification: Skip steps if predicates already pass + +These features allow the SDK to remain generic while supporting specialized +behavior for different sites (e-commerce, forms, etc.). + +Usage: + export OPENAI_API_KEY="sk-..." + python intent_heuristics_example.py +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any + +from predicate import AsyncPredicateBrowser +from predicate.agent_runtime import AgentRuntime +from predicate.agents import ( + ExecutorOverride, + IntentHeuristics, + PlannerExecutorAgent, + PlannerExecutorConfig, + RecoveryNavigationConfig, +) +from predicate.backends.playwright_backend import PlaywrightBackend +from predicate.llm_provider import OpenAIProvider + + +# --------------------------------------------------------------------------- +# Example IntentHeuristics Implementation +# --------------------------------------------------------------------------- + + +class EcommerceHeuristics: + """ + Example IntentHeuristics for e-commerce sites. + + This heuristics class provides domain-specific element selection for + common e-commerce actions like adding to cart, checkout, etc. + + When the agent receives a step with an intent like "add_to_cart", + this class tries to find the matching element without calling the LLM. + If no match is found, the agent falls back to the LLM executor. + """ + + def find_element_for_intent( + self, + intent: str, + elements: list[Any], + url: str, + goal: str, + ) -> int | None: + """ + Find element ID for a given intent using domain-specific patterns. + + Args: + intent: The intent hint from the plan step + elements: List of snapshot elements + url: Current page URL + goal: Human-readable goal for context + + Returns: + Element ID if match found, None to fall back to LLM + """ + intent_lower = intent.lower() + + # Add to cart patterns + if "add_to_cart" in intent_lower or "add to cart" in intent_lower: + for el in elements: + text = (getattr(el, "text", "") or "").lower() + role = getattr(el, "role", "") + if role == "button" and "add to cart" in text: + return getattr(el, "id", None) + + # Search box patterns + if "search" in intent_lower and "box" in intent_lower: + for el in elements: + role = getattr(el, "role", "") + if role in ("searchbox", "combobox", "textbox"): + text = (getattr(el, "text", "") or "").lower() + if "search" in text: + return getattr(el, "id", None) + + # Checkout patterns + if "checkout" in intent_lower or "proceed" in intent_lower: + for el in elements: + text = (getattr(el, "text", "") or "").lower() + role = getattr(el, "role", "") + if role == "button" and ("checkout" in text or "proceed" in text): + return getattr(el, "id", None) + + # First product link (common in search results) + if "first_product" in intent_lower: + # Find links containing product indicators + for el in elements: + role = getattr(el, "role", "") + if role == "link": + # Check if it looks like a product link (has /dp/ or product in href) + href = getattr(el, "href", "") or "" + if "/dp/" in href or "/product/" in href: + return getattr(el, "id", None) + + return None # Fall back to LLM executor + + def priority_order(self) -> list[str]: + """ + Return intent patterns in priority order. + + This helps the agent prioritize certain actions when multiple + matching elements are found. + """ + return [ + "checkout", + "proceed_to_checkout", + "add_to_cart", + "search_box", + "first_product", + "quantity", + ] + + +# --------------------------------------------------------------------------- +# Example ExecutorOverride Implementation +# --------------------------------------------------------------------------- + + +class SafetyOverride: + """ + Example ExecutorOverride for safety validation. + + This override validates executor element choices before actions are + executed, providing safety checks like: + - Block clicks on delete/remove buttons + - Block form submissions that might cause data loss + - Block navigation to external sites + """ + + def __init__(self, blocked_patterns: list[str] | None = None): + """ + Initialize SafetyOverride. + + Args: + blocked_patterns: Text patterns to block (case-insensitive) + """ + self.blocked_patterns = blocked_patterns or [ + "delete", + "remove account", + "cancel order", + "unsubscribe", + ] + + def validate_choice( + self, + element_id: int, + action: str, + elements: list[Any], + goal: str, + ) -> tuple[bool, int | None, str | None]: + """ + Validate or override the executor's element choice. + + Args: + element_id: The element ID chosen by the executor + action: The action type (CLICK, TYPE, etc.) + elements: List of snapshot elements + goal: Human-readable goal + + Returns: + Tuple of (is_valid, override_element_id, rejection_reason) + """ + # Find the selected element + selected_element = None + for el in elements: + if getattr(el, "id", None) == element_id: + selected_element = el + break + + if selected_element is None: + # Element not found, allow (might be a valid ID) + return True, None, None + + text = (getattr(selected_element, "text", "") or "").lower() + + # Check against blocked patterns + for pattern in self.blocked_patterns: + if pattern.lower() in text: + return False, None, f"blocked_pattern:{pattern}" + + # All checks passed + return True, None, None + + +# --------------------------------------------------------------------------- +# Main Example +# --------------------------------------------------------------------------- + + +async def example_with_heuristics() -> None: + """Run agent with IntentHeuristics.""" + print("\n--- Example 1: IntentHeuristics ---") + + openai_key = os.getenv("OPENAI_API_KEY") + if not openai_key: + print(" Skipping (no OPENAI_API_KEY)") + return + + # Create heuristics instance + heuristics = EcommerceHeuristics() + + # Create agent with heuristics + agent = PlannerExecutorAgent( + planner=OpenAIProvider(model="gpt-4o"), + executor=OpenAIProvider(model="gpt-4o-mini"), + config=PlannerExecutorConfig(), + intent_heuristics=heuristics, # Plug in domain-specific heuristics + ) + + print(" Agent created with EcommerceHeuristics") + print(" Heuristics priority order:", heuristics.priority_order()) + + +async def example_with_safety_override() -> None: + """Run agent with ExecutorOverride for safety.""" + print("\n--- Example 2: ExecutorOverride (Safety) ---") + + openai_key = os.getenv("OPENAI_API_KEY") + if not openai_key: + print(" Skipping (no OPENAI_API_KEY)") + return + + # Create safety override + safety = SafetyOverride( + blocked_patterns=[ + "delete", + "remove", + "cancel subscription", + ], + ) + + # Create agent with safety override + agent = PlannerExecutorAgent( + planner=OpenAIProvider(model="gpt-4o"), + executor=OpenAIProvider(model="gpt-4o-mini"), + config=PlannerExecutorConfig(), + executor_override=safety, # Add safety validation + ) + + print(" Agent created with SafetyOverride") + print(" Blocked patterns:", safety.blocked_patterns) + + +async def example_pre_step_verification() -> None: + """Demonstrate pre-step verification skipping.""" + print("\n--- Example 3: Pre-step Verification ---") + + # Pre-step verification is enabled by default + config = PlannerExecutorConfig( + pre_step_verification=True, # Default + ) + print(f" pre_step_verification: {config.pre_step_verification}") + print(" When enabled, steps are skipped if their verification predicates already pass") + print(" This saves time when the desired state is already achieved") + + # Example: If a step's goal is 'go to checkout' with verify=[url_contains('checkout')] + # and the browser is already on a checkout page, the step will be skipped + + +async def example_recovery_navigation() -> None: + """Demonstrate recovery navigation config.""" + print("\n--- Example 4: Recovery Navigation ---") + + config = PlannerExecutorConfig( + recovery=RecoveryNavigationConfig( + enabled=True, + max_recovery_attempts=3, + track_successful_urls=True, + ), + ) + + print(f" recovery.enabled: {config.recovery.enabled}") + print(f" recovery.max_recovery_attempts: {config.recovery.max_recovery_attempts}") + print(f" recovery.track_successful_urls: {config.recovery.track_successful_urls}") + print(" When enabled, the agent tracks last_known_good_url for recovery") + + +async def example_combined() -> None: + """Combined example with all new features.""" + print("\n--- Example 5: Combined Features ---") + + openai_key = os.getenv("OPENAI_API_KEY") + if not openai_key: + print(" Skipping (no OPENAI_API_KEY)") + return + + # Create instances + heuristics = EcommerceHeuristics() + safety = SafetyOverride() + + # Create agent with all features + agent = PlannerExecutorAgent( + planner=OpenAIProvider(model="gpt-4o"), + executor=OpenAIProvider(model="gpt-4o-mini"), + config=PlannerExecutorConfig( + pre_step_verification=True, + recovery=RecoveryNavigationConfig( + enabled=True, + max_recovery_attempts=2, + ), + ), + intent_heuristics=heuristics, + executor_override=safety, + ) + + print(" Agent created with:") + print(" - EcommerceHeuristics (domain-specific element selection)") + print(" - SafetyOverride (action validation)") + print(" - Pre-step verification (skip if already satisfied)") + print(" - Recovery navigation (track good URLs)") + + +async def example_run_with_features() -> None: + """Actually run the agent with new features.""" + print("\n--- Example 6: Run with Features ---") + + openai_key = os.getenv("OPENAI_API_KEY") + if not openai_key: + print(" Skipping (no OPENAI_API_KEY)") + return + + predicate_api_key = os.getenv("PREDICATE_API_KEY") + + # Create agent with heuristics + heuristics = EcommerceHeuristics() + + agent = PlannerExecutorAgent( + planner=OpenAIProvider(model="gpt-4o"), + executor=OpenAIProvider(model="gpt-4o-mini"), + config=PlannerExecutorConfig( + pre_step_verification=True, + ), + intent_heuristics=heuristics, + ) + + # Simple task on example.com + task = "Navigate to example.com and verify the page loaded" + + async with AsyncPredicateBrowser( + api_key=predicate_api_key, + headless=True, + ) as browser: + page = await browser.new_page() + await page.goto("https://example.com") + + backend = PlaywrightBackend(page) + runtime = AgentRuntime(backend=backend) + + result = await agent.run( + runtime=runtime, + task=task, + ) + + print(f" Success: {result.success}") + print(f" Steps: {result.steps_completed}/{result.steps_total}") + + # Check if any steps were skipped due to pre-step verification + for outcome in result.step_outcomes: + if outcome.action_taken and "SKIPPED" in outcome.action_taken: + print(f" Skipped step {outcome.step_id}: {outcome.goal}") + + +async def main() -> None: + print("PlannerExecutorAgent - New Features Examples") + print("=" * 50) + + await example_with_heuristics() + await example_with_safety_override() + await example_pre_step_verification() + await example_recovery_navigation() + await example_combined() + await example_run_with_features() + + print("\n" + "=" * 50) + print("Done!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/predicate/agents/__init__.py b/predicate/agents/__init__.py index b6f5824..f3d8e5e 100644 --- a/predicate/agents/__init__.py +++ b/predicate/agents/__init__.py @@ -18,17 +18,22 @@ VisionFallbackConfig, ) from .planner_executor_agent import ( + ExecutorOverride, + IntentHeuristics, Plan, PlanStep, PlannerExecutorAgent, PlannerExecutorConfig, PredicateSpec, + RecoveryNavigationConfig, RetryConfig, RunOutcome, SnapshotContext, SnapshotEscalationConfig, StepOutcome, StepStatus, + normalize_plan, + validate_plan_smoothness, ) __all__ = [ @@ -39,16 +44,21 @@ "PredicateBrowserAgentConfig", "VisionFallbackConfig", # Planner + Executor Agent + "ExecutorOverride", + "IntentHeuristics", "Plan", "PlanStep", "PlannerExecutorAgent", "PlannerExecutorConfig", "PredicateSpec", + "RecoveryNavigationConfig", "RetryConfig", "RunOutcome", "SnapshotContext", "SnapshotEscalationConfig", "StepOutcome", "StepStatus", + "normalize_plan", + "validate_plan_smoothness", ] diff --git a/predicate/agents/planner_executor_agent.py b/predicate/agents/planner_executor_agent.py index dcde12f..6832ea9 100644 --- a/predicate/agents/planner_executor_agent.py +++ b/predicate/agents/planner_executor_agent.py @@ -20,11 +20,12 @@ import re import time import uuid +from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Literal +from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -47,6 +48,133 @@ from .browser_agent import CaptchaConfig, VisionFallbackConfig +# --------------------------------------------------------------------------- +# IntentHeuristics Protocol +# --------------------------------------------------------------------------- + + +@runtime_checkable +class IntentHeuristics(Protocol): + """ + Protocol for pluggable domain-specific element selection heuristics. + + Developers can implement this protocol to provide domain-specific logic + for selecting elements based on the step intent. This allows the SDK to + remain generic while supporting specialized behavior for different sites. + + Example implementation for an e-commerce site: + + class EcommerceHeuristics: + def find_element_for_intent( + self, + intent: str, + elements: list[Any], + url: str, + goal: str, + ) -> int | None: + if "add to cart" in intent.lower(): + for el in elements: + text = getattr(el, "text", "") or "" + if "add to cart" in text.lower(): + return getattr(el, "id", None) + return None # Fall back to LLM + + def priority_order(self) -> list[str]: + return ["add_to_cart", "checkout", "search"] + + # Usage: + agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + intent_heuristics=EcommerceHeuristics(), + ) + """ + + def find_element_for_intent( + self, + intent: str, + elements: list[Any], + url: str, + goal: str, + ) -> int | None: + """ + Find element ID for a given intent using domain-specific heuristics. + + Args: + intent: The intent hint from the plan step (e.g., "add_to_cart", "checkout") + elements: List of snapshot elements with id, role, text, etc. + url: Current page URL + goal: Human-readable goal for context + + Returns: + Element ID if a match is found, None to fall back to LLM executor + """ + ... + + def priority_order(self) -> list[str]: + """ + Return list of intent patterns in priority order. + + The agent will try heuristics for each intent pattern in order. + This helps prioritize certain actions (e.g., checkout over add-to-cart). + + Returns: + List of intent pattern strings + """ + ... + + +class ExecutorOverride(Protocol): + """ + Protocol for validating or overriding executor element choices. + + This allows developers to add validation logic or override the executor's + choice before an action is executed. Useful for safety checks or + domain-specific corrections. + + Example: + class SafetyOverride: + def validate_choice( + self, + element_id: int, + action: str, + elements: list[Any], + goal: str, + ) -> tuple[bool, int | None, str | None]: + # Block clicks on delete buttons + for el in elements: + if getattr(el, "id", None) == element_id: + text = getattr(el, "text", "") or "" + if "delete" in text.lower(): + return False, None, "blocked_delete_button" + return True, None, None + """ + + def validate_choice( + self, + element_id: int, + action: str, + elements: list[Any], + goal: str, + ) -> tuple[bool, int | None, str | None]: + """ + Validate or override the executor's element choice. + + Args: + element_id: The element ID chosen by the executor + action: The action type (CLICK, TYPE, etc.) + elements: List of snapshot elements + goal: Human-readable goal + + Returns: + Tuple of (is_valid, override_element_id, rejection_reason) + - is_valid: True if choice is acceptable + - override_element_id: Alternative element ID, or None + - rejection_reason: Reason for rejection, or None + """ + ... + + # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- @@ -102,6 +230,25 @@ class RetryConfig: max_replans: int = 1 +@dataclass(frozen=True) +class RecoveryNavigationConfig: + """ + Configuration for recovery navigation when agent gets off-track. + + The agent tracks the last known good URL (where verification passed) + and can navigate back if subsequent steps fail. + + Attributes: + enabled: If True, track last_known_good_url and attempt recovery navigation. + max_recovery_attempts: Maximum navigation recovery attempts per step. + recovery_predicates: Optional predicates to verify recovery succeeded. + """ + + enabled: bool = True + max_recovery_attempts: int = 2 + track_successful_urls: bool = True + + @dataclass(frozen=True) class PlannerExecutorConfig: """ @@ -111,6 +258,7 @@ class PlannerExecutorConfig: - Snapshot escalation settings - Retry/verification settings - Vision fallback settings + - Recovery navigation settings - Planner/Executor LLM settings - Tracing settings """ @@ -132,6 +280,9 @@ class PlannerExecutorConfig: # CAPTCHA handling captcha: CaptchaConfig = CaptchaConfig() + # Recovery navigation + recovery: RecoveryNavigationConfig = RecoveryNavigationConfig() + # Planner LLM settings planner_max_tokens: int = 2048 planner_temperature: float = 0.0 @@ -145,6 +296,9 @@ class PlannerExecutorConfig: stabilize_poll_s: float = 0.35 stabilize_max_attempts: int = 6 + # Pre-step verification (skip step if predicates already pass) + pre_step_verification: bool = True + # Tracing trace_screenshots: bool = True trace_screenshot_format: str = "jpeg" @@ -372,6 +526,125 @@ def build_predicate(spec: PredicateSpec | dict[str, Any]) -> Predicate: raise ValueError(f"Unsupported predicate: {name}") +# --------------------------------------------------------------------------- +# Plan Normalization and Validation +# --------------------------------------------------------------------------- + + +def normalize_plan(plan_dict: dict[str, Any]) -> dict[str, Any]: + """ + Normalize plan dictionary to handle LLM output variations. + + This function handles common variations in LLM output: + - url vs target field names + - action aliases (click vs CLICK) + - step id variations (string vs int) + + Args: + plan_dict: Raw plan dictionary from LLM + + Returns: + Normalized plan dictionary + """ + # Normalize steps + if "steps" in plan_dict: + for step in plan_dict["steps"]: + # Normalize action names to uppercase + if "action" in step: + action = step["action"].upper() + # Handle common aliases + action_aliases = { + "CLICK_ELEMENT": "CLICK", + "CLICK_BUTTON": "CLICK", + "CLICK_LINK": "CLICK", + "INPUT": "TYPE_AND_SUBMIT", + "TYPE_TEXT": "TYPE_AND_SUBMIT", + "ENTER_TEXT": "TYPE_AND_SUBMIT", + "GOTO": "NAVIGATE", + "GO_TO": "NAVIGATE", + "OPEN": "NAVIGATE", + "SCROLL_DOWN": "SCROLL", + "SCROLL_UP": "SCROLL", + } + step["action"] = action_aliases.get(action, action) + + # Normalize url -> target for NAVIGATE actions + if "url" in step and "target" not in step: + step["target"] = step.pop("url") + + # Ensure step id is int + if "id" in step and isinstance(step["id"], str): + try: + step["id"] = int(step["id"]) + except ValueError: + pass + + # Normalize optional_substeps recursively + if "optional_substeps" in step: + for substep in step["optional_substeps"]: + if "action" in substep: + substep["action"] = substep["action"].upper() + if "url" in substep and "target" not in substep: + substep["target"] = substep.pop("url") + + return plan_dict + + +def validate_plan_smoothness(plan: "Plan") -> list[str]: + """ + Validate plan quality and smoothness. + + Checks for common issues that indicate a low-quality plan: + - Missing verification predicates + - Consecutive same actions + - Empty or too short plans + - Missing required fields + + Args: + plan: Parsed Plan object + + Returns: + List of warning strings (empty if plan is smooth) + """ + warnings: list[str] = [] + + # Check for empty plan + if not plan.steps: + warnings.append("Plan has no steps") + return warnings + + # Check for very short plans (might be incomplete) + if len(plan.steps) < 2: + warnings.append("Plan has only one step - might be incomplete") + + # Check each step + prev_action = None + for i, step in enumerate(plan.steps): + # Check for missing verification + if not step.verify and step.required: + warnings.append(f"Step {step.id} has no verification predicates") + + # Check for consecutive same actions (might indicate loop) + if step.action == prev_action and step.action == "CLICK": + warnings.append(f"Steps {step.id - 1} and {step.id} both use {step.action}") + + # Check for NAVIGATE without target + if step.action == "NAVIGATE" and not step.target: + warnings.append(f"Step {step.id} is NAVIGATE but has no target URL") + + # Check for CLICK without intent + if step.action == "CLICK" and not step.intent: + warnings.append(f"Step {step.id} is CLICK but has no intent hint") + + # Check for TYPE_AND_SUBMIT without input + if step.action == "TYPE_AND_SUBMIT" and not step.input: + warnings.append(f"Step {step.id} is TYPE_AND_SUBMIT but has no input") + + prev_action = step.action + + return warnings + + # --------------------------------------------------------------------------- # Prompt Builders # --------------------------------------------------------------------------- @@ -486,6 +759,13 @@ class PlannerExecutorAgent: - SnapshotContext sharing to avoid redundant captures - Full tracing integration for Predicate Studio visualization - Replanning on step failure + - Pre-step verification: skip execution if predicates already pass + - Optional substeps: fallback steps for edge cases (scroll, close drawer) + - Plan normalization: handles LLM output variations (url vs target, etc.) + - Plan smoothness validation: quality checks on generated plans + - Pluggable IntentHeuristics: domain-specific element selection without LLM + - ExecutorOverride: validate/override executor element choices + - Recovery navigation: track last known good URL for off-track recovery Example: >>> from predicate.agents import PlannerExecutorAgent, PlannerExecutorConfig @@ -508,6 +788,24 @@ class PlannerExecutorAgent: ... start_url="https://amazon.com", ... ) ... print(f"Success: {result.success}") + + Example with IntentHeuristics: + >>> class EcommerceHeuristics: + ... def find_element_for_intent(self, intent, elements, url, goal): + ... if "add to cart" in intent.lower(): + ... for el in elements: + ... if "add to cart" in (getattr(el, "text", "") or "").lower(): + ... return getattr(el, "id", None) + ... return None # Fall back to LLM + ... + ... def priority_order(self): + ... return ["add_to_cart", "checkout"] + >>> + >>> agent = PlannerExecutorAgent( + ... planner=planner, + ... executor=executor, + ... intent_heuristics=EcommerceHeuristics(), + ... ) """ def __init__( @@ -520,6 +818,8 @@ def __init__( config: PlannerExecutorConfig | None = None, tracer: Tracer | None = None, context_formatter: Callable[[Snapshot, str], str] | None = None, + intent_heuristics: IntentHeuristics | None = None, + executor_override: ExecutorOverride | None = None, ) -> None: """ Initialize PlannerExecutorAgent. @@ -532,6 +832,11 @@ def __init__( config: Agent configuration tracer: Tracer for Predicate Studio visualization context_formatter: Custom function to format snapshot for LLM + intent_heuristics: Optional pluggable heuristics for domain-specific + element selection. When provided, the agent tries heuristics + before falling back to the LLM executor. + executor_override: Optional hook to validate or override executor + element choices before action execution. """ self.planner = planner self.executor = executor @@ -540,6 +845,8 @@ def __init__( self.config = config or PlannerExecutorConfig() self.tracer = tracer self._context_formatter = context_formatter + self._intent_heuristics = intent_heuristics + self._executor_override = executor_override # State tracking self._current_plan: Plan | None = None @@ -548,6 +855,7 @@ def __init__( self._vision_calls: int = 0 self._snapshot_context: SnapshotContext | None = None self._run_id: str | None = None + self._last_known_good_url: str | None = None def _format_context(self, snap: Snapshot, goal: str) -> str: """Format snapshot for LLM context.""" @@ -937,7 +1245,20 @@ async def plan( try: plan_dict = self._extract_json(resp.content) + + # Normalize plan to handle LLM output variations + plan_dict = normalize_plan(plan_dict) + plan = Plan.model_validate(plan_dict) + + # Validate plan smoothness (warnings only, don't fail) + warnings = validate_plan_smoothness(plan) + if warnings and self.tracer: + try: + self.tracer.emit("plan_warnings", {"warnings": warnings}) + except Exception: + pass + self._current_plan = plan self._step_index = 0 @@ -1044,13 +1365,155 @@ async def replan( # Step Execution # ------------------------------------------------------------------------- + async def _check_pre_step_verification( + self, + runtime: AgentRuntime, + step: PlanStep, + ) -> bool: + """ + Check if step verification predicates already pass before executing. + + This optimization skips step execution if the desired state is already + achieved (e.g., already on checkout page when step goal is "go to checkout"). + + Returns: + True if all predicates pass (step can be skipped), False otherwise + """ + if not step.verify: + return False + + for verify_spec in step.verify: + try: + pred = build_predicate(verify_spec) + # Quick check without retries + snap = await runtime.snapshot(limit=30, screenshot=False, goal=step.goal) + if snap is None: + return False + if not pred.evaluate(snap): + return False + except Exception: + return False + + return True + + async def _try_intent_heuristics( + self, + step: PlanStep, + elements: list[Any], + url: str, + ) -> int | None: + """ + Try pluggable intent heuristics to find element without LLM. + + Returns: + Element ID if heuristics found a match, None otherwise + """ + if self._intent_heuristics is None: + return None + + if not step.intent: + return None + + try: + element_id = self._intent_heuristics.find_element_for_intent( + intent=step.intent, + elements=elements, + url=url, + goal=step.goal, + ) + return element_id + except Exception: + return None + + async def _execute_optional_substeps( + self, + substeps: list[PlanStep], + runtime: AgentRuntime, + parent_step_index: int, + ) -> list[StepOutcome]: + """ + Execute optional substeps (fallback steps for edge cases). + + Optional substeps are executed when the main step's verification fails. + They handle scenarios like scroll-to-reveal, closing drawers, etc. + + Returns: + List of substep outcomes + """ + outcomes: list[StepOutcome] = [] + + for i, substep in enumerate(substeps): + substep_index = parent_step_index * 100 + i + 1 # e.g., 101, 102 for step 1's substeps + + # Execute substep with simplified logic + try: + ctx = await self._snapshot_with_escalation( + runtime, + goal=substep.goal, + capture_screenshot=False, + ) + + # Determine element and action + action_type = substep.action + element_id: int | None = None + + if action_type in ("CLICK", "TYPE_AND_SUBMIT"): + # Try heuristics first + elements = getattr(ctx.snapshot, "elements", []) or [] + url = getattr(ctx.snapshot, "url", "") or "" + element_id = await self._try_intent_heuristics(substep, elements, url) + + if element_id is None: + # Fall back to executor + sys_prompt, user_prompt = build_executor_prompt( + substep.goal, + substep.intent, + ctx.compact_representation, + ) + resp = self.executor.generate( + sys_prompt, + user_prompt, + temperature=self.config.executor_temperature, + max_new_tokens=self.config.executor_max_tokens, + ) + parsed_action, parsed_args = self._parse_action(resp.content) + if parsed_action == "CLICK" and parsed_args: + element_id = parsed_args[0] + + # Execute the action + if action_type == "CLICK" and element_id is not None: + await runtime.click(element_id) + elif action_type == "SCROLL": + direction = "down" # Default + await runtime.scroll(direction) + elif action_type == "NAVIGATE" and substep.target: + await runtime.goto(substep.target) + + outcomes.append(StepOutcome( + step_id=substep.id, + goal=substep.goal, + status=StepStatus.SUCCESS, + action_taken=f"{action_type}({element_id})" if element_id else action_type, + verification_passed=True, + )) + + except Exception as e: + outcomes.append(StepOutcome( + step_id=substep.id, + goal=substep.goal, + status=StepStatus.FAILED, + error=str(e), + )) + + return outcomes + async def _execute_step( self, step: PlanStep, runtime: AgentRuntime, step_index: int, ) -> StepOutcome: - """Execute a single plan step.""" + """Execute a single plan step with pre-verification, heuristics, and optional substeps.""" start_time = time.time() pre_url = await runtime.get_url() if hasattr(runtime, "get_url") else None step_id = self._emit_step_start(step, step_index, pre_url) @@ -1058,10 +1521,46 @@ async def _execute_step( llm_response: str | None = None action_taken: str | None = None used_vision = False + used_heuristics = False error: str | None = None verification_passed = False try: + # Pre-step verification check: skip if predicates already pass + if self.config.pre_step_verification and step.verify: + if await self._check_pre_step_verification(runtime, step): + # Step already satisfied, skip execution + verification_passed = True + action_taken = "SKIPPED(pre_verification_passed)" + + # Track successful URL for recovery + if self.config.recovery.track_successful_urls: + self._last_known_good_url = pre_url + + outcome = StepOutcome( + step_id=step.id, + goal=step.goal, + status=StepStatus.SKIPPED, + action_taken=action_taken, + verification_passed=True, + duration_ms=int((time.time() - start_time) * 1000), + url_before=pre_url, + url_after=pre_url, + ) + + self._emit_step_end( + step_id=step_id, + step_index=step_index, + step=step, + outcome=outcome, + pre_url=pre_url, + post_url=pre_url, + llm_response=None, + snapshot_digest=None, + ) + + return outcome + # Capture snapshot with escalation ctx = await self._snapshot_with_escalation( runtime, @@ -1081,52 +1580,122 @@ async def _execute_step( # Vision execution would go here # For now, fall through to standard executor - # Build executor prompt - sys_prompt, user_prompt = build_executor_prompt( - step.goal, - step.intent, - ctx.compact_representation, - ) - - # Call executor - resp = self.executor.generate( - sys_prompt, - user_prompt, - temperature=self.config.executor_temperature, - max_new_tokens=self.config.executor_max_tokens, - ) - llm_response = resp.content + # Determine element and action + action_type = step.action + element_id: int | None = None - # Parse and execute action - action_type, action_args = self._parse_action(resp.content) - action_taken = f"{action_type}({', '.join(str(a) for a in action_args)})" + if action_type in ("CLICK", "TYPE_AND_SUBMIT"): + # Try intent heuristics first (if available) + elements = getattr(ctx.snapshot, "elements", []) or [] + url = getattr(ctx.snapshot, "url", "") or "" + element_id = await self._try_intent_heuristics(step, elements, url) - # Execute action via runtime - if action_type == "CLICK" and action_args: - element_id = action_args[0] - await runtime.click(element_id) - elif action_type == "TYPE" and len(action_args) >= 2: - element_id, text = action_args[0], action_args[1] - await runtime.type(element_id, text) - elif action_type == "PRESS" and action_args: - key = action_args[0] - await runtime.press(key) + if element_id is not None: + used_heuristics = True + action_taken = f"{action_type}({element_id}) [heuristic]" + else: + # Fall back to LLM executor + sys_prompt, user_prompt = build_executor_prompt( + step.goal, + step.intent, + ctx.compact_representation, + ) + + resp = self.executor.generate( + sys_prompt, + user_prompt, + temperature=self.config.executor_temperature, + max_new_tokens=self.config.executor_max_tokens, + ) + llm_response = resp.content + + # Parse action + parsed_action, parsed_args = self._parse_action(resp.content) + action_type = parsed_action + if parsed_action == "CLICK" and parsed_args: + element_id = parsed_args[0] + elif parsed_action == "TYPE" and len(parsed_args) >= 2: + element_id = parsed_args[0] + + action_taken = f"{action_type}({', '.join(str(a) for a in parsed_args)})" + + # Apply executor override if configured + if self._executor_override and element_id is not None: + try: + is_valid, override_id, reason = self._executor_override.validate_choice( + element_id=element_id, + action=action_type, + elements=elements, + goal=step.goal, + ) + if not is_valid: + if override_id is not None: + element_id = override_id + action_taken = f"{action_type}({element_id}) [override]" + else: + error = f"Executor override rejected: {reason}" + except Exception: + pass # Ignore override errors + + elif action_type == "NAVIGATE": + action_taken = f"NAVIGATE({step.target})" elif action_type == "SCROLL": - direction = action_args[0] if action_args else "down" - await runtime.scroll(direction) - elif action_type == "FINISH": - pass # No action needed + action_taken = "SCROLL(down)" else: - error = f"Unknown action: {action_type}" + action_taken = action_type + + # Execute action via runtime + if error is None: + if action_type == "CLICK" and element_id is not None: + await runtime.click(element_id) + elif action_type == "TYPE" and element_id is not None: + text = step.input or "" + await runtime.type(element_id, text) + elif action_type == "TYPE_AND_SUBMIT" and element_id is not None: + text = step.input or "" + await runtime.type(element_id, text) + await runtime.press("Enter") + elif action_type == "PRESS": + key = "Enter" # Default + await runtime.press(key) + elif action_type == "SCROLL": + await runtime.scroll("down") + elif action_type == "NAVIGATE" and step.target: + await runtime.goto(step.target) + elif action_type == "FINISH": + pass # No action needed + elif action_type not in ("CLICK", "TYPE", "TYPE_AND_SUBMIT") or element_id is None: + if action_type in ("CLICK", "TYPE", "TYPE_AND_SUBMIT"): + error = f"No element ID for {action_type}" + else: + error = f"Unknown action: {action_type}" # Record action for tracing - await runtime.record_action(action_taken) + if action_taken: + await runtime.record_action(action_taken) # Run verifications - if step.verify: + if step.verify and error is None: verification_passed = await self._verify_step(runtime, step) + + # If verification failed and we have optional substeps, try them + if not verification_passed and step.optional_substeps: + substep_outcomes = await self._execute_optional_substeps( + step.optional_substeps, + runtime, + step_index, + ) + # Re-run verification after substeps + if any(o.status == StepStatus.SUCCESS for o in substep_outcomes): + verification_passed = await self._verify_step(runtime, step) else: - verification_passed = True + verification_passed = error is None + + # Track successful URL for recovery + if verification_passed and self.config.recovery.track_successful_urls: + post_url = await runtime.get_url() if hasattr(runtime, "get_url") else None + if post_url: + self._last_known_good_url = post_url except Exception as e: error = str(e) diff --git a/tests/unit/test_planner_executor_agent.py b/tests/unit/test_planner_executor_agent.py new file mode 100644 index 0000000..02a579b --- /dev/null +++ b/tests/unit/test_planner_executor_agent.py @@ -0,0 +1,575 @@ +""" +Unit tests for PlannerExecutorAgent. + +Tests for: +- IntentHeuristics protocol +- ExecutorOverride protocol +- Pre-step verification (skip if predicates pass) +- Optional substeps execution +- Plan normalization +- Plan smoothness validation +- RecoveryNavigationConfig +""" + +from __future__ import annotations + +import pytest +from typing import Any + +from predicate.agents.planner_executor_agent import ( + ExecutorOverride, + IntentHeuristics, + Plan, + PlannerExecutorConfig, + PlanStep, + PredicateSpec, + RecoveryNavigationConfig, + normalize_plan, + validate_plan_smoothness, +) + + +# --------------------------------------------------------------------------- +# Test normalize_plan +# --------------------------------------------------------------------------- + + +class TestNormalizePlan: + """Tests for the normalize_plan function.""" + + def test_normalizes_action_to_uppercase(self) -> None: + plan_dict = { + "task": "test", + "steps": [ + {"id": 1, "goal": "click button", "action": "click", "verify": []}, + ], + } + result = normalize_plan(plan_dict) + assert result["steps"][0]["action"] == "CLICK" + + def test_normalizes_action_aliases(self) -> None: + test_cases = [ + ("CLICK_ELEMENT", "CLICK"), + ("CLICK_BUTTON", "CLICK"), + ("CLICK_LINK", "CLICK"), + ("INPUT", "TYPE_AND_SUBMIT"), + ("TYPE_TEXT", "TYPE_AND_SUBMIT"), + ("ENTER_TEXT", "TYPE_AND_SUBMIT"), + ("GOTO", "NAVIGATE"), + ("GO_TO", "NAVIGATE"), + ("OPEN", "NAVIGATE"), + ("SCROLL_DOWN", "SCROLL"), + ("SCROLL_UP", "SCROLL"), + ] + + for alias, expected in test_cases: + plan_dict = { + "task": "test", + "steps": [ + {"id": 1, "goal": "test", "action": alias, "verify": []}, + ], + } + result = normalize_plan(plan_dict) + assert result["steps"][0]["action"] == expected, f"Failed for alias: {alias}" + + def test_normalizes_url_to_target(self) -> None: + plan_dict = { + "task": "test", + "steps": [ + { + "id": 1, + "goal": "navigate", + "action": "NAVIGATE", + "url": "https://example.com", + "verify": [], + }, + ], + } + result = normalize_plan(plan_dict) + assert "target" in result["steps"][0] + assert result["steps"][0]["target"] == "https://example.com" + assert "url" not in result["steps"][0] + + def test_preserves_existing_target(self) -> None: + plan_dict = { + "task": "test", + "steps": [ + { + "id": 1, + "goal": "navigate", + "action": "NAVIGATE", + "target": "https://example.com", + "url": "https://other.com", + "verify": [], + }, + ], + } + result = normalize_plan(plan_dict) + assert result["steps"][0]["target"] == "https://example.com" + + def test_converts_string_id_to_int(self) -> None: + plan_dict = { + "task": "test", + "steps": [ + {"id": "1", "goal": "test", "action": "CLICK", "verify": []}, + {"id": "2", "goal": "test2", "action": "CLICK", "verify": []}, + ], + } + result = normalize_plan(plan_dict) + assert result["steps"][0]["id"] == 1 + assert result["steps"][1]["id"] == 2 + + def test_normalizes_optional_substeps(self) -> None: + plan_dict = { + "task": "test", + "steps": [ + { + "id": 1, + "goal": "test", + "action": "CLICK", + "verify": [], + "optional_substeps": [ + { + "id": 1, + "goal": "scroll", + "action": "scroll_down", + "url": "https://example.com", + }, + ], + }, + ], + } + result = normalize_plan(plan_dict) + substep = result["steps"][0]["optional_substeps"][0] + assert substep["action"] == "SCROLL_DOWN" + assert substep["target"] == "https://example.com" + + +# --------------------------------------------------------------------------- +# Test validate_plan_smoothness +# --------------------------------------------------------------------------- + + +class TestValidatePlanSmoothness: + """Tests for the validate_plan_smoothness function.""" + + def test_warns_on_empty_plan(self) -> None: + plan = Plan(task="test", steps=[]) + warnings = validate_plan_smoothness(plan) + assert "Plan has no steps" in warnings + + def test_warns_on_single_step_plan(self) -> None: + plan = Plan( + task="test", + steps=[ + PlanStep(id=1, goal="test", action="CLICK", verify=[]), + ], + ) + warnings = validate_plan_smoothness(plan) + assert any("only one step" in w for w in warnings) + + def test_warns_on_missing_verification_for_required_step(self) -> None: + plan = Plan( + task="test", + steps=[ + PlanStep(id=1, goal="step1", action="CLICK", verify=[], required=True), + PlanStep(id=2, goal="step2", action="CLICK", verify=[], required=True), + ], + ) + warnings = validate_plan_smoothness(plan) + assert any("no verification" in w for w in warnings) + + def test_no_warning_for_optional_step_without_verification(self) -> None: + plan = Plan( + task="test", + steps=[ + PlanStep( + id=1, + goal="step1", + action="CLICK", + verify=[PredicateSpec(predicate="url_contains", args=["test"])], + required=True, + ), + PlanStep(id=2, goal="step2", action="CLICK", verify=[], required=False), + ], + ) + warnings = validate_plan_smoothness(plan) + # Only one warning about step 2 having no verification + verification_warnings = [w for w in warnings if "no verification" in w] + assert len(verification_warnings) == 0 # required=False means no warning + + def test_warns_on_navigate_without_target(self) -> None: + plan = Plan( + task="test", + steps=[ + PlanStep(id=1, goal="go", action="NAVIGATE", target=None, verify=[]), + PlanStep(id=2, goal="done", action="FINISH", verify=[]), + ], + ) + warnings = validate_plan_smoothness(plan) + assert any("NAVIGATE but has no target" in w for w in warnings) + + def test_warns_on_click_without_intent(self) -> None: + plan = Plan( + task="test", + steps=[ + PlanStep(id=1, goal="click", action="CLICK", intent=None, verify=[]), + PlanStep(id=2, goal="done", action="FINISH", verify=[]), + ], + ) + warnings = validate_plan_smoothness(plan) + assert any("CLICK but has no intent" in w for w in warnings) + + def test_warns_on_type_and_submit_without_input(self) -> None: + plan = Plan( + task="test", + steps=[ + PlanStep( + id=1, + goal="type", + action="TYPE_AND_SUBMIT", + input=None, + verify=[], + ), + PlanStep(id=2, goal="done", action="FINISH", verify=[]), + ], + ) + warnings = validate_plan_smoothness(plan) + assert any("TYPE_AND_SUBMIT but has no input" in w for w in warnings) + + def test_warns_on_consecutive_click_actions(self) -> None: + plan = Plan( + task="test", + steps=[ + PlanStep( + id=1, + goal="click1", + action="CLICK", + intent="btn1", + verify=[PredicateSpec(predicate="url_contains", args=["a"])], + ), + PlanStep( + id=2, + goal="click2", + action="CLICK", + intent="btn2", + verify=[PredicateSpec(predicate="url_contains", args=["b"])], + ), + ], + ) + warnings = validate_plan_smoothness(plan) + assert any("both use CLICK" in w for w in warnings) + + def test_no_warnings_for_good_plan(self) -> None: + plan = Plan( + task="test", + steps=[ + PlanStep( + id=1, + goal="navigate", + action="NAVIGATE", + target="https://example.com", + verify=[PredicateSpec(predicate="url_contains", args=["example"])], + ), + PlanStep( + id=2, + goal="click", + action="CLICK", + intent="submit", + verify=[PredicateSpec(predicate="exists", args=["role=button"])], + ), + ], + ) + warnings = validate_plan_smoothness(plan) + # May have some warnings (consecutive actions check is lenient) + # but no critical issues + assert "Plan has no steps" not in warnings + + +# --------------------------------------------------------------------------- +# Test IntentHeuristics Protocol +# --------------------------------------------------------------------------- + + +class TestIntentHeuristicsProtocol: + """Tests for the IntentHeuristics protocol.""" + + def test_protocol_check_passes_for_valid_implementation(self) -> None: + class ValidHeuristics: + def find_element_for_intent( + self, + intent: str, + elements: list[Any], + url: str, + goal: str, + ) -> int | None: + return None + + def priority_order(self) -> list[str]: + return [] + + heuristics = ValidHeuristics() + assert isinstance(heuristics, IntentHeuristics) + + def test_heuristics_can_return_element_id(self) -> None: + class MockHeuristics: + def find_element_for_intent( + self, + intent: str, + elements: list[Any], + url: str, + goal: str, + ) -> int | None: + if intent == "add_to_cart": + for el in elements: + if getattr(el, "text", "").lower() == "add to cart": + return getattr(el, "id", None) + return None + + def priority_order(self) -> list[str]: + return ["add_to_cart", "checkout"] + + class MockElement: + def __init__(self, id: int, text: str): + self.id = id + self.text = text + + heuristics = MockHeuristics() + elements = [ + MockElement(1, "Some text"), + MockElement(2, "Add to Cart"), + MockElement(3, "Other button"), + ] + + result = heuristics.find_element_for_intent( + intent="add_to_cart", + elements=elements, + url="https://example.com", + goal="add item to cart", + ) + assert result == 2 + + def test_heuristics_returns_none_for_unknown_intent(self) -> None: + class MockHeuristics: + def find_element_for_intent( + self, + intent: str, + elements: list[Any], + url: str, + goal: str, + ) -> int | None: + return None + + def priority_order(self) -> list[str]: + return [] + + heuristics = MockHeuristics() + result = heuristics.find_element_for_intent( + intent="unknown", + elements=[], + url="https://example.com", + goal="test", + ) + assert result is None + + +# --------------------------------------------------------------------------- +# Test ExecutorOverride Protocol +# --------------------------------------------------------------------------- + + +class TestExecutorOverrideProtocol: + """Tests for the ExecutorOverride protocol.""" + + def test_protocol_check_passes_for_valid_implementation(self) -> None: + class ValidOverride: + def validate_choice( + self, + element_id: int, + action: str, + elements: list[Any], + goal: str, + ) -> tuple[bool, int | None, str | None]: + return True, None, None + + override = ValidOverride() + assert isinstance(override, ExecutorOverride) + + def test_override_can_block_action(self) -> None: + class SafetyOverride: + def validate_choice( + self, + element_id: int, + action: str, + elements: list[Any], + goal: str, + ) -> tuple[bool, int | None, str | None]: + for el in elements: + if getattr(el, "id", None) == element_id: + text = getattr(el, "text", "").lower() + if "delete" in text: + return False, None, "blocked_delete" + return True, None, None + + class MockElement: + def __init__(self, id: int, text: str): + self.id = id + self.text = text + + override = SafetyOverride() + elements = [ + MockElement(1, "Submit"), + MockElement(2, "Delete Account"), + ] + + # Should allow submit button + is_valid, override_id, reason = override.validate_choice( + element_id=1, + action="CLICK", + elements=elements, + goal="submit form", + ) + assert is_valid is True + + # Should block delete button + is_valid, override_id, reason = override.validate_choice( + element_id=2, + action="CLICK", + elements=elements, + goal="delete account", + ) + assert is_valid is False + assert reason == "blocked_delete" + + def test_override_can_suggest_alternative(self) -> None: + class CorrectionOverride: + def validate_choice( + self, + element_id: int, + action: str, + elements: list[Any], + goal: str, + ) -> tuple[bool, int | None, str | None]: + # Always suggest element 5 instead + return False, 5, "corrected" + + override = CorrectionOverride() + is_valid, override_id, reason = override.validate_choice( + element_id=1, + action="CLICK", + elements=[], + goal="test", + ) + assert is_valid is False + assert override_id == 5 + assert reason == "corrected" + + +# --------------------------------------------------------------------------- +# Test RecoveryNavigationConfig +# --------------------------------------------------------------------------- + + +class TestRecoveryNavigationConfig: + """Tests for RecoveryNavigationConfig.""" + + def test_default_values(self) -> None: + config = RecoveryNavigationConfig() + assert config.enabled is True + assert config.max_recovery_attempts == 2 + assert config.track_successful_urls is True + + def test_custom_values(self) -> None: + config = RecoveryNavigationConfig( + enabled=False, + max_recovery_attempts=5, + track_successful_urls=False, + ) + assert config.enabled is False + assert config.max_recovery_attempts == 5 + assert config.track_successful_urls is False + + +# --------------------------------------------------------------------------- +# Test PlannerExecutorConfig with new options +# --------------------------------------------------------------------------- + + +class TestPlannerExecutorConfigNewOptions: + """Tests for new config options in PlannerExecutorConfig.""" + + def test_pre_step_verification_default_enabled(self) -> None: + config = PlannerExecutorConfig() + assert config.pre_step_verification is True + + def test_pre_step_verification_can_be_disabled(self) -> None: + config = PlannerExecutorConfig(pre_step_verification=False) + assert config.pre_step_verification is False + + def test_recovery_config_present(self) -> None: + config = PlannerExecutorConfig() + assert config.recovery is not None + assert config.recovery.enabled is True + + def test_custom_recovery_config(self) -> None: + config = PlannerExecutorConfig( + recovery=RecoveryNavigationConfig( + enabled=True, + max_recovery_attempts=3, + ), + ) + assert config.recovery.max_recovery_attempts == 3 + + +# --------------------------------------------------------------------------- +# Test PlanStep with optional_substeps +# --------------------------------------------------------------------------- + + +class TestPlanStepOptionalSubsteps: + """Tests for PlanStep optional_substeps field.""" + + def test_optional_substeps_default_empty(self) -> None: + step = PlanStep(id=1, goal="test", action="CLICK", verify=[]) + assert step.optional_substeps == [] + + def test_optional_substeps_can_be_set(self) -> None: + substep = PlanStep(id=1, goal="scroll", action="SCROLL", verify=[], required=False) + step = PlanStep( + id=1, + goal="click", + action="CLICK", + verify=[], + optional_substeps=[substep], + ) + assert len(step.optional_substeps) == 1 + assert step.optional_substeps[0].action == "SCROLL" + + def test_optional_substeps_nested_structure(self) -> None: + step_dict = { + "id": 1, + "goal": "click product", + "action": "CLICK", + "intent": "first_product", + "verify": [{"predicate": "url_contains", "args": ["/dp/"]}], + "optional_substeps": [ + { + "id": 1, + "goal": "scroll down", + "action": "SCROLL", + "required": False, + }, + { + "id": 2, + "goal": "retry click", + "action": "CLICK", + "intent": "first_product", + "verify": [{"predicate": "url_contains", "args": ["/dp/"]}], + "required": False, + }, + ], + } + step = PlanStep.model_validate(step_dict) + assert len(step.optional_substeps) == 2 + assert step.optional_substeps[0].action == "SCROLL" + assert step.optional_substeps[1].action == "CLICK" From 1bce000985400ccae430161a371cd5e9d5c2d3c7 Mon Sep 17 00:00:00 2001 From: SentienceDEV Date: Sat, 14 Mar 2026 23:49:20 -0700 Subject: [PATCH 3/3] fix tests --- predicate/agents/planner_executor_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/predicate/agents/planner_executor_agent.py b/predicate/agents/planner_executor_agent.py index 6832ea9..70e6ee2 100644 --- a/predicate/agents/planner_executor_agent.py +++ b/predicate/agents/planner_executor_agent.py @@ -124,6 +124,7 @@ def priority_order(self) -> list[str]: ... +@runtime_checkable class ExecutorOverride(Protocol): """ Protocol for validating or overriding executor element choices.