diff --git a/docs/PLANNER_EXECUTOR_AGENT.md b/docs/PLANNER_EXECUTOR_AGENT.md new file mode 100644 index 0000000..d694f47 --- /dev/null +++ b/docs/PLANNER_EXECUTOR_AGENT.md @@ -0,0 +1,1063 @@ +# PlannerExecutorAgent User Manual + +The `PlannerExecutorAgent` is a two-tier agent architecture for browser automation that separates planning from execution: + +- **Planner**: Generates JSON execution plans with verification predicates (uses 7B+ model) +- **Executor**: Executes each step with snapshot-first verification (uses 3B-7B model) + +## Table of Contents + +1. [Quick Start](#quick-start) +2. [Core Concepts](#core-concepts) +3. [AutomationTask](#automationtask) +4. [Configuration](#configuration) +5. [CAPTCHA Handling](#captcha-handling) +6. [Permissions](#permissions) +7. [Modal and Dialog Handling](#modal-and-dialog-handling) +8. [Recovery and Rollback](#recovery-and-rollback) +9. [Custom Heuristics](#custom-heuristics) +10. [Tracing](#tracing) +11. [Examples](#examples) + +--- + +## Quick Start + +```python +from predicate.agents import PlannerExecutorAgent, PlannerExecutorConfig, AutomationTask +from predicate.llm_provider import OpenAIProvider +from predicate import AsyncPredicateBrowser, AgentRuntime + +# Initialize LLM providers +planner = OpenAIProvider(model="gpt-4o") +executor = OpenAIProvider(model="gpt-4o-mini") + +# Create agent +agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + config=PlannerExecutorConfig(), +) + +# Run automation +async with AsyncPredicateBrowser() as browser: + runtime = AgentRuntime.from_browser(browser) + + result = await agent.run( + runtime=runtime, + task="Search for laptops on Amazon", + start_url="https://amazon.com", + ) + + print(f"Success: {result.success}") +``` + +--- + +## Core Concepts + +### Snapshot-First Verification + +The agent uses a **snapshot-first** approach: +1. Capture DOM snapshot before each action +2. Select element using heuristics or LLM +3. Execute action +4. Verify result using predicates + +### Predicate Verification + +Each plan step includes verification predicates: + +```json +{ + "id": 1, + "goal": "Click search button", + "action": "CLICK", + "intent": "search", + "verify": [ + {"predicate": "url_contains", "args": ["/search"]}, + {"predicate": "exists", "args": ["role=list"]} + ] +} +``` + +Available predicates: +- `url_contains(substring)`: URL contains the given string +- `url_matches(pattern)`: URL matches regex pattern +- `exists(selector)`: Element matching selector exists +- `not_exists(selector)`: Element does not exist +- `element_count(selector, min, max)`: Element count within range +- `any_of(predicates...)`: Any predicate is true +- `all_of(predicates...)`: All predicates are true + +Selectors: `role=button`, `role=link`, `text~'text'`, `role=textbox`, etc. + +### Snapshot Escalation and Scroll-after-Escalation + +When capturing DOM snapshots, the agent uses **incremental limit escalation** to ensure it captures enough elements: + +1. Start with `limit_base` (default: 60 elements) +2. If element count is low (<10), escalate by `limit_step` (default: 30) +3. Continue until `limit_max` (default: 200) is reached + +After exhausting limit escalation, if the target element is still not found, the agent can use **scroll-after-escalation** to find elements outside the current viewport. This feature only triggers when ALL conditions are met: + +1. `scroll_after_escalation=True` (default) +2. The step action is `CLICK` (not TYPE_AND_SUBMIT, NAVIGATE, etc.) +3. The step has a specific `intent` field +4. Custom `intent_heuristics` are injected into the agent + +When triggered: +1. Scroll down (up to `scroll_max_attempts` times, default: 3) +2. Take a new snapshot after each scroll +3. Check if target element is now visible using intent heuristics +4. If still not found, scroll up +5. Return the best snapshot found + +This is particularly useful for elements like "Add to Cart" buttons that may be below the initial viewport on product pages, when you have custom heuristics to detect them. + +```python +from predicate.agents import SnapshotEscalationConfig + +# Default behavior: escalation + scroll enabled +config = SnapshotEscalationConfig() + +# Disable scroll-after-escalation (only use limit escalation) +config = SnapshotEscalationConfig(scroll_after_escalation=False) + +# Custom scroll settings +config = SnapshotEscalationConfig( + scroll_after_escalation=True, + scroll_max_attempts=5, # More scrolls per direction + scroll_directions=("down",), # Only scroll down +) + +# Try scrolling up first (useful for elements at top of page) +config = SnapshotEscalationConfig( + scroll_directions=("up", "down"), # Try up before down +) +``` + +--- + +## AutomationTask + +The `AutomationTask` model provides a flexible way to define browser automation tasks: + +```python +from predicate.agents import AutomationTask, TaskCategory, SuccessCriteria + +# Basic task +task = AutomationTask( + task_id="purchase-laptop-001", + starting_url="https://amazon.com", + task="Find a laptop under $1000 with good reviews and add to cart", +) + +# Task with category hint +task = AutomationTask( + task_id="purchase-laptop-001", + starting_url="https://amazon.com", + task="Find a laptop under $1000 with good reviews and add to cart", + category=TaskCategory.TRANSACTION, + max_steps=50, + enable_recovery=True, +) + +# Task with success criteria +task = task.with_success_criteria( + {"predicate": "url_contains", "args": ["/cart"]}, + {"predicate": "exists", "args": [".cart-item"]}, +) + +# Run with AutomationTask +result = await agent.run(runtime, task) +``` + +### TaskCategory + +Categories help the planner and executor make better decisions: + +| Category | Use Case | +|----------|----------| +| `NAVIGATION` | Navigate to a destination | +| `SEARCH` | Search and find information | +| `FORM_FILL` | Fill out forms | +| `EXTRACTION` | Extract data from pages | +| `TRANSACTION` | Purchase, submit, create actions | +| `VERIFICATION` | Verify state/information exists | + +### Extraction Tasks + +For data extraction tasks: + +```python +from predicate.agents import ExtractionSpec + +task = AutomationTask( + task_id="extract-product-info", + starting_url="https://amazon.com/dp/B0...", + task="Extract the product name, price, and rating", + category=TaskCategory.EXTRACTION, + extraction_spec=ExtractionSpec( + output_schema={"name": "str", "price": "float", "rating": "float"}, + format="json", + ), +) +``` + +--- + +## Configuration + +### PlannerExecutorConfig + +```python +from predicate.agents import ( + PlannerExecutorConfig, + SnapshotEscalationConfig, + RetryConfig, + RecoveryNavigationConfig, +) +from predicate.agents.browser_agent import VisionFallbackConfig, CaptchaConfig + +config = PlannerExecutorConfig( + # Snapshot escalation: progressively increase limit on low element count + # After exhausting limit escalation, scrolls to find elements outside viewport + snapshot=SnapshotEscalationConfig( + enabled=True, + limit_base=60, # Initial snapshot limit + limit_step=30, # Increment per escalation + limit_max=200, # Maximum limit + # Scroll-after-escalation: find elements below/above viewport + scroll_after_escalation=True, # Enable scrolling after limit exhaustion + scroll_max_attempts=3, # Max scrolls per direction + scroll_directions=("down", "up"), # Directions to try + ), + + # Retry configuration + retry=RetryConfig( + verify_timeout_s=10.0, + verify_poll_s=0.5, + verify_max_attempts=5, + executor_repair_attempts=2, + max_replans=1, + ), + + # Vision fallback for canvas pages or low-confidence snapshots + vision=VisionFallbackConfig( + enabled=True, + max_vision_calls=3, + trigger_requires_vision=True, + trigger_canvas_or_low_actionables=True, + ), + + # CAPTCHA handling + captcha=CaptchaConfig(), # See CAPTCHA section below + + # Recovery navigation + recovery=RecoveryNavigationConfig( + enabled=True, + max_recovery_attempts=2, + track_successful_urls=True, + ), + + # LLM settings + planner_max_tokens=2048, + planner_temperature=0.0, + executor_max_tokens=96, + executor_temperature=0.0, + + # Stabilization + stabilize_enabled=True, + stabilize_poll_s=0.35, + stabilize_max_attempts=6, + + # Pre-step verification: skip step if predicates already pass + pre_step_verification=True, + + # Tracing + trace_screenshots=True, + trace_screenshot_format="jpeg", + trace_screenshot_quality=80, +) + +agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + config=config, +) +``` + +--- + +## CAPTCHA Handling + +The SDK provides flexible CAPTCHA handling through the `CaptchaConfig` system. + +### CAPTCHA Policies + +| Policy | Behavior | +|--------|----------| +| `abort` | Fail immediately when CAPTCHA is detected (default) | +| `callback` | Invoke a handler and wait for resolution | + +### Using a CAPTCHA Handler + +```python +from predicate.agents.browser_agent import CaptchaConfig +from predicate.captcha_strategies import ( + HumanHandoffSolver, + ExternalSolver, + VisionSolver, +) + +# Option 1: Human handoff - waits for manual solve +config = PlannerExecutorConfig( + captcha=CaptchaConfig( + policy="callback", + handler=HumanHandoffSolver( + message="Please solve the CAPTCHA in the browser window", + timeout_ms=120_000, # 2 minute timeout + ), + ), +) + +# Option 2: External solver integration (e.g., 2Captcha, CapSolver) +def solve_with_2captcha(ctx): + # ctx.url - current page URL + # ctx.screenshot_path - path to screenshot + # ctx.captcha - CaptchaDiagnostics with type info + + # Call your CAPTCHA solving service here + # Return when solved or raise on failure + pass + +config = PlannerExecutorConfig( + captcha=CaptchaConfig( + policy="callback", + handler=ExternalSolver( + resolver=solve_with_2captcha, + message="Solving CAPTCHA via 2Captcha", + timeout_ms=180_000, + ), + ), +) + +# Option 3: Vision-based solving (requires vision model) +config = PlannerExecutorConfig( + captcha=CaptchaConfig( + policy="callback", + handler=VisionSolver( + message="Attempting vision-based CAPTCHA solve", + timeout_ms=60_000, + ), + ), +) +``` + +### CaptchaContext + +When the CAPTCHA handler is invoked, it receives a `CaptchaContext`: + +```python +@dataclass +class CaptchaContext: + run_id: str # Current run ID + step_index: int # Current step index + url: str # Current page URL + source: CaptchaSource # "extension" | "gateway" | "runtime" + captcha: CaptchaDiagnostics # CAPTCHA type and details + screenshot_path: str | None # Path to screenshot + frames_dir: str | None # Directory with frame images + snapshot_path: str | None # Path to snapshot + live_session_url: str | None # URL for live debugging + meta: dict | None # Additional metadata + page_control: PageControlHook | None # JS evaluation hook +``` + +### CaptchaResolution Actions + +| Action | Behavior | +|--------|----------| +| `abort` | Stop automation immediately | +| `retry_new_session` | Clear cookies and retry | +| `wait_until_cleared` | Poll until CAPTCHA is cleared | + +### Implementing a Custom CAPTCHA Handler + +```python +from predicate.captcha import CaptchaContext, CaptchaResolution, CaptchaHandler + +async def my_captcha_handler(ctx: CaptchaContext) -> CaptchaResolution: + """Custom CAPTCHA handler with external solving service.""" + + # Example: Integrate with 2Captcha + import requests + + # Read screenshot for solving + if ctx.screenshot_path: + with open(ctx.screenshot_path, 'rb') as f: + image_data = f.read() + + # Submit to solving service + # response = requests.post("https://2captcha.com/in.php", ...) + # solution = poll_for_solution(response['captcha_id']) + + # Inject solution if page_control is available + if ctx.page_control: + await ctx.page_control.evaluate_js(f""" + document.getElementById('captcha-input').value = '{solution}'; + document.getElementById('captcha-form').submit(); + """) + + return CaptchaResolution( + action="wait_until_cleared", + message="CAPTCHA solved via 2Captcha", + handled_by="customer_system", + timeout_ms=30_000, + poll_ms=2_000, + ) + +config = PlannerExecutorConfig( + captcha=CaptchaConfig( + policy="callback", + handler=my_captcha_handler, + timeout_ms=180_000, + min_confidence=0.7, + ), +) +``` + +### Dependencies for CAPTCHA Solving + +For external CAPTCHA solving services: + +```bash +# For 2Captcha integration +pip install 2captcha-python + +# For CapSolver integration +pip install capsolver +``` + +Example with 2Captcha: + +```python +from twocaptcha import TwoCaptcha + +solver = TwoCaptcha('YOUR_API_KEY') + +def solve_with_2captcha(ctx: CaptchaContext): + if ctx.captcha.type == "recaptcha": + result = solver.recaptcha( + sitekey=ctx.captcha.sitekey, + url=ctx.url, + ) + # Inject the solution + if ctx.page_control: + import asyncio + asyncio.get_event_loop().run_until_complete( + ctx.page_control.evaluate_js(f""" + document.getElementById('g-recaptcha-response').innerHTML = '{result["code"]}'; + """) + ) + return True +``` + +--- + +## Permissions + +Chrome browser permissions (geolocation, notifications, etc.) are handled at two levels: + +### Startup Permissions + +Applied at browser/context creation: + +```python +from predicate.permissions import PermissionPolicy + +policy = PermissionPolicy( + default="grant", # "clear" | "deny" | "grant" + auto_grant=["geolocation", "notifications", "camera", "microphone"], + geolocation={"latitude": 37.7749, "longitude": -122.4194}, # Mock location + origin="https://example.com", +) + +# Apply via browser configuration (implementation-dependent) +``` + +### Recovery Permissions + +For handling permission prompts during automation: + +```python +from predicate.agents.browser_agent import PermissionRecoveryConfig + +config = PlannerExecutorConfig( + # Other config... +) + +# PermissionRecoveryConfig is used at agent level +permission_recovery = PermissionRecoveryConfig( + enabled=True, + max_restarts=1, + auto_grant=["geolocation", "notifications"], + geolocation={"latitude": 37.7749, "longitude": -122.4194}, + origin="https://example.com", +) +``` + +### Common Permissions + +| Permission | Description | +|------------|-------------| +| `geolocation` | Access to device location | +| `notifications` | Push notification access | +| `camera` | Camera access | +| `microphone` | Microphone access | +| `clipboard-read` | Read clipboard | +| `clipboard-write` | Write clipboard | + +--- + +## Modal and Dialog Handling + +Modal and dialog handling is done through plan steps with heuristic hints: + +### Common Modal Patterns + +The SDK includes common hints for dismissing modals: + +```python +from predicate.agents import COMMON_HINTS, get_common_hint + +# Built-in hints for common patterns +close_hint = get_common_hint("close") # "close", "dismiss", "x", "cancel" +accept_cookies = get_common_hint("accept_cookies") # "accept", "allow", "agree" +``` + +### Handling Cookie Consent + +```python +task = AutomationTask( + task_id="example", + starting_url="https://example.com", + task="Accept cookies and then search for products", +) + +# The planner will generate steps with heuristic hints: +# { +# "id": 1, +# "goal": "Accept cookie consent", +# "action": "CLICK", +# "intent": "accept_cookies", +# "heuristic_hints": [ +# { +# "intent_pattern": "accept_cookies", +# "text_patterns": ["accept", "accept all", "allow", "agree"], +# "role_filter": ["button"] +# } +# ] +# } +``` + +### Custom Modal Handling Heuristics + +For site-specific modals, provide custom heuristics: + +```python +class ModalDismissHeuristics: + def find_element_for_intent(self, intent, elements, url, goal): + if "dismiss" in intent.lower() or "close" in intent.lower(): + # Look for close buttons + for el in elements: + text = (getattr(el, "text", "") or "").lower() + role = getattr(el, "role", "") or "" + + # Common close button patterns + if role == "button": + if any(p in text for p in ["close", "dismiss", "x", "no thanks", "cancel"]): + return getattr(el, "id", None) + + # Close icon (×) + if text in ["×", "x", "✕", "✖"]: + return getattr(el, "id", None) + + return None + + def priority_order(self): + return ["close", "dismiss", "accept_cookies"] + +agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + intent_heuristics=ModalDismissHeuristics(), +) +``` + +### Optional Substeps for Modals + +The planner can generate optional substeps for edge cases: + +```json +{ + "id": 2, + "goal": "Search for laptops", + "action": "TYPE_AND_SUBMIT", + "input": "laptop", + "verify": [{"predicate": "url_contains", "args": ["/search"]}], + "optional_substeps": [ + { + "id": 201, + "goal": "Dismiss any modal that may have appeared", + "action": "CLICK", + "intent": "close" + } + ] +} +``` + +### Automatic Modal Dismissal + +The agent includes automatic modal/drawer dismissal that triggers after DOM changes. This handles common blocking scenarios: + +- Product protection/warranty upsells (e.g., Amazon's "Add Protection Plan") +- Cookie consent banners +- Newsletter signup popups +- Promotional overlays +- Cart upsell drawers + +**How It Works:** + +When a CLICK action triggers a significant DOM change (5+ new elements), the agent: +1. Detects that a modal may have appeared +2. Scans for dismissal buttons using common text patterns +3. Clicks the best matching button to clear the overlay +4. Continues with the task + +**Default Configuration:** + +```python +from predicate.agents import PlannerExecutorConfig, ModalDismissalConfig + +# Default: enabled with common English patterns +config = PlannerExecutorConfig() +print(config.modal.enabled) # True +print(config.modal.dismiss_patterns[:3]) # ('no thanks', 'no, thanks', 'no thank you') +``` + +**Disabling Modal Dismissal:** + +```python +config = PlannerExecutorConfig( + modal=ModalDismissalConfig(enabled=False), +) +``` + +**Custom Patterns for Internationalization:** + +The dismissal patterns are fully configurable for non-English sites: + +```python +# German site configuration +config = PlannerExecutorConfig( + modal=ModalDismissalConfig( + dismiss_patterns=( + "nein danke", # No thanks + "nicht jetzt", # Not now + "abbrechen", # Cancel + "schließen", # Close + "überspringen", # Skip + "später", # Later + ), + dismiss_icons=("x", "×", "✕"), # Icons are universal + ), +) + +# Spanish site configuration +config = PlannerExecutorConfig( + modal=ModalDismissalConfig( + dismiss_patterns=( + "no gracias", # No thanks + "ahora no", # Not now + "cancelar", # Cancel + "cerrar", # Close + "omitir", # Skip + ), + ), +) + +# French site configuration +config = PlannerExecutorConfig( + modal=ModalDismissalConfig( + dismiss_patterns=( + "non merci", # No thanks + "pas maintenant", # Not now + "annuler", # Cancel + "fermer", # Close + "passer", # Skip + ), + ), +) + +# Japanese site configuration +config = PlannerExecutorConfig( + modal=ModalDismissalConfig( + dismiss_patterns=( + "いいえ", # No + "後で", # Later + "閉じる", # Close + "キャンセル", # Cancel + "スキップ", # Skip + ), + ), +) +``` + +**Multilingual Configuration:** + +For sites that may show modals in multiple languages: + +```python +config = PlannerExecutorConfig( + modal=ModalDismissalConfig( + dismiss_patterns=( + # English + "no thanks", "not now", "close", "skip", "cancel", + # German + "nein danke", "schließen", "abbrechen", + # Spanish + "no gracias", "cerrar", "cancelar", + # French + "non merci", "fermer", "annuler", + ), + ), +) +``` + +**Pattern Matching:** + +- **Word boundary matching**: Patterns use word boundary matching to avoid false positives (e.g., "close" won't match "enclosed") +- **Icon exact matching**: Short patterns like "x", "×" require exact match +- **Pattern ordering**: Earlier patterns in the list have higher priority + +--- + +## Recovery and Rollback + +The agent tracks checkpoints for recovery when steps fail: + +### How Recovery Works + +1. After each successful step verification, a checkpoint is recorded +2. If a step fails repeatedly, the agent attempts recovery: + - Navigate back to the last successful URL + - Re-verify the page state + - Resume from the checkpoint step +3. Limited by `max_recovery_attempts` + +### Recovery Configuration + +```python +from predicate.agents import AutomationTask + +task = AutomationTask( + task_id="checkout-flow", + starting_url="https://shop.com", + task="Complete checkout process", + enable_recovery=True, # Enable recovery + max_recovery_attempts=2, # Max attempts +) +``` + +### RecoveryState API + +For advanced use cases: + +```python +from predicate.agents import RecoveryState, RecoveryCheckpoint + +state = RecoveryState(max_recovery_attempts=2) + +# Record checkpoint after successful step +checkpoint = state.record_checkpoint( + url="https://shop.com/cart", + step_index=2, + snapshot_digest="abc123", + predicates_passed=["url_contains('/cart')"], +) + +# Check if recovery is possible +if state.can_recover(): + checkpoint = state.consume_recovery_attempt() + # Navigate to checkpoint.url and resume +``` + +--- + +## Custom Heuristics + +### IntentHeuristics Protocol + +Implement domain-specific element selection: + +```python +class EcommerceHeuristics: + def find_element_for_intent(self, intent, elements, url, goal): + intent_lower = intent.lower() + + if "add to cart" in intent_lower or "add_to_cart" in intent_lower: + for el in elements: + text = (getattr(el, "text", "") or "").lower() + role = getattr(el, "role", "") or "" + if role == "button" and any(p in text for p in ["add to cart", "add to bag"]): + return getattr(el, "id", None) + + if "checkout" in intent_lower: + for el in elements: + text = (getattr(el, "text", "") or "").lower() + if "checkout" in text or "proceed" in text: + return getattr(el, "id", None) + + return None # Fall back to LLM + + def priority_order(self): + return ["add_to_cart", "checkout", "search", "login"] + +agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + intent_heuristics=EcommerceHeuristics(), +) +``` + +### ExecutorOverride Protocol + +Validate or override executor choices: + +```python +class SafetyOverride: + def validate_choice(self, element_id, action, elements, goal): + # Block clicks on delete buttons + for el in elements: + if getattr(el, "id", None) == element_id: + text = (getattr(el, "text", "") or "").lower() + if "delete" in text and action == "CLICK": + return False, None, "blocked_delete_button" + return True, None, None + +agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + executor_override=SafetyOverride(), +) +``` + +### ComposableHeuristics + +The agent uses `ComposableHeuristics` internally to compose from multiple sources: + +1. Planner-provided `HeuristicHint` (per step, highest priority) +2. Common hints for well-known patterns +3. Static `IntentHeuristics` (user-injected) +4. `TaskCategory` defaults (lowest priority) + +--- + +## Tracing + +Enable tracing for Predicate Studio visualization: + +```python +from predicate.tracing import Tracer + +tracer = Tracer(output_dir="./traces") + +agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + tracer=tracer, + config=PlannerExecutorConfig( + trace_screenshots=True, + trace_screenshot_format="jpeg", + trace_screenshot_quality=80, + ), +) + +# Run automation +result = await agent.run(runtime, task) + +# Trace files saved to ./traces/ +``` + +--- + +## Examples + +### E-commerce Purchase Flow + +```python +from predicate.agents import ( + PlannerExecutorAgent, + PlannerExecutorConfig, + AutomationTask, + TaskCategory, +) +from predicate.agents.browser_agent import CaptchaConfig +from predicate.captcha_strategies import HumanHandoffSolver +from predicate.llm_provider import OpenAIProvider + +# Setup +planner = OpenAIProvider(model="gpt-4o") +executor = OpenAIProvider(model="gpt-4o-mini") + +config = PlannerExecutorConfig( + captcha=CaptchaConfig( + policy="callback", + handler=HumanHandoffSolver(timeout_ms=120_000), + ), +) + +agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + config=config, +) + +task = AutomationTask( + task_id="buy-laptop", + starting_url="https://amazon.com", + task="Search for 'laptop under $500', add the first result to cart, proceed to checkout", + category=TaskCategory.TRANSACTION, + enable_recovery=True, + max_recovery_attempts=2, +) + +task = task.with_success_criteria( + {"predicate": "url_contains", "args": ["/cart"]}, + {"predicate": "exists", "args": [".cart-item"]}, +) + +async with AsyncPredicateBrowser() as browser: + runtime = AgentRuntime.from_browser(browser) + result = await agent.run(runtime, task) + print(f"Success: {result.success}, Steps: {result.steps_completed}/{result.steps_total}") +``` + +### Form Filling with Extraction + +```python +from predicate.agents import AutomationTask, TaskCategory, ExtractionSpec + +task = AutomationTask( + task_id="fill-contact-form", + starting_url="https://example.com/contact", + task="Fill the contact form with name 'John Doe', email 'john@example.com', and message 'Hello'", + category=TaskCategory.FORM_FILL, +).with_success_criteria( + {"predicate": "exists", "args": [".success-message"]}, +) + +result = await agent.run(runtime, task) +``` + +### Data Extraction + +```python +task = AutomationTask( + task_id="extract-prices", + starting_url="https://shop.com/products", + task="Extract all product names and prices from the product listing", + category=TaskCategory.EXTRACTION, + extraction_spec=ExtractionSpec( + output_schema={ + "products": [{"name": "str", "price": "float"}] + }, + format="json", + ), +) + +result = await agent.run(runtime, task) +``` + +--- + +## API Reference + +### PlannerExecutorAgent + +```python +class PlannerExecutorAgent: + def __init__( + self, + *, + planner: LLMProvider, # LLM for generating plans + executor: LLMProvider, # LLM for executing steps + vision_executor: LLMProvider | None = None, + vision_verifier: LLMProvider | None = None, + config: PlannerExecutorConfig | None = None, + tracer: Tracer | None = None, + context_formatter: Callable | None = None, + intent_heuristics: IntentHeuristics | None = None, + executor_override: ExecutorOverride | None = None, + ) + + async def run( + self, + runtime: AgentRuntime, + task: AutomationTask | str, + *, + start_url: str | None = None, + run_id: str | None = None, + ) -> RunOutcome + + async def plan( + self, + task: str, + *, + start_url: str | None = None, + max_attempts: int = 2, + ) -> Plan + + async def step( + self, + runtime: AgentRuntime, + step: PlanStep, + step_index: int = 0, + ) -> StepOutcome +``` + +### RunOutcome + +```python +@dataclass +class RunOutcome: + run_id: str + task: str + success: bool + steps_completed: int + steps_total: int + replans_used: int + step_outcomes: list[StepOutcome] + total_duration_ms: int + error: str | None +``` + +### StepOutcome + +```python +@dataclass +class StepOutcome: + step_id: int + goal: str + status: StepStatus # SUCCESS, FAILED, SKIPPED, VISION_FALLBACK + action_taken: str | None + verification_passed: bool + used_vision: bool + error: str | None + duration_ms: int + url_before: str | None + url_after: str | None +``` diff --git a/examples/planner-executor/README.md b/examples/planner-executor/README.md index 7ce9835..137e1bc 100644 --- a/examples/planner-executor/README.md +++ b/examples/planner-executor/README.md @@ -3,11 +3,15 @@ This directory contains examples for the `PlannerExecutorAgent`, a two-tier agent architecture with separate Planner (7B+) and Executor (3B-7B) models. +> **See also**: [Full User Manual](../../docs/PLANNER_EXECUTOR_AGENT.md) for comprehensive documentation. + ## Examples | File | Description | |------|-------------| | `minimal_example.py` | Basic usage with OpenAI models | +| `automation_task_example.py` | Using AutomationTask for flexible task definition | +| `captcha_example.py` | CAPTCHA handling with different solvers | | `local_models_example.py` | Using local HuggingFace/MLX models | | `custom_config_example.py` | Custom configuration (escalation, retry, vision) | | `tracing_example.py` | Full tracing integration for Predicate Studio | @@ -139,3 +143,57 @@ agent = PlannerExecutorAgent( tracer.close() # Upload trace to Studio ``` + +## AutomationTask + +Use `AutomationTask` for flexible task definition with built-in recovery: + +```python +from predicate.agents import AutomationTask, TaskCategory + +# Basic task +task = AutomationTask( + task_id="search-products", + starting_url="https://amazon.com", + task="Search for laptops and add the first result to cart", + category=TaskCategory.TRANSACTION, + enable_recovery=True, +) + +# Add success criteria +task = task.with_success_criteria( + {"predicate": "url_contains", "args": ["/cart"]}, + {"predicate": "exists", "args": [".cart-item"]}, +) + +result = await agent.run(runtime, task) +``` + +## CAPTCHA Handling + +Configure CAPTCHA solving with different strategies: + +```python +from predicate.agents.browser_agent import CaptchaConfig +from predicate.captcha_strategies import HumanHandoffSolver, ExternalSolver + +# Human handoff: wait for manual solve +config = PlannerExecutorConfig( + captcha=CaptchaConfig( + policy="callback", + handler=HumanHandoffSolver(timeout_ms=120_000), + ), +) + +# External solver: integrate with 2Captcha, CapSolver, etc. +def solve_captcha(ctx): + # Call your CAPTCHA solving service + pass + +config = PlannerExecutorConfig( + captcha=CaptchaConfig( + policy="callback", + handler=ExternalSolver(resolver=solve_captcha), + ), +) +``` diff --git a/examples/planner-executor/automation_task_example.py b/examples/planner-executor/automation_task_example.py new file mode 100644 index 0000000..b4514d4 --- /dev/null +++ b/examples/planner-executor/automation_task_example.py @@ -0,0 +1,242 @@ +""" +AutomationTask Example + +Demonstrates using AutomationTask for flexible task definition with: +- Task categories for better heuristics +- Success criteria for verification +- Recovery configuration for rollback +- Extraction specification for data extraction tasks + +Prerequisites: + pip install predicate-sdk openai + export OPENAI_API_KEY=sk-... +""" + +import asyncio + +from predicate import AsyncPredicateBrowser +from predicate.agent_runtime import AgentRuntime +from predicate.agents import ( + AutomationTask, + ExtractionSpec, + PlannerExecutorAgent, + PlannerExecutorConfig, + TaskCategory, +) +from predicate.llm_provider import OpenAIProvider + + +async def basic_task_example(): + """Basic AutomationTask usage.""" + print("\n=== Basic AutomationTask Example ===\n") + + # Create LLM providers + planner = OpenAIProvider(model="gpt-4o") + executor = OpenAIProvider(model="gpt-4o-mini") + + # Create agent + agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + ) + + # Create a basic task + task = AutomationTask( + task_id="search-example", + starting_url="https://example.com", + task="Find the main heading on the page", + ) + + print(f"Task ID: {task.task_id}") + print(f"Starting URL: {task.starting_url}") + print(f"Task: {task.task}") + + async with AsyncPredicateBrowser() as browser: + page = await browser.new_page() + await page.goto(task.starting_url) + + runtime = AgentRuntime.from_page(page) + result = await agent.run(runtime, task) + + print(f"\nResult: {'Success' if result.success else 'Failed'}") + print(f"Steps completed: {result.steps_completed}/{result.steps_total}") + + +async def transaction_task_example(): + """E-commerce transaction task with recovery.""" + print("\n=== Transaction Task Example ===\n") + + planner = OpenAIProvider(model="gpt-4o") + executor = OpenAIProvider(model="gpt-4o-mini") + + agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + ) + + # Create a transaction task with category and recovery + task = AutomationTask( + task_id="purchase-laptop", + starting_url="https://amazon.com", + task="Search for 'laptop under $500' and add the first result to cart", + category=TaskCategory.TRANSACTION, # Helps with element selection + enable_recovery=True, # Enable rollback on failure + max_recovery_attempts=2, + max_steps=50, + ) + + # Add success criteria + task = task.with_success_criteria( + {"predicate": "url_contains", "args": ["/cart"]}, + {"predicate": "exists", "args": [".cart-item, .sc-list-item"]}, + ) + + print(f"Task: {task.task}") + print(f"Category: {task.category}") + print(f"Recovery enabled: {task.enable_recovery}") + print(f"Success criteria: {task.success_criteria}") + + async with AsyncPredicateBrowser() as browser: + page = await browser.new_page() + await page.goto(task.starting_url) + + runtime = AgentRuntime.from_page(page) + result = await agent.run(runtime, task) + + print(f"\nResult: {'Success' if result.success else 'Failed'}") + print(f"Steps completed: {result.steps_completed}/{result.steps_total}") + print(f"Replans used: {result.replans_used}") + + if result.error: + print(f"Error: {result.error}") + + +async def extraction_task_example(): + """Data extraction task with schema.""" + print("\n=== Extraction Task Example ===\n") + + planner = OpenAIProvider(model="gpt-4o") + executor = OpenAIProvider(model="gpt-4o-mini") + + agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + ) + + # Create an extraction task with output schema + task = AutomationTask( + task_id="extract-product-info", + starting_url="https://amazon.com/dp/B0EXAMPLE", + task="Extract the product name, price, and rating", + category=TaskCategory.EXTRACTION, + extraction_spec=ExtractionSpec( + output_schema={ + "name": "str", + "price": "float", + "rating": "float", + "num_reviews": "int", + }, + format="json", + require_evidence=True, + ), + ) + + print(f"Task: {task.task}") + print(f"Category: {task.category}") + print(f"Output schema: {task.extraction_spec.output_schema}") + print(f"Format: {task.extraction_spec.format}") + + # Note: This example won't run successfully as the URL is fake + # In real usage, provide a valid product URL + + +async def form_fill_task_example(): + """Form filling task example.""" + print("\n=== Form Fill Task Example ===\n") + + planner = OpenAIProvider(model="gpt-4o") + executor = OpenAIProvider(model="gpt-4o-mini") + + agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + ) + + # Create a form fill task + task = AutomationTask( + task_id="contact-form", + starting_url="https://example.com/contact", + task="Fill the contact form with name 'John Doe', email 'john@example.com', and message 'Hello, I have a question'", + category=TaskCategory.FORM_FILL, + ) + + # Add success criteria for form submission + task = task.with_success_criteria( + {"predicate": "any_of", "args": [ + {"predicate": "exists", "args": [".success-message"]}, + {"predicate": "url_contains", "args": ["/thank-you"]}, + ]}, + ) + + print(f"Task: {task.task}") + print(f"Category: {task.category}") + + +async def from_string_example(): + """Create task from simple string.""" + print("\n=== From String Example ===\n") + + # Quick task creation from string + task = AutomationTask.from_string( + "Search for 'headphones' and filter by price under $50", + "https://amazon.com", + category=TaskCategory.SEARCH, + ) + + print(f"Task ID: {task.task_id}") # Auto-generated UUID + print(f"Task: {task.task}") + print(f"Starting URL: {task.starting_url}") + print(f"Category: {task.category}") + + +async def with_extraction_example(): + """Add extraction to existing task.""" + print("\n=== With Extraction Example ===\n") + + # Create basic task + task = AutomationTask( + task_id="product-search", + starting_url="https://amazon.com", + task="Search for the cheapest laptop", + ) + + # Add extraction specification using fluent API + task_with_extraction = task.with_extraction( + output_schema={"product_name": "str", "price": "float"}, + format="json", + ) + + print(f"Original category: {task.category}") + print(f"After with_extraction: {task_with_extraction.category}") + print(f"Extraction spec: {task_with_extraction.extraction_spec}") + + +async def main(): + """Run all examples.""" + print("=" * 60) + print("AutomationTask Examples") + print("=" * 60) + + # Show task creation patterns (no browser needed) + await from_string_example() + await with_extraction_example() + + # These require browser and API keys + # Uncomment to run: + # await basic_task_example() + # await transaction_task_example() + # await form_fill_task_example() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/planner-executor/captcha_example.py b/examples/planner-executor/captcha_example.py new file mode 100644 index 0000000..16432e6 --- /dev/null +++ b/examples/planner-executor/captcha_example.py @@ -0,0 +1,312 @@ +""" +CAPTCHA Handling Example + +Demonstrates different CAPTCHA solving strategies: +1. Abort policy: Fail immediately on CAPTCHA +2. Human handoff: Wait for manual solve +3. External solver: Integrate with 2Captcha, CapSolver, etc. +4. Custom handler: Implement your own solving logic + +Prerequisites: + pip install predicate-sdk openai + +For external solvers: + pip install 2captcha-python # For 2Captcha + pip install capsolver # For CapSolver +""" + +import asyncio +from typing import Any + +from predicate import AsyncPredicateBrowser +from predicate.agent_runtime import AgentRuntime +from predicate.agents import ( + AutomationTask, + PlannerExecutorAgent, + PlannerExecutorConfig, + TaskCategory, +) +from predicate.agents.browser_agent import CaptchaConfig +from predicate.captcha import CaptchaContext, CaptchaResolution +from predicate.captcha_strategies import ( + ExternalSolver, + HumanHandoffSolver, + VisionSolver, +) +from predicate.llm_provider import OpenAIProvider + + +def create_abort_config() -> PlannerExecutorConfig: + """ + Abort policy: Fail immediately when CAPTCHA is detected. + + This is the default behavior. Use this when: + - You want to fail fast and handle CAPTCHA externally + - Your automation should not encounter CAPTCHAs (e.g., authenticated sessions) + """ + return PlannerExecutorConfig( + captcha=CaptchaConfig( + policy="abort", + min_confidence=0.7, # Confidence threshold for CAPTCHA detection + ), + ) + + +def create_human_handoff_config() -> PlannerExecutorConfig: + """ + Human handoff: Wait for manual CAPTCHA solve. + + Use this when: + - Running with a visible browser window + - A human operator can solve CAPTCHAs manually + - Using live session URLs for remote debugging + """ + return PlannerExecutorConfig( + captcha=CaptchaConfig( + policy="callback", + handler=HumanHandoffSolver( + message="Please solve the CAPTCHA in the browser window", + timeout_ms=180_000, # 3 minute timeout + poll_ms=2_000, # Check every 2 seconds + ), + ), + ) + + +def create_external_solver_config() -> PlannerExecutorConfig: + """ + External solver: Integrate with solving services. + + This example shows integration with 2Captcha. + Similar patterns work for CapSolver, Anti-Captcha, etc. + """ + + def solve_with_2captcha(ctx: CaptchaContext) -> bool: + """ + Solve CAPTCHA using 2Captcha service. + + This is a simplified example. In production: + - Handle different CAPTCHA types (reCAPTCHA, hCaptcha, etc.) + - Properly inject solutions into the page + - Handle errors and timeouts + """ + try: + # Import 2Captcha (pip install 2captcha-python) + # from twocaptcha import TwoCaptcha + # solver = TwoCaptcha('YOUR_API_KEY') + + # Get CAPTCHA type from diagnostics + captcha_type = getattr(ctx.captcha, "type", "unknown") + print(f"CAPTCHA detected: {captcha_type}") + print(f"URL: {ctx.url}") + print(f"Screenshot: {ctx.screenshot_path}") + + # Example for reCAPTCHA v2 + if captcha_type == "recaptcha": + sitekey = getattr(ctx.captcha, "sitekey", None) + if sitekey: + # result = solver.recaptcha(sitekey=sitekey, url=ctx.url) + # solution = result['code'] + + # Inject solution using page_control + if ctx.page_control: + # await ctx.page_control.evaluate_js(f""" + # document.getElementById('g-recaptcha-response').innerHTML = '{solution}'; + # """) + pass + + return True # Signal that solving was attempted + + except Exception as e: + print(f"CAPTCHA solve failed: {e}") + return False + + return PlannerExecutorConfig( + captcha=CaptchaConfig( + policy="callback", + handler=ExternalSolver( + resolver=solve_with_2captcha, + message="Solving CAPTCHA via external service", + timeout_ms=180_000, + poll_ms=5_000, + ), + ), + ) + + +def create_custom_handler_config() -> PlannerExecutorConfig: + """ + Custom handler: Full control over CAPTCHA handling. + + Use this for: + - Complex solving logic + - Integration with internal systems + - Custom retry/escalation strategies + """ + + async def custom_captcha_handler(ctx: CaptchaContext) -> CaptchaResolution: + """ + Custom CAPTCHA handler with full context access. + + CaptchaContext provides: + - ctx.run_id: Current automation run ID + - ctx.step_index: Current step being executed + - ctx.url: Page URL where CAPTCHA appeared + - ctx.source: Where CAPTCHA was detected + - ctx.captcha: CaptchaDiagnostics with type, sitekey, etc. + - ctx.screenshot_path: Path to screenshot + - ctx.frames_dir: Directory with frame images + - ctx.snapshot_path: Path to DOM snapshot + - ctx.live_session_url: URL for live debugging + - ctx.page_control: Hook for JS evaluation + """ + + print(f"[Custom Handler] CAPTCHA at step {ctx.step_index}") + print(f"[Custom Handler] URL: {ctx.url}") + print(f"[Custom Handler] Type: {getattr(ctx.captcha, 'type', 'unknown')}") + + # Example: Check if we have a live session for manual intervention + if ctx.live_session_url: + print(f"[Custom Handler] Live session: {ctx.live_session_url}") + + # Return wait_until_cleared for human intervention + return CaptchaResolution( + action="wait_until_cleared", + message="CAPTCHA detected - please solve manually via live session", + handled_by="human", + timeout_ms=120_000, + poll_ms=3_000, + ) + + # Example: For certain sites, retry with new session + if "problematic-site.com" in ctx.url: + return CaptchaResolution( + action="retry_new_session", + message="Retrying with fresh session", + handled_by="unknown", + ) + + # Default: Abort if we can't handle + return CaptchaResolution( + action="abort", + message="Cannot handle CAPTCHA automatically", + handled_by="unknown", + ) + + return PlannerExecutorConfig( + captcha=CaptchaConfig( + policy="callback", + handler=custom_captcha_handler, + timeout_ms=300_000, # 5 minute overall timeout + min_confidence=0.7, + ), + ) + + +async def run_with_captcha_handling(): + """Example: Run automation with CAPTCHA handling.""" + print("\n=== CAPTCHA Handling Example ===\n") + + planner = OpenAIProvider(model="gpt-4o") + executor = OpenAIProvider(model="gpt-4o-mini") + + # Choose a CAPTCHA handling strategy + # config = create_abort_config() # Fail on CAPTCHA + config = create_human_handoff_config() # Wait for manual solve + # config = create_external_solver_config() # Use 2Captcha + # config = create_custom_handler_config() # Custom logic + + agent = PlannerExecutorAgent( + planner=planner, + executor=executor, + config=config, + ) + + task = AutomationTask( + task_id="captcha-test", + starting_url="https://example.com", # Replace with actual site + task="Complete the signup form", + category=TaskCategory.FORM_FILL, + enable_recovery=True, + ) + + async with AsyncPredicateBrowser() as browser: + page = await browser.new_page() + await page.goto(task.starting_url) + + runtime = AgentRuntime.from_page(page) + result = await agent.run(runtime, task) + + print(f"\nResult: {'Success' if result.success else 'Failed'}") + if result.error: + print(f"Error: {result.error}") + + +async def demonstrate_captcha_configs(): + """Show different CAPTCHA configurations without running.""" + print("=" * 60) + print("CAPTCHA Configuration Examples") + print("=" * 60) + + print("\n1. ABORT Policy (default)") + print("-" * 40) + config = create_abort_config() + print(f" Policy: {config.captcha.policy}") + print(f" Min confidence: {config.captcha.min_confidence}") + + print("\n2. HUMAN HANDOFF") + print("-" * 40) + config = create_human_handoff_config() + print(f" Policy: {config.captcha.policy}") + print(" Handler: HumanHandoffSolver") + print(" - Waits for manual CAPTCHA solve") + print(" - Useful with visible browser or live sessions") + + print("\n3. EXTERNAL SOLVER") + print("-" * 40) + config = create_external_solver_config() + print(f" Policy: {config.captcha.policy}") + print(" Handler: ExternalSolver") + print(" - Integrates with 2Captcha, CapSolver, etc.") + print(" - Requires API key from solving service") + + print("\n4. CUSTOM HANDLER") + print("-" * 40) + config = create_custom_handler_config() + print(f" Policy: {config.captcha.policy}") + print(" Handler: custom async function") + print(" - Full control over solving logic") + print(" - Access to CaptchaContext with all details") + + print("\n" + "=" * 60) + print("CaptchaContext fields:") + print("=" * 60) + print(" - run_id: Current automation run ID") + print(" - step_index: Current step being executed") + print(" - url: Page URL where CAPTCHA appeared") + print(" - source: Where CAPTCHA was detected") + print(" - captcha: CaptchaDiagnostics (type, sitekey, etc.)") + print(" - screenshot_path: Path to screenshot") + print(" - frames_dir: Directory with frame images") + print(" - live_session_url: URL for live debugging") + print(" - page_control: Hook for JS evaluation") + + print("\n" + "=" * 60) + print("CaptchaResolution actions:") + print("=" * 60) + print(" - abort: Stop automation immediately") + print(" - retry_new_session: Clear cookies and retry") + print(" - wait_until_cleared: Poll until CAPTCHA is gone") + + +async def main(): + """Run examples.""" + # Show configuration options (no browser needed) + await demonstrate_captcha_configs() + + # Uncomment to run with actual browser: + # await run_with_captcha_handling() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/planner-executor/custom_config_example.py b/examples/planner-executor/custom_config_example.py index af852b3..11b9f40 100644 --- a/examples/planner-executor/custom_config_example.py +++ b/examples/planner-executor/custom_config_example.py @@ -4,10 +4,12 @@ This example demonstrates various configuration options: - Snapshot escalation (enable/disable, custom step sizes) +- Scroll-after-escalation (find elements outside viewport) - Retry configuration (timeouts, max attempts) - Vision fallback settings - Pre-step verification (skip if predicates pass) - Recovery navigation (track last good URL) +- Modal dismissal (auto-dismiss overlays, custom patterns for i18n) Usage: export OPENAI_API_KEY="sk-..." @@ -22,6 +24,8 @@ from predicate import AsyncPredicateBrowser from predicate.agent_runtime import AgentRuntime from predicate.agents import ( + CheckoutDetectionConfig, + ModalDismissalConfig, PlannerExecutorAgent, PlannerExecutorConfig, RecoveryNavigationConfig, @@ -34,9 +38,10 @@ async def example_default_config() -> None: - """Default configuration: escalation enabled, step=30.""" + """Default configuration: escalation enabled, step=30, scroll enabled.""" print("\n--- Example 1: Default Config ---") print("Escalation: 60 -> 90 -> 120 -> 150 -> 180 -> 200") + print("Scroll-after-escalation: down (x3), up (x3)") config = PlannerExecutorConfig() @@ -44,6 +49,9 @@ async def example_default_config() -> None: print(f" snapshot.limit_base: {config.snapshot.limit_base}") print(f" snapshot.limit_step: {config.snapshot.limit_step}") print(f" snapshot.limit_max: {config.snapshot.limit_max}") + print(f" snapshot.scroll_after_escalation: {config.snapshot.scroll_after_escalation}") + print(f" snapshot.scroll_max_attempts: {config.snapshot.scroll_max_attempts}") + print(f" snapshot.scroll_directions: {config.snapshot.scroll_directions}") async def example_disabled_escalation() -> None: @@ -91,6 +99,43 @@ async def example_custom_limits() -> None: print(f" snapshot.limit_max: {config.snapshot.limit_max}") +async def example_scroll_after_escalation() -> None: + """Scroll-after-escalation configuration.""" + print("\n--- Example 4b: Scroll-after-Escalation ---") + print("After exhausting limit escalation, scroll to find elements outside viewport") + + # Default: scroll down first, then up + config_default = PlannerExecutorConfig() + print(f" Default scroll_directions: {config_default.snapshot.scroll_directions}") + + # Disable scroll-after-escalation (only use limit escalation) + config_no_scroll = PlannerExecutorConfig( + snapshot=SnapshotEscalationConfig( + scroll_after_escalation=False, + ), + ) + print(f" Disabled: scroll_after_escalation={config_no_scroll.snapshot.scroll_after_escalation}") + + # Custom: more scroll attempts, down only (useful for infinite scroll pages) + config_down_only = PlannerExecutorConfig( + snapshot=SnapshotEscalationConfig( + scroll_after_escalation=True, + scroll_max_attempts=5, # More scrolls + scroll_directions=("down",), # Only scroll down + ), + ) + print(f" Down-only: scroll_directions={config_down_only.snapshot.scroll_directions}") + print(f" Down-only: scroll_max_attempts={config_down_only.snapshot.scroll_max_attempts}") + + # Custom: try up first (e.g., for elements at top of page) + config_up_first = PlannerExecutorConfig( + snapshot=SnapshotEscalationConfig( + scroll_directions=("up", "down"), # Try up before down + ), + ) + print(f" Up-first: scroll_directions={config_up_first.snapshot.scroll_directions}") + + async def example_retry_config() -> None: """Custom retry configuration.""" print("\n--- Example 5: Retry Config ---") @@ -163,17 +208,169 @@ async def example_recovery_navigation() -> None: print(" Tracks last_known_good_url for recovery when agent gets off-track") +async def example_modal_dismissal() -> None: + """Modal dismissal configuration for auto-dismissing overlays.""" + print("\n--- Example 9: Modal Dismissal ---") + print("Auto-dismiss blocking modals, drawers, and popups") + + # Default: enabled with common English patterns + config_default = PlannerExecutorConfig() + print(f" Default enabled: {config_default.modal.enabled}") + print(f" Default patterns: {config_default.modal.dismiss_patterns[:5]}...") + + # Disable modal dismissal + config_disabled = PlannerExecutorConfig( + modal=ModalDismissalConfig(enabled=False), + ) + print(f" Disabled: modal.enabled={config_disabled.modal.enabled}") + + # Custom patterns for German sites + config_german = PlannerExecutorConfig( + modal=ModalDismissalConfig( + dismiss_patterns=( + "nein danke", # No thanks + "nicht jetzt", # Not now + "abbrechen", # Cancel + "schließen", # Close + "überspringen", # Skip + "später", # Later + "ablehnen", # Decline + "weiter", # Continue + ), + dismiss_icons=("x", "×", "✕"), # Icons are universal + ), + ) + print(f" German patterns: {config_german.modal.dismiss_patterns[:4]}...") + + # Custom patterns for Spanish sites + config_spanish = PlannerExecutorConfig( + modal=ModalDismissalConfig( + dismiss_patterns=( + "no gracias", # No thanks + "ahora no", # Not now + "cancelar", # Cancel + "cerrar", # Close + "omitir", # Skip + "más tarde", # Later + "rechazar", # Reject + "continuar", # Continue + ), + ), + ) + print(f" Spanish patterns: {config_spanish.modal.dismiss_patterns[:4]}...") + + # Custom patterns for French sites + config_french = PlannerExecutorConfig( + modal=ModalDismissalConfig( + dismiss_patterns=( + "non merci", # No thanks + "pas maintenant", # Not now + "annuler", # Cancel + "fermer", # Close + "passer", # Skip + "plus tard", # Later + "refuser", # Refuse + "continuer", # Continue + ), + ), + ) + print(f" French patterns: {config_french.modal.dismiss_patterns[:4]}...") + + # Custom patterns for Japanese sites + config_japanese = PlannerExecutorConfig( + modal=ModalDismissalConfig( + dismiss_patterns=( + "いいえ", # No + "後で", # Later + "閉じる", # Close + "キャンセル", # Cancel + "スキップ", # Skip + "続ける", # Continue + "結構です", # No thank you + ), + ), + ) + print(f" Japanese patterns: {config_japanese.modal.dismiss_patterns[:4]}...") + + # Combined multilingual config + config_multilingual = PlannerExecutorConfig( + modal=ModalDismissalConfig( + dismiss_patterns=( + # English + "no thanks", "not now", "close", "skip", "cancel", + # German + "nein danke", "schließen", "abbrechen", + # Spanish + "no gracias", "cerrar", "cancelar", + # French + "non merci", "fermer", "annuler", + ), + ), + ) + print(f" Multilingual: {len(config_multilingual.modal.dismiss_patterns)} patterns") + + +async def example_checkout_detection() -> None: + """Checkout page detection configuration.""" + print("\n--- Example 10: Checkout Detection ---") + print("Auto-detect checkout pages and trigger continuation replanning") + + # Default: enabled with common checkout patterns + config_default = PlannerExecutorConfig() + print(f" Default enabled: {config_default.checkout.enabled}") + print(f" Default URL patterns: {config_default.checkout.url_patterns[:5]}...") + + # Disable checkout detection + config_disabled = PlannerExecutorConfig( + checkout=CheckoutDetectionConfig(enabled=False), + ) + print(f" Disabled: checkout.enabled={config_disabled.checkout.enabled}") + + # Custom patterns for German e-commerce sites + config_german = PlannerExecutorConfig( + checkout=CheckoutDetectionConfig( + url_patterns=( + "/warenkorb", # Cart + "/kasse", # Checkout + "/zahlung", # Payment + "/bestellung", # Order + "/anmelden", # Sign-in + ), + element_patterns=( + "zur kasse", # To checkout + "warenkorb", # Shopping cart + "jetzt kaufen", # Buy now + "anmelden", # Sign in + ), + ), + ) + print(f" German URL patterns: {config_german.checkout.url_patterns[:3]}...") + + # Disable replan trigger (just detect, don't act) + config_detect_only = PlannerExecutorConfig( + checkout=CheckoutDetectionConfig( + enabled=True, + trigger_replan=False, # Only detect, don't trigger continuation + ), + ) + print(f" Detect-only: trigger_replan={config_detect_only.checkout.trigger_replan}") + + async def example_full_custom() -> None: """Full custom configuration with all options.""" - print("\n--- Example 9: Full Custom Config ---") + print("\n--- Example 11: Full Custom Config ---") config = PlannerExecutorConfig( - # Snapshot escalation + # Snapshot escalation with scroll-after-escalation snapshot=SnapshotEscalationConfig( enabled=True, limit_base=80, limit_step=40, limit_max=240, + # Scroll to find elements outside viewport + scroll_after_escalation=True, + scroll_max_attempts=3, + scroll_directions=("down", "up"), ), # Retry settings retry=RetryConfig( @@ -192,6 +389,16 @@ async def example_full_custom() -> None: enabled=True, max_recovery_attempts=2, ), + # Modal dismissal (auto-dismiss blocking overlays) + modal=ModalDismissalConfig( + enabled=True, + max_attempts=2, + ), + # Checkout detection (continue workflow on checkout pages) + checkout=CheckoutDetectionConfig( + enabled=True, + trigger_replan=True, + ), # Pre-step verification pre_step_verification=True, # Planner settings @@ -208,15 +415,18 @@ async def example_full_custom() -> None: print(" Full config created successfully!") print(f" Escalation: {config.snapshot.limit_base} -> ... -> {config.snapshot.limit_max}") + print(f" Scroll-after-escalation: {config.snapshot.scroll_after_escalation}") print(f" Max replans: {config.retry.max_replans}") print(f" Vision enabled: {config.vision.enabled}") print(f" Pre-step verification: {config.pre_step_verification}") print(f" Recovery enabled: {config.recovery.enabled}") + print(f" Modal dismissal: {config.modal.enabled}") + print(f" Checkout detection: {config.checkout.enabled}") async def example_run_with_config() -> None: """Run agent with custom config.""" - print("\n--- Example 10: Run Agent with Custom Config ---") + print("\n--- Example 12: Run Agent with Custom Config ---") openai_key = os.getenv("OPENAI_API_KEY") if not openai_key: @@ -275,10 +485,13 @@ async def main() -> None: await example_disabled_escalation() await example_custom_step_size() await example_custom_limits() + await example_scroll_after_escalation() await example_retry_config() await example_vision_fallback() await example_pre_step_verification() await example_recovery_navigation() + await example_modal_dismissal() + await example_checkout_detection() await example_full_custom() await example_run_with_config() diff --git a/predicate/agent_runtime.py b/predicate/agent_runtime.py index d4ad2d4..8f82362 100644 --- a/predicate/agent_runtime.py +++ b/predicate/agent_runtime.py @@ -326,6 +326,152 @@ async def get_url(self) -> str: self._cached_url = url return url + async def get_viewport_height(self) -> int: + """ + Get current viewport height in pixels. + + Returns: + Viewport height in pixels, or 800 as fallback if unavailable + """ + try: + # Try refresh_page_info first (PlaywrightBackend) + refresh_fn = getattr(self.backend, "refresh_page_info", None) + if callable(refresh_fn): + info = await refresh_fn() + height = getattr(info, "height", None) + if height and height > 0: + return int(height) + + # Try evaluating JavaScript directly + eval_fn = getattr(self.backend, "eval", None) + if callable(eval_fn): + height = await eval_fn("window.innerHeight") + if height and height > 0: + return int(height) + except Exception: + pass + + # Fallback to reasonable default + return 800 + + # ------------------------------------------------------------------------- + # Action methods for PlannerExecutorAgent compatibility + # ------------------------------------------------------------------------- + + async def click(self, element_id: int) -> None: + """ + Click an element by its snapshot ID. + + Args: + element_id: Element ID from snapshot + """ + from .actions import click_async + + # Get element bounds from last snapshot + if self.last_snapshot is None: + raise RuntimeError("No snapshot available. Call snapshot() first.") + + element = None + for el in self.last_snapshot.elements or []: + if getattr(el, "id", None) == element_id: + element = el + break + + if element is None: + raise ValueError(f"Element {element_id} not found in snapshot") + + # Use the backend page for clicking + page = getattr(self.backend, "page", None) or getattr(self.backend, "_page", None) + if page is None: + raise RuntimeError("No page available in backend") + + # Get element center coordinates (Element model uses 'bbox' not 'bounds') + bbox = getattr(element, "bbox", None) + if bbox: + x = bbox.x + bbox.width / 2 + y = bbox.y + bbox.height / 2 + await self.backend.mouse_click(x=x, y=y, button="left", click_count=1) + else: + # Fall back to evaluating click via page + # Note: click_async expects AsyncSentienceBrowser, not page + await page.evaluate(f"window.sentience?.clickElement({element_id})") + + await self.record_action(f"CLICK({element_id})") + + async def type(self, element_id: int, text: str) -> None: + """ + Type text into an element. + + Args: + element_id: Element ID from snapshot + text: Text to type + """ + # First click to focus + await self.click(element_id) + + # Then type + await self.backend.type_text(text) + await self.record_action(f"TYPE({element_id}, '{text[:20]}...')" if len(text) > 20 else f"TYPE({element_id}, '{text}')") + + async def press(self, key: str) -> None: + """ + Press a keyboard key. + + Args: + key: Key to press (e.g., "Enter", "Tab", "Escape") + """ + page = getattr(self.backend, "page", None) or getattr(self.backend, "_page", None) + if page is None: + raise RuntimeError("No page available in backend") + + await page.keyboard.press(key) + await self.record_action(f"PRESS({key})") + + async def goto(self, url: str) -> None: + """ + Navigate to a URL. + + Args: + url: URL to navigate to + """ + page = getattr(self.backend, "page", None) or getattr(self.backend, "_page", None) + if page is None: + raise RuntimeError("No page available in backend") + + await page.goto(url) + await page.wait_for_load_state("domcontentloaded") + self._cached_url = url + await self.record_action(f"NAVIGATE({url})") + + async def scroll(self, direction: str = "down", amount: int = 500) -> None: + """ + Scroll the page. + + Args: + direction: "up" or "down" + amount: Pixels to scroll + """ + dy = amount if direction == "down" else -amount + await self.backend.wheel(delta_y=float(dy)) + await self.record_action(f"SCROLL({direction})") + + async def stabilize(self, timeout_s: float = 5.0, poll_s: float = 0.5) -> None: + """ + Wait for page to stabilize (network idle, no pending animations). + + Args: + timeout_s: Maximum wait time + poll_s: Poll interval + """ + page = getattr(self.backend, "page", None) or getattr(self.backend, "_page", None) + if page is None: + return + + try: + await page.wait_for_load_state("networkidle", timeout=int(timeout_s * 1000)) + except Exception: + pass # Best effort + async def snapshot(self, emit_trace: bool = True, **kwargs: Any) -> Snapshot: """ Take a snapshot of the current page state. diff --git a/predicate/agents/__init__.py b/predicate/agents/__init__.py index f3d8e5e..bdfad68 100644 --- a/predicate/agents/__init__.py +++ b/predicate/agents/__init__.py @@ -8,8 +8,26 @@ Agent types: - PredicateBrowserAgent: Single-executor agent with manual step definitions - PlannerExecutorAgent: Two-tier agent with LLM-generated plans + +Task abstractions: +- AutomationTask: Generic task model for browser automation +- TaskCategory: Task category hints for heuristics selection + +Heuristics: +- HeuristicHint: Planner-generated hints for element selection +- ComposableHeuristics: Dynamic heuristics composition + +Recovery: +- RecoveryState: Checkpoint tracking for rollback recovery +- RecoveryCheckpoint: Individual recovery checkpoint """ +from .automation_task import ( + AutomationTask, + ExtractionSpec, + SuccessCriteria, + TaskCategory, +) from .browser_agent import ( CaptchaConfig, PermissionRecoveryConfig, @@ -17,9 +35,14 @@ PredicateBrowserAgentConfig, VisionFallbackConfig, ) +from .composable_heuristics import ComposableHeuristics +from .heuristic_spec import COMMON_HINTS, HeuristicHint, get_common_hint from .planner_executor_agent import ( + AuthBoundaryConfig, + CheckoutDetectionConfig, ExecutorOverride, IntentHeuristics, + ModalDismissalConfig, Plan, PlanStep, PlannerExecutorAgent, @@ -35,17 +58,31 @@ normalize_plan, validate_plan_smoothness, ) +from .recovery import RecoveryCheckpoint, RecoveryState __all__ = [ + # Automation Task + "AutomationTask", + "ExtractionSpec", + "SuccessCriteria", + "TaskCategory", # Browser Agent "CaptchaConfig", "PermissionRecoveryConfig", "PredicateBrowserAgent", "PredicateBrowserAgentConfig", "VisionFallbackConfig", + # Heuristics + "COMMON_HINTS", + "ComposableHeuristics", + "HeuristicHint", + "get_common_hint", # Planner + Executor Agent + "AuthBoundaryConfig", + "CheckoutDetectionConfig", "ExecutorOverride", "IntentHeuristics", + "ModalDismissalConfig", "Plan", "PlanStep", "PlannerExecutorAgent", @@ -60,5 +97,7 @@ "StepStatus", "normalize_plan", "validate_plan_smoothness", + # Recovery + "RecoveryCheckpoint", + "RecoveryState", ] - diff --git a/predicate/agents/automation_task.py b/predicate/agents/automation_task.py new file mode 100644 index 0000000..a0d0497 --- /dev/null +++ b/predicate/agents/automation_task.py @@ -0,0 +1,336 @@ +""" +AutomationTask: Generic task model for browser automation. + +This module provides a task abstraction that generalizes WebBenchTask to support +broad web automation use cases like "buy a laptop on xyz.com". + +Key features: +- Natural language task description with optional structured goal +- Task category hints for heuristics selection +- Budget constraints (timeout, max_steps, max_replans) +- Extraction specification for data extraction tasks +- Human-defined success criteria (optional override) +- Recovery configuration for rollback on failure +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class TaskCategory(str, Enum): + """ + Task category for heuristics and verification selection. + + Categories help the planner and executor make better decisions about: + - Element selection strategies + - Verification stringency + - Recovery behavior + """ + + NAVIGATION = "navigation" # Navigate to a destination + SEARCH = "search" # Search and find information + FORM_FILL = "form_fill" # Fill out forms + EXTRACTION = "extraction" # Extract data from pages + TRANSACTION = "transaction" # Purchase, submit, create actions + VERIFICATION = "verification" # Verify state/information exists + + +class ExtractionSpec(BaseModel): + """ + Specification for data extraction tasks. + + Used when the task involves extracting structured data from pages, + such as product information, search results, or table data. + """ + + output_schema: dict[str, Any] | None = Field( + default=None, + description="JSON schema for expected extraction output", + ) + target_selectors: list[str] = Field( + default_factory=list, + description="Suggested selectors for extraction targets", + ) + format: Literal["json", "text", "markdown", "table"] = Field( + default="json", + description="Output format for extracted data", + ) + require_evidence: bool = Field( + default=True, + description="Whether to require grounding evidence for extractions", + ) + + class Config: + extra = "allow" + + +class SuccessCriteria(BaseModel): + """ + Human-defined success criteria for task completion. + + If provided, these criteria override planner-proposed verification. + This allows users to define exactly what "success" means for their task. + + Example: + SuccessCriteria( + predicates=[ + {"predicate": "url_contains", "args": ["/confirmation"]}, + {"predicate": "exists", "args": ["[data-testid='order-number']"]}, + ], + require_all=True, + ) + """ + + predicates: list[dict[str, Any]] = Field( + default_factory=list, + description="PredicateSpec definitions for success verification", + ) + require_all: bool = Field( + default=True, + description="If True, all predicates must pass. If False, any passing is success.", + ) + + class Config: + extra = "allow" + + +@dataclass(frozen=True) +class AutomationTask: + """ + Generic automation task for the PlannerExecutorAgent. + + This replaces WebBenchTask with a more flexible structure that supports: + - Broad tasks like "buy a laptop on xyz.com" + - Optional structured goals and extraction specs + - Timeout and budget constraints + - Category hints for heuristics selection + - Recovery configuration + + Example: + task = AutomationTask( + task_id="purchase-laptop-001", + starting_url="https://amazon.com", + task="Find a laptop under $1000 with good reviews and add to cart", + category=TaskCategory.TRANSACTION, + timeout_s=300.0, + max_steps=50, + ) + + Example with extraction: + task = AutomationTask( + task_id="extract-product-info", + starting_url="https://amazon.com/dp/B0...", + task="Extract the product name, price, and rating", + category=TaskCategory.EXTRACTION, + extraction_spec=ExtractionSpec( + schema={"name": "str", "price": "float", "rating": "float"}, + format="json", + ), + ) + + Example with human-defined success criteria: + task = AutomationTask( + task_id="checkout-flow", + starting_url="https://shop.com/cart", + task="Complete the checkout process", + category=TaskCategory.TRANSACTION, + success_criteria=SuccessCriteria( + predicates=[ + {"predicate": "url_contains", "args": ["/confirmation"]}, + {"predicate": "exists", "args": [".order-number"]}, + ], + require_all=True, + ), + ) + """ + + # Required fields + task_id: str + starting_url: str + task: str # Natural language task description + + # Optional: Structured goal for more precise planning + goal: dict[str, Any] | None = None + + # Optional: Category hint for heuristics/verification selection + category: TaskCategory | None = None + + # Budget constraints + timeout_s: float | None = None + max_steps: int = 50 + max_replans: int = 2 + max_vision_calls: int = 3 + + # Extraction specification (for data extraction tasks) + extraction_spec: ExtractionSpec | None = None + + # Human-defined success criteria (optional override) + success_criteria: SuccessCriteria | None = None + + # Recovery configuration + enable_recovery: bool = True + max_recovery_attempts: int = 2 + + # Domain hints for heuristics (e.g., ["ecommerce", "amazon"]) + domain_hints: tuple[str, ...] = field(default_factory=tuple) + + @classmethod + def from_webbench_task(cls, task: Any) -> "AutomationTask": + """ + Factory method to convert WebBenchTask to AutomationTask. + + Preserves backward compatibility with webbench. + + Args: + task: WebBenchTask instance with id, starting_url, task, category + + Returns: + AutomationTask instance + + Example: + from webbench.models import WebBenchTask + + wb_task = WebBenchTask( + id="task-001", + starting_url="https://example.com", + task="Click the login button", + category="CREATE", + ) + automation_task = AutomationTask.from_webbench_task(wb_task) + """ + # Map WebBench categories to TaskCategory + category_map = { + "READ": TaskCategory.EXTRACTION, + "CREATE": TaskCategory.TRANSACTION, + "UPDATE": TaskCategory.FORM_FILL, + "DELETE": TaskCategory.TRANSACTION, + "FILE_MANIPULATION": TaskCategory.TRANSACTION, + } + wb_category = getattr(task, "category", None) + category = category_map.get(wb_category) if wb_category else None + + # For READ tasks, create a basic extraction spec + extraction_spec = None + if wb_category == "READ": + extraction_spec = ExtractionSpec( + format="json", + require_evidence=True, + ) + + return cls( + task_id=task.id, + starting_url=task.starting_url, + task=task.task, + category=category, + extraction_spec=extraction_spec, + ) + + @classmethod + def from_string( + cls, + task: str, + starting_url: str, + *, + task_id: str | None = None, + category: TaskCategory | None = None, + ) -> "AutomationTask": + """ + Create an AutomationTask from a simple string description. + + Args: + task: Natural language task description + starting_url: URL to start automation from + task_id: Optional task ID (auto-generated if not provided) + category: Optional task category hint + + Returns: + AutomationTask instance + + Example: + task = AutomationTask.from_string( + "Search for 'laptop' and add the first result to cart", + "https://amazon.com", + ) + """ + import uuid + + return cls( + task_id=task_id or str(uuid.uuid4()), + starting_url=starting_url, + task=task, + category=category, + ) + + def with_success_criteria(self, *predicates: dict[str, Any], require_all: bool = True) -> "AutomationTask": + """ + Return a new AutomationTask with the specified success criteria. + + Args: + *predicates: PredicateSpec dictionaries + require_all: If True, all predicates must pass + + Returns: + New AutomationTask with success_criteria set + + Example: + task = task.with_success_criteria( + {"predicate": "url_contains", "args": ["/success"]}, + {"predicate": "exists", "args": [".confirmation"]}, + ) + """ + return AutomationTask( + task_id=self.task_id, + starting_url=self.starting_url, + task=self.task, + goal=self.goal, + category=self.category, + timeout_s=self.timeout_s, + max_steps=self.max_steps, + max_replans=self.max_replans, + max_vision_calls=self.max_vision_calls, + extraction_spec=self.extraction_spec, + success_criteria=SuccessCriteria( + predicates=list(predicates), + require_all=require_all, + ), + enable_recovery=self.enable_recovery, + max_recovery_attempts=self.max_recovery_attempts, + domain_hints=self.domain_hints, + ) + + def with_extraction( + self, + output_schema: dict[str, Any] | None = None, + format: Literal["json", "text", "markdown", "table"] = "json", + ) -> "AutomationTask": + """ + Return a new AutomationTask with extraction specification. + + Args: + output_schema: JSON schema for expected output + format: Output format + + Returns: + New AutomationTask with extraction_spec set + """ + return AutomationTask( + task_id=self.task_id, + starting_url=self.starting_url, + task=self.task, + goal=self.goal, + category=self.category or TaskCategory.EXTRACTION, + timeout_s=self.timeout_s, + max_steps=self.max_steps, + max_replans=self.max_replans, + max_vision_calls=self.max_vision_calls, + extraction_spec=ExtractionSpec(output_schema=output_schema, format=format), + success_criteria=self.success_criteria, + enable_recovery=self.enable_recovery, + max_recovery_attempts=self.max_recovery_attempts, + domain_hints=self.domain_hints, + ) diff --git a/predicate/agents/composable_heuristics.py b/predicate/agents/composable_heuristics.py new file mode 100644 index 0000000..819c9f6 --- /dev/null +++ b/predicate/agents/composable_heuristics.py @@ -0,0 +1,322 @@ +""" +ComposableHeuristics: Dynamic heuristics composition for element selection. + +This module provides a heuristics implementation that composes from multiple sources: +1. Planner-provided HeuristicHints (per step, highest priority) +2. Static IntentHeuristics (user-injected at agent construction) +3. TaskCategory defaults (lowest priority) + +This allows the planner to dynamically guide element selection without +requiring changes to user-provided heuristics. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +from .automation_task import TaskCategory +from .heuristic_spec import COMMON_HINTS, HeuristicHint + +if TYPE_CHECKING: + pass + + +@runtime_checkable +class IntentHeuristics(Protocol): + """ + Protocol for pluggable domain-specific element selection heuristics. + + This protocol is duplicated here to avoid circular imports with + planner_executor_agent.py. The actual protocol is defined there. + """ + + def find_element_for_intent( + self, + intent: str, + elements: list[Any], + url: str, + goal: str, + ) -> int | None: + """Find element ID for a given intent.""" + ... + + def priority_order(self) -> list[str]: + """Return list of intent patterns in priority order.""" + ... + + +class ComposableHeuristics: + """ + Heuristics implementation that composes from multiple sources. + + Priority order (highest to lowest): + 1. Planner-provided HeuristicHints for current step + 2. Common hints for well-known patterns (add_to_cart, checkout, etc.) + 3. Static IntentHeuristics (user-injected) + 4. TaskCategory defaults + + Example: + heuristics = ComposableHeuristics( + static_heuristics=my_ecommerce_heuristics, + task_category=TaskCategory.TRANSACTION, + ) + + # Before executing each step, set the step's hints + heuristics.set_step_hints(step.heuristic_hints) + + # Find element for intent + element_id = heuristics.find_element_for_intent( + intent="add_to_cart", + elements=snapshot.elements, + url=snapshot.url, + goal="Add laptop to cart", + ) + """ + + def __init__( + self, + *, + static_heuristics: IntentHeuristics | None = None, + task_category: TaskCategory | None = None, + use_common_hints: bool = True, + ) -> None: + """ + Initialize ComposableHeuristics. + + Args: + static_heuristics: User-provided IntentHeuristics (optional) + task_category: Task category for default heuristics (optional) + use_common_hints: Whether to use COMMON_HINTS as fallback + """ + self._static = static_heuristics + self._category = task_category + self._use_common_hints = use_common_hints + self._current_hints: list[HeuristicHint] = [] + + def set_step_hints(self, hints: list[HeuristicHint] | list[dict] | None) -> None: + """ + Set hints for the current step. + + Called before each step execution with hints from the plan. + + Args: + hints: List of HeuristicHint objects or dicts + """ + if not hints: + self._current_hints = [] + return + + # Convert dicts to HeuristicHint if needed + parsed_hints: list[HeuristicHint] = [] + for h in hints: + if isinstance(h, HeuristicHint): + parsed_hints.append(h) + elif isinstance(h, dict): + try: + parsed_hints.append(HeuristicHint(**h)) + except Exception: + # Skip invalid hints + pass + + # Sort by priority (highest first) + self._current_hints = sorted(parsed_hints, key=lambda h: -h.priority) + + def clear_step_hints(self) -> None: + """Clear hints for the current step.""" + self._current_hints = [] + + def find_element_for_intent( + self, + intent: str, + elements: list[Any], + url: str, + goal: str, + ) -> int | None: + """ + Find element ID for a given intent using composed heuristics. + + Tries sources in priority order: + 1. Planner-provided hints for current step + 2. Common hints for well-known patterns + 3. Static heuristics (user-provided) + 4. TaskCategory defaults + + Args: + intent: The intent hint from the plan step + elements: List of snapshot elements + url: Current page URL + goal: Human-readable goal for context + + Returns: + Element ID if found, None to fall back to LLM executor + """ + if not intent or not elements: + return None + + # 1. Try planner-provided hints first + for hint in self._current_hints: + if hint.matches_intent(intent): + element_id = self._match_hint(hint, elements) + if element_id is not None: + return element_id + + # 2. Try common hints for well-known patterns + if self._use_common_hints: + common_hint = self._get_common_hint_for_intent(intent) + if common_hint: + element_id = self._match_hint(common_hint, elements) + if element_id is not None: + return element_id + + # 3. Try static heuristics + if self._static is not None: + try: + element_id = self._static.find_element_for_intent( + intent, elements, url, goal + ) + if element_id is not None: + return element_id + except Exception: + # Don't let static heuristics crash the flow + pass + + # 4. Try category-based defaults + return self._category_default_match(intent, elements) + + def _match_hint(self, hint: HeuristicHint, elements: list[Any]) -> int | None: + """ + Match elements against a hint's criteria. + + Args: + hint: HeuristicHint to match against + elements: List of snapshot elements + + Returns: + Element ID if match found, None otherwise + """ + for el in elements: + if hint.matches_element(el): + element_id = getattr(el, "id", None) + if element_id is not None: + return element_id + return None + + def _get_common_hint_for_intent(self, intent: str) -> HeuristicHint | None: + """Get common hint for well-known intents.""" + intent_normalized = intent.lower().replace(" ", "_").replace("-", "_") + + # Direct match + if intent_normalized in COMMON_HINTS: + return COMMON_HINTS[intent_normalized] + + # Partial match + for key, hint in COMMON_HINTS.items(): + if key in intent_normalized or intent_normalized in key: + return hint + + return None + + def _category_default_match( + self, intent: str, elements: list[Any] + ) -> int | None: + """ + Apply category-based default matching. + + Uses TaskCategory to apply sensible defaults for common patterns. + + Args: + intent: The intent string + elements: List of snapshot elements + + Returns: + Element ID if match found, None otherwise + """ + if not self._category: + return None + + intent_lower = intent.lower() + + if self._category == TaskCategory.TRANSACTION: + # Transaction patterns: add to cart, checkout, buy, submit + transaction_keywords = [ + "add to cart", + "add to bag", + "buy now", + "checkout", + "proceed", + "submit", + "confirm", + "place order", + ] + for el in elements: + text = (getattr(el, "text", "") or "").lower() + role = getattr(el, "role", "") or "" + if role in ("button", "link"): + for keyword in transaction_keywords: + if keyword in text: + return getattr(el, "id", None) + + elif self._category == TaskCategory.FORM_FILL: + # Form patterns: submit, next, continue + form_keywords = ["submit", "next", "continue", "save", "update"] + for el in elements: + text = (getattr(el, "text", "") or "").lower() + role = getattr(el, "role", "") or "" + if role == "button": + for keyword in form_keywords: + if keyword in text: + return getattr(el, "id", None) + + elif self._category == TaskCategory.SEARCH: + # Search patterns: search button, go + if "search" in intent_lower: + for el in elements: + text = (getattr(el, "text", "") or "").lower() + role = getattr(el, "role", "") or "" + if role in ("button", "textbox") and "search" in text: + return getattr(el, "id", None) + + elif self._category == TaskCategory.NAVIGATION: + # Navigation: links matching intent + for el in elements: + text = (getattr(el, "text", "") or "").lower() + role = getattr(el, "role", "") or "" + if role == "link" and intent_lower in text: + return getattr(el, "id", None) + + return None + + def priority_order(self) -> list[str]: + """ + Return list of intent patterns in priority order. + + Combines patterns from all sources. + + Returns: + List of intent pattern strings + """ + patterns: list[str] = [] + + # Add current step hints + patterns.extend(h.intent_pattern for h in self._current_hints) + + # Add common hints + if self._use_common_hints: + patterns.extend(COMMON_HINTS.keys()) + + # Add static heuristics patterns + if self._static is not None: + try: + patterns.extend(self._static.priority_order()) + except Exception: + pass + + # Deduplicate while preserving order + seen = set() + result = [] + for p in patterns: + if p not in seen: + seen.add(p) + result.append(p) + + return result diff --git a/predicate/agents/heuristic_spec.py b/predicate/agents/heuristic_spec.py new file mode 100644 index 0000000..662737f --- /dev/null +++ b/predicate/agents/heuristic_spec.py @@ -0,0 +1,186 @@ +""" +HeuristicSpec: Planner-generated hints for element selection. + +This module provides models for dynamic heuristics composition. The planner +can generate HeuristicHint objects alongside execution plans, allowing +element selection without requiring an LLM call. + +Key concepts: +- HeuristicHint: A single hint with intent pattern, text patterns, and role filters +- Hints are generated per-step by the planner +- ComposableHeuristics (separate module) uses these hints at runtime +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class HeuristicHint(BaseModel): + """ + Planner-generated hint for element selection. + + The planner can propose these hints alongside plan steps to guide + element selection without requiring an LLM executor call. + + Attributes: + intent_pattern: Pattern to match against step intent (e.g., "add_to_cart") + text_patterns: Text patterns to search in element text (case-insensitive) + role_filter: Allowed element roles (e.g., ["button", "link"]) + priority: Priority order (higher = try first) + attribute_patterns: Optional attribute patterns to match (e.g., {"data-action": "add-to-cart"}) + + Example: + HeuristicHint( + intent_pattern="add_to_cart", + text_patterns=["add to cart", "add to bag", "buy now"], + role_filter=["button"], + priority=10, + ) + + Example in plan JSON: + { + "id": 3, + "goal": "Add item to cart", + "action": "CLICK", + "intent": "add_to_cart", + "heuristic_hints": [ + { + "intent_pattern": "add_to_cart", + "text_patterns": ["add to cart", "add to bag"], + "role_filter": ["button"], + "priority": 10 + } + ], + "verify": [{"predicate": "url_contains", "args": ["/cart"]}] + } + """ + + intent_pattern: str = Field( + ..., + description="Intent pattern to match (e.g., 'add_to_cart', 'checkout', 'login')", + ) + text_patterns: list[str] = Field( + default_factory=list, + description="Text patterns to search in elements (case-insensitive)", + ) + role_filter: list[str] = Field( + default_factory=list, + description="Allowed element roles (e.g., ['button', 'link'])", + ) + priority: int = Field( + default=0, + description="Priority order (higher = try first)", + ) + attribute_patterns: dict[str, str] = Field( + default_factory=dict, + description="Attribute patterns to match (e.g., {'data-action': 'add-to-cart'})", + ) + + class Config: + extra = "allow" + + def matches_intent(self, intent: str) -> bool: + """ + Check if this hint matches the given intent. + + Args: + intent: The intent string from the plan step + + Returns: + True if the hint's intent_pattern is found in the intent + """ + if not intent: + return False + return self.intent_pattern.lower() in intent.lower() + + def matches_element(self, element: object) -> bool: + """ + Check if an element matches this hint's criteria. + + Args: + element: Snapshot element with text, role, and attributes + + Returns: + True if the element matches all criteria + """ + # Check role filter + role = getattr(element, "role", "") or "" + if self.role_filter and role.lower() not in [r.lower() for r in self.role_filter]: + return False + + # Check text patterns + text = (getattr(element, "text", "") or "").lower() + if self.text_patterns: + if not any(pattern.lower() in text for pattern in self.text_patterns): + return False + + # Check attribute patterns + if self.attribute_patterns: + attributes = getattr(element, "attributes", {}) or {} + for attr_name, attr_pattern in self.attribute_patterns.items(): + attr_value = attributes.get(attr_name, "") + if attr_pattern.lower() not in (attr_value or "").lower(): + return False + + return True + + +# Common heuristic hints for well-known patterns +COMMON_HINTS = { + "add_to_cart": HeuristicHint( + intent_pattern="add_to_cart", + text_patterns=["add to cart", "add to bag", "add to basket", "buy now"], + role_filter=["button"], + priority=10, + ), + "checkout": HeuristicHint( + intent_pattern="checkout", + text_patterns=["checkout", "proceed to checkout", "go to checkout"], + role_filter=["button", "link"], + priority=10, + ), + "login": HeuristicHint( + intent_pattern="login", + text_patterns=["log in", "login", "sign in", "signin"], + role_filter=["button", "link"], + priority=10, + ), + "submit": HeuristicHint( + intent_pattern="submit", + text_patterns=["submit", "send", "continue", "next", "confirm"], + role_filter=["button"], + priority=5, + ), + "search": HeuristicHint( + intent_pattern="search", + text_patterns=["search", "find", "go"], + role_filter=["button", "textbox"], + priority=5, + ), + "close": HeuristicHint( + intent_pattern="close", + text_patterns=["close", "dismiss", "x", "cancel"], + role_filter=["button"], + priority=3, + ), + "accept_cookies": HeuristicHint( + intent_pattern="accept_cookies", + text_patterns=["accept", "accept all", "allow", "agree", "ok", "got it"], + role_filter=["button"], + priority=8, + ), +} + + +def get_common_hint(intent: str) -> HeuristicHint | None: + """ + Get a common heuristic hint for well-known intents. + + Args: + intent: Intent string (e.g., "add_to_cart", "checkout") + + Returns: + HeuristicHint if a common hint exists, None otherwise + """ + return COMMON_HINTS.get(intent.lower().replace(" ", "_").replace("-", "_")) diff --git a/predicate/agents/planner_executor_agent.py b/predicate/agents/planner_executor_agent.py index 70e6ee2..13774ae 100644 --- a/predicate/agents/planner_executor_agent.py +++ b/predicate/agents/planner_executor_agent.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import base64 import hashlib import json @@ -45,7 +46,11 @@ url_matches, ) +from .automation_task import AutomationTask, ExtractionSpec, SuccessCriteria, TaskCategory from .browser_agent import CaptchaConfig, VisionFallbackConfig +from .composable_heuristics import ComposableHeuristics +from .heuristic_spec import HeuristicHint +from .recovery import RecoveryCheckpoint, RecoveryState # --------------------------------------------------------------------------- @@ -210,12 +215,23 @@ class SnapshotEscalationConfig: # Larger initial limit, smaller steps config = SnapshotEscalationConfig(limit_base=100, limit_step=20, limit_max=180) + + # Enable scroll-after-escalation to find elements below/above viewport + config = SnapshotEscalationConfig(scroll_after_escalation=True, scroll_directions=("down", "up")) + + # Custom scroll amount as fraction of viewport height (default: 0.4 = 40%) + config = SnapshotEscalationConfig(scroll_viewport_fraction=0.5) # 50% of viewport """ enabled: bool = True limit_base: int = 60 limit_step: int = 30 limit_max: int = 200 + # Scroll after exhausting limit escalation to find elements in different viewports + scroll_after_escalation: bool = True + scroll_max_attempts: int = 3 # Max scrolls per direction + scroll_directions: tuple[str, ...] = ("down", "up") # Directions to try + scroll_viewport_fraction: float = 0.4 # Scroll by 40% of viewport height (adaptive to screen size) @dataclass(frozen=True) @@ -250,6 +266,213 @@ class RecoveryNavigationConfig: track_successful_urls: bool = True +@dataclass(frozen=True) +class ModalDismissalConfig: + """ + Configuration for automatic modal/drawer dismissal after DOM changes. + + When a CLICK action triggers a DOM change (e.g., modal/drawer appears), + this feature attempts to dismiss blocking overlays using common patterns. + + This handles common blocking scenarios: + - Product protection/warranty upsells (Amazon, etc.) + - Cookie consent banners + - Newsletter signup popups + - Promotional overlays + - Cart upsell drawers + + The dismissal logic looks for buttons with common dismissal text patterns + and clicks them to clear the overlay. + + Attributes: + enabled: If True, attempt to dismiss modals after DOM change detection. + dismiss_patterns: Text patterns to match dismissal buttons (case-insensitive). + role_filter: Element roles to consider for dismissal buttons. + max_attempts: Maximum dismissal attempts per modal. + min_new_elements: Minimum new DOM elements to trigger modal detection. + + Example: + # Default: enabled with common dismissal patterns + config = ModalDismissalConfig() + + # Disable modal dismissal + config = ModalDismissalConfig(enabled=False) + + # Custom patterns (e.g., for non-English sites) + config = ModalDismissalConfig( + dismiss_patterns=("nein danke", "schließen", "abbrechen"), + ) + """ + + enabled: bool = True + # Patterns that require word boundary matching (longer patterns) + # Ordered by preference: decline > close > accept (we prefer not to accept upsells) + dismiss_patterns: tuple[str, ...] = ( + # Decline/Skip patterns (highest priority - user is declining an offer) + "no thanks", + "no, thanks", + "no thank you", + "not now", + "not interested", + "maybe later", + "skip", + "decline", + "reject", + "deny", + "continue without", + # Close patterns + "close", + "close dialog", + "close modal", + "close popup", + "dismiss", + "dismiss banner", + "dismiss dialog", + "cancel", + # Continue patterns (when modal offers upgrade vs continue) + "continue", + "continue to", + "proceed", + ) + # Icon characters that require exact match (entire label is just this character) + dismiss_icons: tuple[str, ...] = ( + "x", + "×", # Unicode multiplication sign + "✕", # Unicode X mark + "✖", # Heavy multiplication X + "✗", # Ballot X + "╳", # Box drawings + ) + role_filter: tuple[str, ...] = ("button", "link") + max_attempts: int = 2 + min_new_elements: int = 5 # Same threshold as DOM change fallback + + +@dataclass(frozen=True) +class CheckoutDetectionConfig: + """ + Configuration for checkout page detection. + + After modal dismissal or action completion, the agent checks if the + current page is a checkout-relevant page. If detected, this triggers + a replan to continue the checkout flow. + + This handles scenarios where: + - Agent clicks "Add to Cart" and modal is dismissed, but agent stops + - Agent lands on cart page but plan doesn't include checkout steps + - Agent reaches login page during checkout flow + + Attributes: + enabled: If True, detect checkout pages and trigger continuation. + url_patterns: URL patterns that indicate checkout-relevant pages. + element_patterns: Element text patterns that indicate checkout pages. + trigger_replan: If True, trigger replanning when checkout page detected. + + Example: + # Default: enabled with common checkout patterns + config = CheckoutDetectionConfig() + + # Disable checkout detection + config = CheckoutDetectionConfig(enabled=False) + + # Custom patterns + config = CheckoutDetectionConfig( + url_patterns=("/warenkorb", "/kasse"), # German + ) + """ + + enabled: bool = True + # URL patterns that indicate checkout-relevant pages + url_patterns: tuple[str, ...] = ( + # Cart pages + "/cart", + "/basket", + "/bag", + "/shopping-cart", + "/gp/cart", # Amazon + # Checkout pages + "/checkout", + "/buy", + "/order", + "/payment", + "/pay", + "/purchase", + "/gp/buy", # Amazon + "/gp/checkout", # Amazon + # Sign-in during checkout + "/signin", + "/sign-in", + "/login", + "/ap/signin", # Amazon + "/auth", + "/authenticate", + ) + # Element text patterns that indicate checkout pages (case-insensitive) + element_patterns: tuple[str, ...] = ( + "proceed to checkout", + "proceed to buy", + "go to checkout", + "view cart", + "shopping cart", + "your cart", + "sign in to checkout", + "continue to payment", + "place your order", + "buy now", + "checkout", + ) + # If True, trigger replanning when checkout page is detected + trigger_replan: bool = True + + +@dataclass(frozen=True) +class AuthBoundaryConfig: + """ + Configuration for authentication boundary detection. + + When the agent reaches a login/sign-in page and doesn't have credentials, + it should stop gracefully instead of failing or spinning endlessly. + + This is a "terminal state" - the agent has successfully navigated as far + as possible without authentication. + + Attributes: + enabled: If True, detect auth boundaries and stop gracefully. + url_patterns: URL patterns indicating authentication pages. + stop_on_auth: If True, mark run as successful when auth boundary reached. + auth_success_message: Message to include in outcome when stopping at auth. + + Example: + # Default: enabled, stop gracefully at login pages + config = AuthBoundaryConfig() + + # Disable (try to continue past auth pages) + config = AuthBoundaryConfig(enabled=False) + """ + + enabled: bool = True + # URL patterns that indicate authentication/login pages + url_patterns: tuple[str, ...] = ( + "/signin", + "/sign-in", + "/login", + "/log-in", + "/auth", + "/authenticate", + "/ap/signin", # Amazon sign-in + "/ap/register", # Amazon registration + "/ax/claim", # Amazon CAPTCHA/verification + "/account/login", + "/accounts/login", + "/user/login", + ) + # If True, mark the run as successful when auth boundary is reached + # (since the agent did everything it could without credentials) + stop_on_auth: bool = True + # Message to include when stopping at auth boundary + auth_success_message: str = "Reached authentication boundary (login required)" + + @dataclass(frozen=True) class PlannerExecutorConfig: """ @@ -260,6 +483,7 @@ class PlannerExecutorConfig: - Retry/verification settings - Vision fallback settings - Recovery navigation settings + - Modal dismissal settings - Planner/Executor LLM settings - Tracing settings """ @@ -284,6 +508,15 @@ class PlannerExecutorConfig: # Recovery navigation recovery: RecoveryNavigationConfig = RecoveryNavigationConfig() + # Modal dismissal (for blocking overlays after DOM changes) + modal: ModalDismissalConfig = ModalDismissalConfig() + + # Checkout page detection (continue workflow when reaching checkout pages) + checkout: CheckoutDetectionConfig = CheckoutDetectionConfig() + + # Authentication boundary detection (stop gracefully at login pages) + auth_boundary: AuthBoundaryConfig = AuthBoundaryConfig() + # Planner LLM settings planner_max_tokens: int = 2048 planner_temperature: float = 0.0 @@ -300,11 +533,19 @@ class PlannerExecutorConfig: # Pre-step verification (skip step if predicates already pass) pre_step_verification: bool = True + # Scroll-to-find: automatically scroll to find elements when not in viewport + scroll_to_find_enabled: bool = True + scroll_to_find_max_scrolls: int = 3 # Max scroll attempts per direction + scroll_to_find_directions: tuple[str, ...] = ("down", "up") # Try down first, then up + # Tracing trace_screenshots: bool = True trace_screenshot_format: str = "jpeg" trace_screenshot_quality: int = 80 + # Verbose mode (print plan, prompts, and LLM responses to stdout) + verbose: bool = False + # --------------------------------------------------------------------------- # Plan Models (Pydantic for validation) @@ -334,6 +575,10 @@ class PlanStep(BaseModel): required: bool = Field(True, description="If True, step failure triggers replan") stop_if_true: bool = Field(False, description="If True, stop execution when verification passes") optional_substeps: list["PlanStep"] = Field(default_factory=list, description="Optional fallback steps") + heuristic_hints: list[dict[str, Any]] = Field( + default_factory=list, + description="Planner-generated hints for element selection", + ) class Config: extra = "allow" @@ -390,9 +635,10 @@ def should_use_vision(self) -> bool: def digest(self) -> str: """Compute a digest for loop/change detection.""" + title = getattr(self.snapshot, "title", None) or "" parts = [ self.snapshot.url[:200] if self.snapshot.url else "", - self.snapshot.title[:200] if self.snapshot.title else "", + title[:200] if title else "", f"count:{len(self.snapshot.elements or [])}", ] for el in (self.snapshot.elements or [])[:100]: @@ -532,6 +778,108 @@ def build_predicate(spec: PredicateSpec | dict[str, Any]) -> Predicate: # --------------------------------------------------------------------------- +def _parse_string_predicate(pred_str: str) -> dict[str, Any] | None: + """ + Parse a string predicate into a normalized dict. + + LLMs sometimes output predicates as strings like: + - "url_contains('amazon.com')" -> {"predicate": "url_contains", "args": ["amazon.com"]} + - "url_matches(^https://www\\.amazon\\.com/.*)" -> {"predicate": "url_matches", "args": ["^https://www\\.amazon\\.com/.*"]} + - "exists(role=button)" -> {"predicate": "exists", "args": ["role=button"]} + + Args: + pred_str: String representation of a predicate + + Returns: + Normalized predicate dict or None if parsing fails + """ + import re + + pred_str = pred_str.strip() + + # Try to match function-call style: predicate_name(args) + match = re.match(r'^(\w+)\s*\(\s*(.+?)\s*\)$', pred_str, re.DOTALL) + if match: + pred_name = match.group(1) + args_str = match.group(2) + + # Strip quotes from args if present + args_str = args_str.strip() + if (args_str.startswith("'") and args_str.endswith("'")) or \ + (args_str.startswith('"') and args_str.endswith('"')): + args_str = args_str[1:-1] + + return { + "predicate": pred_name, + "args": [args_str], + } + + # Try simple predicate name without args + if re.match(r'^[\w_]+$', pred_str): + return { + "predicate": pred_str, + "args": [], + } + + return None + + +def _normalize_verify_predicate(pred: dict[str, Any]) -> dict[str, Any]: + """ + Normalize a verify predicate to the expected format. + + LLMs may output predicates in various formats: + - {"url_contains": "amazon.com"} -> {"predicate": "url_contains", "args": ["amazon.com"]} + - {"exists": "text~'Logitech'"} -> {"predicate": "exists", "args": ["text~'Logitech'"]} + - {"predicate": "url_contains", "input": "x"} -> {"predicate": "url_contains", "args": ["x"]} + - {"type": "url_contains", "input": "x"} -> {"predicate": "url_contains", "args": ["x"]} + + Args: + pred: Raw predicate dictionary + + Returns: + Normalized predicate with "predicate" and "args" fields + """ + # Handle "type" field as alternative to "predicate" (common LLM variation) + if "type" in pred and "predicate" not in pred: + pred["predicate"] = pred.pop("type") + + # Already has predicate field - normalize args + if "predicate" in pred: + # Handle "input" field as alternative to "args" + if "args" not in pred or not pred["args"]: + if "input" in pred: + pred["args"] = [pred.pop("input")] + elif "value" in pred: + pred["args"] = [pred.pop("value")] + elif "pattern" in pred: # For url_matches + pred["args"] = [pred.pop("pattern")] + elif "substring" in pred: # For url_contains + pred["args"] = [pred.pop("substring")] + elif "selector" in pred: # For exists/not_exists + pred["args"] = [pred.pop("selector")] + return pred + + # Predicate type is a key in the dict (e.g., {"url_contains": "amazon.com"}) + known_predicates = [ + "url_contains", "url_equals", "url_matches", + "exists", "not_exists", + "element_count", "element_visible", + "any_of", "all_of", + "text_contains", "text_equals", + ] + + for pred_type in known_predicates: + if pred_type in pred: + return { + "predicate": pred_type, + "args": [pred[pred_type]] if pred[pred_type] else [], + } + + # Unknown format - return as-is and let validation fail with clear error + return pred + + def normalize_plan(plan_dict: dict[str, Any]) -> dict[str, Any]: """ Normalize plan dictionary to handle LLM output variations. @@ -540,6 +888,7 @@ def normalize_plan(plan_dict: dict[str, Any]) -> dict[str, Any]: - url vs target field names - action aliases (click vs CLICK) - step id variations (string vs int) + - verify predicate format variations Args: plan_dict: Raw plan dictionary from LLM @@ -580,6 +929,24 @@ def normalize_plan(plan_dict: dict[str, Any]) -> dict[str, Any]: except ValueError: pass + # Normalize verify predicates + if "verify" in step and isinstance(step["verify"], list): + normalized_verify = [] + for pred in step["verify"]: + if isinstance(pred, dict): + normalized_verify.append(_normalize_verify_predicate(pred)) + elif isinstance(pred, str): + # Try to parse string predicates like "url_contains('text')" + parsed = _parse_string_predicate(pred) + if parsed: + normalized_verify.append(parsed) + else: + # Keep as-is, let validation fail with clear error + normalized_verify.append({"predicate": "unknown", "args": [pred]}) + else: + normalized_verify.append(pred) + step["verify"] = normalized_verify + # Normalize optional_substeps recursively if "optional_substeps" in step: for substep in step["optional_substeps"]: @@ -587,6 +954,21 @@ def normalize_plan(plan_dict: dict[str, Any]) -> dict[str, Any]: substep["action"] = substep["action"].upper() if "url" in substep and "target" not in substep: substep["target"] = substep.pop("url") + # Normalize verify in substeps too + if "verify" in substep and isinstance(substep["verify"], list): + normalized_verify = [] + for pred in substep["verify"]: + if isinstance(pred, dict): + normalized_verify.append(_normalize_verify_predicate(pred)) + elif isinstance(pred, str): + parsed = _parse_string_predicate(pred) + if parsed: + normalized_verify.append(parsed) + else: + normalized_verify.append({"predicate": "unknown", "args": [pred]}) + else: + normalized_verify.append(pred) + substep["verify"] = normalized_verify return plan_dict @@ -669,31 +1051,97 @@ def build_planner_prompt( strict_note = "\nReturn ONLY a JSON object. No explanations, no code fences.\n" if strict else "" schema_note = f"\nSchema errors from last attempt:\n{schema_errors}\n" if schema_errors else "" + # Detect task type for domain-specific guidance + task_lower = task.lower() + is_ecommerce_task = any(keyword in task_lower for keyword in [ + "buy", "purchase", "add to cart", "checkout", "order", "shop", + "amazon", "ebay", "walmart", "target", "bestbuy", + "cart", "mouse", "keyboard", "laptop", "product", + ]) + is_search_task = any(keyword in task_lower for keyword in [ + "search", "find", "look for", "look up", "google", + ]) + + # Build domain-specific planning guidance + domain_guidance = "" + if is_ecommerce_task: + domain_guidance = """ + +IMPORTANT: E-Commerce Task Planning Rules +========================================= +For shopping/purchase tasks, include ALL necessary steps in order: +1. NAVIGATE to the site (if not already there) +2. TYPE_AND_SUBMIT search query in search box +3. CLICK on specific product from results (not filters or categories) +4. CLICK "Add to Cart" button on product page +5. CLICK "Proceed to Checkout" or cart icon +6. Handle login/signup if required (may need CLICK + TYPE_AND_SUBMIT) +7. CLICK through checkout process + +Common mistakes to AVOID: +- Do NOT skip "Add to Cart" step - clicking a product link is NOT adding to cart +- Do NOT combine multiple distinct actions into one step +- Do NOT confuse filter/category clicks with product selection +- Each distinct user action should be its own step + +Intent hints are critical - use clear hints like: +- intent: "Click Add to Cart button" +- intent: "Click Proceed to Checkout" +- intent: "Click on product title or image" +- intent: "Click sign in button" +""" + elif is_search_task: + domain_guidance = """ + +IMPORTANT: Search Task Planning Rules +===================================== +For search tasks, include steps to: +1. NAVIGATE to the search site (if not already there) +2. TYPE_AND_SUBMIT the search query +3. Wait for/verify search results +4. If selecting a result: CLICK on specific result item +""" + system = f"""You are the PLANNER. Output a JSON execution plan for the web automation task. {strict_note} -Your output must be a valid JSON object with: -- task: string (the task description) -- notes: list of strings (assumptions, constraints) -- steps: list of step objects - -Each step must have: -- id: int (1-indexed, contiguous) -- goal: string (human-readable goal) -- action: string (NAVIGATE, CLICK, TYPE_AND_SUBMIT, or SCROLL) -- verify: list of predicate specs - -Available predicates: -- url_contains(substring): URL contains the given string -- url_matches(pattern): URL matches regex pattern -- exists(selector): Element matching selector exists -- not_exists(selector): Element matching selector does not exist -- element_count(selector, min, max): Element count within range -- any_of(predicates...): Any predicate is true -- all_of(predicates...): All predicates are true - -Selectors: role=button, role=link, text~'text', role=textbox, etc. - -Return ONLY valid JSON. No prose, no code fences.""" +Your output must be a valid JSON object with this EXACT structure: + +{{ + "task": "description of task", + "notes": ["assumption 1", "assumption 2"], + "steps": [ + {{ + "id": 1, + "goal": "Navigate to website", + "action": "NAVIGATE", + "target": "https://example.com", + "verify": [{{"predicate": "url_contains", "args": ["example.com"]}}] + }}, + {{ + "id": 2, + "goal": "Search for product", + "action": "TYPE_AND_SUBMIT", + "input": "search query", + "verify": [{{"predicate": "url_contains", "args": ["search"]}}] + }}, + {{ + "id": 3, + "goal": "Click on result", + "action": "CLICK", + "intent": "Click on product title", + "verify": [{{"predicate": "url_contains", "args": ["/product/"]}}] + }} + ] +}} + +CRITICAL: Each verify predicate MUST be an object with "predicate" and "args" keys: +- {{"predicate": "url_contains", "args": ["substring"]}} +- {{"predicate": "exists", "args": ["role=button"]}} +- {{"predicate": "not_exists", "args": ["text~'error'"]}} + +DO NOT use string format like "url_contains('text')" - use object format only. +{domain_guidance} +Return ONLY valid JSON. No prose, no code fences, no markdown.""" user = f"""Task: {task} {schema_note} @@ -701,7 +1149,7 @@ def build_planner_prompt( Site type: {site_type} Auth state: {auth_state} -Output a JSON plan to accomplish this task.""" +Output a JSON plan to accomplish this task. Each step should represent ONE distinct action.""" return system, user @@ -710,14 +1158,22 @@ def build_executor_prompt( goal: str, intent: str | None, compact_context: str, + input_text: str | None = None, ) -> tuple[str, str]: """ Build system and user prompts for the Executor LLM. + Args: + goal: Human-readable goal for this step + intent: Intent hint for element selection (optional) + compact_context: Compact representation of page elements + input_text: Text to type for TYPE_AND_SUBMIT actions (optional) + Returns: (system_prompt, user_prompt) """ intent_line = f"Intent: {intent}\n" if intent else "" + input_line = f"Text to type: \"{input_text}\"\n" if input_text else "" system = """You are a careful web automation executor. You must respond with exactly ONE action in this format: @@ -732,7 +1188,7 @@ def build_executor_prompt( user = f"""You are controlling a browser via element IDs. Goal: {goal} -{intent_line} +{intent_line}{input_line} Elements (ID|role|text|importance|clickable|...): {compact_context} @@ -858,22 +1314,261 @@ def __init__( self._run_id: str | None = None self._last_known_good_url: str | None = None + # Recovery state (initialized per-run) + self._recovery_state: RecoveryState | None = None + + # Composable heuristics (wraps static heuristics with dynamic hints) + self._composable_heuristics: ComposableHeuristics | None = None + + # Current automation task (for run-level context) + self._current_task: AutomationTask | None = None + def _format_context(self, snap: Snapshot, goal: str) -> str: - """Format snapshot for LLM context.""" + """ + Format snapshot for LLM context. + + Uses compact format: id|role|text|importance|is_primary|bg|clickable|nearby_text|ord|DG|href + Same format as documented in reddit_post_planner_executor_local_llm.md + """ if self._context_formatter is not None: return self._context_formatter(snap, goal) - # Default compact format + import re + + # Filter to interactive elements + interactive_roles = { + "button", "link", "textbox", "searchbox", "combobox", + "checkbox", "radio", "slider", "tab", "menuitem", + "option", "switch", "cell", "a", "input", "select", "textarea", + } + + elements = [] + for el in snap.elements: + role = (getattr(el, "role", "") or "").lower() + if role in interactive_roles or getattr(el, "clickable", False): + elements.append(el) + + # Sort by importance + elements.sort(key=lambda el: getattr(el, "importance", 0) or 0, reverse=True) + + # Build dominant group rank map + rank_in_group_map: dict[int, int] = {} + dg_key = getattr(snap, "dominant_group_key", None) + dg_elements = [el for el in elements if getattr(el, "in_dominant_group", False)] + if not dg_elements and dg_key: + dg_elements = [el for el in elements if getattr(el, "group_key", None) == dg_key] + + # Sort dominant group by position + def rank_sort_key(el): + doc_y = getattr(el, "doc_y", None) or float("inf") + bbox = getattr(el, "bbox", None) + bbox_y = bbox.y if bbox else float("inf") + bbox_x = bbox.x if bbox else float("inf") + neg_imp = -(getattr(el, "importance", 0) or 0) + return (doc_y, bbox_y, bbox_x, neg_imp) + + dg_elements.sort(key=rank_sort_key) + for rank, el in enumerate(dg_elements): + rank_in_group_map[el.id] = rank + + # Format lines - take top elements by importance + from dominant group + by position + selected_ids: set[int] = set() + selected: list = [] + + # Top 40 by importance + for el in elements[:40]: + if el.id not in selected_ids: + selected_ids.add(el.id) + selected.append(el) + + # Top 20 from dominant group + for el in dg_elements[:20]: + if el.id not in selected_ids: + selected_ids.add(el.id) + selected.append(el) + + # Top 20 by position + elements_by_pos = sorted(elements, key=lambda el: ( + getattr(el, "doc_y", None) or float("inf"), + -(getattr(el, "importance", 0) or 0) + )) + for el in elements_by_pos[:20]: + if el.id not in selected_ids: + selected_ids.add(el.id) + selected.append(el) + + def compress_href(href: str | None) -> str: + if not href: + return "" + # Extract domain or last path segment + href = href.strip() + if href.startswith("/"): + parts = href.split("/") + return parts[-1][:20] if parts[-1] else "" + try: + from urllib.parse import urlparse + parsed = urlparse(href) + if parsed.path and parsed.path != "/": + parts = parsed.path.rstrip("/").split("/") + return parts[-1][:20] if parts[-1] else parsed.netloc[:15] + return parsed.netloc[:15] + except Exception: + return href[:20] + + def get_bg_color(el) -> str: + """Extract background color name from visual_cues.""" + visual_cues = getattr(el, "visual_cues", None) + if not visual_cues: + return "" + # Try bg_color_name first, then bg_fallback + bg = getattr(visual_cues, "bg_color_name", None) or "" + if not bg: + bg = getattr(visual_cues, "bg_fallback", None) or "" + return bg[:10] if bg else "" + + def get_nearby_text(el) -> str: + """Extract nearby text for context.""" + nearby = getattr(el, "nearby_text", None) or "" + if not nearby: + # Try aria_label or placeholder as fallback + nearby = getattr(el, "aria_label", None) or "" + if not nearby: + nearby = getattr(el, "placeholder", None) or "" + # Truncate and normalize + nearby = re.sub(r"\s+", " ", nearby.strip()) + return nearby[:20] if nearby else "" + lines = [] - for el in snap.elements[:120]: - eid = getattr(el, "id", "?") - role = getattr(el, "role", "") - text = (getattr(el, "text", "") or "")[:50] - importance = getattr(el, "importance", 0) - clickable = 1 if getattr(el, "clickable", False) else 0 - lines.append(f"{eid}|{role}|{text}|{importance}|{clickable}") + for el in selected: + eid = el.id + role = getattr(el, "role", "") or "" + if getattr(el, "href", None): + role = "link" + + # Truncate and normalize text + text = getattr(el, "text", "") or "" + text = re.sub(r"\s+", " ", text.strip()) + if len(text) > 30: + text = text[:27] + "..." + + importance = getattr(el, "importance", 0) or 0 + + # is_primary from visual_cues + is_primary = "0" + visual_cues = getattr(el, "visual_cues", None) + if visual_cues and getattr(visual_cues, "is_primary", False): + is_primary = "1" + + # bg: background color + bg = get_bg_color(el) + + # clickable flag + clickable = "1" if getattr(el, "clickable", False) else "0" + + # nearby_text + nearby_text = get_nearby_text(el) + + # in dominant group + in_dg = getattr(el, "in_dominant_group", False) + if not in_dg and dg_key: + in_dg = getattr(el, "group_key", None) == dg_key + + # ord: rank in dominant group + ord_val = rank_in_group_map.get(eid, "") if in_dg else "" + dg_flag = "1" if in_dg else "0" + + # href + href = compress_href(getattr(el, "href", None)) + + # Format: id|role|text|importance|is_primary|bg|clickable|nearby_text|ord|DG|href + line = f"{eid}|{role}|{text}|{importance}|{is_primary}|{bg}|{clickable}|{nearby_text}|{ord_val}|{dg_flag}|{href}" + lines.append(line) + return "\n".join(lines) + async def _attempt_recovery( + self, + runtime: AgentRuntime, + ) -> bool: + """ + Attempt to recover to the last known good state. + + Navigates back to the checkpoint URL and verifies we're in a recoverable state. + + Args: + runtime: AgentRuntime instance + + Returns: + True if recovery succeeded, False otherwise + """ + if self._recovery_state is None: + return False + + checkpoint = self._recovery_state.consume_recovery_attempt() + if checkpoint is None: + return False + + # Emit recovery event + if self.tracer: + self.tracer.emit( + "recovery_attempt", + { + "target_url": checkpoint.url, + "target_step_index": checkpoint.step_index, + "attempt": self._recovery_state.recovery_attempts_used, + }, + ) + + try: + # Navigate to checkpoint URL + await runtime.goto(checkpoint.url) + + # Wait for page to settle + await runtime.stabilize() + + # Verify we're at the expected URL (basic check) + snapshot = await runtime.snapshot() + current_url = snapshot.url or "" + + # Check if URL matches (allowing for minor variations) + url_matches = ( + checkpoint.url in current_url + or current_url in checkpoint.url + or checkpoint.url.rstrip("/") == current_url.rstrip("/") + ) + + if url_matches: + if self.tracer: + self.tracer.emit( + "recovery_success", + { + "recovered_to_url": checkpoint.url, + "actual_url": current_url, + }, + ) + return True + else: + if self.tracer: + self.tracer.emit( + "recovery_url_mismatch", + { + "expected_url": checkpoint.url, + "actual_url": current_url, + }, + ) + return False + + except Exception as e: + if self.tracer: + self.tracer.emit( + "recovery_failed", + { + "error": str(e), + "checkpoint_url": checkpoint.url, + }, + ) + return False + def _extract_json(self, text: str) -> dict[str, Any]: """Extract JSON from LLM response, handling code fences and think tags.""" # Remove ... tags @@ -891,31 +1586,41 @@ def _extract_json(self, text: str) -> dict[str, Any]: raise ValueError("No JSON object found in response") def _parse_action(self, text: str) -> tuple[str, list[Any]]: - """Parse action from executor response.""" + """Parse action from executor response. + + Handles various LLM output formats: + - CLICK(42) + - - CLICK(42) (with leading dash/bullet) + - TYPE(42, "text") + - - TYPE(42, "Logitech mouse") + """ text = text.strip() + # Strip common prefixes (bullets, dashes, asterisks) + text = re.sub(r"^[-*•]\s*", "", text) + # CLICK() - match = re.match(r"CLICK\((\d+)\)", text) + match = re.search(r"CLICK\((\d+)\)", text) if match: return "CLICK", [int(match.group(1))] - # TYPE(, "text") - match = re.match(r'TYPE\((\d+),\s*["\'](.+?)["\']\)', text) + # TYPE(, "text") - also handle without quotes + match = re.search(r'TYPE\((\d+),\s*["\']?([^"\']+?)["\']?\)', text) if match: - return "TYPE", [int(match.group(1)), match.group(2)] + return "TYPE", [int(match.group(1)), match.group(2).strip()] # PRESS('key') - match = re.match(r"PRESS\(['\"](.+?)['\"]\)", text) + match = re.search(r"PRESS\(['\"]?(.+?)['\"]?\)", text) if match: return "PRESS", [match.group(1)] # SCROLL(direction) - match = re.match(r"SCROLL\((\w+)\)", text) + match = re.search(r"SCROLL\((\w+)\)", text) if match: return "SCROLL", [match.group(1)] # FINISH() - if text.startswith("FINISH"): + if "FINISH" in text: return "FINISH", [] return "UNKNOWN", [text] @@ -1107,15 +1812,25 @@ async def _snapshot_with_escalation( runtime: AgentRuntime, goal: str, capture_screenshot: bool = True, + step: PlanStep | None = None, ) -> SnapshotContext: """ - Capture snapshot with incremental limit escalation. + Capture snapshot with incremental limit escalation and optional scroll-to-find. Progressively increases snapshot limit if elements are missing or confidence is low. Escalation can be disabled via config. + After exhausting limit escalation, if scroll_after_escalation is enabled, + scrolls down/up to find elements that may be outside the current viewport. + When escalation is disabled (config.snapshot.enabled=False), only a single snapshot at limit_base is captured. + + Args: + runtime: AgentRuntime for capturing snapshots and scrolling + goal: Goal string for context formatting + capture_screenshot: Whether to capture screenshot + step: Optional PlanStep for goal-aware element detection during scroll """ cfg = self.config.snapshot current_limit = cfg.limit_base @@ -1132,6 +1847,7 @@ async def _snapshot_with_escalation( limit=current_limit, screenshot=capture_screenshot, goal=goal, + show_overlay=True, # Show visual overlay for debugging ) if snap is None: if not cfg.enabled: @@ -1163,6 +1879,9 @@ async def _snapshot_with_escalation( break # Check element count - if sufficient, no need to escalate + # NOTE: Limit escalation is based on element COUNT only, not on whether + # a specific target element was found. Intent heuristics are only used + # for scroll-after-escalation AFTER limit escalation is exhausted. elements = getattr(snap, "elements", []) or [] if len(elements) >= 10: break @@ -1178,6 +1897,106 @@ async def _snapshot_with_escalation( break # No escalation on error current_limit = min(current_limit + cfg.limit_step, max_limit + 1) + # Scroll-after-escalation: if we have a step and scroll is enabled, + # try scrolling to find elements that may be outside the viewport + # + # IMPORTANT: Only trigger scroll-after-escalation for CLICK actions with + # specific intents that suggest an element may be below the viewport + # (e.g., "add_to_cart", "checkout"). Do NOT scroll for TYPE_AND_SUBMIT + # or generic actions where elements are typically visible. + should_try_scroll = ( + cfg.scroll_after_escalation + and step is not None + and last_snap is not None + and not requires_vision + and step.action == "CLICK" # Only for CLICK actions + and step.intent # Must have a specific intent + and self._intent_heuristics is not None # Must have heuristics to detect + ) + + if should_try_scroll: + # Check if we can find the target element using intent heuristics + elements = getattr(last_snap, "elements", []) or [] + url = getattr(last_snap, "url", "") or "" + found_element = await self._try_intent_heuristics(step, elements, url) + + if found_element is None: + # Element not found in current viewport - try scrolling + if self.config.verbose: + print(f" [SNAPSHOT-ESCALATION] Target element not found, trying scroll-after-escalation...", flush=True) + + # Get viewport height and calculate scroll delta + viewport_height = await runtime.get_viewport_height() + scroll_delta = viewport_height * cfg.scroll_viewport_fraction + + for direction in cfg.scroll_directions: + # Map direction to dy (pixels): down=positive, up=negative + scroll_dy = scroll_delta if direction == "down" else -scroll_delta + + for scroll_num in range(cfg.scroll_max_attempts): + if self.config.verbose: + print(f" [SNAPSHOT-ESCALATION] Scrolling {direction} ({scroll_num + 1}/{cfg.scroll_max_attempts})...", flush=True) + + # Scroll with deterministic verification + # scroll_by() returns False if scroll had no effect (reached page boundary) + scroll_effective = await runtime.scroll_by( + dy=scroll_dy, + verify=True, + min_delta_px=50.0, + js_fallback=True, + required=False, # Don't fail the task if scroll doesn't work + timeout_s=5.0, + ) + + if not scroll_effective: + if self.config.verbose: + print(f" [SNAPSHOT-ESCALATION] Scroll {direction} had no effect (reached boundary), skipping remaining attempts", flush=True) + break # No point trying more scrolls in this direction + + # Wait for stabilization after successful scroll + if self.config.stabilize_enabled: + await asyncio.sleep(self.config.stabilize_poll_s) + + # Take new snapshot at max limit (we already escalated) + try: + snap = await runtime.snapshot( + limit=cfg.limit_max, + screenshot=capture_screenshot, + goal=goal, + show_overlay=True, + ) + if snap is None: + continue + + last_snap = snap + last_compact = self._format_context(snap, goal) + + # Extract screenshot + if capture_screenshot: + raw_screenshot = getattr(snap, "screenshot", None) + if raw_screenshot: + screenshot_b64 = raw_screenshot + + # Check if target element is now visible + elements = getattr(snap, "elements", []) or [] + url = getattr(snap, "url", "") or "" + found_element = await self._try_intent_heuristics(step, elements, url) + + if found_element is not None: + if self.config.verbose: + print(f" [SNAPSHOT-ESCALATION] Found target element {found_element} after scrolling {direction}", flush=True) + # Break out of both loops + break + except Exception: + continue + + # If found, break out of direction loop + if found_element is not None: + break + + if found_element is None and self.config.verbose: + print(f" [SNAPSHOT-ESCALATION] Target element not found after scrolling", flush=True) + # Fallback for failed capture if last_snap is None: last_snap = Snapshot( @@ -1236,6 +2055,14 @@ async def plan( schema_errors=last_errors or None, ) + if self.config.verbose: + print("\n" + "=" * 60, flush=True) + print(f"[PLANNER] Attempt {attempt}/{max_attempts}", flush=True) + print("=" * 60, flush=True) + print(f"Task: {task}", flush=True) + print(f"Start URL: {start_url}", flush=True) + print("-" * 40, flush=True) + resp = self.planner.generate( sys_prompt, user_prompt, @@ -1244,6 +2071,11 @@ async def plan( ) last_output = resp.content + if self.config.verbose: + print("\n--- Planner Response ---", flush=True) + print(resp.content, flush=True) + print("--- End Response ---\n", flush=True) + try: plan_dict = self._extract_json(resp.content) @@ -1266,10 +2098,18 @@ async def plan( # Emit trace event self._emit_plan_event(plan, last_output) + if self.config.verbose: + import json + print("\n=== PLAN GENERATED ===", flush=True) + print(json.dumps(plan.model_dump(), indent=2, default=str), flush=True) + print("=== END PLAN ===\n", flush=True) + return plan except Exception as e: last_errors = str(e) + if self.config.verbose: + print(f"[PLANNER] Parse error: {e}", flush=True) continue raise RuntimeError(f"Planner failed after {max_attempts} attempts. Last output:\n{last_output[:500]}") @@ -1426,6 +2266,85 @@ async def _try_intent_heuristics( except Exception: return None + async def _scroll_to_find_element( + self, + runtime: AgentRuntime, + step: PlanStep, + goal: str, + ) -> tuple[int | None, "SnapshotContext | None"]: + """ + Scroll the page to find an element that matches the step's goal/intent. + + This is used when the initial snapshot doesn't contain the target element + (e.g., "Add to Cart" button below the viewport). + + Returns: + (element_id, new_snapshot_context) if found, (None, None) otherwise + """ + if not self.config.scroll_to_find_enabled: + return None, None + + for direction in self.config.scroll_to_find_directions: + for scroll_num in range(self.config.scroll_to_find_max_scrolls): + if self.config.verbose: + print(f" [SCROLL-TO-FIND] Scrolling {direction} ({scroll_num + 1}/{self.config.scroll_to_find_max_scrolls})...", flush=True) + + # Scroll + await runtime.scroll(direction) + + # Wait for stabilization + if self.config.stabilize_enabled: + await asyncio.sleep(self.config.stabilize_poll_s) + + # Take new snapshot (don't pass step to avoid recursive scroll-to-find) + ctx = await self._snapshot_with_escalation( + runtime, + goal=goal, + capture_screenshot=self.config.trace_screenshots, + step=None, # Avoid recursive scroll-after-escalation + ) + + # Try heuristics on new snapshot + elements = getattr(ctx.snapshot, "elements", []) or [] + url = getattr(ctx.snapshot, "url", "") or "" + element_id = await self._try_intent_heuristics(step, elements, url) + + if element_id is not None: + if self.config.verbose: + print(f" [SCROLL-TO-FIND] Found element {element_id} after scrolling {direction}", flush=True) + return element_id, ctx + + # Try LLM executor on new snapshot + sys_prompt, user_prompt = build_executor_prompt( + step.goal, + step.intent, + ctx.compact_representation, + input_text=step.input, + ) + resp = self.executor.generate( + sys_prompt, + user_prompt, + temperature=self.config.executor_temperature, + max_new_tokens=self.config.executor_max_tokens, + ) + parsed_action, parsed_args = self._parse_action(resp.content) + + if parsed_action == "CLICK" and parsed_args: + element_id = parsed_args[0] + # Verify element exists in snapshot + if any(getattr(el, "id", None) == element_id for el in elements): + if self.config.verbose: + print(f" [SCROLL-TO-FIND] LLM found element {element_id} after scrolling {direction}", flush=True) + return element_id, ctx + + # If LLM suggests SCROLL, it means element still not visible + if parsed_action == "SCROLL": + continue + + if self.config.verbose: + print(f" [SCROLL-TO-FIND] Element not found after scrolling", flush=True) + return None, None + async def _execute_optional_substeps( self, substeps: list[PlanStep], @@ -1452,6 +2371,7 @@ async def _execute_optional_substeps( runtime, goal=substep.goal, capture_screenshot=False, + step=substep, # Enable scroll-after-escalation for substeps ) # Determine element and action @@ -1470,6 +2390,7 @@ async def _execute_optional_substeps( substep.goal, substep.intent, ctx.compact_representation, + input_text=substep.input, ) resp = self.executor.generate( sys_prompt, @@ -1508,6 +2429,329 @@ async def _execute_optional_substeps( return outcomes + def _word_boundary_match(self, pattern: str, text: str) -> bool: + """ + Match pattern as a word boundary to avoid false positives. + + For example, "x" should match "x" but not "mexico" or "boxer". + "close" should match "close" or "Close Dialog" but not "enclosed". + """ + import re + + if not pattern or not text: + return False + + pattern_lower = pattern.lower() + text_lower = text.lower() + + # For very short patterns (1-2 chars like "x", "×"), require exact match + if len(pattern) <= 2: + return text_lower == pattern_lower or text_lower.strip() == pattern_lower + + # For longer patterns, use word boundary matching + try: + return bool(re.search(r'\b' + re.escape(pattern_lower) + r'\b', text_lower)) + except re.error: + # Fall back to contains if regex fails + return pattern_lower in text_lower + + async def _attempt_modal_dismissal( + self, + runtime: AgentRuntime, + post_snap: Any, + ) -> bool: + """ + Attempt to dismiss a modal/drawer overlay after DOM change detection. + + This handles common blocking scenarios like: + - Product protection/warranty upsells + - Cookie consent banners + - Newsletter signup popups + - Promotional overlays + - Cart upsell drawers + + The method looks for buttons with common dismissal text patterns + and clicks them to clear the overlay. + + Uses word boundary matching to avoid false positives like: + - "mexico" matching "x" + - "enclosed" matching "close" + - "boxer" matching "x" + + Args: + runtime: The AgentRuntime for browser control + post_snap: The snapshot after DOM change was detected + + Returns: + True if a dismissal was attempted, False otherwise + """ + if not self.config.modal.enabled: + return False + + cfg = self.config.modal + elements = getattr(post_snap, "elements", []) or [] + + # Find candidates that match dismissal patterns + candidates: list[tuple[int, int, str]] = [] # (element_id, score, matched_pattern) + + for el in elements: + el_id = getattr(el, "id", None) + if el_id is None: + continue + + role = (getattr(el, "role", "") or "").lower() + if role not in cfg.role_filter: + continue + + # Get text and aria_label for matching + text = (getattr(el, "text", "") or "").lower().strip() + aria_label = (getattr(el, "aria_label", "") or getattr(el, "ariaLabel", "") or "").lower().strip() + labels = [lbl for lbl in [text, aria_label] if lbl] + + if not labels: + continue + + matched = False + + # Check icon patterns first (require exact match) + for i, icon in enumerate(cfg.dismiss_icons): + icon_lower = icon.lower() + score = 200 + len(cfg.dismiss_icons) - i # Icons get high priority + + for lbl in labels: + if lbl == icon_lower: + candidates.append((el_id, score, icon)) + matched = True + break + if matched: + break + + if matched: + continue + + # Check word boundary patterns + for i, pattern in enumerate(cfg.dismiss_patterns): + score = len(cfg.dismiss_patterns) - i # Higher score for earlier patterns + + for lbl in labels: + if self._word_boundary_match(pattern, lbl): + candidates.append((el_id, score, pattern)) + matched = True + break + if matched: + break + + if not candidates: + if self.config.verbose: + print(" [MODAL] No dismissal candidates found", flush=True) + return False + + # Sort by score (highest first) and try the best candidates + candidates.sort(key=lambda x: x[1], reverse=True) + + for attempt in range(min(cfg.max_attempts, len(candidates))): + el_id, score, pattern = candidates[attempt] + if self.config.verbose: + print(f" [MODAL] Attempting dismissal: clicking element {el_id} (pattern: '{pattern}')", flush=True) + + try: + await runtime.click(el_id) + # Wait briefly for modal to close + await asyncio.sleep(0.3) + + # Check if modal was dismissed (DOM changed again) + try: + new_snap = await runtime.snapshot(emit_trace=False) + new_elements = set(getattr(el, "id", 0) for el in (new_snap.elements or [])) + post_elements = set(getattr(el, "id", 0) for el in (post_snap.elements or [])) + removed_elements = post_elements - new_elements + + if len(removed_elements) > 3: + if self.config.verbose: + print(f" [MODAL] Dismissal successful ({len(removed_elements)} elements removed)", flush=True) + return True + except Exception: + pass # Continue with next attempt + + except Exception as e: + if self.config.verbose: + print(f" [MODAL] Dismissal click failed: {e}", flush=True) + continue + + if self.config.verbose: + print(" [MODAL] All dismissal attempts exhausted", flush=True) + return False + + async def _detect_checkout_page( + self, + runtime: AgentRuntime, + snapshot: Any | None = None, + ) -> tuple[bool, str | None]: + """ + Detect if the current page is a checkout-relevant page. + + This detects pages that indicate the agent should continue + the checkout flow, such as: + - Cart pages + - Checkout pages + - Sign-in pages during checkout + + Args: + runtime: The AgentRuntime for browser control + snapshot: Optional snapshot to check (avoids re-capture) + + Returns: + (is_checkout_page, page_type) - page_type is 'cart', 'checkout', 'login', etc. + """ + if not self.config.checkout.enabled: + return False, None + + cfg = self.config.checkout + + # Get current URL + try: + current_url = await runtime.get_url() if hasattr(runtime, "get_url") else None + except Exception: + current_url = None + + if not current_url: + return False, None + + url_lower = current_url.lower() + + # Check URL patterns + for pattern in cfg.url_patterns: + if pattern.lower() in url_lower: + page_type = self._classify_checkout_page(pattern) + if self.config.verbose: + print(f" [CHECKOUT] Detected {page_type} page (URL: {pattern})", flush=True) + return True, page_type + + # Check element patterns if we have a snapshot + if snapshot: + elements = getattr(snapshot, "elements", []) or [] + for el in elements: + text = (getattr(el, "text", "") or "").lower() + for pattern in cfg.element_patterns: + if pattern.lower() in text: + page_type = self._classify_checkout_page(pattern) + if self.config.verbose: + print(f" [CHECKOUT] Detected {page_type} page (element: '{pattern}')", flush=True) + return True, page_type + + return False, None + + def _classify_checkout_page(self, pattern: str) -> str: + """Classify the type of checkout page based on the matched pattern.""" + pattern_lower = pattern.lower() + + if any(kw in pattern_lower for kw in ["cart", "basket", "bag"]): + return "cart" + elif any(kw in pattern_lower for kw in ["signin", "sign-in", "login", "auth"]): + return "login" + elif any(kw in pattern_lower for kw in ["payment", "pay"]): + return "payment" + elif any(kw in pattern_lower for kw in ["order", "place"]): + return "order" + else: + return "checkout" + + def _should_continue_after_checkout_detection( + self, + page_type: str | None, + current_step_index: int, + total_steps: int, + ) -> bool: + """ + Determine if the agent should trigger a replan after detecting a checkout page. + + Returns True if: + - We're on the last step or near the end + - The page type indicates we need more steps (e.g., login, payment) + """ + if not page_type: + return False + + # If we're on the last few steps and detected checkout, we may need to continue + steps_remaining = total_steps - current_step_index - 1 + + # Always continue if we hit a login page (authentication required) + if page_type == "login": + return True + + # Continue if we're on cart/checkout and have no more steps + if page_type in ("cart", "checkout", "payment", "order") and steps_remaining <= 1: + return True + + return False + + def _build_checkout_continuation_task( + self, + original_task: str, + page_type: str | None, + ) -> str: + """ + Build a task description for continuing from a checkout page. + + Args: + original_task: The original task description + page_type: Type of checkout page detected + + Returns: + A new task description that continues from the current page + """ + if page_type == "login": + return f"Continue from sign-in page: Complete sign-in if credentials are available, or proceed as guest if possible. Original task: {original_task}" + elif page_type == "cart": + return f"Continue from cart page: Proceed to checkout to complete the purchase. Original task: {original_task}" + elif page_type == "payment": + return f"Continue from payment page: Complete payment details and place order. Original task: {original_task}" + elif page_type == "order": + return f"Continue from order page: Review and place the order. Original task: {original_task}" + else: + return f"Continue checkout process from current page. Original task: {original_task}" + + async def _detect_auth_boundary( + self, + runtime: AgentRuntime, + ) -> bool: + """ + Detect if the current page is an authentication boundary. + + An auth boundary is a login/sign-in page where the agent cannot + proceed without credentials. This is a terminal state. + + Args: + runtime: The AgentRuntime for browser control + + Returns: + True if on an authentication page, False otherwise + """ + if not self.config.auth_boundary.enabled: + return False + + cfg = self.config.auth_boundary + + # Get current URL + try: + current_url = await runtime.get_url() if hasattr(runtime, "get_url") else None + except Exception: + current_url = None + + if not current_url: + return False + + url_lower = current_url.lower() + + # Check URL patterns + for pattern in cfg.url_patterns: + if pattern.lower() in url_lower: + if self.config.verbose: + print(f" [AUTH] Detected authentication boundary (URL: {pattern})", flush=True) + return True + + return False + async def _execute_step( self, step: PlanStep, @@ -1519,6 +2763,16 @@ async def _execute_step( pre_url = await runtime.get_url() if hasattr(runtime, "get_url") else None step_id = self._emit_step_start(step, step_index, pre_url) + if self.config.verbose: + print(f"\n[STEP {step.id}] {step.goal}", flush=True) + print(f" Action: {step.action}", flush=True) + if step.intent: + print(f" Intent: {step.intent}", flush=True) + if step.input: + print(f" Input: {step.input}", flush=True) + if step.target: + print(f" Target: {step.target}", flush=True) + llm_response: str | None = None action_taken: str | None = None used_vision = False @@ -1562,14 +2816,59 @@ async def _execute_step( return outcome - # Capture snapshot with escalation + # Pre-step auth boundary check: stop early if on signin page without credentials + # This prevents executing login steps that would fail or use fake credentials + if self.config.auth_boundary.enabled and self.config.auth_boundary.stop_on_auth: + is_auth_page = await self._detect_auth_boundary(runtime) + if is_auth_page: + if self.config.verbose: + print(f" [AUTH] Auth boundary detected at step start - stopping gracefully", flush=True) + + outcome = StepOutcome( + step_id=step.id, + goal=step.goal, + status=StepStatus.SUCCESS, # Graceful stop = success + action_taken="AUTH_BOUNDARY_REACHED", + verification_passed=True, + error=self.config.auth_boundary.auth_success_message, + duration_ms=int((time.time() - start_time) * 1000), + url_before=pre_url, + url_after=pre_url, + ) + + self._emit_step_end( + step_id=step_id, + step_index=step_index, + step=step, + outcome=outcome, + pre_url=pre_url, + post_url=pre_url, + llm_response=None, + snapshot_digest=None, + ) + + # Return special outcome that signals run completion + return outcome + + # Wait for page to stabilize before snapshot + if self.config.stabilize_enabled: + await runtime.stabilize() + + # Capture snapshot with escalation (includes scroll-after-escalation if enabled) ctx = await self._snapshot_with_escalation( runtime, goal=step.goal, capture_screenshot=self.config.trace_screenshots, + step=step, # Enable scroll-after-escalation to find elements outside viewport ) self._snapshot_context = ctx + if self.config.verbose: + elements = getattr(ctx.snapshot, "elements", []) or [] + print(f" [SNAPSHOT] Elements: {len(elements)}, URL: {getattr(ctx.snapshot, 'url', 'N/A')}", flush=True) + if ctx.requires_vision: + print(f" [SNAPSHOT] Vision required: {ctx.requires_vision}", flush=True) + # Emit snapshot trace self._emit_snapshot(ctx, step_id, step_index) @@ -1582,8 +2881,10 @@ async def _execute_step( # For now, fall through to standard executor # Determine element and action + original_action = step.action # Keep original plan action for submit logic action_type = step.action element_id: int | None = None + executor_text: str | None = None # Text from executor response (for TYPE actions) if action_type in ("CLICK", "TYPE_AND_SUBMIT"): # Try intent heuristics first (if available) @@ -1594,14 +2895,22 @@ async def _execute_step( if element_id is not None: used_heuristics = True action_taken = f"{action_type}({element_id}) [heuristic]" + if self.config.verbose: + print(f" [EXECUTOR] Using heuristic: {action_taken}", flush=True) else: # Fall back to LLM executor sys_prompt, user_prompt = build_executor_prompt( step.goal, step.intent, ctx.compact_representation, + input_text=step.input, ) + if self.config.verbose: + print("\n--- Compact Context (Snapshot) ---", flush=True) + print(ctx.compact_representation, flush=True) + print("--- End Compact Context ---\n", flush=True) + resp = self.executor.generate( sys_prompt, user_prompt, @@ -1610,6 +2919,9 @@ async def _execute_step( ) llm_response = resp.content + if self.config.verbose: + print(f" [EXECUTOR] LLM Response: {resp.content}", flush=True) + # Parse action parsed_action, parsed_args = self._parse_action(resp.content) action_type = parsed_action @@ -1617,9 +2929,13 @@ async def _execute_step( element_id = parsed_args[0] elif parsed_action == "TYPE" and len(parsed_args) >= 2: element_id = parsed_args[0] + executor_text = parsed_args[1] # Store text from executor response action_taken = f"{action_type}({', '.join(str(a) for a in parsed_args)})" + if self.config.verbose: + print(f" [EXECUTOR] Parsed action: {action_taken}", flush=True) + # Apply executor override if configured if self._executor_override and element_id is not None: try: @@ -1633,36 +2949,106 @@ async def _execute_step( if override_id is not None: element_id = override_id action_taken = f"{action_type}({element_id}) [override]" + if self.config.verbose: + print(f" [EXECUTOR] Override: {action_taken}", flush=True) else: error = f"Executor override rejected: {reason}" except Exception: pass # Ignore override errors elif action_type == "NAVIGATE": - action_taken = f"NAVIGATE({step.target})" + action_taken = f"NAVIGATE({step.target or 'current'})" elif action_type == "SCROLL": action_taken = "SCROLL(down)" else: action_taken = action_type + # If executor suggested SCROLL or selected a non-actionable element for CLICK, + # try scroll-to-find to locate the actual target element + if error is None and original_action == "CLICK": + should_scroll_to_find = False + reason = "" + + if action_type == "SCROLL": + # Executor explicitly said to scroll - element not visible + should_scroll_to_find = True + reason = "executor suggested SCROLL" + elif element_id is not None: + # Check if selected element is a non-actionable type (input fields) + selected_element = None + for el in elements: + if getattr(el, "id", None) == element_id: + selected_element = el + break + if selected_element: + role = (getattr(selected_element, "role", "") or "").lower() + if role in {"searchbox", "textbox", "combobox", "input", "textarea"}: + should_scroll_to_find = True + reason = f"selected input element ({role})" + + if should_scroll_to_find and self.config.scroll_to_find_enabled: + if self.config.verbose: + print(f" [SCROLL-TO-FIND] Triggering because {reason}", flush=True) + found_id, new_ctx = await self._scroll_to_find_element(runtime, step, step.goal) + if found_id is not None and new_ctx is not None: + element_id = found_id + ctx = new_ctx + action_type = "CLICK" + action_taken = f"CLICK({element_id}) [scroll-to-find]" + # Execute action via runtime if error is None: if action_type == "CLICK" and element_id is not None: await runtime.click(element_id) + if self.config.verbose: + print(f" [ACTION] CLICK({element_id})", flush=True) elif action_type == "TYPE" and element_id is not None: - text = step.input or "" + # Use text from executor response first, then fall back to step.input + text = executor_text or step.input or "" await runtime.type(element_id, text) + # If original plan action was TYPE_AND_SUBMIT, press Enter to submit + if original_action == "TYPE_AND_SUBMIT": + await runtime.press("Enter") + if self.config.verbose: + print(f" [ACTION] TYPE_AND_SUBMIT({element_id}, '{text}')", flush=True) + await runtime.stabilize() + elif self.config.verbose: + print(f" [ACTION] TYPE({element_id}, '{text[:30]}...')" if len(text) > 30 else f" [ACTION] TYPE({element_id}, '{text}')", flush=True) elif action_type == "TYPE_AND_SUBMIT" and element_id is not None: - text = step.input or "" + # Use text from executor response first, then step.input, then extract from goal + text = executor_text or step.input or "" + if not text: + # Try to extract from goal like "Search for Logitech mouse" + import re + match = re.search(r"[Ss]earch\s+(?:for\s+)?['\"]?([^'\"]+)['\"]?", step.goal) + if match: + text = match.group(1).strip() await runtime.type(element_id, text) await runtime.press("Enter") + if self.config.verbose: + print(f" [ACTION] TYPE_AND_SUBMIT({element_id}, '{text}')", flush=True) + # Wait for page to load after submit + await runtime.stabilize() elif action_type == "PRESS": key = "Enter" # Default await runtime.press(key) + if self.config.verbose: + print(f" [ACTION] PRESS({key})", flush=True) elif action_type == "SCROLL": await runtime.scroll("down") - elif action_type == "NAVIGATE" and step.target: - await runtime.goto(step.target) + if self.config.verbose: + print(f" [ACTION] SCROLL(down)", flush=True) + elif action_type == "NAVIGATE": + if step.target: + await runtime.goto(step.target) + if self.config.verbose: + print(f" [ACTION] NAVIGATE({step.target})", flush=True) + # Wait for page to load after navigation + await runtime.stabilize() + else: + # No target URL - we're already at the page, just verify + if self.config.verbose: + print(f" [ACTION] NAVIGATE(skip - already at page)", flush=True) elif action_type == "FINISH": pass # No action needed elif action_type not in ("CLICK", "TYPE", "TYPE_AND_SUBMIT") or element_id is None: @@ -1677,7 +3063,27 @@ async def _execute_step( # Run verifications if step.verify and error is None: + if self.config.verbose: + print(f" [VERIFY] Running {len(step.verify)} verification predicates...", flush=True) verification_passed = await self._verify_step(runtime, step) + if self.config.verbose: + print(f" [VERIFY] Predicate result: {'PASS' if verification_passed else 'FAIL'}", flush=True) + + # For successful CLICK actions, check if a modal/drawer appeared and dismiss it + # This handles cases like Amazon's "Add Protection" drawer after Add to Cart + # where verification passes (e.g., "Proceed to checkout" button exists in drawer) + # but we need to dismiss the overlay before continuing + if verification_passed and original_action == "CLICK" and self.config.modal.enabled: + try: + post_snap = await runtime.snapshot(emit_trace=False) + pre_elements = set(getattr(el, "id", 0) for el in (ctx.snapshot.elements or [])) + post_elements = set(getattr(el, "id", 0) for el in (post_snap.elements or [])) + new_elements = post_elements - pre_elements + if len(new_elements) >= self.config.modal.min_new_elements: + # Significant DOM change after CLICK - might be a modal/drawer + await self._attempt_modal_dismissal(runtime, post_snap) + except Exception: + pass # Ignore snapshot errors # If verification failed and we have optional substeps, try them if not verification_passed and step.optional_substeps: @@ -1689,8 +3095,60 @@ async def _execute_step( # Re-run verification after substeps if any(o.status == StepStatus.SUCCESS for o in substep_outcomes): verification_passed = await self._verify_step(runtime, step) + + # Fallback: For navigation-causing actions, if URL changed significantly, + # consider the action successful even if predicate verification failed. + # This handles cases where local LLMs generate imprecise predicates. + if not verification_passed and original_action in ("TYPE_AND_SUBMIT", "CLICK"): + current_url = await runtime.get_url() if hasattr(runtime, "get_url") else None + if current_url and pre_url and current_url != pre_url: + # URL changed - the action likely achieved navigation + if self.config.verbose: + print(f" [VERIFY] Predicate failed but URL changed: {pre_url} -> {current_url}", flush=True) + print(f" [VERIFY] Accepting {original_action} as successful (URL change fallback)", flush=True) + verification_passed = True + elif original_action == "CLICK" and error is None and element_id is not None: + # For CLICK actions that don't change URL, check if DOM changed + # (e.g., modal appeared, cart drawer opened) + # But first verify we clicked an actionable element, not an input field + clicked_element = None + for el in (ctx.snapshot.elements or []): + if getattr(el, "id", None) == element_id: + clicked_element = el + break + + # Don't accept DOM change fallback for input/search elements + # (clicking them causes focus changes but doesn't complete actions) + clicked_role = (getattr(clicked_element, "role", "") or "").lower() if clicked_element else "" + input_roles = {"searchbox", "textbox", "combobox", "input", "textarea"} + if clicked_role in input_roles: + if self.config.verbose: + print(f" [VERIFY] Clicked input element ({clicked_role}), not accepting DOM change fallback", flush=True) + else: + try: + post_snap = await runtime.snapshot(emit_trace=False) + pre_elements = set(getattr(el, "id", 0) for el in (ctx.snapshot.elements or [])) + post_elements = set(getattr(el, "id", 0) for el in (post_snap.elements or [])) + new_elements = post_elements - pre_elements + if len(new_elements) >= self.config.modal.min_new_elements: # Significant DOM change + if self.config.verbose: + print(f" [VERIFY] Predicate failed but DOM changed ({len(new_elements)} new elements)", flush=True) + print(f" [VERIFY] Accepting CLICK as successful (DOM change fallback)", flush=True) + verification_passed = True + + # Attempt modal dismissal if enabled + # This handles blocking overlays like product protection drawers + if self.config.modal.enabled: + await self._attempt_modal_dismissal(runtime, post_snap) + except Exception: + pass # Ignore snapshot errors else: verification_passed = error is None + if self.config.verbose: + if not step.verify: + print(f" [VERIFY] No predicates defined, using error check: {'PASS' if verification_passed else 'FAIL'}", flush=True) + elif error: + print(f" [VERIFY] Skipped due to error: {error}", flush=True) # Track successful URL for recovery if verification_passed and self.config.recovery.track_successful_urls: @@ -1749,6 +3207,12 @@ async def _verify_step( pred = build_predicate(verify_spec) label = f"verify_{step.id}" + # Log what predicate we're checking + if self.config.verbose: + import json + spec_str = json.dumps(verify_spec, default=str) if isinstance(verify_spec, dict) else str(verify_spec) + print(f" [VERIFY] Checking predicate: {spec_str[:100]}...", flush=True) + # Use runtime's check with eventually for required verifications if step.required: ok = await runtime.check(pred, label=label, required=True).eventually( @@ -1760,7 +3224,12 @@ async def _verify_step( ok = runtime.assert_(pred, label=label, required=False) if not ok: + if self.config.verbose: + print(f" [VERIFY] Predicate FAILED", flush=True) return False + else: + if self.config.verbose: + print(f" [VERIFY] Predicate PASSED", flush=True) return True @@ -1771,7 +3240,7 @@ async def _verify_step( async def run( self, runtime: AgentRuntime, - task: str, + task: AutomationTask | str, *, start_url: str | None = None, run_id: str | None = None, @@ -1781,42 +3250,150 @@ async def run( Args: runtime: AgentRuntime instance - task: Task description - start_url: Starting URL (optional) + task: AutomationTask instance or task description string + start_url: Starting URL (only needed if task is a string) run_id: Run ID for tracing (optional) Returns: RunOutcome with execution results + + Example with AutomationTask: + task = AutomationTask( + task_id="purchase-laptop", + starting_url="https://amazon.com", + task="Find a laptop under $1000 and add to cart", + category=TaskCategory.TRANSACTION, + ) + result = await agent.run(runtime, task) + + Example with string: + result = await agent.run( + runtime, + "Search for laptops", + start_url="https://amazon.com", + ) """ - self._run_id = run_id or str(uuid.uuid4()) + # Normalize task to AutomationTask + if isinstance(task, str): + if start_url is None: + raise ValueError("start_url is required when task is a string") + automation_task = AutomationTask( + task_id=run_id or str(uuid.uuid4()), + starting_url=start_url, + task=task, + ) + else: + automation_task = task + if start_url is None: + start_url = automation_task.starting_url + + self._current_task = automation_task + self._run_id = run_id or automation_task.task_id self._replans_used = 0 self._vision_calls = 0 start_time = time.time() + # Initialize recovery state if enabled + if automation_task.enable_recovery: + self._recovery_state = RecoveryState( + max_recovery_attempts=automation_task.max_recovery_attempts, + ) + else: + self._recovery_state = None + + # Initialize composable heuristics + self._composable_heuristics = ComposableHeuristics( + static_heuristics=self._intent_heuristics, + task_category=automation_task.category, + ) + + # Use task description for planning + task_description = automation_task.task + # Emit run start - self._emit_run_start(task, start_url) + self._emit_run_start(task_description, start_url) step_outcomes: list[StepOutcome] = [] error: str | None = None try: # Generate plan - plan = await self.plan(task, start_url=start_url) + plan = await self.plan(task_description, start_url=start_url) # Execute steps step_index = 0 while step_index < len(plan.steps): step = plan.steps[step_index] + # Set step-specific heuristic hints + if self._composable_heuristics and step.heuristic_hints: + self._composable_heuristics.set_step_hints(step.heuristic_hints) + outcome = await self._execute_step(step, runtime, step_index) step_outcomes.append(outcome) + # Check if auth boundary was reached at step start (graceful termination) + if outcome.action_taken == "AUTH_BOUNDARY_REACHED": + if self.config.verbose: + print(f" [AUTH] Run completed at authentication boundary", flush=True) + break # Graceful termination - no error + + # Record checkpoint on success (for recovery) + if outcome.status in (StepStatus.SUCCESS, StepStatus.VISION_FALLBACK): + if self._recovery_state is not None and outcome.url_after: + self._recovery_state.record_checkpoint( + url=outcome.url_after, + step_index=step_index, + snapshot_digest=hashlib.sha256( + (outcome.url_after or "").encode() + ).hexdigest()[:16], + predicates_passed=[], + ) + # Handle failure if outcome.status == StepStatus.FAILED and step.required: + # Check if we've reached an authentication boundary + # This is a graceful terminal state - agent did all it could + if self.config.auth_boundary.enabled: + is_auth_page = await self._detect_auth_boundary(runtime) + if is_auth_page and self.config.auth_boundary.stop_on_auth: + if self.config.verbose: + print(f" [AUTH] Stopping at authentication boundary", flush=True) + # Mark as success since we reached the auth page successfully + # The "failure" is just that we can't proceed without credentials + error = None # Clear any error + # Update the outcome to indicate auth boundary + outcome = StepOutcome( + step_id=outcome.step_id, + goal=outcome.goal, + status=StepStatus.SUCCESS, # Treat as success + action_taken=outcome.action_taken, + verification_passed=True, + used_vision=outcome.used_vision, + error=self.config.auth_boundary.auth_success_message, + duration_ms=outcome.duration_ms, + url_before=outcome.url_before, + url_after=outcome.url_after, + ) + step_outcomes[-1] = outcome # Replace the failed outcome + break # Stop execution gracefully + + # Try recovery first if available + if self._recovery_state and self._recovery_state.can_recover(): + recovered = await self._attempt_recovery(runtime) + if recovered: + # Resume from checkpoint step + checkpoint = self._recovery_state.current_recovery_target + if checkpoint: + step_index = checkpoint.step_index + 1 + self._recovery_state.clear_recovery_target() + continue + + # Try replanning if self._replans_used < self.config.retry.max_replans: try: plan = await self.replan( - task, + task_description, step, outcome.error or "verification_failed", ) @@ -1835,6 +3412,41 @@ async def run( step_index += 1 + # Check for checkout page detection at end of plan + # This handles cases where modal dismissal leaves us on a checkout page + # but the plan has no more steps + if ( + self.config.checkout.enabled + and self.config.checkout.trigger_replan + and step_index >= len(plan.steps) # Just finished last step + and outcome.status in (StepStatus.SUCCESS, StepStatus.VISION_FALLBACK) + and self._replans_used < self.config.retry.max_replans + ): + # Check if we're on a checkout-relevant page + is_checkout, page_type = await self._detect_checkout_page(runtime) + if is_checkout and self._should_continue_after_checkout_detection( + page_type, step_index - 1, len(plan.steps) + ): + if self.config.verbose: + print(f" [CHECKOUT] Triggering replan for {page_type} page", flush=True) + try: + # Replan with context about where we are + continuation_task = self._build_checkout_continuation_task( + task_description, page_type + ) + plan = await self.plan(continuation_task, start_url=None) + step_index = 0 # Start from beginning of new plan + self._replans_used += 1 + continue + except Exception as e: + if self.config.verbose: + print(f" [CHECKOUT] Replan failed: {e}", flush=True) + # Continue without replanning + + # Clear step hints for next iteration + if self._composable_heuristics: + self._composable_heuristics.clear_step_hints() + except Exception as e: error = str(e) @@ -1847,7 +3459,7 @@ async def run( run_outcome = RunOutcome( run_id=self._run_id, - task=task, + task=task_description, success=success, steps_completed=len(step_outcomes), steps_total=len(self._current_plan.steps) if self._current_plan else 0, diff --git a/predicate/agents/recovery.py b/predicate/agents/recovery.py new file mode 100644 index 0000000..e9cbf1c --- /dev/null +++ b/predicate/agents/recovery.py @@ -0,0 +1,233 @@ +""" +Recovery: Checkpoint and rollback mechanisms for automation recovery. + +This module provides state tracking and recovery mechanisms for when +automation gets off-track. Key concepts: + +- RecoveryCheckpoint: A snapshot of known-good state (URL, step, digest) +- RecoveryState: Tracks checkpoints and manages recovery attempts + +Recovery flow: +1. After each successful step verification, record a checkpoint +2. If verification fails repeatedly, attempt recovery to last checkpoint +3. Navigate back to checkpoint URL and re-verify +4. If recovery succeeds, resume from checkpoint step +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime + + +@dataclass +class RecoveryCheckpoint: + """ + Checkpoint for rollback recovery. + + Created after each successful step verification to enable rollback + if subsequent steps fail. + + Attributes: + url: The URL at this checkpoint + step_index: The step index that was just completed (0-indexed) + snapshot_digest: Hash of the snapshot for state verification + timestamp: When the checkpoint was created + predicates_passed: Labels of predicates that passed at this checkpoint + """ + + url: str + step_index: int + snapshot_digest: str + timestamp: datetime + predicates_passed: list[str] = field(default_factory=list) + + def __post_init__(self) -> None: + # Ensure predicates_passed is a list + if self.predicates_passed is None: + self.predicates_passed = [] + + +@dataclass +class RecoveryState: + """ + Tracks recovery state for rollback mechanism. + + Checkpoints are created after each successful step verification. + Recovery can be attempted when steps fail repeatedly. + + Attributes: + checkpoints: List of recorded checkpoints (most recent last) + recovery_attempts_used: Number of recovery attempts consumed + max_recovery_attempts: Maximum allowed recovery attempts + current_recovery_target: The checkpoint being recovered to (if any) + max_checkpoints: Maximum checkpoints to retain (bounds memory) + + Example: + state = RecoveryState(max_recovery_attempts=2) + + # After successful step + state.record_checkpoint( + url="https://shop.com/cart", + step_index=2, + snapshot_digest="abc123", + predicates_passed=["url_contains('/cart')"], + ) + + # On repeated failure + if state.can_recover(): + checkpoint = state.consume_recovery_attempt() + # Navigate to checkpoint.url and resume + """ + + checkpoints: list[RecoveryCheckpoint] = field(default_factory=list) + recovery_attempts_used: int = 0 + max_recovery_attempts: int = 2 + current_recovery_target: RecoveryCheckpoint | None = None + max_checkpoints: int = 10 + + def record_checkpoint( + self, + url: str, + step_index: int, + snapshot_digest: str, + predicates_passed: list[str] | None = None, + ) -> RecoveryCheckpoint: + """ + Record a successful checkpoint. + + Called after step verification passes to enable future rollback. + + Args: + url: Current page URL + step_index: Index of the step that just completed + snapshot_digest: Hash of current snapshot for verification + predicates_passed: Labels of predicates that passed + + Returns: + The created RecoveryCheckpoint + """ + checkpoint = RecoveryCheckpoint( + url=url, + step_index=step_index, + snapshot_digest=snapshot_digest, + timestamp=datetime.now(), + predicates_passed=predicates_passed or [], + ) + self.checkpoints.append(checkpoint) + + # Keep only last N checkpoints to bound memory + if len(self.checkpoints) > self.max_checkpoints: + self.checkpoints = self.checkpoints[-self.max_checkpoints :] + + return checkpoint + + def get_recovery_target(self) -> RecoveryCheckpoint | None: + """ + Get the most recent checkpoint for recovery. + + Returns: + Most recent RecoveryCheckpoint, or None if no checkpoints exist + """ + if not self.checkpoints: + return None + return self.checkpoints[-1] + + def get_checkpoint_at_step(self, step_index: int) -> RecoveryCheckpoint | None: + """ + Get checkpoint at a specific step index. + + Args: + step_index: The step index to find + + Returns: + RecoveryCheckpoint at that step, or None if not found + """ + for checkpoint in reversed(self.checkpoints): + if checkpoint.step_index == step_index: + return checkpoint + return None + + def get_checkpoint_before_step(self, step_index: int) -> RecoveryCheckpoint | None: + """ + Get the most recent checkpoint before a given step. + + Args: + step_index: The step index to find checkpoint before + + Returns: + Most recent checkpoint with step_index < given index, or None + """ + for checkpoint in reversed(self.checkpoints): + if checkpoint.step_index < step_index: + return checkpoint + return None + + def can_recover(self) -> bool: + """ + Check if recovery is still possible. + + Returns: + True if recovery attempts remain and checkpoints exist + """ + return ( + self.recovery_attempts_used < self.max_recovery_attempts + and len(self.checkpoints) > 0 + ) + + def consume_recovery_attempt(self) -> RecoveryCheckpoint | None: + """ + Consume a recovery attempt and return target checkpoint. + + Increments recovery_attempts_used and sets current_recovery_target. + + Returns: + The checkpoint to recover to, or None if recovery not possible + """ + if not self.can_recover(): + return None + + self.recovery_attempts_used += 1 + self.current_recovery_target = self.get_recovery_target() + return self.current_recovery_target + + def clear_recovery_target(self) -> None: + """Clear the current recovery target after recovery completes.""" + self.current_recovery_target = None + + def reset(self) -> None: + """Reset recovery state for a new run.""" + self.checkpoints.clear() + self.recovery_attempts_used = 0 + self.current_recovery_target = None + + def pop_checkpoint(self) -> RecoveryCheckpoint | None: + """ + Remove and return the most recent checkpoint. + + Useful when recovery fails and we want to try an earlier checkpoint. + + Returns: + The removed checkpoint, or None if no checkpoints exist + """ + if not self.checkpoints: + return None + return self.checkpoints.pop() + + @property + def last_successful_url(self) -> str | None: + """Get the URL of the most recent successful checkpoint.""" + if not self.checkpoints: + return None + return self.checkpoints[-1].url + + @property + def last_successful_step(self) -> int | None: + """Get the step index of the most recent successful checkpoint.""" + if not self.checkpoints: + return None + return self.checkpoints[-1].step_index + + def __len__(self) -> int: + """Return number of checkpoints.""" + return len(self.checkpoints) diff --git a/predicate/backends/sentience_context.py b/predicate/backends/sentience_context.py index d24f8ae..4298b1d 100644 --- a/predicate/backends/sentience_context.py +++ b/predicate/backends/sentience_context.py @@ -382,8 +382,8 @@ def rank_sort_key(el: Element) -> tuple[float, float, float, float]: # DG: 1 if dominant group, else 0 dg_flag = "1" if in_dg else "0" - # href: short token (domain or last path segment, or blank) - href = self._compress_href(el.href) + # href: meaningful path portion (e.g., "/dp/B0FC5SJNQX" for products) + href = self._compress_href(el.href, snapshot.url) # Ultra-compact format: ID|role|text|imp|is_primary|docYq|ord|DG|href line = f"{el.id}|{role}|{name}|{importance}|{is_primary_flag}|{doc_yq}|{ord_val}|{dg_flag}|{href}" @@ -438,37 +438,91 @@ async def _wait_for_extension( logger.warning("Sentience extension not ready after %dms timeout", timeout_ms) return False - def _compress_href(self, href: str | None) -> str: + def _compress_href(self, href: str | None, page_url: str | None = None) -> str: """ - Compress href into a short token for minimal tokens. + Compress href into a meaningful short token for minimal tokens. + + Extracts the path portion that's useful for navigation decisions, + stripping the base domain if it matches the current page. Args: href: Full URL or None + page_url: Current page URL (to detect same-domain links) Returns: - Short token (domain second-level or last path segment) + Compressed href with meaningful path info (e.g., "B0FC5SJNQX" or "mouse") """ if not href: return "" try: parsed = urlparse(href) - if parsed.netloc: - # Extract second-level domain (e.g., "github" from "github.com") + page_parsed = urlparse(page_url) if page_url else None + + # Check if same domain as current page + is_same_domain = ( + page_parsed + and parsed.netloc + and (parsed.netloc == page_parsed.netloc + or parsed.netloc.endswith("." + page_parsed.netloc.split(".")[-2] + "." + page_parsed.netloc.split(".")[-1]) + if len(page_parsed.netloc.split(".")) >= 2 else False) + ) + + if parsed.netloc and not is_same_domain: + # External link - show domain (truncate to 10 chars) parts = parsed.netloc.split(".") if len(parts) >= 2: return parts[-2][:10] return parsed.netloc[:10] - elif parsed.path: - # Use last path segment - segments = [s for s in parsed.path.split("/") if s] - if segments: - return segments[-1][:10] - return "item" + + # Same domain or relative link - extract meaningful path + path = parsed.path or "" + + # For product pages, extract key identifiers (just the ID, not the path prefix) + # Amazon: /dp/XXXXX, /gp/product/XXXXX + # Generic: /product/XXX, /item/XXX, /p/XXX + import re + product_patterns = [ + r"/dp/([A-Z0-9]+)", # Amazon product + r"/gp/product/([A-Z0-9]+)", # Amazon alt + r"/product/([^/]+)", # Generic product + r"/item/([^/]+)", # Generic item + r"/p/([^/]+)", # Short product + ] + for pattern in product_patterns: + match = re.search(pattern, path, re.IGNORECASE) + if match: + return match.group(1)[:30] + + # For search pages, extract search indicator + if "/s?" in href or "/search?" in href or "?q=" in href or "?k=" in href: + query = parsed.query or "" + # Extract search term if present + for param in ["k", "q", "query", "search"]: + match = re.search(rf"{param}=([^&]+)", query) + if match: + term = match.group(1)[:15] + return f"/s?{param}={term}" + return "/search" + + # For cart/checkout pages + if "/cart" in path.lower(): + return "/cart" + if "/checkout" in path.lower(): + return "/checkout" + + # Fallback: use last meaningful path segment only (no leading slash) + segments = [s for s in path.split("/") if s and len(s) > 1] + if segments: + # Return only the last segment (max 30 chars) + return segments[-1][:30] + + return path[:30] if path else "" + except Exception: pass - return "item" + return "" # --------------------------------------------------------------------------- diff --git a/predicate/llm_provider.py b/predicate/llm_provider.py index fe84e65..cfe02f6 100644 --- a/predicate/llm_provider.py +++ b/predicate/llm_provider.py @@ -186,11 +186,18 @@ def generate( temperature: Sampling temperature (0.0 = deterministic, 1.0 = creative) max_tokens: Maximum tokens to generate json_mode: Enable JSON response format (requires model support) - **kwargs: Additional OpenAI API parameters + **kwargs: Additional OpenAI API parameters (max_new_tokens is mapped to max_tokens) Returns: LLMResponse object """ + # Handle max_new_tokens -> max_tokens mapping for cross-provider compatibility + if "max_new_tokens" in kwargs: + if max_tokens is None: + max_tokens = kwargs.pop("max_new_tokens") + else: + kwargs.pop("max_new_tokens") # max_tokens takes precedence + messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -415,11 +422,19 @@ def generate( user_prompt: User query temperature: Sampling temperature max_tokens: Maximum tokens to generate (required by Anthropic) - **kwargs: Additional Anthropic API parameters + **kwargs: Additional Anthropic API parameters (max_new_tokens is mapped to max_tokens) Returns: LLMResponse object """ + # Handle max_new_tokens -> max_tokens mapping for cross-provider compatibility + if "max_new_tokens" in kwargs: + # Use max_new_tokens value if max_tokens is still at default + if max_tokens == 1024: + max_tokens = kwargs.pop("max_new_tokens") + else: + kwargs.pop("max_new_tokens") # Explicit max_tokens takes precedence + # Build API parameters api_params = { "model": self._model_name, diff --git a/tests/test_automation_task.py b/tests/test_automation_task.py new file mode 100644 index 0000000..4502361 --- /dev/null +++ b/tests/test_automation_task.py @@ -0,0 +1,480 @@ +""" +Unit tests for AutomationTask and related models. +""" + +import pytest +from dataclasses import dataclass + +from predicate.agents import ( + AutomationTask, + ExtractionSpec, + SuccessCriteria, + TaskCategory, + HeuristicHint, + RecoveryCheckpoint, + RecoveryState, + ComposableHeuristics, + COMMON_HINTS, + get_common_hint, +) + + +class TestAutomationTask: + """Tests for AutomationTask dataclass.""" + + def test_basic_creation(self): + """Test creating a basic AutomationTask.""" + task = AutomationTask( + task_id="test-001", + starting_url="https://example.com", + task="Click the login button", + ) + assert task.task_id == "test-001" + assert task.starting_url == "https://example.com" + assert task.task == "Click the login button" + assert task.category is None + assert task.max_steps == 50 + assert task.enable_recovery is True + + def test_with_category(self): + """Test creating AutomationTask with category.""" + task = AutomationTask( + task_id="test-002", + starting_url="https://shop.com", + task="Add item to cart", + category=TaskCategory.TRANSACTION, + ) + assert task.category == TaskCategory.TRANSACTION + + def test_with_success_criteria(self): + """Test creating AutomationTask with success criteria.""" + task = AutomationTask( + task_id="test-003", + starting_url="https://shop.com", + task="Complete checkout", + success_criteria=SuccessCriteria( + predicates=[ + {"predicate": "url_contains", "args": ["/confirmation"]}, + {"predicate": "exists", "args": [".order-number"]}, + ], + require_all=True, + ), + ) + assert task.success_criteria is not None + assert len(task.success_criteria.predicates) == 2 + assert task.success_criteria.require_all is True + + def test_with_extraction_spec(self): + """Test creating AutomationTask with extraction spec.""" + task = AutomationTask( + task_id="test-004", + starting_url="https://shop.com/product", + task="Extract product details", + category=TaskCategory.EXTRACTION, + extraction_spec=ExtractionSpec( + output_schema={"name": "str", "price": "float"}, + format="json", + ), + ) + assert task.extraction_spec is not None + assert task.extraction_spec.format == "json" + assert task.extraction_spec.output_schema == {"name": "str", "price": "float"} + + def test_from_string(self): + """Test creating AutomationTask from string.""" + task = AutomationTask.from_string( + "Search for laptops", + "https://amazon.com", + category=TaskCategory.SEARCH, + ) + assert task.task == "Search for laptops" + assert task.starting_url == "https://amazon.com" + assert task.category == TaskCategory.SEARCH + assert task.task_id is not None + + def test_with_success_criteria_method(self): + """Test with_success_criteria method.""" + task = AutomationTask( + task_id="test-005", + starting_url="https://shop.com", + task="Checkout", + ) + task_with_criteria = task.with_success_criteria( + {"predicate": "url_contains", "args": ["/success"]}, + {"predicate": "exists", "args": [".confirmation"]}, + ) + assert task_with_criteria.success_criteria is not None + assert len(task_with_criteria.success_criteria.predicates) == 2 + # Original task unchanged + assert task.success_criteria is None + + def test_with_extraction_method(self): + """Test with_extraction method.""" + task = AutomationTask( + task_id="test-006", + starting_url="https://shop.com", + task="Get product info", + ) + task_with_extraction = task.with_extraction( + output_schema={"name": "str"}, + format="json", + ) + assert task_with_extraction.extraction_spec is not None + assert task_with_extraction.category == TaskCategory.EXTRACTION + + def test_from_webbench_task(self): + """Test converting from WebBenchTask-like object.""" + @dataclass(frozen=True) + class MockWebBenchTask: + id: str + starting_url: str + task: str + category: str | None = None + + wb_task = MockWebBenchTask( + id="wb-001", + starting_url="https://example.com", + task="Read the page title", + category="READ", + ) + automation_task = AutomationTask.from_webbench_task(wb_task) + + assert automation_task.task_id == "wb-001" + assert automation_task.starting_url == "https://example.com" + assert automation_task.task == "Read the page title" + assert automation_task.category == TaskCategory.EXTRACTION + assert automation_task.extraction_spec is not None + + +class TestHeuristicHint: + """Tests for HeuristicHint model.""" + + def test_basic_creation(self): + """Test creating a basic HeuristicHint.""" + hint = HeuristicHint( + intent_pattern="add_to_cart", + text_patterns=["add to cart", "add to bag"], + role_filter=["button"], + priority=10, + ) + assert hint.intent_pattern == "add_to_cart" + assert len(hint.text_patterns) == 2 + assert hint.role_filter == ["button"] + assert hint.priority == 10 + + def test_matches_intent(self): + """Test intent matching.""" + hint = HeuristicHint( + intent_pattern="checkout", + text_patterns=["checkout"], + role_filter=["button"], + ) + assert hint.matches_intent("checkout") is True + assert hint.matches_intent("proceed_to_checkout") is True + assert hint.matches_intent("CHECKOUT") is True + assert hint.matches_intent("login") is False + + def test_matches_element(self): + """Test element matching.""" + hint = HeuristicHint( + intent_pattern="add_to_cart", + text_patterns=["add to cart"], + role_filter=["button"], + ) + + @dataclass + class MockElement: + id: int + role: str + text: str + + matching = MockElement(id=1, role="button", text="Add to Cart") + non_matching_role = MockElement(id=2, role="link", text="Add to Cart") + non_matching_text = MockElement(id=3, role="button", text="Remove") + + assert hint.matches_element(matching) is True + assert hint.matches_element(non_matching_role) is False + assert hint.matches_element(non_matching_text) is False + + def test_common_hints(self): + """Test COMMON_HINTS dictionary.""" + assert "add_to_cart" in COMMON_HINTS + assert "checkout" in COMMON_HINTS + assert "login" in COMMON_HINTS + assert "submit" in COMMON_HINTS + + def test_get_common_hint(self): + """Test get_common_hint function.""" + hint = get_common_hint("add_to_cart") + assert hint is not None + assert hint.intent_pattern == "add_to_cart" + + # Test with dashes/spaces + hint2 = get_common_hint("add-to-cart") + assert hint2 is not None + + # Test non-existent + hint3 = get_common_hint("nonexistent_pattern") + assert hint3 is None + + +class TestRecoveryState: + """Tests for RecoveryState and RecoveryCheckpoint.""" + + def test_basic_creation(self): + """Test creating a basic RecoveryState.""" + state = RecoveryState(max_recovery_attempts=3) + assert state.max_recovery_attempts == 3 + assert state.recovery_attempts_used == 0 + assert len(state.checkpoints) == 0 + + def test_record_checkpoint(self): + """Test recording checkpoints.""" + state = RecoveryState() + checkpoint = state.record_checkpoint( + url="https://shop.com/cart", + step_index=2, + snapshot_digest="abc123", + predicates_passed=["url_contains('/cart')"], + ) + + assert checkpoint.url == "https://shop.com/cart" + assert checkpoint.step_index == 2 + assert len(state) == 1 + + def test_get_recovery_target(self): + """Test getting recovery target.""" + state = RecoveryState() + + # No checkpoints + assert state.get_recovery_target() is None + + # Add checkpoints + state.record_checkpoint("https://a.com", 0, "a", []) + state.record_checkpoint("https://b.com", 1, "b", []) + + target = state.get_recovery_target() + assert target is not None + assert target.url == "https://b.com" + + def test_can_recover(self): + """Test can_recover logic.""" + state = RecoveryState(max_recovery_attempts=2) + + # No checkpoints + assert state.can_recover() is False + + # Add checkpoint + state.record_checkpoint("https://a.com", 0, "a", []) + assert state.can_recover() is True + + # Consume attempts + state.consume_recovery_attempt() + assert state.can_recover() is True + + state.consume_recovery_attempt() + assert state.can_recover() is False + + def test_consume_recovery_attempt(self): + """Test consuming recovery attempts.""" + state = RecoveryState(max_recovery_attempts=2) + state.record_checkpoint("https://a.com", 0, "a", []) + + checkpoint = state.consume_recovery_attempt() + assert checkpoint is not None + assert state.recovery_attempts_used == 1 + assert state.current_recovery_target == checkpoint + + def test_max_checkpoints_limit(self): + """Test that checkpoints are bounded.""" + state = RecoveryState(max_checkpoints=3) + + for i in range(5): + state.record_checkpoint(f"https://a.com/{i}", i, str(i), []) + + # Should only keep last 3 + assert len(state) == 3 + assert state.checkpoints[0].url == "https://a.com/2" + + def test_reset(self): + """Test resetting state.""" + state = RecoveryState() + state.record_checkpoint("https://a.com", 0, "a", []) + state.consume_recovery_attempt() + + state.reset() + + assert len(state) == 0 + assert state.recovery_attempts_used == 0 + assert state.current_recovery_target is None + + def test_last_successful_url(self): + """Test last_successful_url property.""" + state = RecoveryState() + assert state.last_successful_url is None + + state.record_checkpoint("https://a.com", 0, "a", []) + assert state.last_successful_url == "https://a.com" + + +class TestComposableHeuristics: + """Tests for ComposableHeuristics.""" + + def test_basic_creation(self): + """Test creating ComposableHeuristics.""" + heuristics = ComposableHeuristics() + assert heuristics.priority_order() is not None + + def test_with_task_category(self): + """Test with task category.""" + heuristics = ComposableHeuristics( + task_category=TaskCategory.TRANSACTION, + ) + + @dataclass + class MockElement: + id: int + role: str + text: str + + elements = [ + MockElement(id=1, role="button", text="Add to Cart"), + MockElement(id=2, role="link", text="Home"), + ] + + # Category defaults should find transaction-related buttons + element_id = heuristics.find_element_for_intent( + "add_item", + elements, + "https://shop.com", + "Add item to cart", + ) + assert element_id == 1 + + def test_set_step_hints(self): + """Test setting step hints.""" + heuristics = ComposableHeuristics() + + heuristics.set_step_hints([ + {"intent_pattern": "custom", "text_patterns": ["custom text"], "role_filter": ["button"]}, + ]) + + # Verify hints are set + order = heuristics.priority_order() + assert "custom" in order + + def test_clear_step_hints(self): + """Test clearing step hints.""" + heuristics = ComposableHeuristics() + heuristics.set_step_hints([ + {"intent_pattern": "test", "text_patterns": ["test"]}, + ]) + + heuristics.clear_step_hints() + + # Should not have "test" in priority anymore (unless it's a common hint) + order = heuristics.priority_order() + # Common hints should still be there + assert "add_to_cart" in order + + def test_hint_priority_over_static(self): + """Test that planner hints take priority.""" + class StaticHeuristics: + def find_element_for_intent(self, intent, elements, url, goal): + return 999 # Always return 999 + + def priority_order(self): + return ["static"] + + heuristics = ComposableHeuristics( + static_heuristics=StaticHeuristics(), + ) + + @dataclass + class MockElement: + id: int + role: str + text: str + + elements = [ + MockElement(id=1, role="button", text="Custom Button"), + MockElement(id=2, role="button", text="Other"), + ] + + # Set hints that match element 1 + heuristics.set_step_hints([ + HeuristicHint( + intent_pattern="custom", + text_patterns=["custom button"], + role_filter=["button"], + priority=100, + ), + ]) + + # Hints should be tried first + element_id = heuristics.find_element_for_intent( + "custom_action", + elements, + "https://example.com", + "Do custom action", + ) + assert element_id == 1 # Not 999 from static heuristics + + +class TestTaskCategory: + """Tests for TaskCategory enum.""" + + def test_all_categories_exist(self): + """Test all expected categories exist.""" + assert TaskCategory.NAVIGATION == "navigation" + assert TaskCategory.SEARCH == "search" + assert TaskCategory.FORM_FILL == "form_fill" + assert TaskCategory.EXTRACTION == "extraction" + assert TaskCategory.TRANSACTION == "transaction" + assert TaskCategory.VERIFICATION == "verification" + + def test_category_values(self): + """Test category string values.""" + assert TaskCategory.TRANSACTION.value == "transaction" + assert str(TaskCategory.EXTRACTION) == "TaskCategory.EXTRACTION" + + +class TestExtractionSpec: + """Tests for ExtractionSpec model.""" + + def test_defaults(self): + """Test default values.""" + spec = ExtractionSpec() + assert spec.output_schema is None + assert spec.target_selectors == [] + assert spec.format == "json" + assert spec.require_evidence is True + + def test_with_schema(self): + """Test with schema.""" + spec = ExtractionSpec( + output_schema={"name": "str", "price": "float"}, + format="json", + ) + assert spec.output_schema == {"name": "str", "price": "float"} + + +class TestSuccessCriteria: + """Tests for SuccessCriteria model.""" + + def test_defaults(self): + """Test default values.""" + criteria = SuccessCriteria() + assert criteria.predicates == [] + assert criteria.require_all is True + + def test_with_predicates(self): + """Test with predicates.""" + criteria = SuccessCriteria( + predicates=[ + {"predicate": "url_contains", "args": ["/success"]}, + ], + require_all=False, + ) + assert len(criteria.predicates) == 1 + assert criteria.require_all is False diff --git a/tests/unit/test_planner_executor_agent.py b/tests/unit/test_planner_executor_agent.py index 02a579b..8393821 100644 --- a/tests/unit/test_planner_executor_agent.py +++ b/tests/unit/test_planner_executor_agent.py @@ -17,18 +17,89 @@ from typing import Any from predicate.agents.planner_executor_agent import ( + AuthBoundaryConfig, ExecutorOverride, IntentHeuristics, + ModalDismissalConfig, Plan, PlannerExecutorConfig, PlanStep, PredicateSpec, RecoveryNavigationConfig, + SnapshotEscalationConfig, + build_executor_prompt, normalize_plan, validate_plan_smoothness, ) +# --------------------------------------------------------------------------- +# Test build_executor_prompt +# --------------------------------------------------------------------------- + + +class TestBuildExecutorPrompt: + """Tests for the build_executor_prompt function.""" + + def test_basic_prompt_structure(self) -> None: + sys_prompt, user_prompt = build_executor_prompt( + goal="Click the submit button", + intent=None, + compact_context="123|button|Submit|100|1|0|-|0|", + ) + assert "CLICK()" in sys_prompt + assert "TYPE(" in sys_prompt + assert "Goal: Click the submit button" in user_prompt + assert "123|button|Submit" in user_prompt + + def test_includes_intent_when_provided(self) -> None: + sys_prompt, user_prompt = build_executor_prompt( + goal="Click on product", + intent="Click the first product link", + compact_context="456|link|Product|100|1|0|-|0|", + ) + assert "Intent: Click the first product link" in user_prompt + + def test_no_intent_line_when_none(self) -> None: + sys_prompt, user_prompt = build_executor_prompt( + goal="Click button", + intent=None, + compact_context="789|button|OK|100|1|0|-|0|", + ) + assert "Intent:" not in user_prompt + + def test_includes_input_text_when_provided(self) -> None: + """Input text should be included for TYPE_AND_SUBMIT actions.""" + sys_prompt, user_prompt = build_executor_prompt( + goal="Search for Logitech mouse", + intent=None, + compact_context="167|searchbox|Search|100|1|0|-|0|", + input_text="Logitech mouse", + ) + assert 'Text to type: "Logitech mouse"' in user_prompt + + def test_no_input_line_when_none(self) -> None: + """No input text line when not provided.""" + sys_prompt, user_prompt = build_executor_prompt( + goal="Click button", + intent=None, + compact_context="123|button|Submit|100|1|0|-|0|", + input_text=None, + ) + assert "Text to type:" not in user_prompt + + def test_includes_both_intent_and_input(self) -> None: + """Both intent and input can be present.""" + sys_prompt, user_prompt = build_executor_prompt( + goal="Search for laptop", + intent="search_box", + compact_context="100|searchbox|Search|100|1|0|-|0|", + input_text="laptop", + ) + assert "Intent: search_box" in user_prompt + assert 'Text to type: "laptop"' in user_prompt + + # --------------------------------------------------------------------------- # Test normalize_plan # --------------------------------------------------------------------------- @@ -490,6 +561,82 @@ def test_custom_values(self) -> None: assert config.track_successful_urls is False +# --------------------------------------------------------------------------- +# Test ModalDismissalConfig +# --------------------------------------------------------------------------- + + +class TestModalDismissalConfig: + """Tests for ModalDismissalConfig.""" + + def test_default_values(self) -> None: + config = ModalDismissalConfig() + assert config.enabled is True + assert config.max_attempts == 2 + assert config.min_new_elements == 5 + assert "button" in config.role_filter + assert "link" in config.role_filter + + def test_default_patterns_exist(self) -> None: + config = ModalDismissalConfig() + # Should have common dismissal patterns + assert "no thanks" in config.dismiss_patterns + assert "close" in config.dismiss_patterns + assert "skip" in config.dismiss_patterns + assert "cancel" in config.dismiss_patterns + + def test_default_icons_exist(self) -> None: + config = ModalDismissalConfig() + # Should have common close icons + assert "x" in config.dismiss_icons + assert "×" in config.dismiss_icons + assert "✕" in config.dismiss_icons + + def test_can_be_disabled(self) -> None: + config = ModalDismissalConfig(enabled=False) + assert config.enabled is False + + def test_custom_patterns_for_i18n(self) -> None: + # German patterns + config_german = ModalDismissalConfig( + dismiss_patterns=( + "nein danke", + "schließen", + "abbrechen", + ), + ) + assert "nein danke" in config_german.dismiss_patterns + assert "schließen" in config_german.dismiss_patterns + assert config_german.dismiss_patterns[0] == "nein danke" + + def test_custom_icons(self) -> None: + config = ModalDismissalConfig( + dismiss_icons=("x", "×", "✖"), + ) + assert len(config.dismiss_icons) == 3 + assert "✖" in config.dismiss_icons + + def test_custom_max_attempts(self) -> None: + config = ModalDismissalConfig(max_attempts=5) + assert config.max_attempts == 5 + + def test_planner_executor_config_has_modal(self) -> None: + config = PlannerExecutorConfig() + assert hasattr(config, "modal") + assert config.modal.enabled is True + + def test_planner_executor_config_custom_modal(self) -> None: + config = PlannerExecutorConfig( + modal=ModalDismissalConfig( + enabled=True, + dismiss_patterns=("custom pattern",), + max_attempts=3, + ), + ) + assert config.modal.max_attempts == 3 + assert "custom pattern" in config.modal.dismiss_patterns + + # --------------------------------------------------------------------------- # Test PlannerExecutorConfig with new options # --------------------------------------------------------------------------- @@ -573,3 +720,634 @@ def test_optional_substeps_nested_structure(self) -> None: assert len(step.optional_substeps) == 2 assert step.optional_substeps[0].action == "SCROLL" assert step.optional_substeps[1].action == "CLICK" + + +# --------------------------------------------------------------------------- +# Test SnapshotEscalationConfig +# --------------------------------------------------------------------------- + + +class TestSnapshotEscalationConfig: + """Tests for SnapshotEscalationConfig including scroll-after-escalation.""" + + def test_default_values(self) -> None: + config = SnapshotEscalationConfig() + assert config.enabled is True + assert config.limit_base == 60 + assert config.limit_step == 30 + assert config.limit_max == 200 + + def test_scroll_after_escalation_defaults(self) -> None: + config = SnapshotEscalationConfig() + assert config.scroll_after_escalation is True + assert config.scroll_max_attempts == 3 + assert config.scroll_directions == ("down", "up") + + def test_scroll_after_escalation_can_be_disabled(self) -> None: + config = SnapshotEscalationConfig(scroll_after_escalation=False) + assert config.scroll_after_escalation is False + + def test_custom_scroll_settings(self) -> None: + config = SnapshotEscalationConfig( + scroll_after_escalation=True, + scroll_max_attempts=5, + scroll_directions=("down",), # Only scroll down + ) + assert config.scroll_max_attempts == 5 + assert config.scroll_directions == ("down",) + + def test_scroll_viewport_fraction_default(self) -> None: + config = SnapshotEscalationConfig() + assert config.scroll_viewport_fraction == 0.4 + + def test_scroll_viewport_fraction_custom(self) -> None: + config = SnapshotEscalationConfig(scroll_viewport_fraction=0.5) + assert config.scroll_viewport_fraction == 0.5 + + def test_scroll_directions_can_be_reordered(self) -> None: + # Try up first, then down + config = SnapshotEscalationConfig( + scroll_directions=("up", "down"), + ) + assert config.scroll_directions == ("up", "down") + + def test_limit_escalation_steps(self) -> None: + # Default: 60 -> 90 -> 120 -> 150 -> 180 -> 200 + config = SnapshotEscalationConfig() + limits = [] + current = config.limit_base + while current <= config.limit_max: + limits.append(current) + current = min(current + config.limit_step, config.limit_max + 1) + if current > config.limit_max and current - config.limit_step < config.limit_max: + limits.append(config.limit_max) + break + assert limits == [60, 90, 120, 150, 180, 200] or limits == [60, 90, 120, 150, 180] + + def test_custom_limit_step(self) -> None: + # Custom: 60 -> 110 -> 160 -> 200 + config = SnapshotEscalationConfig(limit_step=50) + assert config.limit_step == 50 + + +# --------------------------------------------------------------------------- +# Test PlannerExecutorConfig with SnapshotEscalationConfig +# --------------------------------------------------------------------------- + + +class TestPlannerExecutorConfigSnapshot: + """Tests for PlannerExecutorConfig snapshot escalation settings.""" + + def test_default_snapshot_config(self) -> None: + config = PlannerExecutorConfig() + assert config.snapshot is not None + assert config.snapshot.enabled is True + assert config.snapshot.scroll_after_escalation is True + + def test_custom_snapshot_config(self) -> None: + config = PlannerExecutorConfig( + snapshot=SnapshotEscalationConfig( + enabled=True, + limit_base=100, + limit_max=300, + scroll_after_escalation=True, + scroll_max_attempts=5, + ), + ) + assert config.snapshot.limit_base == 100 + assert config.snapshot.limit_max == 300 + assert config.snapshot.scroll_max_attempts == 5 + + def test_disable_snapshot_escalation(self) -> None: + config = PlannerExecutorConfig( + snapshot=SnapshotEscalationConfig(enabled=False), + ) + assert config.snapshot.enabled is False + + def test_scroll_to_find_config_exists(self) -> None: + # Verify scroll_to_find settings exist in PlannerExecutorConfig + config = PlannerExecutorConfig() + assert hasattr(config, "scroll_to_find_enabled") + assert hasattr(config, "scroll_to_find_max_scrolls") + assert hasattr(config, "scroll_to_find_directions") + + +# --------------------------------------------------------------------------- +# Test _snapshot_with_escalation scroll-after-escalation behavior (async) +# --------------------------------------------------------------------------- + + +class MockElement: + """Mock element for testing.""" + + def __init__(self, id: int, role: str = "button", text: str = ""): + self.id = id + self.role = role + self.text = text + self.name = text + + +class MockSnapshot: + """Mock snapshot for testing.""" + + def __init__(self, elements: list[MockElement], url: str = "https://example.com"): + # Ensure we have at least 3 elements to pass detect_snapshot_failure + if len(elements) < 3: + # Add filler elements to prevent "too_few_elements" detection + for i in range(3 - len(elements)): + elements.append(MockElement(100 + i, "generic", f"Filler {i}")) + self.elements = elements + self.url = url + self.title = "Test Page" + self.status = "success" # Use "success" to pass detect_snapshot_failure + self.screenshot = None + self.diagnostics = None # No diagnostics to avoid confidence checks + + +class MockRuntime: + """Mock runtime for testing scroll-after-escalation and auth boundary.""" + + def __init__( + self, + snapshots_by_scroll: dict[int, MockSnapshot] | None = None, + url: str = "https://example.com", + ): + self.scroll_count = 0 + self.scroll_directions: list[str] = [] + self.snapshots_by_scroll = snapshots_by_scroll or {} + self.default_snapshot = MockSnapshot([MockElement(1, "button", "Submit")]) + self._url = url + + async def get_url(self) -> str: + """Return the current URL for auth boundary detection.""" + return self._url + + async def get_viewport_height(self) -> int: + """Return mock viewport height.""" + return 800 # Standard viewport height + + async def snapshot( + self, + limit: int = 60, + screenshot: bool = False, + goal: str = "", + show_overlay: bool = False, + ) -> MockSnapshot: + # Return snapshot based on scroll count + return self.snapshots_by_scroll.get(self.scroll_count, self.default_snapshot) + + async def scroll(self, direction: str = "down") -> None: + self.scroll_count += 1 + self.scroll_directions.append(direction) + + async def scroll_by( + self, + dy: float, + *, + verify: bool = True, + min_delta_px: float = 50.0, + js_fallback: bool = True, + required: bool = True, + timeout_s: float = 10.0, + **kwargs: Any, + ) -> bool: + """Mock scroll_by with verification - always returns True (scroll effective).""" + direction = "down" if dy > 0 else "up" + self.scroll_count += 1 + self.scroll_directions.append(direction) + return True # Scroll always effective in tests + + async def stabilize(self) -> None: + pass + + +class MockLLMProvider: + """Mock LLM provider for testing.""" + + def generate(self, *args: Any, **kwargs: Any) -> Any: + class MockResponse: + content = '{"task": "test", "steps": []}' + + return MockResponse() + + +class MockIntentHeuristics: + """Mock intent heuristics for testing scroll-after-escalation.""" + + def __init__(self, element_map: dict[str, int] | None = None): + """ + Args: + element_map: Maps intent names to element IDs to return. + If intent not in map, returns None. + """ + self.element_map = element_map or {} + + def find_element_for_intent( + self, + intent: str, + elements: list[Any], + url: str, + goal: str, + ) -> int | None: + return self.element_map.get(intent) + + def priority_order(self) -> list[str]: + return list(self.element_map.keys()) + + +class TestSnapshotWithEscalationScrollBehavior: + """Tests for _snapshot_with_escalation scroll-after-escalation behavior.""" + + @pytest.mark.asyncio + async def test_scroll_after_escalation_disabled_no_scroll(self) -> None: + """When scroll_after_escalation is disabled, no scrolling should occur.""" + from predicate.agents.planner_executor_agent import PlannerExecutorAgent + + config = PlannerExecutorConfig( + snapshot=SnapshotEscalationConfig( + enabled=True, + scroll_after_escalation=False, # Disabled + ), + ) + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MockLLMProvider(), + config=config, + ) + + runtime = MockRuntime() + step = PlanStep( + id=1, + goal="Click Add to Cart", + action="CLICK", + intent="add_to_cart", + verify=[], + ) + + ctx = await agent._snapshot_with_escalation( + runtime, + goal=step.goal, + capture_screenshot=False, + step=step, + ) + + # No scrolling should have occurred + assert runtime.scroll_count == 0 + assert ctx.snapshot is not None + + @pytest.mark.asyncio + async def test_scroll_after_escalation_scrolls_when_element_not_found(self) -> None: + """When element not found after limit escalation, should scroll to find it.""" + from predicate.agents.planner_executor_agent import PlannerExecutorAgent + from unittest.mock import AsyncMock, patch, MagicMock + + config = PlannerExecutorConfig( + snapshot=SnapshotEscalationConfig( + enabled=False, # Disable limit escalation to test scroll directly + scroll_after_escalation=True, + scroll_max_attempts=2, + scroll_directions=("down",), + ), + verbose=True, # Enable verbose to see debug output + stabilize_enabled=False, + ) + + # Must inject intent_heuristics for scroll-after-escalation to trigger + # (scroll only happens for CLICK actions with intent and heuristics) + mock_heuristics = MockIntentHeuristics() + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MockLLMProvider(), + config=config, + intent_heuristics=mock_heuristics, # Required for scroll-after-escalation + ) + + # Initial snapshot has no "Add to Cart" button + # After 1 scroll, "Add to Cart" appears + initial_snap = MockSnapshot([MockElement(1, "button", "Submit")]) + after_scroll_snap = MockSnapshot([ + MockElement(1, "button", "Submit"), + MockElement(2, "button", "Add to Cart"), + ]) + + runtime = MockRuntime(snapshots_by_scroll={0: initial_snap, 1: after_scroll_snap}) + + step = PlanStep( + id=1, + goal="Click Add to Cart button", + action="CLICK", + intent="add_to_cart", + verify=[], + ) + + # Mock _try_intent_heuristics to simulate finding element after scroll + call_count = 0 + + async def mock_try_heuristics(step: Any, elements: list, url: str) -> int | None: + nonlocal call_count + call_count += 1 + # First call (before scroll): element not found + # Second call (after scroll): element found + if call_count == 1: + return None + return 2 # Found element 2 after scrolling + + # Mock _format_context to avoid issues with snapshot formatting + def mock_format_context(snap: Any, goal: str) -> str: + return "mocked context" + + with patch.object(agent, "_try_intent_heuristics", side_effect=mock_try_heuristics): + with patch.object(agent, "_format_context", side_effect=mock_format_context): + ctx = await agent._snapshot_with_escalation( + runtime, + goal=step.goal, + capture_screenshot=False, + step=step, + ) + + # Should have scrolled to find the element + assert runtime.scroll_count >= 1 + assert "down" in runtime.scroll_directions + + @pytest.mark.asyncio + async def test_scroll_tries_multiple_directions(self) -> None: + """Should try all configured directions when element not found.""" + from predicate.agents.planner_executor_agent import PlannerExecutorAgent + from unittest.mock import AsyncMock, patch, MagicMock + + config = PlannerExecutorConfig( + snapshot=SnapshotEscalationConfig( + enabled=False, # Disable limit escalation + scroll_after_escalation=True, + scroll_max_attempts=2, + scroll_directions=("down", "up"), + ), + verbose=True, # Enable verbose to see debug output + stabilize_enabled=False, + ) + + # Must inject intent_heuristics for scroll-after-escalation to trigger + mock_heuristics = MockIntentHeuristics() + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MockLLMProvider(), + config=config, + intent_heuristics=mock_heuristics, # Required for scroll-after-escalation + ) + + # Element never found - should try all directions + snap = MockSnapshot([MockElement(1, "button", "Submit")]) + runtime = MockRuntime(snapshots_by_scroll={i: snap for i in range(10)}) + + step = PlanStep( + id=1, + goal="Click non-existent button", + action="CLICK", + intent="nonexistent", + verify=[], + ) + + # Mock _try_intent_heuristics to always return None (element never found) + async def mock_try_heuristics(step: Any, elements: list, url: str) -> int | None: + return None # Element never found + + # Mock _format_context to avoid issues with snapshot formatting + def mock_format_context(snap: Any, goal: str) -> str: + return "mocked context" + + with patch.object(agent, "_try_intent_heuristics", side_effect=mock_try_heuristics): + with patch.object(agent, "_format_context", side_effect=mock_format_context): + ctx = await agent._snapshot_with_escalation( + runtime, + goal=step.goal, + capture_screenshot=False, + step=step, + ) + + # Should have tried both directions + assert "down" in runtime.scroll_directions + assert "up" in runtime.scroll_directions + # Max attempts per direction is 2, so total scrolls should be 4 + assert runtime.scroll_count == 4 + + @pytest.mark.asyncio + async def test_no_scroll_when_step_is_none(self) -> None: + """When step is None, scroll-after-escalation should not trigger.""" + from predicate.agents.planner_executor_agent import PlannerExecutorAgent + + config = PlannerExecutorConfig( + snapshot=SnapshotEscalationConfig( + enabled=False, + scroll_after_escalation=True, + ), + ) + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MockLLMProvider(), + config=config, + ) + + runtime = MockRuntime() + + ctx = await agent._snapshot_with_escalation( + runtime, + goal="Some goal", + capture_screenshot=False, + step=None, # No step provided + ) + + # No scrolling should occur without a step + assert runtime.scroll_count == 0 + + @pytest.mark.asyncio + async def test_no_scroll_for_type_and_submit_actions(self) -> None: + """Scroll-after-escalation should NOT trigger for TYPE_AND_SUBMIT actions.""" + from predicate.agents.planner_executor_agent import PlannerExecutorAgent + + config = PlannerExecutorConfig( + snapshot=SnapshotEscalationConfig( + enabled=False, + scroll_after_escalation=True, # Enabled, but should not trigger + ), + ) + + # Even with intent_heuristics, TYPE_AND_SUBMIT should not trigger scroll + mock_heuristics = MockIntentHeuristics() + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MockLLMProvider(), + config=config, + intent_heuristics=mock_heuristics, + ) + + runtime = MockRuntime() + + step = PlanStep( + id=1, + goal="Search for Logitech mouse", + action="TYPE_AND_SUBMIT", # Not a CLICK action + intent="search", + input="Logitech mouse", + verify=[], + ) + + ctx = await agent._snapshot_with_escalation( + runtime, + goal=step.goal, + capture_screenshot=False, + step=step, + ) + + # No scrolling should occur for TYPE_AND_SUBMIT + assert runtime.scroll_count == 0 + + @pytest.mark.asyncio + async def test_no_scroll_without_intent_heuristics(self) -> None: + """Scroll-after-escalation should NOT trigger without intent_heuristics.""" + from predicate.agents.planner_executor_agent import PlannerExecutorAgent + + config = PlannerExecutorConfig( + snapshot=SnapshotEscalationConfig( + enabled=False, + scroll_after_escalation=True, # Enabled, but should not trigger + ), + ) + + # No intent_heuristics injected + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MockLLMProvider(), + config=config, + # intent_heuristics=None # Not provided + ) + + runtime = MockRuntime() + + step = PlanStep( + id=1, + goal="Click Add to Cart", + action="CLICK", + intent="add_to_cart", + verify=[], + ) + + ctx = await agent._snapshot_with_escalation( + runtime, + goal=step.goal, + capture_screenshot=False, + step=step, + ) + + # No scrolling should occur without intent_heuristics + assert runtime.scroll_count == 0 + + +# --------------------------------------------------------------------------- +# Test AuthBoundaryConfig +# --------------------------------------------------------------------------- + + +class TestAuthBoundaryConfig: + """Tests for AuthBoundaryConfig.""" + + def test_default_values(self) -> None: + config = AuthBoundaryConfig() + assert config.enabled is True + assert config.stop_on_auth is True + assert "/signin" in config.url_patterns + assert "/ap/signin" in config.url_patterns + assert "/ax/claim" in config.url_patterns # Amazon CAPTCHA + + def test_default_url_patterns_include_common_patterns(self) -> None: + config = AuthBoundaryConfig() + expected_patterns = [ + "/signin", + "/sign-in", + "/login", + "/log-in", + "/auth", + "/authenticate", + "/ap/signin", # Amazon + "/ap/register", # Amazon + "/ax/claim", # Amazon CAPTCHA + "/account/login", + "/accounts/login", + "/user/login", + ] + for pattern in expected_patterns: + assert pattern in config.url_patterns, f"Missing pattern: {pattern}" + + def test_can_be_disabled(self) -> None: + config = AuthBoundaryConfig(enabled=False) + assert config.enabled is False + + def test_stop_on_auth_can_be_disabled(self) -> None: + config = AuthBoundaryConfig(stop_on_auth=False) + assert config.stop_on_auth is False + + def test_custom_url_patterns(self) -> None: + config = AuthBoundaryConfig( + url_patterns=("/custom/login", "/my-signin"), + ) + assert config.url_patterns == ("/custom/login", "/my-signin") + + def test_custom_auth_success_message(self) -> None: + config = AuthBoundaryConfig( + auth_success_message="Custom: Login required", + ) + assert config.auth_success_message == "Custom: Login required" + + def test_planner_executor_config_has_auth_boundary(self) -> None: + config = PlannerExecutorConfig() + assert config.auth_boundary is not None + assert config.auth_boundary.enabled is True + + def test_planner_executor_config_custom_auth_boundary(self) -> None: + config = PlannerExecutorConfig( + auth_boundary=AuthBoundaryConfig( + enabled=True, + url_patterns=("/custom-signin",), + stop_on_auth=True, + ), + ) + assert config.auth_boundary.url_patterns == ("/custom-signin",) + + +# --------------------------------------------------------------------------- +# Test Modal Dismissal After Successful CLICK +# --------------------------------------------------------------------------- + + +class TestModalDismissalAfterSuccessfulClick: + """Tests for modal dismissal when verification passes after CLICK.""" + + def test_modal_dismissal_config_min_new_elements_default(self) -> None: + """Default min_new_elements should be 5 for DOM change detection.""" + config = ModalDismissalConfig() + assert config.min_new_elements == 5 + + def test_modal_enabled_by_default_in_planner_executor_config(self) -> None: + """Modal dismissal should be enabled by default.""" + config = PlannerExecutorConfig() + assert config.modal.enabled is True + assert config.modal.min_new_elements == 5 + + def test_modal_dismissal_patterns_include_no_thanks(self) -> None: + """'no thanks' should be in default patterns for drawer dismissal.""" + config = ModalDismissalConfig() + # This is the pattern that dismisses Amazon's product protection drawer + assert "no thanks" in config.dismiss_patterns + + def test_modal_config_has_required_fields_for_drawer_dismissal(self) -> None: + """Config should have all fields needed for drawer dismissal logic.""" + config = ModalDismissalConfig() + # These are all used in _attempt_modal_dismissal + assert hasattr(config, "enabled") + assert hasattr(config, "dismiss_patterns") + assert hasattr(config, "dismiss_icons") + assert hasattr(config, "role_filter") + assert hasattr(config, "max_attempts") + assert hasattr(config, "min_new_elements")